Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
0268f5b
Created the distributed module
yoshikisd Oct 19, 2025
fded573
Added additional functions to distributed
yoshikisd Oct 20, 2025
10229e4
Renamed cdtools.tools.distributed to cdtools.tools.multigpu. Also add…
yoshikisd Nov 8, 2025
8d8da15
Multi-GPU compatability added for base Reconstructor and AdamReconstr…
yoshikisd Nov 8, 2025
b0604d4
Prevent dataset from plotting/saving outside of Rank 0
yoshikisd Nov 8, 2025
0f1f990
Added a check for init_method in multigpu.setup
yoshikisd Nov 8, 2025
b784620
Multi-GPU compatibility added for CDIModel
yoshikisd Nov 8, 2025
6f15c36
Added example scripts for running multi-GPU jobs with spawn or torchrun
yoshikisd Nov 9, 2025
842bb49
Added example script for running a multi-GPU speed test
yoshikisd Nov 9, 2025
1a3d0cd
Modified runtime error message in multigpu.setup
yoshikisd Nov 9, 2025
cad3d9c
Fixed linting issues on multigpu
yoshikisd Nov 9, 2025
cd584fc
multigpu.setup now recognizes MASTER_ADDR and MASTER_PORT when settin…
yoshikisd Nov 9, 2025
a8d2bfa
Simplified the logic in MASTER_ADDR and MASTER_PORT definition in mul…
yoshikisd Nov 9, 2025
32e9256
Added multi-GPU pytests
yoshikisd Nov 9, 2025
e5996a4
Merge branch 'master' into multigpu
yoshikisd Nov 9, 2025
ebe5b54
Merge remote-tracking branch 'origin/master' into multigpu
yoshikisd Dec 12, 2025
9792b14
Merge remote-tracking branch 'origin/master' into multigpu
yoshikisd Dec 12, 2025
e059c18
Make the timeout default 30 seconds and enforce timeout in units of s…
yoshikisd Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions examples/gold_ball_ptycho_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import cdtools
from matplotlib import pyplot as plt
import torch as t

# If you're noticing that the multi-GPU job is hanging (especially with 100%
# GPU use across all participating devices), you might want to try disabling
# the environment variable NCCL_P2P_DISABLE.
import os
os.environ['NCCL_P2P_DISABLE'] = str(int(True))


# The entire reconstruction script needs to be wrapped in a function
def reconstruct(rank, world_size):

# In the multigpu setup, we need to explicitly define the rank and
# world_size. The master address and master port should also be
# defined if we don't specify the `init_method` parameter.
cdtools.tools.multigpu.setup(rank=rank,
world_size=world_size,
master_addr='localhost',
master_port='6666')

filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)

pad = 10
dataset.pad(pad)
dataset.inspect()
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3,
probe_support_radius=50,
propagation_distance=2e-6,
units='um',
probe_fourier_crop=pad
)
model.translation_offsets.data += 0.7 * \
t.randn_like(model.translation_offsets)
model.weights.requires_grad = False

# We need to manually define the rank parameter for the model, or else
# all plots will be duplicated by the number of GPUs used.
model.rank = rank

device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)

# Rank and world_size also needs to be explicitly defined here
recon = cdtools.reconstructors.AdamReconstructor(model,
dataset,
rank=rank,
world_size=world_size)

with model.save_on_exception(
'example_reconstructions/gold_balls_earlyexit.h5', dataset):

for loss in recon.optimize(20, lr=0.005, batch_size=50):
if rank == 0:
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)

for loss in recon.optimize(50, lr=0.002, batch_size=100,
schedule=True):
if rank == 0:
print(model.report())

if model.epoch % 10 == 0:
model.inspect(dataset)

for loss in recon.optimize(100, lr=0.001, batch_size=100,
schedule=True):
if rank == 0:
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)

cdtools.tools.multigpu.cleanup()

model.tidy_probes()
model.save_to_h5('example_reconstructions/gold_balls.h5', dataset)
model.inspect(dataset)
model.compare(dataset)
plt.show()


if __name__ == '__main__':
# Specify the number of GPUs we want to use, then spawn the multi-GPU job
ngpus = 2
t.multiprocessing.spawn(reconstruct, args=(ngpus,), nprocs=ngpus)
95 changes: 95 additions & 0 deletions examples/gold_ball_ptycho_speedtest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import cdtools
import torch as t

# If you're noticing that the multi-GPU job is hanging (especially with 100%
# GPU use across all participating devices), you might want to try disabling
# the environment variable NCCL_P2P_DISABLE.
import os
os.environ['NCCL_P2P_DISABLE'] = str(int(True))


# For running a speed test, we need to add an additional `conn` parameter
# that the speed test uses to send loss-versus-time curves to the
# speed test function.
def reconstruct(rank, world_size, conn):

# We define the setup in the same manner as we did in the spawn example
# (speed test uses spawn)
cdtools.tools.multigpu.setup(rank=rank,
world_size=world_size,
master_addr='localhost',
master_port='6666')

filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)

pad = 10
dataset.pad(pad)
dataset.inspect()
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3,
probe_support_radius=50,
propagation_distance=2e-6,
units='um',
probe_fourier_crop=pad
)
model.translation_offsets.data += 0.7 * \
t.randn_like(model.translation_offsets)
model.weights.requires_grad = False

# Unless you're plotting data within the reconstruct function, setting
# the rank parameter is not necessary.
# model.rank = rank

device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)

# Rank and world_size also needs to be explicitly defined here
recon = cdtools.reconstructors.AdamReconstructor(model,
dataset,
rank=rank,
world_size=world_size)

