gluonts.torch.model.estimator module#

class gluonts.torch.model.estimator.PyTorchLightningEstimator(trainer_kwargs: Dict[str, Any], lead_time: int = 0)[source]#

Bases: gluonts.model.estimator.Estimator

An Estimator type with utilities for creating PyTorch-Lightning-based models.

To extend this class, one needs to implement three methods: create_transformation, create_training_network, create_predictor, create_training_data_loader, and create_validation_data_loader.

create_lightning_module() lightning.pytorch.core.module.LightningModule[source]#

Create and return the network used for training (i.e., computing the loss).

Returns

The network that computes the loss given input data.

Return type

pl.LightningModule

create_predictor(transformation: gluonts.transform._base.Transformation, module) gluonts.torch.model.predictor.PyTorchPredictor[source]#

Create and return a predictor object.

Parameters
  • transformation – Transformation to be applied to data before it goes into the model.

  • module – A trained pl.LightningModule object.

Returns

A predictor wrapping a nn.Module used for inference.

Return type

Predictor

create_training_data_loader(data: gluonts.dataset.Dataset, module, **kwargs) Iterable[source]#

Create a data loader for training purposes.

Parameters
  • data – Dataset from which to create the data loader.

  • module – The pl.LightningModule object that will receive the batches from the data loader.

Returns

The data loader, i.e. and iterable over batches of data.

Return type

Iterable

create_transformation() gluonts.transform._base.Transformation[source]#

Create and return the transformation needed for training and inference.

Returns

The transformation that will be applied entry-wise to datasets, at training and inference time.

Return type

Transformation

create_validation_data_loader(data: gluonts.dataset.Dataset, module, **kwargs) Iterable[source]#

Create a data loader for validation purposes.

Parameters
  • data – Dataset from which to create the data loader.

  • module – The pl.LightningModule object that will receive the batches from the data loader.

Returns

The data loader, i.e. and iterable over batches of data.

Return type

Iterable

lead_time: int#
prediction_length: int#
train(training_data: gluonts.dataset.Dataset, validation_data: Optional[gluonts.dataset.Dataset] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ckpt_path: Optional[str] = None, **kwargs) gluonts.torch.model.predictor.PyTorchPredictor[source]#

Train the estimator on the given data.

Parameters
  • training_data – Dataset to train the model on.

  • validation_data – Dataset to validate the model on during training.

Returns

The predictor containing the trained model.

Return type

Predictor

train_from(predictor: gluonts.model.predictor.Predictor, training_data: gluonts.dataset.Dataset, validation_data: Optional[gluonts.dataset.Dataset] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ckpt_path: Optional[str] = None) gluonts.torch.model.predictor.PyTorchPredictor[source]#
train_model(training_data: gluonts.dataset.Dataset, validation_data: Optional[gluonts.dataset.Dataset] = None, from_predictor: Optional[gluonts.torch.model.predictor.PyTorchPredictor] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ckpt_path: Optional[str] = None, **kwargs) gluonts.torch.model.estimator.TrainOutput[source]#
class gluonts.torch.model.estimator.TrainOutput(transformation, trained_net, trainer, predictor)[source]#

Bases: tuple

predictor: gluonts.torch.model.predictor.PyTorchPredictor#

Alias for field number 3

trained_net: torch.nn.modules.module.Module#

Alias for field number 1

trainer: lightning.pytorch.trainer.trainer.Trainer#

Alias for field number 2

transformation: gluonts.transform._base.Transformation#

Alias for field number 0