gluonts.torch.model.predictor module#

class gluonts.torch.model.predictor.PyTorchPredictor(input_names: List[str], prediction_net: torch.nn.modules.module.Module, batch_size: int, prediction_length: int, input_transform: gluonts.transform._base.Transformation, forecast_generator: gluonts.model.forecast_generator.ForecastGenerator = gluonts.model.forecast_generator.SampleForecastGenerator(), output_transform: Optional[Callable[[Dict[str, Any], numpy.ndarray], numpy.ndarray]] = None, lead_time: int = 0, device: Union[str, torch.device] = 'auto')[source]#

Bases: gluonts.model.predictor.RepresentablePredictor

classmethod deserialize(path: pathlib.Path, device: Optional[Union[torch.device, str]] = None) gluonts.torch.model.predictor.PyTorchPredictor[source]#

Load a serialized predictor from the given path.

Parameters
  • path – Path to the serialized files predictor.

  • **kwargs – Optional context/device parameter to be used with the predictor. If nothing is passed will use the GPU if available and CPU otherwise.

property network: torch.nn.modules.module.Module#
predict(dataset: gluonts.dataset.Dataset, num_samples: Optional[int] = None) Iterator[gluonts.model.forecast.Forecast][source]#

Compute forecasts for the time series in the provided dataset. This method is not implemented in this abstract class; please use one of the subclasses. :param dataset: The dataset containing the time series to predict.

Returns

Iterator over the forecasts, in the same order as the dataset iterable was provided.

Return type

Iterator[Forecast]

serialize(path: pathlib.Path) None[source]#
to(device: Union[str, torch.device]) gluonts.torch.model.predictor.PyTorchPredictor[source]#