Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions tensorrt_llm/_torch/models/modeling_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def __init__(
self,
model_config: ModelConfig[LlamaConfig],
layer_idx: Optional[int] = None,
next_layer_regular: bool = False,
):
config = model_config.pretrained_config
self._next_layer_regular = next_layer_regular
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
Expand Down Expand Up @@ -78,8 +80,9 @@ def __init__(
) -> Tuple[torch.Tensor, torch.Tensor]:
super().__init__()
config = model_config.pretrained_config
eagle_config = config.eagle_config if hasattr(config, "eagle_config") else {}
self.layer_idx = layer_idx
self._next_layer_regular = config.eagle_config.get("next_layer_regular", True) and not is_first_layer
self._next_layer_regular = (eagle_config.get("next_layer_regular", True) and not is_first_layer) or eagle_config.get("eh_proj_before_attn", False)
self.self_attn = Eagle3Attention(model_config, layer_idx, self._next_layer_regular)

if config.model_type == "llama4_text":
Expand Down Expand Up @@ -153,18 +156,20 @@ def __init__(
super().__init__(model_config)

config = model_config.pretrained_config
eagle_config = config.eagle_config if hasattr(config, "eagle_config") else {}
self.spec_config = model_config.spec_config
self.dtype = config.torch_dtype
self.hidden_size = config.hidden_size
self.mapping = model_config.mapping
self.num_layers = model_config.pretrained_config.num_hidden_layers
self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn", False)

if hasattr(config, "target_hidden_size"):
self.hidden_size_in = config.target_hidden_size
else:
self.hidden_size_in = config.hidden_size

self._return_hidden_post_norm = config.eagle_config.get("return_hidden_post_norm", False)
self._return_hidden_post_norm = eagle_config.get("return_hidden_post_norm", False)

if self.spec_config.num_capture_layers > 1:
self.fc = Linear(self.hidden_size_in *
Expand All @@ -189,6 +194,14 @@ def __init__(
self.d2t = nn.Parameter(torch.empty((config.draft_vocab_size, ),
dtype=torch.int32),
requires_grad=False)
if self._eh_proj_before_attn:
self.enorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=eagle_config.get("eh_proj_bias", False),
dtype=config.torch_dtype)

if self.hidden_size_in != config.hidden_size:
if model_config.mapping.enable_attention_dp:
Expand Down Expand Up @@ -230,11 +243,15 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids).to(self.dtype)

assert hidden_states is not None

# NOTE: If hidden states from the target model have to be concatenated,
# we expect that to happen outside the model definition. This helps us
# avoid data-dependent control flow and gives us better CUDA graph
# coverage.
# ideally,we expect that to happen outside the model definition. This
# helps usavoid data-dependent control flow and gives us better CUDA
# graph coverage.
if self._eh_proj_before_attn:
input_embeds = self.enorm(inputs_embeds)
hidden_states = torch.cat([input_embeds, hidden_states], dim=-1)
hidden_states = self.eh_proj(hidden_states)

residual = None
if self.num_layers > 1:
for layer in self.midlayer:
Expand Down
Loading