gluonts.torch.model.tft package#
- class gluonts.torch.model.tft.TemporalFusionTransformerEstimator(freq: str, prediction_length: int, context_length: Optional[int] = None, quantiles: Optional[List[float]] = None, distr_output: Optional[gluonts.torch.distributions.output.Output] = None, num_heads: int = 4, hidden_dim: int = 32, variable_dim: int = 32, static_dims: Optional[List[int]] = None, dynamic_dims: Optional[List[int]] = None, past_dynamic_dims: Optional[List[int]] = None, static_cardinalities: Optional[List[int]] = None, dynamic_cardinalities: Optional[List[int]] = None, past_dynamic_cardinalities: Optional[List[int]] = None, time_features: Optional[List[Callable[[pandas.core.indexes.period.PeriodIndex], numpy.ndarray]]] = None, lr: float = 0.001, weight_decay: float = 1e-08, dropout_rate: float = 0.1, patience: int = 10, batch_size: int = 32, num_batches_per_epoch: int = 50, trainer_kwargs: Optional[Dict[str, Any]] = None, train_sampler: Optional[gluonts.transform.sampler.InstanceSampler] = None, validation_sampler: Optional[gluonts.transform.sampler.InstanceSampler] = None)[source]#
Bases:
gluonts.torch.model.estimator.PyTorchLightningEstimator
Estimator class to train a Temporal Fusion Transformer (TFT) model, as described in [LAL+21].
TFT internally performs feature selection when making forecasts. For this reason, the dimensions of real-valued features can be grouped together if they correspond to the same variable (e.g., treat weather features as a one feature and holiday indicators as another feature).
For example, if the dataset contains key “feat_static_real” with shape [batch_size, 3], we can, e.g., - set
static_dims = [3]
to treat all three dimensions as a single feature - setstatic_dims = [1, 1, 1]
to treat each dimension as a separate feature - setstatic_dims = [2, 1]
to treat the first two dims as a single featureSee
gluonts.torch.model.tft.TemporalFusionTransformerModel.input_shapes
for more details on how the model configuration corresponds to the expected input shapes.- Parameters
freq – Frequency of the data to train on and predict.
prediction_length (int) – Length of the prediction horizon.
context_length – Number of previous time series values provided as input to the encoder. (default: None, in which case context_length = prediction_length).
quantiles – List of quantiles that the model will learn to predict. Defaults to [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
distr_output – Distribution output to use (default:
QuantileOutput
).num_heads – Number of attention heads in self-attention layer in the decoder.
hidden_dim – Size of the LSTM & transformer hidden states.
variable_dim – Size of the feature embeddings.
static_dims – Sizes of the real-valued static features.
dynamic_dims – Sizes of the real-valued dynamic features that are known in the future.
past_dynamic_dims – Sizes of the real-valued dynamic features that are only known in the past.
static_cardinalities – Cardinalities of the categorical static features.
dynamic_cardinalities – Cardinalities of the categorical dynamic features that are known in the future.
past_dynamic_cardinalities – Cardinalities of the categorical dynamic features that are ony known in the past.
time_features – List of time features, from
gluonts.time_feature
, to use as dynamic real features in addition to the provided data (default: None, in which case these are automatically determined based on freq).lr – Learning rate (default:
1e-3
).weight_decay – Weight decay (default:
1e-8
).dropout_rate – Dropout regularization parameter (default: 0.1).
patience – Patience parameter for learning rate scheduler.
batch_size – The size of the batches to be used for training (default: 32).
num_batches_per_epoch (int = 50,) – Number of batches to be processed in each training epoch (default: 50).
trainer_kwargs – Additional arguments to provide to
pl.Trainer
for construction.train_sampler – Controls the sampling of windows during training.
validation_sampler – Controls the sampling of windows during validation.
- create_lightning_module() gluonts.torch.model.tft.lightning_module.TemporalFusionTransformerLightningModule [source]#
Create and return the network used for training (i.e., computing the loss).
- Returns
The network that computes the loss given input data.
- Return type
pl.LightningModule
- create_predictor(transformation: gluonts.transform._base.Transformation, module: gluonts.torch.model.tft.lightning_module.TemporalFusionTransformerLightningModule) gluonts.torch.model.predictor.PyTorchPredictor [source]#
Create and return a predictor object.
- Parameters
transformation – Transformation to be applied to data before it goes into the model.
module – A trained pl.LightningModule object.
- Returns
A predictor wrapping a nn.Module used for inference.
- Return type
- create_training_data_loader(data: gluonts.dataset.Dataset, module: gluonts.torch.model.tft.lightning_module.TemporalFusionTransformerLightningModule, shuffle_buffer_length: Optional[int] = None, **kwargs) Iterable [source]#
Create a data loader for training purposes.
- Parameters
data – Dataset from which to create the data loader.
module – The pl.LightningModule object that will receive the batches from the data loader.
- Returns
The data loader, i.e. and iterable over batches of data.
- Return type
Iterable
- create_transformation() gluonts.transform._base.Transformation [source]#
Create and return the transformation needed for training and inference.
- Returns
The transformation that will be applied entry-wise to datasets, at training and inference time.
- Return type
- create_validation_data_loader(data: gluonts.dataset.Dataset, module: gluonts.torch.model.tft.lightning_module.TemporalFusionTransformerLightningModule, **kwargs) Iterable [source]#
Create a data loader for validation purposes.
- Parameters
data – Dataset from which to create the data loader.
module – The pl.LightningModule object that will receive the batches from the data loader.
- Returns
The data loader, i.e. and iterable over batches of data.
- Return type
Iterable
- lead_time: int#
- prediction_length: int#
- class gluonts.torch.model.tft.TemporalFusionTransformerLightningModule(model_kwargs: dict, lr: float = 0.001, patience: int = 10, weight_decay: float = 0.0)[source]#
Bases:
lightning.pytorch.core.module.LightningModule
A
pl.LightningModule
class that can be used to train aTemporalFusionTransformerModel
with PyTorch Lightning.This is a thin layer around a (wrapped)
TemporalFusionTransformerModel
object, that exposes the methods to evaluate training and validation loss.- Parameters
model_kwargs – Keyword arguments to construct the
TemporalFusionTransformerModel
to be trained.lr – Learning rate.
weight_decay – Weight decay regularization parameter.
patience – Patience parameter for learning rate scheduler.
- class gluonts.torch.model.tft.TemporalFusionTransformerModel(context_length: int, prediction_length: int, d_feat_static_real: Optional[List[int]] = None, c_feat_static_cat: Optional[List[int]] = None, d_feat_dynamic_real: Optional[List[int]] = None, c_feat_dynamic_cat: Optional[List[int]] = None, d_past_feat_dynamic_real: Optional[List[int]] = None, c_past_feat_dynamic_cat: Optional[List[int]] = None, num_heads: int = 4, d_hidden: int = 32, d_var: int = 32, dropout_rate: float = 0.1, distr_output: Optional[gluonts.torch.distributions.output.Output] = None)[source]#
Bases:
torch.nn.modules.module.Module
Temporal Fusion Transformer neural network.
Partially based on the implementation in github.com/kashif/pytorch-transformer-ts.
Inputs feat_static_real, feat_static_cat and feat_dynamic_real are mandatory. Inputs feat_dynamic_cat, past_feat_dynamic_real and past_feat_dynamic_cat are optional.
- describe_inputs(batch_size=1) gluonts.model.inputs.InputSpec [source]#
- forward(past_target: torch.Tensor, past_observed_values: torch.Tensor, feat_static_real: Optional[torch.Tensor], feat_static_cat: Optional[torch.Tensor], feat_dynamic_real: Optional[torch.Tensor], feat_dynamic_cat: Optional[torch.Tensor] = None, past_feat_dynamic_real: Optional[torch.Tensor] = None, past_feat_dynamic_cat: Optional[torch.Tensor] = None) Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor] [source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- loss(past_target: torch.Tensor, past_observed_values: torch.Tensor, future_target: torch.Tensor, future_observed_values: torch.Tensor, feat_static_real: torch.Tensor, feat_static_cat: torch.Tensor, feat_dynamic_real: torch.Tensor, feat_dynamic_cat: Optional[torch.Tensor] = None, past_feat_dynamic_real: Optional[torch.Tensor] = None, past_feat_dynamic_cat: Optional[torch.Tensor] = None) torch.Tensor [source]#
- training: bool#