diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 31991a9..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "Prithvi-WxC"] - path = Prithvi-WxC - url = git@github.com:NASA-IMPACT/Prithvi-WxC.git diff --git a/README.md b/README.md index 9ae8b24..dd78977 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,11 @@ This repository contains code and resources for training and inferring gravity w ## Setup -1. Clone the repository with submodules: +1. Clone the repository and download config: - git clone --recurse-submodules git@github.com:NASA-IMPACT/gravity-wave-finetuning.git gravity_wave_finetuning - cd gravity_wave_finetuning + git clone git@github.com:NASA-IMPACT/gravity-wave-finetuning.git + cd gravity-wave-finetuning + wget https://huggingface.co/Prithvi-WxC/Gravity_wave_Parameterization/resolve/main/config.yaml 2. Create and activate a Conda environment for the project: @@ -55,7 +56,7 @@ To run the training on a single node and a single GPU, execute the following com --nproc_per_node=1 \ --nnodes=1 \ --rdzv_backend=c10d \ - finetune_gravity_wave.py + prithviwxc/gravitywave/finetune_gravity_wave.py --split uvtp122 ### Multi-node Training @@ -71,7 +72,7 @@ After training, you can run inferences using the following command. Make sure to --nnodes=1 \ --nproc_per_node=1 \ --rdzv_backend=c10d \ - inference.py \ + prithviwxc/gravitywave/inference.py \ --split=uvtp122 \ --ckpt_path=/path/to/checkpoint \ --data_path=/path/to/data \ diff --git a/environment.yml b/environment.yml index 3cb204e..3bad4dc 100644 --- a/environment.yml +++ b/environment.yml @@ -208,3 +208,4 @@ dependencies: - wandb==0.17.7 - xarray==2024.7.0 - yacs==0.1.8 + - git+https://github.com/NASA-IMPACT/Prithvi-WxC.git diff --git a/prithviwxc/__init__.py b/prithviwxc/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/prithviwxc/__init__.py @@ -0,0 +1 @@ + diff --git a/prithviwxc/gravitywave/__init__.py b/prithviwxc/gravitywave/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/prithviwxc/gravitywave/__init__.py @@ -0,0 +1 @@ + diff --git a/config.py b/prithviwxc/gravitywave/config.py similarity index 59% rename from config.py rename to prithviwxc/gravitywave/config.py index db503ce..987877f 100644 --- a/config.py +++ b/prithviwxc/gravitywave/config.py @@ -2,17 +2,12 @@ _CN = CN() +# Declare all the configuration keys _CN.wandb_mode = "disabled" _CN.vartype = "uvtp122" -_CN.train_data_path = ( - "gravity_wave_flux/uvtp122" -) -_CN.valid_data_path = ( - "gravity_wave_flux/uvtp122/test" -) -_CN.singular_sharded_checkpoint = ( - "prithvi_wxc/v0.8.50.rollout_step3.1.pth" -) +_CN.train_data_path = "gravity_wave_flux/uvtp122" +_CN.valid_data_path = "gravity_wave_flux/uvtp122/test" +_CN.singular_sharded_checkpoint = "prithvi_wxc/v0.8.50.rollout_step3.1.pth" _CN.file_glob_pattern = "wxc_input_u_v_t_p_output_theta_uw_vw_era5_*.nc" _CN.lr = 0.0001 @@ -23,14 +18,14 @@ _CN.mask_unit_size_px = [8, 16] _CN.patch_size_px = [1, 1] - -### Training Params - +# Training parameters _CN.max_epochs = 100 _CN.batch_size = 12 _CN.num_data_workers = 8 +_CN.merge_from_file("config.yaml") +# Function to clone the config def get_cfg(): return _CN.clone() diff --git a/datamodule.py b/prithviwxc/gravitywave/datamodule.py similarity index 94% rename from datamodule.py rename to prithviwxc/gravitywave/datamodule.py index 04c09ca..dc8d581 100644 --- a/datamodule.py +++ b/prithviwxc/gravitywave/datamodule.py @@ -4,7 +4,7 @@ import xarray as xr from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler - +import lightning as pl def get_era5_uvtp122(ds: xr.Dataset, index: int = 0) -> dict[str, torch.Tensor]: """Retrieve climate data variables at 122 pressure levels. @@ -134,7 +134,7 @@ def __getitem__(self, index: int = 0) -> dict[str, torch.Tensor]: return batch -class ERA5DataModule: +class ERA5DataModule(pl.LightningDataModule): """ This module handles data loading, batching, and train/validation splits. @@ -143,7 +143,7 @@ class ERA5DataModule: valid_data_path: Path to validation data. file_glob_pattern: Pattern to match NetCDF files. batch_size: Size of each mini-batch. - num_workers: Number of subprocesses for data loading. + num_workers: Number of subprocesses for data loading. """ def __init__( @@ -171,24 +171,26 @@ def __init__( self.batch_size: int = batch_size self.num_workers: int = num_data_workers - def setup(self, stage: str | None = None) -> tuple[Dataset, Dataset]: - """Sets up the datasets for different stages - (train, validation, predict). + def prepare_data(self): + pass - Args: - stage: Stage for which the setup is performed ("fit", "predict"). - """ + def setup(self, stage: str | None = None) -> tuple[Dataset, Dataset]: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if stage == "fit": self.dataset_train = ERA5Dataset( data_path=self.train_data_path, file_glob_pattern=self.file_glob_pattern ) + self.dataset_train = self.dataset_train.to(device) self.dataset_val = ERA5Dataset( data_path=self.valid_data_path, file_glob_pattern=self.file_glob_pattern ) + self.dataset_val = self.dataset_val.to(device) elif stage == "predict": self.dataset_predict = ERA5Dataset( data_path=self.valid_data_path, file_glob_pattern=self.file_glob_pattern ) + self.dataset_predict = self.dataset_predict.to(device) + def train_dataloader(self) -> DataLoader: """Returns a DataLoader for the training data.""" diff --git a/distributed.py b/prithviwxc/gravitywave/distributed.py similarity index 100% rename from distributed.py rename to prithviwxc/gravitywave/distributed.py diff --git a/finetune_gravity_wave.py b/prithviwxc/gravitywave/finetune_gravity_wave.py similarity index 100% rename from finetune_gravity_wave.py rename to prithviwxc/gravitywave/finetune_gravity_wave.py diff --git a/gravity_wave_model.py b/prithviwxc/gravitywave/gravity_wave_model.py similarity index 99% rename from gravity_wave_model.py rename to prithviwxc/gravitywave/gravity_wave_model.py index 80e531f..91e0ca7 100644 --- a/gravity_wave_model.py +++ b/prithviwxc/gravitywave/gravity_wave_model.py @@ -2,9 +2,9 @@ import torch.nn as nn import importlib -model = importlib.import_module('Prithvi-WxC.PrithviWxC.model') +model = importlib.import_module('PrithviWxC.model') -from distributed import print0 +from prithviwxc.gravitywave.distributed import print0 torch.set_float32_matmul_precision("high") diff --git a/inference.py b/prithviwxc/gravitywave/inference.py similarity index 91% rename from inference.py rename to prithviwxc/gravitywave/inference.py index cd30adc..0ba40e1 100644 --- a/inference.py +++ b/prithviwxc/gravitywave/inference.py @@ -7,14 +7,14 @@ import tqdm import xarray as xr -from datamodule import ERA5DataModule -from gravity_wave_model import UNetWithTransformer +from prithviwxc.gravitywave.datamodule import ERA5DataModule +from prithviwxc.gravitywave.gravity_wave_model import UNetWithTransformer from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist -local_rank = int(os.environ["LOCAL_RANK"]) -rank = int(os.environ["RANK"]) -device = f"cuda:{local_rank}" +local_rank = int(os.getenv("LOCAL_RANK",'0')) +rank = int(os.getenv("RANK",'0')) +device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") dtype = torch.float32 def setup(): @@ -60,7 +60,11 @@ def get_model(cfg, vartype,ckpt_singular: str) -> torch.nn.Module: patch_size_px=cfg.patch_size_px, device=device, ) - model = DDP(model.to(local_rank, dtype=dtype), device_ids=[local_rank]) + + if device==torch.device('cpu'): + model = DDP(model.to(device)) + else: + model = DDP(model.to(local_rank, dtype=dtype), device_ids=[local_rank]) model = load_checkpoint(model,ckpt_singular) return model diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6e7e3c3 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,99 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "prithviwxc-gravitywave" +version = "0.1.0" +description = "Gravity Wave Parameterization" +readme = { file = "README.md", content-type = "text/markdown" } +requires-python = ">=3.10" +license = { text = "MIT" } +authors = [ + { name = "Sujit Roy", email = "sr0197@uah.edu" } +] +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent" +] +dependencies = [ +"cachetools", +"cartopy", +"click", +"cloudpickle", +"contourpy", +"cycler", +"dask", +"docker-pycreds", +"fonttools", +"fsspec", +"gitdb", +"gitpython", +"h5netcdf", +"h5py", +"kiwisolver", +"locket", +"matplotlib", +"nvidia-ml-py", +"nvitop", +"pandas", +"partd", +"platformdirs", +"protobuf", +"pyhelpme", +"pyparsing", +"pyproj", +"pyshp", +"python-dateutil", +"pytz", +"sentry-sdk", +"setproctitle", +"shapely", +"smmap", +"tabulate", +"termcolor", +"toolz", +"tqdm", +"tzdata", +"wandb", +"xarray", +"yacs", +"asttokens", +"brotli", +"certifi", +"charset-normalizer", +"comm", +"debugpy", +"decorator", +"exceptiongroup", +"executing", +"filelock", +"idna", +"importlib-metadata", +"intel-openmp", +"ipykernel", +"ipython", +"jedi", +"jinja2", +"jupyter_client", +"jupyter_core", +"lightning", +"markupsafe", +"matplotlib-inline", +"mkl", +"mpmath", +"numpy", +"scipy", +"torch", +"torchaudio", +"torchvision", +"yacs" +] + +[project.urls] +homepage = "https://github.com/NASA-IMPACT/gravity-wave-finetuning" + +[tool.setuptools] +packages=["prithviwxc", "prithviwxc.gravitywave"] +include-package-data = true diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fdbbbf4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,41 @@ +cachetools==5.5.0 +cartopy==0.23.0 +click==8.1.7 +cloudpickle==3.0.0 +contourpy==1.3.0 +cycler==0.12.1 +dask==2024.8.1 +docker-pycreds==0.4.0 +fonttools==4.53.1 +fsspec==2024.6.1 +gitdb==4.0.11 +gitpython==3.1.43 +h5netcdf==1.3.0 +h5py==3.11.0 +kiwisolver==1.4.7 +locket==1.0.0 +matplotlib==3.9.2 +nvidia-ml-py==12.535.161 +nvitop==1.3.2 +pandas==2.2.2 +partd==1.4.2 +platformdirs==4.2.2 +protobuf==5.27.3 +pyhelpme==0.1 +pyparsing==3.1.4 +pyproj==3.6.1 +pyshp==2.3.1 +python-dateutil==2.9.0.post0 +pytz==2024.1 +sentry-sdk==2.13.0 +setproctitle==1.3.3 +shapely==2.0.6 +smmap==5.0.1 +tabulate==0.9.0 +termcolor==2.4.0 +toolz==0.12.1 +tqdm==4.66.5 +tzdata==2024.1 +wandb==0.17.7 +xarray==2024.7.0 +yacs==0.1.8 diff --git a/scripts/train.sh b/scripts/train.sh index b4dba06..03e4f7f 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -74,6 +74,6 @@ torchrun \ --rdzv_id=$RDZV_ID \ --rdzv_endpoint "$RDZV_ADDR:$RDZV_PORT" \ --rdzv_backend=c10d \ - .py --split $data_args + ./prithviwxc/gravitywave/.py --split $data_args conda deactivate \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..4bc870a --- /dev/null +++ b/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup, find_packages + +setup( + name="prithviwxc-gravitywave", + version="0.1.0", + description="Gravity Wave Parameterization", + packages=["prithviwxc", "prithviwxc.gravitywave"], + include_package_data=True, +)