Few Shot Learning with AutoMM

Open In Colab Open In SageMaker Studio Lab

In this tutorial we introduce a simple but effective way for few shot classification problems. We present the functionality which leverages the high-quality features from foundation models and uses SVM for few shot classification tasks. Specifically, we extract sample features with pretrained models, and use the features for SVM learning. We show the effectiveness of the foundation-model-followed-by-SVM on a text classification dataset and an image classification dataset.

Few Shot Text Classification

Prepare Text Data

We prepare all datasets in the format of pd.DataFrame as in many of our tutorials have done. For this tutorial, we’ll use a small MLDoc dataset for demonstration. The dataset is a text classification dataset, which contains 4 classes and we downsampled the training data to 10 samples per class, a.k.a 10 shots. For more details regarding MLDoc please see this link.

import pandas as pd
import os
from autogluon.core.utils.loaders import load_zip

download_dir = "./ag_automm_tutorial_fs_cls"
zip_file = "https://automl-mm-bench.s3.amazonaws.com/nlp_datasets/MLDoc-10shot-en.zip"
load_zip.unzip(zip_file, unzip_dir=download_dir)
dataset_path = os.path.join(download_dir)
train_df = pd.read_csv(f"{dataset_path}/train.csv", names=["label", "text"])
test_df = pd.read_csv(f"{dataset_path}/test.csv", names=["label", "text"])
print(train_df)
print(test_df)
Downloading ./ag_automm_tutorial_fs_cls/file.zip from https://automl-mm-bench.s3.amazonaws.com/nlp_datasets/MLDoc-10shot-en.zip...
   label                                               text
0   GCAT  b'Secretary-General Kofi Annan expressed conce...
1   CCAT  b'The health of ABB Asea Brown Boveri AG\'s Po...
2   GCAT  b'Nepali Prime Minister Lokendra Bahadur Chand...
3   CCAT  b'Integ Inc said Thursday its net loss widened...
4   GCAT  b'These are the leading stories in the Skopje ...
5   ECAT  b'Fears of a slowdown in India\'s industrial g...
6   MCAT  b'The Australian Treasury will offer a total o...
7   CCAT  b'Malaysia\'s Suria Capital Holdings Bhd and M...
8   MCAT  b'The UK gilt repo market had a quiet session ...
9   CCAT  b"Commonwealth Edison Co's (ComEd) 794 megawat...
10  GCAT  b'Police arrested 47 people on Thursday in a c...
11  GCAT  b"Army troops in the Comoros island of Anjouan...
12  ECAT  b"The House Banking Committee is considering w...
13  GCAT  b'A possible international anti-drug centre in...
14  ECAT  b'Angela Knight, economic secretary to the Bri...
15  GCAT  b'Nearly 300 people were feared dead in floods...
16  MCAT  b'The Oslo stock index fell with other Europea...
17  ECAT  b'Morgan Keegan said it won $18.540 million of...
18  CCAT  b'Britons can bank on the phone, bank on the i...
19  CCAT  b"Standard Chartered Bank and Prudential Secur...
20  CCAT  b"United Water Resources Inc said it and Lyonn...
21  ECAT  b'Tanzania on Thursday unveiled its 1997/98 bu...
22  GCAT  b'U.S. President Bill Clinton will meet Prime ...
23  CCAT  b"Pacific Century Regional Developments Ltd sa...
24  MCAT  b'The Athens bourse ended 0.65 percent lower w...
25  ECAT  b'Sri Lanka broad money supply, or M2, is seen...
26  GCAT  b'Collated results of African Nations Cup prel...
27  GCAT  b'Philippine President Fidel Ramos said on Fri...
28  MCAT  b'Shanghai copper futures ended down on heavy ...
29  CCAT  b"Goldman Sachs & Co said on Monday that David...
30  ECAT  b'Maine\'s revenues were higher than forecast ...
31  CCAT  b'Thai animal feedmillers said on Monday they ...
32  MCAT  b"Worldwide trading volume in emerging markets...
33  ECAT  b'One week ended June 25 daily avgs-millions  ...
34  ECAT  b'Algeria\'s non-energy exports reached $688 m...
35  ECAT  b'U.S. seasonally adjusted retail sales rose 1...
36  MCAT  b'The Indonesian rupiah weakened against the d...
37  MCAT  b'Brazilian stocks ended slightly higher led b...
38  MCAT  b'The price of gold hung around the psychologi...
39  MCAT  b'The won closed stronger versus the dollar on...
     label                                               text
