gluonts.mx.trainer.model_averaging module#

class gluonts.mx.trainer.model_averaging.AveragingStrategy(num_models: int = 5, metric: str = 'score', maximize: bool = False)[source]#

Bases: object

apply(model_path: str) str[source]#

Averages model parameters of serialized models based on the selected model strategy and metric. IMPORTANT: Depending on the metric the user might want to minimize or maximize. The maximize flag has to be chosen appropriately to reflect this.

Parameters

model_path – Path to the models directory.

Return type

Path to file with the averaged model.

average(param_paths: List[str], weights: List[float]) Dict[source]#

Averages parameters from a list of .params file paths.

Parameters
  • param_paths – List of paths to parameter files.

  • weights – List of weights for the parameter average.

Return type

Averaged parameter dictionary.

static average_arrays(arrays: List[mxnet.ndarray.ndarray.NDArray], weights: List[float]) mxnet.ndarray.ndarray.NDArray[source]#

Takes a list of arrays of the same shape and computes the element wise weighted average.

Parameters
  • arrays – List of NDArrays with the same shape that will be averaged.

  • weights – List of weights for the parameter average.

Return type

The average of the NDArrays in the same context as arrays[0].

static get_checkpoint_information(model_path: str) List[Dict][source]#
Parameters

model_path – Path to the models directory.

Returns

  • List of checkpoint information dictionaries (metric, epoch_no,

  • checkpoint path).

select_checkpoints(checkpoints: List[Dict]) Tuple[List[str], List[float]][source]#

Selects checkpoints and computes weights for the selected checkpoints.

Parameters

checkpoints – List of checkpoint information dictionaries.

Returns

  • List of selected checkpoint paths and list of corresponding

  • weights.

class gluonts.mx.trainer.model_averaging.ModelAveraging(avg_strategy: gluonts.mx.trainer.model_averaging.AveragingStrategy)[source]#

Bases: gluonts.mx.trainer.callback.Callback

Callback to implement model averaging strategies. Selects the checkpoints with the best loss values and computes the model average or weighted model average depending on the chosen avg_strategy.

Parameters

avg_strategy – AveragingStrategy, one of SelectNBestSoftmax or SelectNBestMean from gluonts.mx.trainer.model_averaging.

on_train_end(training_network: mxnet.gluon.block.HybridBlock, temporary_dir: str, ctx: Optional[mxnet.context.Context] = None) None[source]#

Hook that is called after training is finished. This is the last hook to be called.

Parameters
  • training_network – The network that was trained.

  • temporary_dir – The directory where model parameters are logged throughout training.

  • ctx – An MXNet context used.

class gluonts.mx.trainer.model_averaging.SelectNBestMean(num_models: int = 5, metric: str = 'score', maximize: bool = False)[source]#

Bases: gluonts.mx.trainer.model_averaging.AveragingStrategy

select_checkpoints(checkpoints: List[Dict]) Tuple[List[str], List[float]][source]#

Selects the checkpoints with the best metric values. The weights are equal for all checkpoints, i.e., w_i = 1/N.

Parameters

checkpoints – List of checkpoint information dictionaries.

Returns

  • List of selected checkpoint paths and list of corresponding

  • weights.

class gluonts.mx.trainer.model_averaging.SelectNBestSoftmax(num_models: int = 5, metric: str = 'score', maximize: bool = False)[source]#

Bases: gluonts.mx.trainer.model_averaging.AveragingStrategy

select_checkpoints(checkpoints: List[Dict]) Tuple[List[str], List[float]][source]#

Selects the checkpoints with the best metric values. The weights are the softmax of the metric values, i.e., w_i = exp(v_i) / sum(exp(v_j)) if maximize=True w_i = exp(-v_i) / sum(exp(-v_j)) if maximize=False

Parameters

checkpoints – List of checkpoint information dictionaries.

Returns

  • List of selected checkpoint paths and list of corresponding

  • weights.

gluonts.mx.trainer.model_averaging.save_epoch_info(tmp_path: str, epoch_info: dict) None[source]#

Writes the current epoch information into a json file in the model path.

Parameters
  • tmp_path – Temporary base path to save the epoch info.

  • epoch_info – Epoch information dictionary containing the parameters path, the epoch number and the tracking metric value.

Return type

None