-
Notifications
You must be signed in to change notification settings - Fork 3
Context Parallelism #67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
oliverkinch
wants to merge
4
commits into
main
Choose a base branch
from
cp
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Collaborator
Author
|
From if parallel_dims.cp_enabled: # the following is necessary for CP w/ flex attention
from torch.distributed.tensor.experimental._attention import _set_cp_global_var, _DispatchMode, _cp_options
# set_rotate_method("alltoall") # alltoall or allgather (only allgather for flex)
_set_cp_global_var("cp_shard_dim", 2)
# _cp_options.enable_load_balance = True # no load balancing for flex
torch.distributed.tensor.experimental._attention._dispatch_mode = (
_DispatchMode.TORCH_FUNCTION
)
|
Collaborator
Author
|
Problems with with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale)torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(1, 16, 4096, 192), dtype=torch.bfloat16,
grad_fn=<TransposeBackward0>), FakeTensor(..., device='cuda:0', size=(1, 16, 4096, 192), dtype=torch.bfloat16,
grad_fn=<TransposeBackward0>), FakeTensor(..., device='cuda:0', size=(1, 16, 4096, 128), dtype=torch.bfloat16,
grad_fn=<TransposeBackward0>)), **{'is_causal': True, 'scale': 0.07216878364870322}): got RuntimeError('No available kernel. Aborting execution.') |
Collaborator
Author
|
@rlrs Context parallelism now runs for gemma and llama |
New dcp script related to model where yarn has been used to extend the context length
Collaborator
Author
|
I have now also included the related to YaRN in this PR, see d36078d
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.

Implements CP for non MoE models. Implementing CP for MoEs will be in a separate PR.
Fix #31.
#38 will be redundant given this PR.