0     CCAT  b'RJR Nabisco Holdings Corp has prevailed over...
1     ECAT  b"Britain's economy grew 0.8 percent in the fo...
2     ECAT  b'Slovenia\'s state Institute of Macroeconomic...
3     CCAT  b"Belgium's second largest bank Credit Communa...
4     GCAT  b'The IRA ordered its guerrillas to observe a ...
...    ...                                                ...
3995  CCAT  b"A consortium comprising Itochu Corp and Hanj...
3996  ECAT  b"The volume of Hong Kong's domestic exports i...
3997  ECAT  b'The Danish finance ministry said on Tuesday ...
3998  GCAT  b'A court is to investigate charges that forme...
3999  MCAT  b"German consumers of feed grains, bread rye a...

[4000 rows x 2 columns]
100%|██████████| 2.59M/2.59M [00:00<00:00, 13.8MiB/s]

Train a Few Shot Classifier

In order to perform few shot classification, we need to use the few_shot_classification problem type.

from autogluon.multimodal import MultiModalPredictor

predictor_fs_text = MultiModalPredictor(
    problem_type="few_shot_classification",
    label="label",  # column name of the label
    eval_metric="acc",
)
predictor_fs_text.fit(train_df)
scores = predictor_fs_text.evaluate(test_df, metrics=["acc", "f1_macro"])
print(scores)
{'acc': 0.83575, 'f1_macro': 0.8344679316932194}
No path specified. Models will be saved in: "AutogluonModels/ag-20240716_224436"
=================== System Info ===================
AutoGluon Version:  1.1.1b20240716
Python Version:     3.10.13
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Fri May 17 18:07:48 UTC 2024
CPU Count:          8
Pytorch Version:    2.3.1+cu121
CUDA Version:       12.1
Memory Avail:       28.68 GB / 30.95 GB (92.7%)
Disk Space Avail:   188.96 GB / 255.99 GB (73.8%)
===================================================

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224436
    ```

INFO: Seed set to 0
/home/ci/opt/venv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
/home/ci/autogluon/multimodal/src/autogluon/multimodal/data/utils.py:470: UserWarning: provided max length: 512 is smaller than sentence-transformers/all-mpnet-base-v2's default: 514
  warnings.warn(
GPU Count: 1
GPU Count to be Used: 1
GPU 0 Name: Tesla T4
GPU 0 Memory: 0.42GB/15.0GB (Used/Total)

AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224436")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).

Compare to the Default Classifier

Let’s use the default classification problem type and compare the performance with the above.

from autogluon.multimodal import MultiModalPredictor

predictor_default_text = MultiModalPredictor(
    label="label",
    problem_type="classification",
    eval_metric="acc",
)
predictor_default_text.fit(train_data=train_df)
scores = predictor_default_text.evaluate(test_df, metrics=["acc", "f1_macro"])
print(scores)
{'acc': 0.60075, 'f1_macro': 0.5774389970044109}
No path specified. Models will be saved in: "AutogluonModels/ag-20240716_224533"
=================== System Info ===================
AutoGluon Version:  1.1.1b20240716
Python Version:     3.10.13
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Fri May 17 18:07:48 UTC 2024
CPU Count:          8
Pytorch Version:    2.3.1+cu121
CUDA Version:       12.1
Memory Avail:       27.70 GB / 30.95 GB (89.5%)
Disk Space Avail:   188.15 GB / 255.99 GB (73.5%)
===================================================
AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == object).
	4 unique label values:  ['GCAT', 'CCAT', 'ECAT', 'MCAT']
	If 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224533
    ```

