-
Notifications
You must be signed in to change notification settings - Fork 20
Description
def masked_log_softmax(vector: torch.Tensor, mask: torch.BoolTensor, dim: int = -1) -> torch.Tensor:
if mask is not None:
while mask.dim() < vector.dim():
mask = mask.unsqueeze(1)
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
# results in nans when the whole vector is masked. We need a very small value instead of a
# zero in the mask for these cases.
vector = vector + (mask + tiny_value_of_dtype(vector.dtype)).log()
return torch.nn.functional.log_softmax(vector, dim=dim)
So sorry, I can not understand: why just mask the (vector=type_linear_output @ span_linear_output) before inputting the vector to the log_softmax function, how to make sure the numerator (exp(sim(s_i,j, ek))) in function (5) is the positive?