diff --git a/src/nsls2ptycho/core/ptycho_recon.py b/src/nsls2ptycho/core/ptycho_recon.py index eb6c612..52b7f4f 100755 --- a/src/nsls2ptycho/core/ptycho_recon.py +++ b/src/nsls2ptycho/core/ptycho_recon.py @@ -10,10 +10,99 @@ import traceback import time + +import requests +import json + from .databroker_api import load_metadata, save_data from .utils import use_mpi_machinefile, set_flush_early from .ptycho.utils import save_config + +class RemoteJobHandler: + def __init__(self): + self.url = "https://orion-api-staging.nsls2.bnl.gov/api/v1/compute/orion/jobs" + self.api_key = os.getenv("APIKEY") + self.headers = { + "Content-Type": "application/json", + "x-api-key": f"{self.api_key}" + } + self.params = {"expand_info": False} + + def submit_job(self, remote_path, parent_module, param): + + fname_full = os.path.join(remote_path,'ptycho_'+str(param.scan_num)+'_'+param.sign) + srun_command = "python " + "-W " + "ignore " + "-m " + parent_module + ".ptycho.recon_ptycho_gui " + fname_full + #srun_command = "python -W ignore -m nsls2ptycho.core.ptycho.recon_ptycho_gui /nsls2/users/skarakuzu1/ptycho_test/remote_orion/ptycho_320045_t1" + + HOME = os.environ["HOME"] + + overrides = { + "name": "trial", + "partition": "normal", + "tasks": f"{len(param.gpus)}", + "time_limit": 720, + "tres_per_task": "cpu=1,gres/gpu=1", + "standard_output": "trial.out", + "standard_error": "trial.err", + } + + payload = { + "script": ( + "#!/bin/bash -l\n" + "module load orion/gpu\n" + "module unload openmpi\n" + "conda activate /nsls2/conda/envs/2025-2.0-py311-tiled/\n" + "nvidia-smi\n" + "echo $HOME\n" + "echo $(pwd)\n" + "echo $(which mpicc)\n" + #f"mpirun -n 2 {srun_command}\n" + f"srun --mpi=pmix {srun_command}\n" + ), + "working_dir_path": f"{remote_path}", + "environment": [ + "PATH=/usr/bin:/bin:/usr/sbin:/sbin", + f"HOME={HOME}", + "SLURM_EXPORT_ENV=ALL", + ], + "overrides": overrides, + } + + + sys.stdout.flush() + + response = requests.post(self.url, headers=self.headers, json=payload) + resp_json = response.json() + #print("response is ", response.status_code, resp_json) + + if response.status_code == 200: + self.remote_job_id = resp_json['job_id'] + print("submitted job with id ", self.remote_job_id) + + + return response.status_code + + + def cancel_job(self): + response = requests.delete(f"{self.url}/{self.remote_job_id}", headers=self.headers, params=self.params) + + resp_json = response.json() + #print("response is ", response.status_code, resp_json) + + return response.status_code + + def get_job_status(self): + response = requests.get(f"{self.url}/{self.remote_job_id}", headers=self.headers, params=self.params) + + resp_json = response.json() + #print("response is ", response.status_code, resp_json) + + if response.status_code == 200: + state = resp_json["jobs"][0]["state"][0] + return state + + class PtychoReconRemote(QtCore.QThread): update_signal = QtCore.pyqtSignal(int, object) # (interation number, chi arrays) @@ -28,12 +117,14 @@ def __init__(self, param:Param=None, parent=None): if not os.path.isdir(self.remote_path): os.mkdir(self.remote_path) - self.msg_file = os.path.join(os.path.join(self.remote_path,'msg')) - if not os.path.isfile(self.msg_file): - with open(self.msg_file,'w') as f: - pass - self.msg = open(self.msg_file,'r') - self.msg.readlines() + #self.msg_file = os.path.join(os.path.join(self.remote_path,'msg')) + #if not os.path.isfile(self.msg_file): + # with open(self.msg_file,'w') as f: + # pass + #self.msg = open(self.msg_file,'r') + #self.msg.readlines() + + self.remote_job_handler = RemoteJobHandler() def _parse_message(self, tokens): def _parser(current, upper_limit, target_list): @@ -98,11 +189,21 @@ def clear_slurm_header(self): os.remove(slurm_header) except: pass - + + def cleanup(self): + if self.fname_full and os.path.exists(self.fname_full): + os.remove(self.fname_full) + self.fname_full = None + if os.path.exists(os.path.join(self.remote_path,'prb_live.npy')): + os.remove(os.path.join(self.remote_path,'prb_live.npy')) + if os.path.exists(os.path.join(self.remote_path,'obj_live.npy')): + os.remove(os.path.join(self.remote_path,'obj_live.npy')) + def recon_remote(self, param:Param, update_fcn=None): self.fname_full = os.path.join(self.remote_path,'ptycho_'+str(param.scan_num)+'_'+param.sign) - + self.parent_module = '.'.join(self.__module__.rsplit('.', 2)[:-1]) # get parent module name to run the correct recon worker + if param.working_directory: param.working_directory = os.path.realpath(param.working_directory)+'/' if param.prb_dir: @@ -115,26 +216,39 @@ def recon_remote(self, param:Param, update_fcn=None): param.obj_path = os.path.realpath(param.obj_path) save_config(self.fname_full,param) - self.export_slurm_header() + #self.export_slurm_header() self.return_value = 0 # Assume the recon will succeed unless later detects failure and modify it. - # try: + status = self.remote_job_handler.submit_job(self.remote_path, self.parent_module, param) + print(f"Submitted job from the gui with status code {status} and reserved job id {self.remote_job_handler.remote_job_id}") + time.sleep(1) - out = self.msg.readlines() - while not out: + while self.remote_job_handler.get_job_status() != "RUNNING": print('Waiting for remote worker on %s to take the recon task...'%param.remote_srv) time.sleep(1) + + + file_name = f"slurm-{self.remote_job_handler.remote_job_id}.out" + self.msg_file = os.path.join(self.remote_path, file_name) + + self.msg = open(self.msg_file, "r") + out = self.msg.readlines() + pos = 0 + + time.sleep(1) + while not out: + print('Waiting for remote worker on %s to start writing...'%param.remote_srv) out = self.msg.readlines() - if os.path.isfile(os.path.join(self.remote_path,'abort')): - os.remove(os.path.join(self.remote_path,'abort')) - if os.path.isfile(os.path.join(self.remote_path,'msg')): - os.remove(os.path.join(self.remote_path,'msg')) - if os.path.isfile(self.fname_full): - os.remove(self.fname_full) - raise Exception('Remote recon aborted...') + time.sleep(1) + + time.sleep(1) while True: + self.msg.seek(pos) + out = self.msg.readlines() + pos = self.msg.tell() # remember where we stopped + for line in out: print(line, end='') # because the line already ends with '\n' tokens = line.split() @@ -145,19 +259,27 @@ def recon_remote(self, param:Param, update_fcn=None): #print(result['probe_chi']) if 'aborted' in line: self.return_value = 1 # Aborted - if not os.path.isfile(self.fname_full): break + # ask Slurm about job status + status = self.remote_job_handler.get_job_status() + + # stop when job is no longer running AND there was no new data + if status != "RUNNING": + size = os.path.getsize(self.msg_file) + if size == pos: + break + time.sleep(0.1) - out = self.msg.readlines() + # except: # pass # finally: # pass def run(self): - print('Ptycho thread started') + print('Ptycho thread started helloooo***') try: self.recon_remote(self.param, self.update_signal.emit) except IndexError: @@ -172,10 +294,15 @@ def run(self): self.update_signal.emit(self.param.n_iterations+1,None) finally: - self.clear_slurm_header() print('finally?') + #self.clear_slurm_header() + status = self.remote_job_handler.cancel_job() + print(f"Cancelled job with id {self.remote_job_handler.remote_job_id} from the gui with status code {status}") + self.cleanup() def kill(self): + self.remote_job_handler.cancel_job() + print(f"Cancelled job with id {self.remote_job_handler.remote_job_id} from the gui with status code {status}") if os.path.isdir(self.remote_path): with open(os.path.join(self.remote_path,'abort'),'w') as f: pass