diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index 190fdf393c..f2a6497a51 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -4,7 +4,7 @@ _validate_multiplier, _check_length_objective_funcs_multipliers, ) - +from typing import Callable import numpy as np from dask.distributed import Client, Future @@ -100,6 +100,9 @@ def _setter_broadcast(objfct, key, value): """ Broadcast a value to all workers. """ + if isinstance(value, Callable): + value = value(objfct) + if hasattr(objfct, key): setattr(objfct, key, value) @@ -565,3 +568,24 @@ def residuals(self, m, f=None): residuals += client.gather(future_residuals) return residuals + + def broadcast_updates(self, updates: dict): + """ + Set the attributes of the objective functions and simulations + """ + stores = [] + client = self.client + + for fun, (key, value) in updates.items(): + worker = client.who_has(fun)[fun.key] + stores.append( + client.submit( + _setter_broadcast, + fun, + key, + value, + workers=worker, + ) + ) + + self.client.gather(stores) # blocking call to ensure all models were stored diff --git a/simpeg/directives/_vector_models.py b/simpeg/directives/_vector_models.py index 8d131e7f36..dba2f319e5 100644 --- a/simpeg/directives/_vector_models.py +++ b/simpeg/directives/_vector_models.py @@ -1,5 +1,5 @@ import numpy as np - +from dask.distributed import Future from . import ( BaseSaveGeoH5, InversionDirective, @@ -56,8 +56,11 @@ def update(self): self.invProb.model = m self.opt.xc = self.invProb.model - for misfit in self.dmisfit.objfcts: - misfit.simulation.model = m + if isinstance(self.dmisfit.objfcts[0], Future): + self.dmisfit.model = m + else: + for misfit in self.dmisfit.objfcts: + misfit.simulation.model = m def _reproject(self, m): """ @@ -71,6 +74,15 @@ def _reproject(self, m): return m +def update_map(misfit): + if isinstance(misfit.simulation, MetaSimulation): + misfit.simulation.simulations[0].chiMap = ( + SphericalSystem() * misfit.simulation.simulations[0].chiMap + ) + else: + misfit.simulation.chiMap = SphericalSystem() * misfit.simulation.chiMap + + class VectorInversion(InversionDirective): """ Control a vector inversion from Cartesian to spherical coordinates. @@ -90,7 +102,7 @@ def __init__( self, simulations: list, regularizations: ComboObjectiveFunction, **kwargs ): self.reference_angles = (False, False, False) - self.simulations = simulations + self.misfits = simulations self.regularizations = regularizations set_kwargs(self, **kwargs) @@ -98,11 +110,7 @@ def __init__( @property def target(self): if getattr(self, "_target", None) is None: - nD = 0 - for survey in self.survey: - nD += survey.nD - - self._target = nD * self.chifact_target + self._target = np.hstack(self.invProb.dpred).shape[0] * self.chifact_target return self._target @@ -116,10 +124,6 @@ def initialize(self): self.reference_model = reg.reference_model - for dmisfit in self.dmisfit.objfcts: - if getattr(dmisfit.simulation, "coordinate_system", None) is not None: - dmisfit.simulation.coordinate_system = self.mode - def endIter(self): model = self.invProb.model.copy() @@ -217,20 +221,12 @@ def endIter(self): self.opt.upper[indices[nC:]] = np.inf updates = {} - for simulation in self.simulations: - if isinstance(simulation, MetaSimulation): - - if hasattr(self.dmisfit, "client"): - updates[simulation] = ( - "chiMap", - SphericalSystem() * simulation.simulations[0].chiMap, - ) - else: - simulation.simulations[0].chiMap = ( - SphericalSystem() * simulation.simulations[0].chiMap - ) + for misfit in self.misfits: + + if isinstance(misfit, Future): + updates[misfit] = ("", update_map) else: - simulation.chiMap = SphericalSystem() * simulation.chiMap + update_map(misfit) if hasattr(self.dmisfit, "client"): self.dmisfit.broadcast_updates(updates)