diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index bafb3c05e9d..9c04b5360e3 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -30,8 +30,10 @@ def __init__( self, model_config: ModelConfig[LlamaConfig], layer_idx: Optional[int] = None, + next_layer_regular: bool = False, ): config = model_config.pretrained_config + self._next_layer_regular = next_layer_regular super().__init__( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -78,8 +80,9 @@ def __init__( ) -> Tuple[torch.Tensor, torch.Tensor]: super().__init__() config = model_config.pretrained_config + eagle_config = config.eagle_config if hasattr(config, "eagle_config") else {} self.layer_idx = layer_idx - self._next_layer_regular = config.eagle_config.get("next_layer_regular", True) and not is_first_layer + self._next_layer_regular = (eagle_config.get("next_layer_regular", True) and not is_first_layer) or eagle_config.get("eh_proj_before_attn", False) self.self_attn = Eagle3Attention(model_config, layer_idx, self._next_layer_regular) if config.model_type == "llama4_text": @@ -153,18 +156,20 @@ def __init__( super().__init__(model_config) config = model_config.pretrained_config + eagle_config = config.eagle_config if hasattr(config, "eagle_config") else {} self.spec_config = model_config.spec_config self.dtype = config.torch_dtype self.hidden_size = config.hidden_size self.mapping = model_config.mapping self.num_layers = model_config.pretrained_config.num_hidden_layers + self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn", False) if hasattr(config, "target_hidden_size"): self.hidden_size_in = config.target_hidden_size else: self.hidden_size_in = config.hidden_size - self._return_hidden_post_norm = config.eagle_config.get("return_hidden_post_norm", False) + self._return_hidden_post_norm = eagle_config.get("return_hidden_post_norm", False) if self.spec_config.num_capture_layers > 1: self.fc = Linear(self.hidden_size_in * @@ -189,6 +194,14 @@ def __init__( self.d2t = nn.Parameter(torch.empty((config.draft_vocab_size, ), dtype=torch.int32), requires_grad=False) + if self._eh_proj_before_attn: + self.enorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=eagle_config.get("eh_proj_bias", False), + dtype=config.torch_dtype) if self.hidden_size_in != config.hidden_size: if model_config.mapping.enable_attention_dp: @@ -230,11 +243,15 @@ def forward( inputs_embeds = self.embed_tokens(input_ids).to(self.dtype) assert hidden_states is not None - # NOTE: If hidden states from the target model have to be concatenated, - # we expect that to happen outside the model definition. This helps us - # avoid data-dependent control flow and gives us better CUDA graph - # coverage. + # ideally,we expect that to happen outside the model definition. This + # helps usavoid data-dependent control flow and gives us better CUDA + # graph coverage. + if self._eh_proj_before_attn: + input_embeds = self.enorm(inputs_embeds) + hidden_states = torch.cat([input_embeds, hidden_states], dim=-1) + hidden_states = self.eh_proj(hidden_states) + residual = None if self.num_layers > 1: for layer in self.midlayer: