Truncated Normal distribution in PyTorch. The module provides:
TruncatedStandardNormalclass - zero mean unit variance of the parent Normal distribution, parameterized by the cut-off range[a, b](similar toscipy.stats.truncnorm);TruncatedNormalclass - a wrapper with extralocandscaleparameters of the parent Normal distribution;- Differentiability wrt parameters of the distribution;
- Batching support.
I just needed differentiation with respect to parameters of the distribution and found out that truncated normal
distribution is not bundled in torch.distributions as of 1.6.0.
icdf is numerically unstable; as a consequence, so is rsample. This issue is also seen in
torch.distributions.normal.Normal, so it is sort of normal (ba-dum-tss).
CUDA_VISIBLE_DEVICES=0 python -m tests.testhttps://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf