From 60e336dcf168a6821ae65c44725aaa890c306b1a Mon Sep 17 00:00:00 2001 From: Steven Timm Date: Tue, 14 May 2024 15:08:58 -0500 Subject: [PATCH 1/4] add new method of determining token expiration instaed of relying on jwt.decode error code --- .../NERSC/sources/NerscSFApi.py | 74 ++++++++++++------- 1 file changed, 49 insertions(+), 25 deletions(-) diff --git a/src/decisionengine_modules/NERSC/sources/NerscSFApi.py b/src/decisionengine_modules/NERSC/sources/NerscSFApi.py index 5d38532b..a084b3a1 100644 --- a/src/decisionengine_modules/NERSC/sources/NerscSFApi.py +++ b/src/decisionengine_modules/NERSC/sources/NerscSFApi.py @@ -6,7 +6,7 @@ """ import json -# import time +import time import os import jwt @@ -36,7 +36,7 @@ def __init__(self, config): self.logger = self.logger.bind( class_module=__name__.split(".")[-1], ) - self.localmap = {"uscms": "m2612", "fife": "m3249"} + self.localmap = {"uscms": "m2612", "fife": "m4599", "dunepro": "m3249" } self.keys_list = ["hours_given", "hours_used", "id", "project_hours_given", "project_hours_used", "repo_name"] def check_accesstoken(self, nersc_user): @@ -62,14 +62,14 @@ def check_accesstoken(self, nersc_user): except KeyError: self.logger.error(f"Unknown user '{nersc_user}', exiting") return None - + print(nersc_user) rawfile = params_dict["rawfile"] pemfile = params_dict["private_key"] clientidfile = params_dict["client_id_file"] with open(clientidfile) as cifile: client_id = cifile.read() client_id = client_id.rstrip() - + atoken = None if not os.path.exists(rawfile): self.logger.debug(f"{rawfile} does not exist. Need to generate") else: @@ -77,28 +77,36 @@ def check_accesstoken(self, nersc_user): with open(rawfile) as afile: atoken = afile.read() atoken = atoken.rstrip() - # HK> If the access token is expired, the flow goes directly to except jwt.ExpiredSignatureError - try: - jwt.decode(atoken, options={"verify_signature": False}) - self.logger.debug("Access Token not expired. Returning without generating a new access token") - return atoken # This means the existing access token is not expired. - - except jwt.ExpiredSignatureError: - self.logger.debug("Access Token expired") - - certs = pem.parse_file(pemfile) - private_key = str(certs[0]) - client = OAuth2Session( - client_id=client_id, client_secret=private_key, token_endpoint_auth_method="private_key_jwt" - ) - client.register_client_auth_method(PrivateKeyJWT(token_url)) - resp = client.fetch_token(token_url, grant_type="client_credentials") + # HK> If the access token is expired, the flow goes directly to except jwt.ExpiredSign + + if atoken is not None: + rvalue = jwt.decode(atoken, options={"verify_signature": False}) + ctime = int(time.time()) + diff = ctime - rvalue['exp'] + print( diff ) + else: + self.logger.debug("there is no access token file, setting diff high to indicate expired") + diff=10000000 - newtoken = resp["access_token"] + if diff < 0: + self.logger.debug("Access Token not expired. Returning without generating a new access token") + return atoken # This means the existing access token is not expired. - with open(rawfile, "w") as myfile: - myfile.write(newtoken) - return newtoken + else: + self.logger.debug("Access Token expired") + + certs = pem.parse_file(pemfile) + private_key = str(certs[0]) + client = OAuth2Session( + client_id=client_id, client_secret=private_key, token_endpoint_auth_method="private_key_jwt" + ) + client.register_client_auth_method(PrivateKeyJWT(token_url)) + resp = client.fetch_token(token_url, grant_type="client_credentials") + newtoken = resp["access_token"] + + with open(rawfile, "w") as myfile: + myfile.write(newtoken) + return newtoken def get_headers2(self, access_token): headers = {} @@ -118,16 +126,32 @@ def requests_nersc(self, username): def send_query(self): results = [] + print(self.constraints.get("usernames", [])) for username in self.constraints.get("usernames", []): + self.logger.debug("in send_query %s",username) + print(username) returned_list = self.requests_nersc(username) + self.logger.debug(returned_list) + print(returned_list) for each_dict in returned_list: # HK> This if condition will choose only m3249 for fife and discard m3990 if each_dict["repo_name"] == self.localmap[username]: local_dict = {each_key: each_dict[each_key] for each_key in self.keys_list} - local_dict["real_name"] = username + local_dict['real_name'] = username results.append(local_dict) return results +#self.localmap = {"uscms": "m2612", "fife": "m3249"} +#self.keys_list = [ +# "hours_given", "hours_used", "id", "project_hours_given", "project_hours_used", "repo_name" ] + +#+----+---------------+--------------+-------+-----------------------+----------------------+-------------+ +#| | hours_given | hours_used | id | project_hours_given | project_hours_used | repo_name | +#|----+---------------+--------------+-------+-----------------------+----------------------+-------------| +#| 0 | 600000 | 473490 | 54807 | 600000 | 473946 | m2612 | +#| 1 | 19109.1 | 0 | 63322 | 95545.7 | 24722.7 | m3249 | +#+----+---------------+--------------+-------+-----------------------+----------------------+-------------+ + def acquire(self): self.logger.debug("in NerscSFApi acquire") return {"Nersc_Allocation_SFAPI": pd.DataFrame(self.send_query())} From 799b4e274b52fa564d65ecddaf0ab56f8921e039 Mon Sep 17 00:00:00 2001 From: stevenctimm Date: Thu, 4 Dec 2025 16:14:58 -0600 Subject: [PATCH 2/4] remove AWS mentions in glidein_requests.py --- .../glideinwms/transforms/glidein_requests.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/decisionengine_modules/glideinwms/transforms/glidein_requests.py b/src/decisionengine_modules/glideinwms/transforms/glidein_requests.py index fd68a52b..4c99a23b 100644 --- a/src/decisionengine_modules/glideinwms/transforms/glidein_requests.py +++ b/src/decisionengine_modules/glideinwms/transforms/glidein_requests.py @@ -20,11 +20,12 @@ "startd_manifests", "Grid_Figure_Of_Merit", "GCE_Figure_Of_Merit", - "AWS_Figure_Of_Merit", +# "AWS_Figure_Of_Merit", "Nersc_Figure_Of_Merit", ] -_SUPPORTED_ENTRY_TYPES = ["LCF", "AWS", "Grid", "GCE"] +#_SUPPORTED_ENTRY_TYPES = ["LCF", "AWS", "Grid", "GCE"] +_SUPPORTED_ENTRY_TYPES = ["LCF", "Grid", "GCE"] METRICS = { "NUMBER_OF_JOBS": Gauge("de_jobs_total", "Number of jobs seen by the Decision Engine"), @@ -117,7 +118,7 @@ def transform(self, datablock): foms = { "Grid_Figure_Of_Merit": self.Grid_Figure_Of_Merit(datablock), "GCE_Figure_Of_Merit": self.GCE_Figure_Of_Merit(datablock), - "AWS_Figure_Of_Merit": self.AWS_Figure_Of_Merit(datablock), +# "AWS_Figure_Of_Merit": self.AWS_Figure_Of_Merit(datablock), "Nersc_Figure_Of_Merit": self.Nersc_Figure_Of_Merit(datablock), } fom_entries = fom_eligible_resources( From f81f6da32d586e20d98601edebb5d3c69f70f5f9 Mon Sep 17 00:00:00 2001 From: stevenctimm Date: Fri, 5 Dec 2025 08:54:19 -0600 Subject: [PATCH 3/4] Revert "add new method of determining token expiration instaed of relying on jwt.decode error code" This reverts commit 60e336dcf168a6821ae65c44725aaa890c306b1a. --- .../NERSC/sources/NerscSFApi.py | 74 +++++++------------ 1 file changed, 25 insertions(+), 49 deletions(-) diff --git a/src/decisionengine_modules/NERSC/sources/NerscSFApi.py b/src/decisionengine_modules/NERSC/sources/NerscSFApi.py index a084b3a1..5d38532b 100644 --- a/src/decisionengine_modules/NERSC/sources/NerscSFApi.py +++ b/src/decisionengine_modules/NERSC/sources/NerscSFApi.py @@ -6,7 +6,7 @@ """ import json -import time +# import time import os import jwt @@ -36,7 +36,7 @@ def __init__(self, config): self.logger = self.logger.bind( class_module=__name__.split(".")[-1], ) - self.localmap = {"uscms": "m2612", "fife": "m4599", "dunepro": "m3249" } + self.localmap = {"uscms": "m2612", "fife": "m3249"} self.keys_list = ["hours_given", "hours_used", "id", "project_hours_given", "project_hours_used", "repo_name"] def check_accesstoken(self, nersc_user): @@ -62,14 +62,14 @@ def check_accesstoken(self, nersc_user): except KeyError: self.logger.error(f"Unknown user '{nersc_user}', exiting") return None - print(nersc_user) + rawfile = params_dict["rawfile"] pemfile = params_dict["private_key"] clientidfile = params_dict["client_id_file"] with open(clientidfile) as cifile: client_id = cifile.read() client_id = client_id.rstrip() - atoken = None + if not os.path.exists(rawfile): self.logger.debug(f"{rawfile} does not exist. Need to generate") else: @@ -77,36 +77,28 @@ def check_accesstoken(self, nersc_user): with open(rawfile) as afile: atoken = afile.read() atoken = atoken.rstrip() - # HK> If the access token is expired, the flow goes directly to except jwt.ExpiredSign - - if atoken is not None: - rvalue = jwt.decode(atoken, options={"verify_signature": False}) - ctime = int(time.time()) - diff = ctime - rvalue['exp'] - print( diff ) - else: - self.logger.debug("there is no access token file, setting diff high to indicate expired") - diff=10000000 - - if diff < 0: - self.logger.debug("Access Token not expired. Returning without generating a new access token") - return atoken # This means the existing access token is not expired. - - else: - self.logger.debug("Access Token expired") + # HK> If the access token is expired, the flow goes directly to except jwt.ExpiredSignatureError + try: + jwt.decode(atoken, options={"verify_signature": False}) + self.logger.debug("Access Token not expired. Returning without generating a new access token") + return atoken # This means the existing access token is not expired. + + except jwt.ExpiredSignatureError: + self.logger.debug("Access Token expired") + + certs = pem.parse_file(pemfile) + private_key = str(certs[0]) + client = OAuth2Session( + client_id=client_id, client_secret=private_key, token_endpoint_auth_method="private_key_jwt" + ) + client.register_client_auth_method(PrivateKeyJWT(token_url)) + resp = client.fetch_token(token_url, grant_type="client_credentials") - certs = pem.parse_file(pemfile) - private_key = str(certs[0]) - client = OAuth2Session( - client_id=client_id, client_secret=private_key, token_endpoint_auth_method="private_key_jwt" - ) - client.register_client_auth_method(PrivateKeyJWT(token_url)) - resp = client.fetch_token(token_url, grant_type="client_credentials") - newtoken = resp["access_token"] + newtoken = resp["access_token"] - with open(rawfile, "w") as myfile: - myfile.write(newtoken) - return newtoken + with open(rawfile, "w") as myfile: + myfile.write(newtoken) + return newtoken def get_headers2(self, access_token): headers = {} @@ -126,32 +118,16 @@ def requests_nersc(self, username): def send_query(self): results = [] - print(self.constraints.get("usernames", [])) for username in self.constraints.get("usernames", []): - self.logger.debug("in send_query %s",username) - print(username) returned_list = self.requests_nersc(username) - self.logger.debug(returned_list) - print(returned_list) for each_dict in returned_list: # HK> This if condition will choose only m3249 for fife and discard m3990 if each_dict["repo_name"] == self.localmap[username]: local_dict = {each_key: each_dict[each_key] for each_key in self.keys_list} - local_dict['real_name'] = username + local_dict["real_name"] = username results.append(local_dict) return results -#self.localmap = {"uscms": "m2612", "fife": "m3249"} -#self.keys_list = [ -# "hours_given", "hours_used", "id", "project_hours_given", "project_hours_used", "repo_name" ] - -#+----+---------------+--------------+-------+-----------------------+----------------------+-------------+ -#| | hours_given | hours_used | id | project_hours_given | project_hours_used | repo_name | -#|----+---------------+--------------+-------+-----------------------+----------------------+-------------| -#| 0 | 600000 | 473490 | 54807 | 600000 | 473946 | m2612 | -#| 1 | 19109.1 | 0 | 63322 | 95545.7 | 24722.7 | m3249 | -#+----+---------------+--------------+-------+-----------------------+----------------------+-------------+ - def acquire(self): self.logger.debug("in NerscSFApi acquire") return {"Nersc_Allocation_SFAPI": pd.DataFrame(self.send_query())} From d8ce9694ab9c6936382c33155cae6e460fd8d5d6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:41:25 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../glideinwms/transforms/glidein_requests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/decisionengine_modules/glideinwms/transforms/glidein_requests.py b/src/decisionengine_modules/glideinwms/transforms/glidein_requests.py index 4c99a23b..45db8ebf 100644 --- a/src/decisionengine_modules/glideinwms/transforms/glidein_requests.py +++ b/src/decisionengine_modules/glideinwms/transforms/glidein_requests.py @@ -20,11 +20,11 @@ "startd_manifests", "Grid_Figure_Of_Merit", "GCE_Figure_Of_Merit", -# "AWS_Figure_Of_Merit", + # "AWS_Figure_Of_Merit", "Nersc_Figure_Of_Merit", ] -#_SUPPORTED_ENTRY_TYPES = ["LCF", "AWS", "Grid", "GCE"] +# _SUPPORTED_ENTRY_TYPES = ["LCF", "AWS", "Grid", "GCE"] _SUPPORTED_ENTRY_TYPES = ["LCF", "Grid", "GCE"] METRICS = { @@ -118,7 +118,7 @@ def transform(self, datablock): foms = { "Grid_Figure_Of_Merit": self.Grid_Figure_Of_Merit(datablock), "GCE_Figure_Of_Merit": self.GCE_Figure_Of_Merit(datablock), -# "AWS_Figure_Of_Merit": self.AWS_Figure_Of_Merit(datablock), + # "AWS_Figure_Of_Merit": self.AWS_Figure_Of_Merit(datablock), "Nersc_Figure_Of_Merit": self.Nersc_Figure_Of_Merit(datablock), } fom_entries = fom_eligible_resources(