diff --git a/gusto/configuration.py b/gusto/configuration.py index 5aebf5a33..b0de48cb5 100644 --- a/gusto/configuration.py +++ b/gusto/configuration.py @@ -61,7 +61,7 @@ def __setattr__(self, name, value): When attributes are provided as floats or integers, these are converted to Firedrake :class:`Constant` objects, other than a handful of special - integers (dumpfreq, pddumpfreq, chkptfreq and log_level). + integers (dumpfreq, pddumpfreq and chkptfreq). Args: name: the attribute's name. diff --git a/gusto/equations.py b/gusto/equations.py index c07aed867..52e8db012 100644 --- a/gusto/equations.py +++ b/gusto/equations.py @@ -172,7 +172,8 @@ def __init__(self, domain, function_space, field_name, Vu=None, equation's prognostic is defined on. field_name (str): name of the prognostic field. Vu (:class:`FunctionSpace`, optional): the function space for the - velocity field. If this is Defaults to None. + velocity field. Defaults to None in which case use the + HDiv space. diffusion_parameters (:class:`DiffusionParameters`, optional): parameters describing the diffusion to be applied. """ diff --git a/gusto/io.py b/gusto/io.py index 40352085f..7dd841300 100644 --- a/gusto/io.py +++ b/gusto/io.py @@ -140,7 +140,8 @@ def dump(self, field_creator, t): class DiagnosticsOutput(object): """Object for outputting global diagnostic data.""" - def __init__(self, filename, diagnostics, description, comm, create=True): + def __init__(self, filename, diagnostics, description, comm, + ensemble_comm=None, create=True): """ Args: filename (str): name of file to output to. @@ -155,9 +156,14 @@ def __init__(self, filename, diagnostics, description, comm, create=True): self.filename = filename self.diagnostics = diagnostics self.comm = comm + self.ensemble_comm = ensemble_comm + if ensemble_comm is not None: + self.write_to_file = ensemble_comm.rank == 0 and comm.rank == 0 + else: + self.write_to_file = comm.rank == 0 if not create: return - if self.comm.rank == 0: + if self.write_to_file: with Dataset(filename, "w") as dataset: dataset.description = "Diagnostics data for simulation {desc}".format(desc=description) dataset.history = "Created {t}".format(t=time.ctime()) @@ -185,7 +191,7 @@ def dump(self, state_fields, t): diagnostic = getattr(self.diagnostics, dname) diagnostics.append((fname, dname, diagnostic(field))) - if self.comm.rank == 0: + if self.write_to_file: with Dataset(self.filename, "a") as dataset: idx = dataset.dimensions["time"].size dataset.variables["time"][idx:idx + 1] = t @@ -354,7 +360,7 @@ def setup_diagnostics(self, state_fields): if fname in state_fields.to_dump: self.diagnostics.register(fname) - def setup_dump(self, state_fields, t, pick_up=False): + def setup_dump(self, state_fields, t, ensemble, pick_up=False): """ Sets up a series of things used for outputting. @@ -377,6 +383,18 @@ def setup_dump(self, state_fields, t, pick_up=False): raise_parallel_exception = 0 error = None + if ensemble is not None: + ens_comm = ensemble.ensemble_comm + comm = ensemble.comm + create_dir = ens_comm.rank + comm.rank == 0 + create_files = ens_comm.rank == 0 + else: + ens_comm = None + comm = self.mesh.comm + create_dir = comm.Get_rank() == 0 + create_files = True + self.ensemble = ensemble + if any([self.output.dump_vtus, self.output.dump_nc, self.output.dumplist_latlon, self.output.dump_diagnostics, self.output.point_data, self.output.checkpoint and not pick_up]): @@ -385,7 +403,7 @@ def setup_dump(self, state_fields, t, pick_up=False): running_tests = '--running-tests' in sys.argv or "pytest" in self.output.dirname # Raising exceptions needs to be done in parallel - if self.mesh.comm.Get_rank() == 0: + if create_dir: # Create results directory if it doesn't already exist if not path.exists(self.dumpdir): try: @@ -400,7 +418,10 @@ def setup_dump(self, state_fields, t, pick_up=False): # Gather errors from each rank and raise appropriate error everywhere # This allreduce also ensures that all ranks are in sync wrt the results dir - raise_exception = self.mesh.comm.allreduce(raise_parallel_exception, op=MPI.MAX) + raise_exception = comm.allreduce(raise_parallel_exception, op=MPI.MAX) + if ensemble is not None: + raise_exception = ens_comm.allreduce(raise_exception, op=MPI.MAX) + if raise_exception == 1: raise GustoIOError(f'results directory {self.dumpdir} already exists') elif raise_exception == 2: @@ -421,12 +442,12 @@ def setup_dump(self, state_fields, t, pick_up=False): if pick_up: next(self.dumpcount) - if self.output.dump_vtus: + if self.output.dump_vtus and create_files: # setup pvd output file outfile_pvd = path.join(self.dumpdir, "field_output.pvd") self.pvd_dumpfile = VTKFile( outfile_pvd, project_output=self.output.project_fields, - comm=self.mesh.comm) + comm=comm) if self.output.dump_nc: self.nc_filename = path.join(self.dumpdir, "field_output.nc") @@ -453,10 +474,11 @@ def setup_dump(self, state_fields, t, pick_up=False): # setup the latlon coordinate mesh and make output file if len(self.output.dumplist_latlon) > 0: mesh_ll = get_flat_latlon_mesh(self.mesh) - outfile_ll = path.join(self.dumpdir, "field_output_latlon.pvd") - self.dumpfile_ll = VTKFile(outfile_ll, - project_output=self.output.project_fields, - comm=self.mesh.comm) + if create_files: + outfile_ll = path.join(self.dumpdir, "field_output_latlon.pvd") + self.dumpfile_ll = VTKFile(outfile_ll, + project_output=self.output.project_fields, + comm=comm) # make functions on latlon mesh, as specified by dumplist_latlon self.to_dump_latlon = [] @@ -472,11 +494,12 @@ def setup_dump(self, state_fields, t, pick_up=False): # already exist, in which case we just need the filenames if self.output.dump_diagnostics: diagnostics_filename = self.dumpdir+"/diagnostics.nc" - to_create = not (path.isfile(diagnostics_filename) and pick_up) + to_create = not (path.isfile(diagnostics_filename) and pick_up) and create_files self.diagnostic_output = DiagnosticsOutput(diagnostics_filename, self.diagnostics, self.output.dirname, - self.mesh.comm, + comm=comm, + ensemble_comm=ens_comm, create=to_create) # if picking-up, don't do initial dump @@ -665,6 +688,10 @@ def dump(self, state_fields, t, step, initial_steps=None): completed by a multi-level time scheme. Defaults to None. """ output = self.output + if self.ensemble is not None: + write_file = self.ensemble.ensemble_comm.rank == 0 + else: + write_file = True # Diagnostics: # Compute diagnostic fields @@ -703,7 +730,7 @@ def dump(self, state_fields, t, step, initial_steps=None): # dump fields self.write_nc_dump(t) - if output.dump_vtus: + if output.dump_vtus and write_file: # dump fields self.pvd_dumpfile.write(*self.to_dump) diff --git a/gusto/timeloop.py b/gusto/timeloop.py index 717b9002e..6a3a89af1 100644 --- a/gusto/timeloop.py +++ b/gusto/timeloop.py @@ -24,7 +24,7 @@ class BaseTimestepper(object, metaclass=ABCMeta): """Base class for timesteppers.""" - def __init__(self, equation, io): + def __init__(self, equation, io, ensemble=None): """ Args: equation (:class:`PrognosticEquation`): the prognostic equation. @@ -33,6 +33,7 @@ def __init__(self, equation, io): self.equation = equation self.io = io + self.ensemble = ensemble self.dt = self.equation.domain.dt self.t = self.equation.domain.t self.reference_profiles_initialised = False @@ -177,7 +178,7 @@ def run(self, t, tmax, pick_up=False): # Set up dump, which may also include an initial dump with timed_stage("Dump output"): - self.io.setup_dump(self.fields, t, pick_up) + self.io.setup_dump(self.fields, t, self.ensemble, pick_up) self.t.assign(t) @@ -249,7 +250,7 @@ class Timestepper(BaseTimestepper): """ def __init__(self, equation, scheme, io, spatial_methods=None, - physics_parametrisations=None): + physics_parametrisations=None, ensemble=None): """ Args: equation (:class:`PrognosticEquation`): the prognostic equation @@ -284,7 +285,7 @@ def __init__(self, equation, scheme, io, spatial_methods=None, else: self.physics_parametrisations = [] - super().__init__(equation=equation, io=io) + super().__init__(equation=equation, io=io, ensemble=ensemble) @property def transporting_velocity(self): @@ -716,6 +717,7 @@ def timestep(self): with timed_stage("Apply forcing terms"): logger.info('SIQN: Explicit forcing') + # Put explicit forcing into xstar self.forcing.apply(x_after_slow, xn, xstar(self.field_name), "explicit") diff --git a/integration-tests/model/test_parallel_io.py b/integration-tests/model/test_parallel_io.py new file mode 100644 index 000000000..85b3dabb0 --- /dev/null +++ b/integration-tests/model/test_parallel_io.py @@ -0,0 +1,32 @@ +from firedrake import Ensemble, COMM_WORLD, PeriodicUnitSquareMesh +from gusto import * +import pytest + + +@pytest.mark.parallel(nprocs=4) +@pytest.mark.parametrize("spatial_parallelism", [True, False]) +def test_parallel_io(tmpdir, spatial_parallelism): + + if spatial_parallelism: + ensemble = Ensemble(COMM_WORLD, 2) + else: + ensemble = Ensemble(COMM_WORLD, 1) + + mesh = PeriodicUnitSquareMesh(10, 10, comm=ensemble.comm) + dt = 0.1 + domain = Domain(mesh, dt, "BDM", 1) + + # Equation + parameters = ShallowWaterParameters(H=100) + equation = ShallowWaterEquations(domain, parameters) + + # I/O + output = OutputParameters(dirname=str(tmpdir)) + io = IO(domain, output) + + # Time stepper + spatial_methods = [DGUpwind(equation, "u"), DGUpwind(equation, "D")] + stepper = Timestepper(equation, SSPRK3(domain), io, spatial_methods, + ensemble=ensemble) + + stepper.run(0, 3*dt)