Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
description="Scripts to generate graphs, train and evaluate graph representations",
install_requires=[
"Click>=7.1.2",
"networkx",
"networkx~=2.6.3",
"scipy",
"scikit-learn",
"numpy",
Expand Down
10 changes: 10 additions & 0 deletions src/graph_modeling/generate/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,13 @@ def hac(outdir, **graph_config):
def knn_graph(outdir, **graph_config):
"""Writes out a KNN graph"""
write_graph(outdir, type="knn_graph", **graph_config)

@_common_options
def simple_cycle(outdir, **graph_config):
"""Writes out a directed cycle"""
write_graph(outdir, type="simple_cycle", **graph_config)

@_common_options
def simple_cycle_with_reverse_edge(outdir, **graph_config):
"""Writes out a directed cycle with reverse edge"""
write_graph(outdir, type="simple_cycle_with_reverse_edge", **graph_config)
28 changes: 28 additions & 0 deletions src/graph_modeling/generate/simple_cycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2021 The Geometric Graph Embedding Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import networkx as nx

__all__ = [
"generate",
]


def generate(log_num_nodes: int, **kwargs) -> nx.DiGraph:
num_nodes = 2 ** log_num_nodes

D = nx.cycle_graph(num_nodes,create_using=nx.DiGraph)
return D
28 changes: 28 additions & 0 deletions src/graph_modeling/generate/simple_cycle_with_reverse_edge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2021 The Geometric Graph Embedding Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import networkx as nx

__all__ = [
"generate",
]

def generate(log_num_nodes: int, **kwargs) -> nx.DiGraph:
num_nodes = 2 ** log_num_nodes

D = nx.cycle_graph(num_nodes,create_using=nx.DiGraph)
D.add_edge(1,0)
return D
149 changes: 149 additions & 0 deletions src/graph_modeling/models/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
__all__ = [
"BoxMinDeltaSoftplus",
"TBox",
"GBCBox",
"VBCBox"
]


Expand Down Expand Up @@ -207,3 +209,150 @@ def forward(
overwrite=True,
)
return out


class GBCBox(Module):
def __init__(self, num_entity, dim, num_universe=1.0, volume_temp=1.0, intersection_temp=1.0):
super().__init__()
self.num_universe = num_universe
self.centers = torch.nn.Embedding(num_entity, dim)
self.centers.weight.data.uniform_(-0.1, 0.1)
self.sidelengths = torch.nn.Embedding(num_entity, dim)
self.sidelengths.weight.data.zero_()
self.codes = torch.nn.Embedding(self.num_universe, dim)
torch.nn.init.uniform_(self.codes.weight, -0.1, 0.1)

self.volume_temp = volume_temp
self.intersection_temp = intersection_temp
self.softplus = torch.nn.Softplus(beta=1 / self.volume_temp)
self.sigmoid = torch.nn.Sigmoid()
self.softplus_const = 2 * self.intersection_temp * 0.57721566490153286060

def log_volume(self, z, Z):
log_vol_per_dim = torch.log(self.softplus(
Z - z - self.softplus_const)).unsqueeze(-2)

if len(log_vol_per_dim.shape) == 4:
log_vol_per_subspace = torch.sum(
log_vol_per_dim * self.sigmoid(self.codes.weight[None, None, :, :]), -1) # ..., num_universe
if len(log_vol_per_dim.shape) == 3:
log_vol_per_subspace = torch.sum(
log_vol_per_dim * self.sigmoid(self.codes.weight[None, :, :]), -1) # ..., num_universe

return log_vol_per_subspace

def embedding_lookup(self, idx):
center = self.centers(idx)
length = self.softplus(self.sidelengths(idx))
z = center - length
Z = center + length
return z, Z

def gumbel_intersection(self, e1_min, e1_max, e2_min, e2_max):
meet_min = self.intersection_temp * torch.logsumexp(
torch.stack(
[e1_min / self.intersection_temp, e2_min / self.intersection_temp]
),
0,
)
meet_max = -self.intersection_temp * torch.logsumexp(
torch.stack(
[-e1_max / self.intersection_temp, -e2_max / self.intersection_temp]
),
0,
)
meet_min = torch.max(meet_min, torch.max(e1_min, e2_min))
meet_max = torch.min(meet_max, torch.min(e1_max, e2_max))
return meet_min, meet_max

def forward(self, idxs):
"""
:param idxs: Tensor of shape (..., 2) (N, K+1, 2) during training or (N, 2) during testing
:return: log prob of shape (..., )
"""
e1_min, e1_max = self.embedding_lookup(idxs[..., 0])
e2_min, e2_max = self.embedding_lookup(idxs[..., 1])

meet_min, meet_max = self.gumbel_intersection(e1_min, e1_max, e2_min, e2_max)

log_overlap_volume = self.log_volume(meet_min, meet_max)
log_rhs_volume = self.log_volume(e2_min, e2_max)
log_conditional = log_overlap_volume - log_rhs_volume
log_conditional = torch.max(log_conditional, -1)[0]

return log_conditional


class VBCBox(Module):
def __init__(self, num_entity, dim, dim_share=0, volume_temp=1.0, intersection_temp=1.0):
super().__init__()
self.dim_share = dim_share
self.centers = torch.nn.Embedding(num_entity, dim)
self.centers.weight.data.uniform_(-0.1, 0.1)
self.sidelengths = torch.nn.Embedding(num_entity, dim)
self.sidelengths.weight.data.zero_()
self.codes = torch.nn.Embedding(num_entity, dim - dim_share)
#self.ones = torch.nn.Embedding(num_entity, dim_share)
#torch.nn.init.ones_(self.ones.weight)
#self.ones.weight.requires_grad = False

self.volume_temp = volume_temp
self.intersection_temp = intersection_temp
self.softplus = torch.nn.Softplus(beta=1 / self.volume_temp)
self.sigmoid = torch.nn.Sigmoid()
self.softplus_const = 2 * self.intersection_temp * 0.57721566490153286060

def log_volume(self, z, Z, c):
# a code of near zero will make the corresponding dimension to have small affect on the final volume
log_vol = torch.sum(
torch.log(self.softplus(Z - z - self.softplus_const)) * c, dim=-1,
)

return log_vol

def embedding_lookup(self, idx):
center = self.centers(idx)
length = self.softplus(self.sidelengths(idx))
code = self.sigmoid(self.codes(idx)) # ..., dim - dim_share
#ones = self.ones(idx) # ..., dim_share
ones = center[...,:self.dim_share].abs() + 1.0
ones = ones / ones
code = torch.cat([code, ones], axis=-1)
z = center - length
Z = center + length
return z, Z, code

def gumbel_intersection(self, e1_min, e1_max, e2_min, e2_max):
meet_min = self.intersection_temp * torch.logsumexp(
torch.stack(
[e1_min / self.intersection_temp, e2_min / self.intersection_temp]
),
0,
)
meet_max = -self.intersection_temp * torch.logsumexp(
torch.stack(
[-e1_max / self.intersection_temp, -e2_max / self.intersection_temp]
),
0,
)
meet_min = torch.max(meet_min, torch.max(e1_min, e2_min))
meet_max = torch.min(meet_max, torch.min(e1_max, e2_max))
return meet_min, meet_max

def forward(self, idxs):
"""
:param idxs: Tensor of shape (..., 2) (N, K+1, 2) during training or (N, 2) during testing
:return: log prob of shape (..., )
"""
e1_min, e1_max, e1_code = self.embedding_lookup(idxs[..., 0])
e2_min, e2_max, e2_code = self.embedding_lookup(idxs[..., 1])

meet_min, meet_max = self.gumbel_intersection(e1_min, e1_max, e2_min, e2_max)

code_intersection = e1_code * e2_code

log_overlap_volume = self.log_volume(meet_min, meet_max, code_intersection)
log_rhs_volume = self.log_volume(e2_min, e2_max, code_intersection)
log_conditional = log_overlap_volume - log_rhs_volume

return log_conditional
8 changes: 8 additions & 0 deletions src/graph_modeling/training/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def convert(self, value, param, ctx):
type=click.Choice(
[
"tbox",
"gbcbox",
"vbcbox",
"gumbel_box",
"order_embeddings",
"partial_order_embeddings",
Expand Down Expand Up @@ -88,6 +90,12 @@ def convert(self, value, param, ctx):
@click.option(
"--dim", type=int, default=4, help="dimension for embedding space",
)
@click.option(
"--num_universe", type=int, default=2, help="number of universes in GBC-Box (unused otherwise)",
)
@click.option(
"--shared_dim", type=int, default=0, help="dimension for shared space for VBC-Box (unused otherwise)",
)
@click.option(
"--log_batch_size",
type=int,
Expand Down
20 changes: 19 additions & 1 deletion src/graph_modeling/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
MaxMarginOENegativeSamplingLoss,
)
from .. import metric_logger
from ..models.box import BoxMinDeltaSoftplus, TBox
from ..models.box import BoxMinDeltaSoftplus, TBox, GBCBox, VBCBox
from ..models.hyperbolic import (
Lorentzian,
LorentzianDistance,
Expand Down Expand Up @@ -203,6 +203,24 @@ def setup_model(
),
)
loss_func = BCEWithLogsNegativeSamplingLoss(config["negative_weight"])
elif model_type == "gbcbox":
model = GBCBox(
num_nodes,
config["dim"],
config["num_universe"],
volume_temp=config["box_volume_temp"],
intersection_temp=config["box_intersection_temp"],
)
loss_func = BCEWithLogsNegativeSamplingLoss(config["negative_weight"])
elif model_type == "vbcbox":
model = VBCBox(
num_nodes,
config["dim"],
config["shared_dim"],
volume_temp=config["box_volume_temp"],
intersection_temp=config["box_intersection_temp"],
)
loss_func = BCEWithLogsNegativeSamplingLoss(config["negative_weight"])
elif model_type == "order_embeddings":
model = OE(num_nodes, config["dim"])
loss_func = MaxMarginOENegativeSamplingLoss(
Expand Down