gluonts.mx.trainer.model_averaging module#
- class gluonts.mx.trainer.model_averaging.AveragingStrategy(num_models: int = 5, metric: str = 'score', maximize: bool = False)[source]#
Bases:
object
- apply(model_path: str) str [source]#
Averages model parameters of serialized models based on the selected model strategy and metric. IMPORTANT: Depending on the metric the user might want to minimize or maximize. The maximize flag has to be chosen appropriately to reflect this.
- Parameters
model_path – Path to the models directory.
- Return type
Path to file with the averaged model.
- average(param_paths: List[str], weights: List[float]) Dict [source]#
Averages parameters from a list of .params file paths.
- Parameters
param_paths – List of paths to parameter files.
weights – List of weights for the parameter average.
- Return type
Averaged parameter dictionary.
- static average_arrays(arrays: List[mxnet.ndarray.ndarray.NDArray], weights: List[float]) mxnet.ndarray.ndarray.NDArray [source]#
Takes a list of arrays of the same shape and computes the element wise weighted average.
- Parameters
arrays – List of NDArrays with the same shape that will be averaged.
weights – List of weights for the parameter average.
- Return type
The average of the NDArrays in the same context as arrays[0].
- static get_checkpoint_information(model_path: str) List[Dict] [source]#
- Parameters
model_path – Path to the models directory.
- Returns
List of checkpoint information dictionaries (metric, epoch_no,
checkpoint path).
- select_checkpoints(checkpoints: List[Dict]) Tuple[List[str], List[float]] [source]#
Selects checkpoints and computes weights for the selected checkpoints.
- Parameters
checkpoints – List of checkpoint information dictionaries.
- Returns
List of selected checkpoint paths and list of corresponding
weights.
- class gluonts.mx.trainer.model_averaging.ModelAveraging(avg_strategy: gluonts.mx.trainer.model_averaging.AveragingStrategy)[source]#
Bases:
gluonts.mx.trainer.callback.Callback
Callback to implement model averaging strategies. Selects the checkpoints with the best loss values and computes the model average or weighted model average depending on the chosen avg_strategy.
- Parameters
avg_strategy – AveragingStrategy, one of SelectNBestSoftmax or SelectNBestMean from gluonts.mx.trainer.model_averaging.
- on_train_end(training_network: mxnet.gluon.block.HybridBlock, temporary_dir: str, ctx: Optional[mxnet.context.Context] = None) None [source]#
Hook that is called after training is finished. This is the last hook to be called.
- Parameters
training_network – The network that was trained.
temporary_dir – The directory where model parameters are logged throughout training.
ctx – An MXNet context used.
- class gluonts.mx.trainer.model_averaging.SelectNBestMean(num_models: int = 5, metric: str = 'score', maximize: bool = False)[source]#
Bases:
gluonts.mx.trainer.model_averaging.AveragingStrategy
- select_checkpoints(checkpoints: List[Dict]) Tuple[List[str], List[float]] [source]#
Selects the checkpoints with the best metric values. The weights are equal for all checkpoints, i.e., w_i = 1/N.
- Parameters
checkpoints – List of checkpoint information dictionaries.
- Returns
List of selected checkpoint paths and list of corresponding
weights.
- class gluonts.mx.trainer.model_averaging.SelectNBestSoftmax(num_models: int = 5, metric: str = 'score', maximize: bool = False)[source]#
Bases:
gluonts.mx.trainer.model_averaging.AveragingStrategy
- select_checkpoints(checkpoints: List[Dict]) Tuple[List[str], List[float]] [source]#
Selects the checkpoints with the best metric values. The weights are the softmax of the metric values, i.e., w_i = exp(v_i) / sum(exp(v_j)) if maximize=True w_i = exp(-v_i) / sum(exp(-v_j)) if maximize=False
- Parameters
checkpoints – List of checkpoint information dictionaries.
- Returns
List of selected checkpoint paths and list of corresponding
weights.
- gluonts.mx.trainer.model_averaging.save_epoch_info(tmp_path: str, epoch_info: dict) None [source]#
Writes the current epoch information into a json file in the model path.
- Parameters
tmp_path – Temporary base path to save the epoch info.
epoch_info – Epoch information dictionary containing the parameters path, the epoch number and the tracking metric value.
- Return type
None