-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Open
Description
def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
action_mask):
kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
rewards = kl_divergence_estimate
start = prompts.shape[1] - 1
ends = start + action_mask[:, start:].sum(1) + 1
reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
self.clip_reward_value)
batch_size = log_probs.shape[0]
for j in range(batch_size):
rewards[j, start:ends[j]][-1] += reward_clip[j]
return rewards
rewards[:,ends[j]-1](advantages[:,ends[j]-1) will be mask in actor_loss_fn:
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
## policy gradient loss
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss
because mask is action_mask[:,start:] = attention_mask[:, 1:][:,start:] and rewards[:,ends[j]-1] represent the last non-padding token's reward by predicting padding token which action_mask is 0,The reward_score should give the penultimate non-padding token to reward it's action —— predicting the last non-padding token.
for example:
prompts:
tensor([[101, 102]])
start = 1 end = 5
seq:
tensor([[101, 102, 103, 104, 105, 0, 0, 0]])
attention_mask:
tensor([[1, 1, 1, 1, 1, 0, 0, 0]])
action_mask:
tensor([[1, 1, 1, 1, 0, 0, 0]])
mask = action_mask[:,start:] = tensor([[1, 1, 1, 0, 0, 0]])
reward_score = tensor([[2.5]])
old_rewards:
tensor([[ 8.1432e-03, 7.7722e-04, -4.7493e-05, 3.8694e-03, 2.5037e+00,
0.0000e+00, 0.0000e+00]])
old_values:
tensor([[0.5000, 0.8000, 1.2000, 1.5000, 1.8000, 0.0000, 0.0000]])
advantages:
tensor([[1.4950, 1.1762, 0.9477, 0.7037, 0.0000, 0.0000]])
log_ratio= (logprobs - old_logprobs) * mask:
tensor([[ 0.0078, 0.0045, -0.0020, -0.0000, 0.0000, 0.0000]])
ratio
tensor([[1.0079, 1.0045, 0.9980, 1.0000, 1.0000, 1.0000]])
pg_loss1=-advantages * ratio:
tensor([[-1.5068, -1.1815, -0.9458, -0.7037, -0.0000, -0.0000]])
pg_loss2:
tensor([[-1.5068, -1.1815, -0.9458, -0.7037, -0.0000, -0.0000]])
torch.max(pg_loss1, pg_loss2) * mask:
tensor([[-1.5068, -1.1815, -0.9458, -0.0000, -0.0000, -0.0000]])
which u can see token 105's loss is 0 which means the most important reward 2.5037e+00 doesn't backward
rewards[j, start:ends[j]-1][-1] += reward_clip[j] is correct
Metadata
Metadata
Assignees
Labels
No labels