gluonts.torch.distributions.discrete_distribution module#

class gluonts.torch.distributions.discrete_distribution.DiscreteDistribution(values: torch.Tensor, probs: torch.Tensor, validate_args: Optional[bool] = None)[source]#

Bases: torch.distributions.distribution.Distribution

Implements discrete distribution where the underlying random variable takes a value from the finite set values with the corresponding probabilities.

Note: values can have duplicates in which case the probability mass of duplicates is added up.

A natural loss function, especially given that the new observation does not have to be from the finite set values, is ranked probability score (RPS). For this reason and to be consitent with terminology of other models, log_prob is implemented as the negative RPS.

static adjust_probs(values_sorted, probs_sorted)[source]#

Puts probability mass of all duplicate values into one position (last index of the duplicate).

Assumption: values_sorted is sorted!

Parameters
  • values_sorted

  • probs_sorted

Returns

log_prob(obs: torch.Tensor)[source]#

Returns the log of the probability density/mass function evaluated at value.

Parameters

value (Tensor) –

mean()[source]#

Returns the mean of the distribution.

quantile_losses(obs: torch.Tensor, quantiles: torch.Tensor, levels: torch.Tensor)[source]#
rps(obs: torch.Tensor, check_for_duplicates: bool = True)[source]#

Implements ranked probability score which is the sum of the qunatile losses for all possible quantiles.

Here, the number of quantiles is finite and is equal to the number of unique values in (each batch element of) obs.

Parameters
  • obs

  • check_for_duplicates

sample(sample_shape=torch.Size([]))[source]#

Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.