-
Notifications
You must be signed in to change notification settings - Fork 73
Description
Hi Ed,
I saw in your code, the weights are initialized with truncated normal distribution. When I ran it, it seemed in the medical-code-loss part, this produced large values feeding to exp and resulted in inf in the loss and NaN gradients. Also because of such initial weights, the loss in general is pretty high around several hundreds, especially L2 loss is around tens of thousands. Then I changed the weight initialization to be uniform with a small interval [-0.1, 0.1]. That seems to produce reasonable magnitude of loss (under 10). I wonder if you still remember whether you have tried other weight initializations and how they impact the results.
Another question I have is that in the paper, the loss is averaged over T. Is this T visits in the batch or visits per patient? In your code, it seems, your ivec and jvec are generated for the batch. So in the medical-code-loss calculation, it is averaging over all visits in a batch, instead of averaging per patient and then averaging over all patients in a batch?
Thanks!