-
Notifications
You must be signed in to change notification settings - Fork 15
Description
Hey!
First, thank you for your contribution to the field as well as open-sourcing your code! Really appreciated!
I hopes it's OK that I approach to you here: I want to use D3PM for protein sequences (similar to what you did, with LM) but I'm struggling to understand the following point in the reverse process:
In your paper you've mentioned:
Which theoretically, I agree with. But when it comes to implementing it's not possible to calculate the sum in the last line.
While that in the original paper of D3PM they use the mean & log scale to predict that distribution, as far as I understand, in your code you only consider the logits of
More specifically I looked at:
MLMDiffusionTransformer.forward():
sequence_output = self.encoder(embed, encoder_attention_mask=attn_mask)[0]
prediction_scores = self.cls(sequence_output)
out = {
"logits": prediction_scores,
"sequence_output": sequence_output,
"embeds": token_embed,
}AND
MLMDiffusion.forward():
corrupt_ids, corrupt_mask = (
self.noise_schedule.corrupt(input_ids, t, corrupt_mask)
)
model_output = self.network(
corrupt_ids,
t,
attn_mask,
)
logits = model_output['logits']
hiddens = model_output['sequence_output']
loss_fct = nn.CrossEntropyLoss(reduction='none') # -100 index = padding token
nll = loss_fct(logits.view(-1, logits.shape[-1]), input_ids.view(-1))Am I missing something? If not, how is it match the paper?
Thanks a lot!
Sagi
