Problem
The log probability calculation contains an error in the variance term:
log_prob_matrix = ... - 0.5 * log_sigma.sum(dim=-1)
- 0.5 * log_sigma.sum(dim=-1) is theoretically inconsistent with Gaussian distributions
Solution
Replace with:
log_prob_matrix = ... - log_sigma.sum(dim=-1)