From e1c49fe19872af4694a25f86edf26010d669b544 Mon Sep 17 00:00:00 2001 From: wpbonelli Date: Wed, 10 Dec 2025 07:59:20 -0500 Subject: [PATCH 1/6] add test and replace=False param to set_data() --- autotest/test_mf6_set_data_replace.py | 293 ++++++++++++++++++++++++++ flopy/mf6/data/mfdataarray.py | 33 ++- flopy/mf6/data/mfdatalist.py | 32 ++- flopy/mf6/data/mfdataplist.py | 31 ++- 4 files changed, 377 insertions(+), 12 deletions(-) create mode 100644 autotest/test_mf6_set_data_replace.py diff --git a/autotest/test_mf6_set_data_replace.py b/autotest/test_mf6_set_data_replace.py new file mode 100644 index 0000000000..3992d4611a --- /dev/null +++ b/autotest/test_mf6_set_data_replace.py @@ -0,0 +1,293 @@ +""" +Test for set_data() replace parameter (issue #2663). + +Tests that the replace parameter correctly removes stress period data +for periods not included in the provided dictionary. +""" + +from pathlib import Path + +import pytest + +import flopy + +pytestmark = pytest.mark.mf6 + + +def count_stress_periods_in_file(file_path): + """Count the number of 'BEGIN period' statements in a file.""" + with open(file_path, "r") as f: + return sum(1 for line in f if line.strip().upper().startswith("BEGIN PERIOD")) + + +@pytest.mark.parametrize("replace", [False, True]) +def test_set_data_replace_array(function_tmpdir, replace): + """Test set_data() replace parameter with MFTransientArray (RCH package).""" + # Create a model with 48 stress periods + sim_name = "test_model" + sim_ws = Path(function_tmpdir) / "original" + sim_ws.mkdir(exist_ok=True) + + nper_original = 48 + nper_new = 12 + + # Create simulation with 48 stress periods + sim = flopy.mf6.MFSimulation( + sim_name=sim_name, + sim_ws=str(sim_ws), + exe_name="mf6", + ) + + # Create TDIS with 48 stress periods + tdis = flopy.mf6.ModflowTdis( + sim, + nper=nper_original, + perioddata=[(1.0, 1, 1.0) for _ in range(nper_original)], + ) + + # Create IMS + flopy.mf6.ModflowIms(sim) + + # Create groundwater flow model + gwf = flopy.mf6.ModflowGwf(sim, modelname=sim_name) + + # Create DIS + nlay, nrow, ncol = 1, 10, 10 + flopy.mf6.ModflowGwfdis( + gwf, + nlay=nlay, + nrow=nrow, + ncol=ncol, + delr=100.0, + delc=100.0, + top=100.0, + botm=0.0, + ) + + # Create IC + flopy.mf6.ModflowGwfic(gwf, strt=100.0) + + # Create NPF + flopy.mf6.ModflowGwfnpf(gwf, icelltype=1, k=10.0) + + # Create OC + flopy.mf6.ModflowGwfoc( + gwf, + budget_filerecord=f"{sim_name}.cbc", + head_filerecord=f"{sim_name}.hds", + saverecord=[("HEAD", "LAST"), ("BUDGET", "LAST")], + ) + + # Create RCH package with different recharge for each stress period + rch_data = {kper: 0.001 + kper * 0.0001 for kper in range(nper_original)} + flopy.mf6.ModflowGwfrcha(gwf, recharge=rch_data) + + # Write the original simulation + sim.write_simulation() + + # Count stress periods in original file + original_rch_file = sim_ws / f"{sim_name}.rcha" + original_sp_count = count_stress_periods_in_file(original_rch_file) + assert original_sp_count == nper_original + + # Load the simulation + sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws)) + gwf2 = sim2.get_model(sim_name) + rch2 = gwf2.get_package("RCHA") + + # Create new stress period dictionary with only 12 periods + new_rch_data = {kper: 0.002 + kper * 0.0002 for kper in range(nper_new)} + + # Update TDIS NPER + tdis2 = sim2.get_package("TDIS") + tdis2.nper = nper_new + tdis2.perioddata = [(1.0, 1, 1.0) for _ in range(nper_new)] + + # Use set_data() with the replace parameter + rch2.recharge.set_data(new_rch_data, replace=replace) + + # Write the modified simulation + sim2_ws = Path(function_tmpdir) / f"modified_replace_{replace}" + sim2_ws.mkdir(exist_ok=True) + sim2.set_sim_path(str(sim2_ws)) + sim2.write_simulation() + + # Count stress periods in modified file + modified_rch_file = sim2_ws / f"{sim_name}.rcha" + modified_sp_count = count_stress_periods_in_file(modified_rch_file) + + if replace: + # With replace=True, should only have 12 stress periods + # NOTE: Currently fails due to block header persistence issue + # When fixed, this should pass + assert modified_sp_count == nper_new, ( + f"Expected {nper_new} stress periods with replace=True, got {modified_sp_count}" + ) + else: + # With replace=False (backwards compatible), all 48 periods remain + # Periods 12-47 will be written as empty periods + assert modified_sp_count == nper_original, ( + f"Expected {nper_original} stress periods with replace=False, got {modified_sp_count}" + ) + + # Verify data values are correct for the new periods + with open(modified_rch_file, "r") as f: + content = f.read() + # Check that period 1 has the new recharge value + assert "0.00200000" in content or "2.00000000E-03" in content + # Check that period 12 has the new recharge value + assert "0.00420000" in content or "4.20000000E-03" in content + + +@pytest.mark.parametrize("replace", [False, True]) +def test_set_data_replace_list(function_tmpdir, replace): + """Test set_data() replace parameter with MFTransientList (WEL package).""" + # Create a model with 24 stress periods + sim_name = "test_wel_model" + sim_ws = Path(function_tmpdir) / "wel_original" + sim_ws.mkdir(exist_ok=True) + + nper_original = 24 + nper_new = 6 + + # Create simulation + sim = flopy.mf6.MFSimulation( + sim_name=sim_name, + sim_ws=str(sim_ws), + exe_name="mf6", + ) + + # Create TDIS + tdis = flopy.mf6.ModflowTdis( + sim, + nper=nper_original, + perioddata=[(1.0, 1, 1.0) for _ in range(nper_original)], + ) + + # Create IMS + flopy.mf6.ModflowIms(sim) + + # Create groundwater flow model + gwf = flopy.mf6.ModflowGwf(sim, modelname=sim_name) + + # Create DIS + nlay, nrow, ncol = 1, 10, 10 + flopy.mf6.ModflowGwfdis( + gwf, + nlay=nlay, + nrow=nrow, + ncol=ncol, + delr=100.0, + delc=100.0, + top=100.0, + botm=0.0, + ) + + # Create IC + flopy.mf6.ModflowGwfic(gwf, strt=100.0) + + # Create NPF + flopy.mf6.ModflowGwfnpf(gwf, icelltype=1, k=10.0) + + # Create OC + flopy.mf6.ModflowGwfoc( + gwf, + budget_filerecord=f"{sim_name}.cbc", + head_filerecord=f"{sim_name}.hds", + saverecord=[("HEAD", "LAST"), ("BUDGET", "LAST")], + ) + + # Create WEL package with different pumping rates for each stress period + wel_data = { + kper: [[(0, 5, 5), -1000.0 - kper * 10.0]] for kper in range(nper_original) + } + flopy.mf6.ModflowGwfwel(gwf, stress_period_data=wel_data) + + # Write the original simulation + sim.write_simulation() + + # Count stress periods in original file + original_wel_file = sim_ws / f"{sim_name}.wel" + original_sp_count = count_stress_periods_in_file(original_wel_file) + assert original_sp_count == nper_original + + # Load the simulation + sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws)) + gwf2 = sim2.get_model(sim_name) + wel2 = gwf2.get_package("WEL") + + # Create new stress period dictionary with only 6 periods + new_wel_data = { + kper: [[(0, 5, 5), -2000.0 - kper * 20.0]] for kper in range(nper_new) + } + + # Update TDIS NPER + tdis2 = sim2.get_package("TDIS") + tdis2.nper = nper_new + tdis2.perioddata = [(1.0, 1, 1.0) for _ in range(nper_new)] + + # Use set_data() with the replace parameter + wel2.stress_period_data.set_data(new_wel_data, replace=replace) + + # Write the modified simulation + sim2_ws = Path(function_tmpdir) / f"wel_modified_replace_{replace}" + sim2_ws.mkdir(exist_ok=True) + sim2.set_sim_path(str(sim2_ws)) + sim2.write_simulation() + + # Count stress periods in modified file + modified_wel_file = sim2_ws / f"{sim_name}.wel" + modified_sp_count = count_stress_periods_in_file(modified_wel_file) + + if replace: + # With replace=True, should only have 6 stress periods + # NOTE: Currently fails due to block header persistence issue + assert modified_sp_count == nper_new, ( + f"Expected {nper_new} stress periods with replace=True, got {modified_sp_count}" + ) + else: + # With replace=False, all 24 periods remain + assert modified_sp_count == nper_original, ( + f"Expected {nper_original} stress periods with replace=False, got {modified_sp_count}" + ) + + +def test_set_data_without_replace_backwards_compatible(function_tmpdir): + """Test that set_data() without replace parameter maintains backwards compatibility.""" + # This test ensures that existing code that relies on the "update" behavior + # continues to work as expected + sim_name = "test_compat" + sim_ws = Path(function_tmpdir) / "compat" + sim_ws.mkdir(exist_ok=True) + + # Create a simple model with 10 stress periods + sim = flopy.mf6.MFSimulation(sim_name=sim_name, sim_ws=str(sim_ws), exe_name="mf6") + flopy.mf6.ModflowTdis(sim, nper=10, perioddata=[(1.0, 1, 1.0) for _ in range(10)]) + flopy.mf6.ModflowIms(sim) + gwf = flopy.mf6.ModflowGwf(sim, modelname=sim_name) + flopy.mf6.ModflowGwfdis(gwf, nlay=1, nrow=10, ncol=10) + flopy.mf6.ModflowGwfic(gwf, strt=100.0) + flopy.mf6.ModflowGwfnpf(gwf, k=10.0) + flopy.mf6.ModflowGwfoc(gwf) + + # Create RCH with initial data for periods 0-4 + initial_data = dict.fromkeys(range(5), 0.001) + rch = flopy.mf6.ModflowGwfrcha(gwf, recharge=initial_data) + + # Update periods 5-9 using set_data without replace parameter + # This should ADD to the existing data, not replace it + additional_data = dict.fromkeys(range(5, 10), 0.002) + rch.recharge.set_data(additional_data) # replace defaults to False + + # Write simulation + sim.write_simulation() + + # Load and verify all 10 periods are present + sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws)) + gwf2 = sim2.get_model(sim_name) + rch2 = gwf2.get_package("RCHA") + + # Check that both sets of periods are present + for kper in range(10): + data = rch2.recharge.get_data(key=kper) + assert data is not None, f"Period {kper} should have data" diff --git a/flopy/mf6/data/mfdataarray.py b/flopy/mf6/data/mfdataarray.py index e00c877b1a..9405b64543 100644 --- a/flopy/mf6/data/mfdataarray.py +++ b/flopy/mf6/data/mfdataarray.py @@ -1890,7 +1890,7 @@ def _build_period_data( output[sp] = data return output - def set_record(self, data_record): + def set_record(self, data_record, replace=False): """Sets data and metadata at layer `layer` and time `key` to `data_record`. For unlayered data do not pass in `layer`. @@ -1902,10 +1902,15 @@ def set_record(self, data_record): and metadata (factor, iprn, filename, binary, data) for a given stress period. How to define the dictionary of data and metadata is described in the MFData class's set_record method. + replace : bool + If True, all existing stress period keys not present in the + dictionary will be removed. If False (default), existing keys + not in the dictionary are preserved. Default is False for + backwards compatibility. """ - self._set_data_record(data_record, is_record=True) + self._set_data_record(data_record, is_record=True, replace=replace) - def set_data(self, data, multiplier=None, layer=None, key=None): + def set_data(self, data, multiplier=None, layer=None, key=None, replace=False): """Sets the contents of the data at layer `layer` and time `key` to `data` with multiplier `multiplier`. For unlayered data do not pass in `layer`. @@ -1926,15 +1931,33 @@ def set_data(self, data, multiplier=None, layer=None, key=None): key : int Zero based stress period to assign data too. Does not apply if `data` is a dictionary. + replace : bool + If True and `data` is a dictionary, all existing stress period + keys not present in the dictionary will be removed. If False + (default), existing keys not in the dictionary are preserved. + This provides a way to completely replace stress period data + rather than update it. Default is False for backwards + compatibility. """ - self._set_data_record(data, multiplier, layer, key) + self._set_data_record(data, multiplier, layer, key, replace=replace) def _set_data_record( - self, data, multiplier=None, layer=None, key=None, is_record=False + self, data, multiplier=None, layer=None, key=None, is_record=False, replace=False ): if isinstance(data, dict): # each item in the dictionary is a list for one stress period # the dictionary key is the stress period the list is for + + # If replace=True, remove existing keys not in the new data + if replace and self._data_storage: + existing_keys = set(self._data_storage.keys()) + provided_keys = set(data.keys()) + keys_to_remove = existing_keys - provided_keys + for key_to_remove in keys_to_remove: + self.remove_transient_key(key_to_remove) + if key_to_remove in self.empty_keys: + del self.empty_keys[key_to_remove] + del_keys = [] for key, list_item in data.items(): if list_item is None: diff --git a/flopy/mf6/data/mfdatalist.py b/flopy/mf6/data/mfdatalist.py index 529177c41c..2f95c75a8e 100644 --- a/flopy/mf6/data/mfdatalist.py +++ b/flopy/mf6/data/mfdatalist.py @@ -1780,7 +1780,7 @@ def get_data(self, key=None, apply_mult=False, **kwargs): else: return None - def set_record(self, data_record, autofill=False, check_data=True): + def set_record(self, data_record, autofill=False, check_data=True, replace=False): """Sets the contents of the data based on the contents of 'data_record`. @@ -1795,15 +1795,21 @@ def set_record(self, data_record, autofill=False, check_data=True): Automatically correct data check_data : bool Whether to verify the data + replace : bool + If True, all existing stress period keys not present in the + dictionary will be removed. If False (default), existing keys + not in the dictionary are preserved. Default is False for + backwards compatibility. """ self._set_data_record( data_record, autofill=autofill, check_data=check_data, is_record=True, + replace=replace, ) - def set_data(self, data, key=None, autofill=False): + def set_data(self, data, key=None, autofill=False, replace=False): """Sets the contents of the data at time `key` to `data`. Parameters @@ -1819,17 +1825,35 @@ def set_data(self, data, key=None, autofill=False): if `data` is a dictionary. autofill : bool Automatically correct data. + replace : bool + If True and `data` is a dictionary, all existing stress period + keys not present in the dictionary will be removed. If False + (default), existing keys not in the dictionary are preserved. + This provides a way to completely replace stress period data + rather than update it. Default is False for backwards + compatibility. """ - self._set_data_record(data, key, autofill) + self._set_data_record(data, key, autofill, replace=replace) def _set_data_record( - self, data, key=None, autofill=False, check_data=False, is_record=False + self, data, key=None, autofill=False, check_data=False, is_record=False, replace=False ): self._cache_model_grid = True if isinstance(data, dict): if "filename" not in data and "data" not in data: # each item in the dictionary is a list for one stress period # the dictionary key is the stress period the list is for + + # If replace=True, remove existing keys not in the new data + if replace and self._data_storage: + existing_keys = set(self._data_storage.keys()) + provided_keys = set(data.keys()) + keys_to_remove = existing_keys - provided_keys + for key_to_remove in keys_to_remove: + self.remove_transient_key(key_to_remove) + if key_to_remove in self.empty_keys: + del self.empty_keys[key_to_remove] + del_keys = [] for key, list_item in data.items(): if list_item is None: diff --git a/flopy/mf6/data/mfdataplist.py b/flopy/mf6/data/mfdataplist.py index a7aef4625e..50e55ada6f 100644 --- a/flopy/mf6/data/mfdataplist.py +++ b/flopy/mf6/data/mfdataplist.py @@ -2271,7 +2271,7 @@ def get_data(self, key=None, apply_mult=False, dataframe=False, **kwargs): else: return None - def set_record(self, record, autofill=False, check_data=True): + def set_record(self, record, autofill=False, check_data=True, replace=False): """Sets the contents of the data based on the contents of 'record`. @@ -2286,15 +2286,21 @@ def set_record(self, record, autofill=False, check_data=True): Automatically correct data check_data : bool Whether to verify the data + replace : bool + If True, all existing stress period keys not present in the + dictionary will be removed. If False (default), existing keys + not in the dictionary are preserved. Default is False for + backwards compatibility. """ self._set_data_record( record, autofill=autofill, check_data=check_data, is_record=True, + replace=replace, ) - def set_data(self, data, key=None, autofill=False): + def set_data(self, data, key=None, autofill=False, replace=False): """Sets the contents of the data at time `key` to `data`. Parameters @@ -2310,8 +2316,15 @@ def set_data(self, data, key=None, autofill=False): if `data` is a dictionary. autofill : bool Automatically correct data. + replace : bool + If True and `data` is a dictionary, all existing stress period + keys not present in the dictionary will be removed. If False + (default), existing keys not in the dictionary are preserved. + This provides a way to completely replace stress period data + rather than update it. Default is False for backwards + compatibility. """ - self._set_data_record(data, key, autofill) + self._set_data_record(data, key, autofill, replace=replace) def masked_4D_arrays_itr(self): """Returns list data as an iterator of a masked 4D array.""" @@ -2339,12 +2352,24 @@ def _set_data_record( autofill=False, check_data=False, is_record=False, + replace=False, ): self._cache_model_grid = True if isinstance(data_record, dict): if "filename" not in data_record and "data" not in data_record: # each item in the dictionary is a list for one stress period # the dictionary key is the stress period the list is for + + # If replace=True, remove existing keys not in the new data + if replace and self._data_storage: + existing_keys = set(self._data_storage.keys()) + provided_keys = set(data_record.keys()) + keys_to_remove = existing_keys - provided_keys + for key_to_remove in keys_to_remove: + self.remove_transient_key(key_to_remove) + if key_to_remove in self.empty_keys: + del self.empty_keys[key_to_remove] + del_keys = [] for key, list_item in data_record.items(): list_item_record = False From 672ec91c82492876d3f6389dea2f0ce4a84ee2e8 Mon Sep 17 00:00:00 2001 From: wpbonelli Date: Wed, 10 Dec 2025 08:10:31 -0500 Subject: [PATCH 2/6] fix stale block headers --- flopy/mf6/mfpackage.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/flopy/mf6/mfpackage.py b/flopy/mf6/mfpackage.py index 702becfccd..5231e686e1 100644 --- a/flopy/mf6/mfpackage.py +++ b/flopy/mf6/mfpackage.py @@ -1286,6 +1286,32 @@ def write(self, fd, ext_file_action=ExtFileAction.copy_relative_paths): if self.structure.repeating(): repeating_datasets = self._find_repeating_datasets() for repeating_dataset in repeating_datasets: + # Clean up stale block headers that no longer have data + # This handles the case where stress periods were removed via set_data(replace=True) + # Get the set of stress periods that actually have data + active_keys = set() + for key_data in repeating_dataset.get_active_key_list(): + active_keys.add(key_data[0]) + # Also include empty keys that should be written + for key, value in repeating_dataset.empty_keys.items(): + if value: + active_keys.add(key) + + # Only clean up if we have multiple headers and active data + # This avoids breaking the initial write case where block_headers + # may have a template header with transient_key=None + if len(self.block_headers) > 1 and active_keys: + # Remove block headers for stress periods not in active_keys + headers_to_remove = [] + for i, header in enumerate(self.block_headers): + transient_key = header.get_transient_key() + if transient_key is not None and transient_key not in active_keys: + headers_to_remove.append(i) + + # Remove in reverse order to preserve indices + for i in reversed(headers_to_remove): + del self.block_headers[i] + # resolve any missing block headers self._add_missing_block_headers(repeating_dataset) for block_header in sorted(self.block_headers): From 1067261a4c41774297b8af455f67c37cc94032d8 Mon Sep 17 00:00:00 2001 From: wpbonelli Date: Wed, 10 Dec 2025 10:47:50 -0500 Subject: [PATCH 3/6] ruff, clean up test --- autotest/test_mf6_set_data_replace.py | 28 ++++++++++----------------- 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/autotest/test_mf6_set_data_replace.py b/autotest/test_mf6_set_data_replace.py index 3992d4611a..85c38fd58f 100644 --- a/autotest/test_mf6_set_data_replace.py +++ b/autotest/test_mf6_set_data_replace.py @@ -1,8 +1,6 @@ """ -Test for set_data() replace parameter (issue #2663). - -Tests that the replace parameter correctly removes stress period data -for periods not included in the provided dictionary. +Test set_data() replace parameter (issue #2663). This parameter +toggles whether .set_data() has update or replacement semantics. """ from pathlib import Path @@ -22,8 +20,6 @@ def count_stress_periods_in_file(file_path): @pytest.mark.parametrize("replace", [False, True]) def test_set_data_replace_array(function_tmpdir, replace): - """Test set_data() replace parameter with MFTransientArray (RCH package).""" - # Create a model with 48 stress periods sim_name = "test_model" sim_ws = Path(function_tmpdir) / "original" sim_ws.mkdir(exist_ok=True) @@ -118,16 +114,16 @@ def test_set_data_replace_array(function_tmpdir, replace): if replace: # With replace=True, should only have 12 stress periods - # NOTE: Currently fails due to block header persistence issue - # When fixed, this should pass assert modified_sp_count == nper_new, ( - f"Expected {nper_new} stress periods with replace=True, got {modified_sp_count}" + f"Expected {nper_new} stress periods " + f"with replace=True, got {modified_sp_count}" ) else: # With replace=False (backwards compatible), all 48 periods remain # Periods 12-47 will be written as empty periods assert modified_sp_count == nper_original, ( - f"Expected {nper_original} stress periods with replace=False, got {modified_sp_count}" + f"Expected {nper_original} stress periods " + f"with replace=False, got {modified_sp_count}" ) # Verify data values are correct for the new periods @@ -141,8 +137,6 @@ def test_set_data_replace_array(function_tmpdir, replace): @pytest.mark.parametrize("replace", [False, True]) def test_set_data_replace_list(function_tmpdir, replace): - """Test set_data() replace parameter with MFTransientList (WEL package).""" - # Create a model with 24 stress periods sim_name = "test_wel_model" sim_ws = Path(function_tmpdir) / "wel_original" sim_ws.mkdir(exist_ok=True) @@ -241,21 +235,19 @@ def test_set_data_replace_list(function_tmpdir, replace): if replace: # With replace=True, should only have 6 stress periods - # NOTE: Currently fails due to block header persistence issue assert modified_sp_count == nper_new, ( - f"Expected {nper_new} stress periods with replace=True, got {modified_sp_count}" + f"Expected {nper_new} stress periods with " + f"replace=True, got {modified_sp_count}" ) else: # With replace=False, all 24 periods remain assert modified_sp_count == nper_original, ( - f"Expected {nper_original} stress periods with replace=False, got {modified_sp_count}" + f"Expected {nper_original} stress periods with " + f"replace=False, got {modified_sp_count}" ) def test_set_data_without_replace_backwards_compatible(function_tmpdir): - """Test that set_data() without replace parameter maintains backwards compatibility.""" - # This test ensures that existing code that relies on the "update" behavior - # continues to work as expected sim_name = "test_compat" sim_ws = Path(function_tmpdir) / "compat" sim_ws.mkdir(exist_ok=True) From 36dfc9a45c0789bbbb45a07b9e11d80a8459cae9 Mon Sep 17 00:00:00 2001 From: wpbonelli Date: Wed, 10 Dec 2025 12:27:41 -0500 Subject: [PATCH 4/6] improve tests --- autotest/test_mf6_set_data_replace.py | 262 ++++++++++++-------------- 1 file changed, 120 insertions(+), 142 deletions(-) diff --git a/autotest/test_mf6_set_data_replace.py b/autotest/test_mf6_set_data_replace.py index 85c38fd58f..463bbf44b9 100644 --- a/autotest/test_mf6_set_data_replace.py +++ b/autotest/test_mf6_set_data_replace.py @@ -5,6 +5,7 @@ from pathlib import Path +import numpy as np import pytest import flopy @@ -12,44 +13,37 @@ pytestmark = pytest.mark.mf6 -def count_stress_periods_in_file(file_path): - """Count the number of 'BEGIN period' statements in a file.""" +def count_stress_periods(file_path): + """Count the number of 'BEGIN period' statements in an input file.""" with open(file_path, "r") as f: return sum(1 for line in f if line.strip().upper().startswith("BEGIN PERIOD")) -@pytest.mark.parametrize("replace", [False, True]) -def test_set_data_replace_array(function_tmpdir, replace): - sim_name = "test_model" - sim_ws = Path(function_tmpdir) / "original" - sim_ws.mkdir(exist_ok=True) +@pytest.mark.parametrize("replace", [False, True], ids=["replace", "no_replace"]) +@pytest.mark.parametrize("use_pandas", [False, True], ids=["use_pandas", "no_pandas"]) +def test_set_data_replace_array_based_pkg(function_tmpdir, replace, use_pandas): + name = "array_based" + og_ws = Path(function_tmpdir) / "original" + og_ws.mkdir(exist_ok=True) + nlay, nrow, ncol = 1, 10, 10 nper_original = 48 nper_new = 12 - # Create simulation with 48 stress periods sim = flopy.mf6.MFSimulation( - sim_name=sim_name, - sim_ws=str(sim_ws), + sim_name=name, + sim_ws=str(og_ws), exe_name="mf6", + use_pandas=use_pandas, ) - - # Create TDIS with 48 stress periods tdis = flopy.mf6.ModflowTdis( sim, nper=nper_original, perioddata=[(1.0, 1, 1.0) for _ in range(nper_original)], ) - - # Create IMS - flopy.mf6.ModflowIms(sim) - - # Create groundwater flow model - gwf = flopy.mf6.ModflowGwf(sim, modelname=sim_name) - - # Create DIS - nlay, nrow, ncol = 1, 10, 10 - flopy.mf6.ModflowGwfdis( + ims = flopy.mf6.ModflowIms(sim) + gwf = flopy.mf6.ModflowGwf(sim, modelname=name) + dis = flopy.mf6.ModflowGwfdis( gwf, nlay=nlay, nrow=nrow, @@ -59,58 +53,38 @@ def test_set_data_replace_array(function_tmpdir, replace): top=100.0, botm=0.0, ) - - # Create IC - flopy.mf6.ModflowGwfic(gwf, strt=100.0) - - # Create NPF - flopy.mf6.ModflowGwfnpf(gwf, icelltype=1, k=10.0) - - # Create OC - flopy.mf6.ModflowGwfoc( + ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0) + npf = flopy.mf6.ModflowGwfnpf(gwf, icelltype=1, k=10.0) + oc = flopy.mf6.ModflowGwfoc( gwf, - budget_filerecord=f"{sim_name}.cbc", - head_filerecord=f"{sim_name}.hds", + budget_filerecord=f"{name}.cbc", + head_filerecord=f"{name}.hds", saverecord=[("HEAD", "LAST"), ("BUDGET", "LAST")], ) - - # Create RCH package with different recharge for each stress period rch_data = {kper: 0.001 + kper * 0.0001 for kper in range(nper_original)} - flopy.mf6.ModflowGwfrcha(gwf, recharge=rch_data) + rcha = flopy.mf6.ModflowGwfrcha(gwf, recharge=rch_data) - # Write the original simulation sim.write_simulation() - # Count stress periods in original file - original_rch_file = sim_ws / f"{sim_name}.rcha" - original_sp_count = count_stress_periods_in_file(original_rch_file) + original_rch_file = og_ws / f"{name}.rcha" + original_sp_count = count_stress_periods(original_rch_file) assert original_sp_count == nper_original - # Load the simulation - sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws)) - gwf2 = sim2.get_model(sim_name) - rch2 = gwf2.get_package("RCHA") - - # Create new stress period dictionary with only 12 periods + # Update RCH new_rch_data = {kper: 0.002 + kper * 0.0002 for kper in range(nper_new)} + rcha.recharge.set_data(new_rch_data, replace=replace) - # Update TDIS NPER - tdis2 = sim2.get_package("TDIS") - tdis2.nper = nper_new - tdis2.perioddata = [(1.0, 1, 1.0) for _ in range(nper_new)] - - # Use set_data() with the replace parameter - rch2.recharge.set_data(new_rch_data, replace=replace) + # Update TDIS + tdis.nper = nper_new + tdis.perioddata = [(1.0, 1, 1.0) for _ in range(nper_new)] - # Write the modified simulation - sim2_ws = Path(function_tmpdir) / f"modified_replace_{replace}" - sim2_ws.mkdir(exist_ok=True) - sim2.set_sim_path(str(sim2_ws)) - sim2.write_simulation() + mod_ws = Path(function_tmpdir) / f"modified_replace_{replace}" + mod_ws.mkdir(exist_ok=True) + sim.set_sim_path(str(mod_ws)) + sim.write_simulation() - # Count stress periods in modified file - modified_rch_file = sim2_ws / f"{sim_name}.rcha" - modified_sp_count = count_stress_periods_in_file(modified_rch_file) + modified_rch_file = mod_ws / f"{name}.rcha" + modified_sp_count = count_stress_periods(modified_rch_file) if replace: # With replace=True, should only have 12 stress periods @@ -120,53 +94,39 @@ def test_set_data_replace_array(function_tmpdir, replace): ) else: # With replace=False (backwards compatible), all 48 periods remain - # Periods 12-47 will be written as empty periods assert modified_sp_count == nper_original, ( f"Expected {nper_original} stress periods " f"with replace=False, got {modified_sp_count}" ) - # Verify data values are correct for the new periods with open(modified_rch_file, "r") as f: content = f.read() - # Check that period 1 has the new recharge value assert "0.00200000" in content or "2.00000000E-03" in content - # Check that period 12 has the new recharge value assert "0.00420000" in content or "4.20000000E-03" in content -@pytest.mark.parametrize("replace", [False, True]) -def test_set_data_replace_list(function_tmpdir, replace): - sim_name = "test_wel_model" +@pytest.mark.parametrize("replace", [False, True], ids=["replace", "no_replace"]) +@pytest.mark.parametrize("use_pandas", [False, True], ids=["use_pandas", "no_pandas"]) +def test_set_data_replace_list_based_pkg(function_tmpdir, replace, use_pandas): + name = "list_based" sim_ws = Path(function_tmpdir) / "wel_original" sim_ws.mkdir(exist_ok=True) + nlay, nrow, ncol = 1, 10, 10 nper_original = 24 nper_new = 6 - # Create simulation sim = flopy.mf6.MFSimulation( - sim_name=sim_name, - sim_ws=str(sim_ws), - exe_name="mf6", + sim_name=name, sim_ws=str(sim_ws), exe_name="mf6", use_pandas=use_pandas ) - - # Create TDIS tdis = flopy.mf6.ModflowTdis( sim, nper=nper_original, perioddata=[(1.0, 1, 1.0) for _ in range(nper_original)], ) - - # Create IMS - flopy.mf6.ModflowIms(sim) - - # Create groundwater flow model - gwf = flopy.mf6.ModflowGwf(sim, modelname=sim_name) - - # Create DIS - nlay, nrow, ncol = 1, 10, 10 - flopy.mf6.ModflowGwfdis( + ims = flopy.mf6.ModflowIms(sim) + gwf = flopy.mf6.ModflowGwf(sim, modelname=name) + dis = flopy.mf6.ModflowGwfdis( gwf, nlay=nlay, nrow=nrow, @@ -176,62 +136,42 @@ def test_set_data_replace_list(function_tmpdir, replace): top=100.0, botm=0.0, ) - - # Create IC - flopy.mf6.ModflowGwfic(gwf, strt=100.0) - - # Create NPF - flopy.mf6.ModflowGwfnpf(gwf, icelltype=1, k=10.0) - - # Create OC - flopy.mf6.ModflowGwfoc( + ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0) + npf = flopy.mf6.ModflowGwfnpf(gwf, icelltype=1, k=10.0) + oc = flopy.mf6.ModflowGwfoc( gwf, - budget_filerecord=f"{sim_name}.cbc", - head_filerecord=f"{sim_name}.hds", + budget_filerecord=f"{name}.cbc", + head_filerecord=f"{name}.hds", saverecord=[("HEAD", "LAST"), ("BUDGET", "LAST")], ) - - # Create WEL package with different pumping rates for each stress period wel_data = { kper: [[(0, 5, 5), -1000.0 - kper * 10.0]] for kper in range(nper_original) } - flopy.mf6.ModflowGwfwel(gwf, stress_period_data=wel_data) + wel = flopy.mf6.ModflowGwfwel(gwf, stress_period_data=wel_data) - # Write the original simulation sim.write_simulation() - # Count stress periods in original file - original_wel_file = sim_ws / f"{sim_name}.wel" - original_sp_count = count_stress_periods_in_file(original_wel_file) + original_wel_file = sim_ws / f"{name}.wel" + original_sp_count = count_stress_periods(original_wel_file) assert original_sp_count == nper_original - # Load the simulation - sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws)) - gwf2 = sim2.get_model(sim_name) - wel2 = gwf2.get_package("WEL") - - # Create new stress period dictionary with only 6 periods + # Update WEL new_wel_data = { kper: [[(0, 5, 5), -2000.0 - kper * 20.0]] for kper in range(nper_new) } + wel.stress_period_data.set_data(new_wel_data, replace=replace) - # Update TDIS NPER - tdis2 = sim2.get_package("TDIS") - tdis2.nper = nper_new - tdis2.perioddata = [(1.0, 1, 1.0) for _ in range(nper_new)] - - # Use set_data() with the replace parameter - wel2.stress_period_data.set_data(new_wel_data, replace=replace) + # Update TDIS + tdis.nper = nper_new + tdis.perioddata = [(1.0, 1, 1.0) for _ in range(nper_new)] - # Write the modified simulation - sim2_ws = Path(function_tmpdir) / f"wel_modified_replace_{replace}" - sim2_ws.mkdir(exist_ok=True) - sim2.set_sim_path(str(sim2_ws)) - sim2.write_simulation() + mod_ws = Path(function_tmpdir) / f"wel_modified_replace_{replace}" + mod_ws.mkdir(exist_ok=True) + sim.set_sim_path(str(mod_ws)) + sim.write_simulation() - # Count stress periods in modified file - modified_wel_file = sim2_ws / f"{sim_name}.wel" - modified_sp_count = count_stress_periods_in_file(modified_wel_file) + modified_wel_file = mod_ws / f"{name}.wel" + modified_sp_count = count_stress_periods(modified_wel_file) if replace: # With replace=True, should only have 6 stress periods @@ -247,39 +187,77 @@ def test_set_data_replace_list(function_tmpdir, replace): ) -def test_set_data_without_replace_backwards_compatible(function_tmpdir): - sim_name = "test_compat" +def test_set_data_update_array_based_pkg(function_tmpdir): + name = "update_array_based" sim_ws = Path(function_tmpdir) / "compat" sim_ws.mkdir(exist_ok=True) - # Create a simple model with 10 stress periods - sim = flopy.mf6.MFSimulation(sim_name=sim_name, sim_ws=str(sim_ws), exe_name="mf6") - flopy.mf6.ModflowTdis(sim, nper=10, perioddata=[(1.0, 1, 1.0) for _ in range(10)]) - flopy.mf6.ModflowIms(sim) - gwf = flopy.mf6.ModflowGwf(sim, modelname=sim_name) - flopy.mf6.ModflowGwfdis(gwf, nlay=1, nrow=10, ncol=10) - flopy.mf6.ModflowGwfic(gwf, strt=100.0) - flopy.mf6.ModflowGwfnpf(gwf, k=10.0) - flopy.mf6.ModflowGwfoc(gwf) - - # Create RCH with initial data for periods 0-4 + sim = flopy.mf6.MFSimulation(sim_name=name, sim_ws=str(sim_ws), exe_name="mf6") + tdis = flopy.mf6.ModflowTdis( + sim, nper=10, perioddata=[(1.0, 1, 1.0) for _ in range(10)] + ) + ims = flopy.mf6.ModflowIms(sim) + gwf = flopy.mf6.ModflowGwf(sim, modelname=name) + dis = flopy.mf6.ModflowGwfdis(gwf, nlay=1, nrow=10, ncol=10) + ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0) + npf = flopy.mf6.ModflowGwfnpf(gwf, k=10.0) + oc = flopy.mf6.ModflowGwfoc(gwf) + initial_data = dict.fromkeys(range(5), 0.001) rch = flopy.mf6.ModflowGwfrcha(gwf, recharge=initial_data) - # Update periods 5-9 using set_data without replace parameter - # This should ADD to the existing data, not replace it additional_data = dict.fromkeys(range(5, 10), 0.002) rch.recharge.set_data(additional_data) # replace defaults to False - # Write simulation sim.write_simulation() - # Load and verify all 10 periods are present sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws)) - gwf2 = sim2.get_model(sim_name) + gwf2 = sim2.get_model(name) rch2 = gwf2.get_package("RCHA") - # Check that both sets of periods are present for kper in range(10): data = rch2.recharge.get_data(key=kper) + assert np.allclose(data, 0.001 if kper < 5 else 0.002) + + +def test_set_data_update_list_based_pkg(function_tmpdir): + name = "update_list_based" + sim_ws = Path(function_tmpdir) / "wel_update" + sim_ws.mkdir(exist_ok=True) + + sim = flopy.mf6.MFSimulation(sim_name=name, sim_ws=str(sim_ws), exe_name="mf6") + tdis = flopy.mf6.ModflowTdis( + sim, nper=10, perioddata=[(1.0, 1, 1.0) for _ in range(10)] + ) + ims = flopy.mf6.ModflowIms(sim) + gwf = flopy.mf6.ModflowGwf(sim, modelname=name) + dis = flopy.mf6.ModflowGwfdis(gwf, nlay=1, nrow=10, ncol=10) + ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0) + npf = flopy.mf6.ModflowGwfnpf(gwf, k=10.0) + oc = flopy.mf6.ModflowGwfoc(gwf) + + initial_data = {kper: [[(0, 5, 5), -1000.0]] for kper in range(5)} + wel = flopy.mf6.ModflowGwfwel(gwf, stress_period_data=initial_data) + + additional_data = {kper: [[(0, 7, 7), -2000.0]] for kper in range(5, 10)} + wel.stress_period_data.set_data(additional_data) # replace defaults to False + + sim.write_simulation() + + sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws)) + gwf2 = sim2.get_model(name) + wel2 = gwf2.get_package("WEL") + + for kper in range(10): + data = wel2.stress_period_data.get_data(key=kper) assert data is not None, f"Period {kper} should have data" + if kper < 5: + # Original data should be at (0, 5, 5) + assert len(data) == 1 + assert data[0]["cellid"] == (0, 5, 5) + assert data[0]["q"] == -1000.0 + else: + # Additional data should be at (0, 7, 7) + assert len(data) == 1 + assert data[0]["cellid"] == (0, 7, 7) + assert data[0]["q"] == -2000.0 From 1938bfe946b97d8dacf3dbbfad6626a6184d2268 Mon Sep 17 00:00:00 2001 From: wpbonelli Date: Wed, 10 Dec 2025 13:26:48 -0500 Subject: [PATCH 5/6] clean up, improve docstrings and comments --- flopy/mf6/data/mfdataarray.py | 33 +++++++++++++++------------------ flopy/mf6/data/mfdatalist.py | 33 +++++++++++++++------------------ flopy/mf6/data/mfdataplist.py | 33 +++++++++++++++------------------ flopy/mf6/mfpackage.py | 14 ++++++-------- 4 files changed, 51 insertions(+), 62 deletions(-) diff --git a/flopy/mf6/data/mfdataarray.py b/flopy/mf6/data/mfdataarray.py index 9405b64543..27d7e2a92c 100644 --- a/flopy/mf6/data/mfdataarray.py +++ b/flopy/mf6/data/mfdataarray.py @@ -1903,10 +1903,10 @@ def set_record(self, data_record, replace=False): stress period. How to define the dictionary of data and metadata is described in the MFData class's set_record method. replace : bool - If True, all existing stress period keys not present in the - dictionary will be removed. If False (default), existing keys - not in the dictionary are preserved. Default is False for - backwards compatibility. + Perform the operation with replacement semantics: all existing + stress period keys not present in the new dictionary will be + removed. If False, existing keys not in the new dictionary + will be preserved. Defaults False for backwards compatibility. """ self._set_data_record(data_record, is_record=True, replace=replace) @@ -1932,12 +1932,11 @@ def set_data(self, data, multiplier=None, layer=None, key=None, replace=False): Zero based stress period to assign data too. Does not apply if `data` is a dictionary. replace : bool - If True and `data` is a dictionary, all existing stress period - keys not present in the dictionary will be removed. If False - (default), existing keys not in the dictionary are preserved. - This provides a way to completely replace stress period data - rather than update it. Default is False for backwards - compatibility. + If True and `data` is a dictionary, perform the operation + with replacement semantics: all existing stress period keys + not present in the new dictionary will be removed. If False, + existing keys not in the new dictionary will be preserved. + Defaults False for backwards compatibility. """ self._set_data_record(data, multiplier, layer, key, replace=replace) @@ -1948,15 +1947,13 @@ def _set_data_record( # each item in the dictionary is a list for one stress period # the dictionary key is the stress period the list is for - # If replace=True, remove existing keys not in the new data + # If replacing, remove keys not in the new data if replace and self._data_storage: - existing_keys = set(self._data_storage.keys()) - provided_keys = set(data.keys()) - keys_to_remove = existing_keys - provided_keys - for key_to_remove in keys_to_remove: - self.remove_transient_key(key_to_remove) - if key_to_remove in self.empty_keys: - del self.empty_keys[key_to_remove] + keys_to_remove = set(self._data_storage.keys()) - set(data.keys()) + for k in keys_to_remove: + self.remove_transient_key(k) + if k in self.empty_keys: + del self.empty_keys[k] del_keys = [] for key, list_item in data.items(): diff --git a/flopy/mf6/data/mfdatalist.py b/flopy/mf6/data/mfdatalist.py index 2f95c75a8e..b31f47cec9 100644 --- a/flopy/mf6/data/mfdatalist.py +++ b/flopy/mf6/data/mfdatalist.py @@ -1796,10 +1796,10 @@ def set_record(self, data_record, autofill=False, check_data=True, replace=False check_data : bool Whether to verify the data replace : bool - If True, all existing stress period keys not present in the - dictionary will be removed. If False (default), existing keys - not in the dictionary are preserved. Default is False for - backwards compatibility. + Perform the operation with replacement semantics: all existing + stress period keys not present in the new dictionary will be + removed. If False, existing keys not in the new dictionary + will be preserved. Defaults False for backwards compatibility. """ self._set_data_record( data_record, @@ -1826,12 +1826,11 @@ def set_data(self, data, key=None, autofill=False, replace=False): autofill : bool Automatically correct data. replace : bool - If True and `data` is a dictionary, all existing stress period - keys not present in the dictionary will be removed. If False - (default), existing keys not in the dictionary are preserved. - This provides a way to completely replace stress period data - rather than update it. Default is False for backwards - compatibility. + If True and `data` is a dictionary, perform the operation + with replacement semantics: all existing stress period keys + not present in the new dictionary will be removed. If False, + existing keys not in the new dictionary will be preserved. + Defaults False for backwards compatibility. """ self._set_data_record(data, key, autofill, replace=replace) @@ -1844,15 +1843,13 @@ def _set_data_record( # each item in the dictionary is a list for one stress period # the dictionary key is the stress period the list is for - # If replace=True, remove existing keys not in the new data + # If replacing, remove keys not in the new data if replace and self._data_storage: - existing_keys = set(self._data_storage.keys()) - provided_keys = set(data.keys()) - keys_to_remove = existing_keys - provided_keys - for key_to_remove in keys_to_remove: - self.remove_transient_key(key_to_remove) - if key_to_remove in self.empty_keys: - del self.empty_keys[key_to_remove] + keys_to_remove = set(self._data_storage.keys()) - set(data.keys()) + for k in keys_to_remove: + self.remove_transient_key(k) + if k in self.empty_keys: + del self.empty_keys[k] del_keys = [] for key, list_item in data.items(): diff --git a/flopy/mf6/data/mfdataplist.py b/flopy/mf6/data/mfdataplist.py index 50e55ada6f..5b9e208252 100644 --- a/flopy/mf6/data/mfdataplist.py +++ b/flopy/mf6/data/mfdataplist.py @@ -2287,10 +2287,10 @@ def set_record(self, record, autofill=False, check_data=True, replace=False): check_data : bool Whether to verify the data replace : bool - If True, all existing stress period keys not present in the - dictionary will be removed. If False (default), existing keys - not in the dictionary are preserved. Default is False for - backwards compatibility. + Perform the operation with replacement semantics: all existing + stress period keys not present in the new dictionary will be + removed. If False, existing keys not in the new dictionary + will be preserved. Defaults False for backwards compatibility. """ self._set_data_record( record, @@ -2317,12 +2317,11 @@ def set_data(self, data, key=None, autofill=False, replace=False): autofill : bool Automatically correct data. replace : bool - If True and `data` is a dictionary, all existing stress period - keys not present in the dictionary will be removed. If False - (default), existing keys not in the dictionary are preserved. - This provides a way to completely replace stress period data - rather than update it. Default is False for backwards - compatibility. + If True and `data` is a dictionary, perform the operation + with replacement semantics: all existing stress period keys + not present in the new dictionary will be removed. If False, + existing keys not in the new dictionary will be preserved. + Defaults False for backwards compatibility. """ self._set_data_record(data, key, autofill, replace=replace) @@ -2360,15 +2359,13 @@ def _set_data_record( # each item in the dictionary is a list for one stress period # the dictionary key is the stress period the list is for - # If replace=True, remove existing keys not in the new data + # If replacing, remove keys not in the new data if replace and self._data_storage: - existing_keys = set(self._data_storage.keys()) - provided_keys = set(data_record.keys()) - keys_to_remove = existing_keys - provided_keys - for key_to_remove in keys_to_remove: - self.remove_transient_key(key_to_remove) - if key_to_remove in self.empty_keys: - del self.empty_keys[key_to_remove] + keys_to_remove = set(self._data_storage.keys()) - set(data_record.keys()) + for k in keys_to_remove: + self.remove_transient_key(k) + if k in self.empty_keys: + del self.empty_keys[k] del_keys = [] for key, list_item in data_record.items(): diff --git a/flopy/mf6/mfpackage.py b/flopy/mf6/mfpackage.py index 5231e686e1..c360f2032e 100644 --- a/flopy/mf6/mfpackage.py +++ b/flopy/mf6/mfpackage.py @@ -1286,26 +1286,24 @@ def write(self, fd, ext_file_action=ExtFileAction.copy_relative_paths): if self.structure.repeating(): repeating_datasets = self._find_repeating_datasets() for repeating_dataset in repeating_datasets: - # Clean up stale block headers that no longer have data - # This handles the case where stress periods were removed via set_data(replace=True) - # Get the set of stress periods that actually have data + # Get stress periods that actually have data, + # including deliberately empty stress periods active_keys = set() for key_data in repeating_dataset.get_active_key_list(): active_keys.add(key_data[0]) - # Also include empty keys that should be written for key, value in repeating_dataset.empty_keys.items(): if value: active_keys.add(key) # Only clean up if we have multiple headers and active data # This avoids breaking the initial write case where block_headers - # may have a template header with transient_key=None + # may have a template header with transient_key=None. Otherwise we + # get IndexError when _build_repeating_header tries to use index -1. if len(self.block_headers) > 1 and active_keys: - # Remove block headers for stress periods not in active_keys headers_to_remove = [] for i, header in enumerate(self.block_headers): - transient_key = header.get_transient_key() - if transient_key is not None and transient_key not in active_keys: + k = header.get_transient_key() + if k is not None and k not in active_keys: headers_to_remove.append(i) # Remove in reverse order to preserve indices From 37dd34880d03f4047c44ca81e8b751b71b8f7c8f Mon Sep 17 00:00:00 2001 From: wpbonelli Date: Thu, 11 Dec 2025 16:10:34 -0500 Subject: [PATCH 6/6] fix for multiple datasets in the same block --- flopy/mf6/mfpackage.py | 49 ++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/flopy/mf6/mfpackage.py b/flopy/mf6/mfpackage.py index c360f2032e..270a9bb655 100644 --- a/flopy/mf6/mfpackage.py +++ b/flopy/mf6/mfpackage.py @@ -1285,31 +1285,38 @@ def write(self, fd, ext_file_action=ExtFileAction.copy_relative_paths): return if self.structure.repeating(): repeating_datasets = self._find_repeating_datasets() + + # First, collect active keys from ALL datasets in this block + # This is important for blocks with multiple datasets (e.g., storage package + # has both "steady-state" and "transient" datasets) that share block_headers. + # We need to preserve headers that are active in ANY dataset, not just the + # current one being processed. + all_active_keys = set() for repeating_dataset in repeating_datasets: - # Get stress periods that actually have data, - # including deliberately empty stress periods - active_keys = set() for key_data in repeating_dataset.get_active_key_list(): - active_keys.add(key_data[0]) + all_active_keys.add(key_data[0]) for key, value in repeating_dataset.empty_keys.items(): if value: - active_keys.add(key) - - # Only clean up if we have multiple headers and active data - # This avoids breaking the initial write case where block_headers - # may have a template header with transient_key=None. Otherwise we - # get IndexError when _build_repeating_header tries to use index -1. - if len(self.block_headers) > 1 and active_keys: - headers_to_remove = [] - for i, header in enumerate(self.block_headers): - k = header.get_transient_key() - if k is not None and k not in active_keys: - headers_to_remove.append(i) - - # Remove in reverse order to preserve indices - for i in reversed(headers_to_remove): - del self.block_headers[i] - + all_active_keys.add(key) + + # Clean up stale block headers once, using combined active keys from all datasets + # Only clean up if we have multiple headers and active data. + # This avoids breaking the initial write case where block_headers + # may have a template header with transient_key=None. Otherwise we + # get IndexError when _build_repeating_header tries to use index -1. + if len(self.block_headers) > 1 and all_active_keys: + headers_to_remove = [] + for i, header in enumerate(self.block_headers): + k = header.get_transient_key() + if k is not None and k not in all_active_keys: + headers_to_remove.append(i) + + # Remove in reverse order to preserve indices + for i in reversed(headers_to_remove): + del self.block_headers[i] + + # Now add missing block headers for each dataset + for repeating_dataset in repeating_datasets: # resolve any missing block headers self._add_missing_block_headers(repeating_dataset) for block_header in sorted(self.block_headers):