-
Notifications
You must be signed in to change notification settings - Fork 49
Description
RFC: Unified ParticleDataset for Sample and Trajectory-based Data Loading
Summary
This RFC proposes a unified ParticleDataset class that supports both sample-based and trajectory-based data loading for particle simulation data. The implementation leverages PyTorch's Dataset class and provides a flexible interface for handling different data formats and loading modes, including support for distributed training.
Motivation
Current particle simulation data loading methods often require separate implementations for sample-based and trajectory-based approaches. This leads to code duplication and potential inconsistencies. By unifying these approaches into a single class, we aim to:
- Simplify the codebase and reduce duplication.
- Provide a consistent interface for different data loading needs.
- Improve flexibility in handling various data formats (npz and h5).
- Enhance maintainability and extensibility of the data loading pipeline.
- Support distributed training scenarios.
The expected outcome is a more robust and versatile data loading system that can easily adapt to different research and production needs in particle simulation projects, including distributed training environments.
Design Detail
The core of this proposal is the ParticleDataset class and associated functions in the particle_data_loader.py file. Here's a detailed breakdown of its design:
- Data Loading:
def load_data(path):
"""Load data stored in npz or h5 format."""
# Implementation for loading npz and h5 files- ParticleDataset Class:
class ParticleDataset(Dataset):
def __init__(self, file_path, input_sequence_length=6, mode='sample'):
# Initialize dataset
def _preprocess_data(self):
# Preprocess data based on mode
def __len__(self):
# Return length of dataset
def __getitem__(self, idx):
# Get item based on mode (sample or trajectory)
def _get_sample(self, idx):
# Get a single sample
def _get_trajectory(self, idx):
# Get a full trajectory
def get_num_features(self):
# Return the number of features in the dataset- Collate Functions:
def collate_fn_sample(batch):
# Collate function for sample mode
def collate_fn_trajectory(batch):
# Collate function for trajectory mode- Data Loader Creation:
def get_data_loader(file_path, mode='sample', input_sequence_length=6, batch_size=32, shuffle=True, is_distributed=False):
# Create and return appropriate DataLoader based on mode and distributed settingUsage in training script:
# Determine if we're using distributed training
is_distributed = device == torch.device("cuda") and world_size > 1
# Load training data
dl = pdl.get_data_loader(
file_path=f"{cfg.data.path}train.npz",
mode='sample',
input_sequence_length=cfg.data.input_sequence_length,
batch_size=cfg.data.batch_size,
is_distributed=is_distributed
)
# Get the number of features
train_dataset = pdl.ParticleDataset(f"{cfg.data.path}train.npz")
n_features = train_dataset.get_num_features()
# Similar process for validation dataDrawbacks
- Increased complexity of a single class handling multiple modes and distributed scenarios.
- Potential for slightly increased memory usage due to storing both sample and trajectory-related attributes.
- Users familiar with separate implementations might need to adapt to the new unified interface.
Rationale and Alternatives
This design is optimal because:
- It provides a single, consistent interface for different data loading needs, including distributed training.
- It leverages PyTorch's existing Dataset and DistributedSampler classes, ensuring compatibility with the PyTorch ecosystem.
- It allows for easy switching between sample and trajectory modes without changing the underlying data structure.
- It supports both distributed and non-distributed training scenarios with a single interface.
Alternatives considered:
- Keeping separate classes for sample and trajectory loading. Rejected due to code duplication and lack of flexibility.
- Using a factory pattern to create different dataset types. Rejected as overly complex for the current needs.
- Having separate functions for distributed and non-distributed data loading. Rejected in favor of a unified interface with a flag for distributed training.
The impact of not implementing this change would be continued code duplication, potential inconsistencies between sample and trajectory implementations, and reduced flexibility in data loading options, especially in distributed training scenarios.
Prior Art
- PyTorch's Dataset and DataLoader classes: Our implementation builds directly on these established foundations.
- PyTorch's DistributedSampler: Used for handling data distribution in multi-GPU training.
- TensorFlow's tf.data API: Provides similar flexibility in data loading, though with a different approach.
- NVIDIA's DALI library: Offers high-performance data loading pipelines, though more complex than our needs.
Unresolved questions
- How to handle very large datasets that don't fit in memory?
- Should we implement lazy loading of data?
- How can we optimize the performance of data loading for large-scale simulations?
- Are there any specific optimizations needed for distributed training with very large datasets?
Changelog
- Initial draft of the RFC.
- Updated to reflect the current implementation, including support for distributed training and the unified get_data_loader function.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status