From 38249ecc4c4276c84e2e7ec1c0c1fe405b50bbe8 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 11 Sep 2025 19:50:39 -0700 Subject: [PATCH 01/48] Adding distributed batch normalization layers --- experiments/OGB-LSC/distributed_layers.py | 160 ++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 experiments/OGB-LSC/distributed_layers.py diff --git a/experiments/OGB-LSC/distributed_layers.py b/experiments/OGB-LSC/distributed_layers.py new file mode 100644 index 0000000..05afd98 --- /dev/null +++ b/experiments/OGB-LSC/distributed_layers.py @@ -0,0 +1,160 @@ +import torch +from torch import nn +import torch.distributed as dist +from torch.autograd import Function + + +def _compute_bn_forward(input, learned_gamma=None, learned_beta=None): + local_sum = torch.mean(input, dim=0) + global_sum = local_sum.clone() + num_rows = torch.tensor([input.size(0)], dtype=torch.float32, device=input.device) + + global_num_rows = num_rows.clone() + + dist.all_reduce(global_num_rows, op=dist.ReduceOp.SUM) + global_mean = global_sum / global_num_rows + local_var = (input - global_mean) ** 2 + global_var = local_var.clone() + dist.all_reduce(global_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(global_var, op=dist.ReduceOp.SUM) + global_var = global_var / global_num_rows + + x_hat = (input - global_mean) / torch.sqrt(global_var + 1e-5) + if learned_gamma is not None and learned_beta is not None: + output = x_hat * learned_gamma + learned_beta + + return output, x_hat, global_mean, global_var, global_num_rows + + +class DistributedBN_with_Recompute(Function): + @staticmethod + def forward(ctx, input, learned_gamma=None, learned_beta=None): + ctx.save_for_backward(input) + ctx.learned_gamma = learned_gamma + ctx.learned_beta = learned_beta + output, _, global_mean, global_var, global_num_rows = _compute_bn_forward( + input, learned_gamma, learned_beta + ) + ctx.mean = global_mean + ctx.var = global_var + ctx.input = input + ctx.num_rows = global_num_rows + return output, global_mean, global_var + + @staticmethod + def backward(ctx, grad_output): + x = ctx.input + mean = ctx.mean + var = ctx.var + # recompute x_hat to save memory + x_hat = (x - mean) / torch.sqrt(var + 1e-5) + learned_gamma = ctx.learned_gamma + learned_beta = ctx.learned_beta + num_rows = ctx.num_rows + + if learned_gamma is not None and learned_beta is not None: + local_dbeta = torch.sum(grad_output, dim=0) + global_dbeta = local_dbeta.clone() + dist.all_reduce(global_dbeta, op=dist.ReduceOp.SUM) + local_dgamma = torch.sum(grad_output * x_hat, dim=0) + global_dgamma = local_dgamma.clone() + dist.all_reduce(global_dgamma, op=dist.ReduceOp.SUM) + dx_hat = grad_output * learned_gamma + else: + dx_hat = grad_output + global_dgamma = None + global_dbeta = None + + local_dvar = torch.sum(dx_hat * (x - mean) * -0.5 * (var + 1e-5) ** 2, dim=0) + global_dvar = local_dvar.clone() + dist.all_reduce(global_dvar, op=dist.ReduceOp.SUM) + + local_dmean = torch.sum( + dx_hat * -1 / torch.sqrt(var + 1e-5), dim=0 + ) + global_dvar * torch.mean(-2 * (x - mean), dim=0) + global_dmean = local_dmean.clone() + dist.all_reduce(global_dmean, op=dist.ReduceOp.SUM) + dx = ( + (dx_hat / torch.sqrt(var + 1e-5)) + + (global_dvar * 2 * (x - mean) / num_rows) + + (global_dmean / num_rows) + ) + + return dx, global_dgamma, global_dbeta + + +class DistributedBN_Impl(Function): + @staticmethod + def forward(ctx, input, learned_gamma=None, learned_beta=None): + output, x_hat, global_mean, global_var, global_num_rows = _compute_bn_forward( + input, learned_gamma, learned_beta + ) + + ctx.save_for_backward(x_hat) + ctx.learned_gamma = learned_gamma + ctx.learned_beta = learned_beta + ctx.mean = global_mean + ctx.var = global_var + ctx.num_rows = global_num_rows + ctx.input = input + ctx.x_hat = x_hat + return output, global_mean, global_var + + @staticmethod + def backward(ctx, grad_output): + + learned_gamma = ctx.learned_gamma + learned_beta = ctx.learned_beta + mean = ctx.mean + var = ctx.var + x_hat = ctx.x_hat + num_rows = ctx.num_rows + x = ctx.input + + if learned_gamma is not None and learned_beta is not None: + local_dbeta = torch.sum(grad_output, dim=0) + global_dbeta = local_dbeta.clone() + dist.all_reduce(global_dbeta, op=dist.ReduceOp.SUM) + local_dgamma = torch.sum(grad_output * x_hat, dim=0) + global_dgamma = local_dgamma.clone() + dist.all_reduce(global_dgamma, op=dist.ReduceOp.SUM) + dx_hat = grad_output * learned_gamma + else: + dx_hat = grad_output + global_dgamma = None + global_dbeta = None + + local_dvar = torch.sum(dx_hat * (x - mean) * -0.5 * (var + 1e-5) ** 2, dim=0) + global_dvar = local_dvar.clone() + dist.all_reduce(global_dvar, op=dist.ReduceOp.SUM) + + local_dmean = torch.sum( + dx_hat * -1 / torch.sqrt(var + 1e-5), dim=0 + ) + global_dvar * torch.mean(-2 * (x - mean), dim=0) + global_dmean = local_dmean.clone() + dist.all_reduce(global_dmean, op=dist.ReduceOp.SUM) + dx = ( + (dx_hat / torch.sqrt(var + 1e-5)) + + (global_dvar * 2 * (x - mean) / num_rows) + + (global_dmean / num_rows) + ) + + return dx, global_dgamma, global_dbeta + + +class DistributedBatchNorm1D(nn.Module): + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + ): + super(DistributedBatchNorm1D, self).__init__() + self.bn = nn.BatchNorm1d( + num_features, eps, momentum, affine, track_running_stats + ) + + def forward(self, x): + return self.bn(x) From b24ef7d6d254cbf50de5ca0f4eae8cef86df1bb1 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 11 Sep 2025 19:59:37 -0700 Subject: [PATCH 02/48] Wrap around implementation with torch nn module --- experiments/OGB-LSC/distributed_layers.py | 138 +++++++++++++--------- 1 file changed, 82 insertions(+), 56 deletions(-) diff --git a/experiments/OGB-LSC/distributed_layers.py b/experiments/OGB-LSC/distributed_layers.py index 05afd98..541fe16 100644 --- a/experiments/OGB-LSC/distributed_layers.py +++ b/experiments/OGB-LSC/distributed_layers.py @@ -2,6 +2,7 @@ from torch import nn import torch.distributed as dist from torch.autograd import Function +from typing import Callable def _compute_bn_forward(input, learned_gamma=None, learned_beta=None): @@ -26,6 +27,39 @@ def _compute_bn_forward(input, learned_gamma=None, learned_beta=None): return output, x_hat, global_mean, global_var, global_num_rows +def _compute_bn_backward( + grad_output, x, x_hat, mean, var, num_rows, learned_gamma=None, learned_beta=None +): + if learned_gamma is not None and learned_beta is not None: + local_dbeta = torch.sum(grad_output, dim=0) + global_dbeta = local_dbeta.clone() + dist.all_reduce(global_dbeta, op=dist.ReduceOp.SUM) + local_dgamma = torch.sum(grad_output * x_hat, dim=0) + global_dgamma = local_dgamma.clone() + dist.all_reduce(global_dgamma, op=dist.ReduceOp.SUM) + dx_hat = grad_output * learned_gamma + else: + dx_hat = grad_output + global_dgamma = None + global_dbeta = None + + local_dvar = torch.sum(dx_hat * (x - mean) * -0.5 * (var + 1e-5) ** 2, dim=0) + global_dvar = local_dvar.clone() + dist.all_reduce(global_dvar, op=dist.ReduceOp.SUM) + + local_dmean = torch.sum( + dx_hat * -1 / torch.sqrt(var + 1e-5), dim=0 + ) + global_dvar * torch.mean(-2 * (x - mean), dim=0) + global_dmean = local_dmean.clone() + dist.all_reduce(global_dmean, op=dist.ReduceOp.SUM) + dx = ( + (dx_hat / torch.sqrt(var + 1e-5)) + + (global_dvar * 2 * (x - mean) / num_rows) + + (global_dmean / num_rows) + ) + return dx, global_dgamma, global_dbeta + + class DistributedBN_with_Recompute(Function): @staticmethod def forward(ctx, input, learned_gamma=None, learned_beta=None): @@ -52,32 +86,8 @@ def backward(ctx, grad_output): learned_beta = ctx.learned_beta num_rows = ctx.num_rows - if learned_gamma is not None and learned_beta is not None: - local_dbeta = torch.sum(grad_output, dim=0) - global_dbeta = local_dbeta.clone() - dist.all_reduce(global_dbeta, op=dist.ReduceOp.SUM) - local_dgamma = torch.sum(grad_output * x_hat, dim=0) - global_dgamma = local_dgamma.clone() - dist.all_reduce(global_dgamma, op=dist.ReduceOp.SUM) - dx_hat = grad_output * learned_gamma - else: - dx_hat = grad_output - global_dgamma = None - global_dbeta = None - - local_dvar = torch.sum(dx_hat * (x - mean) * -0.5 * (var + 1e-5) ** 2, dim=0) - global_dvar = local_dvar.clone() - dist.all_reduce(global_dvar, op=dist.ReduceOp.SUM) - - local_dmean = torch.sum( - dx_hat * -1 / torch.sqrt(var + 1e-5), dim=0 - ) + global_dvar * torch.mean(-2 * (x - mean), dim=0) - global_dmean = local_dmean.clone() - dist.all_reduce(global_dmean, op=dist.ReduceOp.SUM) - dx = ( - (dx_hat / torch.sqrt(var + 1e-5)) - + (global_dvar * 2 * (x - mean) / num_rows) - + (global_dmean / num_rows) + dx, global_dgamma, global_dbeta = _compute_bn_backward( + grad_output, x, x_hat, mean, var, num_rows, learned_gamma, learned_beta ) return dx, global_dgamma, global_dbeta @@ -110,33 +120,8 @@ def backward(ctx, grad_output): x_hat = ctx.x_hat num_rows = ctx.num_rows x = ctx.input - - if learned_gamma is not None and learned_beta is not None: - local_dbeta = torch.sum(grad_output, dim=0) - global_dbeta = local_dbeta.clone() - dist.all_reduce(global_dbeta, op=dist.ReduceOp.SUM) - local_dgamma = torch.sum(grad_output * x_hat, dim=0) - global_dgamma = local_dgamma.clone() - dist.all_reduce(global_dgamma, op=dist.ReduceOp.SUM) - dx_hat = grad_output * learned_gamma - else: - dx_hat = grad_output - global_dgamma = None - global_dbeta = None - - local_dvar = torch.sum(dx_hat * (x - mean) * -0.5 * (var + 1e-5) ** 2, dim=0) - global_dvar = local_dvar.clone() - dist.all_reduce(global_dvar, op=dist.ReduceOp.SUM) - - local_dmean = torch.sum( - dx_hat * -1 / torch.sqrt(var + 1e-5), dim=0 - ) + global_dvar * torch.mean(-2 * (x - mean), dim=0) - global_dmean = local_dmean.clone() - dist.all_reduce(global_dmean, op=dist.ReduceOp.SUM) - dx = ( - (dx_hat / torch.sqrt(var + 1e-5)) - + (global_dvar * 2 * (x - mean) / num_rows) - + (global_dmean / num_rows) + dx, global_dgamma, global_dbeta = _compute_bn_backward( + grad_output, x, x_hat, mean, var, num_rows, learned_gamma, learned_beta ) return dx, global_dgamma, global_dbeta @@ -150,11 +135,52 @@ def __init__( momentum=0.1, affine=True, track_running_stats=True, + recompute=False, ): super(DistributedBatchNorm1D, self).__init__() - self.bn = nn.BatchNorm1d( - num_features, eps, momentum, affine, track_running_stats - ) + if affine: + self.gamma = nn.Parameter(torch.ones(num_features)) + self.beta = nn.Parameter(torch.zeros(num_features)) + else: + self.register_parameter("gamma", None) + self.register_parameter("beta", None) + self.eps = eps + self.momentum = momentum + self.track_running_stats = track_running_stats + if self.track_running_stats: + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) + self.register_buffer( + "num_batches_tracked", torch.tensor(0, dtype=torch.long) + ) + else: + self.register_parameter("running_mean", None) + self.register_parameter("running_var", None) + self.register_parameter("num_batches_tracked", None) + self.recompute = recompute + if recompute: + self.bn: Callable = DistributedBN_with_Recompute.apply + else: + self.bn: Callable = DistributedBN_Impl.apply def forward(self, x): + if self.training: + if self.track_running_stats: + self.num_batches_tracked += 1 + y, mean, var = self.bn(x, self.gamma, self.beta) + if self.track_running_stats: + with torch.no_grad(): + self.running_mean = ( + 1 - self.momentum + ) * self.running_mean + self.momentum * mean + self.running_var = ( + 1 - self.momentum + ) * self.running_var + self.momentum * var + else: + y = (x - self.running_mean) / torch.sqrt(self.running_var + self.eps) + if self.gamma is not None and self.beta is not None: + y = y * self.gamma + self.beta + + return y + return self.bn(x) From 9cdadb6cda16556044770df20b96ed655b234d3a Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 11 Sep 2025 23:01:15 -0700 Subject: [PATCH 03/48] Add DGraph imlementation of RGat --- experiments/OGB-LSC/RGAT.py | 71 +++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 experiments/OGB-LSC/RGAT.py diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py new file mode 100644 index 0000000..86e2a5e --- /dev/null +++ b/experiments/OGB-LSC/RGAT.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import torch.distributed as dist + + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels): + super(ConvLayer, self).__init__() + self.conv = nn.Linear(in_channels, out_channels) + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.act(x) + return x + + +class CommAwareGAT(nn.Module): + def __init__(self, in_channels, out_channels, comm, bias=True, residual=False): + super(CommAwareGAT, self).__init__() + self.conv1 = nn.Linear(in_channels, out_channels, bias=False) + self.comm = comm + self.project_message = nn.Linear(2 * out_channels, 1) + self.leaky_relu = nn.LeakyReLU(0.2) + self.residual = residual + if self.residual: + self.res_net = nn.Linear(in_channels, out_channels, bias=False) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + nn.init.zeros_(self.bias) + else: + self.register_parameter("bias", None) + + def forward( + self, x, edge_index, rank_mapping, gather_cache=None, scatter_cache=None + ): + h = self.conv1(x) + _src_indices = edge_index[:, 0, :] + _dst_indices = edge_index[:, 1, :] + _src_rank_mappings = torch.cat( + [rank_mapping[0].unsqueeze(0), rank_mapping[0].unsqueeze(0)], dim=0 + ) + _dst_rank_mappings = torch.cat( + [rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0 + ) + h_i = self.comm.gather(h, _dst_indices, _dst_rank_mappings, cache=gather_cache) + h_j = self.comm.gather(h, _src_indices, _src_rank_mappings, cache=gather_cache) + messages = torch.cat([h_i, h_j], dim=-1) + edge_scores = self.leaky_relu(self.project_message(messages)).squeeze(-1) + numerator = torch.exp(edge_scores) + denominator = self.comm.scatter( + numerator, _dst_indices, _dst_rank_mappings, h.size(1), cache=scatter_cache + ) + denominator = self.comm.gather( + denominator, _src_indices, _src_rank_mappings, cache=gather_cache + ) + alpha_ij = numerator / (denominator + 1e-16) + attention_messages = h_j * alpha_ij.unsqueeze(-1) + out = self.comm.scatter( + attention_messages, + _src_indices, + _src_rank_mappings, + h.size(1), + cache=scatter_cache, + ) + if self.residual: + out = out + self.res_net(x) + if self.bias is not None: + out = out + self.bias + + return out From d44bc83b17ff6a1cd5a7173ddaa51f40c720b664 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 11 Sep 2025 23:19:56 -0700 Subject: [PATCH 04/48] Adding PyG sparse tensor wrapper if needed --- experiments/OGB-LSC/pyg_wrappers.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 experiments/OGB-LSC/pyg_wrappers.py diff --git a/experiments/OGB-LSC/pyg_wrappers.py b/experiments/OGB-LSC/pyg_wrappers.py new file mode 100644 index 0000000..0d76b6e --- /dev/null +++ b/experiments/OGB-LSC/pyg_wrappers.py @@ -0,0 +1,28 @@ +from torch_sparse import SparseTensor + + +class DGraphSparseTensor(SparseTensor): + def __init__( + self, + row, + col, + value=None, + sparse_sizes=None, + is_sorted=False, + comm=None, + rank_mapping=None, + **kwargs, + ): + super(DGraphSparseTensor, self).__init__( + row=row, + col=col, + value=value, + sparse_sizes=sparse_sizes, + is_sorted=is_sorted, + ) + assert comm is not None, "Comm object cannot be None" + assert rank_mapping is not None, "rank_mapping cannot be None" + self.comm = comm + self.rank_mapping = rank_mapping + self.world_size = comm.get_world_size() + self.rank = comm.get_rank() From b539e79444d786edb5960af93b1e3d520bb36c25 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 11 Sep 2025 23:34:17 -0700 Subject: [PATCH 05/48] Add RGat implementation --- experiments/OGB-LSC/RGAT.py | 69 ++++++++++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index 86e2a5e..6f0cc69 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.distributed as dist +from .distributed_layers import DistributedBatchNorm1D class ConvLayer(nn.Module): @@ -16,13 +17,16 @@ def forward(self, x): class CommAwareGAT(nn.Module): - def __init__(self, in_channels, out_channels, comm, bias=True, residual=False): + def __init__( + self, in_channels, out_channels, comm, heads=1, bias=True, residual=False + ): super(CommAwareGAT, self).__init__() self.conv1 = nn.Linear(in_channels, out_channels, bias=False) self.comm = comm self.project_message = nn.Linear(2 * out_channels, 1) self.leaky_relu = nn.LeakyReLU(0.2) self.residual = residual + self.heads = heads if self.residual: self.res_net = nn.Linear(in_channels, out_channels, bias=False) if bias: @@ -69,3 +73,66 @@ def forward( out = out + self.bias return out + + +class CommAwareRGAT(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + num_relations, + num_layers, + heads, + comm, + dropout=0.5, + ): + super(CommAwareRGAT, self).__init__() + self.layers = nn.ModuleList() + self.bn_layers = nn.ModuleList() + self.skip_layers = nn.ModuleList() + self.num_layers = num_layers + self.dropout = dropout + self.comm = comm + relation_specific_convs = nn.ModuleList() + for _ in range(num_relations): + relation_specific_convs.append( + CommAwareGAT( + in_channels, + hidden_channels, + heads=heads, + bias=True, + residual=True, + comm=comm, + ) + ) + self.layers.append(relation_specific_convs) + + for _ in range(num_layers - 1): + relation_specific_convs = nn.ModuleList() + for _ in range(num_relations): + relation_specific_convs.append( + CommAwareGAT( + hidden_channels * heads, + hidden_channels, + heads=heads, + bias=True, + residual=True, + comm=comm, + ) + ) + self.layers.append(relation_specific_convs) + + for _ in range(num_layers): + self.bn_layers.append(DistributedBatchNorm1D(hidden_channels)) + self.skip_layers.append(nn.Linear(in_channels, hidden_channels)) + for _ in range(num_layers - 1): + self.skip_layers.append(nn.Linear(hidden_channels, hidden_channels)) + + self.mlp = nn.Sequential( + nn.Linear(hidden_channels, hidden_channels), + DistributedBatchNorm1D(hidden_channels), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(hidden_channels, out_channels), + ) From 63109a3f8bc9ee9b8ad30a131000d7f77cd3675f Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Sun, 14 Sep 2025 22:18:08 -0700 Subject: [PATCH 06/48] Completed with RGAT implementation - Add datasets for synthetic and MAG240M graphs - Add cache generators for the comm - Fix implementation bug for GAT --- experiments/OGB-LSC/CacheGenerator.py | 84 +++++++++ experiments/OGB-LSC/README.md | 0 experiments/OGB-LSC/RGAT.py | 204 ++++++++++++++++++++-- experiments/OGB-LSC/Trainer.py | 68 ++++++++ experiments/OGB-LSC/config.py | 35 ++++ experiments/OGB-LSC/distributed_layers.py | 24 ++- experiments/OGB-LSC/main.py | 107 ++++++++++++ experiments/OGB-LSC/pyg_wrappers.py | 38 ++-- 8 files changed, 533 insertions(+), 27 deletions(-) create mode 100644 experiments/OGB-LSC/CacheGenerator.py create mode 100644 experiments/OGB-LSC/README.md create mode 100644 experiments/OGB-LSC/Trainer.py create mode 100644 experiments/OGB-LSC/config.py create mode 100644 experiments/OGB-LSC/main.py diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py new file mode 100644 index 0000000..56a2604 --- /dev/null +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -0,0 +1,84 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +import torch + +import os.path as osp +from DGraph.distributed.nccl._nccl_cache import ( + NCCLGatherCacheGenerator, + NCCLScatterCacheGenerator, +) + + +def get_cache( + src_gather_cache, + dest_gather_cache, + dest_scatter_cache, + src_gather_cache_file, + dest_gather_cache_file, + dest_scatter_cache_file, + rank, + world_size, + src_indices, + dest_indices, + edge_location, + src_data_mappings, + dest_data_mappings, + num_input_rows, + num_output_rows, +): + if src_gather_cache is None: + + _src_gather_cache = NCCLGatherCacheGenerator( + indices=src_indices, + edge_placement=edge_location, + edge_dest_ranks=src_data_mappings, + num_input_rows=num_input_rows, + rank=rank, + world_size=world_size, + ) + + torch.save(_src_gather_cache, src_gather_cache_file) + else: + _src_gather_cache = src_gather_cache + + if dest_scatter_cache is None: + _dest_scatter_cache = NCCLScatterCacheGenerator( + indices=dest_indices, + edge_placement=edge_location, + edge_dest_ranks=dest_data_mappings, + num_output_rows=num_output_rows, + rank=rank, + world_size=world_size, + ) + + torch.save(_dest_scatter_cache, dest_scatter_cache_file) + else: + _dest_scatter_cache = dest_scatter_cache + + if dest_gather_cache is None: + _dest_gather_cache = NCCLGatherCacheGenerator( + indices=dest_indices, + edge_placement=edge_location, + edge_dest_ranks=dest_data_mappings, + num_input_rows=num_input_rows, + rank=rank, + world_size=world_size, + ) + + torch.save(_dest_gather_cache, dest_gather_cache_file) + else: + _dest_gather_cache = dest_gather_cache + + return _src_gather_cache, _dest_scatter_cache, _dest_gather_cache diff --git a/experiments/OGB-LSC/README.md b/experiments/OGB-LSC/README.md new file mode 100644 index 0000000..e69de29 diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index 6f0cc69..034e0d3 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -1,7 +1,23 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + import torch import torch.nn as nn import torch.distributed as dist from .distributed_layers import DistributedBatchNorm1D +import os.path as osp +from .CacheGenerator import get_cache class ConvLayer(nn.Module): @@ -18,7 +34,14 @@ def forward(self, x): class CommAwareGAT(nn.Module): def __init__( - self, in_channels, out_channels, comm, heads=1, bias=True, residual=False + self, + in_channels, + out_channels, + comm, + heads=1, + bias=True, + residual=False, + hetero=False, ): super(CommAwareGAT, self).__init__() self.conv1 = nn.Linear(in_channels, out_channels, bias=False) @@ -27,6 +50,7 @@ def __init__( self.leaky_relu = nn.LeakyReLU(0.2) self.residual = residual self.heads = heads + self.hetero = hetero if self.residual: self.res_net = nn.Linear(in_channels, out_channels, bias=False) if bias: @@ -36,9 +60,22 @@ def __init__( self.register_parameter("bias", None) def forward( - self, x, edge_index, rank_mapping, gather_cache=None, scatter_cache=None + self, + x, + edge_index, + rank_mapping, + x_j=None, + src_gather_cache=None, + dest_gather_cache=None, + dest_scatter_cache=None, ): h = self.conv1(x) + if self.hetero: + assert x_j is not None + h_j = self.conv1(x_j) + else: + h_j = h + _src_indices = edge_index[:, 0, :] _dst_indices = edge_index[:, 1, :] _src_rank_mappings = torch.cat( @@ -47,25 +84,34 @@ def forward( _dst_rank_mappings = torch.cat( [rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0 ) - h_i = self.comm.gather(h, _dst_indices, _dst_rank_mappings, cache=gather_cache) - h_j = self.comm.gather(h, _src_indices, _src_rank_mappings, cache=gather_cache) - messages = torch.cat([h_i, h_j], dim=-1) + h_i = self.comm.gather( + h, _dst_indices, _dst_rank_mappings, cache=dest_gather_cache + ) + h_j = self.comm.gather( + h_j, _src_indices, _src_rank_mappings, cache=src_gather_cache + ) + + messages = torch.cat([h_i, h_j], dim=1) edge_scores = self.leaky_relu(self.project_message(messages)).squeeze(-1) numerator = torch.exp(edge_scores) denominator = self.comm.scatter( - numerator, _dst_indices, _dst_rank_mappings, h.size(1), cache=scatter_cache + numerator, + _dst_indices, + _dst_rank_mappings, + h.size(1), + cache=dest_scatter_cache, ) denominator = self.comm.gather( - denominator, _src_indices, _src_rank_mappings, cache=gather_cache + denominator, _src_indices, _src_rank_mappings, cache=dest_gather_cache ) alpha_ij = numerator / (denominator + 1e-16) attention_messages = h_j * alpha_ij.unsqueeze(-1) out = self.comm.scatter( attention_messages, - _src_indices, - _src_rank_mappings, + _dst_indices, + _dst_rank_mappings, h.size(1), - cache=scatter_cache, + cache=dest_scatter_cache, ) if self.residual: out = out + self.res_net(x) @@ -86,6 +132,8 @@ def __init__( heads, comm, dropout=0.5, + use_cache=True, + cache_file_path="rgat_cache", ): super(CommAwareRGAT, self).__init__() self.layers = nn.ModuleList() @@ -94,7 +142,9 @@ def __init__( self.num_layers = num_layers self.dropout = dropout self.comm = comm - relation_specific_convs = nn.ModuleList() + self.use_cache = use_cache + relation_specific_convs = [] + for _ in range(num_relations): relation_specific_convs.append( CommAwareGAT( @@ -104,12 +154,13 @@ def __init__( bias=True, residual=True, comm=comm, + hetero=True, ) ) - self.layers.append(relation_specific_convs) + self.layers.append(nn.ModuleList(relation_specific_convs)) for _ in range(num_layers - 1): - relation_specific_convs = nn.ModuleList() + relation_specific_convs = [] for _ in range(num_relations): relation_specific_convs.append( CommAwareGAT( @@ -119,12 +170,14 @@ def __init__( bias=True, residual=True, comm=comm, + hetero=True, ) ) - self.layers.append(relation_specific_convs) + self.layers.append(nn.ModuleList(relation_specific_convs)) for _ in range(num_layers): self.bn_layers.append(DistributedBatchNorm1D(hidden_channels)) + self.skip_layers.append(nn.Linear(in_channels, hidden_channels)) for _ in range(num_layers - 1): self.skip_layers.append(nn.Linear(hidden_channels, hidden_channels)) @@ -136,3 +189,126 @@ def __init__( nn.Dropout(dropout), nn.Linear(hidden_channels, out_channels), ) + self.num_relations = num_relations + + # Caching for RGAT is a little bit tricky. There are three types of communication + # 1. Source gather (gathering source node features from source ranks) + # 2. Destination gather (gathering destination node features from destination ranks) + # 3. Destination scatter (scattering the messages to destination ranks) + # That gets repeated for each relation type. + # So we will have 3 * num_relations cache files + + self.src_gather_cache_files = [ + ( + f"{cache_file_path}_src_gather_cache_rel_{rel}_rank" + + f"_{comm.get_world_size()}_{comm.get_rank()}.pt" + ) + for rel in range(num_relations) + ] + + self.dest_scatter_cache_files = [ + ( + f"{cache_file_path}_dest_scatter_cache_rel_{rel}_rank" + + f"_{comm.get_world_size()}_{comm.get_rank()}.pt" + ) + for rel in range(num_relations) + ] + self.dest_gather_cache_files = [ + ( + f"{cache_file_path}_dest_gather_cache_rel_{rel}_rank" + + f"_{comm.get_world_size()}_{comm.get_rank()}.pt" + ) + for rel in range(num_relations) + ] + self.src_gather_caches = [] + self.dest_scatter_caches = [] + self.dest_gather_caches = [] + + if self.use_cache: + for caches in zip( + self.src_gather_cache_files, + self.dest_scatter_cache_files, + self.dest_gather_cache_files, + ): + ( + src_gather_cache_file, + dest_scatter_cache_file, + dest_gather_cache_file, + ) = caches + if ( + osp.exists(src_gather_cache_file) + and osp.exists(dest_scatter_cache_file) + and osp.exists(dest_gather_cache_file) + ): + _src_gather_cache = torch.load( + src_gather_cache_file, weights_only=False + ) + _dest_scatter_cache = torch.load( + dest_scatter_cache_file, weights_only=False + ) + _dest_gather_cache = torch.load( + dest_gather_cache_file, weights_only=False + ) + self.src_gather_caches.append(_src_gather_cache) + self.dest_scatter_caches.append(_dest_scatter_cache) + self.dest_gather_caches.append(_dest_gather_cache) + else: + self.src_gather_caches.append(None) + self.dest_scatter_caches.append(None) + self.dest_gather_caches.append(None) + + def forward(self, xs, adjts, edge_types, rank_mappings): + assert len(adjts) == len(edge_types) + assert len(adjts) == self.num_relations + + outs = xs + + for i in range(self.num_layers): + temp_outs = [self.skip_layers[i](outs[feat]) for feat in range(len(outs))] + for j, (edge_index, edge_type, rank_mapping) in enumerate( + zip(adjts, edge_types, rank_mappings) + ): + if self.use_cache: + caches = get_cache( + src_gather_cache=self.src_gather_caches[j], + dest_gather_cache=self.dest_gather_caches[j], + dest_scatter_cache=self.dest_scatter_caches[j], + src_gather_cache_file=self.src_gather_cache_files[j], + dest_scatter_cache_file=self.dest_scatter_cache_files[j], + dest_gather_cache_file=self.dest_gather_cache_files[j], + rank=self.comm.get_rank(), + world_size=self.comm.get_world_size(), + src_indices=edge_index[:, 0, :], + dest_indices=edge_index[:, 1, :], + edge_location=rank_mapping[0], + src_data_mappings=rank_mapping[0], + dest_data_mappings=rank_mapping[1], + num_input_rows=outs[0].size(0), + num_output_rows=outs[1].size(0), + ) + src_gather_cache, dest_scatter_cache, dest_gather_cache = caches + else: + src_gather_cache = None + dest_scatter_cache = None + dest_gather_cache = None + + src_edge_type, dst_edge_type = edge_type + temp_outs[dst_edge_type] += self.layers[i][j]( + temp_outs[dst_edge_type], + edge_index, + rank_mapping, + x_j=temp_outs[src_edge_type], + src_gather_cache=src_gather_cache, + dest_gather_cache=dest_gather_cache, + dest_scatter_cache=dest_scatter_cache, + ) + outs = [ + self.bn_layers[i](temp_outs[feat]) for feat in range(len(temp_outs)) + ] + outs = [torch.relu(outs[feat]) for feat in range(len(outs))] + outs = [ + torch.dropout(outs[feat], p=self.dropout, train=self.training) + for feat in range(len(outs)) + ] + + return self.mlp(outs) diff --git a/experiments/OGB-LSC/Trainer.py b/experiments/OGB-LSC/Trainer.py new file mode 100644 index 0000000..45008f4 --- /dev/null +++ b/experiments/OGB-LSC/Trainer.py @@ -0,0 +1,68 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +import torch +from RGAT import CommAwareRGAT +from config import ModelConfig, TrainingConfig + + +class Trainer: + def __init__(self, dataset, comm): + self.dataset = dataset + self.comm = comm + self.model_config = ModelConfig() + self.training_config = TrainingConfig() + self.device = torch.device(f"cuda:{comm.get_local_rank()}") + self.model = CommAwareRGAT( + in_channels=dataset.num_features, + out_channels=dataset.num_classes, + hidden_channels=self.model_config.hidden_channels, + num_relations=dataset.num_relations, + num_layers=self.model_config.num_layers, + heads=self.model_config.heads, + comm=comm, + dropout=self.model_config.dropout, + ).to(self.device) + + def train(self): + self.model.train() + for epoch in range(1, self.training_config.epochs + 1): + out = self.model( + self.dataset.x, self.dataset.edge_index, self.dataset.rank_mapping + ) + loss = torch.nn.functional.cross_entropy( + out[self.dataset.train_mask], self.dataset.y[self.dataset.train_mask] + ) + loss.backward() + return loss.item() + + @torch.no_grad() + def evaluate(self): + self.model.eval() + out = self.model( + self.dataset.x, self.dataset.edge_index, self.dataset.rank_mapping + ) + y_true = self.dataset.y.cpu().numpy() + y_pred = out.argmax(dim=-1, keepdim=True).cpu().numpy() + + train_acc = ( + y_pred[self.dataset.train_mask] == y_true[self.dataset.train_mask] + ).sum() / int(self.dataset.train_mask.sum()) + val_acc = ( + y_pred[self.dataset.val_mask] == y_true[self.dataset.val_mask] + ).sum() / int(self.dataset.val_mask.sum()) + test_acc = ( + y_pred[self.dataset.test_mask] == y_true[self.dataset.test_mask] + ).sum() / int(self.dataset.test_mask.sum()) + + return train_acc, val_acc, test_acc diff --git a/experiments/OGB-LSC/config.py b/experiments/OGB-LSC/config.py new file mode 100644 index 0000000..013fcfb --- /dev/null +++ b/experiments/OGB-LSC/config.py @@ -0,0 +1,35 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +from dataclasses import dataclass + + +@dataclass +class ModelConfig: + hidden_channels: int = 1024 + dropout: float = 0.5 + num_layers: int = 2 + num_features: int = 768 + num_relations: int = 5 + num_classes: int = 153 + heads: int = 4 + use_cache: bool = True + + +@dataclass +class TrainingConfig: + epochs: int = 100 + lr: float = 0.0001 + lr_step_size: int = 25 + lr_gamma: float = 0.25 diff --git a/experiments/OGB-LSC/distributed_layers.py b/experiments/OGB-LSC/distributed_layers.py index 541fe16..4fd76d4 100644 --- a/experiments/OGB-LSC/distributed_layers.py +++ b/experiments/OGB-LSC/distributed_layers.py @@ -1,3 +1,17 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + import torch from torch import nn import torch.distributed as dist @@ -164,6 +178,12 @@ def __init__( self.bn: Callable = DistributedBN_Impl.apply def forward(self, x): + if x.dim() == 3: + assert x.size(0) == 1, "only mini-batch size 1 is supported" + x = x.squeeze(0) + elif x.dim() != 2: + raise ValueError("Expected 2D or 3D input (got {}D input)".format(x.dim())) + if self.training: if self.track_running_stats: self.num_batches_tracked += 1 @@ -181,6 +201,6 @@ def forward(self, x): if self.gamma is not None and self.beta is not None: y = y * self.gamma + self.beta + if y.dim() == 2: + y = y.unsqueeze(0) return y - - return self.bn(x) diff --git a/experiments/OGB-LSC/main.py b/experiments/OGB-LSC/main.py new file mode 100644 index 0000000..6c858af --- /dev/null +++ b/experiments/OGB-LSC/main.py @@ -0,0 +1,107 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +import fire +import torch +from functools import partial +import os.path as osp +import DGraph.Communicator as Comm +from Trainer import Trainer +from config import ModelConfig +import torch.distributed as dist + + +def main( + comm_type: str = "nccl", + dataset: str = "synthetic", + num_papers: int = 2048, + num_authors: int = 512, + num_institutions: int = 16, + paper_rank_mapping_file: str = "", + author_rank_mapping_file: str = "", + institution_rank_mapping_file: str = "", + data_dir: str = "mag240m/data/MAG240M", +): + """Main function to run DGraph experiments on OGB-LSC datasets. + + Args: + comm_type (str): Type of communicator to use. Options are 'nccl' and + 'nvshmem'. Default is 'nccl'. + dataset (str): Dataset to use. Options are 'synthetic' and 'mag240m'. + Default is 'synthetic'. + num_papers (int): Number of paper nodes to use in the synthetic dataset. + Default is 2048. + num_authors (int): Number of author nodes to use in the synthetic dataset. + Default is 512. + num_institutions (int): Number of institution nodes to use in the synthetic + dataset. Default is 16. + paper_rank_mapping_file (str): Path to the paper rank mapping file for + mag240m dataset. Default is ''. + author_rank_mapping_file (str): Path to the author rank mapping file for + mag240m dataset. Default is not set. + institution_rank_mapping_file (str): Path to the institution rank mapping + file for mag240m dataset. Default is not set. + data_dir (str): Path to the mag240m dataset directory. Default is + 'mag240m/data/MAG240M'. + """ + assert dataset in ["synthetic", "mag240m"] + if dataset == "synthetic": + from synthetic.synthetic_dataset import HeterogeneousDataset as Dataset + + graph_dataset = partial( + Dataset, + num_papers=num_papers, + num_authors=num_authors, + num_institutions=num_institutions, + num_features=ModelConfig().num_features, + num_classes=ModelConfig().num_classes, + ) + + elif dataset == "mag240m": + from mag240m.DGraph_MAG240M import DGraph_MAG240M as Dataset + + assert osp.exists(paper_rank_mapping_file) + assert osp.exists(author_rank_mapping_file) + assert osp.exists(institution_rank_mapping_file) + paper_rank_mapping = torch.load(paper_rank_mapping_file, weights_only=False) + author_rank_mapping = torch.load(author_rank_mapping_file, weights_only=False) + institution_rank_mapping = torch.load( + institution_rank_mapping_file, weights_only=False + ) + + graph_dataset = partial( + Dataset, + paper_rank_mappings=paper_rank_mapping, + author_rank_mappings=author_rank_mapping, + institution_rank_mappings=institution_rank_mapping, + data_dir=data_dir, + ) + else: + raise ValueError(f"Invalid dataset: {dataset}") + + assert comm_type in ["nccl", "nvshmem"] + comm = Comm.Communicator.init_process_group(comm_type) + + graph_dataset = graph_dataset(comm=comm) + trainer = Trainer(graph_dataset, comm) + trainer.train() + comm.destroy() + + if dist.is_initialized(): + dist.destroy_process_group() + + return 0 + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/experiments/OGB-LSC/pyg_wrappers.py b/experiments/OGB-LSC/pyg_wrappers.py index 0d76b6e..7f76d57 100644 --- a/experiments/OGB-LSC/pyg_wrappers.py +++ b/experiments/OGB-LSC/pyg_wrappers.py @@ -1,28 +1,44 @@ -from torch_sparse import SparseTensor +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) -class DGraphSparseTensor(SparseTensor): +class DGraphSparseTensor: def __init__( self, row, col, value=None, - sparse_sizes=None, - is_sorted=False, comm=None, rank_mapping=None, **kwargs, ): - super(DGraphSparseTensor, self).__init__( - row=row, - col=col, - value=value, - sparse_sizes=sparse_sizes, - is_sorted=is_sorted, - ) + super(DGraphSparseTensor, self).__init__() assert comm is not None, "Comm object cannot be None" assert rank_mapping is not None, "rank_mapping cannot be None" self.comm = comm self.rank_mapping = rank_mapping self.world_size = comm.get_world_size() self.rank = comm.get_rank() + self.row = row + self.col = col + + def to(self, device): + self.row = self.row.to(device) + self.col = self.col.to(device) + if self.rank_mapping is not None: + self.rank_mapping = self.rank_mapping.to(device) + + if self.value is not None: + self.value = self.value.to(device) + return self From 335a65ec8b60253392c04a0504694f1887701128 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Mon, 15 Sep 2025 08:48:52 -0700 Subject: [PATCH 07/48] Add the synthetic dataset for testing purposes --- .../OGB-LSC/synthetic/synthetic_dataset.py | 249 ++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 experiments/OGB-LSC/synthetic/synthetic_dataset.py diff --git a/experiments/OGB-LSC/synthetic/synthetic_dataset.py b/experiments/OGB-LSC/synthetic/synthetic_dataset.py new file mode 100644 index 0000000..3790054 --- /dev/null +++ b/experiments/OGB-LSC/synthetic/synthetic_dataset.py @@ -0,0 +1,249 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +from DGraph.Communicator import Communicator +import torch + + +def _generate_paper_2_paper_edges(num_papers): + # Average degree of a paper is ~11 + num_edges = num_papers * 11 + coo_list = torch.randint( + low=0, high=num_papers, size=(2, num_edges), dtype=torch.long + ) + coo_list = torch.unique(coo_list, dim=1) + return coo_list + + +def _generate_paper_2_author_edges(num_papers, num_authors): + # Average number of authors per paper is ~3.5 + num_edges = int(num_papers * 3.5) + dest_papers = torch.randint( + low=0, high=num_papers, size=(1, num_edges), dtype=torch.long + ) + src_authors = torch.randint( + low=0, high=num_authors, size=(1, num_edges), dtype=torch.long + ) + coo_list = torch.cat([src_authors, dest_papers], dim=0) + coo_list = torch.unique(coo_list, dim=1) + return coo_list + + +def _generate_author_2_institution_edges(num_authors, num_institutions): + # Average number of institutions per author is ~0.35 + num_edges = int(num_authors * 0.35) + dest_num_institutions = torch.randint( + low=0, high=num_institutions, size=(1, num_edges), dtype=torch.long + ) + src_authors = torch.randint( + low=0, high=num_authors, size=(1, num_edges), dtype=torch.long + ) + coo_list = torch.cat([src_authors, dest_num_institutions], dim=0) + coo_list = torch.unique(coo_list, dim=1) + return coo_list + + +def _get_rank_mappings(num_vertices, world_size, rank): + vertices_per_rank = num_vertices // world_size + rank_mappings = torch.zeros(num_vertices, dtype=torch.uint8) + vertices_cur_rank = 0 + for r in range(world_size): + start = r * vertices_per_rank + end = (r + 1) * vertices_per_rank if r != world_size - 1 else num_vertices + rank_mappings[start:end] = r + if r == rank: + vertices_cur_rank = end - start + return rank_mappings, vertices_cur_rank + + +def edge_mapping_from_vertex_mapping(edge_index, rank_mappings): + # directed edges, so edge_index[0] -> edge_index[1] + src_indices = edge_index[0] + dest_indices = edge_index[1] + # We put the edge on the rank where the destination vertex is located + edge_placement = rank_mappings[dest_indices] + src_data_mappings = rank_mappings[src_indices] + dest_data_mappings = rank_mappings[dest_indices] + return (edge_placement, src_data_mappings, dest_data_mappings) + + +class HeterogeneousDataset: + def __init__( + self, + num_papers, + num_authors, + num_institutions, + num_features, + num_classes, + comm: Communicator, + ): + self.num_papers = num_papers + self.num_authors = num_authors + self.num_institutions = num_institutions + self.num_classes = num_classes + self.num_features = num_features + self.comm = comm + self.rank = comm.get_rank() + self.world_size = comm.get_world_size() + self.rank = comm.get_rank() + self.paper_vertex_rank_mapping, self.num_paper_vertices = _get_rank_mappings( + num_vertices=num_papers, world_size=self.world_size, rank=self.rank + ) + self.author_vertex_rank_mapping, self.num_author_vertices = _get_rank_mappings( + num_vertices=num_authors, world_size=self.world_size, rank=self.rank + ) + self.institution_vertex_rank_mapping, self.num_institution_vertices = ( + _get_rank_mappings( + num_vertices=num_institutions, + world_size=self.world_size, + rank=self.rank, + ) + ) + _vertices = torch.randperm(num_papers) + self.train_mask = _vertices[: int(0.7 * num_papers)] + self.val_mask = _vertices[int(0.7 * num_papers) : int(0.85 * num_papers)] + self.test_mask = _vertices[int(0.85 * num_papers) :] + self.y = torch.randint( + low=0, high=self.num_classes, size=(num_papers,), dtype=torch.long + ) + + self.paper_2_paper_edges = _generate_paper_2_paper_edges(num_papers) + + ( + paper_2_paper_edge_location, + paper_2_paper_src_data_mappings, + paper_2_paper_dest_data_mappings, + ) = edge_mapping_from_vertex_mapping( + edge_index=self.paper_2_paper_edges, + rank_mappings=self.paper_vertex_rank_mapping, + ) + + self.paper_edge_locations = paper_2_paper_edge_location + self.paper_src_data_mappings = paper_2_paper_src_data_mappings + self.paper_dest_data_mappings = paper_2_paper_dest_data_mappings + + self.paper_2_author_edges = _generate_paper_2_author_edges( + num_papers, num_authors + ) + + ( + paper_2_author_edge_location, + paper_2_author_src_data_mappings, + paper_2_author_dest_data_mappings, + ) = edge_mapping_from_vertex_mapping( + edge_index=self.paper_2_author_edges, + rank_mappings=self.author_vertex_rank_mapping, + ) + self.paper_2_author_edge_locations = paper_2_author_edge_location + self.paper_2_author_src_data_mappings = paper_2_author_src_data_mappings + self.paper_2_author_dest_data_mappings = paper_2_author_dest_data_mappings + + self.author_2_institution_edges = _generate_author_2_institution_edges( + num_authors, num_institutions + ) + + ( + author_2_institution_edge_location, + author_2_institution_src_data_mappings, + author_2_institution_dest_data_mappings, + ) = edge_mapping_from_vertex_mapping( + edge_index=self.author_2_institution_edges, + rank_mappings=self.institution_vertex_rank_mapping, + ) + self.author_2_institution_edge_locations = author_2_institution_edge_location + self.author_2_institution_src_data_mappings = ( + author_2_institution_src_data_mappings + ) + self.author_2_institution_dest_data_mappings = ( + author_2_institution_dest_data_mappings + ) + + paper_vertices_cur_rank = int( + (self.paper_vertex_rank_mapping == self.rank).sum() + ) + author_vertices_cur_rank = int( + (self.author_vertex_rank_mapping == self.rank).sum() + ) + institution_vertices_cur_rank = int( + (self.institution_vertex_rank_mapping == self.rank).sum() + ) + + self.paper_features = torch.randn( + (self.num_papers, paper_vertices_cur_rank), dtype=torch.float32 + ) + self.author_features = torch.randn( + (self.num_authors, author_vertices_cur_rank), dtype=torch.float32 + ) + self.institution_features = torch.randn( + (self.num_institutions, institution_vertices_cur_rank), dtype=torch.float32 + ) + + def get_validation_mask(self): + # Only papers are classified + validation_vertices_mappings = self.paper_vertex_rank_mapping[self.val_mask] + num_validation_vertices = (validation_vertices_mappings == self.rank).sum() + if num_validation_vertices > 0: + return self.val_mask[validation_vertices_mappings == self.rank] + else: + return torch.tensor([], dtype=torch.long) + + def get_test_mask(self): + # Only papers are classified + paper_vertices = self.paper_vertex_rank_mapping == self.rank + num_test_vertices = (paper_vertices[self.test_mask] == self.rank).sum() + if num_test_vertices > 0: + return self.test_mask[paper_vertices[self.test_mask] == self.rank] + else: + return torch.tensor([], dtype=torch.long) + + def __len__(self): + return 0 + + def __getitem__(self, idx): + # There are 5 relations: + # paper -> paper + # paper -> author + # author -> paper + # author -> institution + # institution -> author + edge_index = [ + self.paper_2_paper_edges, + self.paper_2_author_edges, + self.paper_2_author_edges.flip(0), + self.author_2_institution_edges, + self.author_2_institution_edges.flip(0), + ] + # Locations of the edges + rank_mappings = [ + [self.paper_edge_locations, self.paper_dest_data_mappings], + [self.paper_2_author_edge_locations, self.paper_2_author_src_data_mappings], + [ + self.paper_2_author_edge_locations, + self.paper_2_author_dest_data_mappings, + ], + [ + self.author_2_institution_edge_locations, + self.author_2_institution_src_data_mappings, + ], + [ + self.author_2_institution_dest_data_mappings, + self.author_2_institution_src_data_mappings, + ], + ] + edge_type = [(0, 0), (0, 1), (1, 0), (1, 2), (2, 1)] + features = [ + self.paper_features, + self.author_features, + self.institution_features, + ] + return (features, edge_index, edge_type, rank_mappings) From 2d9ff3f6c37f7139d139651f6770f91912deceb7 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Mon, 15 Sep 2025 09:42:14 -0700 Subject: [PATCH 08/48] Add MAG240M dataset --- experiments/OGB-LSC/README.md | 25 ++ experiments/OGB-LSC/mag240m/DGraph_MAG240M.py | 232 ++++++++++++++++++ experiments/OGB-LSC/mag240m/README.md | 29 +++ 3 files changed, 286 insertions(+) create mode 100644 experiments/OGB-LSC/mag240m/DGraph_MAG240M.py create mode 100644 experiments/OGB-LSC/mag240m/README.md diff --git a/experiments/OGB-LSC/README.md b/experiments/OGB-LSC/README.md index e69de29..0d79356 100644 --- a/experiments/OGB-LSC/README.md +++ b/experiments/OGB-LSC/README.md @@ -0,0 +1,25 @@ +# Directed Heterogeneous Graphs on DGraph + +`DGraph` supports arbitrary graph types, GNNs, and structures for distributed training. This example shows how to use `DGraph` to train a Relational Graph Attention Network ([RGAT](https://arxiv.org/abs/1703.06103)) on the [OGB-LSC MAG240M](https://ogb.stanford.edu/docs/lsc/mag240m/) dataset, which is a large-scale heterogeneous graph with three types of nodes (paper, author, institution) and three types of edges (paper->paper, paper->author, author->institution). + +## Requirements + + +## Data preparation +The dataset is fairly large (over 100GB). Please follow the instructions in the `mag240m` folder to download and preprocess the dataset. + +## Training +To train RGAT on a synthetic dataset, run the following command: + +```bash +torchrun-hpc -N -n main.py \ +--dataset synthetic --num_papers \ +--num_authors --num_institutions -n main.py --dataset mag240m \ +--data-path +``` diff --git a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py new file mode 100644 index 0000000..5f6ca78 --- /dev/null +++ b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py @@ -0,0 +1,232 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +from ogb.lsc import MAG240MDataset +import torch +from typing import Optional +from torch_sparse import SparseTensor +import numpy as np +from tqdm import tqdm +import os.path as osp + + +def get_col_slice(x, start_row_idx, end_row_idx, start_col_idx, end_col_idx): + """Obtained from: + https://github.com/snap-stanford/ogb/blob/master/examples/lsc/mag240m/rgnn.py + """ + outs = [] + chunk = 100000 + for i in tqdm(range(start_row_idx, end_row_idx, chunk)): + j = min(i + chunk, end_row_idx) + outs.append(x[i:j, start_col_idx:end_col_idx].copy()) + return np.concatenate(outs, axis=0) + + +def save_col_slice( + x_src, x_dst, start_row_idx, end_row_idx, start_col_idx, end_col_idx +): + """Obtained from: + https://github.com/snap-stanford/ogb/blob/master/examples/lsc/mag240m/rgnn.py + """ + assert x_src.shape[0] == end_row_idx - start_row_idx + assert x_src.shape[1] == end_col_idx - start_col_idx + chunk, offset = 100000, start_row_idx + for i in tqdm(range(0, end_row_idx - start_row_idx, chunk)): + j = min(i + chunk, end_row_idx - start_row_idx) + x_dst[offset + i : offset + j, start_col_idx:end_col_idx] = x_src[i:j] + + +def get_rank_mappings(num_nodes, world_size, rank): + nodes_per_rank = num_nodes // world_size + print(f"Rank {rank}: nodes_per_rank = {nodes_per_rank}") + # Don't use uint8 if world_size > 256 + # Doing this to save memory + if world_size > 256: + raise ValueError("world_size > 256 not supported yet") + rank_mappings = torch.zeros(num_nodes, dtype=torch.uint8) + for r in range(world_size): + start = r * nodes_per_rank + end = (r + 1) * nodes_per_rank if r != world_size - 1 else num_nodes + rank_mappings[start:end] = r + return rank_mappings + + +def get_edge_mappings(src_indices, dst_indices, rank_mappings): + edge_mappings = torch.zeros_like(src_indices) + # The edges are mapped to the rank of the destination node + # Because that is the accumulation rank + edge_mappings = rank_mappings[dst_indices] + return edge_mappings + + +def _generate_features_from_paper_features( + out: np.memmap, + num_nodes: int, + num_papers: int, + paper_feat: np.ndarray, + edge_index: np.ndarray, + num_features: int, +): + + row, col = torch.from_numpy(edge_index) + adj = SparseTensor( + row=row, col=col, sparse_sizes=(num_nodes, num_papers), is_sorted=True + ) + + dim_chunk_size = 64 + + for i in tqdm(range(0, num_features, dim_chunk_size)): + j = min(i + dim_chunk_size, num_features) + inputs = get_col_slice( + paper_feat, + start_row_idx=0, + end_row_idx=num_papers, + start_col_idx=i, + end_col_idx=j, + ) + inputs = torch.from_numpy(inputs) + out_ = adj.matmul(inputs, reduce="mean").numpy() # type: ignore + del inputs + save_col_slice( + x_src=out_, + x_dst=out, + start_row_idx=0, + end_row_idx=num_nodes, + start_col_idx=i, + end_col_idx=j, + ) + del out_ + out.flush() + + +class DGraph_MAG240M: + def __init__( + self, + comm, + data_dir: str = "data/MAG240M", + paper_rank_mappings: Optional[torch.Tensor] = None, + author_rank_mappings: Optional[torch.Tensor] = None, + institution_rank_mappings: Optional[torch.Tensor] = None, + ): + self.rank = comm.get_rank() + self.world_size = comm.get_world_size() + self.comm = comm + self.dataset = MAG240MDataset(root=data_dir) + self.num_papers = self.dataset.num_papers + self.num_authors = self.dataset.num_authors + self.num_institutions = self.dataset.num_institutions + self.num_classes = self.dataset.num_classes + self.paper_rank_mappings = ( + paper_rank_mappings + if paper_rank_mappings is not None + else get_rank_mappings(self.num_papers, self.world_size, self.rank) + ) + self.author_rank_mappings = ( + author_rank_mappings + if author_rank_mappings is not None + else get_rank_mappings(self.num_authors, self.world_size, self.rank) + ) + self.institution_rank_mappings = ( + institution_rank_mappings + if institution_rank_mappings is not None + else get_rank_mappings(self.num_institutions, self.world_size, self.rank) + ) + + # authors -> paper + self.write_mappings = get_edge_mappings( + self.dataset.edge_index("author", "paper")[0], + self.dataset.edge_index("author", "paper")[1], + self.paper_rank_mappings, + ) + + # author -> institution + self.write_mappings_author_institution = get_edge_mappings( + self.dataset.edge_index("author", "institution")[0], + self.dataset.edge_index("author", "institution")[1], + self.institution_rank_mappings, + ) + self.num_features = 768 + # paper -> paper + self.process_feature_data() + + def process_feature_data(self): + dataset = self.dataset + # This function emulates the data processing step here: + # https://github.com/snap-stanford/ogb/blob/61e9784ca76edeaa6e259ba0f836099608ff0586/examples/lsc/mag240m/rgnn.py#L82 + + # The above function converts the heterogenous graph to a homogeneous graph + # So we will do the same here + + # Generate author features + # Mag240M author features are generated from paper features + num_authors = dataset.num_authors + num_papers = dataset.num_papers + path = dataset.dir + paper_feat = dataset.paper_feat + + if not osp.exists(path + "/author_feat.npy"): + print("Generating author features") + author_feat = np.memmap( + filename=path + "/author_feat.npy", + mode="w+", + dtype=np.float16, + shape=(num_authors, self.num_features), + ) + + _generate_features_from_paper_features( + out=author_feat, + num_nodes=num_authors, + num_papers=num_papers, + paper_feat=paper_feat, + edge_index=dataset.edge_index("author", "paper"), + num_features=self.num_features, + ) + + if not osp.exists(path + "/institution_feat.npy"): + print("Generating institution features") + # Generate institution features + num_institutions = dataset.num_institutions + institution_feat = np.memmap( + filename=path + "/institution_feat.npy", + mode="w+", + dtype=np.float16, + shape=(num_institutions, self.num_features), + ) + print("Generating institution features") + _generate_features_from_paper_features( + out=institution_feat, + num_nodes=num_institutions, + num_papers=num_papers, + paper_feat=paper_feat, + edge_index=dataset.edge_index("author", "institution"), + num_features=self.num_features, + ) + print("Data processing complete") + + +if __name__ == "__main__": + import fire + + def main(data_dir: str = "data/MAG240M"): + rank = 0 + world_size = 64 + # Python is so weird haha + COMM = type( + "dummy_comm", + (object,), + {"get_rank": lambda self: rank, "get_world_size": lambda self: world_size}, + ) + comm = COMM() + dgraph = DGraph_MAG240M(comm, data_dir=data_dir) + + fire.Fire(main) diff --git a/experiments/OGB-LSC/mag240m/README.md b/experiments/OGB-LSC/mag240m/README.md new file mode 100644 index 0000000..5be12e9 --- /dev/null +++ b/experiments/OGB-LSC/mag240m/README.md @@ -0,0 +1,29 @@ +# Processing OGB-LSC MAG240M Dataset + +This directory contains the code to preprocess and load the OGB-LSC MAG240M dataset to use with DGraph. + +## Prerequisites + +Make sure you have the following packages installed: +- `torch` +- `torch_geometric` +- `ogb` +- `torch_sparse` +- `numpy` +- `tqdm` +- `fire` + +## Preprocessing the dataset +The MAG240M dataset is a fairly large graph dataset and requires some preprocessing before it can be used with DGraph, and takes a while to process. The following script processes the dataset and saves the processed data in a directory. + +```bash +python DGraph_MAG240M.py --data_dir +``` + +Make sure to replace `` with the path where you want to store the processed data. The script will download the dataset if it is not already present in the specified directory. The processed data will be saved in the same directory. + +The processing machine requires at least `128GB` of RAM to process the dataset. + + + + From c2530ad54e6bf479e2a34724548de4d0e75f2697 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Mon, 15 Sep 2025 15:01:05 -0700 Subject: [PATCH 09/48] Bug fixes to get things running correctly --- experiments/OGB-LSC/CacheGenerator.py | 65 ++++++++++++ experiments/OGB-LSC/RGAT.py | 11 ++- experiments/OGB-LSC/Trainer.py | 22 +++-- experiments/OGB-LSC/config.py | 9 ++ experiments/OGB-LSC/main.py | 2 + .../OGB-LSC/synthetic/synthetic_dataset.py | 99 +++++++++++++++++-- 6 files changed, 188 insertions(+), 20 deletions(-) diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py index 56a2604..b3caa83 100644 --- a/experiments/OGB-LSC/CacheGenerator.py +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -82,3 +82,68 @@ def get_cache( _dest_gather_cache = dest_gather_cache return _src_gather_cache, _dest_scatter_cache, _dest_gather_cache + + +if __name__ == "__main__": + from fire import Fire + from functools import partial + from config import SyntheticDatasetConfig + + # Use this script to generate the caches prior to running the main training script + # This is useful because cache generation can take a long time and could cause issues + # with timeouts on some systems. + + def main(dataset): + assert dataset in ["synthetic", "mag240m"] + if dataset == "synthetic": + from synthetic.synthetic_dataset import HeterogeneousDataset as Dataset + + synthetic_config = SyntheticDatasetConfig() + graph_dataset = partial( + Dataset, + num_papers=synthetic_config.num_papers, + num_authors=synthetic_config.num_authors, + num_institutions=synthetic_config.num_institutions, + num_features=synthetic_config.num_features, + num_classes=synthetic_config.num_classes, + ) + elif dataset == "mag240m": + from mag240m.DGraph_MAG240M import DGraph_MAG240M as Dataset + + graph_dataset = partial(Dataset, data_dir="data/MAG240M") + + rank = 0 + world_size = 16 + COMM = type( + "dummy_comm", + (object,), + {"get_rank": lambda self: rank, "get_world_size": lambda self: world_size}, + ) + comm = COMM() + + dataset = graph_dataset( + comm=comm, + ) + + dataset = dataset.add_batch_dimension() + dataset = dataset.to("cpu") + xs, edge_index, edge_type, rank_mapping = dataset[0] + print("Dataset loaded") + + breakpoint() + # get_cache( + # src_gather_cache=None, + # dest_gather_cache=None, + # dest_scatter_cache=None, + # src_gather_cache_file="paper_src_gather_cache.pt", + # dest_gather_cache_file="paper_dest_gather_cache.pt", + # dest_scatter_cache_file="paper_dest_scatter_cache.pt", + # rank=rank, + # world_size=world_size, + # src_indices=edge_index[0][0][0], + # dest_indices=edge_index[0][0][1], + # edge_location=rank_mapping[0], + # src_data_mappings=rank_mapping[2], + # dest_data_mappings=rank_mapping[3],) + + Fire(main) diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index 034e0d3..7b384b7 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -15,9 +15,9 @@ import torch import torch.nn as nn import torch.distributed as dist -from .distributed_layers import DistributedBatchNorm1D +from distributed_layers import DistributedBatchNorm1D import os.path as osp -from .CacheGenerator import get_cache +from CacheGenerator import get_cache class ConvLayer(nn.Module): @@ -190,7 +190,11 @@ def __init__( nn.Linear(hidden_channels, out_channels), ) self.num_relations = num_relations + self._setup_caches(cache_file_path) + def _setup_caches(self, cache_file_path): + num_relations = self.num_relations + comm = self.comm # Caching for RGAT is a little bit tricky. There are three types of communication # 1. Source gather (gathering source node features from source ranks) # 2. Destination gather (gathering destination node features from destination ranks) @@ -268,6 +272,7 @@ def forward(self, xs, adjts, edge_types, rank_mappings): for j, (edge_index, edge_type, rank_mapping) in enumerate( zip(adjts, edge_types, rank_mappings) ): + if self.use_cache: caches = get_cache( src_gather_cache=self.src_gather_caches[j], @@ -311,4 +316,4 @@ def forward(self, xs, adjts, edge_types, rank_mappings): for feat in range(len(outs)) ] - return self.mlp(outs) + return self.mlp(outs[0]) diff --git a/experiments/OGB-LSC/Trainer.py b/experiments/OGB-LSC/Trainer.py index 45008f4..13c15b1 100644 --- a/experiments/OGB-LSC/Trainer.py +++ b/experiments/OGB-LSC/Trainer.py @@ -22,24 +22,32 @@ def __init__(self, dataset, comm): self.comm = comm self.model_config = ModelConfig() self.training_config = TrainingConfig() - self.device = torch.device(f"cuda:{comm.get_local_rank()}") + # TODO: We need some better way to set the device but + # difficult to do that since systems have different bindings. + # self.device = torch.device(f"cuda:{comm.get_local_rank()}") + self.device = torch.device("cuda") self.model = CommAwareRGAT( - in_channels=dataset.num_features, - out_channels=dataset.num_classes, + in_channels=self.model_config.num_features, + out_channels=self.model_config.num_classes, hidden_channels=self.model_config.hidden_channels, - num_relations=dataset.num_relations, + num_relations=self.model_config.num_relations, num_layers=self.model_config.num_layers, heads=self.model_config.heads, comm=comm, dropout=self.model_config.dropout, ).to(self.device) + def prepare_data(self): + self.dataset = self.dataset.add_batch_dimension() + self.dataset = self.dataset.to(self.device) + def train(self): self.model.train() + + xs, edge_index, edge_type, rank_mapping = self.dataset[0] + for epoch in range(1, self.training_config.epochs + 1): - out = self.model( - self.dataset.x, self.dataset.edge_index, self.dataset.rank_mapping - ) + out = self.model(xs, edge_index, edge_type, rank_mapping) loss = torch.nn.functional.cross_entropy( out[self.dataset.train_mask], self.dataset.y[self.dataset.train_mask] ) diff --git a/experiments/OGB-LSC/config.py b/experiments/OGB-LSC/config.py index 013fcfb..0c0a3f5 100644 --- a/experiments/OGB-LSC/config.py +++ b/experiments/OGB-LSC/config.py @@ -33,3 +33,12 @@ class TrainingConfig: lr: float = 0.0001 lr_step_size: int = 25 lr_gamma: float = 0.25 + + +@dataclass +class SyntheticDatasetConfig: + num_papers: int = 2048 + num_authors: int = 512 + num_institutions: int = 16 + num_features: int = 768 + num_classes: int = 153 diff --git a/experiments/OGB-LSC/main.py b/experiments/OGB-LSC/main.py index 6c858af..bc32cb6 100644 --- a/experiments/OGB-LSC/main.py +++ b/experiments/OGB-LSC/main.py @@ -93,7 +93,9 @@ def main( comm = Comm.Communicator.init_process_group(comm_type) graph_dataset = graph_dataset(comm=comm) + trainer = Trainer(graph_dataset, comm) + trainer.prepare_data() trainer.train() comm.destroy() diff --git a/experiments/OGB-LSC/synthetic/synthetic_dataset.py b/experiments/OGB-LSC/synthetic/synthetic_dataset.py index 3790054..59ab66d 100644 --- a/experiments/OGB-LSC/synthetic/synthetic_dataset.py +++ b/experiments/OGB-LSC/synthetic/synthetic_dataset.py @@ -66,14 +66,16 @@ def _get_rank_mappings(num_vertices, world_size, rank): return rank_mappings, vertices_cur_rank -def edge_mapping_from_vertex_mapping(edge_index, rank_mappings): +def edge_mapping_from_vertex_mapping(edge_index, src_rank_mappings, dst_rank_mappings): # directed edges, so edge_index[0] -> edge_index[1] src_indices = edge_index[0] dest_indices = edge_index[1] # We put the edge on the rank where the destination vertex is located - edge_placement = rank_mappings[dest_indices] - src_data_mappings = rank_mappings[src_indices] - dest_data_mappings = rank_mappings[dest_indices] + # Since heterogeneous graphs have different rank mappings for different + # vertex types. + edge_placement = dst_rank_mappings[dest_indices] + src_data_mappings = src_rank_mappings[src_indices] + dest_data_mappings = dst_rank_mappings[dest_indices] return (edge_placement, src_data_mappings, dest_data_mappings) @@ -92,6 +94,7 @@ def __init__( self.num_institutions = num_institutions self.num_classes = num_classes self.num_features = num_features + self.num_relations = 5 self.comm = comm self.rank = comm.get_rank() self.world_size = comm.get_world_size() @@ -125,7 +128,8 @@ def __init__( paper_2_paper_dest_data_mappings, ) = edge_mapping_from_vertex_mapping( edge_index=self.paper_2_paper_edges, - rank_mappings=self.paper_vertex_rank_mapping, + src_rank_mappings=self.paper_vertex_rank_mapping, + dst_rank_mappings=self.paper_vertex_rank_mapping, ) self.paper_edge_locations = paper_2_paper_edge_location @@ -142,7 +146,8 @@ def __init__( paper_2_author_dest_data_mappings, ) = edge_mapping_from_vertex_mapping( edge_index=self.paper_2_author_edges, - rank_mappings=self.author_vertex_rank_mapping, + src_rank_mappings=self.author_vertex_rank_mapping, + dst_rank_mappings=self.paper_vertex_rank_mapping, ) self.paper_2_author_edge_locations = paper_2_author_edge_location self.paper_2_author_src_data_mappings = paper_2_author_src_data_mappings @@ -158,7 +163,8 @@ def __init__( author_2_institution_dest_data_mappings, ) = edge_mapping_from_vertex_mapping( edge_index=self.author_2_institution_edges, - rank_mappings=self.institution_vertex_rank_mapping, + src_rank_mappings=self.author_vertex_rank_mapping, + dst_rank_mappings=self.institution_vertex_rank_mapping, ) self.author_2_institution_edge_locations = author_2_institution_edge_location self.author_2_institution_src_data_mappings = ( @@ -179,13 +185,13 @@ def __init__( ) self.paper_features = torch.randn( - (self.num_papers, paper_vertices_cur_rank), dtype=torch.float32 + (paper_vertices_cur_rank, num_features), dtype=torch.float32 ) self.author_features = torch.randn( - (self.num_authors, author_vertices_cur_rank), dtype=torch.float32 + (author_vertices_cur_rank, num_features), dtype=torch.float32 ) self.institution_features = torch.randn( - (self.num_institutions, institution_vertices_cur_rank), dtype=torch.float32 + (institution_vertices_cur_rank, num_features), dtype=torch.float32 ) def get_validation_mask(self): @@ -209,6 +215,58 @@ def get_test_mask(self): def __len__(self): return 0 + def add_batch_dimension(self): + """Add a batch dimension to all tensors. This is particularly useful + because we only have one graph and DGraph is built to handle batches of graphs. + We want to do this here because this allows us to avoid copying the data + and requiring a data loader. + """ + self.paper_features = self.paper_features.unsqueeze(0) + self.author_features = self.author_features.unsqueeze(0) + self.institution_features = self.institution_features.unsqueeze(0) + self.y = self.y.unsqueeze(0) + self.train_mask = self.train_mask.unsqueeze(0) + self.val_mask = self.val_mask.unsqueeze(0) + self.test_mask = self.test_mask.unsqueeze(0) + self.paper_2_paper_edges = self.paper_2_paper_edges.unsqueeze(0) + self.paper_2_author_edges = self.paper_2_author_edges.unsqueeze(0) + self.author_2_institution_edges = self.author_2_institution_edges.unsqueeze(0) + self.paper_edge_locations = self.paper_edge_locations.unsqueeze(0) + self.paper_src_data_mappings = self.paper_src_data_mappings.unsqueeze(0) + self.paper_dest_data_mappings = self.paper_dest_data_mappings.unsqueeze(0) + + self.paper_2_author_src_data_mappings = ( + self.paper_2_author_src_data_mappings.unsqueeze(0) + ) + self.paper_2_author_dest_data_mappings = ( + self.paper_2_author_dest_data_mappings.unsqueeze(0) + ) + self.author_2_institution_src_data_mappings = ( + self.author_2_institution_src_data_mappings.unsqueeze(0) + ) + self.author_2_institution_dest_data_mappings = ( + self.author_2_institution_dest_data_mappings.unsqueeze(0) + ) + return self + + def to(self, device): + """Move the dataset tensors to the specified device. + We want to do this here because this allows us to avoid + copying the data when the different individual tensors are + accessed. + """ + self.paper_features = self.paper_features.to(device) + self.author_features = self.author_features.to(device) + self.institution_features = self.institution_features.to(device) + self.y = self.y.to(device) + self.train_mask = self.train_mask.to(device) + self.val_mask = self.val_mask.to(device) + self.test_mask = self.test_mask.to(device) + self.paper_2_paper_edges = self.paper_2_paper_edges.to(device) + self.paper_2_author_edges = self.paper_2_author_edges.to(device) + self.author_2_institution_edges = self.author_2_institution_edges.to(device) + return self + def __getitem__(self, idx): # There are 5 relations: # paper -> paper @@ -247,3 +305,24 @@ def __getitem__(self, idx): self.institution_features, ] return (features, edge_index, edge_type, rank_mappings) + + +if __name__ == "__main__": + rank = 0 + world_size = 16 + COMM = type( + "dummy_comm", + (object,), + {"get_rank": lambda self: rank, "get_world_size": lambda self: world_size}, + ) + comm = COMM() + + dataset = HeterogeneousDataset( + num_papers=512, + num_authors=128, + num_institutions=32, + num_features=16, + num_classes=4, + comm=comm, + ) + print(dataset[0]) From 88c04104003358548c75b980b65bf7fa90152454 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Tue, 16 Sep 2025 10:43:29 -0700 Subject: [PATCH 10/48] Update fix for data type issues and cache generator, but experiencing hangs --- experiments/OGB-LSC/CacheGenerator.py | 46 +++++++++++-------- experiments/OGB-LSC/RGAT.py | 8 ++-- experiments/OGB-LSC/mag240m/DGraph_MAG240M.py | 6 +++ .../OGB-LSC/synthetic/synthetic_dataset.py | 17 +------ 4 files changed, 39 insertions(+), 38 deletions(-) diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py index b3caa83..14a91fc 100644 --- a/experiments/OGB-LSC/CacheGenerator.py +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -127,23 +127,33 @@ def main(dataset): dataset = dataset.add_batch_dimension() dataset = dataset.to("cpu") - xs, edge_index, edge_type, rank_mapping = dataset[0] - print("Dataset loaded") - - breakpoint() - # get_cache( - # src_gather_cache=None, - # dest_gather_cache=None, - # dest_scatter_cache=None, - # src_gather_cache_file="paper_src_gather_cache.pt", - # dest_gather_cache_file="paper_dest_gather_cache.pt", - # dest_scatter_cache_file="paper_dest_scatter_cache.pt", - # rank=rank, - # world_size=world_size, - # src_indices=edge_index[0][0][0], - # dest_indices=edge_index[0][0][1], - # edge_location=rank_mapping[0], - # src_data_mappings=rank_mapping[2], - # dest_data_mappings=rank_mapping[3],) + + xs, edge_indices, edge_types, rank_mappings = dataset[0] + + for edge_index, edge_type, rank_mapping in zip( + edge_indices, edge_types, rank_mappings + ): + print(f"Edge index shape: {edge_index.shape}") + print(f"Edge type shape: {edge_type}") + print(f"Rank mapping shape: {rank_mapping[0].shape}") + print(f"Rank mapping shape: {rank_mapping[1].shape}") + + get_cache( + src_gather_cache=None, + dest_gather_cache=None, + dest_scatter_cache=None, + src_gather_cache_file="paper_src_gather_cache.pt", + dest_gather_cache_file="paper_dest_gather_cache.pt", + dest_scatter_cache_file="paper_dest_scatter_cache.pt", + rank=rank, + world_size=world_size, + src_indices=edge_index[:, 0], + dest_indices=edge_index[:, 1], + edge_location=rank_mapping[0], + src_data_mappings=rank_mapping[0], + dest_data_mappings=rank_mapping[1], + num_input_rows=xs[edge_type[0]].shape[0], + num_output_rows=xs[edge_type[1]].shape[0], + ) Fire(main) diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index 7b384b7..49e0109 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -288,8 +288,8 @@ def forward(self, xs, adjts, edge_types, rank_mappings): edge_location=rank_mapping[0], src_data_mappings=rank_mapping[0], dest_data_mappings=rank_mapping[1], - num_input_rows=outs[0].size(0), - num_output_rows=outs[1].size(0), + num_input_rows=outs[edge_type[0]].size(0), + num_output_rows=outs[edge_type[1]].size(0), ) src_gather_cache, dest_scatter_cache, dest_gather_cache = caches else: @@ -299,10 +299,10 @@ def forward(self, xs, adjts, edge_types, rank_mappings): src_edge_type, dst_edge_type = edge_type temp_outs[dst_edge_type] += self.layers[i][j]( - temp_outs[dst_edge_type], + outs[dst_edge_type], edge_index, rank_mapping, - x_j=temp_outs[src_edge_type], + x_j=outs[src_edge_type], src_gather_cache=src_gather_cache, dest_gather_cache=dest_gather_cache, dest_scatter_cache=dest_scatter_cache, diff --git a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py index 5f6ca78..35b4498 100644 --- a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py +++ b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py @@ -213,6 +213,12 @@ def process_feature_data(self): ) print("Data processing complete") + def add_batch_dimension(self): + return self + + def to(self, device): + return self + if __name__ == "__main__": import fire diff --git a/experiments/OGB-LSC/synthetic/synthetic_dataset.py b/experiments/OGB-LSC/synthetic/synthetic_dataset.py index 59ab66d..a3a95e2 100644 --- a/experiments/OGB-LSC/synthetic/synthetic_dataset.py +++ b/experiments/OGB-LSC/synthetic/synthetic_dataset.py @@ -55,7 +55,7 @@ def _generate_author_2_institution_edges(num_authors, num_institutions): def _get_rank_mappings(num_vertices, world_size, rank): vertices_per_rank = num_vertices // world_size - rank_mappings = torch.zeros(num_vertices, dtype=torch.uint8) + rank_mappings = torch.zeros(num_vertices, dtype=torch.long) vertices_cur_rank = 0 for r in range(world_size): start = r * vertices_per_rank @@ -231,22 +231,7 @@ def add_batch_dimension(self): self.paper_2_paper_edges = self.paper_2_paper_edges.unsqueeze(0) self.paper_2_author_edges = self.paper_2_author_edges.unsqueeze(0) self.author_2_institution_edges = self.author_2_institution_edges.unsqueeze(0) - self.paper_edge_locations = self.paper_edge_locations.unsqueeze(0) - self.paper_src_data_mappings = self.paper_src_data_mappings.unsqueeze(0) - self.paper_dest_data_mappings = self.paper_dest_data_mappings.unsqueeze(0) - self.paper_2_author_src_data_mappings = ( - self.paper_2_author_src_data_mappings.unsqueeze(0) - ) - self.paper_2_author_dest_data_mappings = ( - self.paper_2_author_dest_data_mappings.unsqueeze(0) - ) - self.author_2_institution_src_data_mappings = ( - self.author_2_institution_src_data_mappings.unsqueeze(0) - ) - self.author_2_institution_dest_data_mappings = ( - self.author_2_institution_dest_data_mappings.unsqueeze(0) - ) return self def to(self, device): From d44fc6e0260e5f16b15509d1d8a6535b3b9a5b45 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Tue, 16 Sep 2025 15:15:58 -0700 Subject: [PATCH 11/48] paper2paper layer running. Isolated hang to author2paper --- experiments/Benchmarks/README.md | 3 ++- experiments/OGB-LSC/RGAT.py | 23 ++++++++++++++++--- experiments/OGB-LSC/main.py | 3 +++ .../OGB-LSC/synthetic/synthetic_dataset.py | 2 ++ 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/experiments/Benchmarks/README.md b/experiments/Benchmarks/README.md index cabbb68..e60f14c 100644 --- a/experiments/Benchmarks/README.md +++ b/experiments/Benchmarks/README.md @@ -34,7 +34,8 @@ class ScatterGraphData: data_rank_mapping: torch.Tensor # Where each data is located edge_rank_placement: torch.Tensor # Where each edge is located edge_dst_rank: torch.Tensor # Rank of the destination vertex of each edge - edge_indices: torch.Tensor # Vertex index of the destination vertex of each num_local_vertices: int # Number of vertices on each rank + edge_indices: torch.Tensor # Vertex index of the destination vertex of each edge + num_local_vertices: int # Number of vertices on each rank ``` *** New communication patterns can be added to the benchmarking code by creating new instances of these dataclasses. *** diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index 49e0109..133b397 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -91,9 +91,10 @@ def forward( h_j, _src_indices, _src_rank_mappings, cache=src_gather_cache ) - messages = torch.cat([h_i, h_j], dim=1) - edge_scores = self.leaky_relu(self.project_message(messages)).squeeze(-1) + messages = torch.cat([h_i, h_j], dim=-1) + edge_scores = self.leaky_relu(self.project_message(messages)) numerator = torch.exp(edge_scores) + denominator = self.comm.scatter( numerator, _dst_indices, @@ -105,7 +106,7 @@ def forward( denominator, _src_indices, _src_rank_mappings, cache=dest_gather_cache ) alpha_ij = numerator / (denominator + 1e-16) - attention_messages = h_j * alpha_ij.unsqueeze(-1) + attention_messages = h_j * alpha_ij out = self.comm.scatter( attention_messages, _dst_indices, @@ -298,6 +299,18 @@ def forward(self, xs, adjts, edge_types, rank_mappings): dest_gather_cache = None src_edge_type, dst_edge_type = edge_type + self.comm.barrier() + if self.comm.get_rank() == 0: + print( + f"Layer {i} Relation {j} started on rank {self.comm.get_rank()}" + ) + print( + f"Edge index shape: {edge_index.shape}" + f" Edge type: {edge_type}", + f" src tensor shape: {outs[src_edge_type].shape}", + f" dst tensor shape: {outs[dst_edge_type].shape}", + ) + self.comm.barrier() temp_outs[dst_edge_type] += self.layers[i][j]( outs[dst_edge_type], edge_index, @@ -307,6 +320,10 @@ def forward(self, xs, adjts, edge_types, rank_mappings): dest_gather_cache=dest_gather_cache, dest_scatter_cache=dest_scatter_cache, ) + self.comm.barrier() + if self.comm.get_rank() == 0: + print(f"Layer {i} Relation {j} done on rank {self.comm.get_rank()}") + self.comm.barrier() outs = [ self.bn_layers[i](temp_outs[feat]) for feat in range(len(temp_outs)) ] diff --git a/experiments/OGB-LSC/main.py b/experiments/OGB-LSC/main.py index bc32cb6..ce4a071 100644 --- a/experiments/OGB-LSC/main.py +++ b/experiments/OGB-LSC/main.py @@ -92,6 +92,9 @@ def main( assert comm_type in ["nccl", "nvshmem"] comm = Comm.Communicator.init_process_group(comm_type) + comm.barrier() + print(f"Running with {comm.get_world_size()} ranks. Rank: {comm.get_rank()}") + graph_dataset = graph_dataset(comm=comm) trainer = Trainer(graph_dataset, comm) diff --git a/experiments/OGB-LSC/synthetic/synthetic_dataset.py b/experiments/OGB-LSC/synthetic/synthetic_dataset.py index a3a95e2..a2664c9 100644 --- a/experiments/OGB-LSC/synthetic/synthetic_dataset.py +++ b/experiments/OGB-LSC/synthetic/synthetic_dataset.py @@ -14,6 +14,8 @@ from DGraph.Communicator import Communicator import torch +torch.random.manual_seed(0) + def _generate_paper_2_paper_edges(num_papers): # Average degree of a paper is ~11 From e3f091e1e9801f598ef20ddfb475efaab258f3fe Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 17 Sep 2025 10:16:47 -0700 Subject: [PATCH 12/48] Updating the synthetic dataset to track down hang on directed relation edge between papers and authors --- DGraph/distributed/nccl/_nccl_cache.py | 3 ++ experiments/OGB-LSC/CacheGenerator.py | 36 +++++++++++++++---- experiments/OGB-LSC/RGAT.py | 26 ++++++++++++++ .../OGB-LSC/synthetic/synthetic_dataset.py | 20 +++++++---- 4 files changed, 71 insertions(+), 14 deletions(-) diff --git a/DGraph/distributed/nccl/_nccl_cache.py b/DGraph/distributed/nccl/_nccl_cache.py index 28a2d01..8d88d54 100644 --- a/DGraph/distributed/nccl/_nccl_cache.py +++ b/DGraph/distributed/nccl/_nccl_cache.py @@ -201,6 +201,7 @@ def NCCLScatterCacheGenerator( indices, edge_placement, remote_recv_mask, num_output_rows, rank, world_size ) + breakpoint() # Information for the backward pass # It's a gather operation so quite a bit simpler @@ -216,6 +217,8 @@ def NCCLScatterCacheGenerator( ) ) + breakpoint() + _cache = NCCLScatterCache( scatter_recv_local_placement=recv_placement, scatter_local_comm_mask=local_send_mask, diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py index 14a91fc..cbf5696 100644 --- a/experiments/OGB-LSC/CacheGenerator.py +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -38,6 +38,7 @@ def get_cache( num_input_rows, num_output_rows, ): + breakpoint() if src_gather_cache is None: _src_gather_cache = NCCLGatherCacheGenerator( @@ -101,9 +102,12 @@ def main(dataset): synthetic_config = SyntheticDatasetConfig() graph_dataset = partial( Dataset, - num_papers=synthetic_config.num_papers, - num_authors=synthetic_config.num_authors, - num_institutions=synthetic_config.num_institutions, + # num_papers=synthetic_config.num_papers, + num_papers=100, + # num_authors=synthetic_config.num_authors, + num_authors=32, + # num_institutions=synthetic_config.num_institutions, + num_institutions=16, num_features=synthetic_config.num_features, num_classes=synthetic_config.num_classes, ) @@ -130,6 +134,9 @@ def main(dataset): xs, edge_indices, edge_types, rank_mappings = dataset[0] + # for simulated_rank in range(world_size): + simulated_rank = 0 + rel = 1 for edge_index, edge_type, rank_mapping in zip( edge_indices, edge_types, rank_mappings ): @@ -142,10 +149,10 @@ def main(dataset): src_gather_cache=None, dest_gather_cache=None, dest_scatter_cache=None, - src_gather_cache_file="paper_src_gather_cache.pt", - dest_gather_cache_file="paper_dest_gather_cache.pt", - dest_scatter_cache_file="paper_dest_scatter_cache.pt", - rank=rank, + src_gather_cache_file=f"test_cache/synthetic_src_gather_cache_{rel}_{simulated_rank}_{world_size}.pt", + dest_gather_cache_file=f"test_cache/synthetic_dest_gather_cache_{rel}_{simulated_rank}_{world_size}.pt", + dest_scatter_cache_file=f"test_cache/synthetic_dest_scatter_cache_{rel}_{simulated_rank}_{world_size}.pt", + rank=simulated_rank, world_size=world_size, src_indices=edge_index[:, 0], dest_indices=edge_index[:, 1], @@ -155,5 +162,20 @@ def main(dataset): num_input_rows=xs[edge_type[0]].shape[0], num_output_rows=xs[edge_type[1]].shape[0], ) + rel += 1 + rel = 3 + synthetic_scatter_cache_1 = torch.load( + f"test_cache/synthetic_dest_scatter_cache_{rel}_1_{world_size}.pt", + weights_only=False, + ) + synthetic_scatter_cache_0 = torch.load( + f"test_cache/synthetic_dest_scatter_cache_{rel}_0_{world_size}.pt", + weights_only=False, + ) + + print(synthetic_scatter_cache_1.scatter_recv_local_placement) + print(synthetic_scatter_cache_0.scatter_recv_local_placement) + + breakpoint() Fire(main) diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index 133b397..fef2c76 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -78,23 +78,43 @@ def forward( _src_indices = edge_index[:, 0, :] _dst_indices = edge_index[:, 1, :] + self.comm.barrier() _src_rank_mappings = torch.cat( [rank_mapping[0].unsqueeze(0), rank_mapping[0].unsqueeze(0)], dim=0 ) + self.comm.barrier() + if self.comm.get_rank() == 0: + print("finished computing _src_rank_mappings") + self.comm.barrier() _dst_rank_mappings = torch.cat( [rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0 ) + self.comm.barrier() + if self.comm.get_rank() == 0: + print("finished computing _dst_rank_mappings") + self.comm.barrier() h_i = self.comm.gather( h, _dst_indices, _dst_rank_mappings, cache=dest_gather_cache ) + self.comm.barrier() + if self.comm.get_rank() == 0: + print("finished gathering h_i") + self.comm.barrier() + h_j = self.comm.gather( h_j, _src_indices, _src_rank_mappings, cache=src_gather_cache ) + self.comm.barrier() + if self.comm.get_rank() == 0: + print("finished gathering h_j") + self.comm.barrier() messages = torch.cat([h_i, h_j], dim=-1) edge_scores = self.leaky_relu(self.project_message(messages)) numerator = torch.exp(edge_scores) + self.comm.barrier() + denominator = self.comm.scatter( numerator, _dst_indices, @@ -102,9 +122,15 @@ def forward( h.size(1), cache=dest_scatter_cache, ) + self.comm.barrier() + if self.comm.get_rank() == 0: + print("finished scatter") + self.comm.barrier() denominator = self.comm.gather( denominator, _src_indices, _src_rank_mappings, cache=dest_gather_cache ) + self.comm.barrier() + alpha_ij = numerator / (denominator + 1e-16) attention_messages = h_j * alpha_ij out = self.comm.scatter( diff --git a/experiments/OGB-LSC/synthetic/synthetic_dataset.py b/experiments/OGB-LSC/synthetic/synthetic_dataset.py index a2664c9..7b5b0a4 100644 --- a/experiments/OGB-LSC/synthetic/synthetic_dataset.py +++ b/experiments/OGB-LSC/synthetic/synthetic_dataset.py @@ -24,12 +24,15 @@ def _generate_paper_2_paper_edges(num_papers): low=0, high=num_papers, size=(2, num_edges), dtype=torch.long ) coo_list = torch.unique(coo_list, dim=1) + transpose = coo_list.flip(0) + coo_list = torch.cat([coo_list, transpose], dim=1) + coo_list = torch.sort(coo_list, dim=1).values return coo_list -def _generate_paper_2_author_edges(num_papers, num_authors): +def _generate_author_2_paper_edges(num_authors, num_papers): # Average number of authors per paper is ~3.5 - num_edges = int(num_papers * 3.5) + num_edges = int(num_authors * 3.5) dest_papers = torch.randint( low=0, high=num_papers, size=(1, num_edges), dtype=torch.long ) @@ -138,8 +141,8 @@ def __init__( self.paper_src_data_mappings = paper_2_paper_src_data_mappings self.paper_dest_data_mappings = paper_2_paper_dest_data_mappings - self.paper_2_author_edges = _generate_paper_2_author_edges( - num_papers, num_authors + self.author_2_paper_edges = _generate_author_2_paper_edges( + num_authors, num_papers ) ( @@ -261,12 +264,15 @@ def __getitem__(self, idx): # author -> paper # author -> institution # institution -> author + edge_index = [ self.paper_2_paper_edges, self.paper_2_author_edges, - self.paper_2_author_edges.flip(0), + self.paper_2_author_edges.flip(self.paper_2_author_edges.dim() - 2), self.author_2_institution_edges, - self.author_2_institution_edges.flip(0), + self.author_2_institution_edges.flip( + self.author_2_institution_edges.dim() - 2 + ), ] # Locations of the edges rank_mappings = [ @@ -281,7 +287,7 @@ def __getitem__(self, idx): self.author_2_institution_src_data_mappings, ], [ - self.author_2_institution_dest_data_mappings, + self.author_2_institution_edge_locations, self.author_2_institution_src_data_mappings, ], ] From 1b45b7559b37faaed204da91a984887f7a9edfda Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 17 Sep 2025 13:46:14 -0700 Subject: [PATCH 13/48] Still debugging error on miscalculated gradient size --- DGraph/distributed/nccl/_nccl_cache.py | 9 +-- experiments/OGB-LSC/CacheGenerator.py | 62 ++++++++++--------- .../OGB-LSC/synthetic/synthetic_dataset.py | 43 ++++++------- 3 files changed, 59 insertions(+), 55 deletions(-) diff --git a/DGraph/distributed/nccl/_nccl_cache.py b/DGraph/distributed/nccl/_nccl_cache.py index 8d88d54..9d62ab2 100644 --- a/DGraph/distributed/nccl/_nccl_cache.py +++ b/DGraph/distributed/nccl/_nccl_cache.py @@ -102,7 +102,10 @@ def all_to_all_cache_helper( # No local sends continue _mask = (edge_vertex_ranks == rank) & (edge_placement == i) - _send_row = indices[0][_mask] % num_rows + try: + _send_row = indices[0][_mask] % num_rows + except: + breakpoint() send_local_placement[i] = _send_row @@ -201,9 +204,9 @@ def NCCLScatterCacheGenerator( indices, edge_placement, remote_recv_mask, num_output_rows, rank, world_size ) - breakpoint() # Information for the backward pass # It's a gather operation so quite a bit simpler + breakpoint() num_grad_output_rows = int(local_edges_mask.sum().item()) send_comm_vector, recv_comm_vector, send_local_placement, recv_local_placement = ( @@ -217,8 +220,6 @@ def NCCLScatterCacheGenerator( ) ) - breakpoint() - _cache = NCCLScatterCache( scatter_recv_local_placement=recv_placement, scatter_local_comm_mask=local_send_mask, diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py index cbf5696..ce12ab4 100644 --- a/experiments/OGB-LSC/CacheGenerator.py +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -38,7 +38,7 @@ def get_cache( num_input_rows, num_output_rows, ): - breakpoint() + # breakpoint() if src_gather_cache is None: _src_gather_cache = NCCLGatherCacheGenerator( @@ -136,33 +136,39 @@ def main(dataset): # for simulated_rank in range(world_size): simulated_rank = 0 - rel = 1 - for edge_index, edge_type, rank_mapping in zip( - edge_indices, edge_types, rank_mappings - ): - print(f"Edge index shape: {edge_index.shape}") - print(f"Edge type shape: {edge_type}") - print(f"Rank mapping shape: {rank_mapping[0].shape}") - print(f"Rank mapping shape: {rank_mapping[1].shape}") - - get_cache( - src_gather_cache=None, - dest_gather_cache=None, - dest_scatter_cache=None, - src_gather_cache_file=f"test_cache/synthetic_src_gather_cache_{rel}_{simulated_rank}_{world_size}.pt", - dest_gather_cache_file=f"test_cache/synthetic_dest_gather_cache_{rel}_{simulated_rank}_{world_size}.pt", - dest_scatter_cache_file=f"test_cache/synthetic_dest_scatter_cache_{rel}_{simulated_rank}_{world_size}.pt", - rank=simulated_rank, - world_size=world_size, - src_indices=edge_index[:, 0], - dest_indices=edge_index[:, 1], - edge_location=rank_mapping[0], - src_data_mappings=rank_mapping[0], - dest_data_mappings=rank_mapping[1], - num_input_rows=xs[edge_type[0]].shape[0], - num_output_rows=xs[edge_type[1]].shape[0], - ) - rel += 1 + for simulated_rank in [0, 1]: + rel = 0 + + for edge_index, edge_type, rank_mapping in zip( + edge_indices, edge_types, rank_mappings + ): + if rel < 4: + rel += 1 + continue + print(f"Edge index shape: {edge_index.shape}") + print(f"Edge type shape: {edge_type}") + print(f"Rank mapping shape: {rank_mapping[0].shape}") + print(f"Rank mapping shape: {rank_mapping[1].shape}") + + get_cache( + src_gather_cache=None, + dest_gather_cache=None, + dest_scatter_cache=None, + src_gather_cache_file=f"test_cache/synthetic_src_gather_cache_{rel}_{simulated_rank}_{world_size}.pt", + dest_gather_cache_file=f"test_cache/synthetic_dest_gather_cache_{rel}_{simulated_rank}_{world_size}.pt", + dest_scatter_cache_file=f"test_cache/synthetic_dest_scatter_cache_{rel}_{simulated_rank}_{world_size}.pt", + rank=simulated_rank, + world_size=world_size, + src_indices=edge_index[:, 0], + dest_indices=edge_index[:, 1], + edge_location=rank_mapping[0], + src_data_mappings=rank_mapping[0], + dest_data_mappings=rank_mapping[1], + num_input_rows=xs[edge_type[0]].shape[0], + num_output_rows=xs[edge_type[1]].shape[0], + ) + + rel += 1 rel = 3 synthetic_scatter_cache_1 = torch.load( f"test_cache/synthetic_dest_scatter_cache_{rel}_1_{world_size}.pt", diff --git a/experiments/OGB-LSC/synthetic/synthetic_dataset.py b/experiments/OGB-LSC/synthetic/synthetic_dataset.py index 7b5b0a4..52e96e7 100644 --- a/experiments/OGB-LSC/synthetic/synthetic_dataset.py +++ b/experiments/OGB-LSC/synthetic/synthetic_dataset.py @@ -78,10 +78,9 @@ def edge_mapping_from_vertex_mapping(edge_index, src_rank_mappings, dst_rank_map # We put the edge on the rank where the destination vertex is located # Since heterogeneous graphs have different rank mappings for different # vertex types. - edge_placement = dst_rank_mappings[dest_indices] src_data_mappings = src_rank_mappings[src_indices] dest_data_mappings = dst_rank_mappings[dest_indices] - return (edge_placement, src_data_mappings, dest_data_mappings) + return (src_data_mappings, dest_data_mappings) class HeterogeneousDataset: @@ -128,7 +127,6 @@ def __init__( self.paper_2_paper_edges = _generate_paper_2_paper_edges(num_papers) ( - paper_2_paper_edge_location, paper_2_paper_src_data_mappings, paper_2_paper_dest_data_mappings, ) = edge_mapping_from_vertex_mapping( @@ -137,7 +135,6 @@ def __init__( dst_rank_mappings=self.paper_vertex_rank_mapping, ) - self.paper_edge_locations = paper_2_paper_edge_location self.paper_src_data_mappings = paper_2_paper_src_data_mappings self.paper_dest_data_mappings = paper_2_paper_dest_data_mappings @@ -146,24 +143,21 @@ def __init__( ) ( - paper_2_author_edge_location, - paper_2_author_src_data_mappings, - paper_2_author_dest_data_mappings, + author_2_paper_src_data_mappings, + author_2_paper_dest_data_mappings, ) = edge_mapping_from_vertex_mapping( - edge_index=self.paper_2_author_edges, + edge_index=self.author_2_paper_edges, src_rank_mappings=self.author_vertex_rank_mapping, dst_rank_mappings=self.paper_vertex_rank_mapping, ) - self.paper_2_author_edge_locations = paper_2_author_edge_location - self.paper_2_author_src_data_mappings = paper_2_author_src_data_mappings - self.paper_2_author_dest_data_mappings = paper_2_author_dest_data_mappings + self.author_2_paper_src_data_mappings = author_2_paper_src_data_mappings + self.author_2_paper_dest_data_mappings = author_2_paper_dest_data_mappings self.author_2_institution_edges = _generate_author_2_institution_edges( num_authors, num_institutions ) ( - author_2_institution_edge_location, author_2_institution_src_data_mappings, author_2_institution_dest_data_mappings, ) = edge_mapping_from_vertex_mapping( @@ -171,7 +165,7 @@ def __init__( src_rank_mappings=self.author_vertex_rank_mapping, dst_rank_mappings=self.institution_vertex_rank_mapping, ) - self.author_2_institution_edge_locations = author_2_institution_edge_location + self.author_2_institution_src_data_mappings = ( author_2_institution_src_data_mappings ) @@ -234,7 +228,7 @@ def add_batch_dimension(self): self.val_mask = self.val_mask.unsqueeze(0) self.test_mask = self.test_mask.unsqueeze(0) self.paper_2_paper_edges = self.paper_2_paper_edges.unsqueeze(0) - self.paper_2_author_edges = self.paper_2_author_edges.unsqueeze(0) + self.author_2_paper_edges = self.author_2_paper_edges.unsqueeze(0) self.author_2_institution_edges = self.author_2_institution_edges.unsqueeze(0) return self @@ -253,7 +247,7 @@ def to(self, device): self.val_mask = self.val_mask.to(device) self.test_mask = self.test_mask.to(device) self.paper_2_paper_edges = self.paper_2_paper_edges.to(device) - self.paper_2_author_edges = self.paper_2_author_edges.to(device) + self.author_2_paper_edges = self.author_2_paper_edges.to(device) self.author_2_institution_edges = self.author_2_institution_edges.to(device) return self @@ -267,8 +261,8 @@ def __getitem__(self, idx): edge_index = [ self.paper_2_paper_edges, - self.paper_2_author_edges, - self.paper_2_author_edges.flip(self.paper_2_author_edges.dim() - 2), + self.author_2_paper_edges, + self.author_2_paper_edges.flip(self.author_2_paper_edges.dim() - 2), self.author_2_institution_edges, self.author_2_institution_edges.flip( self.author_2_institution_edges.dim() - 2 @@ -276,18 +270,21 @@ def __getitem__(self, idx): ] # Locations of the edges rank_mappings = [ - [self.paper_edge_locations, self.paper_dest_data_mappings], - [self.paper_2_author_edge_locations, self.paper_2_author_src_data_mappings], + [self.paper_src_data_mappings, self.paper_dest_data_mappings], + [ + self.author_2_paper_src_data_mappings, + self.author_2_paper_dest_data_mappings, + ], [ - self.paper_2_author_edge_locations, - self.paper_2_author_dest_data_mappings, + self.author_2_paper_dest_data_mappings, + self.author_2_paper_src_data_mappings, ], [ - self.author_2_institution_edge_locations, self.author_2_institution_src_data_mappings, + self.author_2_institution_dest_data_mappings, ], [ - self.author_2_institution_edge_locations, + self.author_2_institution_dest_data_mappings, self.author_2_institution_src_data_mappings, ], ] From f71f109b9487b49314372c8676430f2c0a14d8b1 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 17 Sep 2025 16:37:59 -0700 Subject: [PATCH 14/48] Fix for incorrect tensor shape --- DGraph/distributed/nccl/_nccl_cache.py | 7 ++----- experiments/OGB-LSC/CacheGenerator.py | 11 +++++------ experiments/OGB-LSC/RGAT.py | 4 ++-- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/DGraph/distributed/nccl/_nccl_cache.py b/DGraph/distributed/nccl/_nccl_cache.py index 9d62ab2..7ebef3c 100644 --- a/DGraph/distributed/nccl/_nccl_cache.py +++ b/DGraph/distributed/nccl/_nccl_cache.py @@ -102,10 +102,7 @@ def all_to_all_cache_helper( # No local sends continue _mask = (edge_vertex_ranks == rank) & (edge_placement == i) - try: - _send_row = indices[0][_mask] % num_rows - except: - breakpoint() + _send_row = indices[0][_mask] % num_rows send_local_placement[i] = _send_row @@ -200,13 +197,13 @@ def NCCLScatterCacheGenerator( receving_ranks = torch.unique(local_dest_ranks_slice[local_send_mask]) + breakpoint() recv_placement = _get_local_unique_recv_placement( indices, edge_placement, remote_recv_mask, num_output_rows, rank, world_size ) # Information for the backward pass # It's a gather operation so quite a bit simpler - breakpoint() num_grad_output_rows = int(local_edges_mask.sum().item()) send_comm_vector, recv_comm_vector, send_local_placement, recv_local_placement = ( diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py index ce12ab4..6329722 100644 --- a/experiments/OGB-LSC/CacheGenerator.py +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -82,6 +82,8 @@ def get_cache( else: _dest_gather_cache = dest_gather_cache + # Unit tests + return _src_gather_cache, _dest_scatter_cache, _dest_gather_cache @@ -117,7 +119,7 @@ def main(dataset): graph_dataset = partial(Dataset, data_dir="data/MAG240M") rank = 0 - world_size = 16 + world_size = 4 COMM = type( "dummy_comm", (object,), @@ -142,9 +144,6 @@ def main(dataset): for edge_index, edge_type, rank_mapping in zip( edge_indices, edge_types, rank_mappings ): - if rel < 4: - rel += 1 - continue print(f"Edge index shape: {edge_index.shape}") print(f"Edge type shape: {edge_type}") print(f"Rank mapping shape: {rank_mapping[0].shape}") @@ -164,8 +163,8 @@ def main(dataset): edge_location=rank_mapping[0], src_data_mappings=rank_mapping[0], dest_data_mappings=rank_mapping[1], - num_input_rows=xs[edge_type[0]].shape[0], - num_output_rows=xs[edge_type[1]].shape[0], + num_input_rows=xs[edge_type[0]].shape[1], + num_output_rows=xs[edge_type[1]].shape[1], ) rel += 1 diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index fef2c76..dfd8166 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -315,8 +315,8 @@ def forward(self, xs, adjts, edge_types, rank_mappings): edge_location=rank_mapping[0], src_data_mappings=rank_mapping[0], dest_data_mappings=rank_mapping[1], - num_input_rows=outs[edge_type[0]].size(0), - num_output_rows=outs[edge_type[1]].size(0), + num_input_rows=outs[edge_type[0]].size(1), + num_output_rows=outs[edge_type[1]].size(1), ) src_gather_cache, dest_scatter_cache, dest_gather_cache = caches else: From 076a692ce68845e9408c981b075dd0888e1eba6c Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Fri, 26 Sep 2025 14:29:16 -0700 Subject: [PATCH 15/48] Author 2 paper relation working. Only author 2 institution error remains. Incorrect local recv tensor size --- DGraph/distributed/RankLocalOps.py | 5 ++- DGraph/distributed/nccl/NCCLBackendEngine.py | 15 +++++++ DGraph/distributed/nccl/_nccl_cache.py | 6 +-- DGraph/distributed/nccl/alltoallv_impl.py | 29 +++++++++++++ experiments/OGB-LSC/CacheGenerator.py | 16 ++++--- experiments/OGB-LSC/RGAT.py | 43 +++++++++++++++---- experiments/OGB-LSC/Trainer.py | 4 ++ experiments/OGB-LSC/config.py | 6 +-- .../OGB-LSC/synthetic/synthetic_dataset.py | 2 +- 9 files changed, 102 insertions(+), 24 deletions(-) diff --git a/DGraph/distributed/RankLocalOps.py b/DGraph/distributed/RankLocalOps.py index c4b6de0..243ef16 100644 --- a/DGraph/distributed/RankLocalOps.py +++ b/DGraph/distributed/RankLocalOps.py @@ -16,6 +16,7 @@ """ import torch +import torch.distributed as dist try: from DGraph.torch_local import local_masked_gather, local_masked_scatter @@ -140,7 +141,9 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping): unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True) rank_mapping = rank_mapping.to(_indices.device) renumbered_indices = inverse_indices - unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device) + unique_rank_mapping = torch.zeros_like( + unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device + ) unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping) return renumbered_indices, unique_indices, unique_rank_mapping diff --git a/DGraph/distributed/nccl/NCCLBackendEngine.py b/DGraph/distributed/nccl/NCCLBackendEngine.py index b3ea11a..55a0e5e 100644 --- a/DGraph/distributed/nccl/NCCLBackendEngine.py +++ b/DGraph/distributed/nccl/NCCLBackendEngine.py @@ -116,6 +116,14 @@ def forward( needs_comm = (local_recv_tensor != rank).any() + # For debugging: Delete later + dist.barrier() + for i in range(world_size): + if i == rank: + print(f"Rank {rank} reached local gather") + dist.barrier() + dist.barrier() + recv_tensor = OptimizedRankLocalMaskedGather( local_send_tensor, local_indices, @@ -123,6 +131,13 @@ def forward( recv_tensor, rank, ) + # For debugging: Delete later + dist.barrier() + for i in range(world_size): + if i == rank: + print(f"Rank {rank} finished local gather") + dist.barrier() + dist.barrier() if needs_comm: diff --git a/DGraph/distributed/nccl/_nccl_cache.py b/DGraph/distributed/nccl/_nccl_cache.py index 7ebef3c..b1d7bbc 100644 --- a/DGraph/distributed/nccl/_nccl_cache.py +++ b/DGraph/distributed/nccl/_nccl_cache.py @@ -82,6 +82,7 @@ def all_to_all_cache_helper( recv_local_placement = {} + breakpoint() for i, num_messages in enumerate(recv_comm_vector): if num_messages == 0: continue @@ -106,6 +107,7 @@ def all_to_all_cache_helper( send_local_placement[i] = _send_row + breakpoint() return ( send_comm_vector, recv_comm_vector, @@ -197,11 +199,9 @@ def NCCLScatterCacheGenerator( receving_ranks = torch.unique(local_dest_ranks_slice[local_send_mask]) - breakpoint() recv_placement = _get_local_unique_recv_placement( indices, edge_placement, remote_recv_mask, num_output_rows, rank, world_size ) - # Information for the backward pass # It's a gather operation so quite a bit simpler @@ -254,7 +254,7 @@ def NCCLGatherCacheGenerator( indices, edge_placement, edge_dest_ranks, num_input_rows, rank, world_size ) ) - + breakpoint() local_slice_mask = edge_placement == rank local_mask = edge_placement[local_slice_mask] diff --git a/DGraph/distributed/nccl/alltoallv_impl.py b/DGraph/distributed/nccl/alltoallv_impl.py index 060c390..c9cd783 100644 --- a/DGraph/distributed/nccl/alltoallv_impl.py +++ b/DGraph/distributed/nccl/alltoallv_impl.py @@ -18,6 +18,13 @@ def _nccl_alltoall_v( num_features = local_send_tensor.shape[2] num_src_rows = local_send_tensor.shape[1] + # For debugging: Delete later + dist.barrier() + for i in range(world_size): + if i == rank: + print(f"Rank {rank} starting comm") + dist.barrier() + recv_buffer_dict = {} if cache is None: @@ -77,6 +84,11 @@ def _nccl_alltoall_v( recv_local_placement = cache.gather_recv_local_placement send_local_placement = cache.gather_send_local_placement + dist.barrier() + if rank == 0: + breakpoint() + dist.barrier() + # Allocate the receive buffers for i, num_messages in enumerate(recv_comm_vector): if num_messages == 0: @@ -116,12 +128,29 @@ def _nccl_alltoall_v( recv_tensor = recv_buffer_dict[send_rank_index] p2p_op_list.append(dist.P2POp(dist.irecv, recv_tensor, send_rank_index)) + # For debugging: Delete later + dist.barrier() + for i in range(world_size): + if i == rank: + print(f"Rank {rank} starting batch_isend_irecv") + for op in p2p_op_list: + print(f"Rank {rank} {op.op.__name__} {op.tensor.shape} to {op.peer}") + for key, recv_buffer in recv_buffer_dict.items(): + print(f"Rank {rank} expecting {recv_buffer.shape} from {key}") + dist.barrier() if len(p2p_op_list) > 0: reqs = dist.batch_isend_irecv(p2p_op_list) for req in reqs: req.wait() + # For debugging: Delete later + dist.barrier() + for i in range(world_size): + if i == rank: + print(f"Rank {rank} reached here") + dist.barrier() + for key, recv_buffer in recv_buffer_dict.items(): local_recv_tensor[:, recv_local_placement[key].view(-1), :] = ( diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py index 6329722..38dbd06 100644 --- a/experiments/OGB-LSC/CacheGenerator.py +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -38,7 +38,7 @@ def get_cache( num_input_rows, num_output_rows, ): - # breakpoint() + if src_gather_cache is None: _src_gather_cache = NCCLGatherCacheGenerator( @@ -77,6 +77,7 @@ def get_cache( rank=rank, world_size=world_size, ) + breakpoint() torch.save(_dest_gather_cache, dest_gather_cache_file) else: @@ -104,12 +105,9 @@ def main(dataset): synthetic_config = SyntheticDatasetConfig() graph_dataset = partial( Dataset, - # num_papers=synthetic_config.num_papers, - num_papers=100, - # num_authors=synthetic_config.num_authors, - num_authors=32, - # num_institutions=synthetic_config.num_institutions, - num_institutions=16, + num_papers=synthetic_config.num_papers, + num_authors=synthetic_config.num_authors, + num_institutions=synthetic_config.num_institutions, num_features=synthetic_config.num_features, num_classes=synthetic_config.num_classes, ) @@ -144,11 +142,15 @@ def main(dataset): for edge_index, edge_type, rank_mapping in zip( edge_indices, edge_types, rank_mappings ): + if rel != 3: + rel += 1 + continue print(f"Edge index shape: {edge_index.shape}") print(f"Edge type shape: {edge_type}") print(f"Rank mapping shape: {rank_mapping[0].shape}") print(f"Rank mapping shape: {rank_mapping[1].shape}") + breakpoint() get_cache( src_gather_cache=None, dest_gather_cache=None, diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index dfd8166..c183eff 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -18,6 +18,7 @@ from distributed_layers import DistributedBatchNorm1D import os.path as osp from CacheGenerator import get_cache +import sys class ConvLayer(nn.Module): @@ -83,19 +84,37 @@ def forward( [rank_mapping[0].unsqueeze(0), rank_mapping[0].unsqueeze(0)], dim=0 ) self.comm.barrier() - if self.comm.get_rank() == 0: - print("finished computing _src_rank_mappings") - self.comm.barrier() _dst_rank_mappings = torch.cat( [rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0 ) + self.comm.barrier() if self.comm.get_rank() == 0: - print("finished computing _dst_rank_mappings") + print("starting gather") + self.comm.barrier() + + if self.comm.get_rank() == 0: + print(f"h shape: {h.shape}") + print(f"h_j shape: {h_j.shape}") + # breakpoint() self.comm.barrier() + + # sys.exit(0) # --- IGNORE --- + h_i = self.comm.gather( h, _dst_indices, _dst_rank_mappings, cache=dest_gather_cache ) + + if self.comm.get_rank() == 0: + print("finished computing _dst_rank_mappings") + self.comm.barrier() + + for i in range(self.comm.get_world_size()): + self.comm.barrier() + if self.comm.get_rank() == i: + print(f"Rank {i} h_i shape: {h_i.shape}") + self.comm.barrier() + self.comm.barrier() if self.comm.get_rank() == 0: print("finished gathering h_i") @@ -113,8 +132,13 @@ def forward( edge_scores = self.leaky_relu(self.project_message(messages)) numerator = torch.exp(edge_scores) - self.comm.barrier() + if self.comm.get_rank() == 0: + print(f"Numerator shape: {numerator.shape}") + self.comm.barrier() + if self.comm.get_rank() == 0: + print("starting scatter") + self.comm.barrier() denominator = self.comm.scatter( numerator, _dst_indices, @@ -299,7 +323,9 @@ def forward(self, xs, adjts, edge_types, rank_mappings): for j, (edge_index, edge_type, rank_mapping) in enumerate( zip(adjts, edge_types, rank_mappings) ): - + if j != 3: + continue + src_edge_type, dst_edge_type = edge_type if self.use_cache: caches = get_cache( src_gather_cache=self.src_gather_caches[j], @@ -315,8 +341,8 @@ def forward(self, xs, adjts, edge_types, rank_mappings): edge_location=rank_mapping[0], src_data_mappings=rank_mapping[0], dest_data_mappings=rank_mapping[1], - num_input_rows=outs[edge_type[0]].size(1), - num_output_rows=outs[edge_type[1]].size(1), + num_input_rows=outs[src_edge_type].size(1), + num_output_rows=outs[dst_edge_type].size(1), ) src_gather_cache, dest_scatter_cache, dest_gather_cache = caches else: @@ -324,7 +350,6 @@ def forward(self, xs, adjts, edge_types, rank_mappings): dest_scatter_cache = None dest_gather_cache = None - src_edge_type, dst_edge_type = edge_type self.comm.barrier() if self.comm.get_rank() == 0: print( diff --git a/experiments/OGB-LSC/Trainer.py b/experiments/OGB-LSC/Trainer.py index 13c15b1..0242b0d 100644 --- a/experiments/OGB-LSC/Trainer.py +++ b/experiments/OGB-LSC/Trainer.py @@ -25,6 +25,10 @@ def __init__(self, dataset, comm): # TODO: We need some better way to set the device but # difficult to do that since systems have different bindings. # self.device = torch.device(f"cuda:{comm.get_local_rank()}") + rank = comm.get_rank() + print(f"Rank {rank} using GPU {rank % torch.cuda.device_count()}") + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) self.device = torch.device("cuda") self.model = CommAwareRGAT( in_channels=self.model_config.num_features, diff --git a/experiments/OGB-LSC/config.py b/experiments/OGB-LSC/config.py index 0c0a3f5..a811160 100644 --- a/experiments/OGB-LSC/config.py +++ b/experiments/OGB-LSC/config.py @@ -17,10 +17,10 @@ @dataclass class ModelConfig: - hidden_channels: int = 1024 + hidden_channels: int = 16 dropout: float = 0.5 num_layers: int = 2 - num_features: int = 768 + num_features: int = 16 num_relations: int = 5 num_classes: int = 153 heads: int = 4 @@ -40,5 +40,5 @@ class SyntheticDatasetConfig: num_papers: int = 2048 num_authors: int = 512 num_institutions: int = 16 - num_features: int = 768 + num_features: int = 16 num_classes: int = 153 diff --git a/experiments/OGB-LSC/synthetic/synthetic_dataset.py b/experiments/OGB-LSC/synthetic/synthetic_dataset.py index 52e96e7..7d7b520 100644 --- a/experiments/OGB-LSC/synthetic/synthetic_dataset.py +++ b/experiments/OGB-LSC/synthetic/synthetic_dataset.py @@ -288,7 +288,7 @@ def __getitem__(self, idx): self.author_2_institution_src_data_mappings, ], ] - edge_type = [(0, 0), (0, 1), (1, 0), (1, 2), (2, 1)] + edge_type = [(0, 0), (1, 0), (0, 1), (1, 2), (2, 1)] features = [ self.paper_features, self.author_features, From 2d27416415b0f3ab217f43c606af046fb9f78ddc Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Fri, 26 Sep 2025 14:59:53 -0700 Subject: [PATCH 16/48] Remove extra breakpoints in cache generators --- DGraph/distributed/nccl/_nccl_cache.py | 2 -- experiments/OGB-LSC/CacheGenerator.py | 1 - experiments/OGB-LSC/README.md | 1 + 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/DGraph/distributed/nccl/_nccl_cache.py b/DGraph/distributed/nccl/_nccl_cache.py index b1d7bbc..0774e58 100644 --- a/DGraph/distributed/nccl/_nccl_cache.py +++ b/DGraph/distributed/nccl/_nccl_cache.py @@ -82,7 +82,6 @@ def all_to_all_cache_helper( recv_local_placement = {} - breakpoint() for i, num_messages in enumerate(recv_comm_vector): if num_messages == 0: continue @@ -107,7 +106,6 @@ def all_to_all_cache_helper( send_local_placement[i] = _send_row - breakpoint() return ( send_comm_vector, recv_comm_vector, diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py index 38dbd06..eb5d4f6 100644 --- a/experiments/OGB-LSC/CacheGenerator.py +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -77,7 +77,6 @@ def get_cache( rank=rank, world_size=world_size, ) - breakpoint() torch.save(_dest_gather_cache, dest_gather_cache_file) else: diff --git a/experiments/OGB-LSC/README.md b/experiments/OGB-LSC/README.md index 0d79356..4a35ad5 100644 --- a/experiments/OGB-LSC/README.md +++ b/experiments/OGB-LSC/README.md @@ -4,6 +4,7 @@ ## Requirements +- fire ## Data preparation The dataset is fairly large (over 100GB). Please follow the instructions in the `mag240m` folder to download and preprocess the dataset. From d12722039d965e89b013c1265caf983b8858334a Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Fri, 26 Sep 2025 16:00:14 -0700 Subject: [PATCH 17/48] Fix cache generator with correct input shape for destination gather --- experiments/OGB-LSC/CacheGenerator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py index eb5d4f6..946d7ab 100644 --- a/experiments/OGB-LSC/CacheGenerator.py +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -73,7 +73,7 @@ def get_cache( indices=dest_indices, edge_placement=edge_location, edge_dest_ranks=dest_data_mappings, - num_input_rows=num_input_rows, + num_input_rows=num_output_rows, rank=rank, world_size=world_size, ) @@ -149,7 +149,6 @@ def main(dataset): print(f"Rank mapping shape: {rank_mapping[0].shape}") print(f"Rank mapping shape: {rank_mapping[1].shape}") - breakpoint() get_cache( src_gather_cache=None, dest_gather_cache=None, From a6962cc8e00883a121aadb9a5fa5283c2c4dd6c7 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Fri, 26 Sep 2025 16:24:35 -0700 Subject: [PATCH 18/48] Fix on batch norm to have correct local variance reduction --- experiments/OGB-LSC/distributed_layers.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/experiments/OGB-LSC/distributed_layers.py b/experiments/OGB-LSC/distributed_layers.py index 4fd76d4..b59d18e 100644 --- a/experiments/OGB-LSC/distributed_layers.py +++ b/experiments/OGB-LSC/distributed_layers.py @@ -28,7 +28,7 @@ def _compute_bn_forward(input, learned_gamma=None, learned_beta=None): dist.all_reduce(global_num_rows, op=dist.ReduceOp.SUM) global_mean = global_sum / global_num_rows - local_var = (input - global_mean) ** 2 + local_var = ((input - global_mean) ** 2).sum(dim=0) global_var = local_var.clone() dist.all_reduce(global_sum, op=dist.ReduceOp.SUM) dist.all_reduce(global_var, op=dist.ReduceOp.SUM) @@ -153,8 +153,8 @@ def __init__( ): super(DistributedBatchNorm1D, self).__init__() if affine: - self.gamma = nn.Parameter(torch.ones(num_features)) - self.beta = nn.Parameter(torch.zeros(num_features)) + self.gamma = nn.Parameter(torch.ones(1, num_features)) + self.beta = nn.Parameter(torch.zeros(1, num_features)) else: self.register_parameter("gamma", None) self.register_parameter("beta", None) @@ -162,8 +162,8 @@ def __init__( self.momentum = momentum self.track_running_stats = track_running_stats if self.track_running_stats: - self.register_buffer("running_mean", torch.zeros(num_features)) - self.register_buffer("running_var", torch.ones(num_features)) + self.register_buffer("running_mean", torch.zeros(1, num_features)) + self.register_buffer("running_var", torch.ones(1, num_features)) self.register_buffer( "num_batches_tracked", torch.tensor(0, dtype=torch.long) ) @@ -188,6 +188,7 @@ def forward(self, x): if self.track_running_stats: self.num_batches_tracked += 1 y, mean, var = self.bn(x, self.gamma, self.beta) + if self.track_running_stats: with torch.no_grad(): self.running_mean = ( From f85f133b97a8364c30ea0d1ba42cd6f583d37747 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Fri, 26 Sep 2025 18:37:48 -0700 Subject: [PATCH 19/48] Added additional parameters to batch norm for backprop --- experiments/OGB-LSC/distributed_layers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/experiments/OGB-LSC/distributed_layers.py b/experiments/OGB-LSC/distributed_layers.py index b59d18e..1387ad6 100644 --- a/experiments/OGB-LSC/distributed_layers.py +++ b/experiments/OGB-LSC/distributed_layers.py @@ -46,10 +46,10 @@ def _compute_bn_backward( ): if learned_gamma is not None and learned_beta is not None: local_dbeta = torch.sum(grad_output, dim=0) - global_dbeta = local_dbeta.clone() + global_dbeta = local_dbeta.clone().unsqueeze(0) dist.all_reduce(global_dbeta, op=dist.ReduceOp.SUM) local_dgamma = torch.sum(grad_output * x_hat, dim=0) - global_dgamma = local_dgamma.clone() + global_dgamma = local_dgamma.clone().unsqueeze(0) dist.all_reduce(global_dgamma, op=dist.ReduceOp.SUM) dx_hat = grad_output * learned_gamma else: @@ -90,7 +90,7 @@ def forward(ctx, input, learned_gamma=None, learned_beta=None): return output, global_mean, global_var @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output, grad_mean, grad_var): x = ctx.input mean = ctx.mean var = ctx.var @@ -125,7 +125,7 @@ def forward(ctx, input, learned_gamma=None, learned_beta=None): return output, global_mean, global_var @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output, grad_mean, grad_var): learned_gamma = ctx.learned_gamma learned_beta = ctx.learned_beta From 750d8e6ed1a0525d9828075fc9884b0b5e0bb752 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Sun, 28 Sep 2025 09:17:10 -0700 Subject: [PATCH 20/48] Adding helper functions to sync normalization values and fixed evaluation on trainer --- experiments/OGB-LSC/Trainer.py | 64 ++++++++++++++++++----- experiments/OGB-LSC/distributed_layers.py | 7 +++ 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/experiments/OGB-LSC/Trainer.py b/experiments/OGB-LSC/Trainer.py index 0242b0d..7c70399 100644 --- a/experiments/OGB-LSC/Trainer.py +++ b/experiments/OGB-LSC/Trainer.py @@ -14,6 +14,9 @@ import torch from RGAT import CommAwareRGAT from config import ModelConfig, TrainingConfig +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from distributed_layers import GetGlobalVal class Trainer: @@ -40,6 +43,10 @@ def __init__(self, dataset, comm): comm=comm, dropout=self.model_config.dropout, ).to(self.device) + self.model = DDP(self.model, device_ids=[rank % num_gpus]) + self.optimizer = torch.optim.Adam( + self.model.parameters(), lr=self.training_config.lr, weight_decay=5e-4 + ) def prepare_data(self): self.dataset = self.dataset.add_batch_dimension() @@ -52,29 +59,58 @@ def train(self): for epoch in range(1, self.training_config.epochs + 1): out = self.model(xs, edge_index, edge_type, rank_mapping) + train_mask = self.dataset.get_mask("train") + local_train_vertices = out[:, train_mask, :].squeeze(0) + target = self.dataset.get_target("train") + loss = torch.nn.functional.cross_entropy( - out[self.dataset.train_mask], self.dataset.y[self.dataset.train_mask] + local_train_vertices, target, reduction="sum" ) + local_num_targets = target.size(0) + global_num_targets = GetGlobalVal(local_num_targets) + loss = loss / global_num_targets # Average the loss + self.model.zero_grad() loss.backward() + self.optimizer.step() return loss.item() @torch.no_grad() def evaluate(self): self.model.eval() - out = self.model( - self.dataset.x, self.dataset.edge_index, self.dataset.rank_mapping - ) - y_true = self.dataset.y.cpu().numpy() + + xs, edge_index, edge_type, rank_mapping = self.dataset[0] + out = self.model(xs, edge_index, edge_type, rank_mapping) + y_pred = out.argmax(dim=-1, keepdim=True).cpu().numpy() + train_mask = self.dataset.get_mask("train").cpu().numpy() + val_mask = self.dataset.get_mask("val").cpu().numpy() + test_mask = self.dataset.get_mask("test").cpu().numpy() + y_true_train = self.dataset.get_target("train").cpu().numpy() + y_pred_val = self.dataset.get_target("val").cpu().numpy() + y_pred_test = self.dataset.get_target("test").cpu().numpy() + + train_acc = (y_pred[train_mask] == y_true_train).sum() / int(train_mask.sum()) + # Not guaranteed to have validation or test samples on every rank + num_local_val_samples = int(val_mask.sum()) + num_local_test_samples = int(test_mask.sum()) + if num_local_val_samples == 0: + val_acc = 0.0 + else: + val_acc = (y_pred[val_mask] == y_pred_val).sum().item() + val_acc = GetGlobalVal(val_acc) + + num_global_val_samples = GetGlobalVal(num_local_val_samples) + val_acc = val_acc / int(num_global_val_samples) + + if num_local_test_samples == 0: + test_acc = 0.0 + else: + test_acc = (y_pred[test_mask] == y_pred_test).sum().item() + + test_acc = GetGlobalVal(test_acc) + num_global_test_samples = GetGlobalVal(num_local_test_samples) + test_acc = test_acc / int(num_global_test_samples) - train_acc = ( - y_pred[self.dataset.train_mask] == y_true[self.dataset.train_mask] - ).sum() / int(self.dataset.train_mask.sum()) - val_acc = ( - y_pred[self.dataset.val_mask] == y_true[self.dataset.val_mask] - ).sum() / int(self.dataset.val_mask.sum()) - test_acc = ( - y_pred[self.dataset.test_mask] == y_true[self.dataset.test_mask] - ).sum() / int(self.dataset.test_mask.sum()) + # All ranks should have the same accuracy values return train_acc, val_acc, test_acc diff --git a/experiments/OGB-LSC/distributed_layers.py b/experiments/OGB-LSC/distributed_layers.py index 1387ad6..54408b1 100644 --- a/experiments/OGB-LSC/distributed_layers.py +++ b/experiments/OGB-LSC/distributed_layers.py @@ -205,3 +205,10 @@ def forward(self, x): if y.dim() == 2: y = y.unsqueeze(0) return y + + +def GetGlobalVal(local_val): + """Get the global sum of a local value across all ranks.""" + global_val = torch.tensor([local_val]).cuda() + dist.all_reduce(global_val, op=dist.ReduceOp.SUM) + return global_val.item() From 657111e35f2c8ba08029d08ba8058742d54e374b Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Fri, 3 Oct 2025 13:33:19 -0700 Subject: [PATCH 21/48] Latest changes to RGAT --- DGraph/distributed/nccl/NCCLBackendEngine.py | 1 + DGraph/distributed/nccl/alltoallv_impl.py | 22 ----- experiments/OGB-LSC/CacheGenerator.py | 16 ++-- experiments/OGB-LSC/RGAT.py | 9 +- .../OGB-LSC/synthetic/synthetic_dataset.py | 95 +++++++++++++++---- 5 files changed, 92 insertions(+), 51 deletions(-) diff --git a/DGraph/distributed/nccl/NCCLBackendEngine.py b/DGraph/distributed/nccl/NCCLBackendEngine.py index 55a0e5e..c10c1ee 100644 --- a/DGraph/distributed/nccl/NCCLBackendEngine.py +++ b/DGraph/distributed/nccl/NCCLBackendEngine.py @@ -485,6 +485,7 @@ def backward(ctx, grad_output): src_rank_loc=recv_ranks, rank=rank, world_size=world_size, + cache=cache, ) # if rank == 0: diff --git a/DGraph/distributed/nccl/alltoallv_impl.py b/DGraph/distributed/nccl/alltoallv_impl.py index c9cd783..939343d 100644 --- a/DGraph/distributed/nccl/alltoallv_impl.py +++ b/DGraph/distributed/nccl/alltoallv_impl.py @@ -84,11 +84,6 @@ def _nccl_alltoall_v( recv_local_placement = cache.gather_recv_local_placement send_local_placement = cache.gather_send_local_placement - dist.barrier() - if rank == 0: - breakpoint() - dist.barrier() - # Allocate the receive buffers for i, num_messages in enumerate(recv_comm_vector): if num_messages == 0: @@ -128,29 +123,12 @@ def _nccl_alltoall_v( recv_tensor = recv_buffer_dict[send_rank_index] p2p_op_list.append(dist.P2POp(dist.irecv, recv_tensor, send_rank_index)) - # For debugging: Delete later - dist.barrier() - for i in range(world_size): - if i == rank: - print(f"Rank {rank} starting batch_isend_irecv") - for op in p2p_op_list: - print(f"Rank {rank} {op.op.__name__} {op.tensor.shape} to {op.peer}") - for key, recv_buffer in recv_buffer_dict.items(): - print(f"Rank {rank} expecting {recv_buffer.shape} from {key}") - dist.barrier() if len(p2p_op_list) > 0: reqs = dist.batch_isend_irecv(p2p_op_list) for req in reqs: req.wait() - # For debugging: Delete later - dist.barrier() - for i in range(world_size): - if i == rank: - print(f"Rank {rank} reached here") - dist.barrier() - for key, recv_buffer in recv_buffer_dict.items(): local_recv_tensor[:, recv_local_placement[key].view(-1), :] = ( diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py index 946d7ab..35f3566 100644 --- a/experiments/OGB-LSC/CacheGenerator.py +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -35,17 +35,17 @@ def get_cache( edge_location, src_data_mappings, dest_data_mappings, - num_input_rows, - num_output_rows, + num_src_rows, + num_dest_rows, ): - + """ """ if src_gather_cache is None: _src_gather_cache = NCCLGatherCacheGenerator( indices=src_indices, edge_placement=edge_location, edge_dest_ranks=src_data_mappings, - num_input_rows=num_input_rows, + num_input_rows=num_src_rows, rank=rank, world_size=world_size, ) @@ -59,7 +59,7 @@ def get_cache( indices=dest_indices, edge_placement=edge_location, edge_dest_ranks=dest_data_mappings, - num_output_rows=num_output_rows, + num_output_rows=num_dest_rows, rank=rank, world_size=world_size, ) @@ -73,7 +73,7 @@ def get_cache( indices=dest_indices, edge_placement=edge_location, edge_dest_ranks=dest_data_mappings, - num_input_rows=num_output_rows, + num_input_rows=num_dest_rows, rank=rank, world_size=world_size, ) @@ -163,8 +163,8 @@ def main(dataset): edge_location=rank_mapping[0], src_data_mappings=rank_mapping[0], dest_data_mappings=rank_mapping[1], - num_input_rows=xs[edge_type[0]].shape[1], - num_output_rows=xs[edge_type[1]].shape[1], + num_src_rows=xs[edge_type[0]].shape[1], + num_dest_rows=xs[edge_type[1]].shape[1], ) rel += 1 diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index c183eff..9f7b8a2 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -215,7 +215,7 @@ def __init__( for _ in range(num_relations): relation_specific_convs.append( CommAwareGAT( - hidden_channels * heads, + hidden_channels, hidden_channels, heads=heads, bias=True, @@ -323,8 +323,7 @@ def forward(self, xs, adjts, edge_types, rank_mappings): for j, (edge_index, edge_type, rank_mapping) in enumerate( zip(adjts, edge_types, rank_mappings) ): - if j != 3: - continue + src_edge_type, dst_edge_type = edge_type if self.use_cache: caches = get_cache( @@ -341,8 +340,8 @@ def forward(self, xs, adjts, edge_types, rank_mappings): edge_location=rank_mapping[0], src_data_mappings=rank_mapping[0], dest_data_mappings=rank_mapping[1], - num_input_rows=outs[src_edge_type].size(1), - num_output_rows=outs[dst_edge_type].size(1), + num_src_rows=outs[src_edge_type].size(1), + num_dest_rows=outs[dst_edge_type].size(1), ) src_gather_cache, dest_scatter_cache, dest_gather_cache = caches else: diff --git a/experiments/OGB-LSC/synthetic/synthetic_dataset.py b/experiments/OGB-LSC/synthetic/synthetic_dataset.py index 7d7b520..5ecdb8b 100644 --- a/experiments/OGB-LSC/synthetic/synthetic_dataset.py +++ b/experiments/OGB-LSC/synthetic/synthetic_dataset.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: (Apache-2.0) from DGraph.Communicator import Communicator import torch +from typing import Tuple torch.random.manual_seed(0) @@ -182,6 +183,7 @@ def __init__( institution_vertices_cur_rank = int( (self.institution_vertex_rank_mapping == self.rank).sum() ) + self.paper_vertices_cur_rank = paper_vertices_cur_rank self.paper_features = torch.randn( (paper_vertices_cur_rank, num_features), dtype=torch.float32 @@ -193,23 +195,83 @@ def __init__( (institution_vertices_cur_rank, num_features), dtype=torch.float32 ) - def get_validation_mask(self): - # Only papers are classified - validation_vertices_mappings = self.paper_vertex_rank_mapping[self.val_mask] - num_validation_vertices = (validation_vertices_mappings == self.rank).sum() - if num_validation_vertices > 0: - return self.val_mask[validation_vertices_mappings == self.rank] + # def get_validation_mask(self): + # # Only papers are classified + # validation_vertices_mappings = self.paper_vertex_rank_mapping[self.val_mask] + # validation_vertices_mappings = validation_vertices_mappings.to( + # self.val_mask.device + # ) + # num_validation_vertices = (validation_vertices_mappings == self.rank).sum() + # if num_validation_vertices > 0: + # return ( + # self.val_mask[validation_vertices_mappings == self.rank] + # % self.paper_vertices_cur_rank + # ) + # else: + # return torch.tensor([], dtype=torch.long) + + # def get_test_mask(self): + # # Only papers are classified + + # paper_vertices = self.paper_vertex_rank_mapping == self.rank + # paper_vertices = paper_vertices.to(self.test_mask.device) + # num_test_vertices = (paper_vertices[self.test_mask] == self.rank).sum() + # if num_test_vertices > 0: + # return ( + # self.test_mask[paper_vertices[self.test_mask] == self.rank] + # % self.paper_vertices_cur_rank + # ) + # else: + # return torch.tensor([], dtype=torch.long) + + # def get_train_mask(self): + # # Only papers are classified + # paper_vertices = self.paper_vertex_rank_mapping == self.rank + # paper_vertices = paper_vertices.to(self.train_mask.device) + # num_train_vertices = (paper_vertices[self.train_mask] == self.rank).sum() + # if num_train_vertices > 0: + # return ( + # self.train_mask[paper_vertices[self.train_mask] == self.rank] + # % self.paper_vertices_cur_rank + # ) + # else: + # return torch.tensor([], dtype=torch.long) + + def get_vertex_rank_mask(self, mask_type: str) -> Tuple[torch.Tensor, torch.Tensor]: + if mask_type == "train": + global_int_mask = self.train_mask + elif mask_type == "val": + global_int_mask = self.val_mask + elif mask_type == "test": + global_int_mask = self.test_mask else: - return torch.tensor([], dtype=torch.long) - - def get_test_mask(self): - # Only papers are classified - paper_vertices = self.paper_vertex_rank_mapping == self.rank - num_test_vertices = (paper_vertices[self.test_mask] == self.rank).sum() - if num_test_vertices > 0: - return self.test_mask[paper_vertices[self.test_mask] == self.rank] - else: - return torch.tensor([], dtype=torch.long) + raise ValueError(f"Invalid mask type: {mask_type}") + + # Get the ranks of the vertices + # paper_vertex_rank_mapping -> vector of size num_papers, + # where each entry is the location / rank of the vertex + paper_vertex_rank_mapping = self.paper_vertex_rank_mapping.to( + global_int_mask.device + ) + vertex_ranks = paper_vertex_rank_mapping[global_int_mask] + # vertex_ranks is location of the vertices in the global_int_mask + vertex_ranks_mask = vertex_ranks == self.rank + return global_int_mask, vertex_ranks_mask + + def get_mask(self, mask_type: str) -> torch.Tensor: + + global_int_mask, vertex_ranks_mask = self.get_vertex_rank_mask(mask_type) + local_int_mask = global_int_mask[vertex_ranks_mask] + local_int_mask = local_int_mask % self.paper_vertices_cur_rank + return local_int_mask + + def get_target(self, _type: str) -> torch.Tensor: + global_int_mask, vertex_ranks_mask = self.get_vertex_rank_mask(_type) + + global_training_targets = self.y[:, global_int_mask.squeeze(0)] + local_training_targets = global_training_targets[vertex_ranks_mask] + + return local_training_targets def __len__(self): return 0 @@ -249,6 +311,7 @@ def to(self, device): self.paper_2_paper_edges = self.paper_2_paper_edges.to(device) self.author_2_paper_edges = self.author_2_paper_edges.to(device) self.author_2_institution_edges = self.author_2_institution_edges.to(device) + return self def __getitem__(self, idx): From 746693d92d3713aea9327946ac8646a2d3252ff3 Mon Sep 17 00:00:00 2001 From: Keita Iwabuchi Date: Fri, 17 Oct 2025 22:55:52 -0700 Subject: [PATCH 22/48] (OGB-LSC) Bugfix for geenrating _dest_scatter_cache --- experiments/OGB-LSC/CacheGenerator.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py index 35f3566..e99ad06 100644 --- a/experiments/OGB-LSC/CacheGenerator.py +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -56,10 +56,10 @@ def get_cache( if dest_scatter_cache is None: _dest_scatter_cache = NCCLScatterCacheGenerator( - indices=dest_indices, + indices=src_indices, edge_placement=edge_location, - edge_dest_ranks=dest_data_mappings, - num_output_rows=num_dest_rows, + edge_dest_ranks=src_data_mappings, + num_output_rows=num_src_rows, rank=rank, world_size=world_size, ) @@ -181,6 +181,4 @@ def main(dataset): print(synthetic_scatter_cache_1.scatter_recv_local_placement) print(synthetic_scatter_cache_0.scatter_recv_local_placement) - breakpoint() - Fire(main) From 296fa2f93527379dd1e4ab42a65fce06d999397a Mon Sep 17 00:00:00 2001 From: Keita Iwabuchi Date: Fri, 17 Oct 2025 22:58:09 -0700 Subject: [PATCH 23/48] (OGB-LSC) Workaround for DDP's unsed parameter error --- experiments/OGB-LSC/RGAT.py | 9 +++++++++ experiments/OGB-LSC/Trainer.py | 9 ++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index 9f7b8a2..82053ca 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -19,6 +19,7 @@ import os.path as osp from CacheGenerator import get_cache import sys +import os class ConvLayer(nn.Module): @@ -383,4 +384,12 @@ def forward(self, xs, adjts, edge_types, rank_mappings): for feat in range(len(outs)) ] + dummy_prameters_use = bool(int(os.getenv("RGAT_DUMMY_ALL_PARAMS_USE", "0"))) + if dummy_prameters_use: + # Dummy operation to touch all outs to avoid DDP's 'unused parameters' + dummy = torch.zeros(1, device=outs[0].device, dtype=outs[0].dtype) + for t in outs: + dummy = dummy + (t[0].sum() * 0.0) # zero-valued scalar that depends on t + outs[0][0] = outs[0][0] + dummy + return self.mlp(outs[0]) diff --git a/experiments/OGB-LSC/Trainer.py b/experiments/OGB-LSC/Trainer.py index 7c70399..65b54dd 100644 --- a/experiments/OGB-LSC/Trainer.py +++ b/experiments/OGB-LSC/Trainer.py @@ -17,6 +17,7 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from distributed_layers import GetGlobalVal +import os class Trainer: @@ -43,7 +44,13 @@ def __init__(self, dataset, comm): comm=comm, dropout=self.model_config.dropout, ).to(self.device) - self.model = DDP(self.model, device_ids=[rank % num_gpus]) + # Enable unused-parameter detection only if requested (reduces sync errors with moderate overhead) + ddp_find_unused = bool(int(os.getenv("RGAT_DDP_FIND_UNUSED", "0"))) + self.model = DDP( + self.model, + device_ids=[rank % num_gpus], + find_unused_parameters=ddp_find_unused, + ) self.optimizer = torch.optim.Adam( self.model.parameters(), lr=self.training_config.lr, weight_decay=5e-4 ) From fa585a32246c5cee73bd4450a59c5e3e7f1eec41 Mon Sep 17 00:00:00 2001 From: Keita Iwabuchi Date: Fri, 17 Oct 2025 23:00:35 -0700 Subject: [PATCH 24/48] (OGB-LSC) Some performance optimizations --- DGraph/distributed/nccl/_nccl_cache.py | 1 - experiments/OGB-LSC/Trainer.py | 13 ++++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/DGraph/distributed/nccl/_nccl_cache.py b/DGraph/distributed/nccl/_nccl_cache.py index 0774e58..6247aa3 100644 --- a/DGraph/distributed/nccl/_nccl_cache.py +++ b/DGraph/distributed/nccl/_nccl_cache.py @@ -252,7 +252,6 @@ def NCCLGatherCacheGenerator( indices, edge_placement, edge_dest_ranks, num_input_rows, rank, world_size ) ) - breakpoint() local_slice_mask = edge_placement == rank local_mask = edge_placement[local_slice_mask] diff --git a/experiments/OGB-LSC/Trainer.py b/experiments/OGB-LSC/Trainer.py index 65b54dd..b0744f3 100644 --- a/experiments/OGB-LSC/Trainer.py +++ b/experiments/OGB-LSC/Trainer.py @@ -64,11 +64,16 @@ def train(self): xs, edge_index, edge_type, rank_mapping = self.dataset[0] + # Fetch once; masks/targets are static across epochs + train_mask = self.dataset.get_mask("train") + target = self.dataset.get_target("train") + for epoch in range(1, self.training_config.epochs + 1): + # zero grads before forward to avoid dangling reduction state + self.optimizer.zero_grad(set_to_none=True) + out = self.model(xs, edge_index, edge_type, rank_mapping) - train_mask = self.dataset.get_mask("train") local_train_vertices = out[:, train_mask, :].squeeze(0) - target = self.dataset.get_target("train") loss = torch.nn.functional.cross_entropy( local_train_vertices, target, reduction="sum" @@ -76,9 +81,11 @@ def train(self): local_num_targets = target.size(0) global_num_targets = GetGlobalVal(local_num_targets) loss = loss / global_num_targets # Average the loss - self.model.zero_grad() + loss.backward() self.optimizer.step() + if self.comm.get_rank() == 0: + print(f"Epoch {epoch:03d} | loss {loss.item():.4f}") return loss.item() @torch.no_grad() From 483bc48c606fe526e2c536e8d677b156797bc686 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 29 Oct 2025 14:29:40 -0700 Subject: [PATCH 25/48] Fix DGraph Mag240M dataset __getitem__ method --- experiments/OGB-LSC/mag240m/DGraph_MAG240M.py | 54 ++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py index 35b4498..333235d 100644 --- a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py +++ b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py @@ -18,6 +18,7 @@ import numpy as np from tqdm import tqdm import os.path as osp +from DGraph.Communicator import Communicator def get_col_slice(x, start_row_idx, end_row_idx, start_col_idx, end_col_idx): @@ -112,7 +113,7 @@ def _generate_features_from_paper_features( class DGraph_MAG240M: def __init__( self, - comm, + comm: Communicator, data_dir: str = "data/MAG240M", paper_rank_mappings: Optional[torch.Tensor] = None, author_rank_mappings: Optional[torch.Tensor] = None, @@ -126,6 +127,7 @@ def __init__( self.num_authors = self.dataset.num_authors self.num_institutions = self.dataset.num_institutions self.num_classes = self.dataset.num_classes + self.paper_rank_mappings = ( paper_rank_mappings if paper_rank_mappings is not None @@ -159,6 +161,8 @@ def __init__( # paper -> paper self.process_feature_data() + self.paper_features = + def process_feature_data(self): dataset = self.dataset # This function emulates the data processing step here: @@ -219,6 +223,54 @@ def add_batch_dimension(self): def to(self, device): return self + def __len__(self): + return 1 + + def __getitem__(self, idx): + # There are 5 relations: + # paper -> paper + # paper -> author + # author -> paper + # author -> institution + # institution -> author + edge_index = [ + self.paper_2_paper_edges, + self.author_2_paper_edges, + self.author_2_paper_edges.flip(self.author_2_paper_edges.dim() - 2), + self.author_2_institution_edges, + self.author_2_institution_edges.flip( + self.author_2_institution_edges.dim() - 2 + ), + ] + # Locations of the edges + rank_mappings = [ + [self.paper_src_data_mappings, self.paper_dest_data_mappings], + [ + self.author_2_paper_src_data_mappings, + self.author_2_paper_dest_data_mappings, + ], + [ + self.author_2_paper_dest_data_mappings, + self.author_2_paper_src_data_mappings, + ], + [ + self.author_2_institution_src_data_mappings, + self.author_2_institution_dest_data_mappings, + ], + [ + self.author_2_institution_dest_data_mappings, + self.author_2_institution_src_data_mappings, + ], + ] + edge_type = [(0, 0), (1, 0), (0, 1), (1, 2), (2, 1)] + features = [ + self.paper_features, + self.author_features, + self.institution_features, + ] + + return (features, edge_index, edge_type, rank_mappings) + if __name__ == "__main__": import fire From 7930024482d2a2c0fac6fce61e632ae7bd11cb0b Mon Sep 17 00:00:00 2001 From: Keita Iwabuchi Date: Fri, 24 Oct 2025 08:08:37 -0700 Subject: [PATCH 26/48] Remove debug messages --- DGraph/distributed/nccl/NCCLBackendEngine.py | 18 ------ DGraph/distributed/nccl/alltoallv_impl.py | 7 --- experiments/OGB-LSC/RGAT.py | 63 +------------------- 3 files changed, 1 insertion(+), 87 deletions(-) diff --git a/DGraph/distributed/nccl/NCCLBackendEngine.py b/DGraph/distributed/nccl/NCCLBackendEngine.py index c10c1ee..299dd02 100644 --- a/DGraph/distributed/nccl/NCCLBackendEngine.py +++ b/DGraph/distributed/nccl/NCCLBackendEngine.py @@ -116,14 +116,6 @@ def forward( needs_comm = (local_recv_tensor != rank).any() - # For debugging: Delete later - dist.barrier() - for i in range(world_size): - if i == rank: - print(f"Rank {rank} reached local gather") - dist.barrier() - dist.barrier() - recv_tensor = OptimizedRankLocalMaskedGather( local_send_tensor, local_indices, @@ -131,13 +123,6 @@ def forward( recv_tensor, rank, ) - # For debugging: Delete later - dist.barrier() - for i in range(world_size): - if i == rank: - print(f"Rank {rank} finished local gather") - dist.barrier() - dist.barrier() if needs_comm: @@ -488,9 +473,6 @@ def backward(ctx, grad_output): cache=cache, ) - # if rank == 0: - # breakpoint() - # dist.barrier() # NOTE: even if the inputs are non-tensors, the number of backward outputs # must be the same as the number of inputs. send_tensor_grad = recv_tensor diff --git a/DGraph/distributed/nccl/alltoallv_impl.py b/DGraph/distributed/nccl/alltoallv_impl.py index 939343d..060c390 100644 --- a/DGraph/distributed/nccl/alltoallv_impl.py +++ b/DGraph/distributed/nccl/alltoallv_impl.py @@ -18,13 +18,6 @@ def _nccl_alltoall_v( num_features = local_send_tensor.shape[2] num_src_rows = local_send_tensor.shape[1] - # For debugging: Delete later - dist.barrier() - for i in range(world_size): - if i == rank: - print(f"Rank {rank} starting comm") - dist.barrier() - recv_buffer_dict = {} if cache is None: diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index 82053ca..4029475 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -80,66 +80,25 @@ def forward( _src_indices = edge_index[:, 0, :] _dst_indices = edge_index[:, 1, :] - self.comm.barrier() _src_rank_mappings = torch.cat( [rank_mapping[0].unsqueeze(0), rank_mapping[0].unsqueeze(0)], dim=0 ) - self.comm.barrier() _dst_rank_mappings = torch.cat( [rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0 ) - self.comm.barrier() - if self.comm.get_rank() == 0: - print("starting gather") - self.comm.barrier() - - if self.comm.get_rank() == 0: - print(f"h shape: {h.shape}") - print(f"h_j shape: {h_j.shape}") - # breakpoint() - self.comm.barrier() - - # sys.exit(0) # --- IGNORE --- - h_i = self.comm.gather( h, _dst_indices, _dst_rank_mappings, cache=dest_gather_cache ) - if self.comm.get_rank() == 0: - print("finished computing _dst_rank_mappings") - self.comm.barrier() - - for i in range(self.comm.get_world_size()): - self.comm.barrier() - if self.comm.get_rank() == i: - print(f"Rank {i} h_i shape: {h_i.shape}") - self.comm.barrier() - - self.comm.barrier() - if self.comm.get_rank() == 0: - print("finished gathering h_i") - self.comm.barrier() - h_j = self.comm.gather( h_j, _src_indices, _src_rank_mappings, cache=src_gather_cache ) - self.comm.barrier() - if self.comm.get_rank() == 0: - print("finished gathering h_j") - self.comm.barrier() messages = torch.cat([h_i, h_j], dim=-1) edge_scores = self.leaky_relu(self.project_message(messages)) numerator = torch.exp(edge_scores) - if self.comm.get_rank() == 0: - print(f"Numerator shape: {numerator.shape}") - - self.comm.barrier() - if self.comm.get_rank() == 0: - print("starting scatter") - self.comm.barrier() denominator = self.comm.scatter( numerator, _dst_indices, @@ -147,14 +106,10 @@ def forward( h.size(1), cache=dest_scatter_cache, ) - self.comm.barrier() - if self.comm.get_rank() == 0: - print("finished scatter") - self.comm.barrier() + denominator = self.comm.gather( denominator, _src_indices, _src_rank_mappings, cache=dest_gather_cache ) - self.comm.barrier() alpha_ij = numerator / (denominator + 1e-16) attention_messages = h_j * alpha_ij @@ -350,18 +305,6 @@ def forward(self, xs, adjts, edge_types, rank_mappings): dest_scatter_cache = None dest_gather_cache = None - self.comm.barrier() - if self.comm.get_rank() == 0: - print( - f"Layer {i} Relation {j} started on rank {self.comm.get_rank()}" - ) - print( - f"Edge index shape: {edge_index.shape}" - f" Edge type: {edge_type}", - f" src tensor shape: {outs[src_edge_type].shape}", - f" dst tensor shape: {outs[dst_edge_type].shape}", - ) - self.comm.barrier() temp_outs[dst_edge_type] += self.layers[i][j]( outs[dst_edge_type], edge_index, @@ -371,10 +314,6 @@ def forward(self, xs, adjts, edge_types, rank_mappings): dest_gather_cache=dest_gather_cache, dest_scatter_cache=dest_scatter_cache, ) - self.comm.barrier() - if self.comm.get_rank() == 0: - print(f"Layer {i} Relation {j} done on rank {self.comm.get_rank()}") - self.comm.barrier() outs = [ self.bn_layers[i](temp_outs[feat]) for feat in range(len(temp_outs)) ] From 09735a7663ffe16760ec4044764486f07a721031 Mon Sep 17 00:00:00 2001 From: Keita Iwabuchi Date: Fri, 24 Oct 2025 23:45:50 -0700 Subject: [PATCH 27/48] (OGB-LSC) Bugfix for mag240m dataset --- experiments/OGB-LSC/Trainer.py | 18 +- experiments/OGB-LSC/config.py | 9 +- experiments/OGB-LSC/mag240m/DGraph_MAG240M.py | 282 ++++++++++++++---- experiments/OGB-LSC/main.py | 24 +- .../OGB-LSC/synthetic/synthetic_dataset.py | 18 +- 5 files changed, 279 insertions(+), 72 deletions(-) diff --git a/experiments/OGB-LSC/Trainer.py b/experiments/OGB-LSC/Trainer.py index b0744f3..bd6cf3c 100644 --- a/experiments/OGB-LSC/Trainer.py +++ b/experiments/OGB-LSC/Trainer.py @@ -35,10 +35,10 @@ def __init__(self, dataset, comm): torch.cuda.set_device(rank % num_gpus) self.device = torch.device("cuda") self.model = CommAwareRGAT( - in_channels=self.model_config.num_features, - out_channels=self.model_config.num_classes, + in_channels=self.dataset.num_features, + out_channels=self.dataset.num_classes, + num_relations=self.dataset.num_relations, hidden_channels=self.model_config.hidden_channels, - num_relations=self.model_config.num_relations, num_layers=self.model_config.num_layers, heads=self.model_config.heads, comm=comm, @@ -64,6 +64,18 @@ def train(self): xs, edge_index, edge_type, rank_mapping = self.dataset[0] + # Early sanity check: first feature tensor last dim vs configured num_features + configured = self.dataset.num_features + actual = xs[0].size(-1) if isinstance(xs, (list, tuple)) else xs.size(-1) + if ( + configured != actual + and self.comm.get_rank() == 0 + ): + print( + f"[RGAT] Warning: configured in_channels={configured} but feature dim={actual}; " + f"layers will adapt lazily." + ) + # Fetch once; masks/targets are static across epochs train_mask = self.dataset.get_mask("train") target = self.dataset.get_target("train") diff --git a/experiments/OGB-LSC/config.py b/experiments/OGB-LSC/config.py index a811160..49d8164 100644 --- a/experiments/OGB-LSC/config.py +++ b/experiments/OGB-LSC/config.py @@ -17,14 +17,15 @@ @dataclass class ModelConfig: - hidden_channels: int = 16 + hidden_channels: int = 1024 dropout: float = 0.5 num_layers: int = 2 - num_features: int = 16 - num_relations: int = 5 - num_classes: int = 153 heads: int = 4 use_cache: bool = True + # Those numbers are available in the dataset classes (synthetic or mag240m) + # num_features: int = 768 + # num_relations: int = 5 + # num_classes: int = 153 @dataclass diff --git a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py index 333235d..565fd6e 100644 --- a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py +++ b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py @@ -61,6 +61,16 @@ def get_rank_mappings(num_nodes, world_size, rank): rank_mappings[start:end] = r return rank_mappings +def edge_mapping_from_vertex_mapping(edge_index, src_rank_mappings, dst_rank_mappings): + # directed edges, so edge_index[0] -> edge_index[1] + src_indices = edge_index[0] + dest_indices = edge_index[1] + # We put the edge on the rank where the destination vertex is located + # Since heterogeneous graphs have different rank mappings for different + # vertex types. + src_data_mappings = src_rank_mappings[src_indices] + dest_data_mappings = dst_rank_mappings[dest_indices] + return (src_data_mappings, dest_data_mappings) def get_edge_mappings(src_indices, dst_indices, rank_mappings): edge_mappings = torch.zeros_like(src_indices) @@ -111,6 +121,8 @@ def _generate_features_from_paper_features( class DGraph_MAG240M: + + # data_dir must be the location where all ranks can access def __init__( self, comm: Communicator, @@ -126,8 +138,7 @@ def __init__( self.num_papers = self.dataset.num_papers self.num_authors = self.dataset.num_authors self.num_institutions = self.dataset.num_institutions - self.num_classes = self.dataset.num_classes - + # self.num_classes = self.dataset.num_classes self.paper_rank_mappings = ( paper_rank_mappings if paper_rank_mappings is not None @@ -146,31 +157,113 @@ def __init__( # authors -> paper self.write_mappings = get_edge_mappings( - self.dataset.edge_index("author", "paper")[0], - self.dataset.edge_index("author", "paper")[1], + torch.from_numpy(self.dataset.edge_index("author", "paper")[0]), + torch.from_numpy(self.dataset.edge_index("author", "paper")[1]), self.paper_rank_mappings, ) # author -> institution self.write_mappings_author_institution = get_edge_mappings( - self.dataset.edge_index("author", "institution")[0], - self.dataset.edge_index("author", "institution")[1], + torch.from_numpy(self.dataset.edge_index("author", "institution")[0]), + torch.from_numpy(self.dataset.edge_index("author", "institution")[1]), self.institution_rank_mappings, ) - self.num_features = 768 - # paper -> paper - self.process_feature_data() - self.paper_features = + _vertices = torch.randperm(self.num_papers) + self.train_mask = _vertices[: int(0.7 * self.num_papers)] + self.val_mask = _vertices[int(0.7 * self.num_papers) : int(0.85 * self.num_papers)] + self.test_mask = _vertices[int(0.85 * self.num_papers) :] + + local_papers_mask = self.paper_rank_mappings == self.rank + local_authors_mask = self.author_rank_mappings == self.rank + local_institutions_mask = self.institution_rank_mappings == self.rank + self.num_local_papers = int( + local_papers_mask.sum() + ) + + self.generate_feature_data() + + self.paper_features = torch.from_numpy(self.dataset.paper_feat[local_papers_mask]) + path = self.dataset.dir + self.author_features = torch.from_numpy(np.memmap( + filename=path + "/author_feat.npy", + mode="r", + dtype=np.float16, + shape=(self.num_authors, self.num_features), + )[local_authors_mask]) + self.institution_features = torch.from_numpy(np.memmap( + filename=path + "/institution_feat.npy", + mode="r", + dtype=np.float16, + shape=(self.num_institutions, self.num_features), + )[local_institutions_mask]) + self.y = torch.from_numpy(self.dataset.paper_label) + + self.paper_2_paper_edges = torch.from_numpy(self.dataset.edge_index('paper', 'cites', 'paper')) + ( + paper_2_paper_src_data_mappings, + paper_2_paper_dest_data_mappings, + ) = edge_mapping_from_vertex_mapping( + edge_index=self.paper_2_paper_edges, + src_rank_mappings=self.paper_rank_mappings, + dst_rank_mappings=self.paper_rank_mappings, + ) + self.paper_src_data_mappings = paper_2_paper_src_data_mappings + self.paper_dest_data_mappings = paper_2_paper_dest_data_mappings + + self.author_2_paper_edges = torch.from_numpy(self.dataset.edge_index('author', 'writes', 'paper')) + ( + author_2_paper_src_data_mappings, + author_2_paper_dest_data_mappings, + ) = edge_mapping_from_vertex_mapping( + edge_index=self.author_2_paper_edges, + src_rank_mappings=self.author_rank_mappings, + dst_rank_mappings=self.paper_rank_mappings, + ) + self.author_2_paper_src_data_mappings = author_2_paper_src_data_mappings + self.author_2_paper_dest_data_mappings = author_2_paper_dest_data_mappings + + self.author_2_institution_edges = torch.from_numpy(self.dataset.edge_index('author', 'institution')) + ( + author_2_institution_src_data_mappings, + author_2_institution_dest_data_mappings, + ) = edge_mapping_from_vertex_mapping( + edge_index=self.author_2_institution_edges, + src_rank_mappings=self.author_rank_mappings, + dst_rank_mappings=self.institution_rank_mappings, + ) - def process_feature_data(self): + self.author_2_institution_src_data_mappings = ( + author_2_institution_src_data_mappings + ) + self.author_2_institution_dest_data_mappings = ( + author_2_institution_dest_data_mappings + ) + + @property + def num_features(self) -> int: + # 768 + return self.dataset.num_paper_features + + @property + def num_classes(self) -> int: + # 153 + return self.dataset.num_classes + + @property + def num_relations(self) -> int: + # paper -> paper + # paper -> author + # author -> paper + # author -> institution + # institution -> author + return 5 + + def generate_feature_data(self): dataset = self.dataset - # This function emulates the data processing step here: + # This function emulates the author and institute features generation steps here # https://github.com/snap-stanford/ogb/blob/61e9784ca76edeaa6e259ba0f836099608ff0586/examples/lsc/mag240m/rgnn.py#L82 - # The above function converts the heterogenous graph to a homogeneous graph - # So we will do the same here - # Generate author features # Mag240M author features are generated from paper features num_authors = dataset.num_authors @@ -178,53 +271,135 @@ def process_feature_data(self): path = dataset.dir paper_feat = dataset.paper_feat + # Only one rank must do this work + if self.rank == 0: + if not osp.exists(path + "/author_feat.npy"): + print("Generating author features") + author_feat = np.memmap( + filename=path + "/author_feat.npy", + mode="w+", + dtype=np.float16, + shape=(num_authors, self.num_features), + ) + _generate_features_from_paper_features( + out=author_feat, + num_nodes=num_authors, + num_papers=num_papers, + paper_feat=paper_feat, + edge_index=dataset.edge_index("author", "paper"), + num_features=self.num_features, + ) + + if not osp.exists(path + "/institution_feat.npy"): + print("Generating institution features") + # Generate institution features + num_institutions = dataset.num_institutions + institution_feat = np.memmap( + filename=path + "/institution_feat.npy", + mode="w+", + dtype=np.float16, + shape=(num_institutions, self.num_features), + ) + _generate_features_from_paper_features( + out=institution_feat, + num_nodes=num_authors, + num_papers=num_institutions, + paper_feat=paper_feat, + edge_index=dataset.edge_index("author", "institution"), + num_features=self.num_features, + ) + self.comm.barrier() + + # Make sure all ranks can see the generated files if not osp.exists(path + "/author_feat.npy"): - print("Generating author features") - author_feat = np.memmap( - filename=path + "/author_feat.npy", - mode="w+", - dtype=np.float16, - shape=(num_authors, self.num_features), - ) - - _generate_features_from_paper_features( - out=author_feat, - num_nodes=num_authors, - num_papers=num_papers, - paper_feat=paper_feat, - edge_index=dataset.edge_index("author", "paper"), - num_features=self.num_features, - ) - + raise FileNotFoundError("author_feat.npy not found") if not osp.exists(path + "/institution_feat.npy"): - print("Generating institution features") - # Generate institution features - num_institutions = dataset.num_institutions - institution_feat = np.memmap( - filename=path + "/institution_feat.npy", - mode="w+", - dtype=np.float16, - shape=(num_institutions, self.num_features), - ) - print("Generating institution features") - _generate_features_from_paper_features( - out=institution_feat, - num_nodes=num_institutions, - num_papers=num_papers, - paper_feat=paper_feat, - edge_index=dataset.edge_index("author", "institution"), - num_features=self.num_features, - ) + raise FileNotFoundError("institution_feat.npy not found") + self.comm.barrier() + print("Data processing complete") + # Same as synthetic? + def get_vertex_rank_mask(self, mask_type: str) -> Tuple[torch.Tensor, torch.Tensor]: + if mask_type == "train": + global_int_mask = self.train_mask + elif mask_type == "val": + global_int_mask = self.val_mask + elif mask_type == "test": + global_int_mask = self.test_mask + else: + raise ValueError(f"Invalid mask type: {mask_type}") + + # Get the ranks of the vertices + # paper_vertex_rank_mapping -> vector of size num_papers, + # where each entry is the location / rank of the vertex + paper_rank_mappings = self.paper_rank_mappings.to( + global_int_mask.device + ) + vertex_ranks = paper_rank_mappings[global_int_mask] + # vertex_ranks is location of the vertices in the global_int_mask + vertex_ranks_mask = vertex_ranks == self.rank + return global_int_mask, vertex_ranks_mask + + # Same as synthetic? + def get_mask(self, mask_type: str) -> torch.Tensor: + + global_int_mask, vertex_ranks_mask = self.get_vertex_rank_mask(mask_type) + local_int_mask = global_int_mask[vertex_ranks_mask] + local_int_mask = local_int_mask % self.num_local_papers + return local_int_mask + + # Same as synthetic? + def get_target(self, _type: str) -> torch.Tensor: + global_int_mask, vertex_ranks_mask = self.get_vertex_rank_mask(_type) + + global_training_targets = self.y[:, global_int_mask.squeeze(0)] + local_training_targets = global_training_targets[vertex_ranks_mask] + + return local_training_targets + + def __len__(self): + return 0 + + # Same as synthetic? def add_batch_dimension(self): + """Add a batch dimension to all tensors. This is particularly useful + because we only have one graph and DGraph is built to handle batches of graphs. + We want to do this here because this allows us to avoid copying the data + and requiring a data loader. + """ + self.paper_features = self.paper_features.unsqueeze(0) + self.author_features = self.author_features.unsqueeze(0) + self.institution_features = self.institution_features.unsqueeze(0) + self.y = self.y.unsqueeze(0) + self.train_mask = self.train_mask.unsqueeze(0) + self.val_mask = self.val_mask.unsqueeze(0) + self.test_mask = self.test_mask.unsqueeze(0) + self.paper_2_paper_edges = self.paper_2_paper_edges.unsqueeze(0) + self.author_2_paper_edges = self.author_2_paper_edges.unsqueeze(0) + self.author_2_institution_edges = self.author_2_institution_edges.unsqueeze(0) + return self + # Same as synthetic? def to(self, device): - return self + """Move the dataset tensors to the specified device. + We want to do this here because this allows us to avoid + copying the data when the different individual tensors are + accessed. + """ + self.paper_features = self.paper_features.to(device, dtype=torch.float32) + self.author_features = self.author_features.to(device, dtype=torch.float32) + self.institution_features = self.institution_features.to(device, dtype=torch.float32) + self.y = self.y.to(device) + self.train_mask = self.train_mask.to(device) + self.val_mask = self.val_mask.to(device) + self.test_mask = self.test_mask.to(device) + self.paper_2_paper_edges = self.paper_2_paper_edges.to(device) + self.author_2_paper_edges = self.author_2_paper_edges.to(device) + self.author_2_institution_edges = self.author_2_institution_edges.to(device) - def __len__(self): - return 1 + return self def __getitem__(self, idx): # There are 5 relations: @@ -268,7 +443,6 @@ def __getitem__(self, idx): self.author_features, self.institution_features, ] - return (features, edge_index, edge_type, rank_mappings) diff --git a/experiments/OGB-LSC/main.py b/experiments/OGB-LSC/main.py index ce4a071..35deab3 100644 --- a/experiments/OGB-LSC/main.py +++ b/experiments/OGB-LSC/main.py @@ -70,14 +70,22 @@ def main( elif dataset == "mag240m": from mag240m.DGraph_MAG240M import DGraph_MAG240M as Dataset - assert osp.exists(paper_rank_mapping_file) - assert osp.exists(author_rank_mapping_file) - assert osp.exists(institution_rank_mapping_file) - paper_rank_mapping = torch.load(paper_rank_mapping_file, weights_only=False) - author_rank_mapping = torch.load(author_rank_mapping_file, weights_only=False) - institution_rank_mapping = torch.load( - institution_rank_mapping_file, weights_only=False - ) + paper_rank_mapping = None + if len(paper_rank_mapping_file) > 0: + assert osp.exists(paper_rank_mapping_file) + paper_rank_mapping = torch.load(paper_rank_mapping_file, weights_only=False) + + author_rank_mapping = None + if len(author_rank_mapping_file) > 0: + assert osp.exists(author_rank_mapping_file) + author_rank_mapping = torch.load(author_rank_mapping_file, weights_only=False) + + institution_rank_mapping = None + if len(institution_rank_mapping_file) > 0: + assert osp.exists(institution_rank_mapping_file) + institution_rank_mapping = torch.load( + institution_rank_mapping_file, weights_only=False + ) graph_dataset = partial( Dataset, diff --git a/experiments/OGB-LSC/synthetic/synthetic_dataset.py b/experiments/OGB-LSC/synthetic/synthetic_dataset.py index 5ecdb8b..05cb399 100644 --- a/experiments/OGB-LSC/synthetic/synthetic_dataset.py +++ b/experiments/OGB-LSC/synthetic/synthetic_dataset.py @@ -97,9 +97,9 @@ def __init__( self.num_papers = num_papers self.num_authors = num_authors self.num_institutions = num_institutions - self.num_classes = num_classes - self.num_features = num_features - self.num_relations = 5 + self._num_classes = num_classes + self._num_features = num_features + self._num_relations = 5 self.comm = comm self.rank = comm.get_rank() self.world_size = comm.get_world_size() @@ -195,6 +195,18 @@ def __init__( (institution_vertices_cur_rank, num_features), dtype=torch.float32 ) + @property + def num_features(self) -> int: + return self._num_features + + @property + def num_classes(self) -> int: + return self._num_classes + + @property + def num_relations(self) -> int: + return self._num_relations + # def get_validation_mask(self): # # Only papers are classified # validation_vertices_mappings = self.paper_vertex_rank_mapping[self.val_mask] From ee1f7255cdf963bdec54bcdb6161ab07cdd78d6d Mon Sep 17 00:00:00 2001 From: Keita Iwabuchi Date: Fri, 7 Nov 2025 17:10:51 -0800 Subject: [PATCH 28/48] (OGB-LSC) Remove debug message --- experiments/OGB-LSC/Trainer.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/experiments/OGB-LSC/Trainer.py b/experiments/OGB-LSC/Trainer.py index bd6cf3c..b3fea9a 100644 --- a/experiments/OGB-LSC/Trainer.py +++ b/experiments/OGB-LSC/Trainer.py @@ -64,18 +64,6 @@ def train(self): xs, edge_index, edge_type, rank_mapping = self.dataset[0] - # Early sanity check: first feature tensor last dim vs configured num_features - configured = self.dataset.num_features - actual = xs[0].size(-1) if isinstance(xs, (list, tuple)) else xs.size(-1) - if ( - configured != actual - and self.comm.get_rank() == 0 - ): - print( - f"[RGAT] Warning: configured in_channels={configured} but feature dim={actual}; " - f"layers will adapt lazily." - ) - # Fetch once; masks/targets are static across epochs train_mask = self.dataset.get_mask("train") target = self.dataset.get_target("train") From f334ad4134773d3b807da2b5a9d1204c47f1aaa2 Mon Sep 17 00:00:00 2001 From: Keita Iwabuchi Date: Mon, 10 Nov 2025 18:29:30 -0800 Subject: [PATCH 29/48] (OGB-LSC) Split dtaset using OGB's function --- experiments/OGB-LSC/mag240m/DGraph_MAG240M.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py index 565fd6e..c34db6c 100644 --- a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py +++ b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py @@ -169,10 +169,9 @@ def __init__( self.institution_rank_mappings, ) - _vertices = torch.randperm(self.num_papers) - self.train_mask = _vertices[: int(0.7 * self.num_papers)] - self.val_mask = _vertices[int(0.7 * self.num_papers) : int(0.85 * self.num_papers)] - self.test_mask = _vertices[int(0.85 * self.num_papers) :] + self.train_mask = self.dataset.get_idx_split('train') + self.val_mask = self.dataset.get_idx_split('valid') + self.test_mask = self.dataset.get_idx_split('test-dev') local_papers_mask = self.paper_rank_mappings == self.rank local_authors_mask = self.author_rank_mappings == self.rank From c6636260e8c6168a3a8c5ee162928e44be9f0aa0 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 20 Nov 2025 11:45:21 -0800 Subject: [PATCH 30/48] Updated torch bindings implementation for local-scatter-gather --- .gitignore | 24 + .../distributed/csrc/local_data_kernels.cuh | 109 +++++ .../distributed/csrc/torch_local_bindings.cpp | 1 + .../distributed/csrc/torch_local_kernels.cu | 63 ++- DGraph/distributed/nccl/NCCLBackendEngine.py | 455 ------------------ DGraph/distributed/nccl/_nccl_cache.py | 3 + DGraph/distributed/nccl/alltoallv_impl.py | 20 + 7 files changed, 219 insertions(+), 456 deletions(-) diff --git a/.gitignore b/.gitignore index 1d407a7..6ea5777 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,27 @@ cython_debug/ # Use wildcards as well *~ *.o +# Miscallenous files generated by DGraph data processing +skbuild/ +.vscode/ +logs/ +torchrun_* +*.png +rdvz +*.pt +*.core +*.graph +*.out +*.gz +data_processed +*.zip +cache +graph_cache +*.nsys-rep +*.nsys +*.pth +*.pyc +*.npy +*.npz +*.sqlite +*.csv \ No newline at end of file diff --git a/DGraph/distributed/csrc/local_data_kernels.cuh b/DGraph/distributed/csrc/local_data_kernels.cuh index f12ca4a..e81a9b1 100644 --- a/DGraph/distributed/csrc/local_data_kernels.cuh +++ b/DGraph/distributed/csrc/local_data_kernels.cuh @@ -251,4 +251,113 @@ namespace Local } } } + + /** + * + * Masked Gather Kernel operation that performs the operation: + Y [mask[i]] = X [indices[i]] + + where Y is the output matrix, X is the input matrix, indices is the index matrix, and mask is the mask matrix. + */ + + __global__ void Masked_Scatter_Gather_Kernel( + const float *__restrict__ values, + const long *__restrict__ indices, + const long *__restrict__ mask, + float *__restrict__ output, + const int mini_batch_size, + const int num_indices, + const int num_cols, + const int num_output_rows) + { + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + + const size_t nthreadsx = gridDim.x * blockDim.x; + const size_t nthreadsy = gridDim.y * blockDim.y; + const size_t nthreadsz = gridDim.z * blockDim.z; + + for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz) + { + const auto values_offset = mb_i * num_cols * num_indices; + const auto output_offset = mb_i * num_cols * num_output_rows; + const auto ind_offset = mb_i * num_indices; + const auto mask_offset = mb_i * num_indices; + + for (size_t row = gidy; row < num_indices; row += nthreadsy) + { + const auto output_row = mask[mask_offset + row]; + const auto input_row = indices[ind_offset + row]; + + for (size_t col = gidx; col < num_cols; col += nthreadsx) + { + output[output_offset + output_row * num_cols + col] = values[values_offset + input_row * num_cols + col]; + } + } + } + } + + /* + * + Optimized masked scatter gather kernel that performs the operation: + Y [mask[i]] = X [indices[i]] + + This kernel is optimized for the case where the num_cols is a multiple of 4. + + where Y is the output matrix, X is the input matrix, indices is the index matrix, and mask is the mask matrix. + */ + __global__ void Optimized_Masked_Scatter_Gather_Kernel( + const float *__restrict__ values, + const long *__restrict__ indices, + const long *__restrict__ mask, + float *__restrict__ output, + const int mini_batch_size, + const int num_indices, + const int num_cols, + const int num_output_rows) + { + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + + const size_t nthreadsx = gridDim.x * blockDim.x; + const size_t nthreadsy = gridDim.y * blockDim.y; + const size_t nthreadsz = gridDim.z * blockDim.z; + + // Grid-stride loop over mini-batches + for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz) + { + const auto values_offset = mb_i * num_cols / 4 * num_indices; + const auto output_offset = mb_i * num_cols / 4 * num_output_rows; + const auto ind_offset = mb_i * num_indices; + const auto mask_offset = mb_i * num_indices; + + // Grid-stride loop over rows + for (size_t row = gidy; row < num_indices; row += nthreadsy) + { + long output_row, input_row; + + if (threadIdx.x == 0) { + output_row = mask[mask_offset + row]; + input_row = indices[ind_offset + row]; + } + + output_row = __shfl_sync(0xFFFFFFFF, output_row, 0); + input_row = __shfl_sync(0xFFFFFFFF, input_row, 0); + + output_row = mask[mask_offset + row]; + input_row = indices[ind_offset + row]; + + size_t col = gidx; + + for (; col < num_cols / 4; col += nthreadsx) + { + const float4 values_vec = reinterpret_cast(values)[values_offset + input_row * num_cols / 4 + col]; + float4& output_vec = reinterpret_cast(output)[output_offset + output_row * num_cols / 4 + col]; + output_vec = values_vec; + } + } + } + } } // namespace Local \ No newline at end of file diff --git a/DGraph/distributed/csrc/torch_local_bindings.cpp b/DGraph/distributed/csrc/torch_local_bindings.cpp index a91f516..6701e6a 100644 --- a/DGraph/distributed/csrc/torch_local_bindings.cpp +++ b/DGraph/distributed/csrc/torch_local_bindings.cpp @@ -21,4 +21,5 @@ PYBIND11_MODULE(torch_local, m) { m.def("local_masked_gather", &local_masked_gather, "Masked Gather"); m.def("local_masked_scatter", &local_masked_scatter, "Masked Scatter"); + m.def("local_masked_scatter_gather", &local_masked_scatter_gather, "Masked Scatter Gather"); } diff --git a/DGraph/distributed/csrc/torch_local_kernels.cu b/DGraph/distributed/csrc/torch_local_kernels.cu index b70bf36..a91593b 100644 --- a/DGraph/distributed/csrc/torch_local_kernels.cu +++ b/DGraph/distributed/csrc/torch_local_kernels.cu @@ -114,4 +114,65 @@ torch::Tensor local_masked_scatter(torch::Tensor input, rank); CUDACHECK(cudaGetLastError()); return output; -} \ No newline at end of file +} + +torch::Tensor local_masked_scatter_gather(torch::Tensor input, + torch::Tensor indices, + torch::Tensor mask, + torch::Tensor output, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows, + const int rank) + { + CHECK_INPUT(input); + CHECK_INPUT(indices); + CHECK_INPUT(mask); + CHECK_INPUT(output); + + const float *input_ptr = input.data_ptr(); + const long *indices_ptr = indices.data_ptr(); + const float *mask_ptr = mask.data_ptr(); + float *output_ptr = output.data_ptr(); + + dim3 block_dims, grid_dims; + block_dims.x = 32; + block_dims.y = 32; + block_dims.z = 1; + + const auto num_grids_needed = (num_output_rows + block_dims.y - 1) / block_dims.y; + const auto num_col_grids_needed = (num_cols + block_dims.x - 1) / block_dims.x; + grid_dims.x = num_col_grids_needed < 65535 ? num_col_grids_needed : 65535; + grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535; + grid_dims.z = 1; + + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index()); + + if (num_cols % 4 != 0) + { + Local::Masked_Scatter_Gather_Kernel<<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows, + rank); + } + else + { + Local::Optimized_Masked_Scatter_Gather_Kernel<<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows, + rank); + } + CUDACHECK(cudaGetLastError()); + return output; + } \ No newline at end of file diff --git a/DGraph/distributed/nccl/NCCLBackendEngine.py b/DGraph/distributed/nccl/NCCLBackendEngine.py index 299dd02..beb08bf 100644 --- a/DGraph/distributed/nccl/NCCLBackendEngine.py +++ b/DGraph/distributed/nccl/NCCLBackendEngine.py @@ -38,462 +38,7 @@ TIMINGS = {"Gather_Index_Forward": [], "Gather_Forward_Local": []} -class GatherFunction(Function): - @staticmethod - def forward( - ctx, - local_send_tensor: torch.Tensor, - indices: torch.LongTensor, - # vertex_ranks: torch.Tensor, - edge_rank_loc: torch.Tensor, - edge_dest_ranks: torch.Tensor, - rank: int, - world_size: int, - cache: Optional[NCCLGatherCache] = None, - ): - num_local_input_rows = local_send_tensor.shape[1] - - if cache is not None: - # We have a cache, use it, don't need to save anything - ctx.has_cache = True - ctx.cache = cache - # TODO: Should we cash the indices as well? - S.Z - else: - ctx.has_cache = False - - ctx.save_for_backward( - indices, - edge_rank_loc, - edge_dest_ranks, - torch.tensor(num_local_input_rows), - torch.tensor(rank), - torch.tensor(world_size), - ) - - # Since NCCL is two-sided, we need to push from local rank and pull from - # remote rank to get the global gather - - # TODO: One possible optmization is cache all these calculations - # and only do the gather when the cache is invalidated. Essentially - # if we are working with static graphs, the indices and distribution pattern - # will not change and we can cache the communication pattern. - S.Z - - # We can also pre-compute this on the data ingestion side. Might - # be worth looking to some kind of cached communication pattern store - # that can be passed to the communicator. - S.Z - - batch_size = 1 - num_features = local_send_tensor.shape[2] - - if cache is not None: - local_indices = cache.gather_local_indices % local_send_tensor.shape[1] - local_gather_mask = cache.gather_local_comm_mask - needs_comm = cache.gather_needs_comm - local_output_rows = cache.gather_num_output_rows - local_rank_mapping = cache.gather_local_remapped_ranks - recv_tensor = torch.zeros(batch_size, local_output_rows, num_features).to( - local_send_tensor.device - ) - local_recv_tensor = cache.gather_local_recv_mapping - else: - # Get the edges that are local to the rank - - local_slice_mask = edge_rank_loc == rank - - num_local_output_rows = int(local_slice_mask.sum().item()) - - recv_tensor = torch.zeros( - batch_size, num_local_output_rows, num_features - ).to(local_send_tensor.device) - - local_indices_slice = indices[local_slice_mask.unsqueeze(0)] - local_rank_mapping = edge_rank_loc[local_slice_mask] - local_recv_tensor = edge_dest_ranks[local_slice_mask] - - # assert torch.all(local_recv_tensor == rank), local_recv_tensor - - local_indices = local_indices_slice % local_send_tensor.shape[1] - - needs_comm = (local_recv_tensor != rank).any() - - recv_tensor = OptimizedRankLocalMaskedGather( - local_send_tensor, - local_indices, - local_rank_mapping, - recv_tensor, - rank, - ) - - if needs_comm: - - recv_tensor = _nccl_alltoall_v( - local_send_tensor=local_send_tensor, - local_recv_tensor=recv_tensor, - indices=indices, - local_rank_mapping=local_recv_tensor, - edge_rank_loc=edge_rank_loc, - src_rank_loc=edge_dest_ranks, - rank=rank, - world_size=world_size, - cache=cache, - ) - - return recv_tensor - - @staticmethod - def backward(ctx, grad_output): - # We need to switch the send and recv ranks - ( - indices, - recv_ranks, - send_ranks, - # vertices_per_rank, - num_local_input_rows, - rank, - world_size, - ) = ctx.saved_tensors - - if ctx.has_cache: - cache: Optional[NCCLGatherCache] = ctx.cache - else: - cache = None - - num_local_output_rows = num_local_input_rows.item() - rank = rank.item() - world_size = world_size.item() - send_tensor = grad_output - - # Now it's a scatter operation - num_features = send_tensor.shape[-1] - device = send_tensor.device - local_rank_output = torch.zeros(1, num_local_output_rows, num_features).to( - device - ) - - indices = indices.view(-1) - local_slice_mask = recv_ranks == rank - local_indices_slice = indices[local_slice_mask] - local_dest_ranks = send_ranks[local_slice_mask] - - local_rank_output = RankLocalMaskedScatter( - send_tensor, - local_rank_output, - local_indices_slice, - local_dest_ranks, - rank, - ) - - if cache is not None: - local_comm_mask = cache.scatter_local_comm_mask - else: - local_comm_mask = local_dest_ranks != rank - - send_buffer_dict = {} - if torch.any(local_comm_mask): - # These rows need to be sent to other ranks - # First aggregate these into a single buffer - - if cache is not None: - num_remote_rows = cache.scatter_num_remote_rows - remapped_ranks = cache.scatter_local_remapped_ranks - renumbered_indices = cache.scatter_renumbered_indices - receiving_ranks = cache.scatter_remote_send_to_ranks - - else: - - local_comm_indices = local_indices_slice[local_comm_mask] - local_remote_dest_mappings = local_dest_ranks[local_comm_mask] - - renumbered_indices, unique_indices, remapped_ranks = ( - RankLocalRenumberingWithMapping( - local_comm_indices, local_remote_dest_mappings - ) - ) - receiving_ranks = torch.unique(local_dest_ranks[local_comm_mask]) - num_remote_rows = len(unique_indices) - - buffer = torch.zeros(1, num_remote_rows, num_features).to(device) - buffer.scatter_add_( - 1, - renumbered_indices.view(1, -1, 1).expand(1, -1, num_features), - send_tensor[:, local_comm_mask, :], - ) - - for _recv_rank in receiving_ranks: - _recv_indices = remapped_ranks == _recv_rank - send_buffer_dict[_recv_rank.item()] = buffer[:, _recv_indices, :] - - # Now we need to receive the data from the remote ranks - - recv_buffer_dict = {} - - recv_placement = {} - - if cache is not None: - recv_placement = cache.scatter_recv_local_placement - - # Allocate the receive buffers for the communication based on the - # size of the recv_placement indices. - for key, unique_send_indices in recv_placement.items(): - num_elements = unique_send_indices.shape[0] - recv_buffer_dict[key] = torch.zeros(1, num_elements, num_features).to( - device - ) - else: - send_to_rank = send_ranks # Pedantic variable name change - all_comm_mask = send_to_rank != recv_ranks - reciever_mask = send_to_rank == rank - receive_from_remote = all_comm_mask & reciever_mask - - if torch.any(receive_from_remote): - receive_from_ranks = recv_ranks[receive_from_remote] - - for _sender in range(world_size): - if _sender == rank: - continue - if torch.any(receive_from_ranks == _sender): - _send_mask = (recv_ranks == _sender) & receive_from_remote - _send_indices = indices[_send_mask] % num_local_output_rows - # TODO: This is brittle, look into a better way to do this - S.Z - - unique_send_indices = torch.unique(_send_indices) - num_elements = unique_send_indices.shape[0] - recv_buffer_dict[_sender] = torch.zeros( - 1, num_elements, num_features - ).cuda() - recv_placement[_sender] = unique_send_indices - - recv_buffer_dict = _nccl_alltoallv_with_dict( - send_buffer_dict, recv_buffer_dict, rank, world_size - ) - for key, recv_buffer in recv_buffer_dict.items(): - local_rank_output.scatter_add_( - 1, - recv_placement[key].view(1, -1, 1).expand(1, -1, num_features), - recv_buffer, - ) - - send_tensor_grad = local_rank_output - indices_grad = None - send_ranks_grad = None - recv_ranks_grad = None - rank_grad = None - world_size_grad = None - cache_grad = None - - return ( - send_tensor_grad, - indices_grad, - send_ranks_grad, - recv_ranks_grad, - rank_grad, - world_size_grad, - cache_grad, - ) - -class ScatterFunction(Function): - @staticmethod - def forward( - ctx, - send_tensor: torch.Tensor, - indices: torch.Tensor, - edge_src_ranks: torch.Tensor, - edge_dest_ranks: torch.Tensor, - num_local_output_rows: int, - rank: int, - world_size: int, - scatter_cache: Optional[NCCLScatterCache] = None, - ) -> torch.Tensor: - - ctx.save_for_backward( - indices, - edge_src_ranks, - edge_dest_ranks, - torch.tensor(num_local_output_rows), - torch.tensor(rank), - torch.tensor(world_size), - ) - use_cache = scatter_cache is not None - if use_cache: - ctx.scatter_cache = scatter_cache - ctx.has_cache = True - else: - ctx.has_cache = False - - num_features = send_tensor.shape[-1] - device = send_tensor.device - - local_rank_output = torch.zeros(1, num_local_output_rows, num_features).to( - device - ) - - indices = indices.view(-1) - - local_edge_mask = edge_src_ranks == rank - - local_indices_slice = indices[local_edge_mask] - local_dest_ranks = edge_dest_ranks[local_edge_mask] - - local_rank_output = RankLocalMaskedScatter( - send_tensor, - local_rank_output, - local_indices_slice, - local_dest_ranks, - rank, - ) - - if use_cache: - local_comm_mask = scatter_cache.scatter_local_comm_mask - else: - local_comm_mask = local_dest_ranks != rank - - all_comm_mask = edge_src_ranks != edge_dest_ranks - reciever_mask = edge_dest_ranks == rank - receive_from_remote_mask = all_comm_mask & reciever_mask - - send_buffer_dict = {} - - if torch.any(local_comm_mask): - - if use_cache: - num_remote_rows = scatter_cache.scatter_num_remote_rows - remapped_ranks = scatter_cache.scatter_local_remapped_ranks - renumbered_indices = scatter_cache.scatter_local_renumbered_indices - receving_ranks = scatter_cache.scatter_remote_send_to_ranks - - else: - # These rows need to be sent to other ranks - # First aggregate these into a single buffer - local_comm_indices = local_indices_slice[local_comm_mask] - local_remote_dest_mappings = local_dest_ranks[local_comm_mask] - # TODO: This is very slow, look into a better way to do this - S.Z - # Uncached is slow, should look into augmenting torch functions - # to speed this up - S.Z - renumbered_indices, unique_indices, remapped_ranks = ( - RankLocalRenumberingWithMapping( - local_comm_indices, local_remote_dest_mappings - ) - ) - num_remote_rows = len(unique_indices) - receving_ranks = torch.unique(local_dest_ranks[local_comm_mask]) - - buffer = torch.zeros(1, num_remote_rows, num_features).to(device) - buffer.scatter_add_( - 1, - renumbered_indices.view(1, -1, 1).expand(1, -1, num_features), - send_tensor[:, local_comm_mask, :], - ) - - for _recv_rank in receving_ranks: - _recv_indices = remapped_ranks == _recv_rank - send_buffer_dict[_recv_rank.item()] = buffer[:, _recv_indices, :] - - recv_buffer_dict = {} - recv_placement = {} - if use_cache: - recv_placement = scatter_cache.scatter_recv_local_placement - else: - recv_placement = _get_local_unique_recv_placement( - indices, - edge_src_ranks, - receive_from_remote_mask, - num_local_output_rows, - rank, - world_size, - ) - - # Allocate the receive buffers for the communication based on the - # size of the recv_placement indices. - for key, unique_send_indices in recv_placement.items(): - num_elements = unique_send_indices.shape[0] - recv_buffer_dict[key] = torch.zeros(1, num_elements, num_features).to( - device - ) - recv_buffer_dict = _nccl_alltoallv_with_dict( - send_buffer_dict, recv_buffer_dict, rank, world_size - ) - for key, recv_buffer in recv_buffer_dict.items(): - local_rank_output.scatter_add_( - 1, - recv_placement[key].view(1, -1, 1).expand(1, -1, num_features), - recv_buffer, - ) - return local_rank_output - - @staticmethod - def backward(ctx, grad_output): - # We need to switch the send and recv ranks - indices, recv_ranks, send_ranks, num_input_rows, rank, world_size = ( - ctx.saved_tensors - ) - - local_mask = recv_ranks == rank - if ctx.has_cache: - cache: NCCLScatterCache = ctx.scatter_cache - num_local_output_rows = cache.gather_num_output_rows - - else: - rank = int(rank.item()) - world_size = int(world_size.item()) - - indices = indices.view(1, -1) - - # Now it's a gather operation - - num_local_output_rows = int(local_mask.sum().item()) - - batch_size = 1 - num_features = grad_output.shape[2] - - recv_tensor = torch.zeros(batch_size, num_local_output_rows, num_features).to( - grad_output.device - ) - - local_indices_slice = indices[0][local_mask] - local_rank_mapping = send_ranks[local_mask] - - local_indices = local_indices_slice % grad_output.shape[1] - - if len(local_indices_slice) > 0: - - recv_tensor[:, local_rank_mapping == rank, :] = RankLocalMaskedGather( - grad_output, local_indices, local_rank_mapping, rank - ) - - recv_tensor = _nccl_alltoall_v( - local_send_tensor=grad_output, - local_recv_tensor=recv_tensor, - indices=indices, - local_rank_mapping=local_rank_mapping, - edge_rank_loc=send_ranks, - src_rank_loc=recv_ranks, - rank=rank, - world_size=world_size, - cache=cache, - ) - - # NOTE: even if the inputs are non-tensors, the number of backward outputs - # must be the same as the number of inputs. - send_tensor_grad = recv_tensor - indices_grad = None - send_ranks_grad = None - recv_ranks_grad = None - num_local_output_rows_grad = None - rank_grad = None - world_size_grad = None - scatter_cache_grad = None - - return ( - send_tensor_grad, - indices_grad, - send_ranks_grad, - recv_ranks_grad, - num_local_output_rows_grad, - rank_grad, - world_size_grad, - scatter_cache_grad, - ) class NCCLBackendEngine(BackendEngine): diff --git a/DGraph/distributed/nccl/_nccl_cache.py b/DGraph/distributed/nccl/_nccl_cache.py index 6247aa3..b378f1c 100644 --- a/DGraph/distributed/nccl/_nccl_cache.py +++ b/DGraph/distributed/nccl/_nccl_cache.py @@ -64,6 +64,9 @@ class NCCLScatterCache: world_size: int +# @dataclass +# class + def all_to_all_cache_helper( indices, edge_placement, edge_vertex_ranks, num_rows, rank, world_size ): diff --git a/DGraph/distributed/nccl/alltoallv_impl.py b/DGraph/distributed/nccl/alltoallv_impl.py index 060c390..d549cb9 100644 --- a/DGraph/distributed/nccl/alltoallv_impl.py +++ b/DGraph/distributed/nccl/alltoallv_impl.py @@ -159,3 +159,23 @@ def _nccl_alltoallv_with_dict(send_buffer_dict, recv_buffer_dict, rank, world_si for key, recv_buffer in recv_buffer_dict.items(): recv_buffer_dict[key] = recv_buffer.float() return recv_buffer_dict + + +def torch_alltoallv_with_comm_map(contiguous_send_tensor: torch.Tensor, + contiguous_recv_tensor: torch.Tensor, + send_comm_map: torch.Tensor, + recv_comm_map: torch.Tensor, + rank: int, + world_size: int): + assert len(send_comm_map) == world_size, "Send comm map should be of size world_size" + assert len(recv_comm_map) == world_size, "Recv comm map should be of size world_size" + + send_sizes = send_comm_map.tolist() + recv_sizes = recv_comm_map.tolist() + + send_list = list(torch.split(contiguous_send_tensor, send_sizes, dim=1)) + recv_list = list(torch.split(contiguous_recv_tensor, recv_sizes, dim=1)) + + dist.all_to_all(recv_list, send_list) + return recv_list + From 47f6ef32de1242401bb754f215d4546aace045d6 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 20 Nov 2025 11:46:25 -0800 Subject: [PATCH 31/48] New and improved concise dataplan with efficient connectivity data storage --- DGraph/distributed/nccl/_torch_func_impl.py | 514 ++++++++++++++++++++ 1 file changed, 514 insertions(+) create mode 100644 DGraph/distributed/nccl/_torch_func_impl.py diff --git a/DGraph/distributed/nccl/_torch_func_impl.py b/DGraph/distributed/nccl/_torch_func_impl.py new file mode 100644 index 0000000..f9f45ff --- /dev/null +++ b/DGraph/distributed/nccl/_torch_func_impl.py @@ -0,0 +1,514 @@ +import torch +from typing import Optional +from torch.autograd import Function +import torch.distributed as dist +from dataclasses import dataclass +from DGraph.distributed.nccl._nccl_cache import NCCLGatherCache, NCCLScatterCache +from DGraph.distributed.RankLocalOps import OptimizedRankLocalMaskedGather + +@dataclass +class GatherCommPlan: + """ + Class to store communication plan for distributed gather + + Attributes: + rank (int): Local rank + world_size (int): World size + local_tensor_size (int): Size of the return local tensor after local and global gather + send_comm_vector (torch.Tensor): Communication vector of shape [world_size] of messages to send to each rank + recv_comm_vector (torch.Tensor): Communication vector of shape [world_size] of messages to receive from each rank + local_send_buffer_size (int): Size of the local send buffer + local_recv_buffer_size (int): Size of the local recv buffer + local_output_rows (torch.Tensor): Local rows that don't need communication + local_gather_indices (torch.Tensor): Indices of rows in local input to write to local_output_rows + comm_output_rows (torch.Tensor): Local rows that need data from remote ranks + comm_output_indices (torch.Tensor): Indices of rows in local input to write to comm_output_rows + + """ + + rank: int + world_size: int + local_tensor_size: int + send_comm_vector: torch.Tensor + recv_comm_vector: torch.Tensor + local_send_buffer_size: int + local_recv_buffer_size: int + local_output_rows: torch.Tensor + local_gather_indices: torch.Tensor + comm_output_rows: torch.Tensor + comm_output_indices: torch.Tensor + + +class Cached_Static_GatherFunction(Function): + @staticmethod + def forward( + ctx, + local_send_tensor: torch.Tensor, + comm_plan: GatherCommPlan, + ): + """ + Forward pass for distributed gather + + Args: + ctx (torch.autograd.FunctionContext): Context object + local_send_tensor (torch.Tensor): Local send tensor + comm_plan (GatherCommPlan): Communication plan + """ + assert (len(local_send_tensor.shape) == 3), "Local send tensor must be of shape (batch_size, num_rows, num_features)" + ctx.comm_plan = comm_plan + num_features = local_send_tensor.shape[-1] + num_batches = local_send_tensor.shape[0] + + output_tensor = torch.zeros( + num_batches, comm_plan.local_tensor_size, num_features + ).to(local_send_tensor.device) + + return output_tensor + + + def backward(ctx, grad_output): + """ + Backward pass for distributed gather + + Args: + ctx (torch.autograd.FunctionContext): Context object + grad_output (torch.Tensor): Gradient of the output tensor + """ + comm_plan = ctx.comm_plan + + grad_output = torch.zeros_like(grad_output) + + return grad_output, None + + +class GatherFunction(Function): + @staticmethod + def forward( + ctx, + local_send_tensor: torch.Tensor, + indices: torch.LongTensor, + # vertex_ranks: torch.Tensor, + edge_rank_loc: torch.Tensor, + edge_dest_ranks: torch.Tensor, + rank: int, + world_size: int, + ): + num_local_input_rows = local_send_tensor.shape[1] + + ctx.save_for_backward( + indices, + edge_rank_loc, + edge_dest_ranks, + torch.tensor(num_local_input_rows), + torch.tensor(rank), + torch.tensor(world_size), + ) + + + # Since NCCL is two-sided, we need to push from local rank and pull from + # remote rank to get the global gather + + # TODO: One possible optmization is cache all these calculations + # and only do the gather when the cache is invalidated. Essentially + # if we are working with static graphs, the indices and distribution pattern + # will not change and we can cache the communication pattern. - S.Z + + # We can also pre-compute this on the data ingestion side. Might + # be worth looking to some kind of cached communication pattern store + # that can be passed to the communicator. - S.Z + + batch_size = 1 + num_features = local_send_tensor.shape[2] + + + local_slice_mask = edge_rank_loc == rank + + num_local_output_rows = int(local_slice_mask.sum().item()) + + recv_tensor = torch.zeros( + batch_size, num_local_output_rows, num_features + ).to(local_send_tensor.device) + + local_indices_slice = indices[local_slice_mask.unsqueeze(0)] + local_rank_mapping = edge_rank_loc[local_slice_mask] + local_recv_tensor = edge_dest_ranks[local_slice_mask] + + # assert torch.all(local_recv_tensor == rank), local_recv_tensor + + local_indices = local_indices_slice % local_send_tensor.shape[1] + + needs_comm = (local_recv_tensor != rank).any() + + recv_tensor = OptimizedRankLocalMaskedGather( + local_send_tensor, + local_indices, + local_rank_mapping, + recv_tensor, + rank, + ) + + if needs_comm: + + recv_tensor = _nccl_alltoall_v( + local_send_tensor=local_send_tensor, + local_recv_tensor=recv_tensor, + indices=indices, + local_rank_mapping=local_recv_tensor, + edge_rank_loc=edge_rank_loc, + src_rank_loc=edge_dest_ranks, + rank=rank, + world_size=world_size, + cache=cache, + ) + + return recv_tensor + + @staticmethod + def backward(ctx, grad_output): + # We need to switch the send and recv ranks + ( + indices, + recv_ranks, + send_ranks, + # vertices_per_rank, + num_local_input_rows, + rank, + world_size, + ) = ctx.saved_tensors + + num_local_output_rows = num_local_input_rows.item() + rank = rank.item() + world_size = world_size.item() + send_tensor = grad_output + + # Now it's a scatter operation + num_features = send_tensor.shape[-1] + device = send_tensor.device + local_rank_output = torch.zeros(1, num_local_output_rows, num_features).to( + device + ) + + indices = indices.view(-1) + local_slice_mask = recv_ranks == rank + local_indices_slice = indices[local_slice_mask] + local_dest_ranks = send_ranks[local_slice_mask] + + local_rank_output = RankLocalMaskedScatter( + send_tensor, + local_rank_output, + local_indices_slice, + local_dest_ranks, + rank, + ) + + if cache is not None: + local_comm_mask = cache.scatter_local_comm_mask + else: + local_comm_mask = local_dest_ranks != rank + + send_buffer_dict = {} + if torch.any(local_comm_mask): + # These rows need to be sent to other ranks + # First aggregate these into a single buffer + + if cache is not None: + num_remote_rows = cache.scatter_num_remote_rows + remapped_ranks = cache.scatter_local_remapped_ranks + renumbered_indices = cache.scatter_renumbered_indices + receiving_ranks = cache.scatter_remote_send_to_ranks + + else: + + local_comm_indices = local_indices_slice[local_comm_mask] + local_remote_dest_mappings = local_dest_ranks[local_comm_mask] + + renumbered_indices, unique_indices, remapped_ranks = ( + RankLocalRenumberingWithMapping( + local_comm_indices, local_remote_dest_mappings + ) + ) + receiving_ranks = torch.unique(local_dest_ranks[local_comm_mask]) + num_remote_rows = len(unique_indices) + + buffer = torch.zeros(1, num_remote_rows, num_features).to(device) + buffer.scatter_add_( + 1, + renumbered_indices.view(1, -1, 1).expand(1, -1, num_features), + send_tensor[:, local_comm_mask, :], + ) + + for _recv_rank in receiving_ranks: + _recv_indices = remapped_ranks == _recv_rank + send_buffer_dict[_recv_rank.item()] = buffer[:, _recv_indices, :] + + # Now we need to receive the data from the remote ranks + + recv_buffer_dict = {} + + recv_placement = {} + + if cache is not None: + recv_placement = cache.scatter_recv_local_placement + + # Allocate the receive buffers for the communication based on the + # size of the recv_placement indices. + for key, unique_send_indices in recv_placement.items(): + num_elements = unique_send_indices.shape[0] + recv_buffer_dict[key] = torch.zeros(1, num_elements, num_features).to( + device + ) + else: + send_to_rank = send_ranks # Pedantic variable name change + all_comm_mask = send_to_rank != recv_ranks + reciever_mask = send_to_rank == rank + receive_from_remote = all_comm_mask & reciever_mask + + if torch.any(receive_from_remote): + receive_from_ranks = recv_ranks[receive_from_remote] + + for _sender in range(world_size): + if _sender == rank: + continue + if torch.any(receive_from_ranks == _sender): + _send_mask = (recv_ranks == _sender) & receive_from_remote + _send_indices = indices[_send_mask] % num_local_output_rows + # TODO: This is brittle, look into a better way to do this - S.Z + + unique_send_indices = torch.unique(_send_indices) + num_elements = unique_send_indices.shape[0] + recv_buffer_dict[_sender] = torch.zeros( + 1, num_elements, num_features + ).cuda() + recv_placement[_sender] = unique_send_indices + + recv_buffer_dict = _nccl_alltoallv_with_dict( + send_buffer_dict, recv_buffer_dict, rank, world_size + ) + for key, recv_buffer in recv_buffer_dict.items(): + local_rank_output.scatter_add_( + 1, + recv_placement[key].view(1, -1, 1).expand(1, -1, num_features), + recv_buffer, + ) + + send_tensor_grad = local_rank_output + indices_grad = None + send_ranks_grad = None + recv_ranks_grad = None + rank_grad = None + world_size_grad = None + cache_grad = None + + return ( + send_tensor_grad, + indices_grad, + send_ranks_grad, + recv_ranks_grad, + rank_grad, + world_size_grad, + cache_grad, + ) + + +class ScatterFunction(Function): + @staticmethod + def forward( + ctx, + send_tensor: torch.Tensor, + indices: torch.Tensor, + edge_src_ranks: torch.Tensor, + edge_dest_ranks: torch.Tensor, + num_local_output_rows: int, + rank: int, + world_size: int, + scatter_cache: Optional[NCCLScatterCache] = None, + ) -> torch.Tensor: + + ctx.save_for_backward( + indices, + edge_src_ranks, + edge_dest_ranks, + torch.tensor(num_local_output_rows), + torch.tensor(rank), + torch.tensor(world_size), + ) + use_cache = scatter_cache is not None + if use_cache: + ctx.scatter_cache = scatter_cache + ctx.has_cache = True + else: + ctx.has_cache = False + + num_features = send_tensor.shape[-1] + device = send_tensor.device + + local_rank_output = torch.zeros(1, num_local_output_rows, num_features).to( + device + ) + + indices = indices.view(-1) + + local_edge_mask = edge_src_ranks == rank + + local_indices_slice = indices[local_edge_mask] + local_dest_ranks = edge_dest_ranks[local_edge_mask] + + local_rank_output = RankLocalMaskedScatter( + send_tensor, + local_rank_output, + local_indices_slice, + local_dest_ranks, + rank, + ) + + if use_cache: + local_comm_mask = scatter_cache.scatter_local_comm_mask + else: + local_comm_mask = local_dest_ranks != rank + + all_comm_mask = edge_src_ranks != edge_dest_ranks + reciever_mask = edge_dest_ranks == rank + receive_from_remote_mask = all_comm_mask & reciever_mask + + send_buffer_dict = {} + + if torch.any(local_comm_mask): + + if use_cache: + num_remote_rows = scatter_cache.scatter_num_remote_rows + remapped_ranks = scatter_cache.scatter_local_remapped_ranks + renumbered_indices = scatter_cache.scatter_local_renumbered_indices + receving_ranks = scatter_cache.scatter_remote_send_to_ranks + + else: + # These rows need to be sent to other ranks + # First aggregate these into a single buffer + local_comm_indices = local_indices_slice[local_comm_mask] + local_remote_dest_mappings = local_dest_ranks[local_comm_mask] + # TODO: This is very slow, look into a better way to do this - S.Z + # Uncached is slow, should look into augmenting torch functions + # to speed this up - S.Z + renumbered_indices, unique_indices, remapped_ranks = ( + RankLocalRenumberingWithMapping( + local_comm_indices, local_remote_dest_mappings + ) + ) + num_remote_rows = len(unique_indices) + receving_ranks = torch.unique(local_dest_ranks[local_comm_mask]) + + buffer = torch.zeros(1, num_remote_rows, num_features).to(device) + buffer.scatter_add_( + 1, + renumbered_indices.view(1, -1, 1).expand(1, -1, num_features), + send_tensor[:, local_comm_mask, :], + ) + + for _recv_rank in receving_ranks: + _recv_indices = remapped_ranks == _recv_rank + send_buffer_dict[_recv_rank.item()] = buffer[:, _recv_indices, :] + + recv_buffer_dict = {} + recv_placement = {} + if use_cache: + recv_placement = scatter_cache.scatter_recv_local_placement + else: + recv_placement = _get_local_unique_recv_placement( + indices, + edge_src_ranks, + receive_from_remote_mask, + num_local_output_rows, + rank, + world_size, + ) + + # Allocate the receive buffers for the communication based on the + # size of the recv_placement indices. + for key, unique_send_indices in recv_placement.items(): + num_elements = unique_send_indices.shape[0] + recv_buffer_dict[key] = torch.zeros(1, num_elements, num_features).to( + device + ) + recv_buffer_dict = _nccl_alltoallv_with_dict( + send_buffer_dict, recv_buffer_dict, rank, world_size + ) + for key, recv_buffer in recv_buffer_dict.items(): + local_rank_output.scatter_add_( + 1, + recv_placement[key].view(1, -1, 1).expand(1, -1, num_features), + recv_buffer, + ) + return local_rank_output + + @staticmethod + def backward(ctx, grad_output): + # We need to switch the send and recv ranks + indices, recv_ranks, send_ranks, num_input_rows, rank, world_size = ( + ctx.saved_tensors + ) + + local_mask = recv_ranks == rank + if ctx.has_cache: + cache: NCCLScatterCache = ctx.scatter_cache + num_local_output_rows = cache.gather_num_output_rows + + else: + rank = int(rank.item()) + world_size = int(world_size.item()) + + indices = indices.view(1, -1) + + # Now it's a gather operation + + num_local_output_rows = int(local_mask.sum().item()) + + batch_size = 1 + num_features = grad_output.shape[2] + + recv_tensor = torch.zeros(batch_size, num_local_output_rows, num_features).to( + grad_output.device + ) + + local_indices_slice = indices[0][local_mask] + local_rank_mapping = send_ranks[local_mask] + + local_indices = local_indices_slice % grad_output.shape[1] + + if len(local_indices_slice) > 0: + + recv_tensor[:, local_rank_mapping == rank, :] = RankLocalMaskedGather( + grad_output, local_indices, local_rank_mapping, rank + ) + + recv_tensor = _nccl_alltoall_v( + local_send_tensor=grad_output, + local_recv_tensor=recv_tensor, + indices=indices, + local_rank_mapping=local_rank_mapping, + edge_rank_loc=send_ranks, + src_rank_loc=recv_ranks, + rank=rank, + world_size=world_size, + cache=cache, + ) + + # NOTE: even if the inputs are non-tensors, the number of backward outputs + # must be the same as the number of inputs. + send_tensor_grad = recv_tensor + indices_grad = None + send_ranks_grad = None + recv_ranks_grad = None + num_local_output_rows_grad = None + rank_grad = None + world_size_grad = None + scatter_cache_grad = None + + return ( + send_tensor_grad, + indices_grad, + send_ranks_grad, + recv_ranks_grad, + num_local_output_rows_grad, + rank_grad, + world_size_grad, + scatter_cache_grad, + ) \ No newline at end of file From e4c9b2ed2c2030d14e5f7963dd6ce357d7af4225 Mon Sep 17 00:00:00 2001 From: Shehtab Date: Fri, 12 Dec 2025 14:34:21 -0500 Subject: [PATCH 32/48] Add updated kernels for local scatter-gather + NCCLCommPlan --- DGraph/distributed/Engine.py | 4 +- DGraph/distributed/RankLocalOps.py | 49 +++++- .../distributed/csrc/torch_local_kernels.cu | 9 +- DGraph/distributed/nccl/NCCLBackendEngine.py | 107 +++++++----- DGraph/distributed/nccl/_NCCLCommPlan.py | 142 ++++++++++++++++ DGraph/distributed/nccl/_torch_func_impl.py | 153 ++++++++++++------ 6 files changed, 361 insertions(+), 103 deletions(-) create mode 100644 DGraph/distributed/nccl/_NCCLCommPlan.py diff --git a/DGraph/distributed/Engine.py b/DGraph/distributed/Engine.py index 19e7774..547aada 100644 --- a/DGraph/distributed/Engine.py +++ b/DGraph/distributed/Engine.py @@ -50,7 +50,7 @@ def scatter( output_size: int, rank_mappings: Optional[torch.Tensor] = None, *args, - **kwargs + **kwargs, ) -> torch.Tensor: raise NotImplementedError @@ -60,7 +60,7 @@ def gather( indices: Union[torch.Tensor, torch.LongTensor], rank_mappings: Optional[torch.Tensor] = None, *args, - **kwargs + **kwargs, ) -> torch.Tensor: raise NotImplementedError diff --git a/DGraph/distributed/RankLocalOps.py b/DGraph/distributed/RankLocalOps.py index 243ef16..3d48e49 100644 --- a/DGraph/distributed/RankLocalOps.py +++ b/DGraph/distributed/RankLocalOps.py @@ -19,7 +19,11 @@ import torch.distributed as dist try: - from DGraph.torch_local import local_masked_gather, local_masked_scatter + from DGraph.torch_local import ( + local_masked_gather, + local_masked_scatter, + local_masked_scatter_gather, + ) _LOCAL_OPT_KERNELS_AVAILABLE = True except ImportError: @@ -82,6 +86,49 @@ def OptimizedRankLocalMaskedGather( return output +def OptimizedLocalScatterGather( + src: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + output: torch.Tensor, +): + """ + Performs the operation + + for i in range(len(src_indices)): + output[dst_indices[i]] = src[src_indices[i]] + Args: + src (torch.Tensor): Source tensor + src_indices (torch.Tensor): Source indices + dst_indices (torch.Tensor): Destination indices + output (torch.Tensor): Output tensor + Returns: + torch.Tensor: Output tensor after scatter-gather + """ + + if not _LOCAL_OPT_KERNELS_AVAILABLE: + warnings.warn( + "Optimized local kernels are not available. Falling back to the default implementation." + ) + output[dst_indices] = src[src_indices] + else: + bs = src.shape[0] + num_src_rows = src.shape[1] + num_features = src.shape[-1] + num_output_rows = output.shape[1] + local_masked_scatter_gather( + src, + src_indices.cuda(), + dst_indices.cuda(), + output, + bs, + num_src_rows, + num_features, + num_output_rows, + ) + return output + + def OutOfPlaceRankLocalMaskedGather( _src: torch.Tensor, indices: torch.Tensor, rank_mapping: torch.Tensor, rank: int ) -> torch.Tensor: diff --git a/DGraph/distributed/csrc/torch_local_kernels.cu b/DGraph/distributed/csrc/torch_local_kernels.cu index a91593b..89aad0e 100644 --- a/DGraph/distributed/csrc/torch_local_kernels.cu +++ b/DGraph/distributed/csrc/torch_local_kernels.cu @@ -123,8 +123,7 @@ torch::Tensor local_masked_scatter_gather(torch::Tensor input, const int num_batches, const int num_values_rows, const int num_cols, - const int num_output_rows, - const int rank) + const int num_output_rows) { CHECK_INPUT(input); CHECK_INPUT(indices); @@ -158,8 +157,7 @@ torch::Tensor local_masked_scatter_gather(torch::Tensor input, num_batches, num_values_rows, num_cols, - num_output_rows, - rank); + num_output_rows); } else { @@ -170,8 +168,7 @@ torch::Tensor local_masked_scatter_gather(torch::Tensor input, num_batches, num_values_rows, num_cols, - num_output_rows, - rank); + num_output_rows); } CUDACHECK(cudaGetLastError()); return output; diff --git a/DGraph/distributed/nccl/NCCLBackendEngine.py b/DGraph/distributed/nccl/NCCLBackendEngine.py index beb08bf..4af1eb8 100644 --- a/DGraph/distributed/nccl/NCCLBackendEngine.py +++ b/DGraph/distributed/nccl/NCCLBackendEngine.py @@ -31,16 +31,21 @@ RankLocalRenumberingWithMapping, OptimizedRankLocalMaskedGather, ) +from DGraph.distributed.nccl._torch_func_impl import ( + GatherFunction, + ScatterFunction, + Cached_Static_ScatterFunction, + Cached_Static_GatherFunction, +) + from torch.autograd import Function from DGraph.utils import largest_split +from typing import overload TIMINGS = {"Gather_Index_Forward": [], "Gather_Forward_Local": []} - - - class NCCLBackendEngine(BackendEngine): _is_initialized = False _rank = -1 @@ -102,57 +107,73 @@ def get_local_rank_slice(self, tensor: torch.Tensor, dim: int) -> torch.Tensor: end_index = start_index + local_size return tensor[:, start_index:end_index] + @overload def scatter( self, local_send_tensor: torch.Tensor, indices: torch.Tensor, rank_mappings: torch.Tensor, output_size: int, - cache: Optional[NCCLScatterCache] = None, - *args, - **kwargs, - ) -> torch.Tensor: - send_tensor_shape = local_send_tensor.shape - b_size = send_tensor_shape[0] + ) -> torch.Tensor: ... - world_size = self.get_world_size() - rank = self.get_rank() - assert b_size == 1, "Multi-batch gather disabled for testing" - assert len(send_tensor_shape) == 3, "Currently only support 3D tensors" - assert indices.shape[-1] == rank_mappings.shape[-1], ( - f"Indices shape: {indices.shape} and rank mappings shape: " - + f" {rank_mappings.shape} must match" - ) - assert rank_mappings.shape[0] == 2, ( - "Rank mappings shape[0] expected to be 2, " - + f"but got {rank_mappings.shape[0]}" - ) - assert ( - local_send_tensor.device.type == "cuda" - ), f"Device: {local_send_tensor.device.type} expected cuda" - assert output_size > 0, "Output size must be greater than 0" + @overload + def scatter( + self, + local_send_tensor: torch.Tensor, + *, + cache: NCCLScatterCache, + ) -> torch.Tensor: ... - src_ranks = rank_mappings[0] - dest_ranks = rank_mappings[1] + def scatter( + self, + local_send_tensor: torch.Tensor, + indices: Optional[torch.Tensor] = None, + rank_mappings: Optional[torch.Tensor] = None, + output_size: Optional[int] = None, + cache: Optional[NCCLScatterCache] = None, + ) -> torch.Tensor: - use_cache = cache is not None + if cache is not None: + return Cached_Static_ScatterFunction.apply(local_send_tensor, cache) # type: ignore - if use_cache: - assert type(cache) == NCCLScatterCache - scatter_cache = cache else: - scatter_cache = None - - output_tensor = ScatterFunction.apply( - local_send_tensor, - indices, - src_ranks, - dest_ranks, - output_size, - rank, - world_size, - scatter_cache, - ) + if indices is None or rank_mappings is None or output_size is None: + raise ValueError( + "Indices, rank mappings, and output size must be provided for NCCL backend" + ) + + send_tensor_shape = local_send_tensor.shape + b_size = send_tensor_shape[0] + + world_size = self.get_world_size() + rank = self.get_rank() + assert b_size == 1, "Multi-batch gather disabled for testing" + assert len(send_tensor_shape) == 3, "Currently only support 3D tensors" + assert indices.shape[-1] == rank_mappings.shape[-1], ( + f"Indices shape: {indices.shape} and rank mappings shape: " + + f" {rank_mappings.shape} must match" + ) + assert rank_mappings.shape[0] == 2, ( + "Rank mappings shape[0] expected to be 2, " + + f"but got {rank_mappings.shape[0]}" + ) + assert ( + local_send_tensor.device.type == "cuda" + ), f"Device: {local_send_tensor.device.type} expected cuda" + assert output_size > 0, "Output size must be greater than 0" + + src_ranks = rank_mappings[0] + dest_ranks = rank_mappings[1] + + output_tensor = ScatterFunction.apply( + local_send_tensor, + indices, + src_ranks, + dest_ranks, + output_size, + rank, + world_size, + ) return output_tensor # type: ignore diff --git a/DGraph/distributed/nccl/_NCCLCommPlan.py b/DGraph/distributed/nccl/_NCCLCommPlan.py new file mode 100644 index 0000000..5dcf079 --- /dev/null +++ b/DGraph/distributed/nccl/_NCCLCommPlan.py @@ -0,0 +1,142 @@ +import torch +from dataclasses import dataclass +from typing import List +import torch.distributed as dist + + +@dataclass +class NCCLGraphCommPlan: + """ + Class to store communication plan for distributed gather-scatter (vector addressing) + + Attributes: + rank (int): Local rank + world_size (int): World size + local_num_vertices (int): Number of local vertices + local_src_idx (torch.Tensor): Local source indices for scatter-sum + local_dst_idx (torch.Tensor): Local destination indices for scatter-sum + send_src_idx (torch.Tensor): Source indices to send to other ranks + send_buffer_idx (torch.Tensor): Buffer indices to store data to send to other ranks + send_comm_vector (torch.Tensor): Communication vector of shape [world_size] of messages to send to each rank + recv_dst_idx (torch.Tensor): Destination indices to receive from other ranks + recv_comm_vector (torch.Tensor): Communication vector of shape [world_size] of messages to + """ + + rank: int + world_size: int + + # Allocation meta data + num_local_vertices: int + num_local_edges: int + + # Local edge-vertex mapping + # + # Used for: + # 1) Local scatter-sum (edge -> vertex aggregation) + # y[local_vertex_idx] += x[local_edge_idx] + # 2) Local gather (vertex -> edge gathering) + # y[local_edge_idx] = x[local_vertex_idx] + + local_edge_idx: torch.Tensor + local_vertex_idx: torch.Tensor + + # Boundary edges (data must be sent/received to/from other ranks for gather/scatter) + + boundary_edge_idx: torch.Tensor + boundary_edge_buffer_map: torch.Tensor + boundary_edge_splits: List[int] + + # Boundary vertices (vertices that have edges on other ranks) + boundary_vertex_idx: torch.Tensor + boundary_vertex_splits: List[int] + + def to(self, device: torch.device): + self.local_edge_idx = self.local_edge_idx.to(device) + self.local_vertex_idx = self.local_vertex_idx.to(device) + self.boundary_edge_idx = self.boundary_edge_idx.to(device) + self.boundary_edge_buffer_map = self.boundary_edge_buffer_map.to(device) + self.boundary_vertex_idx = self.boundary_vertex_idx.to(device) + + +def COO_to_NCCLCommPlan( + rank: int, + world_size: int, + global_edges_src: torch.Tensor, + global_edges_dst: torch.Tensor, + vertex_rank_placement: torch.Tensor, + local_edge_list: torch.Tensor, + offset: torch.Tensor, +) -> NCCLGraphCommPlan: + device = local_edge_list.device + my_src_global = global_edges_src[local_edge_list].to(device) + my_dst_global = global_edges_dst[local_edge_list].to(device) + + my_start = offset[rank].item() + my_end = offset[rank + 1].item() + + nodes_per_rank = torch.bincount(vertex_rank_placement, minlength=world_size) + num_local_vertices = int(nodes_per_rank[rank].item()) + num_local_edges = local_edge_list.size(0) + + dest_ranks = torch.bucketize(global_edges_dst, offset, right=True) - 1 + + is_internal = dest_ranks == rank + internal_dst_global = my_dst_global[is_internal] + internal_node_idx = internal_dst_global - offset + + internal_edge_indices = torch.nonzero(is_internal, as_tuple=True)[0] + + remote_mask = ~is_internal + + boundary_edge_indices = torch.nonzero(remote_mask, as_tuple=True)[0] + + b_dst_global = my_dst_global[remote_mask] + b_dest_ranks = dest_ranks[remote_mask] + + sort_idx = torch.argsort(b_dest_ranks) + boundary_edge_indices = boundary_edge_indices[sort_idx] + b_dst_global = b_dst_global[sort_idx] + b_dest_ranks = b_dest_ranks[sort_idx] + + unique_dests, inverse_indices = torch.unique( + torch.stack([b_dest_ranks, b_dst_global]), dim=1, return_inverse=True + ) + unique_ranks = unique_dests[0] + unique_global_ids = unique_dests[1] + + boundary_edge_buffer_map = inverse_indices + + boundary_edge_splits = torch.bincount(unique_ranks, minlength=world_size).tolist() + + recv_counts_tensor = torch.empty(world_size, dtype=torch.long, device=device) + send_counts_tensor = torch.tensor( + boundary_edge_splits, dtype=torch.long, device=device + ) + dist.all_to_all_single(recv_counts_tensor, send_counts_tensor) + boundary_node_splits = recv_counts_tensor.tolist() + + total_recv_nodes = sum(boundary_node_splits) + recv_global_ids = torch.empty(total_recv_nodes, dtype=torch.long, device=device) + + dist.all_to_all_single( + recv_global_ids, + unique_global_ids, + output_split_sizes=boundary_node_splits, + input_split_sizes=boundary_edge_splits, + ) + + boundary_node_idx = recv_global_ids - my_start + + return NCCLGraphCommPlan( + rank=rank, + world_size=world_size, + num_local_vertices=num_local_vertices, + num_local_edges=num_local_edges, + local_edge_idx=internal_edge_indices, + local_vertex_idx=internal_node_idx, + boundary_edge_idx=boundary_edge_indices, + boundary_edge_buffer_map=boundary_edge_buffer_map, + boundary_edge_splits=boundary_edge_splits, + boundary_vertex_idx=boundary_node_idx, + boundary_vertex_splits=boundary_node_splits, + ) diff --git a/DGraph/distributed/nccl/_torch_func_impl.py b/DGraph/distributed/nccl/_torch_func_impl.py index f9f45ff..8e61290 100644 --- a/DGraph/distributed/nccl/_torch_func_impl.py +++ b/DGraph/distributed/nccl/_torch_func_impl.py @@ -4,78 +4,132 @@ import torch.distributed as dist from dataclasses import dataclass from DGraph.distributed.nccl._nccl_cache import NCCLGatherCache, NCCLScatterCache -from DGraph.distributed.RankLocalOps import OptimizedRankLocalMaskedGather - -@dataclass -class GatherCommPlan: - """ - Class to store communication plan for distributed gather - - Attributes: - rank (int): Local rank - world_size (int): World size - local_tensor_size (int): Size of the return local tensor after local and global gather - send_comm_vector (torch.Tensor): Communication vector of shape [world_size] of messages to send to each rank - recv_comm_vector (torch.Tensor): Communication vector of shape [world_size] of messages to receive from each rank - local_send_buffer_size (int): Size of the local send buffer - local_recv_buffer_size (int): Size of the local recv buffer - local_output_rows (torch.Tensor): Local rows that don't need communication - local_gather_indices (torch.Tensor): Indices of rows in local input to write to local_output_rows - comm_output_rows (torch.Tensor): Local rows that need data from remote ranks - comm_output_indices (torch.Tensor): Indices of rows in local input to write to comm_output_rows - - """ - - rank: int - world_size: int - local_tensor_size: int - send_comm_vector: torch.Tensor - recv_comm_vector: torch.Tensor - local_send_buffer_size: int - local_recv_buffer_size: int - local_output_rows: torch.Tensor - local_gather_indices: torch.Tensor - comm_output_rows: torch.Tensor - comm_output_indices: torch.Tensor - - +from DGraph.distributed.RankLocalOps import ( + OptimizedRankLocalMaskedGather, + OptimizedLocalScatterGather, +) +from DGraph.distributed.nccl._NCCLCommPlan import NCCLGraphCommPlan + + class Cached_Static_GatherFunction(Function): @staticmethod def forward( ctx, local_send_tensor: torch.Tensor, - comm_plan: GatherCommPlan, - ): + comm_plan: NCCLGraphCommPlan, + ) -> torch.Tensor: """ - Forward pass for distributed gather - + Forward pass for distributed gather using the common plan to effectively perform: + y[i] = x[indices[i]] + + The process is as follows: + 1) Perform local gather from local vertices to local edges + 2) Gather + Args: ctx (torch.autograd.FunctionContext): Context object local_send_tensor (torch.Tensor): Local send tensor comm_plan (GatherCommPlan): Communication plan """ - assert (len(local_send_tensor.shape) == 3), "Local send tensor must be of shape (batch_size, num_rows, num_features)" + assert ( + len(local_send_tensor.shape) == 3 + ), "Local send tensor must be of shape (batch_size, num_rows, num_features)" ctx.comm_plan = comm_plan + num_features = local_send_tensor.shape[-1] num_batches = local_send_tensor.shape[0] output_tensor = torch.zeros( - num_batches, comm_plan.local_tensor_size, num_features + num_batches, comm_plan.num_local_edges, num_features ).to(local_send_tensor.device) + # Local vertex to edge gather + output_tensor = OptimizedLocalScatterGather( + local_send_tensor, + output_tensor, + comm_plan.local_edge_idx, + comm_plan.local_vertex_idx, + ) + + # To do: Combine this with the local gather above to reduce kernel launches + send_buf = local_send_tensor[:, comm_plan.boundary_edge_idx, :] + + total_recv = sum(comm_plan.boundary_edge_splits) + + recv_buffer = torch.empty(num_batches, total_recv, num_features).to( + local_send_tensor.device + ) + dist.all_to_all_single( + recv_buffer, + send_buf, + output_split_sizes=comm_plan.boundary_edge_splits, + input_split_sizes=comm_plan.boundary_edge_splits, + ) + + output_tensor = OptimizedLocalScatterGather( + recv_buffer, + output_tensor, + comm_plan.boundary_edge_buffer_map, + comm_plan.boundary_vertex_idx, + ) + return output_tensor - + @staticmethod def backward(ctx, grad_output): """ Backward pass for distributed gather - + Args: ctx (torch.autograd.FunctionContext): Context object grad_output (torch.Tensor): Gradient of the output tensor """ comm_plan = ctx.comm_plan - + + grad_output = torch.zeros_like(grad_output) + + return grad_output, None + + +class Cached_Static_ScatterFunction(Function): + @staticmethod + def forward( + ctx, + local_send_tensor: torch.Tensor, + comm_plan: NCCLGraphCommPlan, + ) -> torch.Tensor: + """ + Forward pass for distributed scatter + + Args: + ctx (torch.autograd.FunctionContext): Context object + local_send_tensor (torch.Tensor): Local send tensor + comm_plan (NCCLGraphCommPlan): Communication plan + """ + assert ( + len(local_send_tensor.shape) == 3 + ), "Local send tensor must be of shape (batch_size, num_rows, num_features)" + ctx.comm_plan = comm_plan + num_features = local_send_tensor.shape[-1] + num_batches = local_send_tensor.shape[0] + + output_tensor = torch.zeros( + num_batches, comm_plan.local_tensor_size, num_features + ).to(local_send_tensor.device) + + return output_tensor + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass for distributed scatter + + Args: + ctx (torch.autograd.FunctionContext): Context object + grad_output (torch.Tensor): Gradient of the output tensor + """ + comm_plan = ctx.comm_plan + grad_output = torch.zeros_like(grad_output) return grad_output, None @@ -104,7 +158,6 @@ def forward( torch.tensor(world_size), ) - # Since NCCL is two-sided, we need to push from local rank and pull from # remote rank to get the global gather @@ -120,14 +173,13 @@ def forward( batch_size = 1 num_features = local_send_tensor.shape[2] - local_slice_mask = edge_rank_loc == rank num_local_output_rows = int(local_slice_mask.sum().item()) - recv_tensor = torch.zeros( - batch_size, num_local_output_rows, num_features - ).to(local_send_tensor.device) + recv_tensor = torch.zeros(batch_size, num_local_output_rows, num_features).to( + local_send_tensor.device + ) local_indices_slice = indices[local_slice_mask.unsqueeze(0)] local_rank_mapping = edge_rank_loc[local_slice_mask] @@ -321,7 +373,6 @@ def forward( num_local_output_rows: int, rank: int, world_size: int, - scatter_cache: Optional[NCCLScatterCache] = None, ) -> torch.Tensor: ctx.save_for_backward( @@ -511,4 +562,4 @@ def backward(ctx, grad_output): rank_grad, world_size_grad, scatter_cache_grad, - ) \ No newline at end of file + ) From 8535a409ed01a4b518c637425f41321db9813807 Mon Sep 17 00:00:00 2001 From: Shehtab Date: Fri, 12 Dec 2025 18:41:52 -0500 Subject: [PATCH 33/48] Update the scatter-gather impl to allow set and add aggregation --- .../distributed/csrc/local_data_kernels.cuh | 38 ++++- .../distributed/csrc/torch_local_bindings.cpp | 1 + .../distributed/csrc/torch_local_kernels.cu | 154 ++++++++++++------ DGraph/distributed/nccl/_torch_func_impl.py | 38 ++++- 4 files changed, 175 insertions(+), 56 deletions(-) diff --git a/DGraph/distributed/csrc/local_data_kernels.cuh b/DGraph/distributed/csrc/local_data_kernels.cuh index e81a9b1..e44931c 100644 --- a/DGraph/distributed/csrc/local_data_kernels.cuh +++ b/DGraph/distributed/csrc/local_data_kernels.cuh @@ -252,14 +252,42 @@ namespace Local } } + __device__ __forceinline__ float4 atomicAdd_float4(float4& cur_val, const float4 new_val) + { + atomicAdd(&(cur_val.x), new_val.x); + atomicAdd(&(cur_val.y), new_val.y); + atomicAdd(&(cur_val.z), new_val.z); + atomicAdd(&(cur_val.w), new_val.w); + return cur_val; + } + + __device__ __forceinline__ float4 set_float4(float& cur_val, const float new_val) + { + cur_val = new_val; + return reinterpret_cast(cur_val); + } + + __device__ __forceinline__ float atomicAdd_float(float& cur_val, const float new_val) + { + atomicAdd(&cur_val, new_val); + return cur_val; + } + + __device__ __forceinline__ float set_float(float& cur_val, const float new_val) + { + cur_val = new_val; + return cur_val; + } + /** * * Masked Gather Kernel operation that performs the operation: - Y [mask[i]] = X [indices[i]] + Y [mask[i]] = Op(Y [mask[i]], X [indices[i]]) where Y is the output matrix, X is the input matrix, indices is the index matrix, and mask is the mask matrix. */ + template __global__ void Masked_Scatter_Gather_Kernel( const float *__restrict__ values, const long *__restrict__ indices, @@ -292,7 +320,9 @@ namespace Local for (size_t col = gidx; col < num_cols; col += nthreadsx) { - output[output_offset + output_row * num_cols + col] = values[values_offset + input_row * num_cols + col]; + auto &output_val = output[output_offset + output_row * num_cols + col]; + const auto input_val = values[values_offset + input_row * num_cols + col]; + output_val = op(output_val, input_val); } } } @@ -307,6 +337,7 @@ namespace Local where Y is the output matrix, X is the input matrix, indices is the index matrix, and mask is the mask matrix. */ + template __global__ void Optimized_Masked_Scatter_Gather_Kernel( const float *__restrict__ values, const long *__restrict__ indices, @@ -355,9 +386,10 @@ namespace Local { const float4 values_vec = reinterpret_cast(values)[values_offset + input_row * num_cols / 4 + col]; float4& output_vec = reinterpret_cast(output)[output_offset + output_row * num_cols / 4 + col]; - output_vec = values_vec; + output_vec = op(output_vec, values_vec); } } } } + } // namespace Local \ No newline at end of file diff --git a/DGraph/distributed/csrc/torch_local_bindings.cpp b/DGraph/distributed/csrc/torch_local_bindings.cpp index 6701e6a..fe685b6 100644 --- a/DGraph/distributed/csrc/torch_local_bindings.cpp +++ b/DGraph/distributed/csrc/torch_local_bindings.cpp @@ -22,4 +22,5 @@ PYBIND11_MODULE(torch_local, m) m.def("local_masked_gather", &local_masked_gather, "Masked Gather"); m.def("local_masked_scatter", &local_masked_scatter, "Masked Scatter"); m.def("local_masked_scatter_gather", &local_masked_scatter_gather, "Masked Scatter Gather"); + m.def("local_masked_scatter_add_gather", &local_masked_scatter_add_gather, "Masked Scatter Add Gather"); } diff --git a/DGraph/distributed/csrc/torch_local_kernels.cu b/DGraph/distributed/csrc/torch_local_kernels.cu index 89aad0e..f586a50 100644 --- a/DGraph/distributed/csrc/torch_local_kernels.cu +++ b/DGraph/distributed/csrc/torch_local_kernels.cu @@ -124,52 +124,110 @@ torch::Tensor local_masked_scatter_gather(torch::Tensor input, const int num_values_rows, const int num_cols, const int num_output_rows) +{ + CHECK_INPUT(input); + CHECK_INPUT(indices); + CHECK_INPUT(mask); + CHECK_INPUT(output); + + const float *input_ptr = input.data_ptr(); + const long *indices_ptr = indices.data_ptr(); + const float *mask_ptr = mask.data_ptr(); + float *output_ptr = output.data_ptr(); + + dim3 block_dims, grid_dims; + block_dims.x = 32; + block_dims.y = 32; + block_dims.z = 1; + + const auto num_grids_needed = (num_output_rows + block_dims.y - 1) / block_dims.y; + const auto num_col_grids_needed = (num_cols + block_dims.x - 1) / block_dims.x; + grid_dims.x = num_col_grids_needed < 65535 ? num_col_grids_needed : 65535; + grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535; + grid_dims.z = 1; + + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index()); + + if (num_cols % 4 != 0) + { + Local::Masked_Scatter_Gather_Kernel<<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows); + } + else { - CHECK_INPUT(input); - CHECK_INPUT(indices); - CHECK_INPUT(mask); - CHECK_INPUT(output); - - const float *input_ptr = input.data_ptr(); - const long *indices_ptr = indices.data_ptr(); - const float *mask_ptr = mask.data_ptr(); - float *output_ptr = output.data_ptr(); - - dim3 block_dims, grid_dims; - block_dims.x = 32; - block_dims.y = 32; - block_dims.z = 1; - - const auto num_grids_needed = (num_output_rows + block_dims.y - 1) / block_dims.y; - const auto num_col_grids_needed = (num_cols + block_dims.x - 1) / block_dims.x; - grid_dims.x = num_col_grids_needed < 65535 ? num_col_grids_needed : 65535; - grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535; - grid_dims.z = 1; - - at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index()); - - if (num_cols % 4 != 0) - { - Local::Masked_Scatter_Gather_Kernel<<>>(input_ptr, - indices_ptr, - mask_ptr, - output_ptr, - num_batches, - num_values_rows, - num_cols, - num_output_rows); - } - else - { - Local::Optimized_Masked_Scatter_Gather_Kernel<<>>(input_ptr, - indices_ptr, - mask_ptr, - output_ptr, - num_batches, - num_values_rows, - num_cols, - num_output_rows); - } - CUDACHECK(cudaGetLastError()); - return output; - } \ No newline at end of file + Local::Optimized_Masked_Scatter_Gather_Kernel<<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows); + } + CUDACHECK(cudaGetLastError()); + return output; +} + +torch::Tensor local_masked_scatter_add_gather(torch::Tensor input, + torch::Tensor indices, + torch::Tensor mask, + torch::Tensor output, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows) +{ + CHECK_INPUT(input); + CHECK_INPUT(indices); + CHECK_INPUT(mask); + CHECK_INPUT(output); + + const float *input_ptr = input.data_ptr(); + const long *indices_ptr = indices.data_ptr(); + const float *mask_ptr = mask.data_ptr(); + float *output_ptr = output.data_ptr(); + + dim3 block_dims, grid_dims; + block_dims.x = 32; + block_dims.y = 32; + block_dims.z = 1; + + const auto num_grids_needed = (num_output_rows + block_dims.y - 1) / block_dims.y; + const auto num_col_grids_needed = (num_cols + block_dims.x - 1) / block_dims.x; + grid_dims.x = num_col_grids_needed < 65535 ? num_col_grids_needed : 65535; + grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535; + grid_dims.z = 1; + + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index()); + + if (num_cols % 4 != 0) + { + Local::Masked_Scatter_Gather_Kernel<<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows); + } + else + { + Local::Optimized_Masked_Scatter_Gather_KernelM<<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows); + } + CUDACHECK(cudaGetLastError()); + return output; +} \ No newline at end of file diff --git a/DGraph/distributed/nccl/_torch_func_impl.py b/DGraph/distributed/nccl/_torch_func_impl.py index 8e61290..b9ce3d5 100644 --- a/DGraph/distributed/nccl/_torch_func_impl.py +++ b/DGraph/distributed/nccl/_torch_func_impl.py @@ -11,7 +11,7 @@ from DGraph.distributed.nccl._NCCLCommPlan import NCCLGraphCommPlan -class Cached_Static_GatherFunction(Function): +class CommPlan_GatherFunction(Function): @staticmethod def forward( ctx, @@ -82,16 +82,44 @@ def backward(ctx, grad_output): Args: ctx (torch.autograd.FunctionContext): Context object - grad_output (torch.Tensor): Gradient of the output tensor + grad_output (torch.Tensor): Gradient of the output tensor. + Shape: (batch_size, num_local_edges, num_features) """ comm_plan = ctx.comm_plan + num_features = grad_output.shape[-1] + num_batches = grad_output.shape[0] + device = grad_output.device - grad_output = torch.zeros_like(grad_output) + grad_input = torch.zeros( + num_batches, comm_plan.num_local_vertices, num_features, device=device + ) - return grad_output, None + grad_input = OptimizedLocalScatterGather( + grad_output, + grad_input, + comm_plan.local_vertex_idx, + comm_plan.local_edge_idx, + ) + send_buf = grad_output[:, comm_plan.boundary_vertex_idx, :] + total_recv = sum(comm_plan.boundary_vertex_splits) + recv_buffer = torch.empty(num_batches, total_recv, num_features).to(device) + dist.all_to_all_single( + recv_buffer, + send_buf, + output_split_sizes=comm_plan.boundary_vertex_splits, + input_split_sizes=comm_plan.boundary_edge_splits, + ) + grad_input = OptimizedLocalScatterGather( + recv_buffer, + grad_input, + comm_plan.boundary_edge_buffer_map, + comm_plan.boundary_vertex_idx, + ) + + return grad_input, None -class Cached_Static_ScatterFunction(Function): +class CommPlan_ScatterFunction(Function): @staticmethod def forward( ctx, From 3280ee85b579a1c1822fe2521d7c47607d4d31eb Mon Sep 17 00:00:00 2001 From: Shehtab Date: Sat, 13 Dec 2025 15:31:59 -0500 Subject: [PATCH 34/48] Add ScatterSumGather python wrapper --- DGraph/distributed/RankLocalOps.py | 45 ++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/DGraph/distributed/RankLocalOps.py b/DGraph/distributed/RankLocalOps.py index 3d48e49..d071ce8 100644 --- a/DGraph/distributed/RankLocalOps.py +++ b/DGraph/distributed/RankLocalOps.py @@ -129,6 +129,51 @@ def OptimizedLocalScatterGather( return output +def OptimizedRankLocalScatterSumGather( + src: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + output: torch.Tensor, +): + """ + Performs the operation + + for i in range(len(src_indices)): + output[dst_indices[i]] += src[src_indices[i]] + Args: + src (torch.Tensor): Source tensor + src_indices (torch.Tensor): Source indices + dst_indices (torch.Tensor): Destination indices + output (torch.Tensor): Output tensor + Returns: + torch.Tensor: Output tensor after scatter-gather + """ + + if not _LOCAL_OPT_KERNELS_AVAILABLE: + warnings.warn( + "Optimized local kernels are not available. Falling back to the default implementation." + ) + for i in range(src_indices.shape[0]): + output[:, dst_indices[i], :] += src[:, src_indices[i], :] + else: + bs = src.shape[0] + num_src_rows = src.shape[1] + num_features = src.shape[-1] + num_output_rows = output.shape[1] + local_masked_scatter_gather( + src, + src_indices.cuda(), + dst_indices.cuda(), + output, + bs, + num_src_rows, + num_features, + num_output_rows, + scatter_add=True, + ) + return output + + def OutOfPlaceRankLocalMaskedGather( _src: torch.Tensor, indices: torch.Tensor, rank_mapping: torch.Tensor, rank: int ) -> torch.Tensor: From 13c1205e10ee81164306f28f5f481ae1e160b6fd Mon Sep 17 00:00:00 2001 From: Shehtab Date: Sat, 13 Dec 2025 17:16:55 -0500 Subject: [PATCH 35/48] Fix backward function call on StaticGather --- DGraph/distributed/RankLocalOps.py | 1 - DGraph/distributed/nccl/__init__.py | 7 +- DGraph/distributed/nccl/_torch_func_impl.py | 38 ++++++----- experiments/OGB-LSC/mag240m/DGraph_MAG240M.py | 65 +++++++++++-------- 4 files changed, 60 insertions(+), 51 deletions(-) diff --git a/DGraph/distributed/RankLocalOps.py b/DGraph/distributed/RankLocalOps.py index d071ce8..8cacecc 100644 --- a/DGraph/distributed/RankLocalOps.py +++ b/DGraph/distributed/RankLocalOps.py @@ -169,7 +169,6 @@ def OptimizedRankLocalScatterSumGather( num_src_rows, num_features, num_output_rows, - scatter_add=True, ) return output diff --git a/DGraph/distributed/nccl/__init__.py b/DGraph/distributed/nccl/__init__.py index cf28164..d3ca0b4 100644 --- a/DGraph/distributed/nccl/__init__.py +++ b/DGraph/distributed/nccl/__init__.py @@ -12,9 +12,4 @@ # # SPDX-License-Identifier: (Apache-2.0) from DGraph.distributed.nccl.NCCLBackendEngine import NCCLBackendEngine, TIMINGS -from DGraph.distributed.nccl._nccl_cache import ( - NCCLGatherCache, - NCCLScatterCache, - NCCLScatterCacheGenerator, - NCCLGatherCacheGenerator, -) +from DGraph.distributed.nccl._NCCLCommPlan import NCCLGraphCommPlan, COO_to_NCCLCommPlan diff --git a/DGraph/distributed/nccl/_torch_func_impl.py b/DGraph/distributed/nccl/_torch_func_impl.py index b9ce3d5..d4f4cc9 100644 --- a/DGraph/distributed/nccl/_torch_func_impl.py +++ b/DGraph/distributed/nccl/_torch_func_impl.py @@ -7,6 +7,7 @@ from DGraph.distributed.RankLocalOps import ( OptimizedRankLocalMaskedGather, OptimizedLocalScatterGather, + OptimizedRankLocalScatterSumGather, ) from DGraph.distributed.nccl._NCCLCommPlan import NCCLGraphCommPlan @@ -45,10 +46,10 @@ def forward( # Local vertex to edge gather output_tensor = OptimizedLocalScatterGather( - local_send_tensor, - output_tensor, - comm_plan.local_edge_idx, - comm_plan.local_vertex_idx, + src=local_send_tensor, + src_indices=comm_plan.local_edge_idx, + dst_indices=comm_plan.local_vertex_idx, + output=output_tensor, ) # To do: Combine this with the local gather above to reduce kernel launches @@ -67,10 +68,10 @@ def forward( ) output_tensor = OptimizedLocalScatterGather( - recv_buffer, - output_tensor, - comm_plan.boundary_edge_buffer_map, - comm_plan.boundary_vertex_idx, + src=recv_buffer, + src_indices=comm_plan.boundary_edge_buffer_map, + dst_indices=comm_plan.boundary_edge_idx, + output=output_tensor, ) return output_tensor @@ -94,12 +95,13 @@ def backward(ctx, grad_output): num_batches, comm_plan.num_local_vertices, num_features, device=device ) - grad_input = OptimizedLocalScatterGather( - grad_output, - grad_input, - comm_plan.local_vertex_idx, - comm_plan.local_edge_idx, + grad_input = OptimizedRankLocalScatterSumGather( + src=grad_output, + output=grad_input, + src_indices=comm_plan.local_edge_idx, + dst_indices=comm_plan.local_vertex_idx, ) + send_buf = grad_output[:, comm_plan.boundary_vertex_idx, :] total_recv = sum(comm_plan.boundary_vertex_splits) recv_buffer = torch.empty(num_batches, total_recv, num_features).to(device) @@ -109,11 +111,11 @@ def backward(ctx, grad_output): output_split_sizes=comm_plan.boundary_vertex_splits, input_split_sizes=comm_plan.boundary_edge_splits, ) - grad_input = OptimizedLocalScatterGather( - recv_buffer, - grad_input, - comm_plan.boundary_edge_buffer_map, - comm_plan.boundary_vertex_idx, + grad_input = OptimizedRankLocalScatterSumGather( + src=recv_buffer, + output=grad_input, + src_indices=comm_plan.boundary_edge_buffer_map, + dst_indices=comm_plan.boundary_vertex_idx, ) return grad_input, None diff --git a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py index c34db6c..458048a 100644 --- a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py +++ b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py @@ -19,6 +19,7 @@ from tqdm import tqdm import os.path as osp from DGraph.Communicator import Communicator +from DGraph.distributed.nccl import NCCLGraphCommPlan, COO_to_NCCLCommPlan def get_col_slice(x, start_row_idx, end_row_idx, start_col_idx, end_col_idx): @@ -61,6 +62,7 @@ def get_rank_mappings(num_nodes, world_size, rank): rank_mappings[start:end] = r return rank_mappings + def edge_mapping_from_vertex_mapping(edge_index, src_rank_mappings, dst_rank_mappings): # directed edges, so edge_index[0] -> edge_index[1] src_indices = edge_index[0] @@ -72,6 +74,7 @@ def edge_mapping_from_vertex_mapping(edge_index, src_rank_mappings, dst_rank_map dest_data_mappings = dst_rank_mappings[dest_indices] return (src_data_mappings, dest_data_mappings) + def get_edge_mappings(src_indices, dst_indices, rank_mappings): edge_mappings = torch.zeros_like(src_indices) # The edges are mapped to the rank of the destination node @@ -169,36 +172,42 @@ def __init__( self.institution_rank_mappings, ) - self.train_mask = self.dataset.get_idx_split('train') - self.val_mask = self.dataset.get_idx_split('valid') - self.test_mask = self.dataset.get_idx_split('test-dev') + self.train_mask = self.dataset.get_idx_split("train") + self.val_mask = self.dataset.get_idx_split("valid") + self.test_mask = self.dataset.get_idx_split("test-dev") local_papers_mask = self.paper_rank_mappings == self.rank local_authors_mask = self.author_rank_mappings == self.rank local_institutions_mask = self.institution_rank_mappings == self.rank - self.num_local_papers = int( - local_papers_mask.sum() - ) + self.num_local_papers = int(local_papers_mask.sum()) self.generate_feature_data() - self.paper_features = torch.from_numpy(self.dataset.paper_feat[local_papers_mask]) + self.paper_features = torch.from_numpy( + self.dataset.paper_feat[local_papers_mask] + ) path = self.dataset.dir - self.author_features = torch.from_numpy(np.memmap( - filename=path + "/author_feat.npy", - mode="r", - dtype=np.float16, - shape=(self.num_authors, self.num_features), - )[local_authors_mask]) - self.institution_features = torch.from_numpy(np.memmap( - filename=path + "/institution_feat.npy", - mode="r", - dtype=np.float16, - shape=(self.num_institutions, self.num_features), - )[local_institutions_mask]) + self.author_features = torch.from_numpy( + np.memmap( + filename=path + "/author_feat.npy", + mode="r", + dtype=np.float16, + shape=(self.num_authors, self.num_features), + )[local_authors_mask] + ) + self.institution_features = torch.from_numpy( + np.memmap( + filename=path + "/institution_feat.npy", + mode="r", + dtype=np.float16, + shape=(self.num_institutions, self.num_features), + )[local_institutions_mask] + ) self.y = torch.from_numpy(self.dataset.paper_label) - self.paper_2_paper_edges = torch.from_numpy(self.dataset.edge_index('paper', 'cites', 'paper')) + self.paper_2_paper_edges = torch.from_numpy( + self.dataset.edge_index("paper", "cites", "paper") + ) ( paper_2_paper_src_data_mappings, paper_2_paper_dest_data_mappings, @@ -210,7 +219,9 @@ def __init__( self.paper_src_data_mappings = paper_2_paper_src_data_mappings self.paper_dest_data_mappings = paper_2_paper_dest_data_mappings - self.author_2_paper_edges = torch.from_numpy(self.dataset.edge_index('author', 'writes', 'paper')) + self.author_2_paper_edges = torch.from_numpy( + self.dataset.edge_index("author", "writes", "paper") + ) ( author_2_paper_src_data_mappings, author_2_paper_dest_data_mappings, @@ -222,7 +233,9 @@ def __init__( self.author_2_paper_src_data_mappings = author_2_paper_src_data_mappings self.author_2_paper_dest_data_mappings = author_2_paper_dest_data_mappings - self.author_2_institution_edges = torch.from_numpy(self.dataset.edge_index('author', 'institution')) + self.author_2_institution_edges = torch.from_numpy( + self.dataset.edge_index("author", "institution") + ) ( author_2_institution_src_data_mappings, author_2_institution_dest_data_mappings, @@ -332,9 +345,7 @@ def get_vertex_rank_mask(self, mask_type: str) -> Tuple[torch.Tensor, torch.Tens # Get the ranks of the vertices # paper_vertex_rank_mapping -> vector of size num_papers, # where each entry is the location / rank of the vertex - paper_rank_mappings = self.paper_rank_mappings.to( - global_int_mask.device - ) + paper_rank_mappings = self.paper_rank_mappings.to(global_int_mask.device) vertex_ranks = paper_rank_mappings[global_int_mask] # vertex_ranks is location of the vertices in the global_int_mask vertex_ranks_mask = vertex_ranks == self.rank @@ -389,7 +400,9 @@ def to(self, device): """ self.paper_features = self.paper_features.to(device, dtype=torch.float32) self.author_features = self.author_features.to(device, dtype=torch.float32) - self.institution_features = self.institution_features.to(device, dtype=torch.float32) + self.institution_features = self.institution_features.to( + device, dtype=torch.float32 + ) self.y = self.y.to(device) self.train_mask = self.train_mask.to(device) self.val_mask = self.val_mask.to(device) From ef2efbbf6535a8ec8bb07e00a351892200d67da7 Mon Sep 17 00:00:00 2001 From: Shehtab Date: Sat, 13 Dec 2025 17:53:14 -0500 Subject: [PATCH 36/48] Fixed Scatter forward --- DGraph/distributed/nccl/_torch_func_impl.py | 40 ++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/DGraph/distributed/nccl/_torch_func_impl.py b/DGraph/distributed/nccl/_torch_func_impl.py index d4f4cc9..628e399 100644 --- a/DGraph/distributed/nccl/_torch_func_impl.py +++ b/DGraph/distributed/nccl/_torch_func_impl.py @@ -140,13 +140,51 @@ def forward( len(local_send_tensor.shape) == 3 ), "Local send tensor must be of shape (batch_size, num_rows, num_features)" ctx.comm_plan = comm_plan + num_features = local_send_tensor.shape[-1] num_batches = local_send_tensor.shape[0] output_tensor = torch.zeros( - num_batches, comm_plan.local_tensor_size, num_features + num_batches, comm_plan.num_local_vertices, num_features ).to(local_send_tensor.device) + output_tensor = OptimizedRankLocalScatterSumGather( + src=local_send_tensor, + output=output_tensor, + src_indices=comm_plan.local_edge_idx, + dst_indices=comm_plan.local_vertex_idx, + ) + + total_send_rows = sum(comm_plan.boundary_edge_splits) + + send_buf = torch.zeros( + num_batches, total_send_rows, num_features, device=local_send_tensor.device + ) + + send_buf = OptimizedRankLocalScatterSumGather( + src=local_send_tensor, + output=send_buf, + src_indices=comm_plan.boundary_edge_idx, + dst_indices=comm_plan.boundary_edge_buffer_map, + ) + + total_recv_rows = sum(comm_plan.boundary_vertex_splits) + recv_buffer = torch.empty( + num_batches, total_recv_rows, num_features, device=local_send_tensor.device + ) + dist.all_to_all_single( + recv_buffer, + send_buf, + output_split_sizes=comm_plan.boundary_vertex_splits, + input_split_sizes=comm_plan.boundary_edge_splits, + ) + output_tensor = OptimizedRankLocalScatterSumGather( + src=recv_buffer, + output=output_tensor, + src_indices=comm_plan.boundary_edge_buffer_map, + dst_indices=comm_plan.boundary_vertex_idx, + ) + return output_tensor @staticmethod From 22ed522a8b3ad8fa3d84cad9a9bfae9e5a49c1c0 Mon Sep 17 00:00:00 2001 From: Shehtab Date: Tue, 16 Dec 2025 23:16:17 -0500 Subject: [PATCH 37/48] Updated scatter function impl --- DGraph/distributed/nccl/_torch_func_impl.py | 44 ++++++++++++++++++++- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/DGraph/distributed/nccl/_torch_func_impl.py b/DGraph/distributed/nccl/_torch_func_impl.py index 628e399..3bb51ef 100644 --- a/DGraph/distributed/nccl/_torch_func_impl.py +++ b/DGraph/distributed/nccl/_torch_func_impl.py @@ -197,10 +197,50 @@ def backward(ctx, grad_output): grad_output (torch.Tensor): Gradient of the output tensor """ comm_plan = ctx.comm_plan + num_features = grad_output.shape[-1] + num_batches = grad_output.shape[0] + device = grad_output.device + num_output_rows = comm_plan.num_local_edges - grad_output = torch.zeros_like(grad_output) + grad_input = torch.zeros( + num_batches, num_output_rows, num_features, device=device + ) + + grad_input = OptimizedLocalScatterGather( + src=grad_output, + src_indices=comm_plan.local_vertex_idx, + dst_indices=comm_plan.local_edge_idx, + output=grad_input, + ) - return grad_output, None + num_send_rows = sum(comm_plan.boundary_vertex_splits) + send_buf_locs = torch.arange(num_send_rows, device=device) + send_buf = torch.zeros(num_batches, num_send_rows, num_features, device=device) + send_buf = OptimizedLocalScatterGather( + src=grad_output, + src_indices=comm_plan.boundary_vertex_idx, + dst_indices=send_buf_locs, + output=send_buf, + ) + total_recv_rows = sum(comm_plan.boundary_edge_splits) + recv_buffer = torch.empty( + num_batches, total_recv_rows, num_features, device=device + ) + dist.all_to_all_single( + recv_buffer, + send_buf, + output_split_sizes=comm_plan.boundary_edge_splits, + input_split_sizes=comm_plan.boundary_vertex_splits, + ) + + grad_input = OptimizedLocalScatterGather( + src=recv_buffer, + src_indices=comm_plan.boundary_edge_idx, + dst_indices=comm_plan.boundary_edge_buffer_map, + output=grad_input, + ) + + return grad_input, None class GatherFunction(Function): From dc0ded2a9a1dee552375a004218518b22483d671 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 17 Dec 2025 11:09:03 -0800 Subject: [PATCH 38/48] Fix build issues and change op struct --- .../distributed/csrc/local_data_kernels.cuh | 101 +++++++++--------- .../distributed/csrc/torch_local_kernels.cu | 12 +-- DGraph/distributed/include/torch_local.hpp | 20 +++- 3 files changed, 75 insertions(+), 58 deletions(-) diff --git a/DGraph/distributed/csrc/local_data_kernels.cuh b/DGraph/distributed/csrc/local_data_kernels.cuh index e44931c..1b2ea2b 100644 --- a/DGraph/distributed/csrc/local_data_kernels.cuh +++ b/DGraph/distributed/csrc/local_data_kernels.cuh @@ -252,42 +252,36 @@ namespace Local } } - __device__ __forceinline__ float4 atomicAdd_float4(float4& cur_val, const float4 new_val) - { - atomicAdd(&(cur_val.x), new_val.x); - atomicAdd(&(cur_val.y), new_val.y); - atomicAdd(&(cur_val.z), new_val.z); - atomicAdd(&(cur_val.w), new_val.w); - return cur_val; - } + - __device__ __forceinline__ float4 set_float4(float& cur_val, const float new_val) + template + struct FloatAtomicAddOp { - cur_val = new_val; - return reinterpret_cast(cur_val); - } + __device__ __forceinline__ void operator()(T *cur_addr, const T new_val) + { + atomicAdd(cur_addr, new_val); + } + }; - __device__ __forceinline__ float atomicAdd_float(float& cur_val, const float new_val) + template + struct FloatSetOp { - atomicAdd(&cur_val, new_val); - return cur_val; - } + __device__ __forceinline__ void operator()(T *cur_addr, const T new_val) + { + *cur_addr = new_val; + } + }; - __device__ __forceinline__ float set_float(float& cur_val, const float new_val) - { - cur_val = new_val; - return cur_val; - } /** - * + * * Masked Gather Kernel operation that performs the operation: Y [mask[i]] = Op(Y [mask[i]], X [indices[i]]) - + where Y is the output matrix, X is the input matrix, indices is the index matrix, and mask is the mask matrix. */ - template + template __global__ void Masked_Scatter_Gather_Kernel( const float *__restrict__ values, const long *__restrict__ indices, @@ -306,6 +300,8 @@ namespace Local const size_t nthreadsy = gridDim.y * blockDim.y; const size_t nthreadsz = gridDim.z * blockDim.z; + Op op; + for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz) { const auto values_offset = mb_i * num_cols * num_indices; @@ -317,12 +313,12 @@ namespace Local { const auto output_row = mask[mask_offset + row]; const auto input_row = indices[ind_offset + row]; - + for (size_t col = gidx; col < num_cols; col += nthreadsx) { - auto &output_val = output[output_offset + output_row * num_cols + col]; + auto *output_addr = &output[output_offset + output_row * num_cols + col]; const auto input_val = values[values_offset + input_row * num_cols + col]; - output_val = op(output_val, input_val); + op(output_addr, input_val); } } } @@ -334,20 +330,20 @@ namespace Local Y [mask[i]] = X [indices[i]] This kernel is optimized for the case where the num_cols is a multiple of 4. - + where Y is the output matrix, X is the input matrix, indices is the index matrix, and mask is the mask matrix. */ - template - __global__ void Optimized_Masked_Scatter_Gather_Kernel( - const float *__restrict__ values, - const long *__restrict__ indices, - const long *__restrict__ mask, - float *__restrict__ output, - const int mini_batch_size, - const int num_indices, - const int num_cols, - const int num_output_rows) - { + template + __global__ void Optimized_Masked_Scatter_Gather_Kernel( + const float *__restrict__ values, + const long *__restrict__ indices, + const long *__restrict__ mask, + float *__restrict__ output, + const int mini_batch_size, + const int num_indices, + const int num_cols, + const int num_output_rows) + { const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; @@ -357,6 +353,8 @@ namespace Local const size_t nthreadsz = gridDim.z * blockDim.z; // Grid-stride loop over mini-batches + + Op binary_operator; for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz) { const auto values_offset = mb_i * num_cols / 4 * num_indices; @@ -368,28 +366,29 @@ namespace Local for (size_t row = gidy; row < num_indices; row += nthreadsy) { long output_row, input_row; - - if (threadIdx.x == 0) { - output_row = mask[mask_offset + row]; - input_row = indices[ind_offset + row]; + + if (threadIdx.x == 0) + { + output_row = mask[mask_offset + row]; + input_row = indices[ind_offset + row]; } output_row = __shfl_sync(0xFFFFFFFF, output_row, 0); input_row = __shfl_sync(0xFFFFFFFF, input_row, 0); - - output_row = mask[mask_offset + row]; + + output_row = mask[mask_offset + row]; input_row = indices[ind_offset + row]; size_t col = gidx; - + for (; col < num_cols / 4; col += nthreadsx) { - const float4 values_vec = reinterpret_cast(values)[values_offset + input_row * num_cols / 4 + col]; - float4& output_vec = reinterpret_cast(output)[output_offset + output_row * num_cols / 4 + col]; - output_vec = op(output_vec, values_vec); + const float4 values_vec = reinterpret_cast(values)[values_offset + input_row * num_cols / 4 + col]; + float4* output_addr = &reinterpret_cast(output)[output_offset + output_row * num_cols / 4 + col]; + binary_operator(output_addr, values_vec); } } } - } - + } + } // namespace Local \ No newline at end of file diff --git a/DGraph/distributed/csrc/torch_local_kernels.cu b/DGraph/distributed/csrc/torch_local_kernels.cu index f586a50..896050f 100644 --- a/DGraph/distributed/csrc/torch_local_kernels.cu +++ b/DGraph/distributed/csrc/torch_local_kernels.cu @@ -132,7 +132,7 @@ torch::Tensor local_masked_scatter_gather(torch::Tensor input, const float *input_ptr = input.data_ptr(); const long *indices_ptr = indices.data_ptr(); - const float *mask_ptr = mask.data_ptr(); + const long *mask_ptr = mask.data_ptr(); float *output_ptr = output.data_ptr(); dim3 block_dims, grid_dims; @@ -150,7 +150,7 @@ torch::Tensor local_masked_scatter_gather(torch::Tensor input, if (num_cols % 4 != 0) { - Local::Masked_Scatter_Gather_Kernel<<>>(input_ptr, + Local::Masked_Scatter_Gather_Kernel><<>>(input_ptr, indices_ptr, mask_ptr, output_ptr, @@ -161,7 +161,7 @@ torch::Tensor local_masked_scatter_gather(torch::Tensor input, } else { - Local::Optimized_Masked_Scatter_Gather_Kernel<<>>(input_ptr, + Local::Optimized_Masked_Scatter_Gather_Kernel><<>>(input_ptr, indices_ptr, mask_ptr, output_ptr, @@ -190,7 +190,7 @@ torch::Tensor local_masked_scatter_add_gather(torch::Tensor input, const float *input_ptr = input.data_ptr(); const long *indices_ptr = indices.data_ptr(); - const float *mask_ptr = mask.data_ptr(); + const long *mask_ptr = mask.data_ptr(); float *output_ptr = output.data_ptr(); dim3 block_dims, grid_dims; @@ -208,7 +208,7 @@ torch::Tensor local_masked_scatter_add_gather(torch::Tensor input, if (num_cols % 4 != 0) { - Local::Masked_Scatter_Gather_Kernel<<>>(input_ptr, + Local::Masked_Scatter_Gather_Kernel><<>>(input_ptr, indices_ptr, mask_ptr, output_ptr, @@ -219,7 +219,7 @@ torch::Tensor local_masked_scatter_add_gather(torch::Tensor input, } else { - Local::Optimized_Masked_Scatter_Gather_KernelM<<>>(input_ptr, + Local::Optimized_Masked_Scatter_Gather_Kernel><<>>(input_ptr, indices_ptr, mask_ptr, output_ptr, diff --git a/DGraph/distributed/include/torch_local.hpp b/DGraph/distributed/include/torch_local.hpp index f780160..4666d38 100644 --- a/DGraph/distributed/include/torch_local.hpp +++ b/DGraph/distributed/include/torch_local.hpp @@ -19,4 +19,22 @@ torch::Tensor local_masked_scatter(torch::Tensor input, const int num_values_rows, const int num_cols, const int num_output_rows, - const int rank); \ No newline at end of file + const int rank); + +torch::Tensor local_masked_scatter_gather(torch::Tensor input, + torch::Tensor indices, + torch::Tensor rank_local_placement, + torch::Tensor output, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows); + +torch::Tensor local_masked_scatter_add_gather(torch::Tensor grad_output, + torch::Tensor input, + torch::Tensor indices, + torch::Tensor rank_local_placement, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows); \ No newline at end of file From f559bcfb2828e93290cd205f3208a87d76982108 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 17 Dec 2025 11:30:08 -0800 Subject: [PATCH 39/48] Remove unnecesary imports and remove cache implementation --- DGraph/distributed/nccl/NCCLBackendEngine.py | 61 ++++++++++---------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/DGraph/distributed/nccl/NCCLBackendEngine.py b/DGraph/distributed/nccl/NCCLBackendEngine.py index 4af1eb8..b8d2fd0 100644 --- a/DGraph/distributed/nccl/NCCLBackendEngine.py +++ b/DGraph/distributed/nccl/NCCLBackendEngine.py @@ -16,26 +16,12 @@ import torch import torch.distributed as dist from DGraph.distributed.Engine import BackendEngine -from DGraph.distributed.nccl._indices_utils import ( - _generate_local_rank_mapping, - _get_local_unique_recv_placement, -) -from DGraph.distributed.nccl._nccl_cache import NCCLGatherCache, NCCLScatterCache -from DGraph.distributed.nccl.alltoallv_impl import ( - _nccl_alltoall_v, - _nccl_alltoallv_with_dict, -) -from DGraph.distributed.RankLocalOps import ( - RankLocalMaskedGather, - RankLocalMaskedScatter, - RankLocalRenumberingWithMapping, - OptimizedRankLocalMaskedGather, -) +from DGraph.distributed.nccl._NCCLCommPlan import NCCLGraphCommPlan from DGraph.distributed.nccl._torch_func_impl import ( GatherFunction, ScatterFunction, - Cached_Static_ScatterFunction, - Cached_Static_GatherFunction, + CommPlan_ScatterFunction, + CommPlan_GatherFunction, ) from torch.autograd import Function @@ -121,7 +107,7 @@ def scatter( self, local_send_tensor: torch.Tensor, *, - cache: NCCLScatterCache, + comm_plan: NCCLGraphCommPlan, ) -> torch.Tensor: ... def scatter( @@ -130,12 +116,11 @@ def scatter( indices: Optional[torch.Tensor] = None, rank_mappings: Optional[torch.Tensor] = None, output_size: Optional[int] = None, - cache: Optional[NCCLScatterCache] = None, + comm_plan: Optional[NCCLGraphCommPlan] = None, ) -> torch.Tensor: - if cache is not None: - return Cached_Static_ScatterFunction.apply(local_send_tensor, cache) # type: ignore - + if comm_plan is not None: + return CommPlan_ScatterFunction.apply(local_send_tensor, comm_plan) # type: ignore else: if indices is None or rank_mappings is None or output_size is None: raise ValueError( @@ -177,12 +162,30 @@ def scatter( return output_tensor # type: ignore + @overload def gather( self, local_send_tensor: torch.Tensor, indices: torch.Tensor, rank_mappings: torch.Tensor, - cache: Optional[NCCLGatherCache] = None, + **kwargs, + ) -> torch.Tensor: ... + + @overload + def gather( + self, + local_send_tensor: torch.Tensor, + *, + comm_plan: NCCLGraphCommPlan, + **kwargs, + ) -> torch.Tensor: ... + + def gather( + self, + local_send_tensor: torch.Tensor, + indices: torch.Tensor, + rank_mappings: torch.Tensor, + comm_plan: Optional[NCCLGraphCommPlan] = None, **kwargs, ) -> torch.Tensor: """Gather the distributed tensor across all ranks according to the indices @@ -208,6 +211,9 @@ def gather( rank_mappings (torch.Tensor): The rank mappings for the gather operation """ + if comm_plan is not None: + return CommPlan_GatherFunction.apply(local_send_tensor, comm_plan) # type: ignore + send_tensor_shape = local_send_tensor.shape b_size = send_tensor_shape[0] world_size = self.get_world_size() @@ -231,14 +237,6 @@ def gather( send_rank = rank_mappings[0] recv_rank = rank_mappings[1] - use_cache = cache is not None - - if use_cache: - assert type(cache) == NCCLGatherCache, f"Invalid cache type {type(cache)}" - gather_cache = cache - else: - gather_cache = None - output_tensor = GatherFunction.apply( local_send_tensor, indices, @@ -246,7 +244,6 @@ def gather( recv_rank, rank, world_size, - gather_cache, ) dist.barrier() From f2c3915cada8f38b61b21248a24b9cc23f481d66 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 17 Dec 2025 13:48:38 -0800 Subject: [PATCH 40/48] Decompose internal function to reduce memory usage - Allows GC to deallocate transient maasking arrays --- DGraph/distributed/nccl/_NCCLCommPlan.py | 81 ++++++++++++++++++------ 1 file changed, 60 insertions(+), 21 deletions(-) diff --git a/DGraph/distributed/nccl/_NCCLCommPlan.py b/DGraph/distributed/nccl/_NCCLCommPlan.py index 5dcf079..9ca39f4 100644 --- a/DGraph/distributed/nccl/_NCCLCommPlan.py +++ b/DGraph/distributed/nccl/_NCCLCommPlan.py @@ -58,6 +58,33 @@ def to(self, device: torch.device): self.boundary_vertex_idx = self.boundary_vertex_idx.to(device) +def compute_edge_slices(dest_ranks, rank, my_dst_global, offset): + + is_internal = dest_ranks == rank + internal_dst_global = my_dst_global[is_internal] + internal_node_idx = internal_dst_global - offset[rank + 1] + + internal_edge_indices = torch.nonzero(is_internal, as_tuple=True)[0] + + remote_mask = ~is_internal + + boundary_edge_indices = torch.nonzero(remote_mask, as_tuple=True)[0] + + print(f"rank {rank} has {torch.sum(is_internal).item()} internal edges") + print(f"rank {rank} has {torch.sum(remote_mask).item()} remote edges") + + b_dst_global = my_dst_global[remote_mask] + b_dest_ranks = dest_ranks[remote_mask] + + return ( + internal_node_idx, + internal_edge_indices, + b_dst_global, + b_dest_ranks, + boundary_edge_indices, + ) + + def COO_to_NCCLCommPlan( rank: int, world_size: int, @@ -67,6 +94,22 @@ def COO_to_NCCLCommPlan( local_edge_list: torch.Tensor, offset: torch.Tensor, ) -> NCCLGraphCommPlan: + """ + + Convert COO (Coordinate List) format graph to NCCLGraphCommPlan for distributed gather-scatter operations. + + Args: + rank (int): Local rank + world_size (int): World size + global_edges_src (torch.Tensor): Global source indices of edges + global_edges_dst (torch.Tensor): Global destination indices of edges + vertex_rank_placement (torch.Tensor): Rank placement of vertices + local_edge_list (torch.Tensor): List of indices of local edges + offset (torch.Tensor): Offset for each rank. + The vertices are partitioned among ranks in a contiguous manner. + All vertices in the range [offset[rank], offset[rank + 1]) are assigned to the rank. + + """ device = local_edge_list.device my_src_global = global_edges_src[local_edge_list].to(device) my_dst_global = global_edges_dst[local_edge_list].to(device) @@ -78,20 +121,16 @@ def COO_to_NCCLCommPlan( num_local_vertices = int(nodes_per_rank[rank].item()) num_local_edges = local_edge_list.size(0) - dest_ranks = torch.bucketize(global_edges_dst, offset, right=True) - 1 - - is_internal = dest_ranks == rank - internal_dst_global = my_dst_global[is_internal] - internal_node_idx = internal_dst_global - offset - - internal_edge_indices = torch.nonzero(is_internal, as_tuple=True)[0] + dest_ranks = torch.bucketize(my_dst_global, offset, right=True) - 1 - remote_mask = ~is_internal - - boundary_edge_indices = torch.nonzero(remote_mask, as_tuple=True)[0] - - b_dst_global = my_dst_global[remote_mask] - b_dest_ranks = dest_ranks[remote_mask] + # Seperate this out to reduce memory usage + ( + internal_node_idx, + internal_edge_indices, + b_dst_global, + b_dest_ranks, + boundary_edge_indices, + ) = compute_edge_slices(dest_ranks, rank, my_dst_global, offset) sort_idx = torch.argsort(b_dest_ranks) boundary_edge_indices = boundary_edge_indices[sort_idx] @@ -108,22 +147,22 @@ def COO_to_NCCLCommPlan( boundary_edge_splits = torch.bincount(unique_ranks, minlength=world_size).tolist() - recv_counts_tensor = torch.empty(world_size, dtype=torch.long, device=device) + recv_counts_tensor = torch.zeros(world_size, dtype=torch.long, device=device) send_counts_tensor = torch.tensor( boundary_edge_splits, dtype=torch.long, device=device ) - dist.all_to_all_single(recv_counts_tensor, send_counts_tensor) + # dist.all_to_all_single(recv_counts_tensor, send_counts_tensor) boundary_node_splits = recv_counts_tensor.tolist() total_recv_nodes = sum(boundary_node_splits) recv_global_ids = torch.empty(total_recv_nodes, dtype=torch.long, device=device) - dist.all_to_all_single( - recv_global_ids, - unique_global_ids, - output_split_sizes=boundary_node_splits, - input_split_sizes=boundary_edge_splits, - ) + # dist.all_to_all_single( + # recv_global_ids, + # unique_global_ids, + # output_split_sizes=boundary_node_splits, + # input_split_sizes=boundary_edge_splits, + # ) boundary_node_idx = recv_global_ids - my_start From bf8c6ac90c0b3397da50e50e84a2703cbe3b01c3 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 17 Dec 2025 14:17:51 -0800 Subject: [PATCH 41/48] Optimized CommPlan generator - Remove unnecessary arrays and memory inefficiencies --- DGraph/distributed/nccl/_NCCLCommPlan.py | 45 +++++++++++------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/DGraph/distributed/nccl/_NCCLCommPlan.py b/DGraph/distributed/nccl/_NCCLCommPlan.py index 9ca39f4..1c0232c 100644 --- a/DGraph/distributed/nccl/_NCCLCommPlan.py +++ b/DGraph/distributed/nccl/_NCCLCommPlan.py @@ -70,9 +70,6 @@ def compute_edge_slices(dest_ranks, rank, my_dst_global, offset): boundary_edge_indices = torch.nonzero(remote_mask, as_tuple=True)[0] - print(f"rank {rank} has {torch.sum(is_internal).item()} internal edges") - print(f"rank {rank} has {torch.sum(remote_mask).item()} remote edges") - b_dst_global = my_dst_global[remote_mask] b_dest_ranks = dest_ranks[remote_mask] @@ -85,12 +82,20 @@ def compute_edge_slices(dest_ranks, rank, my_dst_global, offset): ) +def fast_2D_unique(indices_1, indices_2): + packed_keys = indices_1.to(torch.int64) << 32 | indices_2.to(torch.int64) + unique_packed, inverse_indices = torch.unique( + packed_keys, return_inverse=True, sorted=False + ) + unique_1 = unique_packed >> 32 + unique_2 = unique_packed & 0xFFFFFFFF + return unique_1, unique_2, inverse_indices + + def COO_to_NCCLCommPlan( rank: int, world_size: int, - global_edges_src: torch.Tensor, global_edges_dst: torch.Tensor, - vertex_rank_placement: torch.Tensor, local_edge_list: torch.Tensor, offset: torch.Tensor, ) -> NCCLGraphCommPlan: @@ -111,14 +116,11 @@ def COO_to_NCCLCommPlan( """ device = local_edge_list.device - my_src_global = global_edges_src[local_edge_list].to(device) my_dst_global = global_edges_dst[local_edge_list].to(device) my_start = offset[rank].item() my_end = offset[rank + 1].item() - - nodes_per_rank = torch.bincount(vertex_rank_placement, minlength=world_size) - num_local_vertices = int(nodes_per_rank[rank].item()) + num_local_vertices = int(my_end - my_start) num_local_edges = local_edge_list.size(0) dest_ranks = torch.bucketize(my_dst_global, offset, right=True) - 1 @@ -132,16 +134,9 @@ def COO_to_NCCLCommPlan( boundary_edge_indices, ) = compute_edge_slices(dest_ranks, rank, my_dst_global, offset) - sort_idx = torch.argsort(b_dest_ranks) - boundary_edge_indices = boundary_edge_indices[sort_idx] - b_dst_global = b_dst_global[sort_idx] - b_dest_ranks = b_dest_ranks[sort_idx] - - unique_dests, inverse_indices = torch.unique( - torch.stack([b_dest_ranks, b_dst_global]), dim=1, return_inverse=True + unique_ranks, unique_global_ids, inverse_indices = fast_2D_unique( + b_dest_ranks, b_dst_global ) - unique_ranks = unique_dests[0] - unique_global_ids = unique_dests[1] boundary_edge_buffer_map = inverse_indices @@ -151,18 +146,18 @@ def COO_to_NCCLCommPlan( send_counts_tensor = torch.tensor( boundary_edge_splits, dtype=torch.long, device=device ) - # dist.all_to_all_single(recv_counts_tensor, send_counts_tensor) + dist.all_to_all_single(recv_counts_tensor, send_counts_tensor) boundary_node_splits = recv_counts_tensor.tolist() total_recv_nodes = sum(boundary_node_splits) recv_global_ids = torch.empty(total_recv_nodes, dtype=torch.long, device=device) - # dist.all_to_all_single( - # recv_global_ids, - # unique_global_ids, - # output_split_sizes=boundary_node_splits, - # input_split_sizes=boundary_edge_splits, - # ) + dist.all_to_all_single( + recv_global_ids, + unique_global_ids, + output_split_sizes=boundary_node_splits, + input_split_sizes=boundary_edge_splits, + ) boundary_node_idx = recv_global_ids - my_start From 934d532eb30aea2015a5ca30f0d0101ccc3a750e Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 17 Dec 2025 15:25:47 -0800 Subject: [PATCH 42/48] Fix node size check --- DGraph/distributed/nccl/_NCCLCommPlan.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/DGraph/distributed/nccl/_NCCLCommPlan.py b/DGraph/distributed/nccl/_NCCLCommPlan.py index 1c0232c..4d1828c 100644 --- a/DGraph/distributed/nccl/_NCCLCommPlan.py +++ b/DGraph/distributed/nccl/_NCCLCommPlan.py @@ -118,6 +118,11 @@ def COO_to_NCCLCommPlan( device = local_edge_list.device my_dst_global = global_edges_dst[local_edge_list].to(device) + if int(offset[-1].item()) > (2**32): + raise ValueError( + f"{offset[-1]}, Number of vertices exceeding {2**32}, which is not supported" + ) + my_start = offset[rank].item() my_end = offset[rank + 1].item() num_local_vertices = int(my_end - my_start) @@ -138,6 +143,14 @@ def COO_to_NCCLCommPlan( b_dest_ranks, b_dst_global ) + print(f"Rank {rank} has {len(boundary_edge_indices)} edges to send ") + print(f"Rank {rank} has {len(unique_ranks)} unique messages to send ") + + if len(unique_ranks) > 0: + print( + f"Rank {rank} message reduction ratio: {len(boundary_edge_indices)/len(unique_ranks)}" + ) + boundary_edge_buffer_map = inverse_indices boundary_edge_splits = torch.bincount(unique_ranks, minlength=world_size).tolist() From 25c3edff05df76a782b916665f736de98524415f Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 17 Dec 2025 23:21:29 -0800 Subject: [PATCH 43/48] Add edge-conditioned graph plan to hold full edge communication info --- DGraph/distributed/nccl/_NCCLCommPlan.py | 70 ++++++++++++++++++++ experiments/OGB-LSC/RGAT.py | 83 +++++++++++++++++++++++- 2 files changed, 151 insertions(+), 2 deletions(-) diff --git a/DGraph/distributed/nccl/_NCCLCommPlan.py b/DGraph/distributed/nccl/_NCCLCommPlan.py index 4d1828c..91875c0 100644 --- a/DGraph/distributed/nccl/_NCCLCommPlan.py +++ b/DGraph/distributed/nccl/_NCCLCommPlan.py @@ -56,6 +56,28 @@ def to(self, device: torch.device): self.boundary_edge_idx = self.boundary_edge_idx.to(device) self.boundary_edge_buffer_map = self.boundary_edge_buffer_map.to(device) self.boundary_vertex_idx = self.boundary_vertex_idx.to(device) + return self + + +@dataclass +class NCCLEdgeConditionedGraphCommPlan: + """ + Class to store communication plan for distributed gather-scatter for edge-conditioned + graphs where both source and destination vertices are needed. + + Attributes: + rank (int): Local rank + world_size (int): World size + + source_graph_plan (NCCLGraphCommPlan): Communication plan for source vertices + dest_graph_plan (NCCLGraphCommPlan): Communication plan for destination vertices + """ + + rank: int + world_size: int + + source_graph_plan: NCCLGraphCommPlan + dest_graph_plan: NCCLGraphCommPlan def compute_edge_slices(dest_ranks, rank, my_dst_global, offset): @@ -187,3 +209,51 @@ def COO_to_NCCLCommPlan( boundary_vertex_idx=boundary_node_idx, boundary_vertex_splits=boundary_node_splits, ) + + +def COO_to_NCCLEdgeConditionedCommPlan( + rank: int, + world_size: int, + global_edges_src: torch.Tensor, + global_edges_dst: torch.Tensor, + local_edge_list: torch.Tensor, + offset: torch.Tensor, +) -> NCCLEdgeConditionedGraphCommPlan: + """ + + Convert COO (Coordinate List) format graph to NCCLEdgeConditionedGraphCommPlan for distributed gather-scatter operations. + + Args: + rank (int): Local rank + world_size (int): World size + global_edges_src (torch.Tensor): Global source indices of edges + global_edges_dst (torch.Tensor): Global destination indices of edges + local_edge_list (torch.Tensor): List of indices of local edges + offset (torch.Tensor): Offset for each rank. + The vertices are partitioned among ranks in a contiguous manner. + All vertices in the range [offset[rank], offset[rank + 1]) are assigned to the rank. + """ + device = local_edge_list.device + + source_plan = COO_to_NCCLCommPlan( + rank, + world_size, + global_edges_src, + local_edge_list, + offset, + ) + + dest_plan = COO_to_NCCLCommPlan( + rank, + world_size, + global_edges_dst, + local_edge_list, + offset, + ) + + return NCCLEdgeConditionedGraphCommPlan( + rank=rank, + world_size=world_size, + source_graph_plan=source_plan, + dest_graph_plan=dest_plan, + ) diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index 4029475..f5d1a9a 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -18,8 +18,9 @@ from distributed_layers import DistributedBatchNorm1D import os.path as osp from CacheGenerator import get_cache -import sys import os +from typing import Any, Optional, overload +from DGraph.distributed.nccl import NCCLBackendEngine, NCCLGraphCommPlan class ConvLayer(nn.Module): @@ -61,7 +62,83 @@ def __init__( else: self.register_parameter("bias", None) + @overload def forward( + self, + x: torch.Tensor, + comm_plan: NCCLGraphCommPlan, + *, + x_j: Optional[torch.Tensor] = None, + ): ... + + @overload + def forward( + self, + x: torch.Tensor, + *, + edge_index: Any, + rank_mapping: Any, + x_j: Optional[torch.Tensor] = None, + src_gather_cache: Optional[Any] = None, + dest_gather_cache: Optional[Any] = None, + dest_scatter_cache: Optional[Any] = None, + ): ... + + def forward( + self, + x, + comm_plan=None, + *, + edge_index=None, + rank_mapping=None, + x_j=None, + src_gather_cache=None, + dest_gather_cache=None, + dest_scatter_cache=None, + ): + """Forward method that can use either a communication plan or COO format + + Args: + x: Node features tensor + comm_plan: Communication plan object (if available) + edge_index: Edge index tensor in COO format + rank_mapping: Rank mapping tensors + x_j: Optional source node features tensor (for hetero graphs) + src_gather_cache: Optional cache for source gather communication + dest_gather_cache: Optional cache for destination gather communication + dest_scatter_cache: Optional cache for destination scatter communication + + Returns: + out: Output node features tensor + """ + if comm_plan is not None: + return self._forward_comm_plan(x, comm_plan, x_j=x_j) + + return self._forward_coo( + x, + edge_index=edge_index, + rank_mapping=rank_mapping, + x_j=x_j, + src_gather_cache=src_gather_cache, + dest_gather_cache=dest_gather_cache, + dest_scatter_cache=dest_scatter_cache, + ) + + def _forward_comm_plan(self, x, comm_plan, x_j=None): + h = self.conv1(x) + + if self.hetero: + assert x_j is not None + h_j = self.conv1(x_j) + else: + h_j = h + + assert isinstance(self.comm.__backend_engine, NCCLBackendEngine) + + h_i = self.comm.__backend_engine.gather(h, comm_plan=comm_plan) + h_j = self.comm.__backend_engine.gather(h_j, comm_plan=comm_plan) + + def _forward_coo( self, x, edge_index, @@ -328,7 +405,9 @@ def forward(self, xs, adjts, edge_types, rank_mappings): # Dummy operation to touch all outs to avoid DDP's 'unused parameters' dummy = torch.zeros(1, device=outs[0].device, dtype=outs[0].dtype) for t in outs: - dummy = dummy + (t[0].sum() * 0.0) # zero-valued scalar that depends on t + dummy = dummy + ( + t[0].sum() * 0.0 + ) # zero-valued scalar that depends on t outs[0][0] = outs[0][0] + dummy return self.mlp(outs[0]) From 16dd5ad158737b04328617728892dee4f049443d Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 17 Dec 2025 23:34:28 -0800 Subject: [PATCH 44/48] Update GAT implementation with new comm plan --- DGraph/distributed/nccl/_NCCLCommPlan.py | 10 ++- DGraph/distributed/nccl/__init__.py | 6 +- experiments/OGB-LSC/RGAT.py | 81 ++++++++++++++++++++---- 3 files changed, 80 insertions(+), 17 deletions(-) diff --git a/DGraph/distributed/nccl/_NCCLCommPlan.py b/DGraph/distributed/nccl/_NCCLCommPlan.py index 91875c0..ff0d5e2 100644 --- a/DGraph/distributed/nccl/_NCCLCommPlan.py +++ b/DGraph/distributed/nccl/_NCCLCommPlan.py @@ -1,6 +1,6 @@ import torch from dataclasses import dataclass -from typing import List +from typing import List, Optional import torch.distributed as dist @@ -77,7 +77,13 @@ class NCCLEdgeConditionedGraphCommPlan: world_size: int source_graph_plan: NCCLGraphCommPlan - dest_graph_plan: NCCLGraphCommPlan + dest_graph_plan: Optional[NCCLGraphCommPlan] = None + + def to(self, device: torch.device): + self.source_graph_plan = self.source_graph_plan.to(device) + if self.dest_graph_plan is not None: + self.dest_graph_plan = self.dest_graph_plan.to(device) + return self def compute_edge_slices(dest_ranks, rank, my_dst_global, offset): diff --git a/DGraph/distributed/nccl/__init__.py b/DGraph/distributed/nccl/__init__.py index d3ca0b4..9202116 100644 --- a/DGraph/distributed/nccl/__init__.py +++ b/DGraph/distributed/nccl/__init__.py @@ -12,4 +12,8 @@ # # SPDX-License-Identifier: (Apache-2.0) from DGraph.distributed.nccl.NCCLBackendEngine import NCCLBackendEngine, TIMINGS -from DGraph.distributed.nccl._NCCLCommPlan import NCCLGraphCommPlan, COO_to_NCCLCommPlan +from DGraph.distributed.nccl._NCCLCommPlan import ( + NCCLGraphCommPlan, + NCCLEdgeConditionedGraphCommPlan, + COO_to_NCCLCommPlan, +) diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py index f5d1a9a..778616e 100644 --- a/experiments/OGB-LSC/RGAT.py +++ b/experiments/OGB-LSC/RGAT.py @@ -20,7 +20,11 @@ from CacheGenerator import get_cache import os from typing import Any, Optional, overload -from DGraph.distributed.nccl import NCCLBackendEngine, NCCLGraphCommPlan +from DGraph.distributed.nccl import ( + NCCLBackendEngine, + NCCLGraphCommPlan, + NCCLEdgeConditionedGraphCommPlan, +) class ConvLayer(nn.Module): @@ -66,7 +70,7 @@ def __init__( def forward( self, x: torch.Tensor, - comm_plan: NCCLGraphCommPlan, + comm_plan: NCCLEdgeConditionedGraphCommPlan, *, x_j: Optional[torch.Tensor] = None, ): ... @@ -124,19 +128,72 @@ def forward( dest_scatter_cache=dest_scatter_cache, ) - def _forward_comm_plan(self, x, comm_plan, x_j=None): + def _process_messages( + self, + h, + h_j, + ): + messages = torch.cat([h, h_j], dim=-1) + edge_scores = self.leaky_relu(self.project_message(messages)) + numerator = torch.exp(edge_scores) + return numerator + + def _calc_attention_messages( + self, + neighbor_features, + numerator, + denominator, + ): + alpha_ij = numerator / (denominator + 1e-16) + attention_messages = neighbor_features * alpha_ij + return attention_messages + + def _apply_res_and_bias(self, out, x): + if self.residual: + out = out + self.res_net(x) + if self.bias is not None: + out = out + self.bias + return out + + def _forward_comm_plan( + self, x, comm_plan: NCCLEdgeConditionedGraphCommPlan, x_j=None + ): h = self.conv1(x) + source_graph_plan = comm_plan.source_graph_plan if self.hetero: assert x_j is not None h_j = self.conv1(x_j) + assert comm_plan.dest_graph_plan is not None + dest_graph_plan = comm_plan.dest_graph_plan else: h_j = h + dest_graph_plan = source_graph_plan assert isinstance(self.comm.__backend_engine, NCCLBackendEngine) - h_i = self.comm.__backend_engine.gather(h, comm_plan=comm_plan) - h_j = self.comm.__backend_engine.gather(h_j, comm_plan=comm_plan) + h_i = self.comm.__backend_engine.gather(h, comm_plan=source_graph_plan) + + h_j = self.comm.__backend_engine.gather(h_j, comm_plan=dest_graph_plan) + + numerator = self._process_messages(h_i, h_j) + + denominator = self.comm.__backend_engine.scatter( + numerator, comm_plan=source_graph_plan + ) + + denominator = self.comm.__backend_engine.gather( + denominator, comm_plan=dest_graph_plan + ) + + attention_messages = self._calc_attention_messages(h_j, numerator, denominator) + + out = self.comm.__backend_engine.scatter( + attention_messages, comm_plan=source_graph_plan + ) + out = self._apply_res_and_bias(out, x) + + return out def _forward_coo( self, @@ -172,9 +229,7 @@ def _forward_coo( h_j, _src_indices, _src_rank_mappings, cache=src_gather_cache ) - messages = torch.cat([h_i, h_j], dim=-1) - edge_scores = self.leaky_relu(self.project_message(messages)) - numerator = torch.exp(edge_scores) + numerator = self._process_messages(h_i, h_j) denominator = self.comm.scatter( numerator, @@ -188,8 +243,8 @@ def _forward_coo( denominator, _src_indices, _src_rank_mappings, cache=dest_gather_cache ) - alpha_ij = numerator / (denominator + 1e-16) - attention_messages = h_j * alpha_ij + attention_messages = self._calc_attention_messages(h_j, numerator, denominator) + out = self.comm.scatter( attention_messages, _dst_indices, @@ -197,10 +252,8 @@ def _forward_coo( h.size(1), cache=dest_scatter_cache, ) - if self.residual: - out = out + self.res_net(x) - if self.bias is not None: - out = out + self.bias + + out = self._apply_res_and_bias(out, x) return out From ceb6d906e2c78d0f16014d1894255ee306ed4f6f Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 17 Dec 2025 23:37:30 -0800 Subject: [PATCH 45/48] Remove pyg_wrapper function --- experiments/OGB-LSC/pyg_wrappers.py | 44 ----------------------------- 1 file changed, 44 deletions(-) delete mode 100644 experiments/OGB-LSC/pyg_wrappers.py diff --git a/experiments/OGB-LSC/pyg_wrappers.py b/experiments/OGB-LSC/pyg_wrappers.py deleted file mode 100644 index 7f76d57..0000000 --- a/experiments/OGB-LSC/pyg_wrappers.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. -# Produced at the Lawrence Livermore National Laboratory. -# Written by the LBANN Research Team (B. Van Essen, et al.) listed in -# the CONTRIBUTORS file. See the top-level LICENSE file for details. -# -# LLNL-CODE-697807. -# All rights reserved. -# -# This file is part of LBANN: Livermore Big Artificial Neural Network -# Toolkit. For details, see http://software.llnl.gov/LBANN or -# https://github.com/LBANN and https://github.com/LLNL/LBANN. -# -# SPDX-License-Identifier: (Apache-2.0) - - -class DGraphSparseTensor: - def __init__( - self, - row, - col, - value=None, - comm=None, - rank_mapping=None, - **kwargs, - ): - super(DGraphSparseTensor, self).__init__() - assert comm is not None, "Comm object cannot be None" - assert rank_mapping is not None, "rank_mapping cannot be None" - self.comm = comm - self.rank_mapping = rank_mapping - self.world_size = comm.get_world_size() - self.rank = comm.get_rank() - self.row = row - self.col = col - - def to(self, device): - self.row = self.row.to(device) - self.col = self.col.to(device) - if self.rank_mapping is not None: - self.rank_mapping = self.rank_mapping.to(device) - - if self.value is not None: - self.value = self.value.to(device) - return self From 673f8b6556685c96fa1fa5b599eac38d03e688fe Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 18 Dec 2025 10:02:22 -0800 Subject: [PATCH 46/48] Enable hetero-graphs in comm-plan --- DGraph/distributed/nccl/_NCCLCommPlan.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/DGraph/distributed/nccl/_NCCLCommPlan.py b/DGraph/distributed/nccl/_NCCLCommPlan.py index ff0d5e2..ad5a4e7 100644 --- a/DGraph/distributed/nccl/_NCCLCommPlan.py +++ b/DGraph/distributed/nccl/_NCCLCommPlan.py @@ -223,7 +223,8 @@ def COO_to_NCCLEdgeConditionedCommPlan( global_edges_src: torch.Tensor, global_edges_dst: torch.Tensor, local_edge_list: torch.Tensor, - offset: torch.Tensor, + src_offset: torch.Tensor, + dest_offset: Optional[torch.Tensor], ) -> NCCLEdgeConditionedGraphCommPlan: """ @@ -235,9 +236,12 @@ def COO_to_NCCLEdgeConditionedCommPlan( global_edges_src (torch.Tensor): Global source indices of edges global_edges_dst (torch.Tensor): Global destination indices of edges local_edge_list (torch.Tensor): List of indices of local edges - offset (torch.Tensor): Offset for each rank. + src_offset (torch.Tensor): Offset for each rank for source vertices. The vertices are partitioned among ranks in a contiguous manner. - All vertices in the range [offset[rank], offset[rank + 1]) are assigned to the rank. + All vertices in the range [src_offset[rank], src_offset[rank + 1]) are assigned to the rank. + dest_offset (Optional[torch.Tensor]): Offset for each rank for destination vertices. + The vertices are partitioned among ranks in a contiguous manner. + All vertices in the range [dest_offset[rank], dest_offset[rank + 1]) are assigned to the rank. """ device = local_edge_list.device @@ -246,15 +250,18 @@ def COO_to_NCCLEdgeConditionedCommPlan( world_size, global_edges_src, local_edge_list, - offset, + src_offset, ) + if dest_offset is None: + dest_offset = src_offset + dest_plan = COO_to_NCCLCommPlan( rank, world_size, global_edges_dst, local_edge_list, - offset, + dest_offset, ) return NCCLEdgeConditionedGraphCommPlan( From 88269a18cbe30d4e153f7871834d16dc64d91b36 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 18 Dec 2025 12:09:01 -0800 Subject: [PATCH 47/48] Fixed mismatched API on scatter_sum_gather --- DGraph/distributed/include/torch_local.hpp | 4 +- DGraph/distributed/nccl/__init__.py | 1 + experiments/OGB-LSC/mag240m/DGraph_MAG240M.py | 2 +- tests/test_local_kernels.py | 81 +++++++++++++++++++ 4 files changed, 85 insertions(+), 3 deletions(-) diff --git a/DGraph/distributed/include/torch_local.hpp b/DGraph/distributed/include/torch_local.hpp index 4666d38..7a4a258 100644 --- a/DGraph/distributed/include/torch_local.hpp +++ b/DGraph/distributed/include/torch_local.hpp @@ -30,10 +30,10 @@ torch::Tensor local_masked_scatter_gather(torch::Tensor input, const int num_cols, const int num_output_rows); -torch::Tensor local_masked_scatter_add_gather(torch::Tensor grad_output, - torch::Tensor input, +torch::Tensor local_masked_scatter_add_gather(torch::Tensor input, torch::Tensor indices, torch::Tensor rank_local_placement, + torch::Tensor output, const int num_batches, const int num_values_rows, const int num_cols, diff --git a/DGraph/distributed/nccl/__init__.py b/DGraph/distributed/nccl/__init__.py index 9202116..aae0291 100644 --- a/DGraph/distributed/nccl/__init__.py +++ b/DGraph/distributed/nccl/__init__.py @@ -16,4 +16,5 @@ NCCLGraphCommPlan, NCCLEdgeConditionedGraphCommPlan, COO_to_NCCLCommPlan, + COO_to_NCCLEdgeConditionedCommPlan, ) diff --git a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py index 458048a..9b6313a 100644 --- a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py +++ b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: (Apache-2.0) from ogb.lsc import MAG240MDataset import torch -from typing import Optional +from typing import Optional, Tuple from torch_sparse import SparseTensor import numpy as np from tqdm import tqdm diff --git a/tests/test_local_kernels.py b/tests/test_local_kernels.py index 1544644..f3a44ac 100644 --- a/tests/test_local_kernels.py +++ b/tests/test_local_kernels.py @@ -67,3 +67,84 @@ def test_optimized_local_gather(): assert torch.allclose( out_tensor.cpu(), out_tensor_gt ), "Optimized local gather failed" + + +def test_optimized_scatter_gaher(): + try: + from torch_local import local_masked_scatter_gather + except ImportError as e: + pytest.fail(f"Failed to import local_masked_scatter_gather: {e}") + + num_src_rows = 8 + num_out_rows = 8 + bs = 1 + num_features = 4 + src_tensor = torch.randn(bs, num_src_rows, num_features) + src_indices = torch.tensor([0, 3, 2, 1]) + dst_indices = torch.tensor([1, 3, 5, 7]) + + out_tensor_gt = torch.zeros(bs, num_out_rows, num_features) + + for i in range(bs): + for j in range(len(src_indices)): + out_tensor_gt[i, dst_indices[j]] = src_tensor[i, src_indices[j]] + out_tensor_gt = out_tensor_gt.view(bs, num_out_rows, num_features) + out_tensor = torch.zeros_like(out_tensor_gt) + out_tensor = out_tensor.cuda() + src_tensor = src_tensor.cuda() + src_indices = src_indices.cuda().long() + dst_indices = dst_indices.cuda().long() + local_masked_scatter_gather( + src_tensor, + src_indices, + dst_indices, + out_tensor, + bs, + num_src_rows, + num_features, + num_out_rows, + ) + assert torch.allclose( + out_tensor.cpu(), out_tensor_gt + ), "Optimized local scatter-gather failed" + + +def test_optimized_scatter_add_gather(): + try: + from torch_local import local_masked_scatter_add_gather + except ImportError as e: + pytest.fail(f"Failed to import local_masked_scatter_add_gather: {e}") + + num_src_rows = 8 + num_out_rows = 8 + bs = 1 + num_features = 4 + src_tensor = torch.randn(bs, num_src_rows, num_features) + src_indices = torch.tensor([0, 3, 2, 1, 3]) + dst_indices = torch.tensor([1, 3, 5, 7, 3]) + + out_tensor_gt = torch.zeros(bs, num_out_rows, num_features) + + for i in range(bs): + for j in range(len(src_indices)): + out_tensor_gt[i, dst_indices[j]] += src_tensor[i, src_indices[j]] + + out_tensor_gt = out_tensor_gt.view(bs, num_out_rows, num_features) + out_tensor = torch.zeros_like(out_tensor_gt) + out_tensor = out_tensor.cuda() + src_tensor = src_tensor.cuda() + src_indices = src_indices.cuda().long() + dst_indices = dst_indices.cuda().long() + local_masked_scatter_add_gather( + src_tensor, + src_indices, + dst_indices, + out_tensor, + bs, + num_src_rows, + num_features, + num_out_rows, + ) + assert torch.allclose( + out_tensor.cpu(), out_tensor_gt + ), "Optimized local scatter-add-gather failed" From 0b17d4a7785753ae03a9e633339133abe5f714c5 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 18 Dec 2025 12:12:06 -0800 Subject: [PATCH 48/48] Update python bindings for localScatterSumGather --- DGraph/distributed/RankLocalOps.py | 5 +++-- DGraph/distributed/nccl/_torch_func_impl.py | 12 ++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/DGraph/distributed/RankLocalOps.py b/DGraph/distributed/RankLocalOps.py index 8cacecc..b7302f1 100644 --- a/DGraph/distributed/RankLocalOps.py +++ b/DGraph/distributed/RankLocalOps.py @@ -23,6 +23,7 @@ local_masked_gather, local_masked_scatter, local_masked_scatter_gather, + local_masked_scatter_add_gather, ) _LOCAL_OPT_KERNELS_AVAILABLE = True @@ -129,7 +130,7 @@ def OptimizedLocalScatterGather( return output -def OptimizedRankLocalScatterSumGather( +def OptimizedLocalScatterSumGather( src: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, @@ -160,7 +161,7 @@ def OptimizedRankLocalScatterSumGather( num_src_rows = src.shape[1] num_features = src.shape[-1] num_output_rows = output.shape[1] - local_masked_scatter_gather( + local_masked_scatter_add_gather( src, src_indices.cuda(), dst_indices.cuda(), diff --git a/DGraph/distributed/nccl/_torch_func_impl.py b/DGraph/distributed/nccl/_torch_func_impl.py index 3bb51ef..71880b7 100644 --- a/DGraph/distributed/nccl/_torch_func_impl.py +++ b/DGraph/distributed/nccl/_torch_func_impl.py @@ -7,7 +7,7 @@ from DGraph.distributed.RankLocalOps import ( OptimizedRankLocalMaskedGather, OptimizedLocalScatterGather, - OptimizedRankLocalScatterSumGather, + OptimizedLocalScatterSumGather, ) from DGraph.distributed.nccl._NCCLCommPlan import NCCLGraphCommPlan @@ -95,7 +95,7 @@ def backward(ctx, grad_output): num_batches, comm_plan.num_local_vertices, num_features, device=device ) - grad_input = OptimizedRankLocalScatterSumGather( + grad_input = OptimizedLocalScatterSumGather( src=grad_output, output=grad_input, src_indices=comm_plan.local_edge_idx, @@ -111,7 +111,7 @@ def backward(ctx, grad_output): output_split_sizes=comm_plan.boundary_vertex_splits, input_split_sizes=comm_plan.boundary_edge_splits, ) - grad_input = OptimizedRankLocalScatterSumGather( + grad_input = OptimizedLocalScatterSumGather( src=recv_buffer, output=grad_input, src_indices=comm_plan.boundary_edge_buffer_map, @@ -148,7 +148,7 @@ def forward( num_batches, comm_plan.num_local_vertices, num_features ).to(local_send_tensor.device) - output_tensor = OptimizedRankLocalScatterSumGather( + output_tensor = OptimizedLocalScatterSumGather( src=local_send_tensor, output=output_tensor, src_indices=comm_plan.local_edge_idx, @@ -161,7 +161,7 @@ def forward( num_batches, total_send_rows, num_features, device=local_send_tensor.device ) - send_buf = OptimizedRankLocalScatterSumGather( + send_buf = OptimizedLocalScatterSumGather( src=local_send_tensor, output=send_buf, src_indices=comm_plan.boundary_edge_idx, @@ -178,7 +178,7 @@ def forward( output_split_sizes=comm_plan.boundary_vertex_splits, input_split_sizes=comm_plan.boundary_edge_splits, ) - output_tensor = OptimizedRankLocalScatterSumGather( + output_tensor = OptimizedLocalScatterSumGather( src=recv_buffer, output=output_tensor, src_indices=comm_plan.boundary_edge_buffer_map,