From 397a2abdc0de5076d11930525784b97394dc7f9c Mon Sep 17 00:00:00 2001 From: Ashley Xu Date: Wed, 5 Jun 2024 23:35:08 +0000 Subject: [PATCH 1/2] feat: support score() in GeminiTextGenerator --- bigframes/ml/llm.py | 53 +++++++++++++++++++++++++++++++++++ tests/system/load/test_llm.py | 28 ++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index f62867cdd5..db34281a7b 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -571,6 +571,8 @@ class GeminiTextGenerator(base.BaseEstimator): Connection to connect with remote service. str of the format ... If None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach permission if the connection isn't fully set up. + max_iterations (Optional[int], Default to 300): + The number of steps to run when performing supervised tuning. """ def __init__( @@ -581,9 +583,11 @@ def __init__( ] = "gemini-pro", session: Optional[bigframes.Session] = None, connection_name: Optional[str] = None, + max_iterations: int = 300, ): self.model_name = model_name self.session = session or bpd.get_global_session() + self.max_iterations = max_iterations self._bq_connection_manager = self.session.bqconnectionmanager connection_name = connection_name or self.session._bq_connection @@ -647,6 +651,55 @@ def _from_bq( model._bqml_model = core.BqmlModel(session, bq_model) return model + @property + def _bqml_options(self) -> dict: + """The model options as they will be set for BQML""" + options = { + "max_iterations": self.max_iterations, + "data_split_method": "NO_SPLIT", + } + return options + + def fit( + self, + X: Union[bpd.DataFrame, bpd.Series], + y: Union[bpd.DataFrame, bpd.Series], + ) -> GeminiTextGenerator: + """Fine tune GeminiTextGenerator model. + + .. note:: + + This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the + Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is" + and might have limited support. For more information, see the launch stage descriptions + (https://cloud.google.com/products#product-launch-stages). + + Args: + X (bigframes.dataframe.DataFrame or bigframes.series.Series): + DataFrame of shape (n_samples, n_features). Training data. + y (bigframes.dataframe.DataFrame or bigframes.series.Series: + Training labels. + + Returns: + GeminiTextGenerator: Fitted estimator. + """ + if self._bqml_model.model_name.startswith("gemini-1.5"): + raise NotImplementedError("Fit is not supported for gemini-1.5 model.") + + X, y = utils.convert_to_dataframe(X, y) + + options = self._bqml_options + options["endpoint"] = "gemini-1.0-pro-002" + options["prompt_col"] = X.columns.tolist()[0] + + self._bqml_model = self._bqml_model_factory.create_llm_remote_model( + X, + y, + options=options, + connection_name=self.connection_name, + ) + return self + def predict( self, X: Union[bpd.DataFrame, bpd.Series], diff --git a/tests/system/load/test_llm.py b/tests/system/load/test_llm.py index fd13662275..525beaa428 100644 --- a/tests/system/load/test_llm.py +++ b/tests/system/load/test_llm.py @@ -112,3 +112,31 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index): "evaluation_status", ] assert all(col in score_result_col for col in expected_col) + + +@pytest.mark.flaky(retries=2) +def test_llm_gemini_configure_fit(llm_fine_tune_df_default_index, llm_remote_text_df): + model = bigframes.ml.llm.GeminiTextGenerator( + model_name="gemini-pro", max_iterations=1 + ) + + df = llm_fine_tune_df_default_index.dropna().sample(n=100) + X_train = df[["prompt"]] + y_train = df[["label"]] + model.fit(X_train, y_train) + + assert model is not None + + df = model.predict( + llm_remote_text_df["prompt"], + temperature=0.5, + max_output_tokens=100, + top_k=20, + top_p=0.5, + ).to_pandas() + assert df.shape == (3, 4) + assert "ml_generate_text_llm_result" in df.columns + series = df["ml_generate_text_llm_result"] + assert all(series.str.len() == 1) + + # TODO(ashleyxu b/335492787): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept From 1ed29976c5e45ff0c1fee071817d6530d8a6ede5 Mon Sep 17 00:00:00 2001 From: Ashley Xu Date: Thu, 6 Jun 2024 18:23:24 +0000 Subject: [PATCH 2/2] address comments --- bigframes/ml/llm.py | 2 +- tests/system/load/test_llm.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index db34281a7b..2517178d89 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -665,7 +665,7 @@ def fit( X: Union[bpd.DataFrame, bpd.Series], y: Union[bpd.DataFrame, bpd.Series], ) -> GeminiTextGenerator: - """Fine tune GeminiTextGenerator model. + """Fine tune GeminiTextGenerator model. Only support "gemini-pro" model for now. .. note:: diff --git a/tests/system/load/test_llm.py b/tests/system/load/test_llm.py index 713f98f588..fd047b3ba6 100644 --- a/tests/system/load/test_llm.py +++ b/tests/system/load/test_llm.py @@ -107,9 +107,8 @@ def test_llm_gemini_configure_fit(llm_fine_tune_df_default_index, llm_remote_tex model_name="gemini-pro", max_iterations=1 ) - df = llm_fine_tune_df_default_index.dropna().sample(n=100) - X_train = df[["prompt"]] - y_train = df[["label"]] + X_train = llm_fine_tune_df_default_index[["prompt"]] + y_train = llm_fine_tune_df_default_index[["label"]] model.fit(X_train, y_train) assert model is not None