with model.save_on_exception(
'example_reconstructions/gold_balls_earlyexit.h5', dataset):

# It is recommended to comment out/remove all plotting-related methods
# for the speed test.
for loss in recon.optimize(20, lr=0.005, batch_size=50):
if rank == 0:
print(model.report())

for loss in recon.optimize(50, lr=0.002, batch_size=100):
if rank == 0:
print(model.report())

for loss in recon.optimize(100, lr=0.001, batch_size=100,
schedule=True):
if rank == 0:
print(model.report())

# We now use the conn parameter to send the loss-versus-time data to the
# main process running the speed test.
conn.send((model.loss_times, model.loss_history))

# And, as always, we need to clean up at the end.
cdtools.tools.multigpu.cleanup()


if __name__ == '__main__':
# We call the run_speed_test function instead of calling
# t.multiprocessing.spawn. We specify the number of runs
# we want to perform per GPU count along with how many
# GPU counts we want to test.
#
# Here, we test both 1 and 2 GPUs using 3 runs for both.
# The show_plots will report the mean +- standard deviation
# of loss-versus-time/epoch curves across the 3 runs.
# The plot will also show the mean +- standard deviation of
# the GPU-dependent speed ups across the 3 runs.
cdtools.tools.multigpu.run_speed_test(reconstruct,
gpu_counts=(1, 2),
runs=3,
show_plot=True)
70 changes: 70 additions & 0 deletions examples/gold_ball_ptycho_torchrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import cdtools
from matplotlib import pyplot as plt
import torch as t

# At the beginning of the script we need to setup the multi-GPU job
# by initializing the process group and sycnronizing the RNG seed
# across all participating GPUs.
cdtools.tools.multigpu.setup()

# To avoid redundant print statements, we first grab the GPU "rank"
# (an ID number between 0 and max number of GPUs minus 1).
rank = cdtools.tools.multigpu.get_rank()

filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)

pad = 10
dataset.pad(pad)
dataset.inspect()
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3,
probe_support_radius=50,
propagation_distance=2e-6,
units='um',
probe_fourier_crop=pad
)
model.translation_offsets.data += 0.7 * t.randn_like(model.translation_offsets)
model.weights.requires_grad = False
device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)

recon = cdtools.reconstructors.AdamReconstructor(model, dataset)

with model.save_on_exception(
'example_reconstructions/gold_balls_earlyexit.h5', dataset):

for loss in recon.optimize(20, lr=0.005, batch_size=50):
# We ensure that only the GPU with rank of 0 runs print statement.
if rank == 0:
print(model.report())

# But we don't need to do rank checking for any plotting- or saving-
# related methods; this checking is handled internernally.
if model.epoch % 10 == 0:
model.inspect(dataset)

for loss in recon.optimize(50, lr=0.002, batch_size=100):
if rank == 0:
print(model.report())

if model.epoch % 10 == 0:
model.inspect(dataset)

for loss in recon.optimize(100, lr=0.001, batch_size=100, schedule=True):
if rank == 0:
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)

# After the reconstruction is completed, we need to cleanup things by
# destroying the process group.
cdtools.tools.multigpu.cleanup()

model.tidy_probes()
model.save_to_h5('example_reconstructions/gold_balls.h5', dataset)
model.inspect(dataset)
model.compare(dataset)
plt.show()
7 changes: 6 additions & 1 deletion src/cdtools/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from copy import copy
import h5py
import pathlib
from cdtools.tools import data as cdtdata
from cdtools.tools import data as cdtdata, multigpu
from torch.utils import data as torchdata

__all__ = ['CDataset']
Expand Down Expand Up @@ -92,6 +92,11 @@ def __init__(

self.get_as(device='cpu')

# This is a flag related to multi-GPU operation which prevents
# saving/plotting functions from being executed on GPUs outside of
# rank 0
self.rank = multigpu.get_rank()


def to(self, *args, **kwargs):
"""Sends the relevant data to the given device and dtype
Expand Down
18 changes: 16 additions & 2 deletions src/cdtools/datasets/ptycho_2d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def from_cxi(cls, cxi_file, cut_zeros=True, load_patterns=True):

# Generate a base dataset
dataset = CDataset.from_cxi(cxi_file)


# Mutate the class to this subclass (BasicPtychoDataset)
dataset.__class__ = cls

Expand Down Expand Up @@ -198,7 +200,11 @@ def to_cxi(self, cxi_file):
cxi_file : str, pathlib.Path, or h5py.File
The .cxi file to write to
"""

# FOR MULTI-GPU: Dont run this block of code if it isn't
# called by the rank 0 GPU
if self.rank != 0:
return

# If a bare string is passed
if isinstance(cxi_file, str) or isinstance(cxi_file, pathlib.Path):
with cdtdata.create_cxi(cxi_file) as f:
Expand Down Expand Up @@ -231,7 +237,10 @@ def inspect(
can display a base-10 log plot of the detector readout at each
position.
"""

# FOR MULTI-GPU: Dont run this block of code if it isn't
# called by the rank 0 GPU
if self.rank != 0:
return

def get_images(idx):
inputs, output = self[idx]
Expand Down Expand Up @@ -292,6 +301,11 @@ def plot_mean_pattern(self, log_offset=1):
level.

"""
# FOR MULTI-GPU: Dont run this block of code if it isn't
# called by the rank 0 GPU
if self.rank != 0:
return

mean_pattern, bins, ssnr = analysis.calc_spectral_info(self)
cmap_label = f'Log Base 10 of Intensity + {log_offset}'
title = 'Scaled mean diffraction pattern'
Expand Down
Loading