diff --git a/.github/workflows/basic.yml b/.github/workflows/basic.yml index e0f9221b..970f15d5 100644 --- a/.github/workflows/basic.yml +++ b/.github/workflows/basic.yml @@ -3,17 +3,16 @@ name: Basic tests on: [push, pull_request] jobs: - shellcheck: runs-on: ubuntu-24.04 if: github.event_name != 'push' || github.repository == 'DIRACGrid/DIRAC' timeout-minutes: 10 steps: - - uses: actions/checkout@v3 - - name: Run shellcheck - run: | - find tests/CI -name '*.sh' -print0 | xargs -0 -n1 shellcheck --external-sources; + - uses: actions/checkout@v3 + - name: Run shellcheck + run: | + find tests/CI -name '*.sh' -print0 | xargs -0 -n1 shellcheck --external-sources; pycodestyle: runs-on: ubuntu-24.04 @@ -23,28 +22,26 @@ jobs: strategy: fail-fast: false matrix: - python: - - 2.7.18 + python: - 3.6.15 - 3.9.17 container: python:${{ matrix.python }}-slim steps: - - uses: actions/checkout@v3 - - name: Installing dependencies - run: | - python -m pip install pycodestyle - - name: Run pycodestyle - run: | - if [[ "${REFERENCE_BRANCH}" != "" ]]; then - git remote add upstream https://github.com/DIRACGrid/Pilot.git - git fetch --no-tags upstream "${REFERENCE_BRANCH}" - git branch -vv - git diff -U0 "upstream/${REFERENCE_BRANCH}" | pycodestyle --diff - fi - env: - REFERENCE_BRANCH: ${{ github['base_ref'] || github['head_ref'] }} - + - uses: actions/checkout@v3 + - name: Installing dependencies + run: | + python -m pip install pycodestyle + - name: Run pycodestyle + run: | + if [[ "${REFERENCE_BRANCH}" != "" ]]; then + git remote add upstream https://github.com/DIRACGrid/Pilot.git + git fetch --no-tags upstream "${REFERENCE_BRANCH}" + git branch -vv + git diff -U0 "upstream/${REFERENCE_BRANCH}" | pycodestyle --diff + fi + env: + REFERENCE_BRANCH: ${{ github['base_ref'] || github['head_ref'] }} pytest: runs-on: ubuntu-24.04 @@ -54,24 +51,22 @@ jobs: strategy: fail-fast: false matrix: - python: - - 2.7.18 + python: - 3.6.15 - 3.9.17 container: python:${{ matrix.python }}-slim steps: - - uses: actions/checkout@v3 - - name: Installing dependencies - run: | - echo 'deb http://archive.debian.org/debian stretch main' > /etc/apt/sources.list - echo 'deb http://archive.debian.org/debian-security stretch/updates main' >> /etc/apt/sources.list - apt-get update || true - python -m pip install pytest mock - apt install -y voms-clients - - name: Run pytest - run: pytest - + - uses: actions/checkout@v3 + - name: Installing dependencies + run: | + echo 'deb http://archive.debian.org/debian stretch main' > /etc/apt/sources.list + echo 'deb http://archive.debian.org/debian-security stretch/updates main' >> /etc/apt/sources.list + apt-get update || true + python -m pip install pytest mock + apt install -y voms-clients + - name: Run pytest + run: pytest pylint: runs-on: ubuntu-24.04 @@ -81,16 +76,15 @@ jobs: strategy: fail-fast: false matrix: - python: - - 2.7.18 + python: - 3.6.15 - 3.9.17 container: python:${{ matrix.python }}-slim steps: - - uses: actions/checkout@v3 - - name: Installing dependencies - run: | - python -m pip install pylint - - name: Run pylint - run: pylint -E Pilot/ + - uses: actions/checkout@v3 + - name: Installing dependencies + run: | + python -m pip install pylint + - name: Run pylint + run: pylint -E Pilot/ diff --git a/.pylintrc b/.pylintrc index 2bae85fd..7c158f91 100644 --- a/.pylintrc +++ b/.pylintrc @@ -18,7 +18,3 @@ dummy-variables=_ disable= invalid-name, line-too-long, # would be nice to remove this one - consider-using-f-string, # python2/3 support - unspecified-encoding, # python2/3 support - super-with-arguments, # python2/3 support - redefined-builtin, # python2/3 support \ No newline at end of file diff --git a/Pilot/__init__.py b/Pilot/__init__.py deleted file mode 100755 index e69de29b..00000000 diff --git a/Pilot/dirac-pilot.py b/Pilot/dirac-pilot.py index 9c434c97..b24ebd3f 100644 --- a/Pilot/dirac-pilot.py +++ b/Pilot/dirac-pilot.py @@ -19,36 +19,20 @@ But, as said, all the actions are actually configurable. """ -from __future__ import absolute_import, division, print_function - import os import sys import time +from io import StringIO + +from pilotTools import ( + Logger, + PilotParams, + RemoteLogger, + getCommand, + pythonPathCheck, +) +from proxyTools import revokePilotToken -############################ -# python 2 -> 3 "hacks" - -try: - from cStringIO import StringIO -except ImportError: - from io import StringIO - -try: - from Pilot.pilotTools import ( - Logger, - PilotParams, - RemoteLogger, - getCommand, - pythonPathCheck, - ) -except ImportError: - from pilotTools import ( - Logger, - PilotParams, - RemoteLogger, - getCommand, - pythonPathCheck, - ) ############################ if __name__ == "__main__": @@ -63,25 +47,34 @@ # print the buffer, so we have a "classic' logger back in sync. sys.stdout.write(bufContent) # now the remote logger. - remote = pilotParams.pilotLogging and (pilotParams.loggerURL is not None) - if remote: + remote = pilotParams.pilotLogging and pilotParams.diracXServer + if remote and pilotParams.jwt != {}: # In a remote logger enabled Dirac version we would have some classic logger content from a wrapper, # which we passed in: receivedContent = "" if not sys.stdin.isatty(): receivedContent = sys.stdin.read() + log = RemoteLogger( - pilotParams.loggerURL, + pilotParams.diracXServer, "Pilot", bufsize=pilotParams.loggerBufsize, pilotUUID=pilotParams.pilotUUID, debugFlag=pilotParams.debugFlag, - wnVO=pilotParams.wnVO, + jwt=pilotParams.jwt, + legacy_logging=pilotParams.isLegacyPilot, + clientID=pilotParams.clientID ) log.info("Remote logger activated") - log.buffer.write(receivedContent) + log.buffer.write(log.format_to_json( + "INFO", + receivedContent, + )) log.buffer.flush() - log.buffer.write(bufContent) + log.buffer.write(log.format_to_json( + "INFO", + bufContent, + )) else: log = Logger("Pilot", debugFlag=pilotParams.debugFlag) @@ -100,11 +93,13 @@ log.debug("PARAMETER [%s]" % ", ".join(map(str, pilotParams.optList))) if pilotParams.commandExtensions: - log.info("Requested command extensions: %s" % str(pilotParams.commandExtensions)) + log.info( + "Requested command extensions: %s" % str(pilotParams.commandExtensions) + ) log.info("Executing commands: %s" % str(pilotParams.commands)) - if remote: + if remote and pilotParams.jwt: # It's safer to cancel the timer here. Each command has got its own logger object with a timer cancelled by the # finaliser. No need for a timer in the "else" code segment below. try: @@ -122,5 +117,20 @@ log.error("Command %s could not be instantiated" % commandName) # send the last message and abandon ship. if remote: - log.buffer.flush() + log.buffer.flush(force=True) sys.exit(-1) + + log.info("Pilot tasks finished.") + + if pilotParams.jwt: + if remote: + log.buffer.flush(force=True) + + if not pilotParams.isLegacyPilot: + log.info("Revoking pilot token.") + revokePilotToken( + pilotParams.diracXServer, + pilotParams.pilotUUID, + pilotParams.jwt, + pilotParams.clientID, + ) diff --git a/Pilot/pilotCommands.py b/Pilot/pilotCommands.py index 5032a02d..7f52f15c 100644 --- a/Pilot/pilotCommands.py +++ b/Pilot/pilotCommands.py @@ -17,50 +17,25 @@ def __init__(self, pilotParams): execution. """ -from __future__ import absolute_import, division, print_function - import filecmp import os import platform import shutil import socket import stat -import sys import time import traceback -import subprocess from collections import Counter +from http.client import HTTPSConnection +from shlex import quote + +from pilotTools import ( + CommandBase, + getSubmitterInfo, + retrieveUrlTimeout, + safe_listdir, +) -############################ -# python 2 -> 3 "hacks" -try: - # For Python 3.0 and later - from http.client import HTTPSConnection -except ImportError: - # Fall back to Python 2 - from httplib import HTTPSConnection - -try: - from shlex import quote -except ImportError: - from pipes import quote - -try: - from Pilot.pilotTools import ( - CommandBase, - getSubmitterInfo, - retrieveUrlTimeout, - safe_listdir, - sendMessage, - ) -except ImportError: - from pilotTools import ( - CommandBase, - getSubmitterInfo, - retrieveUrlTimeout, - safe_listdir, - sendMessage, - ) ############################ @@ -90,18 +65,23 @@ def wrapper(self): # controlled exit pRef = self.pp.pilotReference self.log.info( - "Flushing the remote logger buffer for pilot on sys.exit(): %s (exit code:%s)" % (pRef, str(exCode)) + "Flushing the remote logger buffer for pilot on sys.exit(): %s (exit code:%s)" + % (pRef, str(exCode)) ) - self.log.buffer.flush() # flush the buffer unconditionally (on sys.exit()). + try: - sendMessage(self.log.url, self.log.pilotUUID, self.log.wnVO, "finaliseLogs", {"retCode": str(exCode)}) + self.log.error(str(exCode)) + self.log.error(traceback.format_exc()) + self.log.buffer.flush(force=True) except Exception as exc: self.log.error("Remote logger couldn't be finalised %s " % str(exc)) + raise except Exception as exc: # unexpected exit: document it and bail out. self.log.error(str(exc)) self.log.error(traceback.format_exc()) + self.log.buffer.flush(force=True) raise finally: self.log.buffer.cancelTimer() @@ -132,11 +112,12 @@ def __init__(self, pilotParams): @logFinalizer def execute(self): """Get host and local user info, and other basic checks, e.g. space available""" - self.log.info("Uname = %s" % " ".join(os.uname())) self.log.info("Host Name = %s" % socket.gethostname()) self.log.info("Host FQDN = %s" % socket.getfqdn()) - self.log.info("WorkingDir = %s" % self.pp.workingDir) # this could be different than rootPath + self.log.info( + "WorkingDir = %s" % self.pp.workingDir + ) # this could be different than rootPath fileName = "/etc/redhat-release" if os.path.exists(fileName): @@ -202,7 +183,8 @@ def execute(self): if diskSpace < self.pp.minDiskSpace: self.log.error( - "%s MB < %s MB, not enough local disk space available, exiting" % (diskSpace, self.pp.minDiskSpace) + "%s MB < %s MB, not enough local disk space available, exiting" + % (diskSpace, self.pp.minDiskSpace) ) self.exitWithError(1) @@ -213,19 +195,25 @@ class InstallDIRAC(CommandBase): def __init__(self, pilotParams): """c'tor""" super(InstallDIRAC, self).__init__(pilotParams) - self.pp.rootPath = self.pp.pilotRootPath def _sourceEnvironmentFile(self): """Source the $DIRAC_RC_FILE and save the created environment in self.pp.installEnv""" - retCode, output = self.executeAndGetOutput("bash -c 'source $DIRAC_RC_PATH && env'", self.pp.installEnv) + retCode, output = self.executeAndGetOutput( + "bash -c 'source $DIRAC_RC_PATH && env'", self.pp.installEnv + ) if retCode: - self.log.error("Could not parse the %s file [ERROR %d]" % (self.pp.installEnv["DIRAC_RC_PATH"], retCode)) + self.log.error( + "Could not parse the %s file [ERROR %d]" + % (self.pp.installEnv["DIRAC_RC_PATH"], retCode) + ) self.exitWithError(retCode) for line in output.split("\n"): try: var, value = [vx.strip() for vx in line.split("=", 1)] - if var == "_" or "SSH" in var or "{" in value or "}" in value: # Avoiding useless/confusing stuff + if ( + var == "_" or "SSH" in var or "{" in value or "}" in value + ): # Avoiding useless/confusing stuff continue self.pp.installEnv[var] = value except (IndexError, ValueError): @@ -241,7 +229,13 @@ def _saveEnvInFile(self, eFile="environmentSourceDirac"): with open(eFile, "w") as fd: for var, val in self.pp.installEnv.items(): - if var == "_" or var == "X509_USER_PROXY" or "SSH" in var or "{" in val or "}" in val: + if ( + var == "_" + or var == "X509_USER_PROXY" + or "SSH" in var + or "{" in val + or "}" in val + ): continue if " " in val and val[0] != '"': val = '"%s"' % val @@ -252,21 +246,29 @@ def _getPreinstalledEnvScript(self): """Get preinstalled environment script if any""" self.log.debug("self.pp.preinstalledEnv = %s" % self.pp.preinstalledEnv) - self.log.debug("self.pp.preinstalledEnvPrefix = %s" % self.pp.preinstalledEnvPrefix) + self.log.debug( + "self.pp.preinstalledEnvPrefix = %s" % self.pp.preinstalledEnvPrefix + ) self.log.debug("self.pp.CVMFS_locations = %s" % self.pp.CVMFS_locations) preinstalledEnvScript = self.pp.preinstalledEnv if not preinstalledEnvScript and self.pp.preinstalledEnvPrefix: version = self.pp.releaseVersion or "pro" arch = platform.system() + "-" + platform.machine() - preinstalledEnvScript = os.path.join(self.pp.preinstalledEnvPrefix, version, arch, "diracosrc") + preinstalledEnvScript = os.path.join( + self.pp.preinstalledEnvPrefix, version, arch, "diracosrc" + ) if not preinstalledEnvScript and self.pp.CVMFS_locations: for CVMFS_location in self.pp.CVMFS_locations: version = self.pp.releaseVersion or "pro" arch = platform.system() + "-" + platform.machine() preinstalledEnvScript = os.path.join( - CVMFS_location, self.pp.releaseProject.lower() + "dirac", version, arch, "diracosrc" + CVMFS_location, + self.pp.releaseProject.lower() + "dirac", + version, + arch, + "diracosrc", ) if os.path.isfile(preinstalledEnvScript): break @@ -283,7 +285,7 @@ def _getPreinstalledEnvScript(self): self.pp.installEnv["DIRAC_RC_PATH"] = preinstalledEnvScript def _localInstallDIRAC(self): - """Install python3 version of DIRAC client""" + """Install DIRAC client""" self.log.info("Installing DIRAC locally") @@ -296,10 +298,7 @@ def _localInstallDIRAC(self): # 1. Get the DIRACOS installer name # curl -O -L https://github.com/DIRACGrid/DIRACOS2/releases/latest/download/DIRACOS-Linux-$(uname -m).sh - try: - machine = os.uname().machine # py3 - except AttributeError: - machine = os.uname()[4] # py2 + machine = os.uname().machine installerName = "DIRACOS-Linux-%s.sh" % machine @@ -316,7 +315,8 @@ def _localInstallDIRAC(self): # 3. Get the installer from GitHub otherwise if not retrieveUrlTimeout( - "https://github.com/DIRACGrid/DIRACOS2/releases/latest/download/%s" % installerName, + "https://github.com/DIRACGrid/DIRACOS2/releases/latest/download/%s" + % installerName, installerName, self.log, ): @@ -326,7 +326,9 @@ def _localInstallDIRAC(self): shutil.rmtree("diracos") # 4. bash DIRACOS-Linux-$(uname -m).sh - retCode, _ = self.executeAndGetOutput("bash %s 2>&1" % installerName, installEnv) + retCode, _ = self.executeAndGetOutput( + "bash %s 2>&1" % installerName, installEnv + ) if retCode: self.log.error("Could not install DIRACOS [ERROR %d]" % retCode) self.exitWithError(retCode) @@ -339,8 +341,16 @@ def _localInstallDIRAC(self): if self.pp.userEnvVariables: userEnvVariables = dict( zip( - [name.split(":::")[0] for name in self.pp.userEnvVariables.replace(" ", "").split(",")], - [value.split(":::")[1] for value in self.pp.userEnvVariables.replace(" ", "").split(",")], + [ + name.split(":::")[0] + for name in self.pp.userEnvVariables.replace(" ", "").split(",") + ], + [ + value.split(":::")[1] + for value in self.pp.userEnvVariables.replace(" ", "").split( + "," + ) + ], ) ) lines = [] @@ -352,7 +362,9 @@ def _localInstallDIRAC(self): diracosrc.write("\n".join(lines)) # 6. source diracos/diracosrc - self.pp.installEnv["DIRAC_RC_PATH"] = os.path.join(os.getcwd(), "diracos/diracosrc") + self.pp.installEnv["DIRAC_RC_PATH"] = os.path.join( + os.getcwd(), "diracos/diracosrc" + ) self._sourceEnvironmentFile() self._saveEnvInFile() @@ -379,19 +391,35 @@ def _localInstallDIRAC(self): pipInstalling += "[pilot]" # pipInstalling = "pip install %s%s@%s#egg=%s[pilot]" % (prefix, url, branch, project) - retCode, output = self.executeAndGetOutput(pipInstalling, self.pp.installEnv) + retCode, output = self.executeAndGetOutput( + pipInstalling, self.pp.installEnv + ) if retCode: self.log.error("Could not %s [ERROR %d]" % (pipInstalling, retCode)) self.exitWithError(retCode) else: # pip install DIRAC[pilot]==version ExtensionDIRAC[pilot]==version_ext - if not self.releaseVersion or self.releaseVersion in ["master", "main", "integration"]: - cmd = "%s %sDIRAC[pilot]" % (pipInstallingPrefix, self.pp.releaseProject) + if not self.releaseVersion or self.releaseVersion in [ + "master", + "main", + "integration", + ]: + cmd = "%s %sDIRAC[pilot]" % ( + pipInstallingPrefix, + self.pp.releaseProject, + ) else: - cmd = "%s %sDIRAC[pilot]==%s" % (pipInstallingPrefix, self.pp.releaseProject, self.releaseVersion) + cmd = "%s %sDIRAC[pilot]==%s" % ( + pipInstallingPrefix, + self.pp.releaseProject, + self.releaseVersion, + ) retCode, output = self.executeAndGetOutput(cmd, self.pp.installEnv) if retCode: - self.log.error("Could not pip install %s [ERROR %d]" % (self.releaseVersion, retCode)) + self.log.error( + "Could not pip install %s [ERROR %d]" + % (self.releaseVersion, retCode) + ) self.exitWithError(retCode) @logFinalizer @@ -412,19 +440,30 @@ def execute(self): return # if we are here, we have a preinstalled environment self._sourceEnvironmentFile() - self.log.info("source DIRAC env DONE, for release %s" % self.pp.releaseVersion) + self.log.info( + "source DIRAC env DONE, for release %s" % self.pp.releaseVersion + ) # environment variables to add? if self.pp.userEnvVariables: # User-requested environment variables (comma-separated, name and value separated by ":::") - newEnvVars = dict(name.split(":::", 1) for name in self.pp.userEnvVariables.replace(" ", "").split(",")) - self.log.info("Adding env variable(s) to the environment : %s" % newEnvVars) + newEnvVars = dict( + name.split(":::", 1) + for name in self.pp.userEnvVariables.replace(" ", "").split(",") + ) + self.log.info( + "Adding env variable(s) to the environment : %s" % newEnvVars + ) self.pp.installEnv.update(newEnvVars) except OSError as e: - self.log.error("Exception when trying to source the DIRAC environment: %s" % str(e)) + self.log.error( + "Exception when trying to source the DIRAC environment: %s" % str(e) + ) if "cvmfsOnly" in self.pp.genericOption: self.exitWithError(1) - self.log.warn("Source of the DIRAC environment NOT DONE: starting traditional DIRAC installation") + self.log.warn( + "Source of the DIRAC environment NOT DONE: starting traditional DIRAC installation" + ) self._localInstallDIRAC() finally: @@ -465,7 +504,9 @@ def execute(self): VOs may want to replace/extend the _getBasicsCFG and _getSecurityCFG functions """ - self.pp.flavour, self.pp.pilotReference, self.pp.batchSystemInfo = getSubmitterInfo(self.pp.ceName) + self.pp.flavour, self.pp.pilotReference, self.pp.batchSystemInfo = ( + getSubmitterInfo(self.pp.ceName) + ) if not self.pp.pilotReference: self.pp.pilotReference = self.pp.pilotUUID @@ -478,11 +519,15 @@ def execute(self): if self.pp.localConfigFile: self.cfg.append("-O %s" % self.pp.localConfigFile) # here, only as output # Make sure that this configuration is available in the user job environment - self.pp.installEnv["DIRACSYSCONFIG"] = os.path.realpath(self.pp.localConfigFile) + self.pp.installEnv["DIRACSYSCONFIG"] = os.path.realpath( + self.pp.localConfigFile + ) configureCmd = "%s %s" % (self.pp.configureScript, " ".join(self.cfg)) - retCode, _configureOutData = self.executeAndGetOutput(configureCmd, self.pp.installEnv) + retCode, _configureOutData = self.executeAndGetOutput( + configureCmd, self.pp.installEnv + ) if retCode: self.log.error("Could not configure DIRAC basics [ERROR %d]" % retCode) @@ -509,30 +554,44 @@ def _getBasicsCFG(self): if self.pp.configServer: self.cfg.append('-C "%s"' % self.pp.configServer) if self.pp.preferredURLPatterns: - self.cfg.append("-o /DIRAC/PreferredURLPatterns=%s" % quote(",".join(self.pp.preferredURLPatterns))) + self.cfg.append( + "-o /DIRAC/PreferredURLPatterns=%s" + % quote(",".join(self.pp.preferredURLPatterns)) + ) if self.pp.releaseProject: self.cfg.append('-e "%s"' % self.pp.releaseProject) self.cfg.append("-o /LocalSite/ReleaseProject=%s" % self.pp.releaseProject) if self.pp.gateway: self.cfg.append('-W "%s"' % self.pp.gateway) if self.pp.userGroup: - self.cfg.append('-o /AgentJobRequirements/OwnerGroup="%s"' % self.pp.userGroup) + self.cfg.append( + '-o /AgentJobRequirements/OwnerGroup="%s"' % self.pp.userGroup + ) if self.pp.userDN: self.cfg.append('-o /AgentJobRequirements/OwnerDN="%s"' % self.pp.userDN) self.cfg.append("-o /LocalSite/ReleaseVersion=%s" % self.releaseVersion) # add the installation locations - self.cfg.append("-o /LocalSite/CVMFS_locations=%s" % ",".join(self.pp.CVMFS_locations)) + self.cfg.append( + "-o /LocalSite/CVMFS_locations=%s" % ",".join(self.pp.CVMFS_locations) + ) if self.pp.wnVO: - self.cfg.append('-o "/Resources/Computing/CEDefaults/VirtualOrganization=%s"' % self.pp.wnVO) + self.cfg.append( + '-o "/Resources/Computing/CEDefaults/VirtualOrganization=%s"' + % self.pp.wnVO + ) def _getSecurityCFG(self): """Sets security-related env variables, if needed""" # Need to know host cert and key location in case they are needed if self.pp.useServerCertificate: self.cfg.append("--UseServerCertificate") - self.cfg.append("-o /DIRAC/Security/CertFile=%s/hostcert.pem" % self.pp.certsLocation) - self.cfg.append("-o /DIRAC/Security/KeyFile=%s/hostkey.pem" % self.pp.certsLocation) + self.cfg.append( + "-o /DIRAC/Security/CertFile=%s/hostcert.pem" % self.pp.certsLocation + ) + self.cfg.append( + "-o /DIRAC/Security/KeyFile=%s/hostkey.pem" % self.pp.certsLocation + ) # If DIRAC (or its extension) is installed in CVMFS do not download VOMS and CAs if self.pp.preinstalledEnv: @@ -552,7 +611,22 @@ def __init__(self, pilotParams): @logFinalizer def execute(self): - """Calls dirac-admin-add-pilot""" + """Calls dirac-admin-add-pilot + + Deprecated in DIRAC V8, new mechanism in V9 and DiracX.""" + + if self.pp.jwt: + if not self.pp.isLegacyPilot: + self.log.warn( + "Skipping module, normally it is already done via DiracX secret-exchange." + ) + return + + # If we're here, this is a legacy pilot with a DiracX token embedded in it. + # TODO: See if we do a dirac-admin-add-pilot in DiracX for legacy pilots + else: + # If we're here, this is a DIRAC only pilot without diracX token embedded in it. + pass if not self.pp.pilotReference: self.log.warn("Skipping module, no pilot reference found") @@ -572,7 +646,9 @@ def execute(self): ) retCode, _ = self.executeAndGetOutput(checkCmd, self.pp.installEnv) if retCode: - self.log.error("Could not get execute dirac-admin-add-pilot [ERROR %d]" % retCode) + self.log.error( + "Could not get execute dirac-admin-add-pilot [ERROR %d]" % retCode + ) class CheckCECapabilities(CommandBase): @@ -627,12 +703,18 @@ def execute(self): self.pp.queueParameters = resourceDict for queueParamName, queueParamValue in self.pp.queueParameters.items(): if isinstance(queueParamValue, list): # for the tags - queueParamValue = ",".join([str(qpv).strip() for qpv in queueParamValue]) - self.cfg.append("-o /LocalSite/%s=%s" % (queueParamName, quote(queueParamValue))) + queueParamValue = ",".join( + [str(qpv).strip() for qpv in queueParamValue] + ) + self.cfg.append( + "-o /LocalSite/%s=%s" % (queueParamName, quote(queueParamValue)) + ) if self.cfg: if self.pp.localConfigFile: - self.cfg.append("-O %s" % self.pp.localConfigFile) # this file is as output + self.cfg.append( + "-O %s" % self.pp.localConfigFile + ) # this file is as output self.cfg.append("-FDMH") @@ -640,13 +722,18 @@ def execute(self): self.cfg.append("-ddd") configureCmd = "%s %s" % (self.pp.configureScript, " ".join(self.cfg)) - retCode, _configureOutData = self.executeAndGetOutput(configureCmd, self.pp.installEnv) + retCode, _configureOutData = self.executeAndGetOutput( + configureCmd, self.pp.installEnv + ) if retCode: self.log.error("Could not configure DIRAC [ERROR %d]" % retCode) self.exitWithError(retCode) else: - self.log.debug("No CE parameters (tags) defined for %s/%s" % (self.pp.ceName, self.pp.queueName)) + self.log.debug( + "No CE parameters (tags) defined for %s/%s" + % (self.pp.ceName, self.pp.queueName) + ) class CheckWNCapabilities(CommandBase): @@ -698,12 +785,17 @@ def execute(self): self.pp.pilotProcessors = numberOfProcessorsOnWN self.log.info("pilotProcessors = %d" % self.pp.pilotProcessors) - self.cfg.append('-o "/Resources/Computing/CEDefaults/NumberOfProcessors=%d"' % self.pp.pilotProcessors) + self.cfg.append( + '-o "/Resources/Computing/CEDefaults/NumberOfProcessors=%d"' + % self.pp.pilotProcessors + ) maxRAM = self.pp.queueParameters.get("MaxRAM", maxRAM) if maxRAM: try: - self.cfg.append('-o "/Resources/Computing/CEDefaults/MaxRAM=%d"' % int(maxRAM)) + self.cfg.append( + '-o "/Resources/Computing/CEDefaults/MaxRAM=%d"' % int(maxRAM) + ) except ValueError: self.log.warn("MaxRAM is not an integer, will not fill it") else: @@ -711,17 +803,24 @@ def execute(self): if numberOfGPUs: self.log.info("numberOfGPUs = %d" % int(numberOfGPUs)) - self.cfg.append('-o "/Resources/Computing/CEDefaults/NumberOfGPUs=%d"' % int(numberOfGPUs)) + self.cfg.append( + '-o "/Resources/Computing/CEDefaults/NumberOfGPUs=%d"' + % int(numberOfGPUs) + ) # Add normal and required tags to the configuration self.pp.tags = list(set(self.pp.tags)) if self.pp.tags: - self.cfg.append('-o "/Resources/Computing/CEDefaults/Tag=%s"' % ",".join((str(x) for x in self.pp.tags))) + self.cfg.append( + '-o "/Resources/Computing/CEDefaults/Tag=%s"' + % ",".join((str(x) for x in self.pp.tags)) + ) self.pp.reqtags = list(set(self.pp.reqtags)) if self.pp.reqtags: self.cfg.append( - '-o "/Resources/Computing/CEDefaults/RequiredTag=%s"' % ",".join((str(x) for x in self.pp.reqtags)) + '-o "/Resources/Computing/CEDefaults/RequiredTag=%s"' + % ",".join((str(x) for x in self.pp.reqtags)) ) if self.pp.useServerCertificate: @@ -738,7 +837,9 @@ def execute(self): self.cfg.append("-FDMH") configureCmd = "%s %s" % (self.pp.configureScript, " ".join(self.cfg)) - retCode, _configureOutData = self.executeAndGetOutput(configureCmd, self.pp.installEnv) + retCode, _configureOutData = self.executeAndGetOutput( + configureCmd, self.pp.installEnv + ) if retCode: self.log.error("Could not configure DIRAC [ERROR %d]" % retCode) self.exitWithError(retCode) @@ -762,17 +863,31 @@ def execute(self): # Add batch system details to the configuration # Can be used by the pilot/job later on, to interact with the batch system - self.cfg.append("-o /LocalSite/BatchSystemInfo/Type=%s" % self.pp.batchSystemInfo.get("Type", "Unknown")) - self.cfg.append("-o /LocalSite/BatchSystemInfo/JobID=%s" % self.pp.batchSystemInfo.get("JobID", "Unknown")) + self.cfg.append( + "-o /LocalSite/BatchSystemInfo/Type=%s" + % self.pp.batchSystemInfo.get("Type", "Unknown") + ) + self.cfg.append( + "-o /LocalSite/BatchSystemInfo/JobID=%s" + % self.pp.batchSystemInfo.get("JobID", "Unknown") + ) batchSystemParams = self.pp.batchSystemInfo.get("Parameters", {}) - self.cfg.append("-o /LocalSite/BatchSystemInfo/Parameters/Queue=%s" % batchSystemParams.get("Queue", "Unknown")) self.cfg.append( - "-o /LocalSite/BatchSystemInfo/Parameters/BinaryPath=%s" % batchSystemParams.get("BinaryPath", "Unknown") + "-o /LocalSite/BatchSystemInfo/Parameters/Queue=%s" + % batchSystemParams.get("Queue", "Unknown") + ) + self.cfg.append( + "-o /LocalSite/BatchSystemInfo/Parameters/BinaryPath=%s" + % batchSystemParams.get("BinaryPath", "Unknown") ) - self.cfg.append("-o /LocalSite/BatchSystemInfo/Parameters/Host=%s" % batchSystemParams.get("Host", "Unknown")) self.cfg.append( - "-o /LocalSite/BatchSystemInfo/Parameters/InfoPath=%s" % batchSystemParams.get("InfoPath", "Unknown") + "-o /LocalSite/BatchSystemInfo/Parameters/Host=%s" + % batchSystemParams.get("Host", "Unknown") + ) + self.cfg.append( + "-o /LocalSite/BatchSystemInfo/Parameters/InfoPath=%s" + % batchSystemParams.get("InfoPath", "Unknown") ) self.cfg.append('-n "%s"' % self.pp.site) @@ -793,8 +908,12 @@ def execute(self): if self.pp.useServerCertificate: self.cfg.append("--UseServerCertificate") - self.cfg.append("-o /DIRAC/Security/CertFile=%s/hostcert.pem" % self.pp.certsLocation) - self.cfg.append("-o /DIRAC/Security/KeyFile=%s/hostkey.pem" % self.pp.certsLocation) + self.cfg.append( + "-o /DIRAC/Security/CertFile=%s/hostcert.pem" % self.pp.certsLocation + ) + self.cfg.append( + "-o /DIRAC/Security/KeyFile=%s/hostkey.pem" % self.pp.certsLocation + ) # these are needed as this is not the first time we call dirac-configure self.cfg.append("-FDMH") @@ -807,7 +926,9 @@ def execute(self): configureCmd = "%s %s" % (self.pp.configureScript, " ".join(self.cfg)) - retCode, _configureOutData = self.executeAndGetOutput(configureCmd, self.pp.installEnv) + retCode, _configureOutData = self.executeAndGetOutput( + configureCmd, self.pp.installEnv + ) if retCode: self.log.error("Could not configure DIRAC [ERROR %d]" % retCode) @@ -840,13 +961,22 @@ def execute(self): architectureCmd = "%s %s -ddd" % (archScript, " ".join(cfg)) if self.pp.architectureScript.split(" ")[0] == "dirac-apptainer-exec": - architectureCmd = "dirac-apptainer-exec '%s' %s" % (architectureCmd, " ".join(cfg)) + architectureCmd = "dirac-apptainer-exec '%s' %s" % ( + architectureCmd, + " ".join(cfg), + ) - retCode, localArchitecture = self.executeAndGetOutput(architectureCmd, self.pp.installEnv) + retCode, localArchitecture = self.executeAndGetOutput( + architectureCmd, self.pp.installEnv + ) if retCode: - self.log.error("There was an error getting the platform [ERROR %d]" % retCode) + self.log.error( + "There was an error getting the platform [ERROR %d]" % retCode + ) self.exitWithError(retCode) - self.log.info("Architecture determined: %s" % localArchitecture.strip().split("\n")[-1]) + self.log.info( + "Architecture determined: %s" % localArchitecture.strip().split("\n")[-1] + ) # standard options cfg = ["-FDMH"] # force update, skip CA checks, skip CA download, skip VOMS @@ -867,7 +997,9 @@ def execute(self): cfg.append("-o /LocalSite/Platform=%s" % platform.machine()) configureCmd = "%s %s" % (self.pp.configureScript, " ".join(cfg)) - retCode, _configureOutData = self.executeAndGetOutput(configureCmd, self.pp.installEnv) + retCode, _configureOutData = self.executeAndGetOutput( + configureCmd, self.pp.installEnv + ) if retCode: self.log.error("Configuration error [ERROR %d]" % retCode) self.exitWithError(retCode) @@ -925,7 +1057,9 @@ def execute(self): cfg.append("-o /LocalSite/Platform=%s" % platform.machine()) configureCmd = "%s %s" % (self.pp.configureScript, " ".join(cfg)) - retCode, _configureOutData = self.executeAndGetOutput(configureCmd, self.pp.installEnv) + retCode, _configureOutData = self.executeAndGetOutput( + configureCmd, self.pp.installEnv + ) if retCode: self.log.error("Configuration error [ERROR %d]" % retCode) self.exitWithError(retCode) @@ -948,7 +1082,11 @@ def execute(self): if self.pp.useServerCertificate: configFileArg = "-o /DIRAC/Security/UseServerCertificate=yes" if self.pp.localConfigFile: - configFileArg = "%s -R %s --cfg %s" % (configFileArg, self.pp.localConfigFile, self.pp.localConfigFile) + configFileArg = "%s -R %s --cfg %s" % ( + configFileArg, + self.pp.localConfigFile, + self.pp.localConfigFile, + ) retCode, cpuNormalizationFactorOutput = self.executeAndGetOutput( "dirac-wms-cpu-normalization -U %s -d" % configFileArg, self.pp.installEnv ) @@ -978,7 +1116,9 @@ def execute(self): ) if retCode: - self.log.error("Failed to determine cpu time left in the queue [ERROR %d]" % retCode) + self.log.error( + "Failed to determine cpu time left in the queue [ERROR %d]" % retCode + ) self.exitWithError(retCode) for line in cpuTimeOutput.split("\n"): @@ -991,7 +1131,10 @@ def execute(self): try: # determining the CPU time left (in HS06s) self.pp.jobCPUReq = float(cpuTime) * float(cpuNormalizationFactor) - self.log.info("Queue length (which is also set as CPUTimeLeft) is %f" % self.pp.jobCPUReq) + self.log.info( + "Queue length (which is also set as CPUTimeLeft) is %f" + % self.pp.jobCPUReq + ) except ValueError: self.log.error("Pilot command output does not have the correct format") self.exitWithError(1) @@ -1002,12 +1145,18 @@ def execute(self): if self.pp.localConfigFile: cfg.append("-O %s" % self.pp.localConfigFile) # our target file for pilots cfg.extend(["--cfg", self.pp.localConfigFile]) # this file is also input - cfg.append("-o /LocalSite/CPUTimeLeft=%s" % str(int(self.pp.jobCPUReq))) # the only real option + cfg.append( + "-o /LocalSite/CPUTimeLeft=%s" % str(int(self.pp.jobCPUReq)) + ) # the only real option configureCmd = "%s %s" % (self.pp.configureScript, " ".join(cfg)) - retCode, _configureOutData = self.executeAndGetOutput(configureCmd, self.pp.installEnv) + retCode, _configureOutData = self.executeAndGetOutput( + configureCmd, self.pp.installEnv + ) if retCode: - self.log.error("Failed to update CFG file for CPUTimeLeft [ERROR %d]" % retCode) + self.log.error( + "Failed to update CFG file for CPUTimeLeft [ERROR %d]" % retCode + ) self.exitWithError(retCode) @@ -1038,7 +1187,8 @@ def __setInnerCEOpts(self): "-o MaxCycles=5000", "-o PollingTime=%s" % min(20, self.pp.pollingTime), "-o StopOnApplicationFailure=False", - "-o StopAfterFailedMatches=%s" % max(self.pp.pilotProcessors, self.pp.stopAfterFailedMatches), + "-o StopAfterFailedMatches=%s" + % max(self.pp.pilotProcessors, self.pp.stopAfterFailedMatches), "-o FillingModeFlag=True", ] else: @@ -1071,7 +1221,9 @@ def __setInnerCEOpts(self): # The file pilot.cfg has to be created previously by ConfigureDIRAC if self.pp.localConfigFile: - self.innerCEOpts.append(" -o /AgentJobRequirements/ExtraOptions=%s" % self.pp.localConfigFile) + self.innerCEOpts.append( + " -o /AgentJobRequirements/ExtraOptions=%s" % self.pp.localConfigFile + ) self.innerCEOpts.extend(["--cfg", self.pp.localConfigFile]) def __startJobAgent(self): @@ -1084,13 +1236,21 @@ def __startJobAgent(self): extraCFG = [] for i in os.listdir(self.pp.rootPath): cfg = os.path.join(self.pp.rootPath, i) - if os.path.isfile(cfg) and cfg.endswith(".cfg") and not filecmp.cmp(self.pp.localConfigFile, cfg): + if ( + os.path.isfile(cfg) + and cfg.endswith(".cfg") + and not filecmp.cmp(self.pp.localConfigFile, cfg) + ): extraCFG.extend(["--cfg", cfg]) if self.pp.executeCmd: # Execute user command self.log.info("Executing user defined command: %s" % self.pp.executeCmd) - self.exitWithError(int(os.system("source diracos/diracosrc; %s" % self.pp.executeCmd) / 256)) + self.exitWithError( + int( + os.system("source diracos/diracosrc; %s" % self.pp.executeCmd) / 256 + ) + ) self.log.info("Starting JobAgent") os.environ["PYTHONUNBUFFERED"] = "yes" @@ -1117,8 +1277,6 @@ def execute(self): self.__setInnerCEOpts() self.__startJobAgent() - sys.exit(0) - class NagiosProbes(CommandBase): """Run one or more Nagios probe scripts that follow the Nagios Plugin API: @@ -1144,21 +1302,31 @@ def _setNagiosOptions(self): try: self.nagiosProbes = [ - str(pv).strip() for pv in self.pp.pilotJSON["Setups"][self.pp.setup]["NagiosProbes"].split(",") + str(pv).strip() + for pv in self.pp.pilotJSON["Setups"][self.pp.setup][ + "NagiosProbes" + ].split(",") ] except KeyError: try: self.nagiosProbes = [ - str(pv).strip() for pv in self.pp.pilotJSON["Setups"]["Defaults"]["NagiosProbes"].split(",") + str(pv).strip() + for pv in self.pp.pilotJSON["Setups"]["Defaults"][ + "NagiosProbes" + ].split(",") ] except KeyError: pass try: - self.nagiosPutURL = str(self.pp.pilotJSON["Setups"][self.pp.setup]["NagiosPutURL"]) + self.nagiosPutURL = str( + self.pp.pilotJSON["Setups"][self.pp.setup]["NagiosPutURL"] + ) except KeyError: try: - self.nagiosPutURL = str(self.pp.pilotJSON["Setups"]["Defaults"]["NagiosPutURL"]) + self.nagiosPutURL = str( + self.pp.pilotJSON["Setups"]["Defaults"]["NagiosPutURL"] + ) except KeyError: pass @@ -1191,7 +1359,9 @@ def _runNagiosProbes(self): retStatus = "warning" else: # retCode could be 2 (error) or 3 (unknown) or something we haven't thought of - self.log.error("Return code = %d: %s" % (retCode, str(output).split("\n", 1)[0])) + self.log.error( + "Return code = %d: %s" % (retCode, str(output).split("\n", 1)[0]) + ) retStatus = "error" # TODO: Do something with the retStatus (for example: log it?) @@ -1202,9 +1372,18 @@ def _runNagiosProbes(self): if self.nagiosPutURL: # Alternate logging of results to HTTPS PUT service too hostPort = self.nagiosPutURL.split("/")[2] - path = "/" + "/".join(self.nagiosPutURL.split("/")[3:]) + self.pp.ceName + "/" + probeCmd + path = ( + "/" + + "/".join(self.nagiosPutURL.split("/")[3:]) + + self.pp.ceName + + "/" + + probeCmd + ) - self.log.info("Putting %s Nagios output to https://%s%s" % (probeCmd, hostPort, path)) + self.log.info( + "Putting %s Nagios output to https://%s%s" + % (probeCmd, hostPort, path) + ) try: connection = HTTPSConnection( @@ -1214,21 +1393,29 @@ def _runNagiosProbes(self): cert_file=os.environ["X509_USER_PROXY"], ) - connection.request("PUT", path, str(retCode) + " " + str(int(time.time())) + "\n" + output) + connection.request( + "PUT", + path, + str(retCode) + " " + str(int(time.time())) + "\n" + output, + ) except Exception as e: - self.log.error("PUT of %s Nagios output fails with %s" % (probeCmd, str(e))) + self.log.error( + "PUT of %s Nagios output fails with %s" % (probeCmd, str(e)) + ) else: result = connection.getresponse() if int(result.status / 100) == 2: self.log.info( - "PUT of %s Nagios output succeeds with %d %s" % (probeCmd, result.status, result.reason) + "PUT of %s Nagios output succeeds with %d %s" + % (probeCmd, result.status, result.reason) ) else: self.log.error( - "PUT of %s Nagios output fails with %d %s" % (probeCmd, result.status, result.reason) + "PUT of %s Nagios output fails with %d %s" + % (probeCmd, result.status, result.reason) ) @logFinalizer diff --git a/Pilot/pilotTools.py b/Pilot/pilotTools.py index 11345f46..31b6f008 100644 --- a/Pilot/pilotTools.py +++ b/Pilot/pilotTools.py @@ -1,94 +1,40 @@ """A set of common tools to be used in pilot commands""" -from __future__ import absolute_import, division, print_function - import fcntl import getopt +import importlib.util import json import os import re import select import signal -import ssl import subprocess import sys import threading import warnings from datetime import datetime from functools import partial, wraps -from threading import RLock - -############################ -# python 2 -> 3 "hacks" -try: - from urllib.error import HTTPError, URLError - from urllib.parse import urlencode - from urllib.request import urlopen -except ImportError: - from urllib import urlencode - - from urllib2 import HTTPError, URLError, urlopen - -try: - import importlib.util - from importlib import import_module +from importlib import import_module +from threading import RLock, Timer +from urllib.error import HTTPError, URLError +from urllib.request import urlopen - def load_module_from_path(module_name, path_to_module): - spec = importlib.util.spec_from_file_location(module_name, path_to_module) # pylint: disable=no-member - module = importlib.util.module_from_spec(spec) # pylint: disable=no-member - spec.loader.exec_module(module) - return module -except ImportError: +def load_module_from_path(module_name, path_to_module): + spec = importlib.util.spec_from_file_location(module_name, path_to_module) # pylint: disable=no-member + module = importlib.util.module_from_spec(spec) # pylint: disable=no-member + spec.loader.exec_module(module) + return module - def import_module(module): - import imp - impData = imp.find_module(module) - return imp.load_module(module, *impData) - - def load_module_from_path(module_name, path_to_module): - import imp - - fp, pathname, description = imp.find_module(module_name, [path_to_module]) - try: - return imp.load_module(module_name, fp, pathname, description) - finally: - if fp: - fp.close() - - -try: - from cStringIO import StringIO -except ImportError: - from io import StringIO - -try: - basestring # pylint: disable=used-before-assignment -except NameError: - basestring = str - -try: - from Pilot.proxyTools import getVO -except ImportError: - from proxyTools import getVO - -try: - FileNotFoundError # pylint: disable=used-before-assignment - # because of https://github.com/PyCQA/pylint/issues/6748 -except NameError: - FileNotFoundError = OSError - -try: - IsADirectoryError # pylint: disable=used-before-assignment -except NameError: - IsADirectoryError = IOError - -# Timer 2.7 and < 3.3 versions issue where Timer is a function -if sys.version_info.major == 2 or sys.version_info.major == 3 and sys.version_info.minor < 3: - from threading import _Timer as Timer # pylint: disable=no-name-in-module -else: - from threading import Timer +from proxyTools import ( + BaseRequest, + extract_diracx_payload, + getVO, + refreshUserToken, + refreshPilotToken, + TokenBasedRequest, +) # Utilities functions @@ -98,7 +44,9 @@ def parseVersion(releaseVersion): :param str releaseVersion: The software version to use """ - VERSION_PATTERN = re.compile(r"^(?:v)?(\d+)[r\.](\d+)(?:[p\.](\d+))?(?:(?:-pre|a)?(\d+))?$") + VERSION_PATTERN = re.compile( + r"^(?:v)?(\d+)[r\.](\d+)(?:[p\.](\d+))?(?:(?:-pre|a)?(\d+))?$" + ) match = VERSION_PATTERN.match(releaseVersion) # If the regex fails just return the original version @@ -188,11 +136,17 @@ def retrieveUrlTimeout(url, fileName, log, timeout=0): signal.alarm(0) return False except URLError: - log.error('Timeout after %s seconds on transfer request for "%s"' % (str(timeout), url)) + log.error( + 'Timeout after %s seconds on transfer request for "%s"' + % (str(timeout), url) + ) return False except Exception as x: if x == "Timeout": - log.error('Timeout after %s seconds on transfer request for "%s"' % (str(timeout), url)) + log.error( + 'Timeout after %s seconds on transfer request for "%s"' + % (str(timeout), url) + ) if timeout: signal.alarm(0) raise x @@ -274,7 +228,9 @@ def getSubmitterInfo(ceName): if "SGE_TASK_ID" in os.environ: batchSystemType = "SGE" batchSystemJobID = os.environ["JOB_ID"] - batchSystemParameters["BinaryPath"] = os.environ.get("SGE_BINARY_PATH", "Unknown") + batchSystemParameters["BinaryPath"] = os.environ.get( + "SGE_BINARY_PATH", "Unknown" + ) batchSystemParameters["Queue"] = os.environ.get("QUEUE", "Unknown") flavour = "SSH%s" % batchSystemType @@ -307,7 +263,12 @@ def getSubmitterInfo(ceName): batchSystemParameters["InfoPath"] = os.environ["_CONDOR_JOB_AD"] flavour = "SSH%s" % batchSystemType - pilotReference = "sshcondor://" + ceName + "/" + os.environ.get("CONDOR_JOBID", pilotReference) + pilotReference = ( + "sshcondor://" + + ceName + + "/" + + os.environ.get("CONDOR_JOBID", pilotReference) + ) # # Local/SSH @@ -325,7 +286,12 @@ def getSubmitterInfo(ceName): if "SSHBATCH_JOBID" in os.environ and "SSH_NODE_HOST" in os.environ: flavour = "SSHBATCH" pilotReference = ( - "sshbatchhost://" + ceName + "/" + os.environ["SSH_NODE_HOST"] + "/" + os.environ["SSHBATCH_JOBID"] + "sshbatchhost://" + + ceName + + "/" + + os.environ["SSH_NODE_HOST"] + + "/" + + os.environ["SSHBATCH_JOBID"] ) # # CEs @@ -348,7 +314,11 @@ def getSubmitterInfo(ceName): return ( flavour, pilotReference, - {"Type": batchSystemType, "JobID": batchSystemJobID, "Parameters": batchSystemParameters}, + { + "Type": batchSystemType, + "JobID": batchSystemJobID, + "Parameters": batchSystemParameters, + }, ) @@ -358,7 +328,9 @@ def getFlavour(ceName): Please use getSubmitterInfo instead. """ warnings.warn( - "getFlavour() is deprecated. Please use getSubmitterInfo() instead.", category=DeprecationWarning, stacklevel=2 + "getFlavour() is deprecated. Please use getSubmitterInfo() instead.", + category=DeprecationWarning, + stacklevel=2, ) flavour, pilotReference, _ = getSubmitterInfo(ceName) return flavour, pilotReference @@ -387,7 +359,9 @@ def loadModule(self, modName, hideExceptions=False): if rootModule: impName = "%s.%s" % (rootModule, impName) self.log.debug("Trying to load %s" % impName) - module, parentPath = self.__recurseImport(impName, hideExceptions=hideExceptions) + module, parentPath = self.__recurseImport( + impName, hideExceptions=hideExceptions + ) # Error. Something cannot be imported. Return error if module is None: return None, None @@ -399,7 +373,7 @@ def loadModule(self, modName, hideExceptions=False): def __recurseImport(self, modName, parentModule=None, hideExceptions=False): """Internal function to load modules""" - if isinstance(modName, basestring): + if isinstance(modName, str): modName = modName.split(".") try: if parentModule: @@ -409,13 +383,18 @@ def __recurseImport(self, modName, parentModule=None, hideExceptions=False): except ImportError as excp: if str(excp).find("No module named %s" % modName[0]) == 0: return None, None - errMsg = "Can't load %s in %s" % (".".join(modName), parentModule.__path__[0]) + errMsg = "Can't load %s in %s" % ( + ".".join(modName), + parentModule.__path__[0], + ) if not hideExceptions: self.log.exception(errMsg) return None, None if len(modName) == 1: return impModule, parentModule.__path__[0] - return self.__recurseImport(modName[1:], impModule, hideExceptions=hideExceptions) + return self.__recurseImport( + modName[1:], impModule, hideExceptions=hideExceptions + ) def loadObject(self, package, moduleName, command): """Load an object from inside a module""" @@ -483,7 +462,9 @@ def __outputMessage(self, msg, level, header): with open(self.out, "a") as outputFile: for _line in str(msg).split("\n"): if header: - outLine = self.messageTemplate.format(level=level, message=_line) + outLine = self.messageTemplate.format( + level=level, message=_line + ) print(outLine) if self.out: outputFile.write(outLine + "\n") @@ -526,7 +507,9 @@ def __init__( pilotUUID="unknown", flushInterval=10, bufsize=1000, - wnVO="unknown", + jwt={}, + legacy_logging=False, + clientID="", ): """ c'tor @@ -536,10 +519,32 @@ def __init__( super(RemoteLogger, self).__init__(name, debugFlag, pilotOutput) self.url = url self.pilotUUID = pilotUUID - self.wnVO = wnVO self.isPilotLoggerOn = isPilotLoggerOn - sendToURL = partial(sendMessage, url, pilotUUID, wnVO, "sendMessage") - self.buffer = FixedSizeBuffer(sendToURL, bufsize=bufsize, autoflush=flushInterval) + sendToURL = partial(sendMessage, url, pilotUUID, legacy_logging, clientID) + self.buffer = FixedSizeBuffer( + sendToURL, bufsize=bufsize, autoflush=flushInterval, jwt=jwt + ) + + def format_to_json(self, level, message): + escaped = json.dumps(message)[1:-1] # remove outer quotes + + # Split on escaped newlines + splitted_message = escaped.split("\\n") + + output = [] + for mess in splitted_message: + if mess: + output.append( + { + "timestamp": datetime.utcnow().strftime( + "%Y-%m-%dT%H:%M:%S.%fZ" + ), + "severity": level, + "message": mess, + "scope": self.name, + } + ) + return output def debug(self, msg, header=True, _sendPilotLog=False): # TODO: Send pilot log remotely? @@ -547,25 +552,25 @@ def debug(self, msg, header=True, _sendPilotLog=False): if ( self.isPilotLoggerOn and self.debugFlag ): # the -d flag activates this debug flag in CommandBase via PilotParams - self.sendMessage(self.messageTemplate.format(level="DEBUG", message=msg)) + self.sendMessage(self.format_to_json(level="DEBUG", message=msg)) def error(self, msg, header=True, _sendPilotLog=False): # TODO: Send pilot log remotely? super(RemoteLogger, self).error(msg, header) if self.isPilotLoggerOn: - self.sendMessage(self.messageTemplate.format(level="ERROR", message=msg)) + self.sendMessage(self.format_to_json(level="ERROR", message=msg)) def warn(self, msg, header=True, _sendPilotLog=False): # TODO: Send pilot log remotely? super(RemoteLogger, self).warn(msg, header) if self.isPilotLoggerOn: - self.sendMessage(self.messageTemplate.format(level="WARNING", message=msg)) + self.sendMessage(self.format_to_json(level="WARNING", message=msg)) def info(self, msg, header=True, _sendPilotLog=False): # TODO: Send pilot log remotely? super(RemoteLogger, self).info(msg, header) if self.isPilotLoggerOn: - self.sendMessage(self.messageTemplate.format(level="INFO", message=msg)) + self.sendMessage(self.format_to_json(level="INFO", message=msg)) def sendMessage(self, msg): """ @@ -577,7 +582,7 @@ def sendMessage(self, msg): :rtype: None """ try: - self.buffer.write(msg + "\n") + self.buffer.write(msg) except Exception as err: super(RemoteLogger, self).error("Message not sent") super(RemoteLogger, self).error(str(err)) @@ -604,7 +609,7 @@ class FixedSizeBuffer(object): Once it's full, a message is sent to a remote server and the buffer is renewed. """ - def __init__(self, senderFunc, bufsize=1000, autoflush=10): + def __init__(self, senderFunc, bufsize=250, autoflush=10, jwt={}): """ Constructor. @@ -622,33 +627,33 @@ def __init__(self, senderFunc, bufsize=1000, autoflush=10): self._timer.start() else: self._timer = None - self.output = StringIO() + self.output = [] self.bufsize = bufsize self._nlines = 0 self.senderFunc = senderFunc + self.jwt = jwt + # A fixed buffer used by a remote buffer can be deactivated: + # If there's a 403/401 error, instead of crashing the pilot, + # we will deactivate the log sending, and prefer just running the pilot. + self.activated = True @synchronized - def write(self, text): + def write(self, content_json): """ Write text to a string buffer. Newline characters are counted and number of lines in the buffer is increased accordingly. - :param text: text string to write - :type text: str + :param content_json: Json to send, format following format_to_json + :type content_json: list[dict] :return: None :rtype: None """ - # reopen the buffer in a case we had to flush a partially filled buffer - if self.output.closed: - self.output = StringIO() - self.output.write(text) - self._nlines += max(1, text.count("\n")) - self.sendFullBuffer() + if not self.activated: + pass - @synchronized - def getValue(self): - content = self.output.getvalue() - return content + self.output.extend(content_json) + self._nlines += max(1, len(content_json)) + self.sendFullBuffer() @synchronized def sendFullBuffer(self): @@ -659,22 +664,26 @@ def sendFullBuffer(self): if self._nlines >= self.bufsize: self.flush() - self.output = StringIO() + self.output = [] @synchronized - def flush(self): + def flush(self, force=False): """ Flush the buffer and send log records to a remote server. The buffer is closed as well. :return: None :rtype: None """ - if not self.output.closed and self._nlines > 0: - self.output.flush() - buf = self.getValue() - self.senderFunc(buf) + if not self.activated: + pass + + if force or (self.output and self._nlines > 0): + try: + self.senderFunc(self.jwt, self.output) + except Exception as e: + print("Deactivating fixed size buffer due to", str(e)) + self.activated = False self._nlines = 0 - self.output.close() def cancelTimer(self): """ @@ -687,40 +696,48 @@ def cancelTimer(self): self._timer.cancel() -def sendMessage(url, pilotUUID, wnVO, method, rawMessage): +def sendMessage( + diracx_URL, pilotUUID, legacy=False, clientID="", jwt={}, rawMessage=[] +): """ Invoke a remote method on a Tornado server and pass a JSON message to it. :param str url: Server URL :param str pilotUUID: pilot unique ID - :param str wnVO: VO name, relevant only if not contained in a proxy :param str method: a method to be invoked :param str rawMessage: a message to be sent, in JSON format + :param dict jwt: JWT for the requests :return: None. """ + caPath = os.getenv("X509_CERT_DIR") - cert = os.getenv("X509_USER_PROXY") - context = ssl.create_default_context() - context.load_verify_locations(capath=caPath) + raw_data = {"pilot_stamp": pilotUUID, "lines": rawMessage} - message = json.dumps((json.dumps(rawMessage), pilotUUID, wnVO)) + if not diracx_URL.endswith("/"): + diracx_URL += "/" - try: - context.load_cert_chain(cert) # this is a proxy - raw_data = {"method": method, "args": message} - except IsADirectoryError: # assuming it'a dir containing cert and key - context.load_cert_chain(os.path.join(cert, "hostcert.pem"), os.path.join(cert, "hostkey.pem")) - raw_data = {"method": method, "args": message, "extraCredentials": '"hosts"'} - - if sys.version_info.major == 3: - data = urlencode(raw_data).encode("utf-8") # encode to bytes ! for python3 + if legacy: + endpoint_path = "api/pilots/legacy/message" + refresh_callback = partial( + refreshUserToken, diracx_URL, pilotUUID, jwt, clientID + ) else: - # Python2 - data = urlencode(raw_data) + endpoint_path = "api/pilots/internal/message" + refresh_callback = partial(refreshPilotToken, diracx_URL, pilotUUID, jwt) + + config = TokenBasedRequest( + diracx_URL=diracx_URL, + endpoint_path=endpoint_path, + caPath=caPath, + jwtData=jwt, + pilotUUID=pilotUUID, + ) - res = urlopen(url, data, context=context) - res.close() + # Do the request + _res = config.executeRequest( + raw_data=raw_data, json_output=False, refresh_callback=refresh_callback + ) class CommandBase(object): @@ -740,7 +757,7 @@ def __init__(self, pilotParams): self.debugFlag = pilotParams.debugFlag loggerURL = pilotParams.loggerURL # URL present and the flag is set: - isPilotLoggerOn = pilotParams.pilotLogging and (loggerURL is not None) + isPilotLoggerOn = pilotParams.pilotLogging and pilotParams.diracXServer interval = pilotParams.loggerTimerInterval bufsize = pilotParams.loggerBufsize @@ -749,13 +766,15 @@ def __init__(self, pilotParams): else: # remote logger self.log = RemoteLogger( - loggerURL, + self.pp.diracXServer, self.__class__.__name__, pilotUUID=pilotParams.pilotUUID, debugFlag=self.debugFlag, flushInterval=interval, bufsize=bufsize, - wnVO=pilotParams.wnVO, + jwt=pilotParams.jwt, + legacy_logging=pilotParams.isLegacyPilot, + clientID=pilotParams.clientID, ) self.log.isPilotLoggerOn = isPilotLoggerOn @@ -770,7 +789,12 @@ def executeAndGetOutput(self, cmd, environDict=None): self.log.info("Executing command %s" % cmd) _p = subprocess.Popen( - cmd, shell=True, env=environDict, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=False + cmd, + shell=True, + env=environDict, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=False, ) # Use non-blocking I/O on the process pipes @@ -787,26 +811,18 @@ def executeAndGetOutput(self, cmd, environDict=None): if not outChunk: continue dataWasRead = True - if sys.version_info.major == 2: - # Ensure outChunk is unicode in Python 2 - if isinstance(outChunk, str): - outChunk = outChunk.decode("utf-8") - # Strip unicode replacement characters - # Ensure correct type conversion in Python 2 - outChunk = str(outChunk.replace("\ufffd", "")) - # Avoid potential str() issues in Py2 - outChunk = unicode(outChunk) # pylint: disable=undefined-variable - else: - outChunk = str(outChunk.replace("\ufffd", "")) # Python 3: Ensure it's a string if stream == _p.stderr: sys.stderr.write(outChunk) sys.stderr.flush() + # TODO: See if wee need also to log here else: sys.stdout.write(outChunk) sys.stdout.flush() if hasattr(self.log, "buffer") and self.log.isPilotLoggerOn: - self.log.buffer.write(outChunk) + self.log.buffer.write( + self.log.format_to_json("COMMAND", outChunk) + ) outData += outChunk # If no data was read on any of the pipes then the process has finished if not dataWasRead: @@ -832,7 +848,7 @@ def exitWithError(self, errorCode): self.log.info("List of child processes of current PID:") retCode, _outData = self.executeAndGetOutput( - "ps --forest -o pid,%%cpu,%%mem,tty,stat,time,cmd -g %d" % os.getpid() + "ps --forest -o pid,%%cpu,%%mem,tty,stat,time,cmd --ppid %d" % os.getpid() ) if retCode: self.log.error("Failed to issue ps [ERROR %d] " % retCode) @@ -852,7 +868,12 @@ def forkAndExecute(self, cmd, logFile, environDict=None): with open(logFile, "a+", 0) as fpLogFile: try: _p = subprocess.Popen( - "%s" % cmd, shell=True, env=environDict, close_fds=False, stdout=fpLogFile, stderr=fpLogFile + "%s" % cmd, + shell=True, + env=environDict, + close_fds=False, + stdout=fpLogFile, + stderr=fpLogFile, ) # return code @@ -909,10 +930,14 @@ def __init__(self): self.setup = "" self.configServer = "" self.preferredURLPatterns = "" + self.diracXServer = "" self.ceName = "" self.ceType = "" self.queueName = "" self.gridCEType = "" + self.pilotSecret = "" + self.clientID = "" + self.jwt = {} # maxNumberOfProcessors: the number of # processors allocated to the pilot which the pilot can allocate to one payload # used to set payloadProcessors unless other limits are reached (like the number of processors on the WN) @@ -947,8 +972,9 @@ def __init__(self): self.pilotCFGFile = "pilot.json" self.pilotLogging = False self.loggerURL = None + self.isLegacyPilot = False self.loggerTimerInterval = 0 - self.loggerBufsize = 1000 + self.loggerBufsize = 250 self.pilotUUID = "unknown" self.modules = "" self.userEnvVariables = "" @@ -964,7 +990,9 @@ def __init__(self): # Set number of allocatable processors from MJF if available try: - self.pilotProcessors = int(urlopen(os.path.join(os.environ["JOBFEATURES"], "allocated_cpu")).read()) + self.pilotProcessors = int( + urlopen(os.path.join(os.environ["JOBFEATURES"], "allocated_cpu")).read() + ) except Exception: self.pilotProcessors = 1 @@ -981,7 +1009,11 @@ def __init__(self): ("l:", "project=", "Project to install"), ("n:", "name=", "Set as Site Name"), ("o:", "option=", "Option=value to add"), - ("m:", "maxNumberOfProcessors=", "specify a max number of processors to use by the payload inside a pilot"), + ( + "m:", + "maxNumberOfProcessors=", + "specify a max number of processors to use by the payload inside a pilot", + ), ("", "modules=", "for installing non-released code"), ( "", @@ -997,6 +1029,7 @@ def __init__(self): ("y:", "CEType=", "CE Type (normally InProcess)"), ("z", "pilotLogging", "Activate pilot logging system"), ("C:", "configurationServer=", "Configuration servers to use"), + ("", "diracx_URL=", "DiracX Server URL to use"), ("D:", "disk=", "Require at least MB available"), ("E:", "commandExtensions=", "Python modules with extra commands"), ("F:", "pilotCFGFile=", "Specify pilot CFG file"), @@ -1004,7 +1037,11 @@ def __init__(self): ("K:", "certLocation=", "Specify server certificate location"), ("M:", "MaxCycles=", "Maximum Number of JobAgent cycles to run"), ("", "PollingTime=", "JobAgent execution frequency"), - ("", "StopOnApplicationFailure=", "Stop Job Agent when encounter an application failure"), + ( + "", + "StopOnApplicationFailure=", + "Stop Job Agent when encounter an application failure", + ), ("", "StopAfterFailedMatches=", "Stop Job Agent after N failed matches"), ("N:", "Name=", "CE Name"), ("O:", "OwnerDN=", "Pilot OwnerDN (for private pilots)"), @@ -1014,14 +1051,24 @@ def __init__(self): ("R:", "reference=", "Use this pilot reference"), ("S:", "setup=", "DIRAC Setup to use"), ("T:", "CPUTime=", "Requested CPU Time"), - ("W:", "gateway=", "Configure as DIRAC Gateway during installation"), + ( + "W:", + "gateway=", + "Configure as DIRAC Gateway during installation", + ), ("X:", "commands=", "Pilot commands to execute"), ("Z:", "commandOptions=", "Options parsed by command modules"), ("", "pilotUUID=", "pilot UUID"), ("", "preinstalledEnv=", "preinstalled pilot environment script location"), - ("", "preinstalledEnvPrefix=", "preinstalled pilot environment area prefix"), + ( + "", + "preinstalledEnvPrefix=", + "preinstalled pilot environment area prefix", + ), ("", "architectureScript=", "architecture script to use"), ("", "CVMFS_locations=", "comma-separated list of CVMS locations"), + ("", "pilotSecret=", "secret that the pilot uses with DiracX"), + ("", "clientID=", "client id used by DiracX to revoke a token"), ) # Possibly get Setup and JSON URL/filename from command line @@ -1048,6 +1095,74 @@ def __init__(self): self.installEnv["X509_USER_PROXY"] = self.certsLocation os.environ["X509_USER_PROXY"] = self.certsLocation + try: + self.__get_diracx_jwt() + except Exception as e: + self.log.error("Error setting DiracX: %s" % e) + # Remove all settings to prevent using it. + self.diracXServer = None + self.pilotSecret = None + self.loggerURL = None + self.jwt = {} + self.log.error("Won't use DiracX.") + + def __get_diracx_jwt(self): + # Pilot auth: two cases + # 1. Has a secret (DiracX Pilot), exchange for a token + # 2. Legacy Pilot, has a proxy with a DiracX section in it (extract the jwt from it) + if self.pilotUUID and self.pilotSecret and self.diracXServer: + self.log.info("Fetching JWT in DiracX (URL: %s)" % self.diracXServer) + + config = BaseRequest( + "%s/api/auth/secret-exchange" % (self.diracXServer), + os.getenv("X509_CERT_DIR"), + self.pilotUUID, + ) + + try: + self.jwt = config.executeRequest( + {"pilot_stamp": self.pilotUUID, "pilot_secret": self.pilotSecret} + ) + except HTTPError as e: + self.log.error("Request failed: %s" % str(e)) + self.log.error("Could not fetch pilot tokens.") + if e.code == 401: + # First test if the error occurred because of "bad pilot_stamp" + # If so, this pilot is in the vacuum case + # So we redo auth, but this time with the right data for vacuum cases + self.log.error("Retrying with vacuum case data...") + self.jwt = config.executeRequest( + { + "pilot_stamp": self.pilotUUID, + "pilot_secret": self.pilotSecret, + "vo": self.wnVO, + "grid_type": self.gridCEType, + "grid_site": self.site, + "status": "Running", + } + ) + else: + raise RuntimeError("Can't be a vacuum case.") + + self.log.info("Fetched the pilot token with the pilot secret.") + self.isLegacyPilot = False + elif self.pilotUUID and self.diracXServer: + # Try to extract a token for proxy + self.log.info("Trying to extract diracx token from proxy.") + + cert = os.getenv("X509_USER_PROXY") + if cert: + with open(cert, "rb") as fp: + self.jwt = extract_diracx_payload(fp.read()) + self.isLegacyPilot = True + self.log.info("Successfully extracted token from proxy.") + else: + raise RuntimeError("Could not locate a proxy via X509_USER_PROXY") + else: + self.log.info( + "PilotUUID, pilotSecret, and diracXServer are needed to support DiracX." + ) + def __setSecurityDir(self, envName, dirLocation): """Set the environment variable of the `envName`, and add it also to the Pilot Parameters @@ -1095,7 +1210,8 @@ def __checkSecurityDir(self, envName, dirName): # If so, just return if envName in os.environ and safe_listdir(os.environ[envName]): self.log.debug( - "%s is set in the host environment as %s, aligning installEnv to it" % (envName, os.environ[envName]) + "%s is set in the host environment as %s, aligning installEnv to it" + % (envName, os.environ[envName]) ) else: # None of the candidates exists, stop the program. @@ -1106,7 +1222,9 @@ def __initCommandLine1(self): """Parses and interpret options on the command line: first pass (essential things)""" self.optList, __args__ = getopt.getopt( - sys.argv[1:], "".join([opt[0] for opt in self.cmdOpts]), [opt[1] for opt in self.cmdOpts] + sys.argv[1:], + "".join([opt[0] for opt in self.cmdOpts]), + [opt[1] for opt in self.cmdOpts], ) self.log.debug("Options list: %s" % self.optList) for o, v in self.optList: @@ -1132,7 +1250,9 @@ def __initCommandLine2(self): """ self.optList, __args__ = getopt.getopt( - sys.argv[1:], "".join([opt[0] for opt in self.cmdOpts]), [opt[1] for opt in self.cmdOpts] + sys.argv[1:], + "".join([opt[0] for opt in self.cmdOpts]), + [opt[1] for opt in self.cmdOpts], ) for o, v in self.optList: if o == "-E" or o == "--commandExtensions": @@ -1141,7 +1261,9 @@ def __initCommandLine2(self): self.commands = v.split(",") elif o == "-Z" or o == "--commandOptions": for i in v.split(","): - self.commandOptions[i.split("=", 1)[0].strip()] = i.split("=", 1)[1].strip() + self.commandOptions[i.split("=", 1)[0].strip()] = i.split("=", 1)[ + 1 + ].strip() elif o == "-e" or o == "--extraPackages": self.extensions = v.split(",") elif o == "-n" or o == "--name": @@ -1152,6 +1274,8 @@ def __initCommandLine2(self): self.keepPythonPath = True elif o in ("-C", "--configurationServer"): self.configServer = v + elif o == "--diracx_URL": + self.diracXServer = v elif o in ("-G", "--Group"): self.userGroup = v elif o in ("-x", "--execute"): @@ -1225,6 +1349,10 @@ def __initCommandLine2(self): self.architectureScript = v elif o == "--CVMFS_locations": self.CVMFS_locations = v.split(",") + elif o == "--pilotSecret": + self.pilotSecret = v + elif o == "--clientID": + self.clientID = v def __loadJSON(self): """ @@ -1260,27 +1388,40 @@ def __initJSON2(self): self.pilotLogging = pilotLogging.upper() == "TRUE" self.loggerURL = pilotOptions.get("RemoteLoggerURL") # logger buffer flush interval in seconds. - self.loggerTimerInterval = int(pilotOptions.get("RemoteLoggerTimerInterval", self.loggerTimerInterval)) + self.loggerTimerInterval = int( + pilotOptions.get("RemoteLoggerTimerInterval", self.loggerTimerInterval) + ) # logger buffer size in lines: - self.loggerBufsize = max(1, int(pilotOptions.get("RemoteLoggerBufsize", self.loggerBufsize))) + self.loggerBufsize = max( + 1, int(pilotOptions.get("RemoteLoggerBufsize", self.loggerBufsize)) + ) # logger CE white list loggerCEsWhiteList = pilotOptions.get("RemoteLoggerCEsWhiteList") # restrict remote logging to whitelisted CEs ([] or None => no restriction) self.log.debug("JSON: Remote logging CE white list: %s" % loggerCEsWhiteList) if loggerCEsWhiteList is not None: if not isinstance(loggerCEsWhiteList, list): - loggerCEsWhiteList = [elem.strip() for elem in loggerCEsWhiteList.split(",")] + loggerCEsWhiteList = [ + elem.strip() for elem in loggerCEsWhiteList.split(",") + ] if self.ceName not in loggerCEsWhiteList: self.pilotLogging = False - self.log.debug("JSON: Remote logging disabled for this CE: %s" % self.ceName) + self.log.debug( + "JSON: Remote logging disabled for this CE: %s" % self.ceName + ) pilotLogLevel = pilotOptions.get("PilotLogLevel", "INFO") if pilotLogLevel.lower() == "debug": self.debugFlag = True self.log.debug("JSON: Remote logging: %s" % self.pilotLogging) self.log.debug("JSON: Remote logging URL: %s" % self.loggerURL) - self.log.debug("JSON: Remote logging buffer flush interval in sec.(0: disabled): %s" % self.loggerTimerInterval) + self.log.debug( + "JSON: Remote logging buffer flush interval in sec.(0: disabled): %s" + % self.loggerTimerInterval + ) self.log.debug("JSON: Remote/local logging debug flag: %s" % self.debugFlag) - self.log.debug("JSON: Remote logging buffer size (lines): %s" % self.loggerBufsize) + self.log.debug( + "JSON: Remote logging buffer size (lines): %s" % self.loggerBufsize + ) # CE type if present, then Defaults, otherwise as defined in the code: if "Commands" in pilotOptions: @@ -1292,7 +1433,9 @@ def __initJSON2(self): else: # TODO: This is a workaround until the pilot JSON syncroniser is fixed self.commands = [elem.strip() for elem in commands.split(",")] - self.log.debug("Selecting commands from JSON for Grid CE type %s" % key) + self.log.debug( + "Selecting commands from JSON for Grid CE type %s" % key + ) break else: key = "CodeDefaults" @@ -1302,12 +1445,18 @@ def __initJSON2(self): # Command extensions for the commands above: commandExtOptions = pilotOptions.get("CommandExtensions") if commandExtOptions: - self.commandExtensions = [elem.strip() for elem in commandExtOptions.split(",")] + self.commandExtensions = [ + elem.strip() for elem in commandExtOptions.split(",") + ] # Configuration server (the synchroniser looks into gConfig.getServersList(), as before # the generic one (a list): - self.configServer = ",".join([str(pv).strip() for pv in self.pilotJSON["ConfigurationServers"]]) + self.configServer = ",".join( + [str(pv).strip() for pv in self.pilotJSON["ConfigurationServers"]] + ) - self.preferredURLPatterns = self.pilotJSON.get("PreferredURLPatterns", self.preferredURLPatterns) + self.preferredURLPatterns = self.pilotJSON.get( + "PreferredURLPatterns", self.preferredURLPatterns + ) # version(a comma separated values in a string). We take the first one. (the default value defined in the code) dVersion = pilotOptions.get("Version", self.releaseVersion) @@ -1317,13 +1466,19 @@ def __initJSON2(self): else: self.log.warn("Could not find a version in the JSON file configuration") - self.log.debug("Version: %s -> (release) %s" % (str(dVersion), self.releaseVersion)) + self.log.debug( + "Version: %s -> (release) %s" % (str(dVersion), self.releaseVersion) + ) - self.releaseProject = pilotOptions.get("Project", self.releaseProject) # default from the code. + self.releaseProject = pilotOptions.get( + "Project", self.releaseProject + ) # default from the code. self.log.debug("Release project: %s" % self.releaseProject) if "CVMFS_locations" in pilotOptions: - self.CVMFS_locations = pilotOptions["CVMFS_locations"].replace(" ", "").split(",") + self.CVMFS_locations = ( + pilotOptions["CVMFS_locations"].replace(" ", "").split(",") + ) self.log.debug("CVMFS locations: %s" % self.CVMFS_locations) def getPilotOptionsDict(self): @@ -1361,7 +1516,10 @@ def __getVO(self): with open(cert, "rb") as fp: return getVO(fp.read()) except IOError as err: - self.log.error("Could not read a proxy, setting vo to 'unknown': %s" % os.strerror(err.errno)) + self.log.error( + "Could not read a proxy, setting vo to 'unknown': %s" + % os.strerror(err.errno) + ) else: self.log.error("Could not locate a proxy via X509_USER_PROXY") @@ -1458,46 +1616,78 @@ def __initJSON(self): # Commands first # FIXME: pilotSynchronizer() should publish these as comma-separated lists. We are ready for that. try: - if isinstance(self.pilotJSON["Setups"][self.setup]["Commands"][self.gridCEType], basestring): + if isinstance( + self.pilotJSON["Setups"][self.setup]["Commands"][self.gridCEType], str + ): self.commands = [ str(pv).strip() - for pv in self.pilotJSON["Setups"][self.setup]["Commands"][self.gridCEType].split(",") + for pv in self.pilotJSON["Setups"][self.setup]["Commands"][ + self.gridCEType + ].split(",") ] else: self.commands = [ - str(pv).strip() for pv in self.pilotJSON["Setups"][self.setup]["Commands"][self.gridCEType] + str(pv).strip() + for pv in self.pilotJSON["Setups"][self.setup]["Commands"][ + self.gridCEType + ] ] except KeyError: try: - if isinstance(self.pilotJSON["Setups"][self.setup]["Commands"]["Defaults"], basestring): + if isinstance( + self.pilotJSON["Setups"][self.setup]["Commands"]["Defaults"], str + ): self.commands = [ str(pv).strip() - for pv in self.pilotJSON["Setups"][self.setup]["Commands"]["Defaults"].split(",") + for pv in self.pilotJSON["Setups"][self.setup]["Commands"][ + "Defaults" + ].split(",") ] else: self.commands = [ - str(pv).strip() for pv in self.pilotJSON["Setups"][self.setup]["Commands"]["Defaults"] + str(pv).strip() + for pv in self.pilotJSON["Setups"][self.setup]["Commands"][ + "Defaults" + ] ] except KeyError: try: - if isinstance(self.pilotJSON["Setups"]["Defaults"]["Commands"][self.gridCEType], basestring): + if isinstance( + self.pilotJSON["Setups"]["Defaults"]["Commands"][ + self.gridCEType + ], + str, + ): self.commands = [ str(pv).strip() - for pv in self.pilotJSON["Setups"]["Defaults"]["Commands"][self.gridCEType].split(",") + for pv in self.pilotJSON["Setups"]["Defaults"]["Commands"][ + self.gridCEType + ].split(",") ] else: self.commands = [ - str(pv).strip() for pv in self.pilotJSON["Setups"]["Defaults"]["Commands"][self.gridCEType] + str(pv).strip() + for pv in self.pilotJSON["Setups"]["Defaults"]["Commands"][ + self.gridCEType + ] ] except KeyError: try: - if isinstance(self.pilotJSON["Defaults"]["Commands"]["Defaults"], basestring): + if isinstance( + self.pilotJSON["Defaults"]["Commands"]["Defaults"], str + ): self.commands = [ - str(pv).strip() for pv in self.pilotJSON["Defaults"]["Commands"]["Defaults"].split(",") + str(pv).strip() + for pv in self.pilotJSON["Defaults"]["Commands"][ + "Defaults" + ].split(",") ] else: self.commands = [ - str(pv).strip() for pv in self.pilotJSON["Defaults"]["Commands"]["Defaults"] + str(pv).strip() + for pv in self.pilotJSON["Defaults"]["Commands"][ + "Defaults" + ] ] except KeyError: pass @@ -1507,26 +1697,36 @@ def __initJSON(self): # pilotSynchronizer() can publish this as a comma separated list. We are ready for that. try: if isinstance( - self.pilotJSON["Setups"][self.setup]["CommandExtensions"], basestring + self.pilotJSON["Setups"][self.setup]["CommandExtensions"], str ): # In the specific setup? self.commandExtensions = [ - str(pv).strip() for pv in self.pilotJSON["Setups"][self.setup]["CommandExtensions"].split(",") + str(pv).strip() + for pv in self.pilotJSON["Setups"][self.setup][ + "CommandExtensions" + ].split(",") ] else: self.commandExtensions = [ - str(pv).strip() for pv in self.pilotJSON["Setups"][self.setup]["CommandExtensions"] + str(pv).strip() + for pv in self.pilotJSON["Setups"][self.setup]["CommandExtensions"] ] except KeyError: try: if isinstance( - self.pilotJSON["Setups"]["Defaults"]["CommandExtensions"], basestring + self.pilotJSON["Setups"]["Defaults"]["CommandExtensions"], str ): # Or in the defaults section? self.commandExtensions = [ - str(pv).strip() for pv in self.pilotJSON["Setups"]["Defaults"]["CommandExtensions"].split(",") + str(pv).strip() + for pv in self.pilotJSON["Setups"]["Defaults"][ + "CommandExtensions" + ].split(",") ] else: self.commandExtensions = [ - str(pv).strip() for pv in self.pilotJSON["Setups"]["Defaults"]["CommandExtensions"] + str(pv).strip() + for pv in self.pilotJSON["Setups"]["Defaults"][ + "CommandExtensions" + ] ] except KeyError: pass @@ -1536,40 +1736,62 @@ def __initJSON(self): # pilotSynchronizer() can publish this as a comma separated list. We are ready for that try: if isinstance( - self.pilotJSON["ConfigurationServers"], basestring + self.pilotJSON["ConfigurationServers"], str ): # Generic, there may also be setup-specific ones self.configServer = ",".join( - [str(pv).strip() for pv in self.pilotJSON["ConfigurationServers"].split(",")] + [ + str(pv).strip() + for pv in self.pilotJSON["ConfigurationServers"].split(",") + ] ) else: # it's a list, we suppose - self.configServer = ",".join([str(pv).strip() for pv in self.pilotJSON["ConfigurationServers"]]) + self.configServer = ",".join( + [str(pv).strip() for pv in self.pilotJSON["ConfigurationServers"]] + ) except KeyError: pass try: # now trying to see if there is setup-specific ones if isinstance( - self.pilotJSON["Setups"][self.setup]["ConfigurationServer"], basestring + self.pilotJSON["Setups"][self.setup]["ConfigurationServer"], str ): # In the specific setup? self.configServer = ",".join( - [str(pv).strip() for pv in self.pilotJSON["Setups"][self.setup]["ConfigurationServer"].split(",")] + [ + str(pv).strip() + for pv in self.pilotJSON["Setups"][self.setup][ + "ConfigurationServer" + ].split(",") + ] ) else: # it's a list, we suppose self.configServer = ",".join( - [str(pv).strip() for pv in self.pilotJSON["Setups"][self.setup]["ConfigurationServer"]] + [ + str(pv).strip() + for pv in self.pilotJSON["Setups"][self.setup][ + "ConfigurationServer" + ] + ] ) except KeyError: # and if it doesn't exist try: if isinstance( - self.pilotJSON["Setups"]["Defaults"]["ConfigurationServer"], basestring + self.pilotJSON["Setups"]["Defaults"]["ConfigurationServer"], str ): # Is there one in the defaults section? self.configServer = ",".join( [ str(pv).strip() - for pv in self.pilotJSON["Setups"]["Defaults"]["ConfigurationServer"].split(",") + for pv in self.pilotJSON["Setups"]["Defaults"][ + "ConfigurationServer" + ].split(",") ] ) else: # it's a list, we suppose self.configServer = ",".join( - [str(pv).strip() for pv in self.pilotJSON["Setups"]["Defaults"]["ConfigurationServer"]] + [ + str(pv).strip() + for pv in self.pilotJSON["Setups"]["Defaults"][ + "ConfigurationServer" + ] + ] ) except KeyError: pass @@ -1579,10 +1801,18 @@ def __initJSON(self): # There may be a list of versions specified (in a string, comma separated). We just want the first one. dVersion = None try: - dVersion = [dv.strip() for dv in self.pilotJSON["Setups"][self.setup]["Version"].split(",", 1)] + dVersion = [ + dv.strip() + for dv in self.pilotJSON["Setups"][self.setup]["Version"].split(",", 1) + ] except KeyError: try: - dVersion = [dv.strip() for dv in self.pilotJSON["Setups"]["Defaults"]["Version"].split(",", 1)] + dVersion = [ + dv.strip() + for dv in self.pilotJSON["Setups"]["Defaults"]["Version"].split( + ",", 1 + ) + ] except KeyError: self.log.warn("Could not find a version in the JSON file configuration") if dVersion is not None: @@ -1593,7 +1823,9 @@ def __initJSON(self): self.releaseProject = str(self.pilotJSON["Setups"][self.setup]["Project"]) except KeyError: try: - self.releaseProject = str(self.pilotJSON["Setups"]["Defaults"]["Project"]) + self.releaseProject = str( + self.pilotJSON["Setups"]["Defaults"]["Project"] + ) except KeyError: pass self.log.debug("Release project: %s" % self.releaseProject) @@ -1613,7 +1845,9 @@ def __ceType(self): try: if not self.gridCEType: # We don't override a grid CEType given on the command line! - self.gridCEType = str(self.pilotJSON["CEs"][self.ceName]["GridCEType"]) + self.gridCEType = str( + self.pilotJSON["CEs"][self.ceName]["GridCEType"] + ) except KeyError: pass # This LocalCEType is like 'InProcess' or 'Pool' or 'Pool/Singularity' etc. @@ -1623,7 +1857,9 @@ def __ceType(self): except KeyError: pass try: - self.ceType = str(self.pilotJSON["CEs"][self.ceName][self.queueName]["LocalCEType"]) + self.ceType = str( + self.pilotJSON["CEs"][self.ceName][self.queueName]["LocalCEType"] + ) except KeyError: pass diff --git a/Pilot/proxyTools.py b/Pilot/proxyTools.py index a5fa652e..765ae465 100644 --- a/Pilot/proxyTools.py +++ b/Pilot/proxyTools.py @@ -1,15 +1,25 @@ -"""few functions for dealing with proxies""" - -from __future__ import absolute_import, division, print_function +"""few functions for dealing with proxies and authentication""" +import json +import os import re -from base64 import b16decode +import ssl +import sys +import time +from base64 import b16decode, b64decode +from random import randint from subprocess import PIPE, Popen +from urllib.error import HTTPError +from urllib.parse import urlencode +from urllib.request import Request, urlopen VOMS_FQANS_OID = b"1.3.6.1.4.1.8005.100.100.4" VOMS_EXTENSION_OID = b"1.3.6.1.4.1.8005.100.100.5" -RE_OPENSSL_ANS1_FORMAT = re.compile(br"^\s*\d+:d=(\d+)\s+hl=") +RE_OPENSSL_ANS1_FORMAT = re.compile(rb"^\s*\d+:d=(\d+)\s+hl=") + +MAX_REQUEST_RETRIES = 10 # If a request failed (503 error), we retry +MAX_TIME_BETWEEN_TRIES = 20 # 20 seconds max between each request def parseASN1(data): @@ -30,18 +40,17 @@ def findExtension(oid, lines): def getVO(proxy_data): """Fetches the VO in a chain certificate - Args: - proxy_data (bytes): Bytes for the proxy chain - - Raises: - Exception: Any error related to openssl - NotImplementedError: Not documented error - - Returns: - str: A VO + :param proxy_data: Bytes for the proxy chain + :type proxy_data: bytes + :return: A VO + :rtype: str """ - chain = re.findall(br"-----BEGIN CERTIFICATE-----\n.+?\n-----END CERTIFICATE-----", proxy_data, flags=re.DOTALL) + chain = re.findall( + rb"-----BEGIN CERTIFICATE-----\n.+?\n-----END CERTIFICATE-----", + proxy_data, + flags=re.DOTALL, + ) for cert in chain: proc = Popen(["openssl", "x509", "-outform", "der"], stdin=PIPE, stdout=PIPE) out, _ = proc.communicate(cert) @@ -52,16 +61,374 @@ def getVO(proxy_data): idx_voms_line = findExtension(VOMS_EXTENSION_OID, cert_info) if idx_voms_line is None: continue - voms_extension = parseASN1(b16decode(cert_info[idx_voms_line + 1].split(b":")[-1])) + voms_extension = parseASN1( + b16decode(cert_info[idx_voms_line + 1].split(b":")[-1]) + ) # Look for the attribute names idx_fqans = findExtension(VOMS_FQANS_OID, voms_extension) - (initial_depth,) = map(int, RE_OPENSSL_ANS1_FORMAT.match(voms_extension[idx_fqans - 1]).groups()) + (initial_depth,) = map( + int, RE_OPENSSL_ANS1_FORMAT.match(voms_extension[idx_fqans - 1]).groups() + ) for line in voms_extension[idx_fqans:]: (depth,) = map(int, RE_OPENSSL_ANS1_FORMAT.match(line).groups()) if depth <= initial_depth: break # Look for a role, if it exists the VO is the first element - match = re.search(br"OCTET STRING\s+:/([a-zA-Z0-9]+)/Role=", line) + match = re.search(rb"OCTET STRING\s+:/([a-zA-Z0-9]+)/Role=", line) if match: return match.groups()[0].decode() raise NotImplementedError("Something went very wrong") + + +def extract_diracx_payload(proxy_data): + """Extracts and decodes the DIRACX section from proxy data + + :param proxy_data: The full proxy content (str or bytes) + :return: Parsed DIRACX payload as dict + :rtype: dict + """ + if isinstance(proxy_data, bytes): + proxy_data = proxy_data.decode("utf-8") + + # 1. Extract the DIRACX block + match = re.search( + r"-----BEGIN DIRACX-----(.*?)-----END DIRACX-----", proxy_data, re.DOTALL + ) + if not match: + raise ValueError("DIRACX section not found") + + # 2. Remove whitespaces/newlines and base64-decode the inner content + b64_data = "".join(match.group(1).strip().splitlines()) + + # 3. Base64 decode + try: + decoded = b64decode(b64_data) + except Exception as e: + raise ValueError("Base64 decoding failed: %s" % str(e)) + + # 4. JSON decode + try: + payload = json.loads(decoded) + except Exception as e: + raise ValueError("JSON decoding failed: %s" % str(e)) + + return payload + + +class BaseRequest(object): + """This class helps supporting multiple kinds of requests that require connections""" + + def __init__(self, url, caPath, pilotUUID, name="unknown"): + self.name = name + self.url = url + self.caPath = caPath + self.headers = {"User-Agent": "Dirac Pilot [Unknown ID]"} + self.pilotUUID = pilotUUID + # We assume we have only one context, so this variable could be shared to avoid opening n times a cert. + # On the contrary, to avoid race conditions, we do avoid using "self.data" and "self.headers" + self._context = None + + self._prepareRequest() + + def generateUserAgent(self): + """To analyse the traffic, we can send a taylor-made User-Agent""" + self.addHeader("User-Agent", "Dirac Pilot [%s]" % self.pilotUUID) + + def _prepareRequest(self): + """As previously, loads the SSL certificates of the server (to avoid "unknown issuer")""" + # Load the SSL context + self._context = ssl.create_default_context() + self._context.load_verify_locations(capath=self.caPath) + + def addHeader(self, key, value): + """Add a header (key, value) into the request header""" + self.headers[key] = value + + def executeRequest( + self, raw_data, insecure=False, content_type="json", json_output=True + ): + tries_left = MAX_REQUEST_RETRIES + + while tries_left > 0: + try: + return self.__execute_raw_request( + raw_data=raw_data, + insecure=insecure, + content_type=content_type, + json_output=json_output, + ) + except HTTPError as e: + if e.code >= 500 and e.code < 600: + # If we have an 5XX error (server overloaded), we retry + # To avoid DOS-ing the server, we retry few seconds later + time.sleep(randint(1, MAX_TIME_BETWEEN_TRIES)) + else: + raise e + + tries_left -= 1 + + raise RuntimeError("Too much tries. Server down.") + + def __execute_raw_request( + self, raw_data, insecure=False, content_type="json", json_output=True + ): + """Execute a HTTP request with the data, headers, and the pre-defined data (SSL + auth) + + :param raw_data: Data to send + :type raw_data: dict + :param insecure: Deactivate proxy verification WARNING Debug ONLY + :type insecure: bool + :param content_type: Data format to send, either "json" or "x-www-form-urlencoded" or "query" + :type content_type: str + :param json_output: If we have an output + :type json_output: bool + :return: Parsed JSON response + :rtype: dict + """ + if content_type == "json": + data = json.dumps(raw_data).encode("utf-8") + self.addHeader("Content-Type", "application/json") + self.addHeader("Content-Length", str(len(data))) + else: + data = urlencode(raw_data) + + if content_type == "x-www-form-urlencoded": + if sys.version_info.major == 3: + data = urlencode(raw_data).encode( + "utf-8" + ) # encode to bytes ! for python3 + + self.addHeader("Content-Type", "application/x-www-form-urlencoded") + self.addHeader("Content-Length", str(len(data))) + elif content_type == "query": + self.url = self.url + "?" + data + data = None # No body + else: + raise ValueError( + "Invalid content_type. Use 'json' or 'x-www-form-urlencoded'." + ) + + request = Request(self.url, data=data, headers=self.headers, method="POST") + + ctx = self._context # Save in case of an insecure request + + if insecure: + # DEBUG ONLY + # Overrides context + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + if sys.version_info.major == 3: + # Python 3 code + with urlopen(request, context=ctx) as res: + response_data = res.read().decode("utf-8") # Decode response bytes + else: + # Python 2 code + res = urlopen(request, context=ctx) + try: + response_data = res.read() + finally: + res.close() + + if json_output: + try: + return json.loads(response_data) # Parse JSON response + except ( + ValueError + ): # In Python 2, json.JSONDecodeError is a subclass of ValueError + raise ValueError("Invalid JSON response: %s" % response_data) + + +class TokenBasedRequest(BaseRequest): + """Connected Request with JWT support""" + + def __init__(self, diracx_URL, endpoint_path, caPath, jwtData, pilotUUID): + url = diracx_URL + endpoint_path + + super(TokenBasedRequest, self).__init__( + url, caPath, pilotUUID, "TokenBasedConnection" + ) + self.jwtData = jwtData + self.diracx_URL = diracx_URL + self.endpoint_path = endpoint_path + self.addJwtToHeader() + + def addJwtToHeader(self): + # Adds the JWT in the HTTP request (in the Bearer field) + self.headers["Authorization"] = "Bearer %s" % self.jwtData["access_token"] + + def executeRequest( + self, + raw_data, + insecure=False, + content_type="json", + json_output=True, + tries_left=1, + refresh_callback=None, + ): + while tries_left >= 0: + try: + return super(TokenBasedRequest, self).executeRequest( + raw_data, + insecure=insecure, + content_type=content_type, + json_output=json_output, + ) + except HTTPError as e: + if e.code != 401: + raise e + + # If we have an unauthorized error, then refresh and retry + if refresh_callback: + refresh_callback() + + self.addJwtToHeader() + + tries_left -= 1 + + raise RuntimeError("Too much tries. Can't refresh my token.") + + +class X509BasedRequest(BaseRequest): + """Connected Request with X509 support""" + + def __init__(self, url, caPath, certEnv, pilotUUID): + super(X509BasedRequest, self).__init__( + url, caPath, pilotUUID, "X509BasedConnection" + ) + + self.certEnv = certEnv + self._hasExtraCredentials = False + + # Load X509 once + try: + self._context.load_cert_chain(self.certEnv) + except IsADirectoryError: # assuming it'a dir containing cert and key + self._context.load_cert_chain( + os.path.join(self.certEnv, "hostcert.pem"), + os.path.join(self.certEnv, "hostkey.pem"), + ) + self._hasExtraCredentials = True + + def executeRequest( + self, raw_data, insecure=False, content_type="json", json_output=True + ): + # Adds a flag if the passed cert is a Directory + if self._hasExtraCredentials: + raw_data["extraCredentials"] = '"hosts"' + return super(X509BasedRequest, self).executeRequest( + raw_data, + insecure=insecure, + content_type=content_type, + json_output=json_output, + ) + + +def refreshUserToken(url, pilotUUID, jwt, clientID): + """ + Refresh the JWT token (as a user). + + :param str url: Server URL + :param str pilotUUID: Pilot unique ID + :param dict jwt: Shared dict with current JWT; updated in-place + :return: None + """ + + # PRECONDITION: jwt must contain "refresh_token" + if not jwt or "refresh_token" not in jwt: + raise ValueError("To refresh a token, a pilot needs a JWT with refresh_token") + + # Get CA path from environment + caPath = os.getenv("X509_CERT_DIR") + + # Create request object with required configuration + config = BaseRequest( + url=url + "api/auth/token", + caPath=caPath, + pilotUUID=pilotUUID, + ) + + # Perform the request to refresh the token + response = config.executeRequest( + raw_data={ + "refresh_token": jwt["refresh_token"], + "grant_type": "refresh_token", + "client_id": clientID, + }, + content_type="x-www-form-urlencoded", + ) + + # Do NOT assign directly, because jwt is a reference, not a copy + jwt["access_token"] = response["access_token"] + jwt["refresh_token"] = response["refresh_token"] + + +def refreshPilotToken(url, pilotUUID, jwt, _=None): + """ + Refresh the JWT token (as a pilot). + + :param str url: Server URL + :param str pilotUUID: Pilot unique ID + :param dict jwt: Shared dict with current JWT; updated in-place + :return: None + """ + + # PRECONDITION: jwt must contain "refresh_token" + if not jwt or "refresh_token" not in jwt: + raise ValueError("To refresh a token, a pilot needs a JWT with refresh_token") + + # Get CA path from environment + caPath = os.getenv("X509_CERT_DIR") + + # Create request object with required configuration + config = BaseRequest( + url=url + "api/auth/pilot-token", + caPath=caPath, + pilotUUID=pilotUUID, + ) + + # Perform the request to refresh the token + response = config.executeRequest( + raw_data={"refresh_token": jwt["refresh_token"], "pilot_stamp": pilotUUID}, + insecure=True, + ) + + # Do NOT assign directly, because jwt is a reference, not a copy + jwt["access_token"] = response["access_token"] + jwt["refresh_token"] = response["refresh_token"] + + +def revokePilotToken(url, pilotUUID, jwt, clientID): + """ + Refresh the JWT token in a separate thread. + + :param str url: Server URL + :param str pilotUUID: Pilot unique ID + :param str clientID: ClientID used to revoke tokens + :param dict jwt: Shared dict with current JWT; + :return: None + """ + + # PRECONDITION: jwt must contain "refresh_token" + if not jwt or "refresh_token" not in jwt: + raise ValueError("To refresh a token, a pilot needs a JWT with refresh_token") + + # Get CA path from environment + caPath = os.getenv("X509_CERT_DIR") + + if not url.endswith("/"): + url = url + "/" + + # Create request object with required configuration + config = BaseRequest( + url="%sapi/auth/revoke" % url, caPath=caPath, pilotUUID=pilotUUID + ) + + # Prepare refresh token payload + payload = {"refresh_token": jwt["refresh_token"], "client_id": clientID} + + # Perform the request to revoke the token + _response = config.executeRequest( + raw_data=payload, insecure=True, content_type="query", json_output=False + ) diff --git a/Pilot/tests/Test_Pilot.py b/Pilot/tests/Test_Pilot.py index 8a1b75a1..75b600d1 100644 --- a/Pilot/tests/Test_Pilot.py +++ b/Pilot/tests/Test_Pilot.py @@ -1,7 +1,5 @@ """Test class for Pilot""" -from __future__ import absolute_import, division, print_function - import json import os import shutil @@ -12,8 +10,10 @@ # imports import unittest -from Pilot.pilotCommands import CheckWorkerNode, ConfigureSite, NagiosProbes -from Pilot.pilotTools import PilotParams +sys.path.insert(0, os.getcwd() + "/Pilot") + +from pilotCommands import CheckWorkerNode, ConfigureSite, NagiosProbes +from pilotTools import PilotParams class PilotTestCase(unittest.TestCase): diff --git a/Pilot/tests/Test_proxyTools.py b/Pilot/tests/Test_proxyTools.py index 7a8688cb..86935c22 100644 --- a/Pilot/tests/Test_proxyTools.py +++ b/Pilot/tests/Test_proxyTools.py @@ -1,23 +1,12 @@ -from __future__ import absolute_import, division, print_function - import os -import shlex import shutil -import subprocess import sys import unittest +from unittest.mock import patch -############################ -# python 2 -> 3 "hacks" -try: - from Pilot.proxyTools import getVO, parseASN1 -except ImportError: - from proxyTools import getVO, parseASN1 +sys.path.insert(0, os.getcwd() + "/Pilot") -try: - from unittest.mock import patch -except ImportError: - from mock import patch +from proxyTools import getVO, parseASN1 class TestProxyTools(unittest.TestCase): @@ -31,7 +20,7 @@ def test_getVO(self): os.remove(cert) self.assertEqual(vo, "fakevo") - @patch("Pilot.proxyTools.Popen") + @patch("proxyTools.Popen") def test_getVOPopenFails(self, popenMock): """ Check if an exception is raised when Popen return code is not 0. @@ -59,7 +48,7 @@ def test_getVOPopenFails(self, popenMock): getVO(data) self.assertEqual(str(exc.exception), msg) - @patch("Pilot.proxyTools.Popen") + @patch("proxyTools.Popen") def test_parseASN1Fails(self, popenMock): """Should raise an exception when Popen return code is !=0""" @@ -92,46 +81,7 @@ def __createFakeProxy(self, proxyFile): """ Create a fake proxy locally. """ - basedir = os.path.dirname(__file__) - shutil.copy(basedir + "/certs/user/userkey.pem", basedir + "/certs/user/userkey400.pem") - os.chmod(basedir + "/certs/user/userkey400.pem", 0o400) - ret = self.createFakeProxy( - basedir + "/certs/user/usercert.pem", - basedir + "/certs/user/userkey400.pem", - "fakeserver.cern.ch:15000", - "fakevo", - basedir + "/certs//host/hostcert.pem", - basedir + "/certs/host/hostkey.pem", - basedir + "/certs/ca", - proxyFile, - ) - os.remove(basedir + "/certs/user/userkey400.pem") - return ret - - def createFakeProxy(self, usercert, userkey, serverURI, vo, hostcert, hostkey, CACertDir, proxyfile): - """ - voms-proxy-fake --cert usercert.pem - --key userkey.pem - -rfc - -fqan "/fakevo/Role=user/Capability=NULL" - -uri fakeserver.cern.ch:15000 - -voms fakevo - -hostcert hostcert.pem - -hostkey hostkey.pem - -certdir ca - """ - opt = ( - '--cert %s --key %s -rfc -fqan "/fakevo/Role=user/Capability=NULL" -uri %s -voms %s -hostcert %s' - " -hostkey %s -certdir %s -out %s" - % (usercert, userkey, serverURI, vo, hostcert, hostkey, CACertDir, proxyfile) - ) - proc = subprocess.Popen( - shlex.split("voms-proxy-fake " + opt), - bufsize=1, - stdout=sys.stdout, - stderr=sys.stderr, - universal_newlines=True, - ) - proc.communicate() - return proc.returncode + shutil.copy(basedir + "/certs/voms/proxy.pem", proxyFile) + return 0 + diff --git a/Pilot/tests/Test_simplePilotLogger.py b/Pilot/tests/Test_simplePilotLogger.py index df2ac0c2..3a236da1 100644 --- a/Pilot/tests/Test_simplePilotLogger.py +++ b/Pilot/tests/Test_simplePilotLogger.py @@ -1,25 +1,17 @@ #!/usr/bin/env python -from __future__ import absolute_import, division, print_function - import json import os import random import string import sys import tempfile - -try: - from Pilot.pilotTools import CommandBase, Logger, PilotParams -except ImportError: - from pilotTools import CommandBase, Logger, PilotParams - import unittest +from unittest.mock import patch + +sys.path.insert(0, os.getcwd() + "/Pilot") -try: - from unittest.mock import patch -except ImportError: - from mock import patch +from pilotTools import CommandBase, Logger, PilotParams class TestPilotParams(unittest.TestCase): @@ -146,16 +138,10 @@ def test_executeAndGetOutput(self, popenMock, argvmock): for size in [1000, 1024, 1025, 2005]: random_str = "".join(random.choice(string.ascii_letters + "\n") for i in range(size)) - if sys.version_info.major == 3: - random_bytes = random_str.encode("UTF-8") - self.stdout_mock.write(random_bytes) - else: - self.stdout_mock.write(random_str) + random_bytes = random_str.encode("UTF-8") + self.stdout_mock.write(random_bytes) self.stdout_mock.seek(0) - if sys.version_info.major == 3: - self.stderr_mock.write("Errare humanum est!".encode("UTF-8")) - else: - self.stderr_mock.write("Errare humanum est!") + self.stderr_mock.write("Errare humanum est!".encode("UTF-8")) self.stderr_mock.seek(0) pp = PilotParams() diff --git a/environment.yml b/environment.yml index 41e0a564..72e2765e 100644 --- a/environment.yml +++ b/environment.yml @@ -11,7 +11,6 @@ dependencies: - requests # testing and development - pycodestyle - - caniusepython3 - coverage - mock - pylint