gluonts.torch.model.tft package#

class gluonts.torch.model.tft.TemporalFusionTransformerEstimator(freq: str, prediction_length: int, context_length: Optional[int] = None, quantiles: Optional[List[float]] = None, distr_output: Optional[gluonts.torch.distributions.output.Output] = None, num_heads: int = 4, hidden_dim: int = 32, variable_dim: int = 32, static_dims: Optional[List[int]] = None, dynamic_dims: Optional[List[int]] = None, past_dynamic_dims: Optional[List[int]] = None, static_cardinalities: Optional[List[int]] = None, dynamic_cardinalities: Optional[List[int]] = None, past_dynamic_cardinalities: Optional[List[int]] = None, time_features: Optional[List[Callable[[pandas.core.indexes.period.PeriodIndex], numpy.ndarray]]] = None, lr: float = 0.001, weight_decay: float = 1e-08, dropout_rate: float = 0.1, patience: int = 10, batch_size: int = 32, num_batches_per_epoch: int = 50, trainer_kwargs: Optional[Dict[str, Any]] = None, train_sampler: Optional[gluonts.transform.sampler.InstanceSampler] = None, validation_sampler: Optional[gluonts.transform.sampler.InstanceSampler] = None)[source]#

Bases: gluonts.torch.model.estimator.PyTorchLightningEstimator

Estimator class to train a Temporal Fusion Transformer (TFT) model, as described in [LAL+21].

TFT internally performs feature selection when making forecasts. For this reason, the dimensions of real-valued features can be grouped together if they correspond to the same variable (e.g., treat weather features as a one feature and holiday indicators as another feature).

For example, if the dataset contains key “feat_static_real” with shape [batch_size, 3], we can, e.g., - set static_dims = [3] to treat all three dimensions as a single feature - set static_dims = [1, 1, 1] to treat each dimension as a separate feature - set static_dims = [2, 1] to treat the first two dims as a single feature

See gluonts.torch.model.tft.TemporalFusionTransformerModel.input_shapes for more details on how the model configuration corresponds to the expected input shapes.

