From b45647d1640871d215a96a81b29937585886cc97 Mon Sep 17 00:00:00 2001 From: mashkenazi Date: Mon, 3 Nov 2025 17:37:50 +0000 Subject: [PATCH 1/2] added eh proj option to eagle draft model --- .../_torch/models/modeling_speculative.py | 40 +++++++++++++------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index bafb3c05e9d..993b714ae25 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -30,8 +30,10 @@ def __init__( self, model_config: ModelConfig[LlamaConfig], layer_idx: Optional[int] = None, + eh_proj_before_attn: bool = False, ): config = model_config.pretrained_config + self._eh_proj_before_attn = eh_proj_before_attn super().__init__( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -52,7 +54,7 @@ def __init__( tp_size = 1 # Override the QKV projection. The number of input features # is twice as big for EAGLE3 draft models. - if not self._next_layer_regular: + if not self._eh_proj_before_attn: self.qkv_proj = Linear( 2 * self.hidden_size, tp_size * self.q_size + 2 * tp_size * self.kv_size, @@ -74,13 +76,13 @@ def __init__( self, model_config: LlamaConfig, layer_idx: int = 0, - is_first_layer: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: super().__init__() config = model_config.pretrained_config + eagle_config = config.eagle_config if hasattr(config, "eagle_config") else {} self.layer_idx = layer_idx - self._next_layer_regular = config.eagle_config.get("next_layer_regular", True) and not is_first_layer - self.self_attn = Eagle3Attention(model_config, layer_idx, self._next_layer_regular) + self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn", False) + self.self_attn = Eagle3Attention(model_config, layer_idx, self._eh_proj_before_attn) if config.model_type == "llama4_text": inter_size = config.intermediate_size_mlp @@ -96,7 +98,7 @@ def __init__( overridden_tp_size=1 if model_config.mapping.enable_attention_dp else None, ) - if not self._next_layer_regular: + if not self._eh_proj_before_attn: self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) @@ -120,7 +122,7 @@ def forward( residual = hidden_states hidden_states = self.hidden_norm(hidden_states) - if not self._next_layer_regular: + if not self._eh_proj_before_attn: embeds = self.input_layernorm(embeds) hidden_states = torch.cat([embeds, hidden_states], dim=-1) @@ -153,18 +155,20 @@ def __init__( super().__init__(model_config) config = model_config.pretrained_config + eagle_config = config.eagle_config if hasattr(config, "eagle_config") else {} self.spec_config = model_config.spec_config self.dtype = config.torch_dtype self.hidden_size = config.hidden_size self.mapping = model_config.mapping self.num_layers = model_config.pretrained_config.num_hidden_layers + self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn", False) if hasattr(config, "target_hidden_size"): self.hidden_size_in = config.target_hidden_size else: self.hidden_size_in = config.hidden_size - self._return_hidden_post_norm = config.eagle_config.get("return_hidden_post_norm", False) + self._return_hidden_post_norm = eagle_config.get("return_hidden_post_norm", False) if self.spec_config.num_capture_layers > 1: self.fc = Linear(self.hidden_size_in * @@ -175,7 +179,7 @@ def __init__( if self.num_layers > 1: self.midlayer = nn.ModuleList([ - Eagle3DecoderLayer(model_config, start_layer_idx + i, i == 0) + Eagle3DecoderLayer(model_config, start_layer_idx + i) for i in range(self.num_layers) ]) else: @@ -189,6 +193,14 @@ def __init__( self.d2t = nn.Parameter(torch.empty((config.draft_vocab_size, ), dtype=torch.int32), requires_grad=False) + if self._eh_proj_before_attn: + self.enorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=eagle_config.get("eh_proj_bias", False), + dtype=config.torch_dtype) if self.hidden_size_in != config.hidden_size: if model_config.mapping.enable_attention_dp: @@ -230,11 +242,15 @@ def forward( inputs_embeds = self.embed_tokens(input_ids).to(self.dtype) assert hidden_states is not None - # NOTE: If hidden states from the target model have to be concatenated, - # we expect that to happen outside the model definition. This helps us - # avoid data-dependent control flow and gives us better CUDA graph - # coverage. + # ideally,we expect that to happen outside the model definition. This + # helps usavoid data-dependent control flow and gives us better CUDA + # graph coverage. + if self._eh_proj_before_attn: + input_embeds = self.enorm(inputs_embeds) + hidden_states = torch.cat([input_embeds, hidden_states], dim=-1) + hidden_states = self.eh_proj(hidden_states) + residual = None if self.num_layers > 1: for layer in self.midlayer: From 333b8b9153c358d68b424080aa3587553d7b6e20 Mon Sep 17 00:00:00 2001 From: mashkenazi Date: Mon, 3 Nov 2025 17:58:48 +0000 Subject: [PATCH 2/2] added back next_layer_regular for multiple layer eagle --- .../_torch/models/modeling_speculative.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 993b714ae25..9c04b5360e3 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -30,10 +30,10 @@ def __init__( self, model_config: ModelConfig[LlamaConfig], layer_idx: Optional[int] = None, - eh_proj_before_attn: bool = False, + next_layer_regular: bool = False, ): config = model_config.pretrained_config - self._eh_proj_before_attn = eh_proj_before_attn + self._next_layer_regular = next_layer_regular super().__init__( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -54,7 +54,7 @@ def __init__( tp_size = 1 # Override the QKV projection. The number of input features # is twice as big for EAGLE3 draft models. - if not self._eh_proj_before_attn: + if not self._next_layer_regular: self.qkv_proj = Linear( 2 * self.hidden_size, tp_size * self.q_size + 2 * tp_size * self.kv_size, @@ -76,13 +76,14 @@ def __init__( self, model_config: LlamaConfig, layer_idx: int = 0, + is_first_layer: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: super().__init__() config = model_config.pretrained_config eagle_config = config.eagle_config if hasattr(config, "eagle_config") else {} self.layer_idx = layer_idx - self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn", False) - self.self_attn = Eagle3Attention(model_config, layer_idx, self._eh_proj_before_attn) + self._next_layer_regular = (eagle_config.get("next_layer_regular", True) and not is_first_layer) or eagle_config.get("eh_proj_before_attn", False) + self.self_attn = Eagle3Attention(model_config, layer_idx, self._next_layer_regular) if config.model_type == "llama4_text": inter_size = config.intermediate_size_mlp @@ -98,7 +99,7 @@ def __init__( overridden_tp_size=1 if model_config.mapping.enable_attention_dp else None, ) - if not self._eh_proj_before_attn: + if not self._next_layer_regular: self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) @@ -122,7 +123,7 @@ def forward( residual = hidden_states hidden_states = self.hidden_norm(hidden_states) - if not self._eh_proj_before_attn: + if not self._next_layer_regular: embeds = self.input_layernorm(embeds) hidden_states = torch.cat([embeds, hidden_states], dim=-1) @@ -179,7 +180,7 @@ def __init__( if self.num_layers > 1: self.midlayer = nn.ModuleList([ - Eagle3DecoderLayer(model_config, start_layer_idx + i) + Eagle3DecoderLayer(model_config, start_layer_idx + i, i == 0) for i in range(self.num_layers) ]) else: