From fb60ab5f10e18bd36b59f375ee59d3f201681770 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 9 Apr 2024 15:43:45 -0700 Subject: [PATCH 01/11] Enabled extension of NonEquilibriumCyclingProtocol * Added new (de)compression + (de)serialize functions in feflow.utils.data * Include compressed System, State, and Integrator as a result for the main simulation unit. * SetupUnits now take `extends_data`, which when populated, spoof the running of the `SetupUnit._execute` method. It instead immediately returns results with values consistent to the end state of the extended SimulationUnit. * Added new test `test_pdr_extend` to test the extension functionality. --- feflow/protocols/nonequilibrium_cycling.py | 79 ++++++++++++++++++++- feflow/tests/conftest.py | 4 +- feflow/tests/test_nonequilibrium_cycling.py | 71 ++++++++++++++++++ feflow/utils/data.py | 41 ++++++++++- 4 files changed, 189 insertions(+), 6 deletions(-) diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index f3b25c6..2a4da98 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -25,7 +25,7 @@ from openff.units import unit from openff.units.openmm import to_openmm, from_openmm -from ..utils.data import serialize, deserialize +from ..utils.data import serialize, deserialize, serialize_and_compress, decompress_and_deserialize # Specific instance of logger for this module # logger = logging.getLogger(__name__) @@ -133,6 +133,31 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): from openfe.protocols.openmm_rfe import _rfe_utils from feflow.utils.hybrid_topology import HybridTopologyFactory + if extends_data := self.inputs.get('extends_data', None): + system_outfile = ctx.shared / "system.xml.bz2" + state_outfile = ctx.shared / "state.xml.bz2" + integrator_outfile = ctx.shared / "integrator.xml.bz2" + + def _write_xml(data, filename): + openmm_object = decompress_and_deserialize(data) + serialize(openmm_object, filename) + return filename + + extends_data['system'] = _write_xml( + extends_data['system'], + system_outfile, + ) + extends_data['state'] = _write_xml( + extends_data['state'], + state_outfile, + ) + extends_data['integrator'] = _write_xml( + extends_data['integrator'], + integrator_outfile, + ) + + return extends_data + # Check compatibility between states (same receptor and solvent) self._check_states_compatibility(state_a, state_b) @@ -687,7 +712,20 @@ def _execute(self, ctx, *, setup, settings, **inputs): "reverse_neq_final": reverse_neq_new_path, } finally: + compressed_state = serialize_and_compress( + context.getState(getPositions=True), + ) + + compressed_system = serialize_and_compress( + context.getSystem(), + ) + + compressed_integrator = serialize_and_compress( + context.getIntegrator(), + ) + # Explicit cleanup for GPU resources + del context, integrator return { @@ -696,6 +734,9 @@ def _execute(self, ctx, *, setup, settings, **inputs): "trajectory_paths": trajectory_paths, "log": output_log_path, "timing_info": timing_info, + "system": compressed_system, + "state": compressed_state, + "integrator": compressed_integrator, } @@ -890,10 +931,41 @@ def _create( # Handle parameters if mapping is None: raise ValueError("`mapping` is required for this Protocol") + if "ligand" not in mapping: raise ValueError("'ligand' must be specified in `mapping` dict") - if extends: - raise NotImplementedError("Can't extend simulations yet") + + extends_data = {} + if isinstance(extends, ProtocolDAGResult): + + if not all(map(lambda r: r.ok(), extends.protocol_unit_results)): + raise ValueError("Cannot extend units that failed") + + setup, simulation, _ = extends.protocol_units + r_setup, r_simulation, _ = extends.protocol_unit_results + + # confirm consistency + original_state_a = setup.inputs['state_a'].key + original_state_b = setup.inputs['state_b'].key + original_mapping = setup.inputs['mapping'] + + if original_state_a != stateA.key: + raise ValueError() + + if original_state_b != stateB.key: + raise ValueError() + + if original_mapping != mapping: + raise ValueError() + + extends_data = dict( + system=r_simulation.outputs['system'], + state=r_simulation.outputs['state'], + integrator=r_simulation.outputs['integrator'], + phase=r_setup.outputs["phase"], + initial_atom_indices=r_setup.outputs['initial_atom_indices'], + final_atom_indices=r_setup.outputs["final_atom_indices"], + ) # inputs to `ProtocolUnit.__init__` should either be `Gufe` objects # or JSON-serializable objects @@ -905,6 +977,7 @@ def _create( mapping=mapping, settings=self.settings, name="setup", + extends_data=extends_data, ) simulations = [ diff --git a/feflow/tests/conftest.py b/feflow/tests/conftest.py index d419e56..2f8f39c 100644 --- a/feflow/tests/conftest.py +++ b/feflow/tests/conftest.py @@ -74,8 +74,8 @@ def short_settings(): settings = NonEquilibriumCyclingProtocol.default_settings() settings.thermo_settings.temperature = 300 * unit.kelvin - settings.eq_steps = 25000 - settings.neq_steps = 25000 + settings.eq_steps = 1000 + settings.neq_steps = 1000 settings.work_save_frequency = 50 settings.traj_save_frequency = 250 settings.platform = "CPU" diff --git a/feflow/tests/test_nonequilibrium_cycling.py b/feflow/tests/test_nonequilibrium_cycling.py index e7e58f9..39990a1 100644 --- a/feflow/tests/test_nonequilibrium_cycling.py +++ b/feflow/tests/test_nonequilibrium_cycling.py @@ -100,6 +100,77 @@ def test_terminal_units(self, protocol_dag_result): assert isinstance(finals[0], ProtocolUnitResult) assert finals[0].name == "result" + def test_pdr_extend( + self, + protocol_short, + benzene_vacuum_system, + toluene_vacuum_system, + mapping_benzene_toluene, + tmpdir, + ): + dag = protocol_short.create( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + name="Short vacuum transformation", + mapping={"ligand": mapping_benzene_toluene}, + ) + + with tmpdir.as_cwd(): + + base_path = Path("original") + + shared = base_path / "shared" + shared.mkdir(parents=True) + + scratch = base_path / "scratch" + scratch.mkdir(parents=True) + + pdr: ProtocolDAGResult = execute_DAG( + dag, shared_basedir=shared, scratch_basedir=scratch + ) + + setup, simulation, result = pdr.protocol_units + r_setup, r_simulation, r_result = pdr.protocol_unit_results + + assert setup.inputs['extends_data'] == {} + assert isinstance(r_simulation.outputs['system'], bytes) + assert isinstance(r_simulation.outputs['state'], bytes) + assert isinstance(r_simulation.outputs['integrator'], bytes) + + end_state = r_simulation.outputs['state'] + + dag = protocol_short.create( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + name="Short vacuum transformation, but extended", + mapping={"ligand": mapping_benzene_toluene}, + extends=pdr, + ) + + with tmpdir.as_cwd(): + + base_path = Path("extended") + + shared = base_path / "shared" + shared.mkdir(parents=True) + + scratch = base_path / "scratch" + scratch.mkdir(parents=True) + pdr: ProtocolDAGResult = execute_DAG( + dag, shared_basedir=shared, scratch_basedir=scratch + ) + + setup, simulation, result = pdr.protocol_units + r_setup, r_simulation, r_result = pdr.protocol_unit_results + + assert r_setup.inputs['extends_data'] != {} + + assert isinstance(r_setup.inputs['extends_data']['system'], bytes) + assert isinstance(r_setup.inputs['extends_data']['state'], bytes) + assert isinstance(r_setup.inputs['extends_data']['integrator'], bytes) + + assert r_setup.inputs['extends_data']['state'] == end_state + def test_dag_execute_failure(self, protocol_dag_broken): protocol, dag, dagfailure = protocol_dag_broken diff --git a/feflow/utils/data.py b/feflow/utils/data.py index f829346..f417f35 100644 --- a/feflow/utils/data.py +++ b/feflow/utils/data.py @@ -1,5 +1,45 @@ import os import pathlib +import gzip + +from openmm import XmlSerializer + + +def serialize_and_compress(item): + """Serialize an OpenMM System, State, or Integrator and compress. + + Parameters + ---------- + item : System, State, or Integrator + The OpenMM object to serialize and compress. + + Returns + ------- + bytes : bytes + The compressed serialized OpenMM object. + """ + serialized = XmlSerializer.serialize(item).encode() + data = gzip.compress(serialized) + return data + + +def decompress_and_deserialize(data: bytes): + """Recover an OpenMM object from compression. + + Parameters + ---------- + data : bytes + Bytes containing a gzip compressed XML serialization + of an OpenMM object. + + Returns + ------- + deserialized + The deserialized OpenMM object. + """ + decompressed = gzip.decompress(data).decode() + deserialized = XmlSerializer.deserialize(decompressed) + return deserialized def serialize(item, filename: pathlib.Path): @@ -13,7 +53,6 @@ def serialize(item, filename: pathlib.Path): filename : str The filename to serialize to """ - from openmm import XmlSerializer # Create parent directory if it doesn't exist filename_basedir = filename.parent From 3991187a931735e6291797e35e3044f82c8c2bed Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 10 Apr 2024 11:54:13 -0700 Subject: [PATCH 02/11] Allow extending with replicates --- feflow/protocols/nonequilibrium_cycling.py | 75 +++++++++++++-------- feflow/tests/test_nonequilibrium_cycling.py | 45 ++++++++----- 2 files changed, 77 insertions(+), 43 deletions(-) diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index 2a4da98..7340cd9 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -133,28 +133,30 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): from openfe.protocols.openmm_rfe import _rfe_utils from feflow.utils.hybrid_topology import HybridTopologyFactory - if extends_data := self.inputs.get('extends_data', None): - system_outfile = ctx.shared / "system.xml.bz2" - state_outfile = ctx.shared / "state.xml.bz2" - integrator_outfile = ctx.shared / "integrator.xml.bz2" - + if extends_data := self.inputs.get('extends_data'): def _write_xml(data, filename): openmm_object = decompress_and_deserialize(data) serialize(openmm_object, filename) return filename - extends_data['system'] = _write_xml( - extends_data['system'], - system_outfile, - ) - extends_data['state'] = _write_xml( - extends_data['state'], - state_outfile, - ) - extends_data['integrator'] = _write_xml( - extends_data['integrator'], - integrator_outfile, - ) + for replicate in range(settings.num_replicates): + replicate = str(replicate) + system_outfile = ctx.shared / f"system_{replicate}.xml.bz2" + state_outfile = ctx.shared / f"state_{replicate}.xml.bz2" + integrator_outfile = ctx.shared / f"integrator_{replicate}.xml.bz2" + + extends_data['systems'][replicate] = _write_xml( + extends_data['systems'][replicate], + system_outfile, + ) + extends_data['states'][replicate] = _write_xml( + extends_data['states'][replicate], + state_outfile, + ) + extends_data['integrators'][replicate] = _write_xml( + extends_data['integrators'][replicate], + integrator_outfile, + ) return extends_data @@ -367,10 +369,14 @@ def _write_xml(data, filename): # Explicit cleanup for GPU resources del context, integrator + systems = {str(replicate_name): system_outfile for replicate_name in range(settings.num_replicates)} + states = {str(replicate_name): state_outfile for replicate_name in range(settings.num_replicates)} + integrators = {str(replicate_name): integrator_outfile for replicate_name in range(settings.num_replicates)} + return { - "system": system_outfile, - "state": state_outfile, - "integrator": integrator_outfile, + "systems": systems, + "states": states, + "integrators": integrators, "phase": phase, "initial_atom_indices": hybrid_factory.initial_atom_indices, "final_atom_indices": hybrid_factory.final_atom_indices, @@ -459,9 +465,9 @@ def _execute(self, ctx, *, setup, settings, **inputs): file_logger.addHandler(file_handler) # Get state, system, and integrator from setup unit - system = deserialize(setup.outputs["system"]) - state = deserialize(setup.outputs["state"]) - integrator = deserialize(setup.outputs["integrator"]) + system = deserialize(setup.outputs["systems"][self.name]) + state = deserialize(setup.outputs["states"][self.name]) + integrator = deserialize(setup.outputs["integrators"][self.name]) PeriodicNonequilibriumIntegrator.restore_interface(integrator) # Get atom indices for either end of the hybrid topology @@ -941,8 +947,11 @@ def _create( if not all(map(lambda r: r.ok(), extends.protocol_unit_results)): raise ValueError("Cannot extend units that failed") - setup, simulation, _ = extends.protocol_units - r_setup, r_simulation, _ = extends.protocol_unit_results + setup = extends.protocol_units[0] + simulations = extends.protocol_units[1:-1] + + r_setup = extends.protocol_unit_results[0] + r_simulations = extends.protocol_unit_results[1:-1] # confirm consistency original_state_a = setup.inputs['state_a'].key @@ -958,10 +967,20 @@ def _create( if original_mapping != mapping: raise ValueError() + systems = {} + states = {} + integrators = {} + + for r_simulation, simulation in zip(r_simulations, simulations): + sim_name = simulation.name + systems[sim_name] = r_simulation.outputs['system'] + states[sim_name] = r_simulation.outputs['state'] + integrators[sim_name] = r_simulation.outputs['integrator'] + extends_data = dict( - system=r_simulation.outputs['system'], - state=r_simulation.outputs['state'], - integrator=r_simulation.outputs['integrator'], + systems=systems, + states=states, + integrators=integrators, phase=r_setup.outputs["phase"], initial_atom_indices=r_setup.outputs['initial_atom_indices'], final_atom_indices=r_setup.outputs["final_atom_indices"], diff --git a/feflow/tests/test_nonequilibrium_cycling.py b/feflow/tests/test_nonequilibrium_cycling.py index 39990a1..878fcb9 100644 --- a/feflow/tests/test_nonequilibrium_cycling.py +++ b/feflow/tests/test_nonequilibrium_cycling.py @@ -100,15 +100,25 @@ def test_terminal_units(self, protocol_dag_result): assert isinstance(finals[0], ProtocolUnitResult) assert finals[0].name == "result" + @pytest.mark.parametrize( + "protocol", + [ + 'protocol_short', + "protocol_short_multiple_cycles", + ], + ) def test_pdr_extend( self, - protocol_short, + protocol, benzene_vacuum_system, toluene_vacuum_system, mapping_benzene_toluene, tmpdir, + request, ): - dag = protocol_short.create( + + protocol = request.getfixturevalue(protocol) + dag = protocol.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, name="Short vacuum transformation", @@ -129,17 +139,21 @@ def test_pdr_extend( dag, shared_basedir=shared, scratch_basedir=scratch ) - setup, simulation, result = pdr.protocol_units - r_setup, r_simulation, r_result = pdr.protocol_unit_results + setup = pdr.protocol_units[0] + r_setup = pdr.protocol_unit_results[0] assert setup.inputs['extends_data'] == {} - assert isinstance(r_simulation.outputs['system'], bytes) - assert isinstance(r_simulation.outputs['state'], bytes) - assert isinstance(r_simulation.outputs['integrator'], bytes) - end_state = r_simulation.outputs['state'] + end_states = {} + for simulation, r_simulation in zip(pdr.protocol_units[1:-1], pdr.protocol_unit_results[1:-1]): + assert isinstance(r_simulation.outputs['system'], bytes) + assert isinstance(r_simulation.outputs['state'], bytes) + assert isinstance(r_simulation.outputs['integrator'], bytes) - dag = protocol_short.create( + end_states[simulation.name] = r_simulation.outputs["state"] + + + dag = protocol.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, name="Short vacuum transformation, but extended", @@ -160,16 +174,17 @@ def test_pdr_extend( dag, shared_basedir=shared, scratch_basedir=scratch ) - setup, simulation, result = pdr.protocol_units - r_setup, r_simulation, r_result = pdr.protocol_unit_results + r_setup = pdr.protocol_unit_results[0] assert r_setup.inputs['extends_data'] != {} - assert isinstance(r_setup.inputs['extends_data']['system'], bytes) - assert isinstance(r_setup.inputs['extends_data']['state'], bytes) - assert isinstance(r_setup.inputs['extends_data']['integrator'], bytes) + for replicate in range(protocol.settings.num_replicates): + replicate = str(replicate) + assert isinstance(r_setup.inputs["extends_data"]["systems"][replicate], bytes) + assert isinstance(r_setup.inputs["extends_data"]["states"][replicate], bytes) + assert isinstance(r_setup.inputs["extends_data"]["integrators"][replicate], bytes) - assert r_setup.inputs['extends_data']['state'] == end_state + assert r_setup.inputs["extends_data"]["states"][replicate] == end_states[replicate] def test_dag_execute_failure(self, protocol_dag_broken): protocol, dag, dagfailure = protocol_dag_broken From 96f069373538cbc0228a41c31b018b381893cdae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Apr 2024 22:34:28 +0000 Subject: [PATCH 03/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- feflow/protocols/nonequilibrium_cycling.py | 53 +++++++++++++-------- feflow/tests/test_nonequilibrium_cycling.py | 34 ++++++++----- 2 files changed, 56 insertions(+), 31 deletions(-) diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index 7340cd9..71876bb 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -25,7 +25,12 @@ from openff.units import unit from openff.units.openmm import to_openmm, from_openmm -from ..utils.data import serialize, deserialize, serialize_and_compress, decompress_and_deserialize +from ..utils.data import ( + serialize, + deserialize, + serialize_and_compress, + decompress_and_deserialize, +) # Specific instance of logger for this module # logger = logging.getLogger(__name__) @@ -133,7 +138,8 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): from openfe.protocols.openmm_rfe import _rfe_utils from feflow.utils.hybrid_topology import HybridTopologyFactory - if extends_data := self.inputs.get('extends_data'): + if extends_data := self.inputs.get("extends_data"): + def _write_xml(data, filename): openmm_object = decompress_and_deserialize(data) serialize(openmm_object, filename) @@ -145,16 +151,16 @@ def _write_xml(data, filename): state_outfile = ctx.shared / f"state_{replicate}.xml.bz2" integrator_outfile = ctx.shared / f"integrator_{replicate}.xml.bz2" - extends_data['systems'][replicate] = _write_xml( - extends_data['systems'][replicate], + extends_data["systems"][replicate] = _write_xml( + extends_data["systems"][replicate], system_outfile, ) - extends_data['states'][replicate] = _write_xml( - extends_data['states'][replicate], - state_outfile, + extends_data["states"][replicate] = _write_xml( + extends_data["states"][replicate], + state_outfile, ) - extends_data['integrators'][replicate] = _write_xml( - extends_data['integrators'][replicate], + extends_data["integrators"][replicate] = _write_xml( + extends_data["integrators"][replicate], integrator_outfile, ) @@ -369,9 +375,18 @@ def _write_xml(data, filename): # Explicit cleanup for GPU resources del context, integrator - systems = {str(replicate_name): system_outfile for replicate_name in range(settings.num_replicates)} - states = {str(replicate_name): state_outfile for replicate_name in range(settings.num_replicates)} - integrators = {str(replicate_name): integrator_outfile for replicate_name in range(settings.num_replicates)} + systems = { + str(replicate_name): system_outfile + for replicate_name in range(settings.num_replicates) + } + states = { + str(replicate_name): state_outfile + for replicate_name in range(settings.num_replicates) + } + integrators = { + str(replicate_name): integrator_outfile + for replicate_name in range(settings.num_replicates) + } return { "systems": systems, @@ -954,9 +969,9 @@ def _create( r_simulations = extends.protocol_unit_results[1:-1] # confirm consistency - original_state_a = setup.inputs['state_a'].key - original_state_b = setup.inputs['state_b'].key - original_mapping = setup.inputs['mapping'] + original_state_a = setup.inputs["state_a"].key + original_state_b = setup.inputs["state_b"].key + original_mapping = setup.inputs["mapping"] if original_state_a != stateA.key: raise ValueError() @@ -973,16 +988,16 @@ def _create( for r_simulation, simulation in zip(r_simulations, simulations): sim_name = simulation.name - systems[sim_name] = r_simulation.outputs['system'] - states[sim_name] = r_simulation.outputs['state'] - integrators[sim_name] = r_simulation.outputs['integrator'] + systems[sim_name] = r_simulation.outputs["system"] + states[sim_name] = r_simulation.outputs["state"] + integrators[sim_name] = r_simulation.outputs["integrator"] extends_data = dict( systems=systems, states=states, integrators=integrators, phase=r_setup.outputs["phase"], - initial_atom_indices=r_setup.outputs['initial_atom_indices'], + initial_atom_indices=r_setup.outputs["initial_atom_indices"], final_atom_indices=r_setup.outputs["final_atom_indices"], ) diff --git a/feflow/tests/test_nonequilibrium_cycling.py b/feflow/tests/test_nonequilibrium_cycling.py index 878fcb9..4223fb4 100644 --- a/feflow/tests/test_nonequilibrium_cycling.py +++ b/feflow/tests/test_nonequilibrium_cycling.py @@ -103,7 +103,7 @@ def test_terminal_units(self, protocol_dag_result): @pytest.mark.parametrize( "protocol", [ - 'protocol_short', + "protocol_short", "protocol_short_multiple_cycles", ], ) @@ -142,17 +142,18 @@ def test_pdr_extend( setup = pdr.protocol_units[0] r_setup = pdr.protocol_unit_results[0] - assert setup.inputs['extends_data'] == {} + assert setup.inputs["extends_data"] == {} end_states = {} - for simulation, r_simulation in zip(pdr.protocol_units[1:-1], pdr.protocol_unit_results[1:-1]): - assert isinstance(r_simulation.outputs['system'], bytes) - assert isinstance(r_simulation.outputs['state'], bytes) - assert isinstance(r_simulation.outputs['integrator'], bytes) + for simulation, r_simulation in zip( + pdr.protocol_units[1:-1], pdr.protocol_unit_results[1:-1] + ): + assert isinstance(r_simulation.outputs["system"], bytes) + assert isinstance(r_simulation.outputs["state"], bytes) + assert isinstance(r_simulation.outputs["integrator"], bytes) end_states[simulation.name] = r_simulation.outputs["state"] - dag = protocol.create( stateA=benzene_vacuum_system, stateB=toluene_vacuum_system, @@ -176,15 +177,24 @@ def test_pdr_extend( r_setup = pdr.protocol_unit_results[0] - assert r_setup.inputs['extends_data'] != {} + assert r_setup.inputs["extends_data"] != {} for replicate in range(protocol.settings.num_replicates): replicate = str(replicate) - assert isinstance(r_setup.inputs["extends_data"]["systems"][replicate], bytes) - assert isinstance(r_setup.inputs["extends_data"]["states"][replicate], bytes) - assert isinstance(r_setup.inputs["extends_data"]["integrators"][replicate], bytes) + assert isinstance( + r_setup.inputs["extends_data"]["systems"][replicate], bytes + ) + assert isinstance( + r_setup.inputs["extends_data"]["states"][replicate], bytes + ) + assert isinstance( + r_setup.inputs["extends_data"]["integrators"][replicate], bytes + ) - assert r_setup.inputs["extends_data"]["states"][replicate] == end_states[replicate] + assert ( + r_setup.inputs["extends_data"]["states"][replicate] + == end_states[replicate] + ) def test_dag_execute_failure(self, protocol_dag_broken): protocol, dag, dagfailure = protocol_dag_broken From 1aa6669abb0867b97ee2352b01ee83835d0dd953 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Fri, 12 Apr 2024 11:53:18 -0700 Subject: [PATCH 04/11] Updated serialization/compression of OpenMM objects GufeTokenizables must be JSON serializable, so including bytes is not an option. Instead, we take the compressed bytes and encode them into a Base64 string. --- feflow/tests/test_nonequilibrium_cycling.py | 14 ++++++------- feflow/utils/data.py | 23 ++++++++++++--------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/feflow/tests/test_nonequilibrium_cycling.py b/feflow/tests/test_nonequilibrium_cycling.py index 4223fb4..5454fa4 100644 --- a/feflow/tests/test_nonequilibrium_cycling.py +++ b/feflow/tests/test_nonequilibrium_cycling.py @@ -148,9 +148,9 @@ def test_pdr_extend( for simulation, r_simulation in zip( pdr.protocol_units[1:-1], pdr.protocol_unit_results[1:-1] ): - assert isinstance(r_simulation.outputs["system"], bytes) - assert isinstance(r_simulation.outputs["state"], bytes) - assert isinstance(r_simulation.outputs["integrator"], bytes) + assert isinstance(r_simulation.outputs["system"], str) + assert isinstance(r_simulation.outputs["state"], str) + assert isinstance(r_simulation.outputs["integrator"], str) end_states[simulation.name] = r_simulation.outputs["state"] @@ -159,7 +159,7 @@ def test_pdr_extend( stateB=toluene_vacuum_system, name="Short vacuum transformation, but extended", mapping={"ligand": mapping_benzene_toluene}, - extends=pdr, + extends=ProtocolDAGResult.from_dict(pdr.to_dict()), ) with tmpdir.as_cwd(): @@ -182,13 +182,13 @@ def test_pdr_extend( for replicate in range(protocol.settings.num_replicates): replicate = str(replicate) assert isinstance( - r_setup.inputs["extends_data"]["systems"][replicate], bytes + r_setup.inputs["extends_data"]["systems"][replicate], str ) assert isinstance( - r_setup.inputs["extends_data"]["states"][replicate], bytes + r_setup.inputs["extends_data"]["states"][replicate], str ) assert isinstance( - r_setup.inputs["extends_data"]["integrators"][replicate], bytes + r_setup.inputs["extends_data"]["integrators"][replicate], str ) assert ( diff --git a/feflow/utils/data.py b/feflow/utils/data.py index f417f35..64665d0 100644 --- a/feflow/utils/data.py +++ b/feflow/utils/data.py @@ -1,11 +1,12 @@ import os import pathlib -import gzip +import bz2 +import base64 from openmm import XmlSerializer -def serialize_and_compress(item): +def serialize_and_compress(item) -> str: """Serialize an OpenMM System, State, or Integrator and compress. Parameters @@ -15,21 +16,22 @@ def serialize_and_compress(item): Returns ------- - bytes : bytes - The compressed serialized OpenMM object. + b64string : str + The compressed serialized OpenMM object encoded in a Base64 string. """ serialized = XmlSerializer.serialize(item).encode() - data = gzip.compress(serialized) - return data + compressed = bz2.compress(serialized) + b64string = base64.b64encode(compressed).decode("ascii") + return b64string -def decompress_and_deserialize(data: bytes): +def decompress_and_deserialize(data: str): """Recover an OpenMM object from compression. Parameters ---------- - data : bytes - Bytes containing a gzip compressed XML serialization + data : str + String containing a Base64 encoded bzip2 compressed XML serialization of an OpenMM object. Returns @@ -37,7 +39,8 @@ def decompress_and_deserialize(data: bytes): deserialized The deserialized OpenMM object. """ - decompressed = gzip.decompress(data).decode() + compressed = base64.b64decode(data) + decompressed = bz2.decompress(compressed).decode("utf-8") deserialized = XmlSerializer.deserialize(decompressed) return deserialized From 67392bc960ce22aa845e6ac1768f1be0ef71028d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 18:53:35 +0000 Subject: [PATCH 05/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- feflow/tests/test_nonequilibrium_cycling.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/feflow/tests/test_nonequilibrium_cycling.py b/feflow/tests/test_nonequilibrium_cycling.py index 5454fa4..9682093 100644 --- a/feflow/tests/test_nonequilibrium_cycling.py +++ b/feflow/tests/test_nonequilibrium_cycling.py @@ -181,12 +181,8 @@ def test_pdr_extend( for replicate in range(protocol.settings.num_replicates): replicate = str(replicate) - assert isinstance( - r_setup.inputs["extends_data"]["systems"][replicate], str - ) - assert isinstance( - r_setup.inputs["extends_data"]["states"][replicate], str - ) + assert isinstance(r_setup.inputs["extends_data"]["systems"][replicate], str) + assert isinstance(r_setup.inputs["extends_data"]["states"][replicate], str) assert isinstance( r_setup.inputs["extends_data"]["integrators"][replicate], str ) From 8c69b31364e2751c0b30ad12c15a2e822434c5d9 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Fri, 12 Apr 2024 18:53:35 +0000 Subject: [PATCH 06/11] Updated ValueError messages when extending * When `mapping` is `None`, use the mapping provided by the ProtocolDAGResult --- feflow/protocols/nonequilibrium_cycling.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index 71876bb..092e25c 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -974,13 +974,16 @@ def _create( original_mapping = setup.inputs["mapping"] if original_state_a != stateA.key: - raise ValueError() + raise ValueError("'stateA' key is not the same as the key provided by the 'extends' ProtocolDAGResult.") if original_state_b != stateB.key: - raise ValueError() + raise ValueError("'stateB' key is not the same as the key provided by the 'extends' ProtocolDAGResult.") - if original_mapping != mapping: - raise ValueError() + if mapping is not None: + if original_mapping != mapping: + raise ValueError("'mapping' is not consistent with the mapping provided by the 'extnds' ProtocolDAGResult.") + else: + mapping = original_mapping systems = {} states = {} From c244ae9c17b1d1678eeed74385013c2aefdb1cbe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 20:42:19 +0000 Subject: [PATCH 07/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- feflow/protocols/nonequilibrium_cycling.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index 092e25c..871dd16 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -974,14 +974,20 @@ def _create( original_mapping = setup.inputs["mapping"] if original_state_a != stateA.key: - raise ValueError("'stateA' key is not the same as the key provided by the 'extends' ProtocolDAGResult.") + raise ValueError( + "'stateA' key is not the same as the key provided by the 'extends' ProtocolDAGResult." + ) if original_state_b != stateB.key: - raise ValueError("'stateB' key is not the same as the key provided by the 'extends' ProtocolDAGResult.") + raise ValueError( + "'stateB' key is not the same as the key provided by the 'extends' ProtocolDAGResult." + ) if mapping is not None: if original_mapping != mapping: - raise ValueError("'mapping' is not consistent with the mapping provided by the 'extnds' ProtocolDAGResult.") + raise ValueError( + "'mapping' is not consistent with the mapping provided by the 'extnds' ProtocolDAGResult." + ) else: mapping = original_mapping From 714c4d05e8cc81023588427d1c3e0b9562d03481 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 15 Apr 2024 12:34:58 -0700 Subject: [PATCH 08/11] Reduced loop count --- feflow/protocols/nonequilibrium_cycling.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index 871dd16..51d3886 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -375,18 +375,13 @@ def _write_xml(data, filename): # Explicit cleanup for GPU resources del context, integrator - systems = { - str(replicate_name): system_outfile - for replicate_name in range(settings.num_replicates) - } - states = { - str(replicate_name): state_outfile - for replicate_name in range(settings.num_replicates) - } - integrators = { - str(replicate_name): integrator_outfile - for replicate_name in range(settings.num_replicates) - } + systems = dict() + states = dict() + integrators = dict() + for replicate_name in map(str, range(settings.num_replicates)): + systems[replicate_name] = system_outfile + states[replicate_name] = state_outfile + integrators[replicate_name] = integrator_outfile return { "systems": systems, From 1d46275e2a73461597be0e8ec04c212e1e15329f Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 16 Apr 2024 13:17:24 -0700 Subject: [PATCH 09/11] Use `gufe` methods to determine if extends is ok --- feflow/protocols/nonequilibrium_cycling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index 51d3886..bde1691 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -954,8 +954,8 @@ def _create( extends_data = {} if isinstance(extends, ProtocolDAGResult): - if not all(map(lambda r: r.ok(), extends.protocol_unit_results)): - raise ValueError("Cannot extend units that failed") + if not extends.ok(): + raise ValueError("Cannot extend protocols that failed") setup = extends.protocol_units[0] simulations = extends.protocol_units[1:-1] From 9c42a256fb7191ea12fb661d3aba50dadc8a79a6 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 22 Apr 2024 11:24:48 -0700 Subject: [PATCH 10/11] Use dict literal syntax over constructor --- feflow/protocols/nonequilibrium_cycling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index bde1691..d23ffa7 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -375,9 +375,9 @@ def _write_xml(data, filename): # Explicit cleanup for GPU resources del context, integrator - systems = dict() - states = dict() - integrators = dict() + systems = {} + states = {} + integrators = {} for replicate_name in map(str, range(settings.num_replicates)): systems[replicate_name] = system_outfile states[replicate_name] = state_outfile From 093b76a4c4a43cc2466a4ca9ebf7e580029096f7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Jun 2024 19:55:19 +0000 Subject: [PATCH 11/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- feflow/protocols/nonequilibrium_cycling.py | 4 +--- feflow/tests/test_nonequilibrium_cycling.py | 9 ++------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index 592022d..533e6b2 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -1042,9 +1042,7 @@ def _create( ) simulations = [ - self._simulation_unit( - setup=setup, settings=self.settings, name=f"{cycle}" - ) + self._simulation_unit(setup=setup, settings=self.settings, name=f"{cycle}") for cycle in range(num_cycles) ] diff --git a/feflow/tests/test_nonequilibrium_cycling.py b/feflow/tests/test_nonequilibrium_cycling.py index 99834b3..d1b301b 100644 --- a/feflow/tests/test_nonequilibrium_cycling.py +++ b/feflow/tests/test_nonequilibrium_cycling.py @@ -182,14 +182,9 @@ def test_pdr_extend( cycle = str(cycle) assert isinstance(r_setup.inputs["extends_data"]["systems"][cycle], str) assert isinstance(r_setup.inputs["extends_data"]["states"][cycle], str) - assert isinstance( - r_setup.inputs["extends_data"]["integrators"][cycle], str - ) + assert isinstance(r_setup.inputs["extends_data"]["integrators"][cycle], str) - assert ( - r_setup.inputs["extends_data"]["states"][cycle] - == end_states[cycle] - ) + assert r_setup.inputs["extends_data"]["states"][cycle] == end_states[cycle] # TODO: We probably need to find failure test cases as control # def test_dag_execute_failure(self, protocol_dag_broken):