gluonts.mx.distribution.iresnet module#
- class gluonts.mx.distribution.iresnet.InvertibleResnetHybridBlock(event_shape, hidden_units: int = 16, num_hidden_layers: int = 1, num_inv_iters: int = 10, ignore_logdet: bool = False, activation: str = 'lipswish', num_power_iter: int = 1, flatten: bool = False, coeff: float = 0.9, use_caching: bool = True, *args, **kwargs)[source]#
Bases:
gluonts.mx.distribution.bijection.BijectionHybridBlock
Based on [BJC19], apart from f and f_inv that are swapped.
- property event_dim: int#
- property event_shape#
- f(x: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol] [source]#
Forward transformation of iResnet.
- Parameters
x – observations
- Returns
transformed tensor ` ext{iResnet}(x)`
- Return type
Tensor
- f_inv(y: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol] [source]#
Inverse transformation of iResnet
- Parameters
y – input tensor
- Returns
transformed tensor ` ext{iResnet}^{-1}(y)`
- Return type
Tensor
- log_abs_det_jac(x: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol], y: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol] [source]#
Logarithm of the absolute value of the Jacobian determinant corresponding to the iResnet Transform.
- Parameters
x – input of the forward transformation or output of the inverse transform
y – output of the forward transform or input of the inverse transform
- Returns
Jacobian evaluated for x as input or y as output
- Return type
Tensor
- gluonts.mx.distribution.iresnet.iresnet(num_blocks: int, **block_kwargs) gluonts.mx.distribution.bijection.ComposedBijectionHybridBlock [source]#
- Parameters
num_blocks – number of iResnet blocks
block_kwargs – keyword arguments given to initialize each block object
- gluonts.mx.distribution.iresnet.log_abs_det(A: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol] [source]#
Logarithm of the absolute value of matrix A :param A: Tensor matrix from which to compute the log absolute value of its
determinant
- Return type
Tensor