Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
38249ec
Adding distributed batch normalization layers
szaman19 Sep 12, 2025
b24ef7d
Wrap around implementation with torch nn module
szaman19 Sep 12, 2025
9cdadb6
Add DGraph imlementation of RGat
szaman19 Sep 12, 2025
d44bc83
Adding PyG sparse tensor wrapper if needed
szaman19 Sep 12, 2025
b539e79
Add RGat implementation
szaman19 Sep 12, 2025
63109a3
Completed with RGAT implementation
szaman19 Sep 15, 2025
335a65e
Add the synthetic dataset for testing purposes
szaman19 Sep 15, 2025
2d9ff3f
Add MAG240M dataset
szaman19 Sep 15, 2025
c2530ad
Bug fixes to get things running correctly
szaman19 Sep 15, 2025
88c0410
Update fix for data type issues and cache generator, but experiencing…
szaman19 Sep 16, 2025
d44fc6e
paper2paper layer running.
szaman19 Sep 16, 2025
e3f091e
Updating the synthetic dataset to track down hang on directed relatio…
szaman19 Sep 17, 2025
1b45b75
Still debugging error on miscalculated gradient size
szaman19 Sep 17, 2025
f71f109
Fix for incorrect tensor shape
szaman19 Sep 17, 2025
076a692
Author 2 paper relation working. Only author 2 institution error rema…
szaman19 Sep 26, 2025
2d27416
Remove extra breakpoints in cache generators
szaman19 Sep 26, 2025
d127220
Fix cache generator with correct input shape for destination gather
szaman19 Sep 26, 2025
a6962cc
Fix on batch norm to have correct local variance reduction
szaman19 Sep 26, 2025
f85f133
Added additional parameters to batch norm for backprop
szaman19 Sep 27, 2025
750d8e6
Adding helper functions to sync normalization values and fixed evalua…
szaman19 Sep 28, 2025
657111e
Latest changes to RGAT
szaman19 Oct 3, 2025
746693d
(OGB-LSC) Bugfix for geenrating _dest_scatter_cache
KIwabuchi Oct 18, 2025
296fa2f
(OGB-LSC) Workaround for DDP's unsed parameter error
KIwabuchi Oct 18, 2025
fa585a3
(OGB-LSC) Some performance optimizations
KIwabuchi Oct 18, 2025
483bc48
Fix DGraph Mag240M dataset __getitem__ method
szaman19 Oct 29, 2025
7930024
Remove debug messages
KIwabuchi Oct 24, 2025
09735a7
(OGB-LSC) Bugfix for mag240m dataset
KIwabuchi Oct 25, 2025
ee1f725
(OGB-LSC) Remove debug message
KIwabuchi Nov 8, 2025
f334ad4
(OGB-LSC) Split dtaset using OGB's function
KIwabuchi Nov 11, 2025
c663626
Updated torch bindings implementation for local-scatter-gather
szaman19 Nov 20, 2025
47f6ef3
New and improved concise dataplan with efficient connectivity data st…
szaman19 Nov 20, 2025
e4c9b2e
Add updated kernels for local scatter-gather + NCCLCommPlan
szaman19 Dec 12, 2025
8535a40
Update the scatter-gather impl to allow set and add aggregation
szaman19 Dec 12, 2025
3280ee8
Add ScatterSumGather python wrapper
szaman19 Dec 13, 2025
13c1205
Fix backward function call on StaticGather
szaman19 Dec 13, 2025
ef2efbb
Fixed Scatter forward
szaman19 Dec 13, 2025
22ed522
Updated scatter function impl
szaman19 Dec 17, 2025
dc0ded2
Fix build issues and change op struct
szaman19 Dec 17, 2025
f559bcf
Remove unnecesary imports and remove cache implementation
szaman19 Dec 17, 2025
f2c3915
Decompose internal function to reduce memory usage
szaman19 Dec 17, 2025
bf8c6ac
Optimized CommPlan generator
szaman19 Dec 17, 2025
934d532
Fix node size check
szaman19 Dec 17, 2025
25c3edf
Add edge-conditioned graph plan to hold full edge communication info
szaman19 Dec 18, 2025
16dd5ad
Update GAT implementation with new comm plan
szaman19 Dec 18, 2025
ceb6d90
Remove pyg_wrapper function
szaman19 Dec 18, 2025
673f8b6
Enable hetero-graphs in comm-plan
szaman19 Dec 18, 2025
88269a1
Fixed mismatched API on scatter_sum_gather
szaman19 Dec 18, 2025
0b17d4a
Update python bindings for localScatterSumGather
szaman19 Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,27 @@ cython_debug/
# Use wildcards as well
*~
*.o
# Miscallenous files generated by DGraph data processing
skbuild/
.vscode/
logs/
torchrun_*
*.png
rdvz
*.pt
*.core
*.graph
*.out
*.gz
data_processed
*.zip
cache
graph_cache
*.nsys-rep
*.nsys
*.pth
*.pyc
*.npy
*.npz
*.sqlite
*.csv
4 changes: 2 additions & 2 deletions DGraph/distributed/Engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def scatter(
output_size: int,
rank_mappings: Optional[torch.Tensor] = None,
*args,
**kwargs
**kwargs,
) -> torch.Tensor:
raise NotImplementedError

