Download this notebook

Trainer callbacks#

This notebook illustrates how one can control the training procedure of MXNet-based models by providing callbacks to the Trainer class. A callback is a function which gets called at one or more specific hook points during training. You can use predefined GluonTS callbacks like TrainingHistory, ModelAveraging or TerminateOnNaN, or you can implement your own callback.

[1]:
from gluonts.dataset.repository import get_dataset

dataset = get_dataset("m4_hourly")
prediction_length = dataset.metadata.prediction_length
freq = dataset.metadata.freq

Using a single callback#

To use callbacks, simply pass them as a list when constructing the Trainer: in the following example, we are using the TrainingHistory callback to record loss values measured during training.

[2]:
from gluonts.mx import SimpleFeedForwardEstimator, Trainer
from gluonts.mx.trainer.callback import TrainingHistory

# defining a callback, which will log the training loss for each epoch
history = TrainingHistory()

trainer = Trainer(epochs=3, callbacks=[history])
estimator = SimpleFeedForwardEstimator(
    prediction_length=prediction_length, trainer=trainer
)

predictor = estimator.train(dataset.train, num_workers=None)
100%|██████████| 50/50 [00:00<00:00, 126.64it/s, epoch=1/3, avg_epoch_loss=5.55]
100%|██████████| 50/50 [00:00<00:00, 140.81it/s, epoch=2/3, avg_epoch_loss=4.7]
100%|██████████| 50/50 [00:00<00:00, 137.59it/s, epoch=3/3, avg_epoch_loss=4.54]

Print the training loss over the epochs:

[3]:
print(history.loss_history)
[5.546479229927063, 4.702160387039185, 4.540805015563965]

Using multiple callbacks#

To continue the training from a given predictor you can use the WarmStart callback. When you want to use more than one callback, just provide a list with multiple callback objects:

[4]:
from gluonts.mx.trainer.callback import WarmStart

warm_start = WarmStart(predictor=predictor)

trainer = Trainer(epochs=3, callbacks=[history, warm_start])

estimator = SimpleFeedForwardEstimator(
    prediction_length=prediction_length, trainer=trainer
)

predictor = estimator.train(dataset.train, num_workers=None)
100%|██████████| 50/50 [00:00<00:00, 133.93it/s, epoch=1/3, avg_epoch_loss=4.44]
100%|██████████| 50/50 [00:00<00:00, 138.14it/s, epoch=2/3, avg_epoch_loss=4.4]
100%|██████████| 50/50 [00:00<00:00, 142.33it/s, epoch=3/3, avg_epoch_loss=4.43]
[5]:
print(
    history.loss_history
)  # The training loss history of all 3+3 epochs we trained the model for
[5.546479229927063, 4.702160387039185, 4.540805015563965, 4.439644269943237, 4.402952268123626, 4.425053224563599]

Default callbacks#

In addition to the callbacks you specify, the Trainer class uses the two default callbacks ModelAveraging and LearningRateReduction. You can turn them off by setting add_default_callbacks=False when initializing the Trainer.

[6]:
trainer = Trainer(
    epochs=20, callbacks=[history]
)  # use the TrainingHistory Callback and the default callbacks.
trainer = Trainer(
    epochs=20, callbacks=[history], add_default_callbacks=False
)  # use only the TrainingHistory Callback
trainer = Trainer(epochs=20, add_default_callbacks=False)  # use no callback at all

Custom callbacks#

To implement your own callback you can write a class which inherits from gluonts.mx.trainer.Callback, and overwrite one or more of the hooks. Have a look at the abstract Callback class, the hooks take different arguments which you can use. Hook methods with boolean return value stop the training if False is returned.

Here is an example for a custom callback implementation which terminates training early based on the value of some metric (such as the RMSE). It only implements the hook method on_epoch_end which gets called after all batches of one epoch have been processed.

[7]:
import numpy as np
import mxnet as mx

from gluonts.evaluation import Evaluator
from gluonts.dataset.common import Dataset
from gluonts.mx import copy_parameters, GluonPredictor
from gluonts.mx.trainer.callback import Callback


