Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from fast_llm.utils import Assert, compare_nested, log

if typing.TYPE_CHECKING:
import torch

from fast_llm.engine.base_model.base_model import BaseModel


Expand Down Expand Up @@ -58,24 +60,29 @@ def _serialize_architecture_field(self, value: typing.Any) -> typing.Any:
return self._serialize_value(value)


def set_model_names(model: "torch.nn.Module"):
from fast_llm.tensor import ParameterMeta

for key, value in model.named_modules():
value.module_name = key
for key, value in model.named_parameters():
Assert.custom(isinstance, value, ParameterMeta)
# Rename to the parameter full name
value.tensor_name = key


@config_class()
class BaseModelConfig(ModuleConfig):
"""
Abstract config class for a base model.
"""

def get_base_model(self, distributed_config: DistributedConfig) -> "BaseModel":
from fast_llm.tensor import ParameterMeta

model = self.base_model_class(self, distributed_config)
# Storing the global name of each module and tensor.
# Done here because it needs to run right after `model.__init__()`
for key, value in model.named_modules():
value.module_name = key
for key, value in model.named_parameters():
Assert.custom(isinstance, value, ParameterMeta)
# Rename to the parameter full name
value.tensor_name = key
set_model_names(model)
return model

@property
Expand Down
1 change: 1 addition & 0 deletions fast_llm/engine/multi_stage/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
):
self._name = name
self._parameter_metas = {parameter_meta.tensor_name: parameter_meta for parameter_meta in parameter_metas}
Assert.eq(len(self._parameter_metas), len(parameter_metas)) # `set_model_names` ensure unique names.
self._distributed_config = distributed_config
self._fsdp_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.data)
self._is_tied_weight_copy = is_tied_weight_copy
Expand Down
47 changes: 9 additions & 38 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.functional.autograd import wrap_forward_backward
from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs
from fast_llm.layers.attention.preprocessing import preprocess_for_varlen
from fast_llm.layers.block.config import BlockDimNames
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.block import BlockWithBias
from fast_llm.tensor import TensorMeta
from fast_llm.utils import Assert, div
from fast_llm.utils import div

try:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa
Expand Down Expand Up @@ -505,40 +506,10 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non
kwargs[AttentionKwargs.attention_mask_value] = self._backup_attention_mask_value

def _preprocess_for_flash_attention(self, kwargs: dict[str, typing.Any]) -> None:
"""
Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375
cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively.
Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k.
If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally
also contain previous tokens from the first document in micro-sequence.
We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths.
"""
if self._config.cross_document_attention:
return
device = kwargs[AttentionKwargs.device] if AttentionKwargs.device in kwargs else self._distributed.device

# TODO: ====== Fix (need to know how much first sequence was cropped) ======
Assert.eq(
kwargs[AttentionKwargs.sequence_k_dim].global_size, kwargs[AttentionKwargs.sequence_q_dim].global_size
)

# TODO: Calculate these in batch preprocessing?
sequence_lengths_q = torch.tensor(
[
0,
*(
sequence_length
for sequence_lengths in kwargs[AttentionKwargs.sequence_lengths]
for sequence_length in sequence_lengths
),
],
dtype=torch.int32,
)
max_sequence_length = sequence_lengths_q.max().item()
cu_seqlens_q = sequence_lengths_q.cumsum_(0).to(device)
max_seqlen_q = cu_seqlens_q.new_full((1,), max_sequence_length)
kwargs[AttentionKwargs.cu_seqlens_q] = cu_seqlens_q
kwargs[AttentionKwargs.cu_seqlens_k] = cu_seqlens_q
kwargs[AttentionKwargs.max_seqlen_q] = max_seqlen_q
kwargs[AttentionKwargs.max_seqlen_k] = max_seqlen_q
if not self._config.cross_document_attention:
preprocess_for_varlen(
kwargs,
kwargs[AttentionKwargs.device] if AttentionKwargs.device in kwargs else self._distributed.device,
return_cu_seqlens=True,
return_max_seqlen=True,
)
15 changes: 10 additions & 5 deletions fast_llm/layers/attention/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@
logger = logging.getLogger(__name__)


