Skip to content

Commit 62660be

Browse files
committed
Refactor tensors storing logic
Signed-off-by: Agrawal, Kush <kush.agrawal@intel.com>
1 parent 2987d29 commit 62660be

File tree

7 files changed

+363
-397
lines changed

7 files changed

+363
-397
lines changed

Task_1/FeTS_Challenge.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -336,26 +336,19 @@ def clipped_aggregation(local_tensors,
336336
clip_to_percentile = 80
337337

338338
# first, we need to determine how much each local update has changed the tensor from the previous value
339-
# we'll use the tensor_db search function to find the
340-
previous_tensor_value = tensor_db.search(tensor_name=tensor_name, fl_round=fl_round, tags=('trained',), origin='aggregator')
341-
logger.info(f"Tensor Values {previous_tensor_value}")
342-
logger.info(f"Tensor Values Shape {previous_tensor_value.shape[0]}")
339+
# we'll use the tensor_db retrieve function to find the previous tensor value
340+
previous_tensor_value = tensor_db.retrieve(tensor_name=tensor_name, origin='aggregator', fl_round=fl_round - 1, tags=('aggregated',))
343341

344-
if previous_tensor_value.shape[0] > 1:
345-
logger.info(previous_tensor_value)
346-
raise ValueError(f'found multiple matching tensors for {tensor_name}, tags=(model,), origin=aggregator')
347-
348-
if previous_tensor_value.shape[0] < 1:
342+
if previous_tensor_value is None:
349343
# no previous tensor, so just return the weighted average
344+
logger.info(f"previous_tensor_value is None")
350345
return weighted_average_aggregation(local_tensors,
351346
tensor_db,
352347
tensor_name,
353348
fl_round,
354349
collaborators_chosen_each_round,
355350
collaborator_times_per_round)
356351

357-
previous_tensor_value = previous_tensor_value.nparray.iloc[0]
358-
359352
# compute the deltas for each collaborator
360353
deltas = [t.tensor - previous_tensor_value for t in local_tensors]
361354

@@ -428,21 +421,20 @@ def FedAvgM_Selection(local_tensors,
428421
if tensor_name not in tensor_db.search(tags=('weight_speeds',))['tensor_name']:
429422
#weight_speeds[tensor_name] = np.zeros_like(local_tensors[0].tensor) # weight_speeds[tensor_name] = np.zeros(local_tensors[0].tensor.shape)
430423
tensor_db.store(
431-
tensor_name=tensor_name,
424+
tensor_name=tensor_name,
432425
tags=('weight_speeds',),
433426
nparray=np.zeros_like(local_tensors[0].tensor),
434427
)
428+
435429
return new_tensor_weight
436430
else:
437431
if tensor_name.endswith("weight") or tensor_name.endswith("bias"):
438432
# Calculate aggregator's last value
439433
previous_tensor_value = None
440434
for _, record in tensor_db.iterrows():
441-
print(f'record tags {record["tags"]} record round {record["round"]} record tensor_name {record["tensor_name"]}')
442-
print(f'fl_round {fl_round} tensor_name {tensor_name}')
443-
if (record['round'] == fl_round
435+
if (record['round'] == fl_round - 1 # Fetching aggregated value for previous round
444436
and record["tensor_name"] == tensor_name
445-
and record["tags"] == ("aggregated",)):
437+
and record["tags"] == ('aggregated',)):
446438
previous_tensor_value = record['nparray']
447439
break
448440

@@ -457,7 +449,7 @@ def FedAvgM_Selection(local_tensors,
457449

458450
if tensor_name not in tensor_db.search(tags=('weight_speeds',))['tensor_name']:
459451
tensor_db.store(
460-
tensor_name=tensor_name,
452+
tensor_name=tensor_name,
461453
tags=('weight_speeds',),
462454
nparray=np.zeros_like(local_tensors[0].tensor),
463455
)
@@ -481,7 +473,7 @@ def FedAvgM_Selection(local_tensors,
481473
new_tensor_weight_speed = momentum * tensor_weight_speed + average_deltas # fix delete (1-momentum)
482474

483475
tensor_db.store(
484-
tensor_name=tensor_name,
476+
tensor_name=tensor_name,
485477
tags=('weight_speeds',),
486478
nparray=new_tensor_weight_speed
487479
)
@@ -516,7 +508,7 @@ def FedAvgM_Selection(local_tensors,
516508

517509

518510
# change any of these you wish to your custom functions. You may leave defaults if you wish.
519-
aggregation_function = FedAvgM_Selection
511+
aggregation_function = weighted_average_aggregation
520512
choose_training_collaborators = all_collaborators_train
521513
training_hyper_parameters_for_round = constant_hyper_parameters
522514

@@ -525,26 +517,26 @@ def FedAvgM_Selection(local_tensors,
525517
# to those you specify immediately above. Changing the below value to False will change
526518
# this fact, excluding the three hausdorff measurements. As hausdorff distance is
527519
# expensive to compute, excluding them will speed up your experiments.
528-
include_validation_with_hausdorff=True #TODO change it to True
520+
include_validation_with_hausdorff=True
529521

530522
# We encourage participants to experiment with partitioning_1 and partitioning_2, as well as to create
531523
# other partitionings to test your changes for generalization to multiple partitionings.
532524
#institution_split_csv_filename = 'partitioning_1.csv'
533525
institution_split_csv_filename = 'small_split.csv'
534526

535527
# change this to point to the parent directory of the data
536-
brats_training_data_parent_dir = '/home/ad_kagrawa2/Data/MICCAI_FeTS2022_TrainingData'
528+
brats_training_data_parent_dir = '/raid/datasets/FeTS22/MICCAI_FeTS2022_TrainingData'
537529

538530
# increase this if you need a longer history for your algorithms
539531
# decrease this if you need to reduce system RAM consumption
540-
db_store_rounds = 1 #TODO store the tensor db for these many rounds
532+
db_store_rounds = 1
541533

542534
# this is passed to PyTorch, so set it accordingly for your system
543535
device = 'cpu'
544536

545537
# you'll want to increase this most likely. You can set it as high as you like,
546538
# however, the experiment will exit once the simulated time exceeds one week.
547-
rounds_to_train = 2 #TODO change it to 5 before merging
539+
rounds_to_train = 5
548540

549541
# (bool) Determines whether checkpoints should be saved during the experiment.
550542
# The checkpoints can grow quite large (5-10GB) so only the latest will be saved when this parameter is enabled
@@ -612,7 +604,7 @@ def FedAvgM_Selection(local_tensors,
612604
# the data you want to run inference over (assumed to be the experiment that just completed)
613605

614606
#data_path = </PATH/TO/CHALLENGE_VALIDATION_DATA>
615-
data_path = '/home/ad_kagrawa2/Data/MICCAI_FeTS2022_ValidationData'
607+
data_path = '/raid/datasets/FeTS22/MICCAI_FeTS2022_ValidationData'
616608
validation_csv_filename = 'validation.csv'
617609

618610
# you can keep these the same if you wish

Task_1/fets_challenge/checkpoint_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def save_checkpoint(checkpoint_folder, agg_tensor_db,
2828
best_dice_over_time_auc,
2929
collaborators_chosen_each_round,
3030
collaborator_times_per_round,
31-
tensor_keys_per_col,
3231
experiment_results,
3332
summary):
3433
"""
@@ -39,7 +38,7 @@ def save_checkpoint(checkpoint_folder, agg_tensor_db,
3938
with open(f'checkpoint/{checkpoint_folder}/state.pkl', 'wb') as f:
4039
pickle.dump([collaborator_names, round_num, collaborator_time_stats, total_simulated_time,
4140
best_dice, best_dice_over_time_auc, collaborators_chosen_each_round,
42-
collaborator_times_per_round, tensor_keys_per_col, experiment_results, summary], f)
41+
collaborator_times_per_round, experiment_results, summary], f)
4342

4443
def load_checkpoint(checkpoint_folder):
4544
"""

Task_1/fets_challenge/config/gandlf_config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ model:
3131
final_layer: softmax
3232
ignore_label_validation: null
3333
norm_type: instance
34+
num_channels: 4
3435
nested_training:
3536
testing: 1
3637
validation: -5
@@ -56,7 +57,7 @@ scaling_factor: 1
5657
scheduler:
5758
type: triangle_modified
5859
track_memory_usage: false
59-
verbose: True
60+
verbose: False
6061
version:
6162
maximum: 0.1.0
6263
minimum: 0.0.14

Task_1/fets_challenge/experiment.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,28 @@
2020
from openfl.experimental.workflow.interface import Aggregator, Collaborator
2121
from openfl.experimental.workflow.runtime import LocalRuntime
2222

23+
from GANDLF.config_manager import ConfigManager
24+
2325
logger = getLogger(__name__)
2426
# This catches PyTorch UserWarnings for CPU
2527
warnings.filterwarnings("ignore", category=UserWarning)
2628

27-
def aggregator_private_attributes(
28-
aggregation_type, collaborator_names, db_store_rounds):
29-
return {"aggregation_type" : aggregation_type,
30-
"collaborator_names": collaborator_names,
31-
"checkpoint_folder":None,
32-
"db_store_rounds":db_store_rounds
33-
}
34-
35-
36-
def collaborator_private_attributes(
37-
index, gandlf_config, train_csv_path, val_csv_path):
38-
return {
39-
"index": index,
40-
"gandlf_config": gandlf_config,
41-
"train_csv_path": train_csv_path,
42-
"val_csv_path": val_csv_path
43-
}
29+
def aggregator_private_attributes(aggregation_type, collaborator_names, db_store_rounds):
30+
return {
31+
"aggregation_type" : aggregation_type,
32+
"collaborator_names": collaborator_names,
33+
"checkpoint_folder":None,
34+
"db_store_rounds":db_store_rounds,
35+
"agg_tensor_dict":{}
36+
}
37+
38+
39+
def collaborator_private_attributes(index, train_csv_path, val_csv_path):
40+
return {
41+
"index": index,
42+
"train_csv_path": train_csv_path,
43+
"val_csv_path": val_csv_path
44+
}
4445

4546

4647
def run_challenge_experiment(aggregation_function,
@@ -70,12 +71,20 @@ def run_challenge_experiment(aggregation_function,
7071
0.8,
7172
gandlf_csv_path)
7273

73-
print(f'Collaborator names for experiment : {collaborator_names}')
74+
logger.info(f'Collaborator names for experiment : {collaborator_names}')
7475

7576
aggregation_wrapper = CustomAggregationWrapper(aggregation_function)
7677

7778
transformed_csv_dict = extract_csv_partitions(os.path.join(work, 'gandlf_paths.csv'))
7879

80+
gandlf_conf = {}
81+
if isinstance(gandlf_config_path, str) and os.path.exists(gandlf_config_path):
82+
gandlf_conf = ConfigManager(gandlf_config_path)
83+
elif isinstance(gandlf_config_path, dict):
84+
gandlf_conf = gandlf_config_path
85+
else:
86+
exit("GANDLF config file not found. Exiting...")
87+
7988
collaborators = []
8089
for idx, col in enumerate(collaborator_names):
8190
col_dir = os.path.join(work, 'data', str(col))
@@ -96,9 +105,8 @@ def run_challenge_experiment(aggregation_function,
96105
# with ray backend with 2 collaborators
97106
num_cpus=4.0,
98107
num_gpus=0.0,
99-
# arguments required to pass to callable
108+
# private arguments required to pass to callable
100109
index=idx,
101-
gandlf_config=gandlf_config_path,
102110
train_csv_path=train_csv_path,
103111
val_csv_path=val_csv_path
104112
)
@@ -108,6 +116,7 @@ def run_challenge_experiment(aggregation_function,
108116
private_attributes_callable=aggregator_private_attributes,
109117
num_cpus=4.0,
110118
num_gpus=0.0,
119+
# private arguments required to pass to callable
111120
collaborator_names=collaborator_names,
112121
aggregation_type=aggregation_wrapper,
113122
db_store_rounds=db_store_rounds)
@@ -119,10 +128,12 @@ def run_challenge_experiment(aggregation_function,
119128
logger.info(f"Local runtime collaborators = {local_runtime.collaborators}")
120129

121130
params_dict = {"include_validation_with_hausdorff": include_validation_with_hausdorff,
122-
"choose_training_collaborators": choose_training_collaborators,
123-
"training_hyper_parameters_for_round": training_hyper_parameters_for_round,
124-
"restore_from_checkpoint_folder": restore_from_checkpoint_folder,
125-
"save_checkpoints": save_checkpoints}
131+
"use_pretrained_model": use_pretrained_model,
132+
"gandlf_config": gandlf_conf,
133+
"choose_training_collaborators": choose_training_collaborators,
134+
"training_hyper_parameters_for_round": training_hyper_parameters_for_round,
135+
"restore_from_checkpoint_folder": restore_from_checkpoint_folder,
136+
"save_checkpoints": save_checkpoints}
126137

127138
model = FeTSChallengeModel()
128139
flflow = FeTSFederatedFlow(
@@ -134,15 +145,4 @@ def run_challenge_experiment(aggregation_function,
134145

135146
flflow.runtime = local_runtime
136147
flflow.run()
137-
138-
# #TODO [Workflow - API] -> Commenting as pretrained model is not used.
139-
# if use_pretrained_model:
140-
# if device == 'cpu':
141-
# checkpoint = torch.load(f'{root}/pretrained_model/resunet_pretrained.pth',map_location=torch.device('cpu'))
142-
# task_runner.model.load_state_dict(checkpoint['model_state_dict'])
143-
# task_runner.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
144-
# else:
145-
# checkpoint = torch.load(f'{root}/pretrained_model/resunet_pretrained.pth')
146-
# task_runner.model.load_state_dict(checkpoint['model_state_dict'])
147-
# task_runner.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
148148
return aggregator.private_attributes["checkpoint_folder"]

0 commit comments

Comments
 (0)