diff --git a/src/paddlefleet/transformer/moe/moe_layer.py b/src/paddlefleet/transformer/moe/moe_layer.py index 646041a06..4116ca857 100644 --- a/src/paddlefleet/transformer/moe/moe_layer.py +++ b/src/paddlefleet/transformer/moe/moe_layer.py @@ -645,7 +645,7 @@ def compute_combine(self, hidden_states, async_finish=False): def aux_loss_compute(self, args): hidden_states, aux_loss, residuals = args - if self.training and self.router_aux_loss_coef: + if self.training and self.router_aux_loss_coef > 0.0: aux_loss = aux_loss * self.router_aux_loss_coef output = AddAuxiliaryLoss.apply(hidden_states, aux_loss) @@ -718,7 +718,7 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: reshaped_input, topk_indices, topk_weights ) - if self.training and self.router_aux_loss_coef: + if self.training and self.router_aux_loss_coef > 0.0: aux_loss = aux_loss * self.router_aux_loss_coef output = AddAuxiliaryLoss.apply(output, aux_loss) diff --git a/src/paddlefleet/transformer/moe/moe_router.py b/src/paddlefleet/transformer/moe/moe_router.py index 211496c72..a6ed204e9 100644 --- a/src/paddlefleet/transformer/moe/moe_router.py +++ b/src/paddlefleet/transformer/moe/moe_router.py @@ -603,7 +603,7 @@ def forward(self, input): self.expert_usage += exp_counts # aux_loss - if self.config.router_aux_loss_coef: + if self.config.router_aux_loss_coef > 0.0: if self.routing_type == "seq_aux_loss": l_aux = self._cal_seq_aux_loss( gates_ori,