-
Notifications
You must be signed in to change notification settings - Fork 29
Refactor/user function handling modules + Manager can run additional worker on thread #1216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
862a776
e6874a6
e17eabe
83493d0
6ad870c
d14b0aa
ab32e3f
68d8855
33ea282
843df39
3aeab06
b083a21
231e2b7
d251363
368bf93
744620d
ca14b7c
cd6f0db
0952067
884d61b
dfb0fbb
ec236ed
7b94467
f06148a
d584152
592c8c4
59ca40a
d8a3a42
1839ff2
43e98a9
65fc121
95badb1
1fcf91f
ad525bb
fe64869
dcf6db7
ba05900
482ec15
3d06b1c
9165d7d
f7ba205
376e450
ac52a9f
6375058
c07a565
c46802e
2ee9466
9ebe767
09d030c
2f631e0
ab39de6
550ca1f
e7591b6
68b991a
c433ecb
e78056b
f30233c
6d0f9d2
97c2c53
73d4b4c
13fecde
0bcfc79
45cbd16
2bc504c
aa4db8a
429adb4
b1f9108
6fa18ef
dbdf88f
eacf46f
e4d4b08
ffbe6c9
14c8b1f
45e99b2
6f713dc
4aa386f
0b12af2
97cdfdb
77f880e
2780f10
655a1ba
15719c7
3138a39
2093629
8a50e60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -146,10 +146,11 @@ def mail_flag(self): | |
|
|
||
|
|
||
| class QCommLocal(Comm): | ||
| def __init__(self, main, nworkers, *args, **kwargs): | ||
| def __init__(self, main, *args, **kwargs): | ||
| self._result = None | ||
| self._exception = None | ||
| self._done = False | ||
| self._ufunc = kwargs.get("ufunc", False) | ||
|
||
|
|
||
| def _is_result_msg(self, msg): | ||
| """Return true if message indicates final result (and set result/except).""" | ||
|
|
@@ -208,10 +209,13 @@ def result(self, timeout=None): | |
| return self._result | ||
|
|
||
| @staticmethod | ||
| def _qcomm_main(comm, main, *args, **kwargs): | ||
| def _qcomm_main(comm, main, *fargs, **kwargs): | ||
|
||
| """Main routine -- handles return values and exceptions.""" | ||
| try: | ||
| _result = main(comm, *args, **kwargs) | ||
| if not kwargs.get("ufunc"): | ||
|
||
| _result = main(comm, *fargs, **kwargs) | ||
| else: | ||
| _result = main(*fargs) | ||
| comm.send(CommResult(_result)) | ||
| except Exception as e: | ||
| comm.send(CommResultErr(str(e), format_exc())) | ||
|
|
@@ -233,12 +237,12 @@ def __exit__(self, etype, value, traceback): | |
| class QCommThread(QCommLocal): | ||
| """Launch a user function in a thread with an attached QComm.""" | ||
|
|
||
| def __init__(self, main, nworkers, *args, **kwargs): | ||
| def __init__(self, main, nworkers, *fargs, **kwargs): | ||
| self.inbox = thread_queue.Queue() | ||
| self.outbox = thread_queue.Queue() | ||
| super().__init__(self, main, nworkers, *args, **kwargs) | ||
| super().__init__(self, main, *fargs, **kwargs) | ||
| comm = QComm(self.inbox, self.outbox, nworkers) | ||
| self.handle = Thread(target=QCommThread._qcomm_main, args=(comm, main) + args, kwargs=kwargs) | ||
| self.handle = Thread(target=QCommThread._qcomm_main, args=(comm, main) + fargs, kwargs=kwargs) | ||
|
|
||
| def terminate(self, timeout=None): | ||
| """Terminate the thread. | ||
|
|
@@ -260,7 +264,7 @@ class QCommProcess(QCommLocal): | |
| def __init__(self, main, nworkers, *args, **kwargs): | ||
| self.inbox = Queue() | ||
| self.outbox = Queue() | ||
| super().__init__(self, main, nworkers, *args, **kwargs) | ||
| super().__init__(self, main, *args, **kwargs) | ||
| comm = QComm(self.inbox, self.outbox, nworkers) | ||
| self.handle = Process(target=QCommProcess._qcomm_main, args=(comm, main) + args, kwargs=kwargs) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,8 @@ | |
| import numpy.typing as npt | ||
| from numpy.lib.recfunctions import repack_fields | ||
|
|
||
| from libensemble.comms.comms import CommFinishedException | ||
| from libensemble.comms.comms import CommFinishedException, QCommThread | ||
| from libensemble.executors.executor import Executor | ||
| from libensemble.message_numbers import ( | ||
| EVAL_GEN_TAG, | ||
| EVAL_SIM_TAG, | ||
|
|
@@ -37,7 +38,7 @@ | |
| from libensemble.utils.misc import extract_H_ranges | ||
| from libensemble.utils.output_directory import EnsembleDirectory | ||
| from libensemble.utils.timer import Timer | ||
| from libensemble.worker import WorkerErrMsg | ||
| from libensemble.worker import WorkerErrMsg, worker_main | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| # For debug messages - uncomment | ||
|
|
@@ -154,6 +155,48 @@ def filter_nans(array: npt.NDArray) -> npt.NDArray: | |
| """ | ||
|
|
||
|
|
||
| class _Worker: | ||
| """Wrapper class for Worker array and worker comms""" | ||
|
|
||
| def __init__(self, W: npt.NDArray, wid: int, wcomms: list = []): | ||
| self.__dict__["_W"] = W | ||
|
||
| if 0 in W["worker_id"]: # Contains "0" for manager. Otherwise first entry is Worker 1 | ||
| self.__dict__["_wididx"] = wid | ||
| else: | ||
| self.__dict__["_wididx"] = wid - 1 | ||
| self.__dict__["_wcomms"] = wcomms | ||
|
|
||
| def __setattr__(self, field, value): | ||
| self._W[self._wididx][field] = value | ||
|
|
||
| def __getattr__(self, field): | ||
| return self._W[self._wididx][field] | ||
|
|
||
| def update_state_on_alloc(self, Work: dict): | ||
| self.active = Work["tag"] | ||
| if "persistent" in Work["libE_info"]: | ||
| self.persis_state = Work["tag"] | ||
| if Work["libE_info"].get("active_recv", False): | ||
| self.active_recv = Work["tag"] | ||
| else: | ||
| assert "active_recv" not in Work["libE_info"], "active_recv worker must also be persistent" | ||
|
|
||
| def update_persistent_state(self): | ||
| self.persis_state = 0 | ||
| if self.active_recv: | ||
| self.active = 0 | ||
| self.active_recv = 0 | ||
|
|
||
| def send(self, tag, data): | ||
| self._wcomms[self._wididx].send(tag, data) | ||
|
|
||
| def mail_flag(self): | ||
| return self._wcomms[self._wididx].mail_flag() | ||
|
|
||
| def recv(self): | ||
| return self._wcomms[self._wididx].recv() | ||
|
|
||
|
|
||
| class Manager: | ||
| """Manager class for libensemble.""" | ||
|
|
||
|
|
@@ -209,6 +252,30 @@ def __init__( | |
| (1, "stop_val", self.term_test_stop_val), | ||
| ] | ||
|
|
||
| if self.libE_specs.get("manager_runs_additional_worker", False): | ||
|
||
|
|
||
| dtypes = { | ||
| EVAL_SIM_TAG: repack_fields(hist.H[sim_specs["in"]]).dtype, | ||
| EVAL_GEN_TAG: repack_fields(hist.H[gen_specs["in"]]).dtype, | ||
| } | ||
|
|
||
| self.W = np.zeros(len(self.wcomms) + 1, dtype=Manager.worker_dtype) | ||
| self.W["worker_id"] = np.arange(len(self.wcomms) + 1) | ||
| local_worker_comm = QCommThread( | ||
| worker_main, | ||
| len(self.wcomms), | ||
| sim_specs, | ||
| gen_specs, | ||
| libE_specs, | ||
| 0, | ||
| False, | ||
| Resources.resources, | ||
| Executor.executor, | ||
| ) | ||
| self.wcomms = [local_worker_comm] + self.wcomms | ||
| local_worker_comm.run() | ||
| local_worker_comm.send(0, dtypes) | ||
|
|
||
| temp_EnsembleDirectory = EnsembleDirectory(libE_specs=libE_specs) | ||
| self.resources = Resources.resources | ||
| self.scheduler_opts = self.libE_specs.get("scheduler_opts", {}) | ||
|
|
@@ -266,7 +333,8 @@ def term_test(self, logged: bool = True) -> Union[bool, int]: | |
| def _kill_workers(self) -> None: | ||
| """Kills the workers""" | ||
| for w in self.W["worker_id"]: | ||
| self.wcomms[w - 1].send(STOP_TAG, MAN_SIGNAL_FINISH) | ||
| worker = _Worker(self.W, w, self.wcomms) | ||
| worker.send(STOP_TAG, MAN_SIGNAL_FINISH) | ||
|
|
||
| # --- Checkpointing logic | ||
|
|
||
|
|
@@ -320,15 +388,16 @@ def _init_every_k_save(self, complete=False) -> None: | |
|
|
||
| def _check_work_order(self, Work: dict, w: int, force: bool = False) -> None: | ||
| """Checks validity of an allocation function order""" | ||
| assert w != 0, "Can't send to worker 0; this is the manager." | ||
| if self.W[w - 1]["active_recv"]: | ||
| # assert w != 0, "Can't send to worker 0; this is the manager." | ||
| worker = _Worker(self.W, w, self.wcomms) | ||
| if worker.active_recv: | ||
| assert "active_recv" in Work["libE_info"], ( | ||
| "Messages to a worker in active_recv mode should have active_recv" | ||
| f"set to True in libE_info. Work['libE_info'] is {Work['libE_info']}" | ||
| ) | ||
| else: | ||
| if not force: | ||
| assert self.W[w - 1]["active"] == 0, ( | ||
| assert worker.active == 0, ( | ||
| "Allocation function requested work be sent to worker %d, an already active worker." % w | ||
| ) | ||
| work_rows = Work["libE_info"]["H_rows"] | ||
|
|
@@ -370,13 +439,15 @@ def _send_work_order(self, Work: dict, w: int) -> None: | |
| """Sends an allocation function order to a worker""" | ||
| logger.debug(f"Manager sending work unit to worker {w}") | ||
|
|
||
| worker = _Worker(self.W, w, self.wcomms) | ||
|
|
||
| if self.resources: | ||
| self._set_resources(Work, w) | ||
|
|
||
| self.wcomms[w - 1].send(Work["tag"], Work) | ||
| worker.send(Work["tag"], Work) | ||
|
|
||
| if Work["tag"] == EVAL_GEN_TAG: | ||
| self.W[w - 1]["gen_started_time"] = time.time() | ||
| worker.gen_started_time = time.time() | ||
|
|
||
| work_rows = Work["libE_info"]["H_rows"] | ||
| work_name = calc_type_strings[Work["tag"]] | ||
|
|
@@ -386,18 +457,14 @@ def _send_work_order(self, Work: dict, w: int) -> None: | |
| H_to_be_sent = np.empty(len(work_rows), dtype=new_dtype) | ||
| for i, row in enumerate(work_rows): | ||
| H_to_be_sent[i] = repack_fields(self.hist.H[Work["H_fields"]][row]) | ||
| self.wcomms[w - 1].send(0, H_to_be_sent) | ||
|
|
||
| worker.send(0, H_to_be_sent) | ||
|
|
||
| def _update_state_on_alloc(self, Work: dict, w: int): | ||
| """Updates a workers' active/idle status following an allocation order""" | ||
| self.W[w - 1]["active"] = Work["tag"] | ||
| if "libE_info" in Work: | ||
| if "persistent" in Work["libE_info"]: | ||
| self.W[w - 1]["persis_state"] = Work["tag"] | ||
| if Work["libE_info"].get("active_recv", False): | ||
| self.W[w - 1]["active_recv"] = Work["tag"] | ||
| else: | ||
| assert "active_recv" not in Work["libE_info"], "active_recv worker must also be persistent" | ||
|
|
||
| worker = _Worker(self.W, w, self.wcomms) | ||
| worker.update_state_on_alloc(Work) | ||
|
|
||
| work_rows = Work["libE_info"]["H_rows"] | ||
| if Work["tag"] == EVAL_SIM_TAG: | ||
|
|
@@ -432,7 +499,8 @@ def _receive_from_workers(self, persis_info: dict) -> dict: | |
| while new_stuff: | ||
| new_stuff = False | ||
| for w in self.W["worker_id"]: | ||
| if self.wcomms[w - 1].mail_flag(): | ||
| worker = _Worker(self.W, w, self.wcomms) | ||
| if worker.mail_flag(): | ||
| new_stuff = True | ||
| self._handle_msg_from_worker(persis_info, w) | ||
|
|
||
|
|
@@ -445,38 +513,37 @@ def _update_state_on_worker_msg(self, persis_info: dict, D_recv: dict, w: int) - | |
| calc_status = D_recv["calc_status"] | ||
| Manager._check_received_calc(D_recv) | ||
|
|
||
| worker = _Worker(self.W, w, self.wcomms) | ||
|
|
||
| keep_state = D_recv["libE_info"].get("keep_state", False) | ||
| if w not in self.persis_pending and not self.W[w - 1]["active_recv"] and not keep_state: | ||
| self.W[w - 1]["active"] = 0 | ||
| if w not in self.persis_pending and not worker.active_recv and not keep_state: | ||
| worker.active = 0 | ||
|
|
||
| if calc_status in [FINISHED_PERSISTENT_SIM_TAG, FINISHED_PERSISTENT_GEN_TAG]: | ||
| final_data = D_recv.get("calc_out", None) | ||
| if isinstance(final_data, np.ndarray): | ||
| if calc_status is FINISHED_PERSISTENT_GEN_TAG and self.libE_specs.get("use_persis_return_gen", False): | ||
| self.hist.update_history_x_in(w, final_data, self.W[w - 1]["gen_started_time"]) | ||
| self.hist.update_history_x_in(w, final_data, worker.gen_started_time) | ||
| elif calc_status is FINISHED_PERSISTENT_SIM_TAG and self.libE_specs.get("use_persis_return_sim", False): | ||
| self.hist.update_history_f(D_recv, self.kill_canceled_sims) | ||
| else: | ||
| logger.info(_PERSIS_RETURN_WARNING) | ||
| self.W[w - 1]["persis_state"] = 0 | ||
| if self.W[w - 1]["active_recv"]: | ||
| self.W[w - 1]["active"] = 0 | ||
| self.W[w - 1]["active_recv"] = 0 | ||
| worker.update_persistent_state() | ||
| if w in self.persis_pending: | ||
| self.persis_pending.remove(w) | ||
| self.W[w - 1]["active"] = 0 | ||
| worker.active = 0 | ||
| self._freeup_resources(w) | ||
| else: | ||
| if calc_type == EVAL_SIM_TAG: | ||
| self.hist.update_history_f(D_recv, self.kill_canceled_sims) | ||
| if calc_type == EVAL_GEN_TAG: | ||
| self.hist.update_history_x_in(w, D_recv["calc_out"], self.W[w - 1]["gen_started_time"]) | ||
| self.hist.update_history_x_in(w, D_recv["calc_out"], worker.gen_started_time) | ||
| assert ( | ||
| len(D_recv["calc_out"]) or np.any(self.W["active"]) or self.W[w - 1]["persis_state"] | ||
| len(D_recv["calc_out"]) or np.any(self.W["active"]) or worker.persis_state | ||
| ), "Gen must return work when is is the only thing active and not persistent." | ||
| if "libE_info" in D_recv and "persistent" in D_recv["libE_info"]: | ||
| # Now a waiting, persistent worker | ||
| self.W[w - 1]["persis_state"] = calc_type | ||
| worker.persis_state = calc_type | ||
| else: | ||
| self._freeup_resources(w) | ||
|
|
||
|
|
@@ -485,14 +552,15 @@ def _update_state_on_worker_msg(self, persis_info: dict, D_recv: dict, w: int) - | |
|
|
||
| def _handle_msg_from_worker(self, persis_info: dict, w: int) -> None: | ||
| """Handles a message from worker w""" | ||
| worker = _Worker(self.W, w, self.wcomms) | ||
| try: | ||
| msg = self.wcomms[w - 1].recv() | ||
| msg = worker.recv() | ||
| tag, D_recv = msg | ||
| except CommFinishedException: | ||
| logger.debug(f"Finalizing message from Worker {w}") | ||
| return | ||
| if isinstance(D_recv, WorkerErrMsg): | ||
| self.W[w - 1]["active"] = 0 | ||
| worker.active = 0 | ||
| logger.debug(f"Manager received exception from worker {w}") | ||
| if not self.WorkerExc: | ||
| self.WorkerExc = True | ||
|
|
@@ -525,7 +593,8 @@ def _kill_cancelled_sims(self) -> None: | |
| kill_ids = self.hist.H["sim_id"][kill_sim_rows] | ||
| kill_on_workers = self.hist.H["sim_worker"][kill_sim_rows] | ||
| for w in kill_on_workers: | ||
| self.wcomms[w - 1].send(STOP_TAG, MAN_SIGNAL_KILL) | ||
| worker = _Worker(self.W, w, self.wcomms) | ||
| worker.send(STOP_TAG, MAN_SIGNAL_KILL) | ||
| self.hist.H["kill_sent"][kill_ids] = True | ||
|
|
||
| # --- Handle termination | ||
|
|
@@ -542,6 +611,7 @@ def _final_receive_and_kill(self, persis_info: dict) -> (dict, int, int): | |
| # Send a handshake signal to each persistent worker. | ||
| if any(self.W["persis_state"]): | ||
| for w in self.W["worker_id"][self.W["persis_state"] > 0]: | ||
| worker = _Worker(self.W, w, self.wcomms) | ||
| logger.debug(f"Manager sending PERSIS_STOP to worker {w}") | ||
| if self.libE_specs.get("final_gen_send", False): | ||
| rows_to_send = np.where(self.hist.H["sim_ended"] & ~self.hist.H["gen_informed"])[0] | ||
|
|
@@ -555,10 +625,10 @@ def _final_receive_and_kill(self, persis_info: dict) -> (dict, int, int): | |
| self._send_work_order(work, w) | ||
| self.hist.update_history_to_gen(rows_to_send) | ||
| else: | ||
| self.wcomms[w - 1].send(PERSIS_STOP, MAN_SIGNAL_KILL) | ||
| if not self.W[w - 1]["active"]: | ||
| worker.send(PERSIS_STOP, MAN_SIGNAL_KILL) | ||
| if not worker.active: | ||
| # Re-activate if necessary | ||
| self.W[w - 1]["active"] = self.W[w - 1]["persis_state"] | ||
| worker.active = worker.persis_state | ||
| self.persis_pending.append(w) | ||
|
|
||
| exit_flag = 0 | ||
|
|
@@ -601,6 +671,7 @@ def _get_alloc_libE_info(self) -> dict: | |
| "use_resource_sets": self.use_resource_sets, | ||
| "gen_num_procs": self.gen_num_procs, | ||
| "gen_num_gpus": self.gen_num_gpus, | ||
| "manager_additional_worker": self.libE_specs.get("manager_runs_additional_worker", False), | ||
| } | ||
|
|
||
| def _alloc_work(self, H: npt.NDArray, persis_info: dict) -> dict: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.