Skip to content

Understanding the discrete reverse process #2

@SagiPolaczek

Description

@SagiPolaczek

Hey!

First, thank you for your contribution to the field as well as open-sourcing your code! Really appreciated!
I hopes it's OK that I approach to you here: I want to use D3PM for protein sequences (similar to what you did, with LM) but I'm struggling to understand the following point in the reverse process:

In your paper you've mentioned:

Screenshot 2023-08-17 at 12 39 12

Which theoretically, I agree with. But when it comes to implementing it's not possible to calculate the sum in the last line.
While that in the original paper of D3PM they use the mean & log scale to predict that distribution, as far as I understand, in your code you only consider the logits of $p_\theta(x_0 | x_t)$.

More specifically I looked at:
MLMDiffusionTransformer.forward():

        sequence_output = self.encoder(embed, encoder_attention_mask=attn_mask)[0]
        prediction_scores = self.cls(sequence_output)

        out = {
            "logits": prediction_scores,
            "sequence_output": sequence_output,
            "embeds": token_embed,
        }

AND
MLMDiffusion.forward():

        corrupt_ids, corrupt_mask = (
            self.noise_schedule.corrupt(input_ids, t, corrupt_mask)
        )

        model_output = self.network(
            corrupt_ids,
            t, 
            attn_mask,
        )
        logits = model_output['logits']
        hiddens = model_output['sequence_output']
        
        loss_fct = nn.CrossEntropyLoss(reduction='none')  # -100 index = padding token
        nll = loss_fct(logits.view(-1, logits.shape[-1]), input_ids.view(-1))

Am I missing something? If not, how is it match the paper?

Thanks a lot!
Sagi

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions