gluonts.torch.distributions.truncated_normal module#

class gluonts.torch.distributions.truncated_normal.TruncatedNormal(loc: torch.Tensor, scale: torch.Tensor, min: Union[torch.Tensor, float] = - 1.0, max: Union[torch.Tensor, float] = 1.0, upscale: Union[torch.Tensor, float] = 5.0, tanh_loc: bool = False)[source]#

Bases: torch.distributions.distribution.Distribution

Implements a Truncated Normal distribution with location scaling.

Location scaling prevents the location to be “too far” from 0, which ultimately leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion). In practice, the location is computed according to

\[loc = tanh(loc / upscale) * upscale.\]

This behaviour can be disabled by switching off the tanh_loc parameter (see below).

Parameters
  • loc – normal distribution location parameter

  • scale – normal distribution sigma parameter (squared root of variance)

  • min – minimum value of the distribution. Default = -1.0

  • max – maximum value of the distribution. Default = 1.0

  • upscale – scaling factor. Default = 5.0

  • tanh_loc – if True, the above formula is used for the location scaling, otherwise the raw value is kept. Default is False

References

Notes

This implementation is strongly based on:
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=1e-06)}#
cdf(value)[source]#

Returns the cumulative density/mass function evaluated at value.

Parameters

value (Tensor) –

cdf_truncated_standard_normal(value)[source]#
property entropy#

Returns entropy of distribution, batched over batch_shape.

Returns

Tensor of shape batch_shape.

eps = 1e-06#
has_rsample = True#
icdf(value)[source]#

Returns the inverse cumulative density/mass function evaluated at value.

Parameters

value (Tensor) –

icdf_truncated_standard_normal(value)[source]#
log_prob(value)[source]#

Returns the log of the probability density/mass function evaluated at value.

Parameters

value (Tensor) –

log_prob_truncated_standard_normal(value)[source]#
property mean#

Returns the mean of the distribution.

rsample(sample_shape=None)[source]#

Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched.

property support#

Returns a Constraint object representing this distribution’s support.

property variance#

Returns the variance of the distribution.

class gluonts.torch.distributions.truncated_normal.TruncatedNormalOutput(min: float = - 1.0, max: float = 1.0, upscale: float = 5.0, tanh_loc: bool = False)[source]#

Bases: gluonts.torch.distributions.distribution_output.DistributionOutput

distr_cls#

alias of gluonts.torch.distributions.truncated_normal.TruncatedNormal

distribution(distr_args, loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None) torch.distributions.distribution.Distribution[source]#

Construct the associated distribution, given the collection of constructor arguments and, optionally, a scale tensor.

Parameters
  • distr_args – Constructor arguments for the underlying Distribution type.

  • loc – Optional tensor, of the same shape as the batch_shape+event_shape of the resulting distribution.

  • scale – Optional tensor, of the same shape as the batch_shape+event_shape of the resulting distribution.

classmethod domain_map(loc: torch.Tensor, scale: torch.Tensor)[source]#

Converts arguments to the right shape and domain.

The domain depends on the type of distribution, while the correct shape is obtained by reshaping the trailing axis in such a way that the returned tensors define a distribution of the right event_shape.

property event_shape: Tuple#

Shape of each individual event compatible with the output object.