11import numpy as np
2+ import random
23from scipy .interpolate import interp1d
34
45class DataAugmenter :
56 """
67 Data augmentation class.
78 """
9+ def __init__ (self , augmentations , aug_chance = 0.5 ):
10+ self .augmentations = augmentations
11+ self .aug_chance = aug_chance
12+
13+ def apply_augmentations (self , data , data2 = None ):
14+ for a in self .augmentations :
15+ aug = self .get_augmentation_list ()[a ]
16+ if random .random () < self .aug_chance :
17+ if a == 'MIXUP' :
18+ data = aug (data , data2 )
19+ else :
20+ data = aug (data )
21+ return data
22+
823 def get_augmentation_list (self ):
924 """Gets a list of all available augmentations.
1025
@@ -29,8 +44,9 @@ def get_augmentation_list(self):
2944 'NONE' : self .augNONE ,
3045 }
3146
32- def augGNOISE (self , data , mag = 1 ):
33- gaussian_noise = np .random .normal (np .mean (data , axis = 0 ) * mag , np .std (data , axis = 0 ) * mag , data .shape )
47+ def augGNOISE (self , data , max_mag = 0.25 ):
48+ noise_factor = random .random ()
49+ gaussian_noise = np .random .normal (np .mean (data , axis = 0 ) * noise_factor * max_mag , np .std (data , axis = 0 ) * noise_factor * max_mag , data .shape )
3450 return data + gaussian_noise
3551
3652 def augCS (self , data ):
@@ -47,7 +63,7 @@ def augCROP(self, data, crop_percent=0.2):
4763 end_idx = start_idx + crop_length
4864 return data [start_idx :end_idx , :]
4965
50- def augMAG (self , data , min_mag = 0.5 , max_mag = 1.5 ):
66+ def augMAG (self , data , min_mag = 0.5 , max_mag = 2 ):
5167 mag = np .random .uniform (min_mag , max_mag )
5268 return data * mag
5369
@@ -81,13 +97,11 @@ def augNOISE(self, data, noise_level=1):
8197 def augCUTOUT (self , data , drop_prob = 0.2 ):
8298 num_channels = data .shape [1 ]
8399 channels_to_zero = np .random .choice (num_channels , size = int (num_channels * drop_prob ), replace = False )
84-
85- # Zero out the selected channels
86100 data [:, channels_to_zero ] = 0
87101
88102 return data
89103
90- def augMIXUP (self , data1 , data2 , alpha = 0.2 ):
104+ def augMIXUP (self , data1 , data2 , alpha = 0.8 ):
91105 lambda_ = np .random .beta (alpha , alpha )
92106 mixed_data = lambda_ * data1 + (1 - lambda_ ) * data2
93107 return mixed_data
0 commit comments