Parameters
  • freq – Frequency of the data to train on and predict.

  • prediction_length (int) – Length of the prediction horizon.

  • context_length – Number of previous time series values provided as input to the encoder. (default: None, in which case context_length = prediction_length).

  • quantiles – List of quantiles that the model will learn to predict. Defaults to [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

  • distr_output – Distribution output to use (default: QuantileOutput).

  • num_heads – Number of attention heads in self-attention layer in the decoder.

  • hidden_dim – Size of the LSTM & transformer hidden states.

  • variable_dim – Size of the feature embeddings.

  • static_dims – Sizes of the real-valued static features.

  • dynamic_dims – Sizes of the real-valued dynamic features that are known in the future.

  • past_dynamic_dims – Sizes of the real-valued dynamic features that are only known in the past.

  • static_cardinalities – Cardinalities of the categorical static features.

  • dynamic_cardinalities – Cardinalities of the categorical dynamic features that are known in the future.

  • past_dynamic_cardinalities – Cardinalities of the categorical dynamic features that are ony known in the past.

  • time_features – List of time features, from gluonts.time_feature, to use as dynamic real features in addition to the provided data (default: None, in which case these are automatically determined based on freq).

  • lr – Learning rate (default: 1e-3).

  • weight_decay – Weight decay (default: 1e-8).

  • dropout_rate – Dropout regularization parameter (default: 0.1).

  • patience – Patience parameter for learning rate scheduler.

  • batch_size – The size of the batches to be used for training (default: 32).

  • num_batches_per_epoch (int = 50,) – Number of batches to be processed in each training epoch (default: 50).

  • trainer_kwargs – Additional arguments to provide to pl.Trainer for construction.

  • train_sampler – Controls the sampling of windows during training.

  • validation_sampler – Controls the sampling of windows during validation.

create_lightning_module() gluonts.torch.model.tft.lightning_module.TemporalFusionTransformerLightningModule[source]#

Create and return the network used for training (i.e., computing the loss).

Returns

The network that computes the loss given input data.

Return type

pl.LightningModule

create_predictor(transformation: gluonts.transform._base.Transformation, module: gluonts.torch.model.tft.lightning_module.TemporalFusionTransformerLightningModule) gluonts.torch.model.predictor.PyTorchPredictor[source]#

Create and return a predictor object.

Parameters
  • transformation – Transformation to be applied to data before it goes into the model.

  • module – A trained pl.LightningModule object.

Returns

A predictor wrapping a nn.Module used for inference.

Return type

Predictor

create_training_data_loader(data: gluonts.dataset.Dataset, module: gluonts.torch.model.tft.lightning_module.TemporalFusionTransformerLightningModule, shuffle_buffer_length: Optional[int] = None, **kwargs) Iterable[source]#

Create a data loader for training purposes.

Parameters
  • data – Dataset from which to create the data loader.

  • module – The pl.LightningModule object that will receive the batches from the data loader.

Returns

The data loader, i.e. and iterable over batches of data.

Return type

Iterable

create_transformation() gluonts.transform._base.Transformation[source]#

Create and return the transformation needed for training and inference.

Returns

The transformation that will be applied entry-wise to datasets, at training and inference time.

Return type

Transformation

create_validation_data_loader(data: gluonts.dataset.Dataset, module: gluonts.torch.model.tft.lightning_module.TemporalFusionTransformerLightningModule, **kwargs) Iterable[source]#

Create a data loader for validation purposes.

Parameters
  • data – Dataset from which to create the data loader.

  • module – The pl.LightningModule object that will receive the batches from the data loader.

Returns

The data loader, i.e. and iterable over batches of data.

Return type

Iterable

input_names()[source]#
lead_time: int#
prediction_length: int#
class gluonts.torch.model.tft.TemporalFusionTransformerLightningModule(model_kwargs: dict, lr: float = 0.001, patience: int = 10, weight_decay: float = 0.0)[source]#

Bases: lightning.pytorch.core.module.LightningModule

A pl.LightningModule class that can be used to train a TemporalFusionTransformerModel with PyTorch Lightning.

This is a thin layer around a (wrapped) TemporalFusionTransformerModel object, that exposes the methods to evaluate training and validation loss.

Parameters
  • model_kwargs – Keyword arguments to construct the TemporalFusionTransformerModel to be trained.

  • lr – Learning rate.

  • weight_decay – Weight decay regularization parameter.

  • patience – Patience parameter for learning rate scheduler.

configure_optimizers()[source]#

Returns the optimizer to use.

forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

Parameters
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns

Your model’s output

training_step(batch, batch_idx: int)[source]#

Execute training step.

validation_step(batch, batch_idx: int)[source]#

Execute validation step.

class gluonts.torch.model.tft.TemporalFusionTransformerModel(context_length: int, prediction_length: int, d_feat_static_real: Optional[List[int]] = None, c_feat_static_cat: Optional[List[int]] = None, d_feat_dynamic_real: Optional[List[int]] = None, c_feat_dynamic_cat: Optional[List[int]] = None, d_past_feat_dynamic_real: Optional[List[int]] = None, c_past_feat_dynamic_cat: Optional[List[int]] = None, num_heads: int = 4, d_hidden: int = 32, d_var: int = 32, dropout_rate: float = 0.1, distr_output: Optional[gluonts.torch.distributions.output.Output] = None)[source]#

Bases: torch.nn.modules.module.Module

Temporal Fusion Transformer neural network.

Partially based on the implementation in github.com/kashif/pytorch-transformer-ts.

Inputs feat_static_real, feat_static_cat and feat_dynamic_real are mandatory. Inputs feat_dynamic_cat, past_feat_dynamic_real and past_feat_dynamic_cat are optional.

describe_inputs(batch_size=1) gluonts.model.inputs.InputSpec[source]#
forward(past_target: torch.Tensor, past_observed_values: torch.Tensor, feat_static_real: Optional[torch.Tensor], feat_static_cat: Optional[torch.Tensor], feat_dynamic_real: Optional[torch.Tensor], feat_dynamic_cat: Optional[torch.Tensor] = None, past_feat_dynamic_real: Optional[torch.Tensor] = None, past_feat_dynamic_cat: Optional[torch.Tensor] = None) Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

input_types() Dict[str, torch.dtype][source]#
loss(past_target: torch.Tensor, past_observed_values: torch.Tensor, future_target: torch.Tensor, future_observed_values: torch.Tensor, feat_static_real: torch.Tensor, feat_static_cat: torch.Tensor, feat_dynamic_real: torch.Tensor, feat_dynamic_cat: Optional[torch.Tensor] = None, past_feat_dynamic_real: Optional[torch.Tensor] = None, past_feat_dynamic_cat: Optional[torch.Tensor] = None) torch.Tensor[source]#
training: bool#