Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions simpeg/dask/electromagnetics/frequency_domain/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def compute_J(self, m, f=None):
for block_derivs_chunks, addresses_chunks in zip(
blocks_receiver_derivs, blocks, strict=True
):
Jmatrix = parallel_block_compute(
parallel_block_compute(
simulation,
m,
Jmatrix,
Expand All @@ -301,7 +301,10 @@ def compute_J(self, m, f=None):
gc.collect()
if self.store_sensitivities == "disk":
del Jmatrix
Jmatrix = array.from_zarr(self.sensitivity_path)
return array.from_zarr(self.sensitivity_path)

if client:
return client.gather(Jmatrix)

return Jmatrix

Expand Down Expand Up @@ -367,11 +370,9 @@ def parallel_block_compute(
count += n_cols

if client:
client.gather(block_delayed)
return client.gather(block_delayed)
else:
compute(block_delayed)

return Jmatrix
return compute(block_delayed)


Sim.compute_J = compute_J
Expand Down
4 changes: 4 additions & 0 deletions simpeg/dask/electromagnetics/static/resistivity/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def getSourceTerm(self):
)
blocks = []
for ind in indices:

if len(ind) == 0:
continue

blocks.append(
client.submit(source_eval, sim, future_list, ind, workers=worker)
)
Expand Down
26 changes: 13 additions & 13 deletions simpeg/directives/_save_geoh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,20 @@ def __init__(
)

def initialize(self):
if self.open_geoh5:
self._geoh5.open(mode="r+")
if self.open_geoh5 and not getattr(self._workspace, "_geoh5", None):
self._workspace.open(mode="r+")

self.write(0)

if self.close_geoh5:
self._geoh5.close()
self._workspace.close()

def endIter(self):
if self.open_geoh5:
self._geoh5.open(mode="r+")
if self.open_geoh5 and not getattr(self._workspace, "_geoh5", None):
self._workspace.open(mode="r+")
self.write(self.opt.iter)
if self.close_geoh5:
self._geoh5.close()
self._workspace.close()

def get_names(
self, component: str, channel: str, iteration: int
Expand Down Expand Up @@ -127,7 +127,7 @@ def h5_object(self, entity: ObjectBase):
)

self._h5_object = entity.uid
self._geoh5 = entity.workspace
self._workspace = entity.workspace

if getattr(entity, "n_cells", None) is not None:
self.association = "CELL"
Expand Down Expand Up @@ -263,7 +263,7 @@ def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noq
prop = self.apply_transformations(prop)

# Save results
with fetch_active_workspace(self._geoh5, mode="r+") as w_s:
with fetch_active_workspace(self._workspace, mode="r+") as w_s:
h5_object = w_s.get_entity(self.h5_object)[0]
for cc, component in enumerate(self.components):
if component not in self.data_type:
Expand Down Expand Up @@ -386,7 +386,7 @@ def joint_index(self, value: list[int] | None):
class SaveLogFilesGeoH5(BaseSaveGeoH5):

def write(self, iteration: int, **_):
dirpath = Path(self._geoh5.h5file).parent
dirpath = Path(self._workspace.h5file).parent
filepath = dirpath / "SimPEG.out"

if iteration == 0:
Expand All @@ -412,9 +412,9 @@ def save_log(self):
"""
Save iteration metrics to comments.
"""
dirpath = Path(self._geoh5.h5file).parent
dirpath = Path(self._workspace.h5file).parent

with fetch_active_workspace(self._geoh5, mode="r+") as w_s:
with fetch_active_workspace(self._workspace, mode="r+") as w_s:
h5_object = w_s.get_entity(self.h5_object)[0]

for file in ["SimPEG.out", "SimPEG.log", "ChiFactors.log"]:
Expand Down Expand Up @@ -452,7 +452,7 @@ def write(self, iteration: int, **_):
"""
Save the model to the geoh5 file
"""
with fetch_active_workspace(self._geoh5, mode="r+") as w_s:
with fetch_active_workspace(self._workspace, mode="r+") as w_s:
h5_object = w_s.get_entity(self.h5_object)[0]

for component in self.components:
Expand Down Expand Up @@ -560,7 +560,7 @@ def write(self, iteration: int, values: list[np.ndarray] | None = None):
petro_model = self.get_values(values)
petro_model = self.apply_transformations(petro_model).flatten()
channel_name, _ = self.get_names("petrophysics", "", iteration)
with fetch_active_workspace(self._geoh5, mode="r+") as w_s:
with fetch_active_workspace(self._workspace, mode="r+") as w_s:
h5_object = w_s.get_entity(self.h5_object)[0]
data = h5_object.add_data(
{
Expand Down