diff --git a/.gitignore b/.gitignore index f7c77e4d3e..d083ea1ddc 100644 --- a/.gitignore +++ b/.gitignore @@ -60,6 +60,5 @@ coverage.xml system_tests/local_test_setup # Make sure a generated file isn't accidentally committed. -demo.ipynb pylintrc pylintrc.test diff --git a/bigframes/ml/forecasting.py b/bigframes/ml/forecasting.py index 7aa8ba5a5f..2e93e5485f 100644 --- a/bigframes/ml/forecasting.py +++ b/bigframes/ml/forecasting.py @@ -36,6 +36,8 @@ "holiday_region": "holidayRegion", "clean_spikes_and_dips": "cleanSpikesAndDips", "adjust_step_changes": "adjustStepChanges", + "forecast_limit_upper_bound": "forecastLimitUpperBound", + "forecast_limit_lower_bound": "forecastLimitLowerBound", "time_series_length_fraction": "timeSeriesLengthFraction", "min_time_series_length": "minTimeSeriesLength", "max_time_series_length": "maxTimeSeriesLength", @@ -78,6 +80,17 @@ class ARIMAPlus(base.SupervisedTrainableWithIdColPredictor): adjust_step_changes (bool, default True): Determines whether or not to perform automatic step change detection and adjustment in the model training pipeline. + forecast_limit_upper_bound (float or None, default None): + The upper bound of the forecasting values. When you specify the ``forecast_limit_upper_bound`` option, all of the forecast values must be less than the specified value. + For example, if you set ``forecast_limit_upper_bound`` to 100, then all of the forecast values are less than 100. + Also, all values greater than or equal to the ``forecast_limit_upper_bound`` value are excluded from modelling. + The forecasting limit ensures that forecasts stay within limits. + + forecast_limit_lower_bound (float or None, default None): + The lower bound of the forecasting values where the minimum value allowed is 0. When you specify the ``forecast_limit_lower_bound`` option, all of the forecast values must be greater than the specified value. + For example, if you set ``forecast_limit_lower_bound`` to 0, then all of the forecast values are larger than 0. Also, all values less than or equal to the ``forecast_limit_lower_bound`` value are excluded from modelling. + The forecasting limit ensures that forecasts stay within limits. + time_series_length_fraction (float or None, default None): The fraction of the interpolated length of the time series that's used to model the time series trend component. All of the time points of the time series are used to model the non-trend component. @@ -106,6 +119,8 @@ def __init__( holiday_region: Optional[str] = None, clean_spikes_and_dips: bool = True, adjust_step_changes: bool = True, + forecast_limit_lower_bound: Optional[float] = None, + forecast_limit_upper_bound: Optional[float] = None, time_series_length_fraction: Optional[float] = None, min_time_series_length: Optional[int] = None, max_time_series_length: Optional[int] = None, @@ -121,6 +136,8 @@ def __init__( self.holiday_region = holiday_region self.clean_spikes_and_dips = clean_spikes_and_dips self.adjust_step_changes = adjust_step_changes + self.forecast_limit_upper_bound = forecast_limit_upper_bound + self.forecast_limit_lower_bound = forecast_limit_lower_bound self.time_series_length_fraction = time_series_length_fraction self.min_time_series_length = min_time_series_length self.max_time_series_length = max_time_series_length @@ -175,6 +192,10 @@ def _bqml_options(self) -> dict: if self.include_drift: options["include_drift"] = True + if self.forecast_limit_upper_bound is not None: + options["forecast_limit_upper_bound"] = self.forecast_limit_upper_bound + if self.forecast_limit_lower_bound is not None: + options["forecast_limit_lower_bound"] = self.forecast_limit_lower_bound return options diff --git a/tests/system/large/ml/test_forecasting.py b/tests/system/large/ml/test_forecasting.py index 7c070fd200..56b93e5338 100644 --- a/tests/system/large/ml/test_forecasting.py +++ b/tests/system/large/ml/test_forecasting.py @@ -154,6 +154,7 @@ def test_arima_plus_model_fit_params( holiday_region="US", clean_spikes_and_dips=False, adjust_step_changes=False, + forecast_limit_lower_bound=0.0, time_series_length_fraction=0.5, min_time_series_length=10, trend_smoothing_window_size=5, @@ -183,6 +184,8 @@ def test_arima_plus_model_fit_params( assert reloaded_model.holiday_region == "US" assert reloaded_model.clean_spikes_and_dips is False assert reloaded_model.adjust_step_changes is False + # TODO(b/391399223): API must return forecastLimitLowerBound for the following assertion + # assert reloaded_model.forecast_limit_lower_bound == 0.0 assert reloaded_model.time_series_length_fraction == 0.5 assert reloaded_model.min_time_series_length == 10 assert reloaded_model.trend_smoothing_window_size == 5