diff --git a/tables/automl/automl_tables_predict.py b/tables/automl/automl_tables_predict.py index 4a3423e3d53..a20716ff860 100644 --- a/tables/automl/automl_tables_predict.py +++ b/tables/automl/automl_tables_predict.py @@ -112,6 +112,32 @@ def batch_predict( # [END automl_tables_batch_predict] +def exported_model_predict(): + """Make a prediction for the exported model.""" + # [START automl_tables_exported_model_predict] + import requests + + response = requests.post( + "http://localhost:8080/predict", + json={ + "instances": [ + { + "categorical_col": "mouse", + "num_array_col": [1, 2, 3], + "struct_col": {"foo": "piano", "bar": "2019-05-17T23:56:09.05Z"}, + }, + { + "categorical_col": "dog", + "num_array_col": [5, 6, 7], + "struct_col": {"foo": "guitar", "bar": "2019-06-17T23:56:09.05Z"}, + }, + ] + }, + ) + print(response.json()) + # [END automl_tables_exported_model_predict] + + if __name__ == "__main__": parser = argparse.ArgumentParser( description=__doc__, @@ -130,6 +156,8 @@ def batch_predict( batch_predict_parser.add_argument("--input_path") batch_predict_parser.add_argument("--output_path") + subparsers.add_parser("exported_model_predict", help=predict.__doc__) + project_id = os.environ["PROJECT_ID"] compute_region = os.environ["REGION_NAME"] @@ -148,3 +176,6 @@ def batch_predict( args.input_path, args.output_path, ) + + if args.command == "exported_model_predict": + exported_model_predict()