MultiModalPredictor.fit

MultiModalPredictor.fit(train_data: DataFrame | str, presets: str | None = None, tuning_data: DataFrame | str | None = None, max_num_tuning_data: int | None = None, id_mappings: Dict[str, Dict] | Dict[str, Series] | None = None, time_limit: int | None = None, save_path: str | None = None, hyperparameters: str | Dict | List[str] | None = None, column_types: dict | None = None, holdout_frac: float | None = None, teacher_predictor: str | MultiModalPredictor | None = None, seed: int | None = 0, standalone: bool | None = True, hyperparameter_tune_kwargs: dict | None = None, clean_ckpts: bool | None = True)[source]

Fit models to predict a column of a data table (label) based on the other columns (features).

Parameters:
  • train_data – A pd.DataFrame containing training data.

  • presets – Presets regarding model quality, e.g., best_quality, high_quality, and medium_quality. Each quality has its corresponding HPO presets: ‘best_quality_hpo’, ‘high_quality_hpo’, and ‘medium_quality_hpo’.

  • tuning_data – A pd.DataFrame containing validation data, which should have the same columns as the train_data. If tuning_data = None, fit() will automatically hold out some random validation data from train_data.

  • max_num_tuning_data – The maximum number of tuning samples (used for object detection).

  • id_mappings – Id-to-content mappings (used for semantic matching). The contents can be text, image, etc. This is used when the pd.DataFrame contains the query/response identifiers instead of their contents.

  • time_limit – How long fit() should run for (wall clock time in seconds). If not specified, fit() will run until the model has completed training.

  • save_path – Path to directory where models and artifacts should be saved.

  • hyperparameters

    This is to override some default configurations. For example, changing the text and image backbones can be done by formatting:

    a string hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”

    or a list of strings hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]

    or a dictionary hyperparameters = {

    ”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,

    }

  • column_types

    A dictionary that maps column names to their data types. For example: column_types = {“item_name”: “text”, “image”: “image_path”, “product_description”: “text”, “height”: “numerical”} may be used for a table with columns: “item_name”, “brand”, “product_description”, and “height”. If None, column_types will be automatically inferred from the data. The current supported types are:

    • ”image_path”: each row in this column is one image path.

    • ”text”: each row in this column contains text (sentence, paragraph, etc.).

    • ”numerical”: each row in this column contains a number.

    • ”categorical”: each row in this column belongs to one of K categories.

  • holdout_frac – Fraction of train_data to holdout as tuning_data for optimizing hyperparameters or early stopping (ignored unless tuning_data = None). Default value (if None) is selected based on the number of rows in the training data and whether hyperparameter optimization is utilized.

  • teacher_predictor – The pre-trained teacher predictor or its saved path. If provided, fit() can distill its knowledge to a student predictor, i.e., the current predictor.

  • seed – The random seed to be used for training (default 0).

  • standalone – Whether to save the entire model for offline deployment.

  • hyperparameter_tune_kwargs

    Hyperparameter tuning strategy and kwargs (for example, how many HPO trials to run). If None, then hyperparameter tuning will not be performed.

    num_trials: int

    How many HPO trials to run. Either num_trials or time_limit to fit needs to be specified.

    scheduler: Union[str, ray.tune.schedulers.TrialScheduler]

    If str is passed, AutoGluon will create the scheduler for you with some default parameters. If ray.tune.schedulers.TrialScheduler object is passed, you are responsible for initializing the object.

    scheduler_init_args: Optional[dict] = None

    If provided str to scheduler, you can optionally provide custom init_args to the scheduler

    searcher: Union[str, ray.tune.search.SearchAlgorithm, ray.tune.search.Searcher]

    If str is passed, AutoGluon will create the searcher for you with some default parameters. If ray.tune.schedulers.TrialScheduler object is passed, you are responsible for initializing the object. You don’t need to worry about metric and mode of the searcher object. AutoGluon will figure it out by itself.

    scheduler_init_args: Optional[dict] = None

    If provided str to searcher, you can optionally provide custom init_args to the searcher You don’t need to worry about metric and mode. AutoGluon will figure it out by itself.

  • clean_ckpts – Whether to clean the intermediate checkpoints after training.

Return type:

An “MultiModalPredictor” object (itself).