class AttentionKwargs(BlockKwargs):
rotary_freq_q = "rotary_freq_q"
rotary_freq_k = "rotary_freq_k"
attention_mask = "attention_mask"
attention_mask_value = "attention_mask_value"
class MixerKwargs(BlockKwargs):
cu_seqlens_q = "cu_seqlens_q"
cu_seqlens_k = "cu_seqlens_k"
max_seqlen_q = "max_seqlen_q"
max_seqlen_k = "max_seqlen_k"
seq_idx = "seq_idx"
position_ids = "position_ids"


class AttentionKwargs(MixerKwargs):
rotary_freq_q = "rotary_freq_q"
rotary_freq_k = "rotary_freq_k"
attention_mask = "attention_mask"
attention_mask_value = "attention_mask_value"
# TODO: Review these
presents = "presents"
past_key_values = "past_key_values"
Expand Down
58 changes: 58 additions & 0 deletions fast_llm/layers/attention/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import typing

import torch

from fast_llm.layers.attention.config import MixerKwargs
from fast_llm.utils import Assert


def preprocess_for_varlen(
kwargs: dict[str, typing.Any],
device: torch.device,
return_cu_seqlens: bool = False,
return_max_seqlen: bool = False,
return_seq_idx: bool = False,
return_position_ids: bool = False,
) -> None:
"""
Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375
cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively.
Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k.
If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally
also contain previous tokens from the first document in micro-sequence.
We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths.
"""

# TODO: ====== Fix (need to know how much first sequence was cropped) ======
Assert.eq(kwargs[MixerKwargs.sequence_k_dim].global_size, kwargs[MixerKwargs.sequence_q_dim].global_size)

sequence_lengths = [
sequence_length
for sequence_lengths in kwargs[MixerKwargs.sequence_lengths]
for sequence_length in sequence_lengths
]
if return_cu_seqlens:
cu_seqlens_q = torch.tensor([0] + sequence_lengths, dtype=torch.int32, device=device).cumsum(
0, dtype=torch.int32
)
kwargs[MixerKwargs.cu_seqlens_q] = cu_seqlens_q
kwargs[MixerKwargs.cu_seqlens_k] = cu_seqlens_q
if return_max_seqlen:
max_seqlen_q = torch.full((1,), max(sequence_lengths), dtype=torch.int32, device=device)
kwargs[MixerKwargs.max_seqlen_q] = max_seqlen_q
kwargs[MixerKwargs.max_seqlen_k] = max_seqlen_q
if return_seq_idx:
kwargs[MixerKwargs.seq_idx] = torch.cat(
[
torch.full((sequence_length,), i, dtype=torch.int32, device=device)
for i, sequence_length in enumerate(sequence_lengths)
]
)
if return_position_ids:
kwargs[MixerKwargs.position_ids] = torch.cat(
[
torch.arange(sequence_length, dtype=torch.int32, device=device)
for i, sequence_length in enumerate(sequence_lengths)
]
)
6 changes: 5 additions & 1 deletion fast_llm/layers/common/linear/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def __init__(
else self._forward_torch
)