INFO: Seed set to 0
/home/ci/opt/venv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
GPU Count: 1
GPU Count to be Used: 1
GPU 0 Name: Tesla T4
GPU 0 Memory: 0.55GB/15.0GB (Used/Total)

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name              | Type                         | Params | Mode 
---------------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 108 M  | train
1 | validation_metric | MulticlassAccuracy           | 0      | train
2 | loss_func         | CrossEntropyLoss             | 0      | train
---------------------------------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.579   Total estimated model params size (MB)
/home/ci/opt/venv/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (4) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
INFO: Epoch 0, global step 1: 'val_acc' reached 0.37500 (best 0.37500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224533/epoch=0-step=1.ckpt' as top 3
INFO: Epoch 1, global step 2: 'val_acc' reached 0.50000 (best 0.50000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224533/epoch=1-step=2.ckpt' as top 3
INFO: Epoch 2, global step 3: 'val_acc' reached 0.37500 (best 0.50000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224533/epoch=2-step=3.ckpt' as top 3
INFO: Epoch 3, global step 4: 'val_acc' reached 0.50000 (best 0.50000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224533/epoch=3-step=4.ckpt' as top 3
INFO: Epoch 4, global step 5: 'val_acc' reached 0.62500 (best 0.62500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224533/epoch=4-step=5.ckpt' as top 3
INFO: Epoch 5, global step 6: 'val_acc' reached 0.62500 (best 0.62500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224533/epoch=5-step=6.ckpt' as top 3
INFO: Epoch 6, global step 7: 'val_acc' reached 0.62500 (best 0.62500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224533/epoch=6-step=7.ckpt' as top 3
INFO: Epoch 7, global step 8: 'val_acc' was not in top 3
INFO: Epoch 8, global step 9: 'val_acc' was not in top 3
INFO: Epoch 9, global step 10: 'val_acc' was not in top 3
Start to fuse 3 checkpoints via the greedy soup algorithm.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224533")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).

Few Shot Image Classification

We also provide an example of using MultiModalPredictor on a few-shot image classification task.

Load Dataset

We use the Stanford Cars dataset for demonstration and have downsampled the training set to have 8 samples per class. The Stanford Cars is an image classification dataset and contains 196 classes. For more information regarding the dataset, please see here.

import os
from autogluon.core.utils.loaders import load_zip, load_s3

download_dir = "./ag_automm_tutorial_fs_cls/stanfordcars/"
zip_file = "https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/stanfordcars.zip"
train_csv = "https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/train_8shot.csv"
test_csv = "https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/test.csv"

load_zip.unzip(zip_file, unzip_dir=download_dir)
dataset_path = os.path.join(download_dir)
Downloading ./ag_automm_tutorial_fs_cls/stanfordcars//file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/stanfordcars.zip...
100%|██████████| 1.96G/1.96G [01:09<00:00, 28.1MiB/s]
Unzipping ./ag_automm_tutorial_fs_cls/stanfordcars//file.zip to ./ag_automm_tutorial_fs_cls/stanfordcars/
!wget https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/train_8shot.csv -O ./ag_automm_tutorial_fs_cls/stanfordcars/train.csv
!wget https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/test.csv -O ./ag_automm_tutorial_fs_cls/stanfordcars/test.csv
--2024-07-16 22:49:22--  https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/train_8shot.csv
Resolving automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)... 16.182.42.193, 52.217.225.9, 52.217.95.41, ...
Connecting to automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)|16.182.42.193|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94879 (93K) [text/csv]
Saving to: ‘./ag_automm_tutorial_fs_cls/stanfordcars/train.csv’

./ag_automm_tutoria 100%[===================>]  92.66K  --.-KB/s    in 0.004s  

2024-07-16 22:49:22 (24.6 MB/s) - ‘./ag_automm_tutorial_fs_cls/stanfordcars/train.csv’ saved [94879/94879]

--2024-07-16 22:49:22--  https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/test.csv
Resolving automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)... 3.5.24.127, 3.5.13.29, 52.216.37.201, ...
Connecting to automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)|3.5.24.127|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 34472 (34K) [text/csv]
Saving to: ‘./ag_automm_tutorial_fs_cls/stanfordcars/test.csv’

./ag_automm_tutoria 100%[===================>]  33.66K  --.-KB/s    in 0.001s  

2024-07-16 22:49:22 (51.5 MB/s) - ‘./ag_automm_tutorial_fs_cls/stanfordcars/test.csv’ saved [34472/34472]
import pandas as pd
import os

train_df_raw = pd.read_csv(os.path.join(download_dir, "train.csv"))
train_df = train_df_raw.drop(
        columns=[
            "Source",
            "Confidence",
            "XMin",
            "XMax",
            "YMin",
            "YMax",
            "IsOccluded",
            "IsTruncated",
            "IsGroupOf",
            "IsDepiction",
            "IsInside",
        ]
    )
train_df["ImageID"] = download_dir + train_df["ImageID"].astype(str)


test_df_raw = pd.read_csv(os.path.join(download_dir, "test.csv"))
test_df = test_df_raw.drop(
        columns=[
            "Source",
            "Confidence",
            "XMin",
            "XMax",
            "YMin",
            "YMax",
            "IsOccluded",
            "IsTruncated",
            "IsGroupOf",
            "IsDepiction",
            "IsInside",
        ]
    )
test_df["ImageID"] = download_dir + test_df["ImageID"].astype(str)

print(os.path.exists(train_df.iloc[0]["ImageID"]))
print(train_df)
print(os.path.exists(test_df.iloc[0]["ImageID"]))
print(test_df)
True
                                                ImageID  LabelName
0     ./ag_automm_tutorial_fs_cls/stanfordcars/train...        147
1     ./ag_automm_tutorial_fs_cls/stanfordcars/train...        120
2     ./ag_automm_tutorial_fs_cls/stanfordcars/train...        147
3     ./ag_automm_tutorial_fs_cls/stanfordcars/train...        167
4     ./ag_automm_tutorial_fs_cls/stanfordcars/train...         73
...                                                 ...        ...
1563  ./ag_automm_tutorial_fs_cls/stanfordcars/train...        116
1564  ./ag_automm_tutorial_fs_cls/stanfordcars/train...         76
1565  ./ag_automm_tutorial_fs_cls/stanfordcars/train...        148
1566  ./ag_automm_tutorial_fs_cls/stanfordcars/train...        189
1567  ./ag_automm_tutorial_fs_cls/stanfordcars/train...        183

[1568 rows x 2 columns]
True
                                               ImageID  LabelName
0    ./ag_automm_tutorial_fs_cls/stanfordcars/test/...          0
1    ./ag_automm_tutorial_fs_cls/stanfordcars/test/...          0
2    ./ag_automm_tutorial_fs_cls/stanfordcars/test/...          0
3    ./ag_automm_tutorial_fs_cls/stanfordcars/test/...          1
4    ./ag_automm_tutorial_fs_cls/stanfordcars/test/...          1
..                                                 ...        ...
583  ./ag_automm_tutorial_fs_cls/stanfordcars/test/...        194
584  ./ag_automm_tutorial_fs_cls/stanfordcars/test/...        194
585  ./ag_automm_tutorial_fs_cls/stanfordcars/test/...        195
586  ./ag_automm_tutorial_fs_cls/stanfordcars/test/...        195
587  ./ag_automm_tutorial_fs_cls/stanfordcars/test/...        195

[588 rows x 2 columns]

Train a Few Shot Classifier

Similarly, we need to initialize MultiModalPredictor with the problem type few_shot_classification.

from autogluon.multimodal import MultiModalPredictor

predictor_fs_image = MultiModalPredictor(
    problem_type="few_shot_classification",
    label="LabelName",  # column name of the label
    eval_metric="acc",
)
predictor_fs_image.fit(train_df)
scores = predictor_fs_image.evaluate(test_df, metrics=["acc", "f1_macro"])
print(scores)
{'acc': 0.7993197278911565, 'f1_macro': 0.7941690962099125}
No path specified. Models will be saved in: "AutogluonModels/ag-20240716_224923"
=================== System Info ===================
AutoGluon Version:  1.1.1b20240716
Python Version:     3.10.13
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Fri May 17 18:07:48 UTC 2024
CPU Count:          8
Pytorch Version:    2.3.1+cu121
CUDA Version:       12.1
Memory Avail:       25.99 GB / 30.95 GB (84.0%)
Disk Space Avail:   183.71 GB / 255.99 GB (71.8%)
===================================================

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224923
    ```

INFO: Seed set to 0
/home/ci/opt/venv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
GPU Count: 1
GPU Count to be Used: 1
GPU 0 Name: Tesla T4
GPU 0 Memory: 0.56GB/15.0GB (Used/Total)

AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_224923")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).

Compare to the Default Classifier

We can also train a default image classifier and compare to the few shot classifier.

from autogluon.multimodal import MultiModalPredictor

predictor_default_image = MultiModalPredictor(
    problem_type="classification",
    label="LabelName",  # column name of the label
    eval_metric="acc",
)
predictor_default_image.fit(train_data=train_df)
scores = predictor_default_image.evaluate(test_df, metrics=["acc", "f1_macro"])
print(scores)
{'acc': 0.5629251700680272, 'f1_macro': 0.5488790970933828}
No path specified. Models will be saved in: "AutogluonModels/ag-20240716_225102"
=================== System Info ===================
AutoGluon Version:  1.1.1b20240716
Python Version:     3.10.13
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Fri May 17 18:07:48 UTC 2024
CPU Count:          8
Pytorch Version:    2.3.1+cu121
CUDA Version:       12.1
Memory Avail:       24.31 GB / 30.95 GB (78.6%)
Disk Space Avail:   180.82 GB / 255.99 GB (70.6%)
===================================================
AutoGluon infers your prediction problem is: 'regression' (because dtype of label-column == int and many unique label-values observed).
	Label info (max, min, mean, stddev): (195, 0, 97.5, 56.59764)
	If 'regression' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102
    ```

INFO: Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
GPU 0 Name: Tesla T4
GPU 0 Memory: 0.57GB/15.0GB (Used/Total)

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name              | Type                            | Params | Mode 
------------------------------------------------------------------------------
0 | model             | TimmAutoModelForImagePrediction | 96.3 M | train
1 | validation_metric | MulticlassAccuracy              | 0      | train
2 | loss_func         | CrossEntropyLoss                | 0      | train
------------------------------------------------------------------------------
96.3 M    Trainable params
0         Non-trainable params
96.3 M    Total params
385.132   Total estimated model params size (MB)
INFO: Epoch 0, global step 4: 'val_acc' reached 0.00000 (best 0.00000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=0-step=4.ckpt' as top 3
INFO: Epoch 0, global step 9: 'val_acc' reached 0.00318 (best 0.00318), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=0-step=9.ckpt' as top 3
INFO: Epoch 1, global step 14: 'val_acc' reached 0.01274 (best 0.01274), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=1-step=14.ckpt' as top 3
INFO: Epoch 1, global step 19: 'val_acc' reached 0.06688 (best 0.06688), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=1-step=19.ckpt' as top 3
INFO: Epoch 2, global step 24: 'val_acc' reached 0.08599 (best 0.08599), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=2-step=24.ckpt' as top 3
INFO: Epoch 2, global step 29: 'val_acc' reached 0.12420 (best 0.12420), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=2-step=29.ckpt' as top 3
INFO: Epoch 3, global step 34: 'val_acc' reached 0.15605 (best 0.15605), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=3-step=34.ckpt' as top 3
INFO: Epoch 3, global step 39: 'val_acc' reached 0.21019 (best 0.21019), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=3-step=39.ckpt' as top 3
INFO: Epoch 4, global step 44: 'val_acc' reached 0.26115 (best 0.26115), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=4-step=44.ckpt' as top 3
INFO: Epoch 4, global step 49: 'val_acc' reached 0.27707 (best 0.27707), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=4-step=49.ckpt' as top 3
INFO: Epoch 5, global step 54: 'val_acc' reached 0.32166 (best 0.32166), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=5-step=54.ckpt' as top 3
INFO: Epoch 5, global step 59: 'val_acc' reached 0.32484 (best 0.32484), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=5-step=59.ckpt' as top 3
INFO: Epoch 6, global step 64: 'val_acc' reached 0.39490 (best 0.39490), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=6-step=64.ckpt' as top 3
INFO: Epoch 6, global step 69: 'val_acc' reached 0.42357 (best 0.42357), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=6-step=69.ckpt' as top 3
INFO: Epoch 7, global step 74: 'val_acc' reached 0.42675 (best 0.42675), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=7-step=74.ckpt' as top 3
INFO: Epoch 7, global step 79: 'val_acc' reached 0.44268 (best 0.44268), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=7-step=79.ckpt' as top 3
INFO: Epoch 8, global step 84: 'val_acc' reached 0.43631 (best 0.44268), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=8-step=84.ckpt' as top 3
INFO: Epoch 8, global step 89: 'val_acc' reached 0.48089 (best 0.48089), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=8-step=89.ckpt' as top 3
INFO: Epoch 9, global step 94: 'val_acc' reached 0.50318 (best 0.50318), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=9-step=94.ckpt' as top 3
INFO: Epoch 9, global step 99: 'val_acc' reached 0.50955 (best 0.50955), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=9-step=99.ckpt' as top 3
INFO: Epoch 10, global step 104: 'val_acc' reached 0.49363 (best 0.50955), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=10-step=104.ckpt' as top 3
INFO: Epoch 10, global step 109: 'val_acc' reached 0.53822 (best 0.53822), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=10-step=109.ckpt' as top 3
INFO: Epoch 11, global step 114: 'val_acc' reached 0.53503 (best 0.53822), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=11-step=114.ckpt' as top 3
INFO: Epoch 11, global step 119: 'val_acc' reached 0.55732 (best 0.55732), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=11-step=119.ckpt' as top 3
INFO: Epoch 12, global step 124: 'val_acc' was not in top 3
INFO: Epoch 12, global step 129: 'val_acc' was not in top 3
INFO: Epoch 13, global step 134: 'val_acc' reached 0.56688 (best 0.56688), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=13-step=134.ckpt' as top 3
INFO: Epoch 13, global step 139: 'val_acc' reached 0.55414 (best 0.56688), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=13-step=139.ckpt' as top 3
INFO: Epoch 14, global step 144: 'val_acc' reached 0.55732 (best 0.56688), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=14-step=144.ckpt' as top 3
INFO: Epoch 14, global step 149: 'val_acc' was not in top 3
INFO: Epoch 15, global step 154: 'val_acc' was not in top 3
INFO: Epoch 15, global step 159: 'val_acc' reached 0.56369 (best 0.56688), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=15-step=159.ckpt' as top 3
INFO: Epoch 16, global step 164: 'val_acc' reached 0.57962 (best 0.57962), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=16-step=164.ckpt' as top 3
INFO: Epoch 16, global step 169: 'val_acc' reached 0.56688 (best 0.57962), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102/epoch=16-step=169.ckpt' as top 3
INFO: Epoch 17, global step 174: 'val_acc' was not in top 3
INFO: Epoch 17, global step 179: 'val_acc' was not in top 3
INFO: Epoch 18, global step 184: 'val_acc' was not in top 3
INFO: Epoch 18, global step 189: 'val_acc' was not in top 3
INFO: Epoch 19, global step 194: 'val_acc' was not in top 3
INFO: Epoch 19, global step 199: 'val_acc' was not in top 3
INFO: `Trainer.fit` stopped: `max_epochs=20` reached.
Start to fuse 3 checkpoints via the greedy soup algorithm.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20240716_225102")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).

As you can see that the few_shot_classification performs much better than the default classification in image classification as well.

Customization

To learn how to customize AutoMM, please refer to Customize AutoMM.