Expand All @@ -60,7 +60,7 @@ def gather(
indices: Union[torch.Tensor, torch.LongTensor],
rank_mappings: Optional[torch.Tensor] = None,
*args,
**kwargs
**kwargs,
) -> torch.Tensor:
raise NotImplementedError

Expand Down
99 changes: 97 additions & 2 deletions DGraph/distributed/RankLocalOps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
"""

import torch
import torch.distributed as dist

try:
from DGraph.torch_local import local_masked_gather, local_masked_scatter
from DGraph.torch_local import (
local_masked_gather,
local_masked_scatter,
local_masked_scatter_gather,
local_masked_scatter_add_gather,
)

_LOCAL_OPT_KERNELS_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -81,6 +87,93 @@ def OptimizedRankLocalMaskedGather(
return output


def OptimizedLocalScatterGather(
src: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
output: torch.Tensor,
):
"""
Performs the operation

for i in range(len(src_indices)):
output[dst_indices[i]] = src[src_indices[i]]
Args:
src (torch.Tensor): Source tensor
src_indices (torch.Tensor): Source indices
dst_indices (torch.Tensor): Destination indices
output (torch.Tensor): Output tensor
Returns:
torch.Tensor: Output tensor after scatter-gather
"""

if not _LOCAL_OPT_KERNELS_AVAILABLE:
warnings.warn(
"Optimized local kernels are not available. Falling back to the default implementation."
)
output[dst_indices] = src[src_indices]
else:
bs = src.shape[0]
num_src_rows = src.shape[1]
num_features = src.shape[-1]
num_output_rows = output.shape[1]
local_masked_scatter_gather(
src,
src_indices.cuda(),
dst_indices.cuda(),
output,
bs,
num_src_rows,
num_features,
num_output_rows,
)
return output


def OptimizedLocalScatterSumGather(
src: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
output: torch.Tensor,
):
"""
Performs the operation

