From abec17fbe5f75876b31fa4dbc77ee216d73a79b4 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 21 Oct 2025 10:51:59 -0700 Subject: [PATCH 1/3] Rename and handle workspace already opened --- simpeg/directives/_save_geoh5.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/simpeg/directives/_save_geoh5.py b/simpeg/directives/_save_geoh5.py index 9ad798ab67..73614411db 100644 --- a/simpeg/directives/_save_geoh5.py +++ b/simpeg/directives/_save_geoh5.py @@ -52,20 +52,20 @@ def __init__( ) def initialize(self): - if self.open_geoh5: - self._geoh5.open(mode="r+") + if self.open_geoh5 and not getattr(self._workspace, "_geoh5", None): + self._workspace.open(mode="r+") self.write(0) if self.close_geoh5: - self._geoh5.close() + self._workspace.close() def endIter(self): - if self.open_geoh5: - self._geoh5.open(mode="r+") + if self.open_geoh5 and not getattr(self._workspace, "_geoh5", None): + self._workspace.open(mode="r+") self.write(self.opt.iter) if self.close_geoh5: - self._geoh5.close() + self._workspace.close() def get_names( self, component: str, channel: str, iteration: int @@ -127,7 +127,7 @@ def h5_object(self, entity: ObjectBase): ) self._h5_object = entity.uid - self._geoh5 = entity.workspace + self._workspace = entity.workspace if getattr(entity, "n_cells", None) is not None: self.association = "CELL" @@ -263,7 +263,7 @@ def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noq prop = self.apply_transformations(prop) # Save results - with fetch_active_workspace(self._geoh5, mode="r+") as w_s: + with fetch_active_workspace(self._workspace, mode="r+") as w_s: h5_object = w_s.get_entity(self.h5_object)[0] for cc, component in enumerate(self.components): if component not in self.data_type: @@ -386,7 +386,7 @@ def joint_index(self, value: list[int] | None): class SaveLogFilesGeoH5(BaseSaveGeoH5): def write(self, iteration: int, **_): - dirpath = Path(self._geoh5.h5file).parent + dirpath = Path(self._workspace.h5file).parent filepath = dirpath / "SimPEG.out" if iteration == 0: @@ -412,9 +412,9 @@ def save_log(self): """ Save iteration metrics to comments. """ - dirpath = Path(self._geoh5.h5file).parent + dirpath = Path(self._workspace.h5file).parent - with fetch_active_workspace(self._geoh5, mode="r+") as w_s: + with fetch_active_workspace(self._workspace, mode="r+") as w_s: h5_object = w_s.get_entity(self.h5_object)[0] for file in ["SimPEG.out", "SimPEG.log", "ChiFactors.log"]: @@ -452,7 +452,7 @@ def write(self, iteration: int, **_): """ Save the model to the geoh5 file """ - with fetch_active_workspace(self._geoh5, mode="r+") as w_s: + with fetch_active_workspace(self._workspace, mode="r+") as w_s: h5_object = w_s.get_entity(self.h5_object)[0] for component in self.components: @@ -560,7 +560,7 @@ def write(self, iteration: int, values: list[np.ndarray] | None = None): petro_model = self.get_values(values) petro_model = self.apply_transformations(petro_model).flatten() channel_name, _ = self.get_names("petrophysics", "", iteration) - with fetch_active_workspace(self._geoh5, mode="r+") as w_s: + with fetch_active_workspace(self._workspace, mode="r+") as w_s: h5_object = w_s.get_entity(self.h5_object)[0] data = h5_object.add_data( { From 51615d0456677dabc0a07f6e8b1920e218bf6f26 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 21 Oct 2025 14:51:01 -0700 Subject: [PATCH 2/3] Fix scattering for FEM simulations --- .../electromagnetics/frequency_domain/simulation.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 023b1ec605..b98f3755d1 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -282,7 +282,7 @@ def compute_J(self, m, f=None): for block_derivs_chunks, addresses_chunks in zip( blocks_receiver_derivs, blocks, strict=True ): - Jmatrix = parallel_block_compute( + parallel_block_compute( simulation, m, Jmatrix, @@ -301,7 +301,10 @@ def compute_J(self, m, f=None): gc.collect() if self.store_sensitivities == "disk": del Jmatrix - Jmatrix = array.from_zarr(self.sensitivity_path) + return array.from_zarr(self.sensitivity_path) + + if client: + return client.gather(Jmatrix) return Jmatrix @@ -367,11 +370,9 @@ def parallel_block_compute( count += n_cols if client: - client.gather(block_delayed) + return client.gather(block_delayed) else: - compute(block_delayed) - - return Jmatrix + return compute(block_delayed) Sim.compute_J = compute_J From 284bdad5c71b4c60b543cc8df220dc5be1ac8732 Mon Sep 17 00:00:00 2001 From: domfournier Date: Tue, 21 Oct 2025 15:23:04 -0700 Subject: [PATCH 3/3] Small fix check for DC --- simpeg/dask/electromagnetics/static/resistivity/simulation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/simpeg/dask/electromagnetics/static/resistivity/simulation.py b/simpeg/dask/electromagnetics/static/resistivity/simulation.py index e3f6e5e29e..bae0a8606c 100644 --- a/simpeg/dask/electromagnetics/static/resistivity/simulation.py +++ b/simpeg/dask/electromagnetics/static/resistivity/simulation.py @@ -181,6 +181,10 @@ def getSourceTerm(self): ) blocks = [] for ind in indices: + + if len(ind) == 0: + continue + blocks.append( client.submit(source_eval, sim, future_list, ind, workers=worker) )