Skip to content

Conversation

@oliverkinch
Copy link
Collaborator

@oliverkinch oliverkinch commented Nov 13, 2025

Implements CP for non MoE models. Implementing CP for MoEs will be in a separate PR.

Fix #31.
#38 will be redundant given this PR.

@oliverkinch
Copy link
Collaborator Author

From train.py we have

if parallel_dims.cp_enabled: # the following is necessary for CP w/ flex attention
    from torch.distributed.tensor.experimental._attention import _set_cp_global_var, _DispatchMode, _cp_options

    # set_rotate_method("alltoall")  # alltoall or allgather (only allgather for flex)
    _set_cp_global_var("cp_shard_dim", 2)
    # _cp_options.enable_load_balance = True  # no load balancing for flex
    torch.distributed.tensor.experimental._attention._dispatch_mode = (
        _DispatchMode.TORCH_FUNCTION
    )

_set_cp_global_var is only available in torch 2.9.0, but if I force this version the code crashes as .backward() is called. Is _set_cp_global_var necessary?

@oliverkinch
Copy link
Collaborator Author

Problems with FLASH_ATTENTION? It works with MATH

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale)
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(1, 16, 4096, 192), dtype=torch.bfloat16,
           grad_fn=<TransposeBackward0>), FakeTensor(..., device='cuda:0', size=(1, 16, 4096, 192), dtype=torch.bfloat16,
           grad_fn=<TransposeBackward0>), FakeTensor(..., device='cuda:0', size=(1, 16, 4096, 128), dtype=torch.bfloat16,
           grad_fn=<TransposeBackward0>)), **{'is_causal': True, 'scale': 0.07216878364870322}): got RuntimeError('No available kernel. Aborting execution.')

@oliverkinch
Copy link
Collaborator Author

@rlrs Context parallelism now runs for gemma and llama

@oliverkinch oliverkinch marked this pull request as ready for review November 13, 2025 09:59
@oliverkinch oliverkinch requested a review from rlrs November 13, 2025 09:59
New dcp script related to model where yarn has been used to extend the context length
@oliverkinch
Copy link
Collaborator Author

oliverkinch commented Dec 10, 2025

I have now also included the related to YaRN in this PR, see d36078d

image
  • Blue: The base model.
  • Orange: The base model with its context window extended from 4k to 32k using YaRN, without any additional training.
  • Green: The same YaRN-extended model, further trained for 1,000 steps on long-context data (wiki_expanded)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Context parallelism

2 participants