Skip to content

Commit b1119cb

Browse files
committed
Updates
1 parent 45bc7d9 commit b1119cb

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

libemg/augmentations.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
11
import numpy as np
2+
import random
23
from scipy.interpolate import interp1d
34

45
class DataAugmenter:
56
"""
67
Data augmentation class.
78
"""
9+
def __init__(self, augmentations, aug_chance=0.5):
10+
self.augmentations = augmentations
11+
self.aug_chance = aug_chance
12+
13+
def apply_augmentations(self, data, data2=None):
14+
for a in self.augmentations:
15+
aug = self.get_augmentation_list()[a]
16+
if random.random() < self.aug_chance:
17+
if a == 'MIXUP':
18+
data = aug(data, data2)
19+
else:
20+
data = aug(data)
21+
return data
22+
823
def get_augmentation_list(self):
924
"""Gets a list of all available augmentations.
1025
@@ -29,8 +44,9 @@ def get_augmentation_list(self):
2944
'NONE': self.augNONE,
3045
}
3146

32-
def augGNOISE(self, data, mag=1):
33-
gaussian_noise = np.random.normal(np.mean(data, axis=0) * mag, np.std(data, axis=0) * mag, data.shape)
47+
def augGNOISE(self, data, max_mag=0.25):
48+
noise_factor = random.random()
49+
gaussian_noise = np.random.normal(np.mean(data, axis=0) * noise_factor * max_mag, np.std(data, axis=0) * noise_factor * max_mag, data.shape)
3450
return data + gaussian_noise
3551

3652
def augCS(self, data):
@@ -47,7 +63,7 @@ def augCROP(self, data, crop_percent=0.2):
4763
end_idx = start_idx + crop_length
4864
return data[start_idx:end_idx, :]
4965

50-
def augMAG(self, data, min_mag=0.5, max_mag=1.5):
66+
def augMAG(self, data, min_mag=0.5, max_mag=2):
5167
mag = np.random.uniform(min_mag, max_mag)
5268
return data * mag
5369

@@ -81,13 +97,11 @@ def augNOISE(self, data, noise_level=1):
8197
def augCUTOUT(self, data, drop_prob=0.2):
8298
num_channels = data.shape[1]
8399
channels_to_zero = np.random.choice(num_channels, size=int(num_channels * drop_prob), replace=False)
84-
85-
# Zero out the selected channels
86100
data[:, channels_to_zero] = 0
87101

88102
return data
89103

90-
def augMIXUP(self, data1, data2, alpha=0.2):
104+
def augMIXUP(self, data1, data2, alpha=0.8):
91105
lambda_ = np.random.beta(alpha, alpha)
92106
mixed_data = lambda_ * data1 + (1 - lambda_) * data2
93107
return mixed_data

0 commit comments

Comments
 (0)