Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions examples/hstu/test/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
import pytest
import torch
import torch.distributed as dist
from commons.checkpoint import get_unwrapped_module
from commons.distributed.finalize_model_grads import finalize_model_grads
from commons.pipeline.train_pipeline import (
JaggedMegatronPrefetchTrainPipelineSparseDist,
JaggedMegatronTrainNonePipeline,
JaggedMegatronTrainPipelineSparseDist,
)
from commons.utils.distributed_utils import collective_assert
from test_utils import create_model
from test_utils import compare_two_modules_state_dict, create_model


@pytest.mark.parametrize("contextual_feature_names", [["user0", "user1"], []])
Expand All @@ -37,7 +38,7 @@
"optimizer_type_str", ["sgd"]
) # adam does not work since torchrec does not save the optimizer state `step`.
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("use_dynamic_emb", [True, False])
@pytest.mark.parametrize("use_dynamic_emb", [False])
@pytest.mark.parametrize("pipeline_type", ["prefetch", "native"])
def test_pipeline(
pipeline_type: str,
Expand All @@ -47,8 +48,12 @@ def test_pipeline(
dtype: torch.dtype,
use_dynamic_emb: bool,
):
# TODO@jiashu, restore the test once bug is fixed in dynamic embedding.
if use_dynamic_emb and pipeline_type == "prefetch":
pytest.skip("Disable dynamic embedding with prefetch pipeline")
init.initialize_distributed()
init.initialize_model_parallel(1)

model, dense_optimizer, history_batches = create_model(
task_type="ranking",
contextual_feature_names=contextual_feature_names,
Expand Down Expand Up @@ -116,6 +121,9 @@ def test_pipeline(
)
iter_history_batches = iter(history_batches)
no_pipeline_batches = iter(history_batches)
hstu_block = get_unwrapped_module(model)._hstu_block
hstu_block_pipelined = get_unwrapped_module(pipelined_model)._hstu_block
compare_two_modules_state_dict(hstu_block_pipelined, hstu_block)
for i, batch in enumerate(history_batches):
reporting_loss, (_, logits, _, _) = no_pipeline.progress(no_pipeline_batches)
pipelined_reporting_loss, (
Expand All @@ -124,6 +132,7 @@ def test_pipeline(
_,
_,
) = target_pipeline.progress(iter_history_batches)

collective_assert(
torch.allclose(pipelined_reporting_loss, reporting_loss),
f"reporting loss mismatch",
Expand Down
10 changes: 9 additions & 1 deletion examples/hstu/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def get_tp_slice(tensor: Optional[torch.Tensor], mode="row"):
raise ValueError(f"mode {mode} is not supported")


def compare_two_modules_state_dict(module1, module2):
module1_state_dict = module1.state_dict()
module2_state_dict = module2.state_dict()
for name, param in module1_state_dict.items():
src = param
dst = module2_state_dict[name]
collective_assert(torch.allclose(src, dst), f"state dict mismatch at {name}")


# TODO: Add get_tp_slice for optimizer state.
def compare_tpN_to_debug_optimizer_state(
tpN_optimizer, debug_optimizer, debug_fp32_optimizer
Expand Down Expand Up @@ -194,7 +203,6 @@ def compare_tpN_to_debug_weights(
name = name.replace(
child_name, debug_module_path_to_tpN_module_path[child_name]
)

dst = tpN_module_params_map[name]
dst_grad = getattr(dst, "main_grad", None)
# model parallel embedding table weight is a TableBatchedEmbeddingSlice, which has no grad
Expand Down
2 changes: 1 addition & 1 deletion examples/hstu/training/configs/movielen_ranking.gin
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ TrainerArgs.log_interval = 100
TrainerArgs.seed = 1234
TrainerArgs.max_train_iters = 1000
TrainerArgs.profile = True
TrainerArgs.pipeline_type = "prefetch"
TrainerArgs.pipeline_type = "native"

DatasetArgs.dataset_name = 'ml-20m'
DatasetArgs.max_sequence_length = 200
Expand Down