gluonts.torch.model.predictor module#
- class gluonts.torch.model.predictor.PyTorchPredictor(input_names: List[str], prediction_net: torch.nn.modules.module.Module, batch_size: int, prediction_length: int, input_transform: gluonts.transform._base.Transformation, forecast_generator: gluonts.model.forecast_generator.ForecastGenerator = gluonts.model.forecast_generator.SampleForecastGenerator(), output_transform: Optional[Callable[[Dict[str, Any], numpy.ndarray], numpy.ndarray]] = None, lead_time: int = 0, device: Union[str, torch.device] = 'auto')[source]#
Bases:
gluonts.model.predictor.RepresentablePredictor
- classmethod deserialize(path: pathlib.Path, device: Optional[Union[torch.device, str]] = None) gluonts.torch.model.predictor.PyTorchPredictor [source]#
Load a serialized predictor from the given path.
- Parameters
path – Path to the serialized files predictor.
**kwargs – Optional context/device parameter to be used with the predictor. If nothing is passed will use the GPU if available and CPU otherwise.
- property network: torch.nn.modules.module.Module#
- predict(dataset: gluonts.dataset.Dataset, num_samples: Optional[int] = None) Iterator[gluonts.model.forecast.Forecast] [source]#
Compute forecasts for the time series in the provided dataset. This method is not implemented in this abstract class; please use one of the subclasses. :param dataset: The dataset containing the time series to predict.
- Returns
Iterator over the forecasts, in the same order as the dataset iterable was provided.
- Return type
Iterator[Forecast]
- to(device: Union[str, torch.device]) gluonts.torch.model.predictor.PyTorchPredictor [source]#