gluonts.torch.model.i_transformer.lightning_module module#

class gluonts.torch.model.i_transformer.lightning_module.ITransformerLightningModule(model_kwargs: dict, num_parallel_samples: int = 100, lr: float = 0.001, weight_decay: float = 1e-08)[source]#

Bases: lightning.pytorch.core.module.LightningModule

A pl.LightningModule class that can be used to train a ITransformerModel with PyTorch Lightning.

This is a thin layer around a (wrapped) ITransformerModel object, that exposes the methods to evaluate training and validation loss.

Parameters
  • model_kwargs – Keyword arguments to construct the ITransformerModel to be trained.

  • num_parallel_samples – Number of evaluation samples per time series to sample during inference.

  • lr – Learning rate.

  • weight_decay – Weight decay regularization parameter.

configure_optimizers()[source]#

Returns the optimizer to use.

forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

Parameters
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns

Your model’s output

training_step(batch, batch_idx: int)[source]#

Execute training step.

validation_step(batch, batch_idx: int)[source]#

Execute validation step.