def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor:
def _forward_torch(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
if kwargs:
raise NotImplementedError(
f"Arguments {tuple(kwargs)} not implemented for torch implementation of 1d convolution."
)
return self._activation.activation_fn(
torch.nn.functional.conv1d(
input_,
Expand Down
19 changes: 6 additions & 13 deletions fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer
from fast_llm.engine.config_utils.parameter import ParameterConfig
from fast_llm.layers.block.config import BlockKwargs
from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig
from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig
from fast_llm.layers.decoder.config import MixerConfig
Expand All @@ -20,11 +19,6 @@
from fast_llm.tensor import ParameterMeta


class LinearAttentionKwargs(BlockKwargs):
cu_seqlens = "cu_seqlens"
seq_idx = "seq_idx"


@config_class(dynamic_type={MixerConfig: "gdn"})
class GatedDeltaNetConfig(MixerConfig):
"""
Expand Down Expand Up @@ -179,13 +173,6 @@ def layer_class(self) -> "type[KimiDeltaAttention]":

return KimiDeltaAttention

def _validate(self) -> None:
with self._set_implicit_default():
if "activation" not in self.normalization._explicit_fields:
self.normalization.activation = "sigmoid"

super()._validate()


@config_class()
class SSMConfig(MixerConfig):
Expand Down Expand Up @@ -334,6 +321,12 @@ class Mamba2Config(MambaBaseConfig):
desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.",
hint=FieldHint.architecture,
)
cross_document_attention: bool = Field(
default=True,
desc="Allow for cross-document attention.",
doc="Disable to prevent attention between tokens belonging to different documents.",
hint=FieldHint.feature,
)

@property
def layer_class(self) -> "type[Mamba2]":
Expand Down
69 changes: 12 additions & 57 deletions fast_llm/layers/ssm/gdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.functional.config import ActivationType
from fast_llm.layers.attention.config import MixerKwargs
from fast_llm.layers.attention.preprocessing import preprocess_for_varlen
from fast_llm.layers.block.config import BlockKwargs
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.block import BlockWithBias
from fast_llm.layers.ssm.config import GatedDeltaNetConfig, LinearAttentionKwargs
from fast_llm.layers.ssm.config import GatedDeltaNetConfig
from fast_llm.tensor import ParameterMeta, TensorMeta
from fast_llm.utils import div

Expand Down Expand Up @@ -293,9 +295,6 @@ def _forward(
# TODO: fuse soome of the reshapes into rearranges
hidden_states = input_

cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None)
seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None)

projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs/seq x seq_len/bs x (qkvz)
projected_states_ba = self.in_proj_ba(hidden_states) # bs/seq x seq_len/bs x (b a)
if sequence_first:
Expand All @@ -315,9 +314,8 @@ def _forward(
mixed_qkv = torch.cat((query, key, value), dim=-1)
mixed_qkv = rearrange(mixed_qkv, "b s ... -> (b s) ...").unsqueeze(0) # 1 s d
mixed_qkv = rearrange(mixed_qkv, "b t d -> b d t") # mixed_qkv.transpose(1, 2)
mixed_qkv = self.convolution(
mixed_qkv, seq_idx=seq_idx
) # conv func. gets sequence dim as last dim, see https://github.com/Dao-AILab/causal-conv1d/blob/22a4577d8ace9d5703daea91a7fb56695492152b/causal_conv1d/causal_conv1d_interface.py#L110
# conv func. gets sequence dim as last dim, see https://github.com/Dao-AILab/causal-conv1d/blob/22a4577d8ace9d5703daea91a7fb56695492152b/causal_conv1d/causal_conv1d_interface.py#L110
mixed_qkv = self.convolution(mixed_qkv, seq_idx=kwargs[MixerKwargs.seq_idx].unsqueeze(0))
mixed_qkv = rearrange(mixed_qkv, "b d t -> b t d") # mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
mixed_qkv,
Expand Down Expand Up @@ -351,7 +349,7 @@ def _forward(
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
cu_seqlens=kwargs[MixerKwargs.cu_seqlens_q],
)

z_shape_og = z.shape
Expand All @@ -368,56 +366,13 @@ def _forward(

return output

def _preprocess_for_varlen(self, kwargs: dict[str, typing.Any]) -> None:
"""
Creates seqlens and cu_seqlens for packed forward.
This assumes that forward pass is performed on a fully packed sequence, i.e. where sequences are flattened out into BS = 1.
Note: padding tokens are always on the right and get their own entry in LinearAttentionKwargs.sequence_lengths --> they are treated as seperate sequence.

Sets:
- seq_idx to [1, BS x T] tensor, where each elemnt is the sequence index of the corresponding token
- cu_seqlens to [N+1] tensor, where N is the total number of sequences in the batch, each element is the cumulative sequence length of packed sequences sofar
"""

sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths]
device = kwargs.get("device", None)
if sequence_lengths is None:
raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.")
seqlens = torch.tensor(
[
0,
*(
sequence_length
for sequence_lengths in kwargs[LinearAttentionKwargs.sequence_lengths]
for sequence_length in sequence_lengths
),
],
dtype=torch.int32,
)
cu_seqlens = seqlens.cumsum_(0).to(device)
# this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303
# also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347
kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens
# seq_idx has to be (bs, seqlen), but bs is forced to 1
kwargs[LinearAttentionKwargs.seq_idx] = (
(
torch.cat(
[
torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device)
for n in (torch.diff(cu_seqlens).to(torch.int32))
],
dim=0,
)
.eq(0)
.cumsum(0)
- 1
)
.to(torch.int32)
.unsqueeze(0)
)

def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
self._preprocess_for_varlen(kwargs)
preprocess_for_varlen(
kwargs,
kwargs[MixerKwargs.device] if MixerKwargs.device in kwargs else self._distributed.device,
return_cu_seqlens=True,
return_seq_idx=True,
)

def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
raise NotImplementedError()
Loading
Loading