Skip to content

Commit 0dcae52

Browse files
committed
Rearranged private functions
1 parent 91bd7ab commit 0dcae52

File tree

3 files changed

+197
-133
lines changed

3 files changed

+197
-133
lines changed

pySWATPlus/calibration.py

Lines changed: 57 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class Calibration(pymoo.core.problem.Problem): # type: ignore[misc]
8686
8787
!!! note
8888
The sub-key `usecols` should **not** be included here. Although no error will be raised, it will be ignored during class initialization
89-
because the `sim_col` sub-key from the `objectives` input is automatically used as `usecols`. Including it manually has no effect.
89+
because the `sim_col` sub-key from the `objective_config` input is automatically used as `usecols`. Including it manually has no effect.
9090
9191
```python
9292
extract_data = {
@@ -124,7 +124,7 @@ class Calibration(pymoo.core.problem.Problem): # type: ignore[misc]
124124
}
125125
```
126126
127-
objectives (dict[str, dict[str, str]]): A nested dictionary specifying objectives configuration. The top-level keys
127+
objective_config (dict[str, dict[str, str]]): A nested dictionary specifying objectives configuration. The top-level keys
128128
are same as keys of `extract_data` (e.g., `channel_sd_day.txt`). Each key must map to a non-empty dictionary containing the following sub-keys:
129129
130130
- `sim_col` (str): **Required.** Name of the column containing simulated values.
@@ -142,7 +142,7 @@ class Calibration(pymoo.core.problem.Problem): # type: ignore[misc]
142142
Avoid using `MARE` if `obs_col` contains zero values, as it will cause a division-by-zero error.
143143
144144
```python
145-
objectives = {
145+
objective_config = {
146146
'channel_sd_day.txt': {
147147
'sim_col': 'flo_out',
148148
'obs_col': 'discharge',
@@ -186,7 +186,7 @@ def __init__(
186186
txtinout_dir: str | pathlib.Path,
187187
extract_data: dict[str, dict[str, typing.Any]],
188188
observe_data: dict[str, dict[str, str]],
189-
objectives: dict[str, dict[str, str]],
189+
objective_config: dict[str, dict[str, str]],
190190
algorithm: str,
191191
n_gen: int,
192192
pop_size: int,
@@ -205,11 +205,17 @@ def __init__(
205205
calsim_dir = pathlib.Path(calsim_dir).resolve()
206206
txtinout_dir = pathlib.Path(txtinout_dir).resolve()
207207

208-
# Check same top-level keys in dictionaries
209-
if not (extract_data.keys() == observe_data.keys() == objectives.keys()):
210-
raise KeyError(
211-
'Mismatch of key names. Ensure extract_data, observe_data, and objectives have identical top-level keys.'
212-
)
208+
# Validate same top-level keys in dictionaries
209+
validators._dict_key_equal(
210+
extract_data=extract_data,
211+
observe_data=observe_data,
212+
objective_config=objective_config
213+
)
214+
215+
# Dictionary of metric key name
216+
df_key = {
217+
obj: obj.split('.')[0] + '_df' for obj in objective_config
218+
}
213219

214220
# Validate initialization of TxtinoutReader class
215221
tmp_reader = TxtinoutReader(
@@ -232,29 +238,31 @@ def __init__(
232238
)
233239

234240
# Validate objectives configuration
235-
self._validate_objectives_config(
236-
objectives=objectives
241+
validators._metric_config(
242+
input_dict=objective_config,
243+
var_name='objective_config'
237244
)
245+
for obj in objective_config:
246+
if objective_config[obj]['indicator'] == 'PBIAS':
247+
raise ValueError(
248+
'Indicator "PBIAS" is invalid in objective_config; it lacks a defined optimization direction'
249+
)
238250

239251
# Validate observe_data configuration
240-
self._validate_observe_data_config(
252+
validators._observe_data_config(
241253
observe_data=observe_data
242254
)
243255

244256
# Dictionary of observed DataFrames
245-
observe_dict = {}
246-
for key in observe_data:
247-
key_df = utils._df_observe(
248-
obs_file=pathlib.Path(observe_data[key]['obs_file']).resolve(),
249-
date_format=observe_data[key]['date_format'],
250-
obs_col=objectives[key]['obs_col']
251-
)
252-
key_df.columns = ['date', 'obs']
253-
observe_dict[key.split('.')[0] + '_df'] = key_df
257+
observe_dict = utils._observe_data_dict(
258+
observe_data=observe_data,
259+
metric_config=objective_config,
260+
df_key=df_key
261+
)
254262

255263
# Validate extract_data configuration
256264
for key in extract_data:
257-
extract_data[key]['usecols'] = [objectives[key]['sim_col']]
265+
extract_data[key]['usecols'] = [objective_config[key]['sim_col']]
258266
validators._extract_data_config(
259267
extract_data=extract_data
260268
)
@@ -274,8 +282,9 @@ def __init__(
274282
# Initalize parameters
275283
self.params_bounds = params_bounds
276284
self.extract_data = extract_data
277-
self.objectives = objectives
285+
self.objective_config = objective_config
278286
self.var_names = var_names
287+
self.df_key = df_key
279288
self.observe_dict = observe_dict
280289
self.calsim_dir = calsim_dir
281290
self.txtinout_dir = txtinout_dir
@@ -289,7 +298,7 @@ def __init__(
289298
# Access properties and methods from Problem class
290299
super().__init__(
291300
n_var=len(params_bounds),
292-
n_obj=len(objectives),
301+
n_obj=len(objective_config),
293302
xl=numpy.array(var_lb),
294303
xu=numpy.array(var_ub)
295304
)
@@ -350,16 +359,12 @@ def _evaluate(
350359
# Simulation output for the population
351360
pop_sim = cpu_dict[tuple(pop)]
352361
# Iterate objectives
353-
for obj in self.objectives:
354-
# Objective indicator
355-
obj_ind = self.objectives[obj]['indicator']
356-
# Objective key name to extract simulated and obseved DataFrames
357-
obj_key = obj.split('.')[0] + '_df'
362+
for obj in self.objective_config:
358363
# Simulated DataFrame
359-
sim_df = pop_sim[obj_key]
364+
sim_df = pop_sim[self.df_key[obj]]
360365
sim_df.columns = ['date', 'sim']
361366
# Observed DataFrame
362-
obs_df = self.observe_dict[obj_key]
367+
obs_df = self.observe_dict[self.df_key[obj]]
363368
# Merge simulated and observed DataFrames by 'date' column
364369
merge_df = sim_df.merge(
365370
right=obs_df,
@@ -372,6 +377,7 @@ def _evaluate(
372377
norm_col='obs'
373378
)
374379
# Indicator method from abbreviation
380+
obj_ind = self.objective_config[obj]['indicator']
375381
indicator_method = getattr(
376382
PerformanceMetrics(),
377383
f'compute_{obj_ind.lower()}'
@@ -414,94 +420,6 @@ def _objectives_directions(
414420

415421
return objs_dirs
416422

417-
def _validate_observe_data_config(
418-
self,
419-
observe_data: dict[str, dict[str, str]],
420-
) -> None:
421-
'''
422-
Validate `observe_data` configuration.
423-
'''
424-
425-
# List of valid sub-keys of sub-dictionaries
426-
valid_subkeys = [
427-
'obs_file',
428-
'date_format'
429-
]
430-
431-
# Iterate dictionary
432-
for file_key, file_dict in observe_data.items():
433-
# Check type of a sub-dictionary
434-
if not isinstance(file_dict, dict):
435-
raise TypeError(
436-
f'Expected "{file_key}" in observe_data must be a dictionary, '
437-
f'but got type "{type(file_dict).__name__}"'
438-
)
439-
# Check sub-dictionary length
440-
if len(file_dict) != 2:
441-
raise ValueError(
442-
f'Length of "{file_key}" sub-dictionary in observe_data must be 2, '
443-
f'but got {len(file_dict)}'
444-
)
445-
# Iterate sub-key
446-
for sub_key in file_dict:
447-
# Check valid sub-key
448-
if sub_key not in valid_subkeys:
449-
raise KeyError(
450-
f'Invalid sub-key "{sub_key}" for "{file_key}" in observe_data; '
451-
f'expected sub-keys are {json.dumps(valid_subkeys)}'
452-
)
453-
454-
return None
455-
456-
def _validate_objectives_config(
457-
self,
458-
objectives: dict[str, dict[str, str]],
459-
) -> None:
460-
'''
461-
Validate `objectives` configuration.
462-
'''
463-
464-
# List of valid sub-keys of sub-dictionaries
465-
valid_subkeys = [
466-
'sim_col',
467-
'obs_col',
468-
'indicator'
469-
]
470-
471-
valid_indicators = [
472-
key for key in PerformanceMetrics().indicator_names if key != 'PBIAS'
473-
]
474-
475-
# Iterate dictionary
476-
for file_key, file_dict in objectives.items():
477-
# Check type of a sub-dictionary
478-
if not isinstance(file_dict, dict):
479-
raise TypeError(
480-
f'Expected "{file_key}" in "objectives" must be a dictionary, '
481-
f'but got type "{type(file_dict).__name__}"'
482-
)
483-
# Check sub-dictionary length
484-
if len(file_dict) != 3:
485-
raise ValueError(
486-
f'Length of "{file_key}" sub-dictionary in "objectives" must be 3, '
487-
f'but got {len(file_dict)}'
488-
)
489-
# Iterate sub-key
490-
for sub_key in file_dict:
491-
# Check valid sub-key
492-
if sub_key not in valid_subkeys:
493-
raise KeyError(
494-
f'Invalid sub-key "{sub_key}" for "{file_key}" in "objectives"; '
495-
f'expected sub-keys are {json.dumps(valid_subkeys)}'
496-
)
497-
if sub_key == 'indicator' and file_dict[sub_key] not in valid_indicators:
498-
raise ValueError(
499-
f'Invalid "indicator" value "{file_dict[sub_key]}" for "{file_key}" in "objectives"; '
500-
f'expected indicators are {valid_indicators}'
501-
)
502-
503-
return None
504-
505423
def _algorithm_class(
506424
self,
507425
algorithm: str
@@ -510,6 +428,9 @@ def _algorithm_class(
510428
Retrieve the optimization algorithm class from the `pymoo` package.
511429
'''
512430

431+
single_obj = ['GA', 'DE']
432+
multi_obj = ['NSGA2']
433+
513434
# Dictionary mapping between algorithm name and module
514435
api_module = {
515436
'GA': importlib.import_module('pymoo.algorithms.soo.nonconvex.ga'),
@@ -523,6 +444,12 @@ def _algorithm_class(
523444
f'Invalid algorithm "{algorithm}"; valid names are {list(api_module.keys())}'
524445
)
525446

447+
# Check single objective algorithm cannot be used for multiple objectives
448+
if len(self.objective_config) >= 2 and algorithm in single_obj:
449+
raise ValueError(
450+
f'Algorithm "{algorithm}" cannot handle multiple objectives; use one of {multi_obj}'
451+
)
452+
526453
# Algorithm class
527454
alg_class = typing.cast(
528455
type,
@@ -540,15 +467,15 @@ def parameter_optimization(
540467
This method executes the optimization process and returns a dictionary containing the optimized
541468
parameters, corresponding objective values, and total execution time.
542469
543-
Two JSON files are saved in the input directory `calsim_dir`.
470+
The following JSON files are saved in `calsim_dir`:
544471
545-
The file `optimization_history.json` stores the optimization history. Each key in this
546-
file is an integer starting from 1, representing the generation number. The corresponding value
547-
is a sub-dictionary with two keys: `pop` for the population data (decision variables) and `obj`
548-
for the objective function values. This file is useful for analyzing optimization progress,
549-
convergence trends, performance indicators, and visualization.
472+
- `optimization_history.json`: A dictionary containing the optimization history. Each key in this
473+
file is an integer starting from 1, representing the generation number. The corresponding value
474+
is a sub-dictionary with two keys: `pop` for the population data (decision variables) and `obj`
475+
for the objective function values. This file is useful for analyzing optimization progress,
476+
convergence trends, performance indicators, and visualization.
550477
551-
The file `optimization_result.json` contains the final output dictionary described below.
478+
- `optimization_result.json`: A dictionary containing the final output dictionary.
552479
553480
Returns:
554481
Dictionary with the following keys:
@@ -599,7 +526,7 @@ def parameter_optimization(
599526
# Sign of objective directions
600527
objs_dirs = self._objectives_directions()
601528
dir_list = [
602-
objs_dirs[v['indicator']] for k, v in self.objectives.items()
529+
objs_dirs[v['indicator']] for k, v in self.objective_config.items()
603530
]
604531
dir_sign = numpy.where(numpy.array(dir_list) == 'max', -1, 1)
605532

@@ -610,18 +537,16 @@ def parameter_optimization(
610537
'pop': gen.pop.get('X').tolist(),
611538
'obj': (gen.pop.get('F') * dir_sign).tolist()
612539
}
613-
json_file = self.calsim_dir / 'optimization_history.json'
614-
with open(json_file, 'w') as output_write:
540+
with open(self.calsim_dir / 'optimization_history.json', 'w') as output_write:
615541
json.dump(opt_hist, output_write, indent=4)
616542

617543
# Optimized output of parameters, objectives, and execution times
618-
required_time = round(result.exec_time)
619544
opt_output = {
620545
'algorithm': self.algorithm,
621546
'generation': self.n_gen,
622547
'population': self.pop_size,
623548
'total_simulation': self.pop_size * self.n_gen,
624-
'time_sec': required_time,
549+
'time_sec': round(result.exec_time),
625550
'variables': result.X,
626551
'objectives': result.F * dir_sign
627552
}
@@ -631,8 +556,7 @@ def parameter_optimization(
631556
save_output = {
632557
k: v.tolist() if k.startswith(('var', 'obj')) else v for k, v in save_output.items()
633558
}
634-
json_file = self.calsim_dir / 'optimization_result.json'
635-
with open(json_file, 'w') as output_write:
559+
with open(self.calsim_dir / 'optimization_result.json', 'w') as output_write:
636560
json.dump(save_output, output_write, indent=4)
637561

638562
return opt_output

pySWATPlus/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,25 @@ def _parameters_name_with_counter(
393393
name_counter.append(f'{p_name}|{current_count[p_name]}')
394394

395395
return name_counter
396+
397+
398+
def _observe_data_dict(
399+
observe_data: dict[str, dict[str, str]],
400+
metric_config: dict[str, dict[str, str]],
401+
df_key: dict[str, str]
402+
) -> dict[str, pandas.DataFrame]:
403+
'''
404+
Generate a dictionary mapping each entry in `observed_data` to its corresponding `DataFrame`.
405+
'''
406+
407+
observe_dict = {}
408+
for obs in observe_data:
409+
obs_df = _df_observe(
410+
obs_file=pathlib.Path(observe_data[obs]['obs_file']).resolve(),
411+
date_format=observe_data[obs]['date_format'],
412+
obs_col=metric_config[obs]['obs_col']
413+
)
414+
obs_df.columns = ['date', 'obs']
415+
observe_dict[df_key[obs]] = obs_df
416+
417+
return observe_dict

0 commit comments

Comments
 (0)