gluonts.mx.trainer.learning_rate_scheduler module#

class gluonts.mx.trainer.learning_rate_scheduler.LearningRateReduction(objective: typing_extensions.Literal[min, max], patience: int, base_lr: float = 0.01, decay_factor: float = 0.5, min_lr: float = 0.0)[source]#

Bases: gluonts.mx.trainer.callback.Callback

This Callback decreases the learning rate based on the value of some validation metric to be optimized (maximized or minimized). The value of such metric is provided by calling the step method on the scheduler. A patience parameter must be provided, and the scheduler will reduce the learning rate if no improvement in the metric is done before patience observations of the metric.

Examples

patience = 0: learning rate will decrease at every call to step, regardless of the metric value

patience = 1: learning rate is reduced as soon step is called with a metric value which does not improve over the best encountered

patience = 10: learning rate is reduced if no improvement in the metric is recorded in 10 successive calls to step

Parameters
  • objective – String, can either be “min” or “max”.

  • patience – The patience to observe before reducing the learning rate, nonnegative integer.

  • base_lr – Initial learning rate to be used.

  • decay_factor – Factor (between 0 and 1) by which to decrease the learning rate.

  • min_lr – Lower bound for the learning rate, learning rate will never go below min_lr.

on_epoch_end(epoch_no: int, epoch_loss: float, training_network: mxnet.gluon.block.HybridBlock, trainer: mxnet.gluon.trainer.Trainer, best_epoch_info: Dict[str, Any], ctx: mxnet.context.Context) bool[source]#

Hook that is called after every epoch. As on_train_epoch_end and on_validation_epoch_end, it returns a boolean whether training should continue. This hook is always called after on_train_epoch_end and on_validation_epoch_end. It is called regardless of these hooks’ return values.

Parameters
  • epoch_no – The current epoch (the first epoch has epoch_no = 0).

  • epoch_loss – The validation loss that was recorded in the last epoch if validation data was provided. The training loss otherwise.

  • training_network – The network that is being trained.

  • trainer – The trainer which is running the training.

  • best_epoch_info – Aggregate information about the best epoch. Contains keys params_path, epoch_no and score. The score is the best validation loss if validation data is provided or the best training loss otherwise.

  • ctx – The MXNet context used.

Returns

A boolean whether the training should continue. Defaults to True.

Return type

bool

class gluonts.mx.trainer.learning_rate_scheduler.Max(best: float = - inf)[source]#

Bases: gluonts.mx.trainer.learning_rate_scheduler.Objective

best: float = -inf#
should_update(metric: float) bool[source]#
class gluonts.mx.trainer.learning_rate_scheduler.MetricAttentiveScheduler(patience: gluonts.mx.trainer.learning_rate_scheduler.Patience, learning_rate: float = 0.01, decay_factor: float = 0.5, min_learning_rate: float = 0.0, max_num_decays: Optional[int] = None)[source]#

Bases: object

This scheduler decreases the learning rate based on the value of some validation metric to be optimized (maximized or minimized). The value of such metric is provided by calling the step method on the scheduler. A patience parameter must be provided, and the scheduler will reduce the learning rate if no improvement in the metric is done before patience observations of the metric.

Examples

patience = 0: learning rate will decrease at every call to step, regardless of the metric value

patience = 1: learning rate is reduced as soon step is called with a metric value which does not improve over the best encountered

patience = 10: learning rate is reduced if no improvement in the metric is recorded in 10 successive calls to step

Parameters
  • objective – String, can either be “min” or “max”

  • patience (gluonts.mx.trainer.learning_rate_scheduler.Patience) – The patience to observe before reducing the learning rate, nonnegative integer.

  • base_lr – Initial learning rate to be used.

  • decay_factor (float) – Factor (between 0 and 1) by which to decrease the learning rate.

  • min_learning_rate (float) – Lower bound for the learning rate, learning rate will never go below min_learning_rate.

decay_factor: float = 0.5#
learning_rate: float = 0.01#
max_num_decays: Optional[int] = None#
min_learning_rate: float = 0.0#
patience: gluonts.mx.trainer.learning_rate_scheduler.Patience#
step(metric_value: float) bool[source]#

Inform the scheduler of the new value of the metric that is being optimized. This method should be invoked at regular intervals (e.g. at the end of every epoch, after computing a validation score).

Parameters

metric_value – Value of the metric that is being optimized.

Return type

bool value indicating, whether to continue training

class gluonts.mx.trainer.learning_rate_scheduler.Min(best: float = inf)[source]#

Bases: gluonts.mx.trainer.learning_rate_scheduler.Objective

best: float = inf#
should_update(metric: float) bool[source]#
class gluonts.mx.trainer.learning_rate_scheduler.Objective(best: float)[source]#

Bases: object

best: float#
static from_str(s: typing_extensions.Literal[min, max]) gluonts.mx.trainer.learning_rate_scheduler.Objective[source]#
should_update(metric: float) bool[source]#
update(metric: float) bool[source]#
class gluonts.mx.trainer.learning_rate_scheduler.Patience(patience: int, objective: gluonts.mx.trainer.learning_rate_scheduler.Objective)[source]#

Bases: object

Simple patience tracker.

Given an Objective, it will check whether the metric has improved and update its patience count. A better value sets the patience back to zero.

In addition, one needs to call reset() explicitly after the patience was exceeded, otherwise RuntimError is raised when trying to invoke step.

Patience keeps track of the number of invocations to reset, via num_resets.

current_patience: int = 0#
exceeded: bool = False#
num_resets: int = 0#
objective: gluonts.mx.trainer.learning_rate_scheduler.Objective#
patience: int#
reset() None[source]#
step(metric_value: float) bool[source]#