class MetricInferenceEarlyStopping(Callback):
    """
    Early Stopping mechanism based on the prediction network.
    Can be used to base the Early Stopping directly on a metric of interest, instead of on the training/validation loss.
    In the same way as test datasets are used during model evaluation,
    the time series of the validation_dataset can overlap with the train dataset time series,
    except for a prediction_length part at the end of each time series.

    Parameters
    ----------
    validation_dataset
        An out-of-sample dataset which is used to monitor metrics
    predictor
        A gluon predictor, with a prediction network that matches the training network
    evaluator
        The Evaluator used to calculate the validation metrics.
    metric
        The metric on which to base the early stopping on.
    patience
        Number of epochs to train on given the metric did not improve more than min_delta.
    min_delta
        Minimum change in the monitored metric counting as an improvement
    verbose
        Controls, if the validation metric is printed after each epoch.
    minimize_metric
        The metric objective.
    restore_best_network
        Controls, if the best model, as assessed by the validation metrics is restored after training.
    num_samples
        The amount of samples drawn to calculate the inference metrics.
    """

    def __init__(
        self,
        validation_dataset: Dataset,
        predictor: GluonPredictor,
        evaluator: Evaluator = Evaluator(num_workers=None),
        metric: str = "MSE",
        patience: int = 10,
        min_delta: float = 0.0,
        verbose: bool = True,
        minimize_metric: bool = True,
        restore_best_network: bool = True,
        num_samples: int = 100,
    ):
        assert patience >= 0, "EarlyStopping Callback patience needs to be >= 0"
        assert min_delta >= 0, "EarlyStopping Callback min_delta needs to be >= 0.0"
        assert num_samples >= 1, "EarlyStopping Callback num_samples needs to be >= 1"

        self.validation_dataset = list(validation_dataset)
        self.predictor = predictor
        self.evaluator = evaluator
        self.metric = metric
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.restore_best_network = restore_best_network
        self.num_samples = num_samples

        if minimize_metric:
            self.best_metric_value = np.inf
            self.is_better = np.less
        else:
            self.best_metric_value = -np.inf
            self.is_better = np.greater

        self.validation_metric_history: List[float] = []
        self.best_network = None
        self.n_stale_epochs = 0

    def on_epoch_end(
        self,
        epoch_no: int,
        epoch_loss: float,
        training_network: mx.gluon.nn.HybridBlock,
        trainer: mx.gluon.Trainer,
        best_epoch_info: dict,
        ctx: mx.Context,
    ) -> bool:
        should_continue = True
        copy_parameters(training_network, self.predictor.prediction_net)

        from gluonts.evaluation.backtest import make_evaluation_predictions

        forecast_it, ts_it = make_evaluation_predictions(
            dataset=self.validation_dataset,
            predictor=self.predictor,
            num_samples=self.num_samples,
        )

        agg_metrics, item_metrics = self.evaluator(ts_it, forecast_it)
        current_metric_value = agg_metrics[self.metric]
        self.validation_metric_history.append(current_metric_value)

        if self.verbose:
            print(
                f"Validation metric {self.metric}: {current_metric_value}, best: {self.best_metric_value}"
            )

        if self.is_better(current_metric_value, self.best_metric_value):
            self.best_metric_value = current_metric_value

            if self.restore_best_network:
                training_network.save_parameters("best_network.params")

            self.n_stale_epochs = 0
        else:
            self.n_stale_epochs += 1
            if self.n_stale_epochs == self.patience:
                should_continue = False
                print(
                    f"EarlyStopping callback initiated stop of training at epoch {epoch_no}."
                )

                if self.restore_best_network:
                    print(
                        f"Restoring best network from epoch {epoch_no - self.patience}."
                    )
                    training_network.load_parameters("best_network.params")

        return should_continue

We can now use the custom callback as follows. Note that we’re running an extremely short number of epochs, simply to keep the runtime of the notebook manageable: feel free to increase the number of epochs to properly test the effectiveness of the callback.

[8]:
estimator = SimpleFeedForwardEstimator(prediction_length=prediction_length)
training_network = estimator.create_training_network()
transformation = estimator.create_transformation()

predictor = estimator.create_predictor(
    transformation=transformation, trained_network=training_network
)

es_callback = MetricInferenceEarlyStopping(
    validation_dataset=dataset.test, predictor=predictor, metric="MSE"
)

trainer = Trainer(epochs=5, callbacks=[es_callback])

estimator.trainer = trainer

pred = estimator.train(dataset.train)
100%|██████████| 50/50 [00:00<00:00, 136.43it/s, epoch=1/5, avg_epoch_loss=5.55]
Running evaluation: 414it [00:02, 153.83it/s]
Validation metric MSE: 16590203.479222953, best: inf
100%|██████████| 50/50 [00:00<00:00, 137.55it/s, epoch=2/5, avg_epoch_loss=4.69]
Running evaluation: 414it [00:02, 156.87it/s]
Validation metric MSE: 9028248.932885194, best: 16590203.479222953
100%|██████████| 50/50 [00:00<00:00, 139.02it/s, epoch=3/5, avg_epoch_loss=4.79]
Running evaluation: 414it [00:02, 157.63it/s]
Validation metric MSE: 16308248.984650122, best: 9028248.932885194
100%|██████████| 50/50 [00:00<00:00, 134.38it/s, epoch=4/5, avg_epoch_loss=4.62]
Running evaluation: 414it [00:02, 157.42it/s]
Validation metric MSE: 10582128.785360953, best: 9028248.932885194
100%|██████████| 50/50 [00:00<00:00, 138.63it/s, epoch=5/5, avg_epoch_loss=4.3]
Running evaluation: 414it [00:02, 157.09it/s]
Validation metric MSE: 10019828.282515068, best: 9028248.932885194