for i in range(len(src_indices)):
output[dst_indices[i]] += src[src_indices[i]]
Args:
src (torch.Tensor): Source tensor
src_indices (torch.Tensor): Source indices
dst_indices (torch.Tensor): Destination indices
output (torch.Tensor): Output tensor
Returns:
torch.Tensor: Output tensor after scatter-gather
"""

if not _LOCAL_OPT_KERNELS_AVAILABLE:
warnings.warn(
"Optimized local kernels are not available. Falling back to the default implementation."
)
for i in range(src_indices.shape[0]):
output[:, dst_indices[i], :] += src[:, src_indices[i], :]
else:
bs = src.shape[0]
num_src_rows = src.shape[1]
num_features = src.shape[-1]
num_output_rows = output.shape[1]
local_masked_scatter_add_gather(
src,
src_indices.cuda(),
dst_indices.cuda(),
output,
bs,
num_src_rows,
num_features,
num_output_rows,
)
return output


def OutOfPlaceRankLocalMaskedGather(
_src: torch.Tensor, indices: torch.Tensor, rank_mapping: torch.Tensor, rank: int
) -> torch.Tensor:
Expand Down Expand Up @@ -140,7 +233,9 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping):
unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True)
rank_mapping = rank_mapping.to(_indices.device)
renumbered_indices = inverse_indices
unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device)
unique_rank_mapping = torch.zeros_like(
unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device
)
unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping)

return renumbered_indices, unique_indices, unique_rank_mapping
Expand Down
140 changes: 140 additions & 0 deletions DGraph/distributed/csrc/local_data_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,144 @@ namespace Local
}
}
}



template <typename T>
struct FloatAtomicAddOp
{
__device__ __forceinline__ void operator()(T *cur_addr, const T new_val)
{
atomicAdd(cur_addr, new_val);
}
};

template <typename T>
struct FloatSetOp
{
__device__ __forceinline__ void operator()(T *cur_addr, const T new_val)
{
*cur_addr = new_val;
}
};


/**
*
* Masked Gather Kernel operation that performs the operation:
Y [mask[i]] = Op(Y [mask[i]], X [indices[i]])

where Y is the output matrix, X is the input matrix, indices is the index matrix, and mask is the mask matrix.
*/

template <typename Op>
__global__ void Masked_Scatter_Gather_Kernel(
const float *__restrict__ values,
const long *__restrict__ indices,
const long *__restrict__ mask,
float *__restrict__ output,
const int mini_batch_size,
const int num_indices,
const int num_cols,
const int num_output_rows)
{
const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x;
const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y;
const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z;

const size_t nthreadsx = gridDim.x * blockDim.x;
const size_t nthreadsy = gridDim.y * blockDim.y;
const size_t nthreadsz = gridDim.z * blockDim.z;

Op op;

for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz)
{
const auto values_offset = mb_i * num_cols * num_indices;
const auto output_offset = mb_i * num_cols * num_output_rows;
const auto ind_offset = mb_i * num_indices;
const auto mask_offset = mb_i * num_indices;

for (size_t row = gidy; row < num_indices; row += nthreadsy)
{
const auto output_row = mask[mask_offset + row];
const auto input_row = indices[ind_offset + row];

for (size_t col = gidx; col < num_cols; col += nthreadsx)
{
auto *output_addr = &output[output_offset + output_row * num_cols + col];
const auto input_val = values[values_offset + input_row * num_cols + col];
op(output_addr, input_val);
}
}
}
}

/*
*
Optimized masked scatter gather kernel that performs the operation:
Y [mask[i]] = X [indices[i]]

This kernel is optimized for the case where the num_cols is a multiple of 4.

where Y is the output matrix, X is the input matrix, indices is the index matrix, and mask is the mask matrix.
*/
template <typename Op>
__global__ void Optimized_Masked_Scatter_Gather_Kernel(
const float *__restrict__ values,
const long *__restrict__ indices,
const long *__restrict__ mask,
float *__restrict__ output,
const int mini_batch_size,
const int num_indices,
const int num_cols,
const int num_output_rows)
{
const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x;
const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y;
const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z;

const size_t nthreadsx = gridDim.x * blockDim.x;
const size_t nthreadsy = gridDim.y * blockDim.y;
const size_t nthreadsz = gridDim.z * blockDim.z;

// Grid-stride loop over mini-batches

Op binary_operator;
for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz)
{
const auto values_offset = mb_i * num_cols / 4 * num_indices;
const auto output_offset = mb_i * num_cols / 4 * num_output_rows;
const auto ind_offset = mb_i * num_indices;
const auto mask_offset = mb_i * num_indices;

// Grid-stride loop over rows
for (size_t row = gidy; row < num_indices; row += nthreadsy)
{
long output_row, input_row;

if (threadIdx.x == 0)
{
output_row = mask[mask_offset + row];
input_row = indices[ind_offset + row];
}

output_row = __shfl_sync(0xFFFFFFFF, output_row, 0);
input_row = __shfl_sync(0xFFFFFFFF, input_row, 0);

output_row = mask[mask_offset + row];
input_row = indices[ind_offset + row];

size_t col = gidx;

for (; col < num_cols / 4; col += nthreadsx)
{
const float4 values_vec = reinterpret_cast<const float4 *>(values)[values_offset + input_row * num_cols / 4 + col];
float4* output_addr = &reinterpret_cast<float4 *>(output)[output_offset + output_row * num_cols / 4 + col];
binary_operator(output_addr, values_vec);
}
}
}
}

} // namespace Local
2 changes: 2 additions & 0 deletions DGraph/distributed/csrc/torch_local_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ PYBIND11_MODULE(torch_local, m)
{
m.def("local_masked_gather", &local_masked_gather, "Masked Gather");
m.def("local_masked_scatter", &local_masked_scatter, "Masked Scatter");
m.def("local_masked_scatter_gather", &local_masked_scatter_gather, "Masked Scatter Gather");
m.def("local_masked_scatter_add_gather", &local_masked_scatter_add_gather, "Masked Scatter Add Gather");
}
Loading