From 8fde925435bd90934e419831f7a119ef1f17ae29 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Thu, 29 Aug 2024 15:52:23 +0200 Subject: [PATCH 1/8] Correcting typos in tutorials --- tutorials/developers/DT101_files.ipynb | 9 +++++-- tutorials/developers/DT121_overview.ipynb | 7 +++++- tutorials/developers/DT131_applications.ipynb | 4 ++-- tutorials/developers/DT141_models.ipynb | 2 +- tutorials/developers/DT151_components.ipynb | 12 +++++----- tutorials/developers/DT171_blocks.ipynb | 2 +- tutorials/developers/DT181_internals.ipynb | 2 +- tutorials/developers/dev_intro.ipynb | 8 +++---- .../getting-started/GS101_core_objects.ipynb | 4 ++-- tutorials/getting-started/GS121_modules.ipynb | 24 +++++++++---------- tutorials/getting-started/GS131_methods.ipynb | 2 +- .../getting-started/GS141_applications.ipynb | 8 +++---- tutorials/getting-started/GS151_models.ipynb | 2 +- .../getting-started/GS161_components.ipynb | 2 +- 14 files changed, 49 insertions(+), 39 deletions(-) diff --git a/tutorials/developers/DT101_files.ipynb b/tutorials/developers/DT101_files.ipynb index 444a101e..9091b3aa 100644 --- a/tutorials/developers/DT101_files.ipynb +++ b/tutorials/developers/DT101_files.ipynb @@ -14,7 +14,7 @@ "## Root Level Files\n", "\n", "Deeplay contains the following files at the root level:\n", - "- `.gitignore`: Contains the files to be ingnored by GIT.\n", + "- `.gitignore`: Contains the files to be ignored by GIT.\n", "- `.pylintrc`: Configuration file for the pylint tool. It contains the rules for code formatting and style.\n", "- `LICENSE.txt`: Deeplay's project license.\n", "- `README.md`: Deeplay's project README file\n", @@ -108,7 +108,7 @@ "\n", "- `blocks`\n", "\n", - " This directory contains the classes and functions related to blocks in the Deeplay library. Blocks are the building blocks of models in the Deeplay library. They are used to define the architecture of a model, and can be combined to create complex models. The most important block classes are in the subfolders `conv`, `linear`, `sequence` and in the files `base.py` anb `sequential.py`.\n", + " This directory contains the classes and functions related to blocks in the Deeplay library. Blocks are the building blocks of models in the Deeplay library. They are used to define the architecture of a model, and can be combined to create complex models. The most important block classes are in the subfolders `conv`, `linear`, `sequence` and in the files `base.py` and `sequential.py`.\n", "\n", "- `components`\n", "\n", @@ -148,6 +148,11 @@ "\n", " This directory contains the unit tests for the library. These are used to ensure that the library is working correctly and to catch any bugs that may arise." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] } ], "metadata": { diff --git a/tutorials/developers/DT121_overview.ipynb b/tutorials/developers/DT121_overview.ipynb index b53b6ce5..3f3c69f4 100644 --- a/tutorials/developers/DT121_overview.ipynb +++ b/tutorials/developers/DT121_overview.ipynb @@ -42,8 +42,13 @@ "| Is the object a small structural object with a sequential forward pass, such as a layer activation? | Yes | `Block` |\n", "| Is the object a unit of computation, such as a convolution or a pooling operation? | Yes | `Operation` |\n", "\n", - "As a general rule of thumb, for objects derived from `Component`, the number of features in each layer should be defineable by the input arguments. For objects derived from `Model`, only the input and output features must be defineable by the input arguments. In both cases, it is recommended to subclass an existing model or component if possible. This will make it easier to implement the required methods and attributes, and will ensure that the new model or component is compatible with the rest of the library." + "As a general rule of thumb, for objects derived from `Component`, the number of features in each layer should be definable by the input arguments. For objects derived from `Model`, only the input and output features must be definable by the input arguments. In both cases, it is recommended to subclass an existing model or component if possible. This will make it easier to implement the required methods and attributes, and will ensure that the new model or component is compatible with the rest of the library." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] } ], "metadata": { diff --git a/tutorials/developers/DT131_applications.ipynb b/tutorials/developers/DT131_applications.ipynb index 14e5785f..65d632c4 100644 --- a/tutorials/developers/DT131_applications.ipynb +++ b/tutorials/developers/DT131_applications.ipynb @@ -22,7 +22,7 @@ "Therefore, applications should strive to be as model agnostic as possible, so that they can be used with any model that fits the input and output shapes.\n", "\n", "Applications define the training and inference loops, and the loss function.\n", - "Applications may also define custom metrics, or specialmethods used for inference. Forexample, a classifier application may define a method to predict a hard label from the input, instead of the probabilities.\n", + "Applications may also define custom metrics, or special methods used for inference. For example, a classifier application may define a method to predict a hard label from the input, instead of the probabilities.\n", "\n", "Examples of applications are `Classifier`, `Regressor`, `Segmentor`, and `VanillaGAN`." ] @@ -56,7 +56,7 @@ "source": [ "### 1. Create a New File\n", "\n", - "The first step is to create a new file in the `deeplay/applications` directory. It should generally be in a subdirectory, named after the type of application. For example, a binary classifier application should be in `deeplay/applications/classificaiton/binary.py`.\n", + "The first step is to create a new file in the `deeplay/applications` directory. It should generally be in a subdirectory, named after the type of application. For example, a binary classifier application should be in `deeplay/applications/classification/binary.py`.\n", "\n", "**The base class: `Application`.** \n", "Applications should inherit from the `Application` class. This class is a subclass of both `DeeplayModule` and `lightning.LightningModule`. This is to ensure that the \n", diff --git a/tutorials/developers/DT141_models.ipynb b/tutorials/developers/DT141_models.ipynb index e65b4064..68d6da6e 100644 --- a/tutorials/developers/DT141_models.ipynb +++ b/tutorials/developers/DT141_models.ipynb @@ -117,7 +117,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can now instatiate this block and verify its structure." + "You can now instantiate this block and verify its structure." ] }, { diff --git a/tutorials/developers/DT151_components.ipynb b/tutorials/developers/DT151_components.ipynb index 40c934b9..32a5a994 100644 --- a/tutorials/developers/DT151_components.ipynb +++ b/tutorials/developers/DT151_components.ipynb @@ -191,7 +191,7 @@ " The input tensor of shape (N, C, H, W).\n", " Where N is the batch size, C is the number of channels, H is the \n", " height, and W is the width.\n", - " Additial dimensions before C are allowed.\n", + " Additional dimensions before C are allowed.\n", " \n", " Output\n", " ------\n", @@ -199,7 +199,7 @@ " The output tensor of shape (N, out_channels, H', W').\n", " Where N is the batch size, out_channels is the number of output \n", " channels, H is the height, and W is the width.\n", - " Additial dimensions before out_channels will be preserved.\n", + " Additional dimensions before out_channels will be preserved.\n", "\n", " Evaluation\n", " ----------\n", @@ -338,7 +338,7 @@ " The input tensor of shape (N, C, H, W).\n", " Where N is the batch size, C is the number of channels, H is the \n", " height, and W is the width.\n", - " Additial dimensions before C are allowed.\n", + " Additional dimensions before C are allowed.\n", " \n", " Output\n", " ------\n", @@ -346,7 +346,7 @@ " The output tensor of shape (N, out_channels, H', W').\n", " Where N is the batch size, out_channels is the number of output \n", " channels, H is the height, and W is the width.\n", - " Additial dimensions before out_channels will be preserved.\n", + " Additional dimensions before out_channels will be preserved.\n", "\n", " Evaluation\n", " ----------\n", @@ -479,7 +479,7 @@ " The input tensor of shape (N, C, H, W).\n", " Where N is the batch size, C is the number of channels, H is the \n", " height, and W is the width.\n", - " Additial dimensions before C are allowed.\n", + " Additional dimensions before C are allowed.\n", " \n", " Output\n", " ------\n", @@ -487,7 +487,7 @@ " The output tensor of shape (N, out_channels, H', W').\n", " Where N is the batch size, out_channels is the number of output \n", " channels, H is the height, and W is the width.\n", - " Additial dimensions before out_channels will be preserved.\n", + " Additional dimensions before out_channels will be preserved.\n", "\n", " Evaluation\n", " ----------\n", diff --git a/tutorials/developers/DT171_blocks.ipynb b/tutorials/developers/DT171_blocks.ipynb index 31a4f447..11c48e79 100644 --- a/tutorials/developers/DT171_blocks.ipynb +++ b/tutorials/developers/DT171_blocks.ipynb @@ -21,7 +21,7 @@ "remove steps from the block. For example, adding an activation function, or\n", "removing a dropout layer. \n", "\n", - "Also, remember that blocks shouls be small and modular. If you are implementing\n", + "Also, remember that blocks should be small and modular. If you are implementing\n", "a block that is too big, you should consider breaking it down into smaller blocks.\n", "\n", "Finally, blocks should be pretty strict in terms of input and output. This is \n", diff --git a/tutorials/developers/DT181_internals.ipynb b/tutorials/developers/DT181_internals.ipynb index 1741af51..bbe8310f 100644 --- a/tutorials/developers/DT181_internals.ipynb +++ b/tutorials/developers/DT181_internals.ipynb @@ -388,7 +388,7 @@ "\n", "Since Deeplay modules requires an additional `build` step before the weights are created, so the default checkpointing system of `lightning` does not work.\n", "\n", - "We have solved this by storing the state of the `Application` object immediately before building as a a hyperparameter in the checkpoint. This is then loaded when the model is loaded from the checkpoint, and the `build` method is called with the same arguments as before before the weights are loaded." + "We have solved this by storing the state of the `Application` object immediately before building as a hyperparameter in the checkpoint. This is then loaded when the model is loaded from the checkpoint, and the `build` method is called with the same arguments as before the weights are loaded." ] } ], diff --git a/tutorials/developers/dev_intro.ipynb b/tutorials/developers/dev_intro.ipynb index 0f56368c..f48c1a22 100644 --- a/tutorials/developers/dev_intro.ipynb +++ b/tutorials/developers/dev_intro.ipynb @@ -6,9 +6,9 @@ "source": [ "# Introduction for Developers\n", "\n", - "This notebook presents a minimal example of how to implement a neural network and and an application with deeplay.\n", + "This notebook presents a minimal example of how to implement a neural network and an application with deeplay.\n", "Specifically, it implements the classes for a multilayer perceptron and a classifier.\n", - "Then, it combines them to demonstrate how tehy can be used for a simple classification task.\n", + "Then, it combines them to demonstrate how they can be used for a simple classification task.\n", "Finally, it upgrades these classes adding functionalities that are required to improved the user experience when using an IDE." ] }, @@ -20,7 +20,7 @@ "\n", "Here, we implement the minimal class `SimpleMLP`. It extends directly `dl.DeeplayModule`, which is the base class for all modules in `deeplay`.\n", "\n", - "It represents a multilayer perceptron with a certain umber of inputs (`ìn_features`, which is an integer), a series of hidden layers with a certain number of neurons (`hidden_features`, a vector with the number of neurons for each layer), and a certain numebr of outputs (`out_features`, which is an integer).\n", + "It represents a multilayer perceptron with a certain umber of inputs (`ìn_features`, which is an integer), a series of hidden layers with a certain number of neurons (`hidden_features`, a vector with the number of neurons for each layer), and a certain number of outputs (`out_features`, which is an integer).\n", "\n", "The constructor initializes the MLP by creating a sequence of linear and ReLU activation layers.\n", "\n", @@ -229,7 +229,7 @@ "source": [ "## Example\n", "\n", - "We'll now use `classifier` for the simple task of determinig whether the sum of two numbers is larger or smaller than 0." + "We'll now use `classifier` for the simple task of determining whether the sum of two numbers is larger or smaller than 0." ] }, { diff --git a/tutorials/getting-started/GS101_core_objects.ipynb b/tutorials/getting-started/GS101_core_objects.ipynb index 6dfaa3e6..7b409f1e 100644 --- a/tutorials/getting-started/GS101_core_objects.ipynb +++ b/tutorials/getting-started/GS101_core_objects.ipynb @@ -18,7 +18,7 @@ "\n", "- **Layer:** A layer consists of a single torch layer, such as `torch.nn.Linear` and `torch.nn.Conv2d`. Layers are the most basic building blocks in Deeplay.\n", "\n", - "In the following sections, you'll create some examples of these obejcts." + "In the following sections, you'll create some examples of these objects." ] }, { @@ -27,7 +27,7 @@ "source": [ "## Importing Deeplay\n", "\n", - "Import `deeplay` (shortened to `dl`, as an abbreaviation of both *deeplay* and *deep learning*) ... " + "Import `deeplay` (shortened to `dl`, as an abbreviation of both *deeplay* and *deep learning*) ... " ] }, { diff --git a/tutorials/getting-started/GS121_modules.ipynb b/tutorials/getting-started/GS121_modules.ipynb index 7f25df72..fb2c1a83 100644 --- a/tutorials/getting-started/GS121_modules.ipynb +++ b/tutorials/getting-started/GS121_modules.ipynb @@ -6,7 +6,7 @@ "source": [ "# Working with Deeplay Modules\n", "\n", - "In this section, you'll learn how to create and build Deeplay modules as well as how to configure their properties. You'll also understand the difference between Deeplay and PyTorch modules. You'll " + "In this section, you'll learn how to create and build Deeplay modules as well as how to configure their properties. You'll also understand the difference between Deeplay and PyTorch modules." ] }, { @@ -275,7 +275,7 @@ "MultiLayerPerceptron(\n", " (blocks): LayerList(\n", " (0): LinearBlock(\n", - " (layer): Layer[Linear](in_features=728, out_features=32, bias=True)\n", + " (layer): Layer[Linear](in_features=784, out_features=32, bias=True)\n", " (activation): Layer[Tanh]()\n", " )\n", " (1): LinearBlock(\n", @@ -294,7 +294,7 @@ "source": [ "import torch.nn as nn\n", "\n", - "mlp = dl.MultiLayerPerceptron(728, [32, 16], 10)\n", + "mlp = dl.MultiLayerPerceptron(784, [32, 16], 10)\n", "mlp[\"blocks\", :].all.configure(activation=dl.Layer(nn.Tanh))\n", "\n", "print(mlp)" @@ -319,7 +319,7 @@ "MultiLayerPerceptron(\n", " (blocks): LayerList(\n", " (0): LinearBlock(\n", - " (layer): Layer[Linear](in_features=728, out_features=32, bias=True)\n", + " (layer): Layer[Linear](in_features=784, out_features=32, bias=True)\n", " (activation): Layer[Tanh]()\n", " )\n", " (1): LinearBlock(\n", @@ -336,7 +336,7 @@ } ], "source": [ - "mlp = dl.MultiLayerPerceptron(728, [32, 16], 10)\n", + "mlp = dl.MultiLayerPerceptron(784, [32, 16], 10)\n", "mlp[\"blocks\", :].all.activated(nn.Tanh)\n", "\n", "print(mlp)" @@ -354,7 +354,7 @@ "MultiLayerPerceptron(\n", " (blocks): LayerList(\n", " (0): LinearBlock(\n", - " (layer): Layer[Linear](in_features=728, out_features=32, bias=True)\n", + " (layer): Layer[Linear](in_features=784, out_features=32, bias=True)\n", " (activation): Layer[Tanh]()\n", " )\n", " (1): LinearBlock(\n", @@ -371,7 +371,7 @@ } ], "source": [ - "mlp = dl.MultiLayerPerceptron(728, [32, 16], 10)\n", + "mlp = dl.MultiLayerPerceptron(784, [32, 16], 10)\n", "mlp[...].hasattr(\"activated\").all.activated(nn.Tanh)\n", "\n", "print(mlp)" @@ -389,7 +389,7 @@ "MultiLayerPerceptron(\n", " (blocks): LayerList(\n", " (0): LinearBlock(\n", - " (layer): Layer[Linear](in_features=728, out_features=32, bias=True)\n", + " (layer): Layer[Linear](in_features=784, out_features=32, bias=True)\n", " (activation): Layer[Tanh]()\n", " )\n", " (1): LinearBlock(\n", @@ -406,7 +406,7 @@ } ], "source": [ - "mlp = dl.MultiLayerPerceptron(728, [32, 16], 10)\n", + "mlp = dl.MultiLayerPerceptron(784, [32, 16], 10)\n", "mlp[...].isinstance(dl.LinearBlock).all.activated(nn.Tanh)\n", "\n", "print(mlp)" @@ -424,7 +424,7 @@ "MultiLayerPerceptron(\n", " (blocks): LayerList(\n", " (0): LinearBlock(\n", - " (layer): Layer[Linear](in_features=728, out_features=32, bias=True)\n", + " (layer): Layer[Linear](in_features=784, out_features=32, bias=True)\n", " (activation): Layer[Tanh]()\n", " )\n", " (1): LinearBlock(\n", @@ -441,7 +441,7 @@ } ], "source": [ - "mlp = dl.MultiLayerPerceptron(728, [32, 16], 10)\n", + "mlp = dl.MultiLayerPerceptron(784, [32, 16], 10)\n", "for block in mlp.blocks:\n", " block.activated(nn.Tanh)\n", " \n", @@ -618,7 +618,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/tutorials/getting-started/GS131_methods.ipynb b/tutorials/getting-started/GS131_methods.ipynb index a1687cef..99d85771 100644 --- a/tutorials/getting-started/GS131_methods.ipynb +++ b/tutorials/getting-started/GS131_methods.ipynb @@ -1137,7 +1137,7 @@ "source": [ "## `DeeplayModule.predict()`\n", "\n", - "The `DeeplayModule.predict()` method is a convienient way to get predictions from the model on a large dataset. It does the batching, moving the data to the correct device and returning the predictions in a single call to the model without gradients." + "The `DeeplayModule.predict()` method is a convenient way to get predictions from the model on a large dataset. It does the batching, moving the data to the correct device and returning the predictions in a single call to the model without gradients." ] }, { diff --git a/tutorials/getting-started/GS141_applications.ipynb b/tutorials/getting-started/GS141_applications.ipynb index 2795d10c..74eda924 100644 --- a/tutorials/getting-started/GS141_applications.ipynb +++ b/tutorials/getting-started/GS141_applications.ipynb @@ -449,14 +449,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**NOTE:** In this case, however, the `.loss` atttribute will still be set to the default loss, so you won't see the correct loss when you print it." + "**NOTE:** In this case, however, the `.loss` attribute will still be set to the default loss, so you won't see the correct loss when you print it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Copntrolling the Optimizer\n", + "## Controlling the Optimizer\n", "\n", "Deeplay provides some wrappers around the PyTorch optimizers. Their main use is to be able to create the optimizer before the model is built and the parameters are known. Examples of optimizers available in Deeplay are:" ] @@ -982,7 +982,7 @@ ], "source": [ "x_numpy_test = numpy.random.randn(100, 10) # Input test data.\n", - "y_numpy_test = x_numpy_val.max(axis=1, keepdims=True) # Target test data.\n", + "y_numpy_test = x_numpy_test.max(axis=1, keepdims=True) # Target test data.\n", "\n", "test_results = regressor.test((x_numpy_test, y_numpy_test))\n", "print(f\"Test results: {test_results}\\n\") # Benjamin: Should this also work? If not, why not?" @@ -1084,7 +1084,7 @@ ], "source": [ "x_numpy_test = numpy.random.randn(100, 10) # Input test data.\n", - "y_numpy_test = x_numpy_val.max(axis=1, keepdims=True) # Target test data.\n", + "y_numpy_test = x_numpy_test.max(axis=1, keepdims=True) # Target test data.\n", "\n", "x_torch_test = torch.from_numpy(x_numpy_test).float()\n", "y_torch_test = torch.from_numpy(y_numpy_test).float()\n", diff --git a/tutorials/getting-started/GS151_models.ipynb b/tutorials/getting-started/GS151_models.ipynb index 95996903..f859c6f2 100644 --- a/tutorials/getting-started/GS151_models.ipynb +++ b/tutorials/getting-started/GS151_models.ipynb @@ -1003,7 +1003,7 @@ "source": [ "### Making a Model by Subclassing\n", "\n", - "A model is just a `DeeplayModule` sublass like any other. For some applications, it might be more convenient to subclass the model and implement the `.forward()` method." + "A model is just a `DeeplayModule` subclass like any other. For some applications, it might be more convenient to subclass the model and implement the `.forward()` method." ] }, { diff --git a/tutorials/getting-started/GS161_components.ipynb b/tutorials/getting-started/GS161_components.ipynb index 0a4d16b9..33005e17 100644 --- a/tutorials/getting-started/GS161_components.ipynb +++ b/tutorials/getting-started/GS161_components.ipynb @@ -353,7 +353,7 @@ " out_activation=dl.Layer(torch.nn.Tanh),\n", ")\n", "\n", - "encdec" + "print(encdec)" ] }, { From 31282f1091e73aa184bdff8da66d28ffa7d6379a Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Tue, 12 Nov 2024 09:54:26 +0100 Subject: [PATCH 2/8] Mg/Graph autoencoder --- deeplay/applications/autoencoders/__init__.py | 1 + deeplay/applications/autoencoders/vgae.py | 135 +++++ deeplay/components/gnn/__init__.py | 1 + deeplay/components/gnn/gcn/gcn.py | 4 +- deeplay/components/gnn/gcn/normalization.py | 6 +- deeplay/components/gnn/graphencdec.py | 536 ++++++++++++++++++ deeplay/components/gnn/mpn/__init__.py | 1 + .../components/gnn/mpn/get_edge_features.py | 16 + deeplay/components/gnn/pooling/__init__.py | 2 + deeplay/components/gnn/pooling/graph_pool.py | 147 +++++ deeplay/components/gnn/pooling/mincut.py | 241 ++++++++ 11 files changed, 1087 insertions(+), 3 deletions(-) create mode 100644 deeplay/applications/autoencoders/vgae.py create mode 100644 deeplay/components/gnn/graphencdec.py create mode 100644 deeplay/components/gnn/mpn/get_edge_features.py create mode 100644 deeplay/components/gnn/pooling/__init__.py create mode 100644 deeplay/components/gnn/pooling/graph_pool.py create mode 100644 deeplay/components/gnn/pooling/mincut.py diff --git a/deeplay/applications/autoencoders/__init__.py b/deeplay/applications/autoencoders/__init__.py index 191aac42..d4f255af 100644 --- a/deeplay/applications/autoencoders/__init__.py +++ b/deeplay/applications/autoencoders/__init__.py @@ -1,2 +1,3 @@ from .vae import VariationalAutoEncoder from .wae import WassersteinAutoEncoder +from .vgae import VariationalGraphAutoEncoder diff --git a/deeplay/applications/autoencoders/vgae.py b/deeplay/applications/autoencoders/vgae.py new file mode 100644 index 00000000..3e7a3e62 --- /dev/null +++ b/deeplay/applications/autoencoders/vgae.py @@ -0,0 +1,135 @@ +from typing import Optional, Sequence, Callable, List + +from deeplay.components import ConvolutionalEncoder2d, ConvolutionalDecoder2d +from deeplay.applications import Application +from deeplay.external import External, Optimizer, Adam + +from deeplay import ( + DeeplayModule, + Layer, +) + + +import torch +import torch.nn as nn + + +class VariationalGraphAutoEncoder(Application): + channels: list + latent_dim: int + encoder: torch.nn.Module + decoder: torch.nn.Module + beta: float + reconstruction_loss: torch.nn.Module + metrics: list + optimizer: Optimizer + + def __init__( + self, + channels: Optional[int] = 96, + encoder: Optional[nn.Module] = None, + decoder: Optional[nn.Module] = None, + reconstruction_loss: Optional[Callable] = nn.L1Loss(), + latent_dim=int, + alpha: Optional[int] = 0, + beta: Optional[int] = 1e-7, + gamma: Optional[int] = 10, + delta: Optional[int] = 1, + optimizer=None, + **kwargs, + ): + self.encoder = encoder + + self.fc_mu = Layer(nn.Linear, channels, latent_dim) + self.fc_mu.set_input_map('x') + self.fc_mu.set_output_map('mu') + + self.fc_var = Layer(nn.Linear, channels, latent_dim) + self.fc_var.set_input_map('x') + self.fc_var.set_output_map('log_var') + + self.fc_dec = Layer(nn.Linear, latent_dim, channels) + self.fc_dec.set_input_map('z') + self.fc_dec.set_output_map('x') + + self.decoder = decoder + + self.reconstruction_loss = reconstruction_loss or nn.L1Loss() + self.latent_dim = latent_dim + self.alpha = alpha # node feature reconstruction loss weight + self.beta = beta # KL loss weight + self.gamma = gamma # edge feature reconstruction loss weight + self.delta = delta # MinCut loss weight + + super().__init__(**kwargs) + + class Reparameterize(DeeplayModule): + def forward(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + self.reparameterize = Reparameterize() + self.reparameterize.set_input_map('mu', 'log_var') + self.reparameterize.set_output_map('z') + + self.optimizer = optimizer or Adam(lr=1e-3) + + @self.optimizer.params + def params(self): + return self.parameters() + + + def encode(self, x): + x = self.encoder(x) + x = self.fc_mu(x) + x = self.fc_var(x) + return x + + def decode(self, x): + x = self.fc_dec(x) + x = self.decoder(x) + return x + + def training_step(self, batch, batch_idx): + x, y = self.train_preprocess(batch) + node_features, edge_features = y + x = self(x) + node_features_hat = x['x'] + edge_features_hat = x['edge_attr'] + mu = x['mu'] + log_var = x['log_var'] + mincut_cut_loss = sum(value for key, value in x.items() if key.startswith('L_cut')) + mincut_ortho_loss = sum(value for key, value in x.items() if key.startswith('L_ortho')) + rec_loss_nodes, rec_loss_edges, KLD = self.compute_loss(node_features_hat, node_features, edge_features_hat, edge_features, mu, log_var) + + tot_loss = self.alpha * rec_loss_nodes + self.gamma * rec_loss_edges + self.beta * KLD + self.delta * (mincut_cut_loss + mincut_ortho_loss) + + loss = {"rec_loss_nodes": rec_loss_nodes, "rec_loss_edges": rec_loss_edges, "KL": KLD, + "MinCut cut loss": mincut_cut_loss, "MinCut orthogonality loss": mincut_ortho_loss, + "total_loss": tot_loss} + for name, v in loss.items(): + self.log( + f"train_{name}", + v, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return tot_loss + + def compute_loss(self, n_hat, n, e_hat, e, mu, log_var): + + rec_loss_nodes = self.reconstruction_loss(n_hat, n) + rec_loss_edges = self.reconstruction_loss(e_hat, e) + + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) + + return rec_loss_nodes, rec_loss_edges, KLD + + def forward(self, x): + x = self.encode(x) + x = self.reparameterize(x) + x = self.decode(x) + return x diff --git a/deeplay/components/gnn/__init__.py b/deeplay/components/gnn/__init__.py index 1f0bd47c..1bf368e6 100644 --- a/deeplay/components/gnn/__init__.py +++ b/deeplay/components/gnn/__init__.py @@ -1,3 +1,4 @@ from .gcn import * from .mpn import * from .tpu import * +from .graphencdec import GraphEncoderBlock, GraphDecoderBlock, GraphEncoder, GraphDecoder \ No newline at end of file diff --git a/deeplay/components/gnn/gcn/gcn.py b/deeplay/components/gnn/gcn/gcn.py index 0490b67e..dd60ab4c 100644 --- a/deeplay/components/gnn/gcn/gcn.py +++ b/deeplay/components/gnn/gcn/gcn.py @@ -82,7 +82,7 @@ def __init__( self.normalize = Layer(sparse_laplacian_normalization) self.normalize.set_input_map("x", "edge_index") - self.normalize.set_output_map("edge_index") + self.normalize.set_output_map("laplacian") class Propagate(DeeplayModule): def forward(self, x, A): @@ -115,7 +115,7 @@ def forward(self, x, A): transform.set_output_map("x") propagate = Layer(Propagate) - propagate.set_input_map("x", "edge_index") + propagate.set_input_map("x", "laplacian") propagate.set_output_map("x") update = Layer(nn.ReLU) if i < len(self.hidden_features) else out_activation diff --git a/deeplay/components/gnn/gcn/normalization.py b/deeplay/components/gnn/gcn/normalization.py index 28607349..98c2ddf5 100644 --- a/deeplay/components/gnn/gcn/normalization.py +++ b/deeplay/components/gnn/gcn/normalization.py @@ -12,7 +12,11 @@ def add_self_loops(self, A, num_nodes): loop_index = torch.arange(num_nodes, device=A.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) - A = torch.cat([A, loop_index], dim=1) + + if A.is_sparse: # changed this to ensure that it works even if the format of A varies. Previous: A = torch.cat([A, loop_index], dim=1) + A = torch.cat([A.indices(), loop_index], dim=1) + elif (not A.is_sparse) & (A.size(0) == 2): + A = torch.cat([A, loop_index], dim=1) return A diff --git a/deeplay/components/gnn/graphencdec.py b/deeplay/components/gnn/graphencdec.py new file mode 100644 index 00000000..657acb02 --- /dev/null +++ b/deeplay/components/gnn/graphencdec.py @@ -0,0 +1,536 @@ +from __future__ import annotations +from typing import Optional, Sequence, Type, Union +import warnings + +from deeplay import ( + DeeplayModule, + Layer, + LayerList, +) +from deeplay.components.gnn import GraphConvolutionalNeuralNetwork, MessagePassingNeuralNetwork +from deeplay.components.gnn.pooling import MinCutPooling +from deeplay.ops import Cat +# from deeplay.deeplay.components.gnn.pooling.graph_pool import GlobalGraphPooling, GlobalGraphUpsampling, MinCutUpsampling +from deeplay.components.gnn.pooling import GlobalGraphPooling, GlobalGraphUpsampling, MinCutUpsampling +from deeplay.components.gnn.mpn import GetEdgeFeatures +from deeplay.components.mlp import MultiLayerPerceptron + +import torch.nn as nn + + +class GraphEncoder(DeeplayModule): + """ A Graph Encoder module that leverages multiple graph processing blocks to learn representations + from graph-structured data. This module supports graph convolution and pooling operations, enabling + effective encoding of graph information for downstream tasks. + + Parameters + ---------- + hidden_features: int + The number of hidden features in the hidden layers, both in the gcn and pooling, of the encoder. + num_blocks: int + The number of processing blocks in the encoder. + num_clusters: list[int] + The number of clusters the graph is pooled to in each processing block. + thresholds: list[float] + The threshold values for the connectivity in the clustering process. + poolings: template-like + A list of pooling layers to use. Default is using MinCut pooling for all layers, except for the + last, which is global pooling. + save_intermediates: bool + Flag indicating whether to save intermediate adjacency matrices and other information, useful + when using it together with the GraphDecoder. Default is True. + + + Configurables + ------------- + - hidden features (int): Number of features of the hidden layers. + - num_blocks: (int): Number of processing blocks in the encoder. + - num_clusters: list[int]: Number of clusters the graph is pooled to in each processing block. + - thresholds list[int]: The threshold values for the connectivity in the clustering process. + + + Evaluation + ---------- + >>> encoder = dl.GraphEncoder(hidden_features=96, num_blocks=3, num_clusters=[5, 3, 1], thresholds=[0.1, 0.2, None], save_intermediates=False).build() + >>> inp = {} + >>> inp["x"] = torch.randn(10, 16) + >>> inp['batch'] = torch.zeros(10, dtype=int) + >>> inp["edge_index"] = torch.randint(0, 10, (2, 20)) + >>> inp["edge_attr"] = torch.randn(20, 8) + >>> output = encoder(inp) + + + Return Values + ------------- + The forward method returns a mapping object with the updated node features, edge_index, edge_attributes, + and the cut and orthogonality losses from the MinCut pooling. + + """ + hidden_features: int + num_blocks: int + num_clusters: Sequence[int] + thresholds: Optional[Sequence[float]] + poolings: Optional[Sequence[nn.Module]] + save_intermediates: Optional[bool] + + def __init__( + self, + hidden_features: int, + num_blocks: int, + num_clusters: Sequence[int], + thresholds: Optional[Sequence[float]] = None, + poolings: Optional[Sequence[Union[Type[nn.Module], nn.Module]]] = None, + save_intermediates: Optional[bool] = True, + ): + super().__init__( + hidden_features = hidden_features, + num_blocks = num_blocks, + num_clusters = num_clusters, + thresholds = thresholds, + poolings = poolings, + save_intermediates = save_intermediates, + ) + + if not isinstance(hidden_features, int) or hidden_features <= 0: + raise ValueError(f"hidden_features must be a positive integer, got {hidden_features}") + + if poolings is None: + poolings = [MinCutPooling] * (num_blocks - 1) + [GlobalGraphPooling] + + assert len(poolings) == num_blocks, "Number of poolings should match num_blocks." + assert len(num_clusters) == num_blocks, "Lenght of number of clusters should match num_blocks." + + + self.message_passing = MessagePassingNeuralNetwork( + hidden_features=[hidden_features], + out_features=hidden_features, + out_activation=nn.ReLU + ) + + self.dense = Layer(nn.Linear, hidden_features, hidden_features) + self.dense.set_input_map('x') + self.dense.set_output_map('x') + + self.activate = Layer(nn.ReLU) + self.activate.set_input_map('x') + self.activate.set_output_map('x') + + + self.blocks = LayerList() + + for i in range(num_blocks): + pool = poolings[i] + + if save_intermediates == True: + edge_index_map = "edge_index" if i == 0 else f"edge_index_{i}" + select_output_map = f"s_{i}" + connect_output_map = f"edge_index_{i+1}" + batch_input_map = "batch" if i == 0 else f"batch_{i}" + batch_output_map = f"batch_{i+1}" + mincut_cut_loss_map = f"L_cut_{i}" + mincut_ortho_loss_map = f"L_ortho_{i}" + + block = GraphEncoderBlock( + in_features=hidden_features, + out_features=hidden_features, + num_clusters=num_clusters[i], + threshold=thresholds[i] if thresholds is not None else None, + pool=pool, + edge_index_map=edge_index_map, + select_output_map=select_output_map, + connect_output_map=connect_output_map, + batch_input_map=batch_input_map, + batch_output_map=batch_output_map, + mincut_cut_loss_map=mincut_cut_loss_map, + mincut_ortho_loss_map=mincut_ortho_loss_map, + ) + + else: + mincut_cut_loss_map = f"L_cut_{i}" + mincut_ortho_loss_map = f"L_ortho_{i}" + + block = GraphEncoderBlock( + in_features=hidden_features, + out_features=hidden_features, + num_clusters=num_clusters[i], + threshold=thresholds[i] if thresholds is not None else None, + pool=pool, + mincut_cut_loss_map=mincut_cut_loss_map, + mincut_ortho_loss_map=mincut_ortho_loss_map, + ) + + self.blocks.append(block) + + def forward(self, x): + x = self.message_passing(x) + x = self.dense(x) + x = self.activate(x) + for block in self.blocks: + x = block(x) + return x + + +class GraphDecoder(DeeplayModule): + """ + A Graph Decoder module that reconstructs graph structures from learned representations generated + by the GraphEncoder. This module aims to decode the latent graph features back into graph node + and edge attributes. + + Attributes: + hidden_features: int + The dimensionality of the hidden layers of the decoder. This should match the hidden + features from the corresponding GraphEncoder. + num_blocks: int + The number of processing blocks in the decoder. This should match the number of blocks + of the GraphEncoder. + output_node_dim: int + The dimensionality of the output node features. This should match the original dimensionallity + of the input node features of the GraphEncoder. + output_edge_dim: int + The dimensionality of the output edge features. This should match the original dimensionallity + of the input edge attributes of the GraphEncoder. + upsamplings: template-like + A list of upsampling layers to use. Default is using MinCut upsampling for all layers, except for the + first, which is global upsampling. This should reflect the pooling layers of the GraphEncoder. + + + Configurables + ------------- + - hidden features (int): Number of features of the hidden layers. + - num_blocks: (int): Number of processing blocks in the decoder. + - output_node_dim: (int): Number of dimensions of the output node features. + - output_edge_dim: (int): Number of dimensions of the output edge attributes. + + Evaluation + ---------- + >>> encoder = dl.GraphEncoder(hidden_features=96, num_blocks=3, num_clusters=[20, 5, 1], thresholds=[0.1, 0.5, None], save_intermediates=False).build() + >>> inp = {} + >>> inp["x"] = torch.randn(10, 16) + >>> inp['batch'] = torch.zeros(10, dtype=int) + >>> inp["edge_index"] = torch.randint(0, 10, (2, 20)) + >>> inp["edge_attr"] = torch.randn(20, 8) + >>> encoder_output = encoder(inp) + >>> decoder = dl.GraphDecoder(hidden_features=96, num_blocks=3, output_node_dim=16, output_edge_dim=8).build() + >>> decoder_output = decoder(encoder_output) + + Return Values + ------------- + The forward method returns a mapping object with the reconstructed node features and edge attributes. + + """ + + hidden_features: int + num_blocks: int + output_node_dim: int + output_edge_dim: int + upsamplings: Optional[Sequence[nn.Module]] + + def __init__( + self, + hidden_features: int, + num_blocks: int, + output_node_dim: int, + output_edge_dim: int, + upsamplings: Optional[Sequence[Union[Type[nn.Module], nn.Module]]] = None, + ): + super().__init__( + hidden_features = hidden_features, + output_node_dim = output_node_dim, + output_edge_dim = output_edge_dim, + num_blocks = num_blocks, + upsamplings = upsamplings, + ) + + if not isinstance(hidden_features, int) or hidden_features <= 0: + raise ValueError(f"hidden_features must be a positive integer, got {hidden_features}") + + if upsamplings is None: + upsamplings = [GlobalGraphUpsampling] + [MinCutUpsampling] * (num_blocks - 1) + + assert len(upsamplings) == num_blocks, "Number of upsamplings should match num_blocks." + + self.blocks = LayerList() + + for i in range(num_blocks): + upsample = upsamplings[i] + edge_index_map = "edge_index" if i == num_blocks-1 else f"edge_index_{num_blocks-1-i}" + select_input_map = f"s_{num_blocks-1-i}" + connect_input_map = f"edge_index_{num_blocks-i}" + + block = GraphDecoderBlock( + in_features=hidden_features, + out_features=hidden_features, + upsample=upsample, + edge_index_map=edge_index_map, + select_input_map=select_input_map, + connect_input_map=connect_input_map, + ) + + self.blocks.append(block) + + self.dense = Layer(nn.Linear, hidden_features, hidden_features) + self.dense.set_input_map('x') + self.dense.set_output_map('x') + + self.activate = Layer(nn.ReLU) + self.activate.set_input_map('x') + self.activate.set_output_map('x') + + # get the edge features: + self.get_edge_attr = GetEdgeFeatures( + combine=Cat(), + layer=Layer(nn.LazyLinear, hidden_features), + activation=Layer(nn.ReLU), + ) + self.get_edge_attr.set_input_map("x", "edge_index") + self.get_edge_attr.set_output_map("edge_attr") + + self.edge_mlp = MultiLayerPerceptron( + in_features=hidden_features, + hidden_features=[hidden_features], + out_features=output_edge_dim, + out_activation=None, + ) + self.edge_mlp.set_input_map("edge_attr") + self.edge_mlp.set_output_map("edge_attr") + + # get the node features: + self.node_mlp = MultiLayerPerceptron( + in_features = hidden_features, + hidden_features = [hidden_features, hidden_features], + out_features = output_node_dim, + out_activation = None, + ) + self.node_mlp.set_input_map('x') + self.node_mlp.set_output_map('x') + + def forward(self, x): + for block in self.blocks: + x = block(x) + + x = self.dense(x) + x = self.activate(x) + x = self.get_edge_attr(x) + x = self.edge_mlp(x) + x = self.node_mlp(x) + return x + + +class GraphEncoderBlock(DeeplayModule): + """ + A Graph Encoder Block that processes graph data through a Graph Convolutional Neural Network (GCN) + and applies pooling operations to generate encoded representations of the graph structure. + This block is a fundamental component of the GraphEncoder, enabling hierarchical feature extraction. + + Parameters + ---------- + in_features: int + The number of input features for each node in the graph. + out_features: int + The number of output features for each node after processing. + pool: Optional[template-like] + The pooling operation to be used. Defaults to MinCutPooling. + num_clusters: Optional[int] + The number of clusters for MinCutPooling. Must be provided if using MinCutPooling. + threshold: Optional[float] + Threshold value for pooling operations. + edge_index_map: Optional[str] + The mapping for edge index inputs. Defaults to "edge_index". + select_output_map: Optional[str] + The mapping for the selection outputs from the pooling layer. Defaults to "s". + connect_output_map: Optional[str] + The mapping for connecting outputs to subsequent layers. Defaults to "edge_index_pool". + batch_input_map: Optional[str] + The mapping for batch input. Defaults to "batch". + batch_output_map: Optional[str] + The mapping for batch output. Defaults to "batch". + mincut_cut_loss_map: Optional[str] + The mapping for saving the mincut cut loss. Defaults to "L_cut". + mincut_ortho_loss_map: Optional[str] + The mapping for saving the mincut orthogonallity loss. Defaults to "L_ortho". + + + Evaluation + ---------- + >>> block = dl.GraphEncoderBlock(in_features=16, out_features=16, num_clusters=5, threshold=0.1).build() + >>> inp = {} + >>> inp["x"] = torch.randn(10, 16) + >>> inp['batch'] = torch.zeros(10, dtype=int) + >>> inp["edge_index"] = torch.randint(0, 10, (2, 20)) + >>> inp["edge_attr"] = torch.randn(20, 8) + >>> output = block(inp) + """ + + in_features: Optional[int] + hidden_features: Sequence[Optional[int]] + out_features: int + pool: Optional[nn.Module] + num_clusters: Optional[int] + threshold: Optional[float] + edge_index_map: Optional[str] + select_output_map: Optional[str] + connect_output_map: Optional[str] + batch_input_map: Optional[str] + batch_output_map: Optional[str] + mincut_cut_loss_map: Optional[str] + mincut_ortho_loss_map: Optional[str] + + def __init__( + self, + in_features: int, + out_features: int, + pool: Optional[Union[Type[nn.Module], nn.Module, None]] = MinCutPooling, + num_clusters: Optional[int] = None, + threshold: Optional[float] = None, + edge_index_map: Optional[str] = "edge_index", + select_output_map: Optional[str] = "s", + connect_output_map: Optional[str] = "edge_index_pool", + batch_input_map: Optional[str] = "batch", + batch_output_map: Optional[str] = "batch", + mincut_cut_loss_map: Optional[str] = 'L_cut', + mincut_ortho_loss_map: Optional[str] = 'L_ortho', + ): + super().__init__( + in_features = in_features, + num_clusters = num_clusters, + threshold = threshold, + out_features = out_features, + pool = pool, + ) + self.edge_index_map = edge_index_map + self.connect_output_map = connect_output_map + + self.gcn = GraphConvolutionalNeuralNetwork( + in_features=in_features, + hidden_features=[], + out_features=out_features, + out_activation=nn.ReLU, + ) + + self.gcn.normalize.set_input_map('x', edge_index_map) + + if pool == MinCutPooling: + if num_clusters is None: + raise ValueError("num_clusters must be provided for MinCutPooling") + + self.pool = pool(hidden_features=[out_features], num_clusters=num_clusters, threshold=threshold) + self.pool.mincut_loss.set_input_map(edge_index_map, select_output_map) + self.pool.mincut_loss.set_output_map(mincut_cut_loss_map, mincut_ortho_loss_map) + else: + self.pool = pool() + + self.pool.select.set_output_map(select_output_map) + + if hasattr(self.pool, 'reduce'): + self.pool.reduce.set_input_map('x', select_output_map) + if hasattr(self.pool, 'batch_compatible'): + self.pool.batch_compatible.set_input_map(select_output_map, batch_input_map) + self.pool.batch_compatible.set_output_map(select_output_map, batch_output_map) + if hasattr(self.pool, 'connect'): + self.pool.connect.set_input_map(edge_index_map, select_output_map) + self.pool.connect.set_output_map(connect_output_map) + if hasattr(self.pool, 'red_self_con') and self.pool.red_self_con is not None: + self.pool.red_self_con.set_input_map(connect_output_map) + self.pool.red_self_con.set_output_map(connect_output_map) + if hasattr(self.pool, 'apply_threshold') and self.pool.apply_threshold is not None: + self.pool.apply_threshold.set_input_map(connect_output_map) + self.pool.apply_threshold.set_output_map(connect_output_map) + if hasattr(self.pool, 'sparse'): + self.pool.sparse.set_input_map(connect_output_map) + self.pool.sparse.set_output_map(connect_output_map) + + def forward(self, x): + x = self.gcn(x) + x = self.pool(x) + return x + + +class GraphDecoderBlock(DeeplayModule): + """ + A Graph Decoder Block that upsamples a graph and applies a Graph Convolutional Neural Network (GCN). + This block is a fundamental component of the GraphDecoder, enabling the reconstruction of graph features + in a Graph Encoder Decoder model. + + + Parameters + ---------- + in_features: int + The number of input features for each node in the graph. + out_features: int + The number of output features for each node after processing. + upsample: Optional[template-like] + The upsampling operation to be used. Defaults to MinCutUpsampling. + edge_index_map: Optional[str] + The mapping for edge index inputs. Defaults to "edge_index". + select_input_map: Optional[str] + The mapping for selection inputs for the upsampling layer. Defaults to "s". + connect_input_map: Optional[str] + The mapping for the connectivity for the upsampling layer. Defaults to "edge_index_pool". + connect_output_map: Optional[str] + The mapping for the connectivity outputs of the upsampling layer. Defaults to "-". + batch_map: Optional[str] + The mapping for batch inputs or outputs. Defaults to "batch". + + + Evaluation + ---------- + >>> encoderblock = dl.GraphEncoderBlock(in_features=16, out_features=16, num_clusters=5, threshold=0.2).build() + >>> inp = {} + >>> inp["x"] = torch.randn(10, 16) + >>> inp['batch'] = torch.zeros(10, dtype=int) + >>> inp["edge_index"] = torch.randint(0, 10, (2, 20)) + >>> inp["edge_attr"] = torch.randn(20, 8) + >>> encoderblock_output = encoderblock(inp) + >>> decoderblock = dl.GraphDecoderBlock(in_features=16, out_features=16).build() + >>> decoderblock_output = decoderblock(encoderblock_output) + """ + in_features: int + out_features: int + upsample: Optional[nn.Module] + edge_index_map: Optional[str] + select_input_map: Optional[str] + connect_input_map: Optional[str] + connect_output_map: Optional[str] + batch_map: Optional[str] + + def __init__( + self, + in_features: int, + out_features: int, + upsample: Optional[Union[Type[nn.Module], nn.Module, None]] = MinCutUpsampling, + edge_index_map: Optional[str] = "edge_index", + select_input_map: Optional[str] = "s", + connect_input_map: Optional[str] = "edge_index_pool", + connect_output_map: Optional[str] = "-", + ): + super().__init__( + in_features = in_features, + out_features = out_features, + upsample = upsample, + edge_index_map=edge_index_map, + select_input_map=select_input_map, + connect_input_map=connect_input_map, + ) + + if upsample == MinCutUpsampling: + self.upsample = upsample() + self.upsample.upsample.set_input_map('x', connect_input_map, select_input_map) + self.upsample.upsample.set_output_map('x', connect_output_map) + + else: + self.upsample = upsample() + self.upsample.upsample.set_input_map('x', select_input_map) + + self.gcn = GraphConvolutionalNeuralNetwork( + in_features=in_features, + hidden_features=[], + out_features=out_features, + out_activation=nn.ReLU, + ) + + self.gcn.normalize.set_input_map('x', edge_index_map) + + def forward(self, x): + x = self.upsample(x) + x = self.gcn(x) + return x \ No newline at end of file diff --git a/deeplay/components/gnn/mpn/__init__.py b/deeplay/components/gnn/mpn/__init__.py index be5bf416..fbd18602 100644 --- a/deeplay/components/gnn/mpn/__init__.py +++ b/deeplay/components/gnn/mpn/__init__.py @@ -4,6 +4,7 @@ from .transformation import * from .propagation import Sum, WeightedSum, Mean, Max, Min, Prod from .update import * +from .get_edge_features import * from .cla import CombineLayerActivation from .ldw import LearnableDistancewWeighting diff --git a/deeplay/components/gnn/mpn/get_edge_features.py b/deeplay/components/gnn/mpn/get_edge_features.py new file mode 100644 index 00000000..a284de39 --- /dev/null +++ b/deeplay/components/gnn/mpn/get_edge_features.py @@ -0,0 +1,16 @@ +from .cla import CombineLayerActivation + + +class GetEdgeFeatures(CombineLayerActivation): + """""" + + def get_forward_args(self, x): # maybe use Tranform instead, and just take the first two ouputs + """Get the arguments for the ... module. + An MPN ... module takes the following arguments: + - node features of sender nodes (x[A[0]]) + - node features of receiver nodes (x[A[1]]) + + A is the adjacency matrix of the graph. + """ + x, edge_index = x + return x[edge_index[0]], x[edge_index[1]] diff --git a/deeplay/components/gnn/pooling/__init__.py b/deeplay/components/gnn/pooling/__init__.py new file mode 100644 index 00000000..1e2c8651 --- /dev/null +++ b/deeplay/components/gnn/pooling/__init__.py @@ -0,0 +1,2 @@ +from .mincut import MinCutPooling +from .graph_pool import * \ No newline at end of file diff --git a/deeplay/components/gnn/pooling/graph_pool.py b/deeplay/components/gnn/pooling/graph_pool.py new file mode 100644 index 00000000..df1a6554 --- /dev/null +++ b/deeplay/components/gnn/pooling/graph_pool.py @@ -0,0 +1,147 @@ +from typing import Optional + +import torch.nn as nn +import torch +from deeplay.module import DeeplayModule + +class GlobalGraphPooling(DeeplayModule): + """ + Pools all the nodes of the graph to a single cluster. + + (Inspired by MinCut-pooling ('Spectral Clustering with Graph Neural Networks for Graph Pooling'): + but with the assignment matrix S being deterministic (all nodes are pooled into one cluster)) + + Input + ----- + X: float (Any, Any) #(number of nodes, number of features) + + Output + ------ + X: float (1, Any) #(number of clusters, number of features) + S: float (Any, 1) #(number of nodes, number of clusters) + """ + # select_output_map: Optional[str] + + def __init__( + self, + # select_output_map: Optional[str] = "s", + ): + super().__init__() + + # self.select_output_map = select_output_map + + class Select(DeeplayModule): + def forward(self, x): + return torch.ones((x.shape[0], 1)) # is this the right dim even if we use batches? + + class ClusterMatrixForBatch(DeeplayModule): + def forward(self, S, B): + unique_graphs = torch.unique(B) + num_graphs = len(unique_graphs) + + S_ = torch.zeros(S.shape[0] * num_graphs) + + row_indices = torch.arange(S.shape[0]) + col_indices = B + + S_[row_indices * num_graphs + col_indices] = S.view(-1) + B_ = torch.arange(num_graphs) + + return S_.reshape([S.shape[0], -1]), B_ + + + class Reduce(DeeplayModule): + def forward(self, x, s): + # return torch.sum(x, dim=0, keepdim=True) + return torch.matmul(s.transpose(-2,-1), x) + + self.select = Select() + self.select.set_input_map('x') + self.select.set_output_map('s') #self.select_output_map) + + self.batch_compatible = ClusterMatrixForBatch() + self.batch_compatible.set_input_map('s', 'batch') + self.batch_compatible.set_output_map('s', 'batch') + + self.reduce = Reduce() + self.reduce.set_input_map('x', 's') + self.reduce.set_output_map('x') + + def forward(self, x): + x = self.select(x) + x = self.batch_compatible(x) + x = self.reduce(x) + return (x) + + +class GlobalGraphUpsampling(DeeplayModule): + """ + Reverse of GlobalGraphPooling. + Only upsampling the node features. + """ + # select_input_map: Optional[str] + + def __init__( + self, + # select_input_map: Optional[str] = "s", + ): + super().__init__() + # self.select_input_map = select_input_map + + class Upsample(DeeplayModule): + def forward(self, x, s): + return torch.matmul(s, x) + + self.upsample = Upsample() + # self.upsample.set_input_map('x_pool', 's') + self.upsample.set_input_map('x', 's')#self.select_input_map) + self.upsample.set_output_map('x') + + def forward(self, x): + x = self.upsample(x) + return x + + +class MinCutUpsampling(DeeplayModule): + """ + Reverse of MinCutPooling as described in 'Spectral Clustering with Graph Neural Networks for Graph Pooling'. + """ + # select_input_map: Optional[str] + # connect_input_map: Optional[str] + + def __init__( + self, + # select_input_map: Optional[str] = "s", + # connect_input_map: Optional[str] = "edge_index", + ): + super().__init__() + # self.select_input_map = select_input_map + # self.connect_input_map = connect_input_map + + class Upsample(DeeplayModule): + def forward(self, x_pool, a_pool, s): + x = torch.matmul(s, x_pool) + + if a_pool.is_sparse: + a = torch.spmm(a_pool, s.T) + elif (not a_pool.is_sparse) & (a_pool.size(0) == 2): + a_pool = torch.sparse_coo_tensor( + a_pool, + torch.ones(a_pool.size(1)), + ((s.T).size(0),) * 2, + device=a_pool.device, + ) + a = torch.spmm(a_pool, s.T) + elif (not a_pool.is_sparse) & len({a_pool.size(0), a_pool.size(1), (s.T).size(0)}) == 1: + a = a_pool.type(s.dtype) @ s.T + + return x, a + + self.upsample = Upsample() + self.upsample.set_input_map('x', 'edge_index_pool', 's') + # self.upsample.set_input_map('x', self.connect_input_map, self.select_input_map) + self.upsample.set_output_map('x', 'edge_index_') + + def forward(self, x): + x = self.upsample(x) + return x \ No newline at end of file diff --git a/deeplay/components/gnn/pooling/mincut.py b/deeplay/components/gnn/pooling/mincut.py new file mode 100644 index 00000000..0df14dd4 --- /dev/null +++ b/deeplay/components/gnn/pooling/mincut.py @@ -0,0 +1,241 @@ +from typing import Sequence, Optional + +from deeplay import DeeplayModule + +from deeplay.components.mlp import MultiLayerPerceptron + +import torch +import torch.nn as nn + +class MinCutPooling(DeeplayModule): + """ + MinCut graph pooling as described in 'Spectral Clustering with Graph Neural Networks for Graph Pooling'. + + Parameters + ---------- + num_clusters: int + The number of clusters to which each graph is pooled. + hidden_features: Sequence[int] + The number of hidden features for the Multi-Layer Perceptron (MLP) used for selecting clusters for the pooling. + reduce_self_connection: Optional[bool] + Whether to reduce self-connections in the adjacency matrix. Defaults to True. + threshold: Optional[float] + A threshold value to apply to the adjacency matrix to binarize the edges. If None, no threshold is applied. + + Evaluation + ---------- + >>> MinCut = dl.components.gnn.pooling.MinCutPooling(hidden_features = [8], num_clusters = 5, reduce_self_connection = True, threshold = 0.25).build() + >>> inp = {} + >>> inp["x"] = torch.randn(10, 16) + >>> inp['batch'] = torch.zeros(10, dtype=int) + >>> inp["edge_index"] = torch.randint(0, 10, (2, 20)) + >>> output = MinCut(inp) + + """ + + num_clusters: int + hidden_features: Sequence[int] + reduce_self_connection: Optional[bool] + threshold: Optional[float] + + def __init__( + self, + num_clusters: int, + hidden_features: Sequence[int], + reduce_self_connection: Optional[bool] = True, + threshold: Optional[float] = None, + ): + super().__init__() + + self.num_clusters = num_clusters + self.reduce_self_connection = reduce_self_connection + self.threshold = threshold + + class ClusterMatrixForBatch(DeeplayModule): + def forward(self, S, B): + + unique_graphs = torch.unique(B) + num_graphs = len(unique_graphs) + + S_ = torch.zeros(S.shape[0] * S.shape[1] * num_graphs) + + row_indices = torch.arange(S.shape[0]).repeat_interleave(S.shape[1]) + col_indices = B.repeat_interleave(S.shape[1]) * S.shape[1] + torch.arange(S.shape[1]).repeat(S.shape[0]) + + S_[row_indices * (S.shape[1] * num_graphs) + col_indices] = S.view(-1) + + B_ = torch.arange(num_graphs).repeat_interleave(S.shape[1]) + + return S_.reshape([S.shape[0], -1]), B_ + + class Reduce(DeeplayModule): + def forward(self, x, s): + return torch.matmul(s.transpose(-2,-1), x) + + class Connect(DeeplayModule): + def forward(self, A, s): + if A.is_sparse: + return torch.spmm(s.transpose(-2,-1), torch.spmm(A, s)) + elif (not A.is_sparse) & (A.size(0) == 2): + A = torch.sparse_coo_tensor( + A, + torch.ones(A.size(1)), + (s.size(0),) * 2, + device=A.device, + ) + return torch.spmm(s.transpose(-2,-1), torch.spmm(A, s)) + elif (not A.is_sparse) & len({A.size(0), A.size(1), s.size(0)}) == 1: + return s.transpose(-2,-1) @ A.type(s.dtype) @ s + else: + raise ValueError( + "Unsupported adjacency matrix format.", + "Ensure it is a pytorch sparse tensor, an edge index tensor, or a square dense tensor.", + "Consider updating the propagate layer to handle alternative formats.", + ) + + class ReduceSelfConnection(DeeplayModule): + def __init__( + self, + eps: Optional[float] = 1e-15, + ): + super().__init__() + self.eps = eps + + def forward(self, A): + ind = torch.arange(A.shape[0]) + A[ind, ind] = 0 + d = torch.einsum('jk->j', A) + d = torch.sqrt(d)[None] #+ self.eps + d = torch.where(torch.isinf(d), torch.tensor(0.0, device=d.device), d) # Replace infinities in `d` with 0 + A = (A / d) / d.transpose(-2,-1) + return A + + class MinCutLoss(DeeplayModule): + def forward(self, A, S): + n_nodes = S.size(0) # number of nodes + n_clusters = S.size(1) # number of clusters in total (= number of clusters per graph * num graphs) + + if A.is_sparse: + degree = torch.sum(A, dim=0) + elif (not A.is_sparse) & (A.size(0) == 2): + A = torch.sparse_coo_tensor( + A, + torch.ones(A.size(1)), + (n_nodes,) * 2, + device=A.device, + ) + degree = torch.sum(A, dim=0) + elif (not A.is_sparse) & len({A.size(0), A.size(1)}) == 1: + degree = torch.sum(A, dim=0) + else: + raise ValueError( + "Unsupported adjacency matrix format.", + "Ensure it is a pytorch sparse tensor, an edge index tensor, or a square dense tensor.", + "Consider updating the propagate layer to handle alternative formats.", + ) + + D = torch.eye(n_nodes) * degree + + denominator_trace = torch.trace(torch.matmul(S.transpose(-2, -1), torch.matmul(D, S))) + denominator_trace[denominator_trace == float('inf')] = 0 # Replace infinities in `D` with 0 + + # cut loss: + # L_cut = - torch.trace(torch.matmul(S.transpose(-2,-1), torch.matmul(A, S))) / (torch.trace(torch.matmul(S.transpose(-2,-1), torch.matmul(D, S)))) + L_cut = - torch.trace(torch.matmul(S.transpose(-2,-1), torch.matmul(A, S))) / denominator_trace + + # orthogonality loss: + L_ortho = torch.linalg.norm( + (torch.matmul(S.transpose(-2,-1), S) / torch.linalg.norm(torch.matmul(S.transpose(-2,-1), S), ord = 'fro')) + - (torch.eye(n_clusters) / torch.sqrt(torch.tensor(n_clusters))), + ord = 'fro') + + + return L_cut, L_ortho + + class ApplyThreshold(DeeplayModule): + def __init__(self, threshold: float): + super().__init__() + self.threshold = threshold + + def forward(self, A): + return torch.where(A >= threshold, 1.0, 0.0) + + + class SparseEdgeIndex(DeeplayModule): + """ output edge index as a sparse tensor """ + def forward(self, A): + if A.is_sparse: + edge_index = A + return edge_index + else: + edge_index = A.to_sparse() + return edge_index + + + # select: S = MLP(X) + self.select = MultiLayerPerceptron( + in_features=None, + hidden_features=hidden_features, + out_features=num_clusters, + out_activation=nn.Softmax(dim=1)) + self.select.set_input_map("x") + self.select.set_output_map('s') + + # make S compatible with batches: + self.batch_compatible = ClusterMatrixForBatch() + self.batch_compatible.set_input_map("s", "batch") + self.batch_compatible.set_output_map("s", "batch") + + # mincut loss + self.mincut_loss = MinCutLoss() + self.mincut_loss.set_input_map('edge_index', 's') + self.mincut_loss.set_output_map('L_cut', 'L_ortho') + + # reduce: X' = S^T * X + self.reduce = Reduce() + self.reduce.set_input_map("x", 's') + self.reduce.set_output_map("x") + + # connect: A' = S^T * A * S + self.connect = Connect() + self.connect.set_input_map('edge_index', 's') + self.connect.set_output_map("edge_index") + + # reduce self connection + self.red_self_con = None + if reduce_self_connection: + self.red_self_con = ReduceSelfConnection(self.num_clusters) + self.red_self_con.set_input_map('edge_index') + self.red_self_con.set_output_map('edge_index') + + # apply threshold to A + self.apply_threshold = None + if threshold is not None: + self.apply_threshold = ApplyThreshold(self.threshold) + self.apply_threshold.set_input_map('edge_index') + self.apply_threshold.set_output_map('edge_index') + + # # make A sparse + self.sparse = SparseEdgeIndex() + self.sparse.set_input_map('edge_index') + self.sparse.set_output_map('edge_index') + + + def forward(self, x): + x = self.select(x) + x = self.batch_compatible(x) + x = self.mincut_loss(x) + x = self.reduce(x) + x = self.connect(x) + + if self.red_self_con is not None: + x = self.red_self_con(x) + + if self.apply_threshold is not None: + x = self.apply_threshold(x) + + x = self.sparse(x) + + return x + + \ No newline at end of file From 7d53b7895a63e0f2d715fad2f32b42c18a433760 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Tue, 12 Nov 2024 14:51:53 +0100 Subject: [PATCH 3/8] Mg/Graph autoencoder --- deeplay/applications/autoencoders/vgae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deeplay/applications/autoencoders/vgae.py b/deeplay/applications/autoencoders/vgae.py index 3e7a3e62..dd394bb6 100644 --- a/deeplay/applications/autoencoders/vgae.py +++ b/deeplay/applications/autoencoders/vgae.py @@ -31,9 +31,9 @@ def __init__( decoder: Optional[nn.Module] = None, reconstruction_loss: Optional[Callable] = nn.L1Loss(), latent_dim=int, - alpha: Optional[int] = 0, + alpha: Optional[int] = 1, beta: Optional[int] = 1e-7, - gamma: Optional[int] = 10, + gamma: Optional[int] = 1, delta: Optional[int] = 1, optimizer=None, **kwargs, From 5c11cb6205e565330a6a862841893b680096702b Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Tue, 12 Nov 2024 15:53:10 +0100 Subject: [PATCH 4/8] smaller corrections --- deeplay/components/gnn/gcn/normalization.py | 3 +-- deeplay/components/gnn/graphencdec.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/deeplay/components/gnn/gcn/normalization.py b/deeplay/components/gnn/gcn/normalization.py index 98c2ddf5..7d44226e 100644 --- a/deeplay/components/gnn/gcn/normalization.py +++ b/deeplay/components/gnn/gcn/normalization.py @@ -11,9 +11,8 @@ def add_self_loops(self, A, num_nodes): """ loop_index = torch.arange(num_nodes, device=A.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) - - if A.is_sparse: # changed this to ensure that it works even if the format of A varies. Previous: A = torch.cat([A, loop_index], dim=1) + if A.is_sparse: A = torch.cat([A.indices(), loop_index], dim=1) elif (not A.is_sparse) & (A.size(0) == 2): A = torch.cat([A, loop_index], dim=1) diff --git a/deeplay/components/gnn/graphencdec.py b/deeplay/components/gnn/graphencdec.py index 657acb02..259514e6 100644 --- a/deeplay/components/gnn/graphencdec.py +++ b/deeplay/components/gnn/graphencdec.py @@ -10,7 +10,6 @@ from deeplay.components.gnn import GraphConvolutionalNeuralNetwork, MessagePassingNeuralNetwork from deeplay.components.gnn.pooling import MinCutPooling from deeplay.ops import Cat -# from deeplay.deeplay.components.gnn.pooling.graph_pool import GlobalGraphPooling, GlobalGraphUpsampling, MinCutUpsampling from deeplay.components.gnn.pooling import GlobalGraphPooling, GlobalGraphUpsampling, MinCutUpsampling from deeplay.components.gnn.mpn import GetEdgeFeatures from deeplay.components.mlp import MultiLayerPerceptron From 0700ceeb0c3e421736ca498c7918e183532b92b4 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Thu, 14 Nov 2024 16:04:03 +0100 Subject: [PATCH 5/8] Mg/Graph autoencoder --- deeplay/components/gnn/graphencdec.py | 4 +- .../components/gnn/mpn/get_edge_features.py | 11 ++-- deeplay/components/gnn/pooling/__init__.py | 4 +- .../pooling/{graph_pool.py => globalpool.py} | 62 ++--------------- deeplay/components/gnn/pooling/mincut.py | 66 ++++++++++++++++--- 5 files changed, 69 insertions(+), 78 deletions(-) rename deeplay/components/gnn/pooling/{graph_pool.py => globalpool.py} (51%) diff --git a/deeplay/components/gnn/graphencdec.py b/deeplay/components/gnn/graphencdec.py index 259514e6..e78ec4ac 100644 --- a/deeplay/components/gnn/graphencdec.py +++ b/deeplay/components/gnn/graphencdec.py @@ -8,9 +8,9 @@ LayerList, ) from deeplay.components.gnn import GraphConvolutionalNeuralNetwork, MessagePassingNeuralNetwork -from deeplay.components.gnn.pooling import MinCutPooling +from deeplay.components.gnn.pooling import MinCutPooling, MinCutUpsampling from deeplay.ops import Cat -from deeplay.components.gnn.pooling import GlobalGraphPooling, GlobalGraphUpsampling, MinCutUpsampling +from deeplay.components.gnn.pooling import GlobalGraphPooling, GlobalGraphUpsampling from deeplay.components.gnn.mpn import GetEdgeFeatures from deeplay.components.mlp import MultiLayerPerceptron diff --git a/deeplay/components/gnn/mpn/get_edge_features.py b/deeplay/components/gnn/mpn/get_edge_features.py index a284de39..05379ef8 100644 --- a/deeplay/components/gnn/mpn/get_edge_features.py +++ b/deeplay/components/gnn/mpn/get_edge_features.py @@ -4,13 +4,12 @@ class GetEdgeFeatures(CombineLayerActivation): """""" - def get_forward_args(self, x): # maybe use Tranform instead, and just take the first two ouputs - """Get the arguments for the ... module. - An MPN ... module takes the following arguments: - - node features of sender nodes (x[A[0]]) - - node features of receiver nodes (x[A[1]]) + def get_forward_args(self, x): + """Get the node features of neighboring nodes for each edge. + - node features of sender nodes (x[edge_index[0]]) + - node features of receiver nodes (x[edge_index[1]]) - A is the adjacency matrix of the graph. + edge_index denote the connectivity of the graph. """ x, edge_index = x return x[edge_index[0]], x[edge_index[1]] diff --git a/deeplay/components/gnn/pooling/__init__.py b/deeplay/components/gnn/pooling/__init__.py index 1e2c8651..e13828e0 100644 --- a/deeplay/components/gnn/pooling/__init__.py +++ b/deeplay/components/gnn/pooling/__init__.py @@ -1,2 +1,2 @@ -from .mincut import MinCutPooling -from .graph_pool import * \ No newline at end of file +from .mincut import * +from .globalpool import * \ No newline at end of file diff --git a/deeplay/components/gnn/pooling/graph_pool.py b/deeplay/components/gnn/pooling/globalpool.py similarity index 51% rename from deeplay/components/gnn/pooling/graph_pool.py rename to deeplay/components/gnn/pooling/globalpool.py index df1a6554..1f79093d 100644 --- a/deeplay/components/gnn/pooling/graph_pool.py +++ b/deeplay/components/gnn/pooling/globalpool.py @@ -20,19 +20,15 @@ class GlobalGraphPooling(DeeplayModule): X: float (1, Any) #(number of clusters, number of features) S: float (Any, 1) #(number of nodes, number of clusters) """ - # select_output_map: Optional[str] def __init__( self, - # select_output_map: Optional[str] = "s", ): super().__init__() - # self.select_output_map = select_output_map - class Select(DeeplayModule): def forward(self, x): - return torch.ones((x.shape[0], 1)) # is this the right dim even if we use batches? + return torch.ones((x.shape[0], 1)) class ClusterMatrixForBatch(DeeplayModule): def forward(self, S, B): @@ -52,12 +48,11 @@ def forward(self, S, B): class Reduce(DeeplayModule): def forward(self, x, s): - # return torch.sum(x, dim=0, keepdim=True) return torch.matmul(s.transpose(-2,-1), x) self.select = Select() self.select.set_input_map('x') - self.select.set_output_map('s') #self.select_output_map) + self.select.set_output_map('s') self.batch_compatible = ClusterMatrixForBatch() self.batch_compatible.set_input_map('s', 'batch') @@ -79,69 +74,20 @@ class GlobalGraphUpsampling(DeeplayModule): Reverse of GlobalGraphPooling. Only upsampling the node features. """ - # select_input_map: Optional[str] def __init__( self, - # select_input_map: Optional[str] = "s", ): super().__init__() - # self.select_input_map = select_input_map - + class Upsample(DeeplayModule): def forward(self, x, s): return torch.matmul(s, x) self.upsample = Upsample() - # self.upsample.set_input_map('x_pool', 's') - self.upsample.set_input_map('x', 's')#self.select_input_map) + self.upsample.set_input_map('x', 's') self.upsample.set_output_map('x') - def forward(self, x): - x = self.upsample(x) - return x - - -class MinCutUpsampling(DeeplayModule): - """ - Reverse of MinCutPooling as described in 'Spectral Clustering with Graph Neural Networks for Graph Pooling'. - """ - # select_input_map: Optional[str] - # connect_input_map: Optional[str] - - def __init__( - self, - # select_input_map: Optional[str] = "s", - # connect_input_map: Optional[str] = "edge_index", - ): - super().__init__() - # self.select_input_map = select_input_map - # self.connect_input_map = connect_input_map - - class Upsample(DeeplayModule): - def forward(self, x_pool, a_pool, s): - x = torch.matmul(s, x_pool) - - if a_pool.is_sparse: - a = torch.spmm(a_pool, s.T) - elif (not a_pool.is_sparse) & (a_pool.size(0) == 2): - a_pool = torch.sparse_coo_tensor( - a_pool, - torch.ones(a_pool.size(1)), - ((s.T).size(0),) * 2, - device=a_pool.device, - ) - a = torch.spmm(a_pool, s.T) - elif (not a_pool.is_sparse) & len({a_pool.size(0), a_pool.size(1), (s.T).size(0)}) == 1: - a = a_pool.type(s.dtype) @ s.T - - return x, a - - self.upsample = Upsample() - self.upsample.set_input_map('x', 'edge_index_pool', 's') - # self.upsample.set_input_map('x', self.connect_input_map, self.select_input_map) - self.upsample.set_output_map('x', 'edge_index_') - def forward(self, x): x = self.upsample(x) return x \ No newline at end of file diff --git a/deeplay/components/gnn/pooling/mincut.py b/deeplay/components/gnn/pooling/mincut.py index 0df14dd4..9766e1ca 100644 --- a/deeplay/components/gnn/pooling/mincut.py +++ b/deeplay/components/gnn/pooling/mincut.py @@ -104,13 +104,19 @@ def __init__( def forward(self, A): ind = torch.arange(A.shape[0]) A[ind, ind] = 0 - d = torch.einsum('jk->j', A) - d = torch.sqrt(d)[None] #+ self.eps - d = torch.where(torch.isinf(d), torch.tensor(0.0, device=d.device), d) # Replace infinities in `d` with 0 - A = (A / d) / d.transpose(-2,-1) + D = torch.einsum('jk->j', A) + D = torch.sqrt(D)[None] + self.eps + A = (A / D) / D.transpose(-2,-1) return A class MinCutLoss(DeeplayModule): + def __init__( + self, + eps: Optional[float] = 1e-15, + ): + super().__init__() + self.eps = eps + def forward(self, A, S): n_nodes = S.size(0) # number of nodes n_clusters = S.size(1) # number of clusters in total (= number of clusters per graph * num graphs) @@ -132,16 +138,18 @@ def forward(self, A, S): "Unsupported adjacency matrix format.", "Ensure it is a pytorch sparse tensor, an edge index tensor, or a square dense tensor.", "Consider updating the propagate layer to handle alternative formats.", - ) + ) - D = torch.eye(n_nodes) * degree + eps = torch.sparse_coo_tensor( + indices=torch.arange(n_nodes).repeat(2, 1), + values=torch.zeros(n_nodes) + self.eps, + size=(n_nodes, n_nodes), + ) - denominator_trace = torch.trace(torch.matmul(S.transpose(-2, -1), torch.matmul(D, S))) - denominator_trace[denominator_trace == float('inf')] = 0 # Replace infinities in `D` with 0 + D = torch.eye(n_nodes) * degree + eps # cut loss: - # L_cut = - torch.trace(torch.matmul(S.transpose(-2,-1), torch.matmul(A, S))) / (torch.trace(torch.matmul(S.transpose(-2,-1), torch.matmul(D, S)))) - L_cut = - torch.trace(torch.matmul(S.transpose(-2,-1), torch.matmul(A, S))) / denominator_trace + L_cut = - torch.trace(torch.matmul(S.transpose(-2,-1), torch.matmul(A, S))) / (torch.trace(torch.matmul(S.transpose(-2,-1), torch.matmul(D, S)))) # orthogonality loss: L_ortho = torch.linalg.norm( @@ -238,4 +246,42 @@ def forward(self, x): return x + +class MinCutUpsampling(DeeplayModule): + """ + Reverse of MinCutPooling as described in 'Spectral Clustering with Graph Neural Networks for Graph Pooling'. + """ + + def __init__( + self, + ): + super().__init__() + + class Upsample(DeeplayModule): + def forward(self, x_pool, a_pool, s): + x = torch.matmul(s, x_pool) + + if a_pool.is_sparse: + a = torch.spmm(a_pool, s.T) + elif (not a_pool.is_sparse) & (a_pool.size(0) == 2): + a_pool = torch.sparse_coo_tensor( + a_pool, + torch.ones(a_pool.size(1)), + ((s.T).size(0),) * 2, + device=a_pool.device, + ) + a = torch.spmm(a_pool, s.T) + elif (not a_pool.is_sparse) & len({a_pool.size(0), a_pool.size(1), (s.T).size(0)}) == 1: + a = a_pool.type(s.dtype) @ s.T + + return x, a + + self.upsample = Upsample() + self.upsample.set_input_map('x', 'edge_index_pool', 's') + self.upsample.set_output_map('x', 'edge_index_') + + def forward(self, x): + x = self.upsample(x) + return x + \ No newline at end of file From 4d08a7b0042bfeae4e3b468e7aa3312bd638ba6d Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Fri, 15 Nov 2024 17:19:51 +0100 Subject: [PATCH 6/8] change of mappings in gcn --- deeplay/components/gnn/gcn/gcn.py | 4 ++-- deeplay/components/gnn/graphencdec.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/deeplay/components/gnn/gcn/gcn.py b/deeplay/components/gnn/gcn/gcn.py index dd60ab4c..0490b67e 100644 --- a/deeplay/components/gnn/gcn/gcn.py +++ b/deeplay/components/gnn/gcn/gcn.py @@ -82,7 +82,7 @@ def __init__( self.normalize = Layer(sparse_laplacian_normalization) self.normalize.set_input_map("x", "edge_index") - self.normalize.set_output_map("laplacian") + self.normalize.set_output_map("edge_index") class Propagate(DeeplayModule): def forward(self, x, A): @@ -115,7 +115,7 @@ def forward(self, x, A): transform.set_output_map("x") propagate = Layer(Propagate) - propagate.set_input_map("x", "laplacian") + propagate.set_input_map("x", "edge_index") propagate.set_output_map("x") update = Layer(nn.ReLU) if i < len(self.hidden_features) else out_activation diff --git a/deeplay/components/gnn/graphencdec.py b/deeplay/components/gnn/graphencdec.py index e78ec4ac..48766651 100644 --- a/deeplay/components/gnn/graphencdec.py +++ b/deeplay/components/gnn/graphencdec.py @@ -407,6 +407,8 @@ def __init__( ) self.gcn.normalize.set_input_map('x', edge_index_map) + self.gcn.normalize.set_output_map('laplacian') + self.gcn.propagate.set_input_map("x", "laplacian") if pool == MinCutPooling: if num_clusters is None: @@ -528,6 +530,8 @@ def __init__( ) self.gcn.normalize.set_input_map('x', edge_index_map) + self.gcn.normalize.set_output_map('laplacian') + self.gcn.propagate.set_input_map("x", "laplacian") def forward(self, x): x = self.upsample(x) From d8532f51a702bd7fee48806f1a544e915dd66b39 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Thu, 21 Nov 2024 14:44:33 +0100 Subject: [PATCH 7/8] Added tests and documentation --- deeplay/applications/autoencoders/vgae.py | 57 +- deeplay/components/gnn/__init__.py | 1 + deeplay/components/gnn/graphencdec.py | 147 +-- deeplay/components/gnn/pooling/__init__.py | 4 +- deeplay/components/gnn/pooling/globalpool.py | 44 +- deeplay/components/gnn/pooling/mincut.py | 49 +- deeplay/tests/test_gnn.py | 124 +++ deeplay/tests/test_testing.ipynb | 1043 ++++++++++++++++++ 8 files changed, 1365 insertions(+), 104 deletions(-) create mode 100644 deeplay/tests/test_testing.ipynb diff --git a/deeplay/applications/autoencoders/vgae.py b/deeplay/applications/autoencoders/vgae.py index dd394bb6..b0aeae85 100644 --- a/deeplay/applications/autoencoders/vgae.py +++ b/deeplay/applications/autoencoders/vgae.py @@ -9,46 +9,71 @@ Layer, ) - import torch import torch.nn as nn class VariationalGraphAutoEncoder(Application): - channels: list + """ Variational Auto Encoder for Graphs + + Parameters + ---------- + encoder : nn.Module + decoder : nn.Module + hidden_features : int + Number of features of the hidden layers latent_dim: int + Number of latent dimensions + alpha: float + Weighting for the node feature reconstruction loss. Defaults to 1 + beta: float + Weighting for the KL loss. Defaults to 1e-7 + gamma: float + Weighting for the edge feature reconstruction loss. Defaults to 1 + delta: float + Weighting for the MinCut loss. Defaults to 1 + reconstruction_loss: Reconstruction loss + Loss metric for the reconstruction of the node and edge features. Defaults to L1 (Mean absolute error). + optimizer: Optimizer + Optimizer to use for training. + """ + encoder: torch.nn.Module decoder: torch.nn.Module + hidden_features: int + latent_dim: int + alpha: float beta: float + gamma: float + delta: float reconstruction_loss: torch.nn.Module - metrics: list optimizer: Optimizer def __init__( self, - channels: Optional[int] = 96, encoder: Optional[nn.Module] = None, decoder: Optional[nn.Module] = None, - reconstruction_loss: Optional[Callable] = nn.L1Loss(), + hidden_features: Optional[int] = 96, latent_dim=int, - alpha: Optional[int] = 1, - beta: Optional[int] = 1e-7, - gamma: Optional[int] = 1, - delta: Optional[int] = 1, + alpha: Optional[float] = 1, + beta: Optional[float] = 1e-7, + gamma: Optional[float] = 1, + delta: Optional[float] = 1, + reconstruction_loss: Optional[Callable] = nn.L1Loss(), optimizer=None, **kwargs, ): self.encoder = encoder - self.fc_mu = Layer(nn.Linear, channels, latent_dim) + self.fc_mu = Layer(nn.Linear, hidden_features, latent_dim) self.fc_mu.set_input_map('x') self.fc_mu.set_output_map('mu') - self.fc_var = Layer(nn.Linear, channels, latent_dim) + self.fc_var = Layer(nn.Linear, hidden_features, latent_dim) self.fc_var.set_input_map('x') self.fc_var.set_output_map('log_var') - self.fc_dec = Layer(nn.Linear, latent_dim, channels) + self.fc_dec = Layer(nn.Linear, latent_dim, hidden_features) self.fc_dec.set_input_map('z') self.fc_dec.set_output_map('x') @@ -56,10 +81,10 @@ def __init__( self.reconstruction_loss = reconstruction_loss or nn.L1Loss() self.latent_dim = latent_dim - self.alpha = alpha # node feature reconstruction loss weight - self.beta = beta # KL loss weight - self.gamma = gamma # edge feature reconstruction loss weight - self.delta = delta # MinCut loss weight + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.delta = delta super().__init__(**kwargs) diff --git a/deeplay/components/gnn/__init__.py b/deeplay/components/gnn/__init__.py index 1bf368e6..f6b57a88 100644 --- a/deeplay/components/gnn/__init__.py +++ b/deeplay/components/gnn/__init__.py @@ -1,4 +1,5 @@ from .gcn import * from .mpn import * from .tpu import * +from .pooling import * from .graphencdec import GraphEncoderBlock, GraphDecoderBlock, GraphEncoder, GraphDecoder \ No newline at end of file diff --git a/deeplay/components/gnn/graphencdec.py b/deeplay/components/gnn/graphencdec.py index 48766651..06f1fa16 100644 --- a/deeplay/components/gnn/graphencdec.py +++ b/deeplay/components/gnn/graphencdec.py @@ -32,23 +32,27 @@ class GraphEncoder(DeeplayModule): The number of clusters the graph is pooled to in each processing block. thresholds: list[float] The threshold values for the connectivity in the clustering process. - poolings: template-like - A list of pooling layers to use. Default is using MinCut pooling for all layers, except for the - last, which is global pooling. - save_intermediates: bool - Flag indicating whether to save intermediate adjacency matrices and other information, useful - when using it together with the GraphDecoder. Default is True. - Configurables ------------- - hidden features (int): Number of features of the hidden layers. - - num_blocks: (int): Number of processing blocks in the encoder. - - num_clusters: list[int]: Number of clusters the graph is pooled to in each processing block. + - num_blocks (int): Number of processing blocks in the encoder. + - num_clusters list[int]: Number of clusters the graph is pooled to in each processing block. - thresholds list[int]: The threshold values for the connectivity in the clustering process. + - poolings (template-like):A list of pooling layers to use. Default is using MinCut pooling for all layers, + except for the last, which is global pooling. + - save_intermediates (bool): Flag indicating whether to save intermediate adjacency matrices and other information, useful + when using it together with the GraphDecoder. Default is True. + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_features) + - edge_index: torch.Tensor of shape (2, num_edges) + - batch: torch.Tensor of shape (num_nodes) + - edge_attr: torch.Tensor of shape (num_edges, edge_features) - Evaluation + Example ---------- >>> encoder = dl.GraphEncoder(hidden_features=96, num_blocks=3, num_clusters=[5, 3, 1], thresholds=[0.1, 0.2, None], save_intermediates=False).build() >>> inp = {} @@ -175,32 +179,39 @@ class GraphDecoder(DeeplayModule): by the GraphEncoder. This module aims to decode the latent graph features back into graph node and edge attributes. - Attributes: - hidden_features: int - The dimensionality of the hidden layers of the decoder. This should match the hidden - features from the corresponding GraphEncoder. - num_blocks: int - The number of processing blocks in the decoder. This should match the number of blocks - of the GraphEncoder. - output_node_dim: int - The dimensionality of the output node features. This should match the original dimensionallity - of the input node features of the GraphEncoder. - output_edge_dim: int - The dimensionality of the output edge features. This should match the original dimensionallity - of the input edge attributes of the GraphEncoder. - upsamplings: template-like - A list of upsampling layers to use. Default is using MinCut upsampling for all layers, except for the - first, which is global upsampling. This should reflect the pooling layers of the GraphEncoder. - + Parameters + ---------- + hidden_features: int + The dimensionality of the hidden layers of the decoder. This should match the hidden + features from the corresponding GraphEncoder. + num_blocks: int + The number of processing blocks in the decoder. This should match the number of blocks + of the GraphEncoder. + output_node_dim: int + The dimensionality of the output node features. This should match the original dimensionallity + of the input node features of the GraphEncoder. + output_edge_dim: int + The dimensionality of the output edge features. This should match the original dimensionallity + of the input edge attributes of the GraphEncoder. Configurables ------------- - hidden features (int): Number of features of the hidden layers. - - num_blocks: (int): Number of processing blocks in the decoder. - - output_node_dim: (int): Number of dimensions of the output node features. - - output_edge_dim: (int): Number of dimensions of the output edge attributes. + - num_blocks (int): Number of processing blocks in the decoder. + - output_node_dim (int): Number of dimensions of the output node features. + - output_edge_dim (int): Number of dimensions of the output edge attributes. + - upsamplings (template-like): A list of upsampling layers to use. Default is using MinCut upsampling + for all layers, except for the first, which is global upsampling. This should reflect the pooling + layers of the GraphEncoder. + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_features) + - edge_index: torch.Tensor of shape (2, num_edges) + - batch: torch.Tensor of shape (num_nodes) - Evaluation + Example ---------- >>> encoder = dl.GraphEncoder(hidden_features=96, num_blocks=3, num_clusters=[20, 5, 1], thresholds=[0.1, 0.5, None], save_intermediates=False).build() >>> inp = {} @@ -327,29 +338,31 @@ class GraphEncoderBlock(DeeplayModule): The number of input features for each node in the graph. out_features: int The number of output features for each node after processing. - pool: Optional[template-like] - The pooling operation to be used. Defaults to MinCutPooling. - num_clusters: Optional[int] - The number of clusters for MinCutPooling. Must be provided if using MinCutPooling. - threshold: Optional[float] - Threshold value for pooling operations. - edge_index_map: Optional[str] - The mapping for edge index inputs. Defaults to "edge_index". - select_output_map: Optional[str] - The mapping for the selection outputs from the pooling layer. Defaults to "s". - connect_output_map: Optional[str] - The mapping for connecting outputs to subsequent layers. Defaults to "edge_index_pool". - batch_input_map: Optional[str] - The mapping for batch input. Defaults to "batch". - batch_output_map: Optional[str] - The mapping for batch output. Defaults to "batch". - mincut_cut_loss_map: Optional[str] - The mapping for saving the mincut cut loss. Defaults to "L_cut". - mincut_ortho_loss_map: Optional[str] - The mapping for saving the mincut orthogonallity loss. Defaults to "L_ortho". - - Evaluation + Configurables + ------------- + - in_features (int): The number of input features for each node in the graph. + - out_features (int): The number of output features for each node after processing. + - pool (template-like): The pooling operation to be used. Defaults to MinCutPooling. + - num_clusters (int): The number of clusters for MinCutPooling. Must be provided if using MinCutPooling. + - threshold (float): Threshold value for pooling operations. + - edge_index_map (str): The mapping for edge index inputs. Defaults to "edge_index". + - select_output_map (str): The mapping for the selection outputs from the pooling layer. Defaults to "s". + - connect_output_map (str): The mapping for connecting outputs to subsequent layers. Defaults to "edge_index_pool". + - batch_input_map (str): The mapping for batch input. Defaults to "batch". + - batch_output_map (str): The mapping for batch output. Defaults to "batch". + - mincut_cut_loss_map (str): The mapping for saving the mincut cut loss. Defaults to "L_cut". + - mincut_ortho_loss_map (str): The mapping for saving the mincut orthogonallity loss. Defaults to "L_ortho". + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_features) + - edge_index: torch.Tensor of shape (2, num_edges) + - batch: torch.Tensor of shape (num_nodes) + - edge_attr: torch.Tensor of shape (num_edges, edge_features) + + Example ---------- >>> block = dl.GraphEncoderBlock(in_features=16, out_features=16, num_clusters=5, threshold=0.1).build() >>> inp = {} @@ -452,28 +465,25 @@ class GraphDecoderBlock(DeeplayModule): This block is a fundamental component of the GraphDecoder, enabling the reconstruction of graph features in a Graph Encoder Decoder model. - Parameters ---------- in_features: int The number of input features for each node in the graph. out_features: int The number of output features for each node after processing. - upsample: Optional[template-like] - The upsampling operation to be used. Defaults to MinCutUpsampling. - edge_index_map: Optional[str] - The mapping for edge index inputs. Defaults to "edge_index". - select_input_map: Optional[str] - The mapping for selection inputs for the upsampling layer. Defaults to "s". - connect_input_map: Optional[str] - The mapping for the connectivity for the upsampling layer. Defaults to "edge_index_pool". - connect_output_map: Optional[str] - The mapping for the connectivity outputs of the upsampling layer. Defaults to "-". - batch_map: Optional[str] - The mapping for batch inputs or outputs. Defaults to "batch". - + + Configurables + ------------- + - in_features (int): The number of input features for each node in the graph. + - out_features (int): The number of output features for each node after processing. + - upsample (template-like): The upsampling operation to be used. Defaults to MinCutUpsampling. + - edge_index_map (str): The mapping for edge index inputs. Defaults to "edge_index". + - select_input_map (str): The mapping for selection inputs for the upsampling layer. Defaults to "s". + - connect_input_map (str): The mapping for the connectivity for the upsampling layer. Defaults to "edge_index_pool". + - connect_output_map (str): The mapping for the connectivity outputs of the upsampling layer. Defaults to "-". + - batch_map (str): The mapping for batch inputs or outputs. Defaults to "batch". - Evaluation + Example ---------- >>> encoderblock = dl.GraphEncoderBlock(in_features=16, out_features=16, num_clusters=5, threshold=0.2).build() >>> inp = {} @@ -484,6 +494,7 @@ class GraphDecoderBlock(DeeplayModule): >>> encoderblock_output = encoderblock(inp) >>> decoderblock = dl.GraphDecoderBlock(in_features=16, out_features=16).build() >>> decoderblock_output = decoderblock(encoderblock_output) + """ in_features: int out_features: int diff --git a/deeplay/components/gnn/pooling/__init__.py b/deeplay/components/gnn/pooling/__init__.py index e13828e0..96b025f7 100644 --- a/deeplay/components/gnn/pooling/__init__.py +++ b/deeplay/components/gnn/pooling/__init__.py @@ -1,2 +1,2 @@ -from .mincut import * -from .globalpool import * \ No newline at end of file +from .mincut import MinCutPooling, MinCutUpsampling +from .globalpool import GlobalGraphPooling, GlobalGraphUpsampling \ No newline at end of file diff --git a/deeplay/components/gnn/pooling/globalpool.py b/deeplay/components/gnn/pooling/globalpool.py index 1f79093d..55f06546 100644 --- a/deeplay/components/gnn/pooling/globalpool.py +++ b/deeplay/components/gnn/pooling/globalpool.py @@ -11,14 +11,23 @@ class GlobalGraphPooling(DeeplayModule): (Inspired by MinCut-pooling ('Spectral Clustering with Graph Neural Networks for Graph Pooling'): but with the assignment matrix S being deterministic (all nodes are pooled into one cluster)) - Input - ----- - X: float (Any, Any) #(number of nodes, number of features) - - Output - ------ - X: float (1, Any) #(number of clusters, number of features) - S: float (Any, 1) #(number of nodes, number of clusters) + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_features) + + - output: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_clusters, node_features) + - s: torch.Tensor of shape (num_nodes, num_clusters) + + Examples + -------- + >>> global_pool = GlobalGraphPooling().build() + >>> inp = {} + >>> inp["x"] = torch.randn(3, 2) + >>> inp["batch"] = torch.zeros(3, dtype=int) + >>> inp["edge_index"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + >>> out = global_pool(inp) """ def __init__( @@ -73,6 +82,25 @@ class GlobalGraphUpsampling(DeeplayModule): """ Reverse of GlobalGraphPooling. Only upsampling the node features. + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_clusters, node_features) + - s: torch.Tensor of shape (num_nodes, num_clusters) + + - output: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_features) + + Examples + -------- + >>> global_upsampling = GlobalGraphUpsampling() + >>> global_upsampling = global_upsampling.build() + + >>> inp = {} + >>> inp["x"] = torch.randn(1, 2) + >>> inp["s"] = torch.ones((3, 1)) + >>> out = global_upsampling(inp) """ def __init__( diff --git a/deeplay/components/gnn/pooling/mincut.py b/deeplay/components/gnn/pooling/mincut.py index 9766e1ca..f2ead2d1 100644 --- a/deeplay/components/gnn/pooling/mincut.py +++ b/deeplay/components/gnn/pooling/mincut.py @@ -17,12 +17,22 @@ class MinCutPooling(DeeplayModule): The number of clusters to which each graph is pooled. hidden_features: Sequence[int] The number of hidden features for the Multi-Layer Perceptron (MLP) used for selecting clusters for the pooling. - reduce_self_connection: Optional[bool] - Whether to reduce self-connections in the adjacency matrix. Defaults to True. - threshold: Optional[float] - A threshold value to apply to the adjacency matrix to binarize the edges. If None, no threshold is applied. - - Evaluation + + Configurables + ------------- + - num_clusters (int): The number of clusters to which each graph is pooled. + - hidden_features (list[int]): The number of hidden features for the Multi-Layer Perceptron (MLP) used for selecting clusters for the pooling. + - reduce_self_connection (bool): Whether to reduce self-connections in the adjacency matrix. Defaults to True. + - threshold (float): A threshold value to apply to the adjacency matrix to binarize the edges. If None, no threshold is applied. Default is None. + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_nodes, node_features) + - edge_index: torch.Tensor of shape (2, num_edges) + - batch: torch.Tensor of shape (num_nodes) + + Example ---------- >>> MinCut = dl.components.gnn.pooling.MinCutPooling(hidden_features = [8], num_clusters = 5, reduce_self_connection = True, threshold = 0.25).build() >>> inp = {} @@ -250,6 +260,25 @@ def forward(self, x): class MinCutUpsampling(DeeplayModule): """ Reverse of MinCutPooling as described in 'Spectral Clustering with Graph Neural Networks for Graph Pooling'. + + Constraints + ----------- + - input: Dict[str, Any] or torch-geometric Data object containing the following attributes: + - x: torch.Tensor of shape (num_clusters, node_features). + - edge_index_pool: torch.Tensor of shape (2, num_edges). + - batch: torch.Tensor of shape (num_clusters). + - s: torch.Tensor of shape (num_nodes, num_clusters) + + Example + ---------- + >>> mincut_upsample = MinCutUpsampling().build() + >>> inp = {} + >>> inp["x"] = torch.randn(2, 1) + >>> inp["batch"] = torch.zeros(2, dtype=int) + >>> inp['s'] = torch.tensor([[1.0, 0], [0, 1.0], [1.0, 0]]) + >>> inp["edge_index_pool"] = torch.tensor([[0, 1], [1, 0]]) + >>> out = mincut_upsample(inp) + """ def __init__( @@ -262,7 +291,7 @@ def forward(self, x_pool, a_pool, s): x = torch.matmul(s, x_pool) if a_pool.is_sparse: - a = torch.spmm(a_pool, s.T) + a = torch.spmm(s, torch.spmm(a_pool, s.T)) elif (not a_pool.is_sparse) & (a_pool.size(0) == 2): a_pool = torch.sparse_coo_tensor( a_pool, @@ -270,15 +299,15 @@ def forward(self, x_pool, a_pool, s): ((s.T).size(0),) * 2, device=a_pool.device, ) - a = torch.spmm(a_pool, s.T) + a = torch.spmm(s, torch.spmm(a_pool, s.T)) elif (not a_pool.is_sparse) & len({a_pool.size(0), a_pool.size(1), (s.T).size(0)}) == 1: - a = a_pool.type(s.dtype) @ s.T + a = s @ a_pool.type(s.dtype) @ s.T return x, a self.upsample = Upsample() self.upsample.set_input_map('x', 'edge_index_pool', 's') - self.upsample.set_output_map('x', 'edge_index_') + self.upsample.set_output_map('x', 'edge_index') def forward(self, x): x = self.upsample(x) diff --git a/deeplay/tests/test_gnn.py b/deeplay/tests/test_gnn.py index 91862992..f7ef1eb7 100644 --- a/deeplay/tests/test_gnn.py +++ b/deeplay/tests/test_gnn.py @@ -22,6 +22,14 @@ Max, Layer, GlobalMeanPool, + GlobalGraphPooling, + GlobalGraphUpsampling, + MinCutPooling, + MinCutUpsampling, + GraphEncoder, + GraphDecoder, + GraphEncoderBlock, + GraphDecoderBlock, ) import itertools @@ -729,3 +737,119 @@ def test_gtoempm_defaults(self): out = model(inp) self.assertEqual(out.shape, (20, 1)) + + +class TestComponentPool(unittest.TestCase): + def test_global_pool(self): + global_pool = GlobalGraphPooling() + global_pool = global_pool.build() + + inp = {} + inp["x"] = torch.randn(3, 2) + inp["batch"] = torch.zeros(3, dtype=int) + inp["edge_index"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out = global_pool(inp) + self.assertEqual(out["x"].shape, (1, 2)) + + def test_global_upsampling(self): + global_upsampling = GlobalGraphUpsampling() + global_upsampling = global_upsampling.build() + + inp = {} + inp["x"] = torch.randn(1, 2) + inp["s"] = torch.ones((3, 1)) + out = global_upsampling(inp) + self.assertEqual(out["x"].shape, (3, 2)) + + def test_mincut_pool(self): + mincut = MinCutPooling(num_clusters = 2, hidden_features = [5]) + mincut = mincut.build() + + inp = {} + inp["x"] = torch.randn(3, 1) + inp["batch"] = torch.zeros(3, dtype=int) + inp["edge_index"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out = mincut(inp) + + self.assertEqual(out["x"].shape, (2, 1)) + self.assertEqual(out['edge_index'].shape, (2,2)) + self.assertEqual(out["s"].shape, (3, 2)) + self.assertTrue((torch.sum(out['s'], axis=1) - torch.tensor([1., 1., 1.])).sum() < 1e-5) + + def test_mincut_upsample(self): + mincut_upsample = MinCutUpsampling() + mincut_upsample = mincut_upsample.build() + + inp = {} + inp["x"] = torch.randn(2, 1) + inp["batch"] = torch.zeros(2, dtype=int) + inp['s'] = torch.tensor([[1.0, 0], [0, 1.0], [1.0, 0]]) + inp["edge_index_pool"] = torch.tensor([[0, 1], [1, 0]]) + out = mincut_upsample(inp) + + self.assertEqual(out["x"].shape, (3, 1)) + + +class TestComponentsGraphEncoderDecoder(unittest.TestCase): + def test_graph_encoder_block(self): + encoder_block = GraphEncoderBlock(in_features=1, out_features=4, num_clusters=2) + encoder_block = encoder_block.build() + + inp = {} + inp["x"] = torch.randn(3, 1) + inp["batch"] = torch.zeros(3, dtype=int) + inp["edge_index"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out = encoder_block(inp) + + self.assertEqual(out["x"].shape, (2, 4)) + self.assertEqual(out["s"].shape, (3, 2)) + self.assertEqual(out["edge_index_pool"].shape, (2, 2)) + + def test_graph_decoder_block(self): + decoder_block = GraphDecoderBlock(in_features=1, out_features=4) + decoder_block = decoder_block.build() + + inp = {} + inp["x"] = torch.randn(2, 1) + inp["batch"] = torch.zeros(2, dtype=int) + inp["edge_index_pool"] = torch.tensor([[0, 1], [1, 0]]) + inp["edge_index"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + inp['s'] = torch.tensor([[1.0, 0], [0, 1.0], [1.0, 0]]) + out = decoder_block(inp) + + self.assertEqual(out["x"].shape, (3, 4)) + self.assertTrue(torch.all(inp["edge_index"] == out["edge_index"])) + + def test_graph_encoder(self): + graph_encoder = GraphEncoder(hidden_features=2, num_blocks=3, num_clusters=[3,2,1]) + graph_encoder = graph_encoder.build() + + self.assertEqual(len(graph_encoder.blocks), 3) + + inp = {} + inp["x"] = torch.randn(4, 2) + inp["batch"] = torch.zeros(4, dtype=int) + inp["edge_index"] = torch.tensor([[0, 1, 1, 2, 1, 3], [1, 0, 2, 1, 3, 1]]) + inp["edge_attr"] = torch.randn(6, 1) + out = graph_encoder(inp) + + self.assertEqual(out["x"].shape, (1, 2)) + self.assertEqual(out["s_1"].shape, (3, 2)) + + def test_graph_decoder(self): + graph_decoder = GraphDecoder(hidden_features=2, num_blocks=2, output_node_dim=2, output_edge_dim=1) + graph_decoder = graph_decoder.build() + + self.assertEqual(len(graph_decoder.blocks), 2) + + inp = {} + inp["x"] = torch.randn(1, 2) + inp["batch"] = torch.zeros(1, dtype=int) + inp["edge_index_1"] = torch.tensor([[0, 1], [1, 0]]) + inp["edge_index"] = torch.tensor([[0, 1, 1, 2, 1, 3], [1, 0, 2, 1, 3, 1]]) + inp['s_1'] = torch.ones((2,1)) + inp['s_0'] = torch.tensor([[1.0, 0], [0, 1.0], [0, 1.0], [1.0, 0]]) + out = graph_decoder(inp) + + self.assertEqual(out["x"].shape, (4, 2)) + self.assertEqual(out["edge_attr"].shape, (6, 1)) diff --git a/deeplay/tests/test_testing.ipynb b/deeplay/tests/test_testing.ipynb new file mode 100644 index 00000000..42156bbe --- /dev/null +++ b/deeplay/tests/test_testing.ipynb @@ -0,0 +1,1043 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "c:\\Users/xgrmir/Documents/Deeplay/deeplay\\deeplay\\__init__.py\n" + ] + } + ], + "source": [ + "import sys\n", + "\n", + "sys.path.insert(0, '/Users/xgrmir/Documents/Deeplay/deeplay')\n", + "\n", + "import deeplay as dl\n", + "print(dl.__file__)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from deeplay import (\n", + " GraphConvolutionalNeuralNetwork,\n", + " GraphToGlobalMPM,\n", + " GraphToNodeMPM,\n", + " GraphToEdgeMPM,\n", + " GraphToEdgeMAGIK,\n", + " MessagePassingNeuralNetwork,\n", + " ResidualMessagePassingNeuralNetwork,\n", + " MultiLayerPerceptron,\n", + " dense_laplacian_normalization,\n", + " Sum,\n", + " WeightedSum,\n", + " Mean,\n", + " Prod,\n", + " Min,\n", + " Max,\n", + " Layer,\n", + " GlobalMeanPool,\n", + " GlobalGraphPooling,\n", + " GlobalGraphUpsampling,\n", + " MinCutPooling,\n", + " MinCutUpsampling,\n", + " GraphEncoder,\n", + " GraphDecoder,\n", + " GraphEncoderBlock,\n", + " GraphDecoderBlock,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Global pool" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "global_pool = GlobalGraphPooling()\n", + "global_pool = global_pool.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(3, 2)\n", + "inp[\"batch\"] = torch.zeros(3, dtype=int)\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2, 1], [1, 0, 2, 1, 0]])\n", + "out = global_pool(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[ 1.6337, -1.7212]]),\n", + " 'batch': tensor([0]),\n", + " 'edge_index': tensor([[0, 1, 1, 2, 1],\n", + " [1, 0, 2, 1, 0]]),\n", + " 's': tensor([[1.],\n", + " [1.],\n", + " [1.]])}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 2])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['x'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 1])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['s'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "global_upsampling = GlobalGraphUpsampling()\n", + "global_upsampling = global_upsampling.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(1, 2)\n", + "inp[\"s\"] = torch.ones((3,1))\n", + "out = global_upsampling(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 2])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['x'].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Mincut" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "mincut = MinCutPooling(num_clusters = 2, hidden_features = [5])\n", + "mincut = mincut.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "MinCutPooling(\n", + " (select): MultiLayerPerceptron(\n", + " (blocks): LayerList(\n", + " (0): LinearBlock(\n", + " (layer): LazyLinear(in_features=0, out_features=5, bias=True)\n", + " (activation): ReLU()\n", + " )\n", + " (1): LinearBlock(\n", + " (layer): Linear(in_features=5, out_features=2, bias=True)\n", + " (activation): Softmax(dim=1)\n", + " )\n", + " )\n", + " )\n", + " (batch_compatible): ClusterMatrixForBatch()\n", + " (mincut_loss): MinCutLoss()\n", + " (reduce): Reduce()\n", + " (connect): Connect()\n", + " (red_self_con): ReduceSelfConnection()\n", + " (sparse): SparseEdgeIndex()\n", + ")" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mincut" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(3, 1)\n", + "inp[\"batch\"] = torch.zeros(3, dtype=int)\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2, 1], [1, 0, 2, 1, 0]])\n", + "out = mincut(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[-1.4034],\n", + " [-1.7131]], grad_fn=),\n", + " 'batch': tensor([0, 0]),\n", + " 'edge_index': tensor(indices=tensor([[0, 1],\n", + " [1, 0]]),\n", + " values=tensor([0.1269, 0.1285]),\n", + " size=(2, 2), nnz=2, layout=torch.sparse_coo, grad_fn=),\n", + " 's': tensor([[0.4595, 0.5405],\n", + " [0.4439, 0.5561],\n", + " [0.4480, 0.5520]], grad_fn=),\n", + " 'L_cut': tensor(-1.0003, grad_fn=),\n", + " 'L_ortho': tensor(0.7652, grad_fn=)}" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 2])" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['s'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([1., 1., 1.], grad_fn=)" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.sum(out['s'], axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(indices=tensor([[0, 1],\n", + " [1, 0]]),\n", + " values=tensor([0.1209, 0.1388]),\n", + " size=(2, 2), nnz=2, layout=torch.sparse_coo, grad_fn=)" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['edge_index']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Add test on different configurations with the layers in MinCut?" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "mincut_upsample = MinCutUpsampling()\n", + "mincut_upsample = mincut_upsample.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(2, 1)\n", + "inp[\"batch\"] = torch.zeros(2, dtype=int)\n", + "inp['s'] = torch.tensor([\n", + " [1.0, 0],\n", + " [0, 1.0],\n", + " [1.0, 0]\n", + "])\n", + "inp[\"edge_index_pool\"] = torch.tensor([[0, 1], [1, 0]])\n", + "out = mincut_upsample(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[1.3595],\n", + " [0.2606],\n", + " [1.3595]]),\n", + " 'batch': tensor([0, 0]),\n", + " 's': tensor([[1., 0.],\n", + " [0., 1.],\n", + " [1., 0.]]),\n", + " 'edge_index_pool': tensor([[0, 1],\n", + " [1, 0]]),\n", + " 'edge_index': tensor([[0., 1., 0.],\n", + " [1., 0., 1.],\n", + " [0., 1., 0.]])}" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0., 1., 0.],\n", + " [1., 0., 1.],\n", + " [0., 1., 0.]])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['edge_index']" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "a_pool = torch.tensor([[0, 1], [1, 0]])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "s = torch.tensor([\n", + " [1, 0],\n", + " [0, 1],\n", + " [0.8, 0.2]\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "a_pool = torch.sparse_coo_tensor(\n", + " a_pool,\n", + " torch.ones(a_pool.size(1)),\n", + " ((s.T).size(0),) * 2,\n", + " device=a_pool.device,\n", + " )\n", + "a = torch.spmm(a_pool, s.T)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(indices=tensor([[0, 1],\n", + " [1, 0]]),\n", + " values=tensor([1., 1.]),\n", + " size=(2, 2), nnz=2, layout=torch.sparse_coo)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a_pool" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.0000, 1.0000, 0.2000],\n", + " [1.0000, 0.0000, 0.8000]])" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1.0000, 0.0000, 0.8000],\n", + " [0.0000, 1.0000, 0.2000]])" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "s.T" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "mincut = MinCutPooling(num_clusters=2, hidden_features=[5], reduce_self_connection=False)\n", + "mincut = mincut.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(3, 1)\n", + "inp[\"batch\"] = torch.zeros(3, dtype=int)\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n", + "out = mincut(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[0.9204],\n", + " [1.3934]], grad_fn=),\n", + " 'batch': tensor([0, 0]),\n", + " 'edge_index': tensor(indices=tensor([[0, 0, 1, 1],\n", + " [0, 1, 0, 1]]),\n", + " values=tensor([0.6270, 0.9567, 0.9567, 1.4596]),\n", + " size=(2, 2), nnz=4, layout=torch.sparse_coo, grad_fn=),\n", + " 's': tensor([[0.3918, 0.6082],\n", + " [0.3950, 0.6050],\n", + " [0.4019, 0.5981]], grad_fn=),\n", + " 'L_cut': tensor(-0.9999, grad_fn=),\n", + " 'L_ortho': tensor(0.7653, grad_fn=)}" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Graph Encoder and decoder" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### blocks" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "encoder_block = GraphEncoderBlock(in_features=1, out_features=4, num_clusters=2)\n", + "encoder_block = encoder_block.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "GraphEncoderBlock(\n", + " (gcn): GraphConvolutionalNeuralNetwork(\n", + " (normalize): sparse_laplacian_normalization()\n", + " (blocks): LayerList(\n", + " (0): TransformPropagateUpdate(\n", + " (transform): Linear(in_features=1, out_features=4, bias=True)\n", + " (propagate): Propagate()\n", + " (update): ReLU()\n", + " )\n", + " )\n", + " )\n", + " (pool): MinCutPooling(\n", + " (select): MultiLayerPerceptron(\n", + " (blocks): LayerList(\n", + " (0): LinearBlock(\n", + " (layer): LazyLinear(in_features=0, out_features=4, bias=True)\n", + " (activation): ReLU()\n", + " )\n", + " (1): LinearBlock(\n", + " (layer): Linear(in_features=4, out_features=2, bias=True)\n", + " (activation): Softmax(dim=1)\n", + " )\n", + " )\n", + " )\n", + " (batch_compatible): ClusterMatrixForBatch()\n", + " (mincut_loss): MinCutLoss()\n", + " (reduce): Reduce()\n", + " (connect): Connect()\n", + " (red_self_con): ReduceSelfConnection()\n", + " (sparse): SparseEdgeIndex()\n", + " )\n", + ")" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "encoder_block" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(3, 1)\n", + "inp[\"batch\"] = torch.zeros(3, dtype=int)\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n", + "out = encoder_block(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[0.4139, 0.0000, 1.0564, 1.0303],\n", + " [0.1099, 0.0000, 0.2809, 0.2738]], grad_fn=),\n", + " 'batch': tensor([0, 0]),\n", + " 'edge_index': tensor([[0, 1, 1, 2],\n", + " [1, 0, 2, 1]]),\n", + " 'laplacian': tensor(indices=tensor([[0, 0, 1, 1, 1, 2, 2],\n", + " [0, 1, 0, 1, 2, 1, 2]]),\n", + " values=tensor([0.5000, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000]),\n", + " size=(3, 3), nnz=7, layout=torch.sparse_coo),\n", + " 's': tensor([[0.7871, 0.2129],\n", + " [0.7905, 0.2095],\n", + " [0.7918, 0.2082]], grad_fn=),\n", + " 'L_cut': tensor(-1.0000, grad_fn=),\n", + " 'L_ortho': tensor(0.7654, grad_fn=),\n", + " 'edge_index_pool': tensor(indices=tensor([[0, 1],\n", + " [1, 0]]),\n", + " values=tensor([0.0838, 0.0838]),\n", + " size=(2, 2), nnz=2, layout=torch.sparse_coo, grad_fn=)}" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 2])" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['s'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "decoder_block = GraphDecoderBlock(in_features=1, out_features=4)\n", + "decoder_block = decoder_block.build()\n", + "\n", + "inp = {}\n", + "inp[\"x\"] = torch.randn(2, 1)\n", + "inp[\"batch\"] = torch.zeros(2, dtype=int)\n", + "inp[\"edge_index_pool\"] = torch.tensor([[0, 1], [1, 0]])\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n", + "inp['s'] = torch.tensor([\n", + " [1.0, 0],\n", + " [0, 1.0],\n", + " [1.0, 0]\n", + "])\n", + "out = decoder_block(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[0.0000, 0.0000, 0.4061, 0.0000],\n", + " [0.0000, 0.0000, 0.5079, 0.0000],\n", + " [0.0000, 0.0000, 0.4061, 0.0000]], grad_fn=),\n", + " 'batch': tensor([0, 0]),\n", + " 'edge_index_pool': tensor([[0, 1],\n", + " [1, 0]]),\n", + " 'edge_index': tensor([[0, 1, 1, 2],\n", + " [1, 0, 2, 1]]),\n", + " 's': tensor([[1., 0.],\n", + " [0., 1.],\n", + " [1., 0.]]),\n", + " '-': tensor([[0., 1., 0.],\n", + " [1., 0., 1.],\n", + " [0., 1., 0.]]),\n", + " 'laplacian': tensor(indices=tensor([[0, 0, 1, 1, 1, 2, 2],\n", + " [0, 1, 0, 1, 2, 1, 2]]),\n", + " values=tensor([0.5000, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000]),\n", + " size=(3, 3), nnz=7, layout=torch.sparse_coo)}" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 4])" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['x'].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Encoder and decoder" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "graph_encoder = GraphEncoder(hidden_features=2, num_blocks=2, num_clusters=[2,1])\n", + "graph_encoder = graph_encoder.build()" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "inp = {}\n", + "inp[\"x\"] = torch.randn(4, 2)\n", + "inp[\"batch\"] = torch.zeros(4, dtype=int)\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2, 1, 3], [1, 0, 2, 1, 3, 1]])\n", + "inp[\"edge_attr\"] = torch.randn(6, 1)\n", + "out = graph_encoder(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(graph_encoder.blocks)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[0.8827, 0.0000]], grad_fn=),\n", + " 'batch': tensor([0, 0, 0, 0]),\n", + " 'edge_index': tensor([[0, 1, 1, 2, 1, 3],\n", + " [1, 0, 2, 1, 3, 1]]),\n", + " 'edge_attr': tensor([[0.0046, 0.0058],\n", + " [0.5006, 0.0829],\n", + " [0.3284, 0.0937],\n", + " [0.1720, 0.0512],\n", + " [0.3950, 0.0706],\n", + " [0.5021, 0.0839]], grad_fn=),\n", + " 'aggregate': tensor([[0.5006, 0.0829],\n", + " [0.6788, 0.1409],\n", + " [0.3284, 0.0937],\n", + " [0.3950, 0.0706]], grad_fn=),\n", + " 'laplacian': tensor(indices=tensor([[0, 0, 1, 1],\n", + " [0, 1, 0, 1]]),\n", + " values=tensor([0.5000, 0.5000, 0.5000, 0.5000]),\n", + " size=(2, 2), nnz=4, layout=torch.sparse_coo),\n", + " 's_0': tensor([[0.4963, 0.5037],\n", + " [0.4963, 0.5037],\n", + " [0.4963, 0.5037],\n", + " [0.4963, 0.5037]], grad_fn=),\n", + " 'batch_1': tensor([0, 0]),\n", + " 'L_cut_0': tensor(-1., grad_fn=),\n", + " 'L_ortho_0': tensor(0.7654, grad_fn=),\n", + " 'edge_index_1': tensor(indices=tensor([[0, 1],\n", + " [1, 0]]),\n", + " values=tensor([0.1442, 0.1442]),\n", + " size=(2, 2), nnz=2, layout=torch.sparse_coo, grad_fn=),\n", + " 's_1': tensor([[1.],\n", + " [1.]]),\n", + " 'batch_2': tensor([0])}" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 2])" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['edge_index_1'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n" + ] + } + ], + "source": [ + "graph_decoder = GraphDecoder(hidden_features=2, num_blocks=2, output_node_dim=2, output_edge_dim=1)\n", + "graph_decoder = graph_decoder.build()\n", + "\n", + "print(len(graph_decoder.blocks))\n", + "\n", + "inp = {}\n", + "inp[\"x\"] = torch.randn(1, 2)\n", + "inp[\"batch\"] = torch.zeros(1, dtype=int)\n", + "inp[\"edge_index_1\"] = torch.tensor([[0, 1], [1, 0]])\n", + "inp[\"edge_index\"] = torch.tensor([[0, 1, 1, 2, 1, 3], [1, 0, 2, 1, 3, 1]])\n", + "inp['s_1'] = torch.ones((2,1))\n", + "inp['s_0'] = torch.tensor([[1.0, 0], [0, 1.0], [0, 1.0], [1.0, 0]])\n", + "out = graph_decoder(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.5838],\n", + " [0.5838],\n", + " [0.5838],\n", + " [0.5838],\n", + " [0.5838],\n", + " [0.5838]], grad_fn=)" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out['edge_attr']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From f989b22344f234656f525d87786820750746ea83 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Mon, 16 Dec 2024 15:56:59 +0100 Subject: [PATCH 8/8] working version, but not nicely implemented --- deeplay/components/gnn/gcn/__init__.py | 1 + deeplay/components/gnn/gcn/gcn_concat.py | 184 +++++++++++++++++++ deeplay/components/gnn/graphencdec.py | 125 +++++++++---- deeplay/components/gnn/mpn/__init__.py | 1 + deeplay/components/gnn/mpn/mpn_gaudi.py | 150 +++++++++++++++ deeplay/components/gnn/mpn/transformation.py | 13 ++ deeplay/components/gnn/pooling/mincut.py | 7 +- deeplay/ops/__init__.py | 1 + deeplay/ops/get_edge_features.py | 14 ++ 9 files changed, 457 insertions(+), 39 deletions(-) create mode 100644 deeplay/components/gnn/gcn/gcn_concat.py create mode 100644 deeplay/components/gnn/mpn/mpn_gaudi.py create mode 100644 deeplay/ops/get_edge_features.py diff --git a/deeplay/components/gnn/gcn/__init__.py b/deeplay/components/gnn/gcn/__init__.py index 1271fefc..0d62ac12 100644 --- a/deeplay/components/gnn/gcn/__init__.py +++ b/deeplay/components/gnn/gcn/__init__.py @@ -1,2 +1,3 @@ from .gcn import GraphConvolutionalNeuralNetwork from .normalization import * +from .gcn_concat import GraphConvolutionalNeuralNetworkConcat \ No newline at end of file diff --git a/deeplay/components/gnn/gcn/gcn_concat.py b/deeplay/components/gnn/gcn/gcn_concat.py new file mode 100644 index 00000000..2b983213 --- /dev/null +++ b/deeplay/components/gnn/gcn/gcn_concat.py @@ -0,0 +1,184 @@ +from typing import List, Optional, Literal, Any, Sequence, Type, overload, Union + +from deeplay import DeeplayModule, Layer, LayerList + +from ..tpu import TransformPropagateUpdate +from deeplay.ops import Cat + +import torch +import torch.nn as nn + + +class GraphConvolutionalNeuralNetworkConcat(DeeplayModule): + in_features: Optional[int] + hidden_features: Sequence[Optional[int]] + out_features: int + blocks: LayerList[TransformPropagateUpdate] + + @property + def input(self): + """Return the input layer of the network. Equivalent to `.blocks[0]`.""" + return self.blocks[0] + + @property + def hidden(self): + """Return the hidden layers of the network. Equivalent to `.blocks[:-1]`""" + return self.blocks[:-1] + + @property + def output(self): + """Return the last layer of the network. Equivalent to `.blocks[-1]`.""" + return self.blocks[-1] + + @property + def transform(self) -> LayerList[Layer]: + """Return the layers of the network. Equivalent to `.blocks.layer`.""" + return self.blocks.transform + + @property + def propagate(self) -> LayerList[Layer]: + """Return the activations of the network. Equivalent to `.blocks.activation`.""" + return self.blocks.propagate + + @property + def update(self) -> LayerList[Layer]: + """Return the normalizations of the network. Equivalent to `.blocks.normalization`.""" + return self.blocks.update + + def __init__( + self, + in_features: int, + hidden_features: Sequence[int], + out_features: int, + out_activation: Union[Type[nn.Module], nn.Module, None] = None, + ): + super().__init__() + + self.in_features = in_features + self.hidden_features = hidden_features + self.out_features = out_features + + if in_features is None: + raise ValueError("in_features must be specified") + + if out_features is None: + raise ValueError("out_features must be specified") + + if in_features <= 0: + raise ValueError(f"in_features must be positive, got {in_features}") + + if out_features <= 0: + raise ValueError(f"out_features must be positive, got {out_features}") + + if any(h <= 0 for h in hidden_features): + raise ValueError( + f"all hidden_features must be positive, got {hidden_features}" + ) + + if out_activation is None: + out_activation = Layer(nn.Identity) + elif isinstance(out_activation, type) and issubclass(out_activation, nn.Module): + out_activation = Layer(out_activation) + + class Propagate(DeeplayModule): + def forward(self, x, A): + if A.is_sparse: + return torch.spmm(A, x) + elif (not A.is_sparse) & (A.size(0) == 2): + A = torch.sparse_coo_tensor( + A, + torch.ones(A.size(1)), + (x.size(0),) * 2, + device=A.device, + ) + return torch.spmm(A, x) + elif (not A.is_sparse) & len({A.size(0), A.size(1), x.size(0)}) == 1: + return A.type(x.dtype) @ x + else: + raise ValueError( + "Unsupported adjacency matrix format.", + "Ensure it is a pytorch sparse tensor, an edge index tensor, or a square dense tensor.", + "Consider updating the propagate layer to handle alternative formats.", + ) + + self.blocks = LayerList() + + for i, (c_in, c_out) in enumerate( + zip([in_features, *hidden_features], [*hidden_features, out_features]) + ): + transform = Layer(nn.Linear, c_in, c_out) + transform.set_input_map("x") + transform.set_output_map("x_prime") + + propagate = Layer(Propagate) + propagate.set_input_map("x_prime", "edge_index") + propagate.set_output_map("x_prime") + + update = Layer(nn.ReLU) if i < len(self.hidden_features) else out_activation + update.set_input_map("x_prime") + update.set_output_map("x_prime") + + block = TransformPropagateUpdate( + transform=transform, + propagate=propagate, + update=update, + order=["transform", "update", "propagate"] + ) + self.blocks.append(block) + + self.concat = Cat() + self.concat.set_input_map('x_prime', 'x') + self.concat.set_output_map('x') + + self.dense = Layer(nn.Linear, out_features*2, out_features) + self.dense.set_input_map('x') + self.dense.set_output_map('x') + + self.activate = Layer(nn.ReLU) + self.activate.set_input_map('x') + self.activate.set_output_map('x') + + def forward(self, x): + for block in self.blocks: + x = block(x) + + x = self.concat(x) + x = self.dense(x) + x = self.activate(x) + + return x + + @overload + def configure( + self, + /, + in_features: Optional[int] = None, + hidden_features: Optional[List[int]] = None, + out_features: Optional[int] = None, + out_activation: Union[Type[nn.Module], nn.Module, None] = None, + ) -> None: ... + + @overload + def configure( + self, + name: Literal["blocks"], + order: Optional[Sequence[str]] = None, + transform: Optional[Type[nn.Module]] = None, + propagate: Optional[Type[nn.Module]] = None, + update: Optional[Type[nn.Module]] = None, + **kwargs: Any, + ) -> None: ... + + @overload + def configure( + self, + name: Literal["blocks"], + index: Union[int, slice, List[Union[int, slice]]], + order: Optional[Sequence[str]] = None, + transform: Optional[Type[nn.Module]] = None, + propagate: Optional[Type[nn.Module]] = None, + update: Optional[Type[nn.Module]] = None, + **kwargs: Any, + ) -> None: ... + + configure = DeeplayModule.configure diff --git a/deeplay/components/gnn/graphencdec.py b/deeplay/components/gnn/graphencdec.py index 06f1fa16..e3c87a34 100644 --- a/deeplay/components/gnn/graphencdec.py +++ b/deeplay/components/gnn/graphencdec.py @@ -7,12 +7,14 @@ Layer, LayerList, ) -from deeplay.components.gnn import GraphConvolutionalNeuralNetwork, MessagePassingNeuralNetwork +from deeplay.components.gnn import MessagePassingNeuralNetworkGAUDI, GraphConvolutionalNeuralNetworkConcat from deeplay.components.gnn.pooling import MinCutPooling, MinCutUpsampling from deeplay.ops import Cat from deeplay.components.gnn.pooling import GlobalGraphPooling, GlobalGraphUpsampling -from deeplay.components.gnn.mpn import GetEdgeFeatures +# from deeplay.components.gnn.mpn import TransformOnlySenderNodes from deeplay.components.mlp import MultiLayerPerceptron +from deeplay.ops import GetEdgeFeaturesNew +# from deeplay.components.gnn.mpn.propagation import Mean import torch.nn as nn @@ -104,12 +106,22 @@ def __init__( assert len(num_clusters) == num_blocks, "Lenght of number of clusters should match num_blocks." - self.message_passing = MessagePassingNeuralNetwork( + self.message_passing = MessagePassingNeuralNetworkGAUDI( hidden_features=[hidden_features], out_features=hidden_features, out_activation=nn.ReLU ) + # self.message_passing.transform = TransformOnlySenderNodes( + # combine=Cat(), + # layer=Layer(nn.LazyLinear, hidden_features), + # activation=nn.ReLU, + # ) + + # self.message_passing.transform.set_input_map("x", "edge_index", "input_edge_attr") + # self.message_passing.propagate = Mean() + # self.message_passing.propagate.set_input_map("x", "edge_index", "edge_attr") + self.dense = Layer(nn.Linear, hidden_features, hidden_features) self.dense.set_input_map('x') self.dense.set_output_map('x') @@ -165,6 +177,8 @@ def __init__( self.blocks.append(block) def forward(self, x): + x['input_edge_index'] = x['edge_index'] # Do this in a nicer way + x['input_edge_attr'] = x['edge_attr'] x = self.message_passing(x) x = self.dense(x) x = self.activate(x) @@ -277,32 +291,60 @@ def __init__( ) self.blocks.append(block) - + self.dense = Layer(nn.Linear, hidden_features, hidden_features) self.dense.set_input_map('x') self.dense.set_output_map('x') self.activate = Layer(nn.ReLU) self.activate.set_input_map('x') - self.activate.set_output_map('x') - - # get the edge features: - self.get_edge_attr = GetEdgeFeatures( - combine=Cat(), - layer=Layer(nn.LazyLinear, hidden_features), - activation=Layer(nn.ReLU), - ) - self.get_edge_attr.set_input_map("x", "edge_index") - self.get_edge_attr.set_output_map("edge_attr") - - self.edge_mlp = MultiLayerPerceptron( - in_features=hidden_features, - hidden_features=[hidden_features], - out_features=output_edge_dim, - out_activation=None, - ) - self.edge_mlp.set_input_map("edge_attr") - self.edge_mlp.set_output_map("edge_attr") + self.activate.set_output_map('x') + + self.get_edge_attr = GetEdgeFeaturesNew() + self.get_edge_attr.set_input_map("x", "input_edge_index") + self.get_edge_attr.set_output_map("edge_attr_sender", "edge_attr_receiver") + + self.dense_sender = Layer(nn.Linear, hidden_features, hidden_features) + self.dense_sender.set_input_map('edge_attr_sender') + self.dense_sender.set_output_map('edge_attr_sender') + + self.activate_sender = Layer(nn.ReLU) + self.activate_sender.set_input_map('edge_attr_sender') + self.activate_sender.set_output_map('edge_attr_sender') + + self.dense_receiver = Layer(nn.Linear, hidden_features, hidden_features) + self.dense_receiver.set_input_map('edge_attr_receiver') + self.dense_receiver.set_output_map('edge_attr_receiver') + + self.activate_receiver = Layer(nn.ReLU) + self.activate_receiver.set_input_map('edge_attr_receiver') + self.activate_receiver.set_output_map('edge_attr_receiver') + + self.concat_edge_attr = Cat() + self.concat_edge_attr.set_input_map('edge_attr_sender', 'edge_attr_receiver') + self.concat_edge_attr.set_output_map('edge_attr') + + self.dense_edge_mlp_1 = Layer(nn.Linear, hidden_features * 2, hidden_features) + self.dense_edge_mlp_1.set_input_map('edge_attr') + self.dense_edge_mlp_1.set_output_map('edge_attr') + + self.activate_edge_mlp_1 = Layer(nn.ReLU) + self.activate_edge_mlp_1.set_input_map('edge_attr') + self.activate_edge_mlp_1.set_output_map('edge_attr') + + self.dense_edge_mlp_2 = Layer(nn.Linear, hidden_features, output_edge_dim) + self.dense_edge_mlp_2.set_input_map('edge_attr') + self.dense_edge_mlp_2.set_output_map('edge_attr') + + # # get the edge features: + # self.edge_mlp = MultiLayerPerceptron( + # in_features = hidden_features * 2, + # hidden_features = [hidden_features], + # out_features = output_edge_dim, + # out_activation = None, + # ) + # self.edge_mlp.set_input_map('edge_attr') + # self.edge_mlp.set_output_map('edge_attr') # get the node features: self.node_mlp = MultiLayerPerceptron( @@ -312,7 +354,7 @@ def __init__( out_activation = None, ) self.node_mlp.set_input_map('x') - self.node_mlp.set_output_map('x') + self.node_mlp.set_output_map('x') def forward(self, x): for block in self.blocks: @@ -320,8 +362,21 @@ def forward(self, x): x = self.dense(x) x = self.activate(x) + x = self.get_edge_attr(x) - x = self.edge_mlp(x) + x = self.dense_sender(x) + x = self.activate_sender(x) + x = self.dense_receiver(x) + x = self.activate_receiver(x) + x = self.concat_edge_attr(x) + + x = self.dense_edge_mlp_1(x) + x = self.activate_edge_mlp_1(x) + + x = self.dense_edge_mlp_2(x) + + # x = self.edge_mlp(x) + x = self.node_mlp(x) return x @@ -411,18 +466,16 @@ def __init__( ) self.edge_index_map = edge_index_map self.connect_output_map = connect_output_map - - self.gcn = GraphConvolutionalNeuralNetwork( + + self.gcn = GraphConvolutionalNeuralNetworkConcat( in_features=in_features, hidden_features=[], out_features=out_features, out_activation=nn.ReLU, ) - - self.gcn.normalize.set_input_map('x', edge_index_map) - self.gcn.normalize.set_output_map('laplacian') - self.gcn.propagate.set_input_map("x", "laplacian") - + + self.gcn.propagate.set_input_map("x", edge_index_map) + if pool == MinCutPooling: if num_clusters is None: raise ValueError("num_clusters must be provided for MinCutPooling") @@ -532,17 +585,15 @@ def __init__( else: self.upsample = upsample() self.upsample.upsample.set_input_map('x', select_input_map) - - self.gcn = GraphConvolutionalNeuralNetwork( + + self.gcn = GraphConvolutionalNeuralNetworkConcat( in_features=in_features, hidden_features=[], out_features=out_features, out_activation=nn.ReLU, ) - self.gcn.normalize.set_input_map('x', edge_index_map) - self.gcn.normalize.set_output_map('laplacian') - self.gcn.propagate.set_input_map("x", "laplacian") + self.gcn.propagate.set_input_map("x", edge_index_map) def forward(self, x): x = self.upsample(x) diff --git a/deeplay/components/gnn/mpn/__init__.py b/deeplay/components/gnn/mpn/__init__.py index fbd18602..bcee56ed 100644 --- a/deeplay/components/gnn/mpn/__init__.py +++ b/deeplay/components/gnn/mpn/__init__.py @@ -1,4 +1,5 @@ from .mpn import MessagePassingNeuralNetwork +from .mpn_gaudi import MessagePassingNeuralNetworkGAUDI from .rmpn import ResidualMessagePassingNeuralNetwork from .transformation import * diff --git a/deeplay/components/gnn/mpn/mpn_gaudi.py b/deeplay/components/gnn/mpn/mpn_gaudi.py new file mode 100644 index 00000000..6cdc9b02 --- /dev/null +++ b/deeplay/components/gnn/mpn/mpn_gaudi.py @@ -0,0 +1,150 @@ +from typing import List, Optional, Literal, Any, Sequence, Type, overload, Union + +from deeplay import DeeplayModule, Layer, LayerList +from deeplay.ops import Cat + +from ..tpu import TransformPropagateUpdate + +from .transformation import TransformOnlySenderNodes +from .propagation import Mean +from .update import Update + +import torch.nn as nn + + +class MessagePassingNeuralNetworkGAUDI(DeeplayModule): + hidden_features: Sequence[Optional[int]] + out_features: int + blocks: LayerList[TransformPropagateUpdate] + + @property + def input(self): + """Return the input layer of the network. Equivalent to `.blocks[0]`.""" + return self.blocks[0] + + @property + def hidden(self): + """Return the hidden layers of the network. Equivalent to `.blocks[:-1]`""" + return self.blocks[:-1] + + @property + def output(self): + """Return the last layer of the network. Equivalent to `.blocks[-1]`.""" + return self.blocks[-1] + + @property + def transform(self) -> LayerList[Layer]: + """Return the transform layers of the network. Equivalent to `.blocks.transform`.""" + return self.blocks.transform + + @property + def propagate(self) -> LayerList[Layer]: + """Return the propagate layers of the network. Equivalent to `.blocks.propagate`.""" + return self.blocks.propagate + + @property + def update(self) -> LayerList[Layer]: + """Return the update layers of the network. Equivalent to `.blocks.update`.""" + return self.blocks.update + + def __init__( + self, + hidden_features: Sequence[int], + out_features: int, + out_activation: Union[Type[nn.Module], nn.Module, None] = None, + ): + super().__init__() + + self.hidden_features = hidden_features + self.out_features = out_features + + if any(h <= 0 for h in hidden_features): + raise ValueError( + f"all hidden_channels must be positive, got {hidden_features}" + ) + + if out_features is None: + raise ValueError("out_features must be specified") + + if out_features <= 0: + raise ValueError( + f"Number of output features must be positive, got {out_features}" + ) + + if out_activation is None: + out_activation = Layer(nn.Identity) + elif isinstance(out_activation, type) and issubclass(out_activation, nn.Module): + out_activation = Layer(out_activation) + + self.blocks = LayerList() + for i, c_out in enumerate([*hidden_features, out_features]): + activation = ( + Layer(nn.ReLU) if i < len(hidden_features) - 1 else out_activation + ) + + transform = TransformOnlySenderNodes( + combine=Cat(), + layer=Layer(nn.LazyLinear, c_out), + activation=activation.new(), + ) + transform.set_input_map("x", "edge_index", "input_edge_attr") + transform.set_output_map("edge_attr") + + propagate = Mean() + propagate.set_input_map("x", "edge_index", "edge_attr") + propagate.set_output_map("aggregate") + + update = Update( + combine=Cat(), + layer=Layer(nn.LazyLinear, c_out), + activation=activation.new(), + ) + update.set_input_map("x", "aggregate") + update.set_output_map("x") + + block = TransformPropagateUpdate( + transform=transform, + propagate=propagate, + update=update, + ) + self.blocks.append(block) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + @overload + def configure( + self, + /, + in_features: Optional[int] = None, + hidden_features: Optional[List[int]] = None, + out_features: Optional[int] = None, + out_activation: Union[Type[nn.Module], nn.Module, None] = None, + ) -> None: ... + + @overload + def configure( + self, + name: Literal["blocks"], + order: Optional[Sequence[str]] = None, + transform: Optional[Type[nn.Module]] = None, + propagate: Optional[Type[nn.Module]] = None, + update: Optional[Type[nn.Module]] = None, + **kwargs: Any, + ) -> None: ... + + @overload + def configure( + self, + name: Literal["blocks"], + index: Union[int, slice, List[Union[int, slice]]], + order: Optional[Sequence[str]] = None, + transform: Optional[Type[nn.Module]] = None, + propagate: Optional[Type[nn.Module]] = None, + update: Optional[Type[nn.Module]] = None, + **kwargs: Any, + ) -> None: ... + + configure = DeeplayModule.configure diff --git a/deeplay/components/gnn/mpn/transformation.py b/deeplay/components/gnn/mpn/transformation.py index 7ee44740..7263915f 100644 --- a/deeplay/components/gnn/mpn/transformation.py +++ b/deeplay/components/gnn/mpn/transformation.py @@ -14,3 +14,16 @@ def get_forward_args(self, x): """ x, edge_index, edge_attr = x return x[edge_index[0]], x[edge_index[1]], edge_attr + +class TransformOnlySenderNodes(CombineLayerActivation): + """Transform module for MPN.""" + + def get_forward_args(self, x): + """Get the arguments for the Transform module. + An MPN Transform module takes the following arguments: + - node features of sender nodes (x[A[0]]) + - edge features (edgefeat) + A is the adjacency matrix of the graph. + """ + x, edge_index, edge_attr = x + return x[edge_index[0]], edge_attr diff --git a/deeplay/components/gnn/pooling/mincut.py b/deeplay/components/gnn/pooling/mincut.py index f2ead2d1..ce307d4e 100644 --- a/deeplay/components/gnn/pooling/mincut.py +++ b/deeplay/components/gnn/pooling/mincut.py @@ -115,8 +115,11 @@ def forward(self, A): ind = torch.arange(A.shape[0]) A[ind, ind] = 0 D = torch.einsum('jk->j', A) - D = torch.sqrt(D)[None] + self.eps - A = (A / D) / D.transpose(-2,-1) + D_inv_sq = torch.pow(D, -0.5) + D_inv_sq = torch.where(torch.isinf(D_inv_sq), torch.tensor(0.0), D_inv_sq) + D_inv_sq = torch.diag(D_inv_sq) + + A = D_inv_sq @ A @ D_inv_sq return A class MinCutLoss(DeeplayModule): diff --git a/deeplay/ops/__init__.py b/deeplay/ops/__init__.py index 7715eb2c..22b5e923 100644 --- a/deeplay/ops/__init__.py +++ b/deeplay/ops/__init__.py @@ -2,3 +2,4 @@ from .logs import FromLogs from .attention import * from .merge import * +from .get_edge_features import * diff --git a/deeplay/ops/get_edge_features.py b/deeplay/ops/get_edge_features.py new file mode 100644 index 00000000..a5b8e7aa --- /dev/null +++ b/deeplay/ops/get_edge_features.py @@ -0,0 +1,14 @@ + +from deeplay import DeeplayModule + +class GetEdgeFeaturesNew(DeeplayModule): + """""" + + def forward(self, x, edge_index): + """Get the node features of neighboring nodes for each edge. + - node features of sender nodes (x[edge_index[0]]) + - node features of receiver nodes (x[edge_index[1]]) + + edge_index denote the connectivity of the graph. + """ + return x[edge_index[0]], x[edge_index[1]]