From 328a8209a6692fb066d69d09b2038b8c91200d86 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Thu, 23 Jan 2025 11:38:54 +0100 Subject: [PATCH 1/6] Added graph augmentations for MAGIK --- .../gnn/augmentations/augmentations.py | 153 ++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 deeplay/components/gnn/augmentations/augmentations.py diff --git a/deeplay/components/gnn/augmentations/augmentations.py b/deeplay/components/gnn/augmentations/augmentations.py new file mode 100644 index 00000000..a5666a5c --- /dev/null +++ b/deeplay/components/gnn/augmentations/augmentations.py @@ -0,0 +1,153 @@ +"""Graph augmentations for MAGIK. + +This module provides classes to augment data during training +with transformations, node dropouts, and noise. + +Module Structure +---------------- + +- `NoisyNode`: Adds random noise to each node. + +- `NodeDropout`: Randomly removes a small ammount of nodes and edges. + +- `RandomRotation`: Randomly rotates all nodes by the same angle. + +- `RandomFlip`: Flips nodes with a 0.5 chance. + +- `AugmentCentroids`: Random rotation and translation of nodes. + + +""" + +from math import sin, cos + +import numpy as np +import torch + +class NoisyNode: + """Class to add noise to node attributes. + + """ + + def __call__(self, graph): + + graph = graph.clone() + + node_feats = graph.x[:, :2] - 0.5 # Centered positions. + node_feats += np.random.randn(*node_feats.shape) * np.random.rand()* 0.1 + + graph.x[:, :2] = node_feats + 0.5 # Restored positions. + + return graph + + +class NodeDropout: + """Removal (dropout) of random nodes to simulate missing frames. + + """ + + def __call__(self, graph): + + # Ensure original graph is unchanged. + graph = graph.clone() + + # Specify node dropout rate. + dropout_rate = 0.05 + + # Get indices of random nodes. + idx = np.array(list(range(len(graph.x)))) + dropped_idx = idx[np.random.rand(len(graph.x)) < dropout_rate] + + # Compute connectivity matrix to dropped nodes. + for dropped_node in dropped_idx: + edges_connected_to_removed_node = np.any( + np.array(graph.edge_index) == dropped_node, axis=0 + ) + + # Remove edges, weights, labels connected to dropped nodes with the + # bitwise not operator '~'. + graph.edge_index = graph.edge_index[:, ~edges_connected_to_removed_node] + graph.edge_attr = graph.edge_attr[~edges_connected_to_removed_node] + graph.distance = graph.distance[~edges_connected_to_removed_node] + graph.y = graph.y[~edges_connected_to_removed_node] + + return graph + + +class RandomRotation: + """Random rotations to diversify training data. + + """ + + def __call__(self, graph): + + graph = graph.clone() + node_feats = graph.x[:, :2] - 0.5 # Centered positons. + angle = np.random.rand() * 2 * pi + rotation_matrix = torch.tensor( + [[cos(angle), -sin(angle)], [sin(angle), cos(angle)]] + ).float() + rotated_node_attr = torch.matmul(node_feats, rotation_matrix) + graph.x[:, :2] = rotated_node_attr + 0.5 # Restored positons. + return graph + + +class RandomFlip: + """Random flip to diversify training data. + + """ + + def __call__(self, graph): + + graph = graph.clone() + node_feats = graph.x[:, :2] - 0.5 # Centered positons. + if np.random.randint(2): node_feats[:, 0] *= -1 + if np.random.randint(2): node_feats[:, 1] *= -1 + graph.x[:, :2] = node_feats + 0.5 # Restored positons. + return graph + + +class AugmentCentroids: + """Translation and rotation to diversify training data. + + """ + + def __call__(self, graph): + + graph = graph.clone() + + # Centered positions. + centroids = graph.x[:, :2] - 0.5 + + angle = np.random.rand() * 2 * np.pi + translate = np.random.rand(1,2) + + # Rotate x component of centroids. + centroids_x = ( + centroids[:, 0] * np.cos(angle) + + centroids[:, 1] * np.sin(angle) + + translate[0] + ) + + # Rotate y component of centroids. + centroids_y = ( + centroids[:, 1] * np.cos(angle) + + centroids[:, 0] * np.sin(angle) + + translate[1] + ) + + # Flip centroids randomly. + flip = np.random.rand(1,2) + + if flip[0] > 0.5: + centroids_x *= -1 + + if flip[1] > 0.5: + centroids_y *= -1 + + # Restore positions. + graph.x[:, 0] = centroids_x + 0.5 + graph.x[:, 1] = centroids_y + 0.5 + + return graph + From e762570bfb5236ac4c8eca375ee550e20cdecabe Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Thu, 23 Jan 2025 11:41:31 +0100 Subject: [PATCH 2/6] uses numpy.pi now --- deeplay/components/gnn/augmentations/augmentations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeplay/components/gnn/augmentations/augmentations.py b/deeplay/components/gnn/augmentations/augmentations.py index a5666a5c..6ac2c5ac 100644 --- a/deeplay/components/gnn/augmentations/augmentations.py +++ b/deeplay/components/gnn/augmentations/augmentations.py @@ -83,7 +83,7 @@ def __call__(self, graph): graph = graph.clone() node_feats = graph.x[:, :2] - 0.5 # Centered positons. - angle = np.random.rand() * 2 * pi + angle = np.random.rand() * 2 * np.pi rotation_matrix = torch.tensor( [[cos(angle), -sin(angle)], [sin(angle), cos(angle)]] ).float() From 6f0a7c2677d0b3a7b1d4dffa6d7792cb3013c0b3 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Thu, 23 Jan 2025 11:51:50 +0100 Subject: [PATCH 3/6] commenting --- .../gnn/augmentations/augmentations.py | 64 ++++++++++++++----- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/deeplay/components/gnn/augmentations/augmentations.py b/deeplay/components/gnn/augmentations/augmentations.py index 6ac2c5ac..232ce188 100644 --- a/deeplay/components/gnn/augmentations/augmentations.py +++ b/deeplay/components/gnn/augmentations/augmentations.py @@ -23,21 +23,27 @@ import numpy as np import torch +from torch_geometric.data import Data class NoisyNode: """Class to add noise to node attributes. """ - def __call__(self, graph): + def __call__( + self, + graph: Data, + ) -> Data : + # Ensure original graph is unchanged. graph = graph.clone() + + # Center positions. + node_feats = graph.x[:, :2] - 0.5 + node_feats += np.random.randn(*node_feats.shape) * np.random.rand()*0.1 - node_feats = graph.x[:, :2] - 0.5 # Centered positions. - node_feats += np.random.randn(*node_feats.shape) * np.random.rand()* 0.1 - - graph.x[:, :2] = node_feats + 0.5 # Restored positions. - + # Restore positions. + graph.x[:, :2] = node_feats + 0.5 return graph @@ -46,7 +52,10 @@ class NodeDropout: """ - def __call__(self, graph): + def __call__( + self, + graph: Data + ) -> Data: # Ensure original graph is unchanged. graph = graph.clone() @@ -79,16 +88,25 @@ class RandomRotation: """ - def __call__(self, graph): - + def __call__( + self, + graph: Data + ) -> Data: + # Ensure original graph is unchanged. graph = graph.clone() - node_feats = graph.x[:, :2] - 0.5 # Centered positons. + + # Center positons. + node_feats = graph.x[:, :2] - 0.5 angle = np.random.rand() * 2 * np.pi + rotation_matrix = torch.tensor( [[cos(angle), -sin(angle)], [sin(angle), cos(angle)]] ).float() rotated_node_attr = torch.matmul(node_feats, rotation_matrix) - graph.x[:, :2] = rotated_node_attr + 0.5 # Restored positons. + + # Restore positons. + graph.x[:, :2] = rotated_node_attr + 0.5 + return graph @@ -96,14 +114,23 @@ class RandomFlip: """Random flip to diversify training data. """ - - def __call__(self, graph): + def __call__( + self, + graph: Data + ) -> Data: + + # Ensure original graph is unchanged. graph = graph.clone() - node_feats = graph.x[:, :2] - 0.5 # Centered positons. + + # Center positons. + node_feats = graph.x[:, :2] - 0.5 + if np.random.randint(2): node_feats[:, 0] *= -1 if np.random.randint(2): node_feats[:, 1] *= -1 - graph.x[:, :2] = node_feats + 0.5 # Restored positons. + + # Restore positons. + graph.x[:, :2] = node_feats + 0.5 return graph @@ -112,11 +139,14 @@ class AugmentCentroids: """ - def __call__(self, graph): + def __call__( + self, + graph: Data + ) -> Data: graph = graph.clone() - # Centered positions. + # Center positions. centroids = graph.x[:, :2] - 0.5 angle = np.random.rand() * 2 * np.pi From 3eeddef1ab3132ad7369a6e1ad224b8fd43bbbab Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Sat, 5 Apr 2025 17:41:59 +0200 Subject: [PATCH 4/6] Clarified names, added parameter fields, uniform noise class, --- .../gnn/augmentations/augmentations.py | 111 ++++++++++++------ 1 file changed, 73 insertions(+), 38 deletions(-) diff --git a/deeplay/components/gnn/augmentations/augmentations.py b/deeplay/components/gnn/augmentations/augmentations.py index 232ce188..ece70dea 100644 --- a/deeplay/components/gnn/augmentations/augmentations.py +++ b/deeplay/components/gnn/augmentations/augmentations.py @@ -1,4 +1,4 @@ -"""Graph augmentations for MAGIK. +"""2D Graph augmentations. This module provides classes to augment data during training with transformations, node dropouts, and noise. @@ -25,47 +25,81 @@ import torch from torch_geometric.data import Data -class NoisyNode: - """Class to add noise to node attributes. +class NodeNormalNoise: + """Adds normal noise to node attributes. """ + def __init__(self, sigma: float = 1.0, mu: float = 0.0): + self.sigma = sigma + self.mu = mu def __call__( self, - graph: Data, - ) -> Data : + graph: torch_geometric.data.Data, + ) -> torch_geometric.data.Data : # Ensure original graph is unchanged. graph = graph.clone() # Center positions. node_feats = graph.x[:, :2] - 0.5 - node_feats += np.random.randn(*node_feats.shape) * np.random.rand()*0.1 + + # Add Normal noise. + node_feats += np.random.randn(*node_feats.shape) * self.sigma + self.mu + + # Restore positions. + graph.x[:, :2] = node_feats + 0.5 + return graph + +class NodeUniformNoise: + """Adds uniform noise to node attributes. + + """ + def __init__(self, low: float = 0.0, high: float = 1.0): + self.low = low + self.high = high + + def __call__( + self, + graph: torch_geometric.data.Data, + ) -> torch_geometric.data.Data : + + # Ensure original graph is unchanged. + graph = graph.clone() + + # Center positions. + node_feats = graph.x[:, :2] - 0.5 + + # Add Uniform noise. + node_feats += np.random.uniform( + self.low, + self.high, + size=node_feats.shape + ) # Restore positions. graph.x[:, :2] = node_feats + 0.5 return graph + class NodeDropout: - """Removal (dropout) of random nodes to simulate missing frames. + """Removal (dropout) of random nodes and edges with some probability.""" - """ + def __init__(self, dropout_rate: float = 0.05): + self.dropout_rate = dropout_rate. def __call__( self, - graph: Data - ) -> Data: + graph: torch_geometric.data.Data + ) -> torch_geometric.data.Data: # Ensure original graph is unchanged. graph = graph.clone() - # Specify node dropout rate. - dropout_rate = 0.05 - # Get indices of random nodes. idx = np.array(list(range(len(graph.x)))) - dropped_idx = idx[np.random.rand(len(graph.x)) < dropout_rate] + dropped_idx = idx[np.random.rand(len(graph.x)) < self.dropout_rate] # Compute connectivity matrix to dropped nodes. for dropped_node in dropped_idx: @@ -83,24 +117,27 @@ def __call__( return graph -class RandomRotation: - """Random rotations to diversify training data. - - """ +class NodeRotations2D: + """Random rotations to diversify training data""" def __call__( self, - graph: Data - ) -> Data: + graph: torch_geometric.data.Data + ) -> torch_geometric.data.Data: # Ensure original graph is unchanged. graph = graph.clone() # Center positons. node_feats = graph.x[:, :2] - 0.5 + + # Sample random angle. angle = np.random.rand() * 2 * np.pi rotation_matrix = torch.tensor( - [[cos(angle), -sin(angle)], [sin(angle), cos(angle)]] + [ + [cos(angle), -sin(angle)], + [sin(angle), cos(angle)] + ] ).float() rotated_node_attr = torch.matmul(node_feats, rotation_matrix) @@ -110,15 +147,13 @@ def __call__( return graph -class RandomFlip: - """Random flip to diversify training data. - - """ +class NodeFlips2D: + """Randomly flips nodes.""" def __call__( self, - graph: Data - ) -> Data: + graph: torch_geometric.data.Data + ) -> torch_geometric.data.Data: # Ensure original graph is unchanged. graph = graph.clone() @@ -126,18 +161,20 @@ def __call__( # Center positons. node_feats = graph.x[:, :2] - 0.5 - if np.random.randint(2): node_feats[:, 0] *= -1 - if np.random.randint(2): node_feats[:, 1] *= -1 + if np.random.randint(2): + node_feats[:, 0] *= -1 + + if np.random.randint(2): + node_feats[:, 1] *= -1 # Restore positons. graph.x[:, :2] = node_feats + 0.5 + return graph -class AugmentCentroids: - """Translation and rotation to diversify training data. - - """ +class NodeAugmentation2D: + """Translations and rotations to diversify training data.""" def __call__( self, @@ -166,13 +203,11 @@ def __call__( translate[1] ) - # Flip centroids randomly. - flip = np.random.rand(1,2) - - if flip[0] > 0.5: + # Flip centroids. + if np.random.randint(2): centroids_x *= -1 - if flip[1] > 0.5: + if np.random.randint(2): centroids_y *= -1 # Restore positions. From 989b8011d76976a503f88ca2ba87e7abd6a3f8e5 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Sat, 5 Apr 2025 17:47:18 +0200 Subject: [PATCH 5/6] fix: syntax error --- deeplay/components/gnn/augmentations/augmentations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeplay/components/gnn/augmentations/augmentations.py b/deeplay/components/gnn/augmentations/augmentations.py index ece70dea..1bb2cd91 100644 --- a/deeplay/components/gnn/augmentations/augmentations.py +++ b/deeplay/components/gnn/augmentations/augmentations.py @@ -87,7 +87,7 @@ class NodeDropout: """Removal (dropout) of random nodes and edges with some probability.""" def __init__(self, dropout_rate: float = 0.05): - self.dropout_rate = dropout_rate. + self.dropout_rate = dropout_rate def __call__( self, From 237e67bf44b30fcfdf3e86aac90d31cd817c5e96 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Sat, 5 Apr 2025 17:52:29 +0200 Subject: [PATCH 6/6] syntax --- .../gnn/augmentations/augmentations.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/deeplay/components/gnn/augmentations/augmentations.py b/deeplay/components/gnn/augmentations/augmentations.py index 1bb2cd91..0ea97de3 100644 --- a/deeplay/components/gnn/augmentations/augmentations.py +++ b/deeplay/components/gnn/augmentations/augmentations.py @@ -35,8 +35,8 @@ def __init__(self, sigma: float = 1.0, mu: float = 0.0): def __call__( self, - graph: torch_geometric.data.Data, - ) -> torch_geometric.data.Data : + graph: Data, + ) -> Data : # Ensure original graph is unchanged. graph = graph.clone() @@ -61,8 +61,8 @@ def __init__(self, low: float = 0.0, high: float = 1.0): def __call__( self, - graph: torch_geometric.data.Data, - ) -> torch_geometric.data.Data : + graph: Data, + ) -> Data : # Ensure original graph is unchanged. graph = graph.clone() @@ -91,8 +91,8 @@ def __init__(self, dropout_rate: float = 0.05): def __call__( self, - graph: torch_geometric.data.Data - ) -> torch_geometric.data.Data: + graph: Data + ) -> Data: # Ensure original graph is unchanged. graph = graph.clone() @@ -122,8 +122,8 @@ class NodeRotations2D: def __call__( self, - graph: torch_geometric.data.Data - ) -> torch_geometric.data.Data: + graph: Data + ) -> Data: # Ensure original graph is unchanged. graph = graph.clone() @@ -152,8 +152,8 @@ class NodeFlips2D: def __call__( self, - graph: torch_geometric.data.Data - ) -> torch_geometric.data.Data: + graph: Data + ) -> Data: # Ensure original graph is unchanged. graph = graph.clone()