@@ -336,26 +336,19 @@ def clipped_aggregation(local_tensors,
336336 clip_to_percentile = 80
337337
338338 # first, we need to determine how much each local update has changed the tensor from the previous value
339- # we'll use the tensor_db search function to find the
340- previous_tensor_value = tensor_db .search (tensor_name = tensor_name , fl_round = fl_round , tags = ('trained' ,), origin = 'aggregator' )
341- logger .info (f"Tensor Values { previous_tensor_value } " )
342- logger .info (f"Tensor Values Shape { previous_tensor_value .shape [0 ]} " )
339+ # we'll use the tensor_db retrieve function to find the previous tensor value
340+ previous_tensor_value = tensor_db .retrieve (tensor_name = tensor_name , origin = 'aggregator' , fl_round = fl_round - 1 , tags = ('aggregated' ,))
343341
344- if previous_tensor_value .shape [0 ] > 1 :
345- logger .info (previous_tensor_value )
346- raise ValueError (f'found multiple matching tensors for { tensor_name } , tags=(model,), origin=aggregator' )
347-
348- if previous_tensor_value .shape [0 ] < 1 :
342+ if previous_tensor_value is None :
349343 # no previous tensor, so just return the weighted average
344+ logger .info (f"previous_tensor_value is None" )
350345 return weighted_average_aggregation (local_tensors ,
351346 tensor_db ,
352347 tensor_name ,
353348 fl_round ,
354349 collaborators_chosen_each_round ,
355350 collaborator_times_per_round )
356351
357- previous_tensor_value = previous_tensor_value .nparray .iloc [0 ]
358-
359352 # compute the deltas for each collaborator
360353 deltas = [t .tensor - previous_tensor_value for t in local_tensors ]
361354
@@ -428,21 +421,20 @@ def FedAvgM_Selection(local_tensors,
428421 if tensor_name not in tensor_db .search (tags = ('weight_speeds' ,))['tensor_name' ]:
429422 #weight_speeds[tensor_name] = np.zeros_like(local_tensors[0].tensor) # weight_speeds[tensor_name] = np.zeros(local_tensors[0].tensor.shape)
430423 tensor_db .store (
431- tensor_name = tensor_name ,
424+ tensor_name = tensor_name ,
432425 tags = ('weight_speeds' ,),
433426 nparray = np .zeros_like (local_tensors [0 ].tensor ),
434427 )
428+
435429 return new_tensor_weight
436430 else :
437431 if tensor_name .endswith ("weight" ) or tensor_name .endswith ("bias" ):
438432 # Calculate aggregator's last value
439433 previous_tensor_value = None
440434 for _ , record in tensor_db .iterrows ():
441- print (f'record tags { record ["tags" ]} record round { record ["round" ]} record tensor_name { record ["tensor_name" ]} ' )
442- print (f'fl_round { fl_round } tensor_name { tensor_name } ' )
443- if (record ['round' ] == fl_round
435+ if (record ['round' ] == fl_round - 1 # Fetching aggregated value for previous round
444436 and record ["tensor_name" ] == tensor_name
445- and record ["tags" ] == (" aggregated" ,)):
437+ and record ["tags" ] == (' aggregated' ,)):
446438 previous_tensor_value = record ['nparray' ]
447439 break
448440
@@ -457,7 +449,7 @@ def FedAvgM_Selection(local_tensors,
457449
458450 if tensor_name not in tensor_db .search (tags = ('weight_speeds' ,))['tensor_name' ]:
459451 tensor_db .store (
460- tensor_name = tensor_name ,
452+ tensor_name = tensor_name ,
461453 tags = ('weight_speeds' ,),
462454 nparray = np .zeros_like (local_tensors [0 ].tensor ),
463455 )
@@ -481,7 +473,7 @@ def FedAvgM_Selection(local_tensors,
481473 new_tensor_weight_speed = momentum * tensor_weight_speed + average_deltas # fix delete (1-momentum)
482474
483475 tensor_db .store (
484- tensor_name = tensor_name ,
476+ tensor_name = tensor_name ,
485477 tags = ('weight_speeds' ,),
486478 nparray = new_tensor_weight_speed
487479 )
@@ -516,7 +508,7 @@ def FedAvgM_Selection(local_tensors,
516508
517509
518510# change any of these you wish to your custom functions. You may leave defaults if you wish.
519- aggregation_function = FedAvgM_Selection
511+ aggregation_function = weighted_average_aggregation
520512choose_training_collaborators = all_collaborators_train
521513training_hyper_parameters_for_round = constant_hyper_parameters
522514
@@ -525,26 +517,26 @@ def FedAvgM_Selection(local_tensors,
525517# to those you specify immediately above. Changing the below value to False will change
526518# this fact, excluding the three hausdorff measurements. As hausdorff distance is
527519# expensive to compute, excluding them will speed up your experiments.
528- include_validation_with_hausdorff = True #TODO change it to True
520+ include_validation_with_hausdorff = True
529521
530522# We encourage participants to experiment with partitioning_1 and partitioning_2, as well as to create
531523# other partitionings to test your changes for generalization to multiple partitionings.
532524#institution_split_csv_filename = 'partitioning_1.csv'
533525institution_split_csv_filename = 'small_split.csv'
534526
535527# change this to point to the parent directory of the data
536- brats_training_data_parent_dir = '/home/ad_kagrawa2/Data /MICCAI_FeTS2022_TrainingData'
528+ brats_training_data_parent_dir = '/raid/datasets/FeTS22 /MICCAI_FeTS2022_TrainingData'
537529
538530# increase this if you need a longer history for your algorithms
539531# decrease this if you need to reduce system RAM consumption
540- db_store_rounds = 1 #TODO store the tensor db for these many rounds
532+ db_store_rounds = 1
541533
542534# this is passed to PyTorch, so set it accordingly for your system
543535device = 'cpu'
544536
545537# you'll want to increase this most likely. You can set it as high as you like,
546538# however, the experiment will exit once the simulated time exceeds one week.
547- rounds_to_train = 2 #TODO change it to 5 before merging
539+ rounds_to_train = 5
548540
549541# (bool) Determines whether checkpoints should be saved during the experiment.
550542# The checkpoints can grow quite large (5-10GB) so only the latest will be saved when this parameter is enabled
@@ -612,7 +604,7 @@ def FedAvgM_Selection(local_tensors,
612604# the data you want to run inference over (assumed to be the experiment that just completed)
613605
614606#data_path = </PATH/TO/CHALLENGE_VALIDATION_DATA>
615- data_path = '/home/ad_kagrawa2/Data /MICCAI_FeTS2022_ValidationData'
607+ data_path = '/raid/datasets/FeTS22 /MICCAI_FeTS2022_ValidationData'
616608validation_csv_filename = 'validation.csv'
617609
618610# you can keep these the same if you wish
0 commit comments