diff --git a/setup.py b/setup.py index ac04deb..c4ba4a4 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ description="Scripts to generate graphs, train and evaluate graph representations", install_requires=[ "Click>=7.1.2", - "networkx", + "networkx~=2.6.3", "scipy", "scikit-learn", "numpy", diff --git a/src/graph_modeling/generate/__main__.py b/src/graph_modeling/generate/__main__.py index 0c703fb..b36f0cf 100644 --- a/src/graph_modeling/generate/__main__.py +++ b/src/graph_modeling/generate/__main__.py @@ -162,3 +162,13 @@ def hac(outdir, **graph_config): def knn_graph(outdir, **graph_config): """Writes out a KNN graph""" write_graph(outdir, type="knn_graph", **graph_config) + +@_common_options +def simple_cycle(outdir, **graph_config): + """Writes out a directed cycle""" + write_graph(outdir, type="simple_cycle", **graph_config) + +@_common_options +def simple_cycle_with_reverse_edge(outdir, **graph_config): + """Writes out a directed cycle with reverse edge""" + write_graph(outdir, type="simple_cycle_with_reverse_edge", **graph_config) diff --git a/src/graph_modeling/generate/simple_cycle.py b/src/graph_modeling/generate/simple_cycle.py new file mode 100644 index 0000000..c34952c --- /dev/null +++ b/src/graph_modeling/generate/simple_cycle.py @@ -0,0 +1,28 @@ +# Copyright 2021 The Geometric Graph Embedding Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import networkx as nx + +__all__ = [ + "generate", +] + + +def generate(log_num_nodes: int, **kwargs) -> nx.DiGraph: + num_nodes = 2 ** log_num_nodes + + D = nx.cycle_graph(num_nodes,create_using=nx.DiGraph) + return D diff --git a/src/graph_modeling/generate/simple_cycle_with_reverse_edge.py b/src/graph_modeling/generate/simple_cycle_with_reverse_edge.py new file mode 100644 index 0000000..5c14968 --- /dev/null +++ b/src/graph_modeling/generate/simple_cycle_with_reverse_edge.py @@ -0,0 +1,28 @@ +# Copyright 2021 The Geometric Graph Embedding Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import networkx as nx + +__all__ = [ + "generate", +] + +def generate(log_num_nodes: int, **kwargs) -> nx.DiGraph: + num_nodes = 2 ** log_num_nodes + + D = nx.cycle_graph(num_nodes,create_using=nx.DiGraph) + D.add_edge(1,0) + return D diff --git a/src/graph_modeling/models/box.py b/src/graph_modeling/models/box.py index 94d281c..8ba4660 100644 --- a/src/graph_modeling/models/box.py +++ b/src/graph_modeling/models/box.py @@ -29,6 +29,8 @@ __all__ = [ "BoxMinDeltaSoftplus", "TBox", + "GBCBox", + "VBCBox" ] @@ -207,3 +209,150 @@ def forward( overwrite=True, ) return out + + +class GBCBox(Module): + def __init__(self, num_entity, dim, num_universe=1.0, volume_temp=1.0, intersection_temp=1.0): + super().__init__() + self.num_universe = num_universe + self.centers = torch.nn.Embedding(num_entity, dim) + self.centers.weight.data.uniform_(-0.1, 0.1) + self.sidelengths = torch.nn.Embedding(num_entity, dim) + self.sidelengths.weight.data.zero_() + self.codes = torch.nn.Embedding(self.num_universe, dim) + torch.nn.init.uniform_(self.codes.weight, -0.1, 0.1) + + self.volume_temp = volume_temp + self.intersection_temp = intersection_temp + self.softplus = torch.nn.Softplus(beta=1 / self.volume_temp) + self.sigmoid = torch.nn.Sigmoid() + self.softplus_const = 2 * self.intersection_temp * 0.57721566490153286060 + + def log_volume(self, z, Z): + log_vol_per_dim = torch.log(self.softplus( + Z - z - self.softplus_const)).unsqueeze(-2) + + if len(log_vol_per_dim.shape) == 4: + log_vol_per_subspace = torch.sum( + log_vol_per_dim * self.sigmoid(self.codes.weight[None, None, :, :]), -1) # ..., num_universe + if len(log_vol_per_dim.shape) == 3: + log_vol_per_subspace = torch.sum( + log_vol_per_dim * self.sigmoid(self.codes.weight[None, :, :]), -1) # ..., num_universe + + return log_vol_per_subspace + + def embedding_lookup(self, idx): + center = self.centers(idx) + length = self.softplus(self.sidelengths(idx)) + z = center - length + Z = center + length + return z, Z + + def gumbel_intersection(self, e1_min, e1_max, e2_min, e2_max): + meet_min = self.intersection_temp * torch.logsumexp( + torch.stack( + [e1_min / self.intersection_temp, e2_min / self.intersection_temp] + ), + 0, + ) + meet_max = -self.intersection_temp * torch.logsumexp( + torch.stack( + [-e1_max / self.intersection_temp, -e2_max / self.intersection_temp] + ), + 0, + ) + meet_min = torch.max(meet_min, torch.max(e1_min, e2_min)) + meet_max = torch.min(meet_max, torch.min(e1_max, e2_max)) + return meet_min, meet_max + + def forward(self, idxs): + """ + :param idxs: Tensor of shape (..., 2) (N, K+1, 2) during training or (N, 2) during testing + :return: log prob of shape (..., ) + """ + e1_min, e1_max = self.embedding_lookup(idxs[..., 0]) + e2_min, e2_max = self.embedding_lookup(idxs[..., 1]) + + meet_min, meet_max = self.gumbel_intersection(e1_min, e1_max, e2_min, e2_max) + + log_overlap_volume = self.log_volume(meet_min, meet_max) + log_rhs_volume = self.log_volume(e2_min, e2_max) + log_conditional = log_overlap_volume - log_rhs_volume + log_conditional = torch.max(log_conditional, -1)[0] + + return log_conditional + + +class VBCBox(Module): + def __init__(self, num_entity, dim, dim_share=0, volume_temp=1.0, intersection_temp=1.0): + super().__init__() + self.dim_share = dim_share + self.centers = torch.nn.Embedding(num_entity, dim) + self.centers.weight.data.uniform_(-0.1, 0.1) + self.sidelengths = torch.nn.Embedding(num_entity, dim) + self.sidelengths.weight.data.zero_() + self.codes = torch.nn.Embedding(num_entity, dim - dim_share) + #self.ones = torch.nn.Embedding(num_entity, dim_share) + #torch.nn.init.ones_(self.ones.weight) + #self.ones.weight.requires_grad = False + + self.volume_temp = volume_temp + self.intersection_temp = intersection_temp + self.softplus = torch.nn.Softplus(beta=1 / self.volume_temp) + self.sigmoid = torch.nn.Sigmoid() + self.softplus_const = 2 * self.intersection_temp * 0.57721566490153286060 + + def log_volume(self, z, Z, c): + # a code of near zero will make the corresponding dimension to have small affect on the final volume + log_vol = torch.sum( + torch.log(self.softplus(Z - z - self.softplus_const)) * c, dim=-1, + ) + + return log_vol + + def embedding_lookup(self, idx): + center = self.centers(idx) + length = self.softplus(self.sidelengths(idx)) + code = self.sigmoid(self.codes(idx)) # ..., dim - dim_share + #ones = self.ones(idx) # ..., dim_share + ones = center[...,:self.dim_share].abs() + 1.0 + ones = ones / ones + code = torch.cat([code, ones], axis=-1) + z = center - length + Z = center + length + return z, Z, code + + def gumbel_intersection(self, e1_min, e1_max, e2_min, e2_max): + meet_min = self.intersection_temp * torch.logsumexp( + torch.stack( + [e1_min / self.intersection_temp, e2_min / self.intersection_temp] + ), + 0, + ) + meet_max = -self.intersection_temp * torch.logsumexp( + torch.stack( + [-e1_max / self.intersection_temp, -e2_max / self.intersection_temp] + ), + 0, + ) + meet_min = torch.max(meet_min, torch.max(e1_min, e2_min)) + meet_max = torch.min(meet_max, torch.min(e1_max, e2_max)) + return meet_min, meet_max + + def forward(self, idxs): + """ + :param idxs: Tensor of shape (..., 2) (N, K+1, 2) during training or (N, 2) during testing + :return: log prob of shape (..., ) + """ + e1_min, e1_max, e1_code = self.embedding_lookup(idxs[..., 0]) + e2_min, e2_max, e2_code = self.embedding_lookup(idxs[..., 1]) + + meet_min, meet_max = self.gumbel_intersection(e1_min, e1_max, e2_min, e2_max) + + code_intersection = e1_code * e2_code + + log_overlap_volume = self.log_volume(meet_min, meet_max, code_intersection) + log_rhs_volume = self.log_volume(e2_min, e2_max, code_intersection) + log_conditional = log_overlap_volume - log_rhs_volume + + return log_conditional diff --git a/src/graph_modeling/training/__main__.py b/src/graph_modeling/training/__main__.py index fc62b67..ca71377 100644 --- a/src/graph_modeling/training/__main__.py +++ b/src/graph_modeling/training/__main__.py @@ -56,6 +56,8 @@ def convert(self, value, param, ctx): type=click.Choice( [ "tbox", + "gbcbox", + "vbcbox", "gumbel_box", "order_embeddings", "partial_order_embeddings", @@ -88,6 +90,12 @@ def convert(self, value, param, ctx): @click.option( "--dim", type=int, default=4, help="dimension for embedding space", ) +@click.option( + "--num_universe", type=int, default=2, help="number of universes in GBC-Box (unused otherwise)", +) +@click.option( + "--shared_dim", type=int, default=0, help="dimension for shared space for VBC-Box (unused otherwise)", +) @click.option( "--log_batch_size", type=int, diff --git a/src/graph_modeling/training/train.py b/src/graph_modeling/training/train.py index 8850364..ce729b3 100644 --- a/src/graph_modeling/training/train.py +++ b/src/graph_modeling/training/train.py @@ -53,7 +53,7 @@ MaxMarginOENegativeSamplingLoss, ) from .. import metric_logger -from ..models.box import BoxMinDeltaSoftplus, TBox +from ..models.box import BoxMinDeltaSoftplus, TBox, GBCBox, VBCBox from ..models.hyperbolic import ( Lorentzian, LorentzianDistance, @@ -203,6 +203,24 @@ def setup_model( ), ) loss_func = BCEWithLogsNegativeSamplingLoss(config["negative_weight"]) + elif model_type == "gbcbox": + model = GBCBox( + num_nodes, + config["dim"], + config["num_universe"], + volume_temp=config["box_volume_temp"], + intersection_temp=config["box_intersection_temp"], + ) + loss_func = BCEWithLogsNegativeSamplingLoss(config["negative_weight"]) + elif model_type == "vbcbox": + model = VBCBox( + num_nodes, + config["dim"], + config["shared_dim"], + volume_temp=config["box_volume_temp"], + intersection_temp=config["box_intersection_temp"], + ) + loss_func = BCEWithLogsNegativeSamplingLoss(config["negative_weight"]) elif model_type == "order_embeddings": model = OE(num_nodes, config["dim"]) loss_func = MaxMarginOENegativeSamplingLoss(