Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content

BenZickel/torch_truncnorm

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torch_truncnorm

Truncated Normal distribution in PyTorch. The module provides:

  • TruncatedStandardNormal class - zero mean unit variance of the parent Normal distribution, parameterized by the cut-off range [a, b] (similar to scipy.stats.truncnorm);
  • TruncatedNormal class - a wrapper with extra loc and scale parameters of the parent Normal distribution;
  • Differentiability wrt parameters of the distribution;
  • Batching support.

Why

I just needed differentiation with respect to parameters of the distribution and found out that truncated normal distribution is not bundled in torch.distributions as of 1.6.0.

Known issues

icdf is numerically unstable; as a consequence, so is rsample. This issue is also seen in torch.distributions.normal.Normal, so it is sort of normal (ba-dum-tss).

Tests

CUDA_VISIBLE_DEVICES=0 python -m tests.test

Links

https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf

About

Truncated Normal Distribution in PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%