Handling Class Imbalance with AutoMM - Focal Loss

Open In Colab Open In SageMaker Studio Lab

In this tutorial, we introduce how to use focal loss with the AutoMM package for balanced training. Focal loss is first introduced in this Paper and can be used for balancing hard/easy samples as well as un-even sample distribution among classes. This tutorial demonstrates how to use focal loss.

Create Dataset

We use the shopee dataset for demonstration in this tutorial. Shopee dataset contains 4 classes and has 200 samples each in the training set.

from autogluon.multimodal.utils.misc import shopee_dataset

download_dir = "./ag_automm_tutorial_imgcls_focalloss"
train_data, test_data = shopee_dataset(download_dir)
Downloading ./ag_automm_tutorial_imgcls_focalloss/file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/shopee.zip...
100%|██████████| 84.0M/84.0M [00:03<00:00, 23.1MiB/s]

For the purpose of demonstrating the effectiveness of Focal Loss on imbalanced training data, we artificially downsampled the shopee training data to form an imbalanced distribution.

import numpy as np
import pandas as pd

ds = 1

imbalanced_train_data = []
for lb in range(4):
    class_data = train_data[train_data.label == lb]
    sample_index = np.random.choice(np.arange(len(class_data)), size=int(len(class_data) * ds), replace=False)
    ds /= 3  # downsample 1/3 each time for each class
    imbalanced_train_data.append(class_data.iloc[sample_index])
imbalanced_train_data = pd.concat(imbalanced_train_data)
print(imbalanced_train_data)

weights = []
for lb in range(4):
    class_data = imbalanced_train_data[imbalanced_train_data.label == lb]
    weights.append(1 / (class_data.shape[0] / imbalanced_train_data.shape[0]))
    print(f"class {lb}: num samples {len(class_data)}")
weights = list(np.array(weights) / np.sum(weights))
print(weights)
                                                 image  label
18   /home/ci/autogluon/docs/tutorials/multimodal/a...      0
100  /home/ci/autogluon/docs/tutorials/multimodal/a...      0
5    /home/ci/autogluon/docs/tutorials/multimodal/a...      0
17   /home/ci/autogluon/docs/tutorials/multimodal/a...      0
177  /home/ci/autogluon/docs/tutorials/multimodal/a...      0
..                                                 ...    ...
665  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
715  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
657  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
752  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
701  /home/ci/autogluon/docs/tutorials/multimodal/a...      3

[295 rows x 2 columns]
class 0: num samples 200
class 1: num samples 66
class 2: num samples 22
class 3: num samples 7
[0.0239850482815907, 0.07268196448966878, 0.21804589346900635, 0.6852870937597342]

Create and train MultiModalPredictor

Train with Focal Loss

We specify the model to use focal loss by setting the "optimization.loss_function" to "focal_loss". There are also three other optional parameters you can set.

optimization.focal_loss.alpha - a list of floats which is the per-class loss weight that can be used to balance un-even sample distribution across classes. Note that the len of the list must match the total number of classes in the training dataset. A good way to compute alpha for each class is to use the inverse of its percentage number of samples.

optimization.focal_loss.gamma - float which controls how much to focus on the hard samples. Larger value means more focus on the hard samples.

optimization.focal_loss.reduction - how to aggregate the loss value. Can only take "mean" or "sum" for now.

import uuid
from autogluon.multimodal import MultiModalPredictor

model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_focal"

predictor = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)

predictor.fit(
    hyperparameters={
        "model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
        "env.num_gpus": 1,
        "optimization.loss_function": "focal_loss",
        "optimization.focal_loss.alpha": weights,  # shopee dataset has 4 classes.
        "optimization.focal_loss.gamma": 1.0,
        "optimization.focal_loss.reduction": "sum",
        "optimization.max_epochs": 10,
    },
    train_data=imbalanced_train_data,
) 

predictor.evaluate(test_data, metrics=["acc"])
=================== System Info ===================
AutoGluon Version:  1.1.1b20240716
Python Version:     3.10.13
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Fri May 17 18:07:48 UTC 2024
CPU Count:          8
Pytorch Version:    2.3.1+cu121
CUDA Version:       12.1
Memory Avail:       28.67 GB / 30.95 GB (92.6%)
Disk Space Avail:   179.90 GB / 255.99 GB (70.3%)
===================================================

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/1e52ce01257d4f1094ce2ae022798eb6-automm_shopee_focal
    ```

Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
GPU 0 Name: Tesla T4
GPU 0 Memory: 0.42GB/15.0GB (Used/Total)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                            | Params | Mode 
------------------------------------------------------------------------------
0 | model             | TimmAutoModelForImagePrediction | 95.7 M | train
1 | validation_metric | MulticlassAccuracy              | 0      | train
2 | loss_func         | FocalLoss                       | 0      | train
------------------------------------------------------------------------------
95.7 M    Trainable params
0         Non-trainable params
95.7 M    Total params
382.772   Total estimated model params size (MB)
Epoch 0, global step 2: 'val_accuracy' reached 0.62712 (best 0.62712), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/1e52ce01257d4f1094ce2ae022798eb6-automm_shopee_focal/epoch=0-step=2.ckpt' as top 3
Epoch 1, global step 4: 'val_accuracy' reached 0.96610 (best 0.96610), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/1e52ce01257d4f1094ce2ae022798eb6-automm_shopee_focal/epoch=1-step=4.ckpt' as top 3
Epoch 2, global step 6: 'val_accuracy' reached 0.96610 (best 0.96610), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/1e52ce01257d4f1094ce2ae022798eb6-automm_shopee_focal/epoch=2-step=6.ckpt' as top 3
Epoch 3, global step 8: 'val_accuracy' reached 0.96610 (best 0.96610), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/1e52ce01257d4f1094ce2ae022798eb6-automm_shopee_focal/epoch=3-step=8.ckpt' as top 3
Epoch 4, global step 10: 'val_accuracy' was not in top 3
Epoch 5, global step 12: 'val_accuracy' was not in top 3
Epoch 6, global step 14: 'val_accuracy' was not in top 3
Start to fuse 3 checkpoints via the greedy soup algorithm.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/1e52ce01257d4f1094ce2ae022798eb6-automm_shopee_focal")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
{'acc': 0.925}

Train without Focal Loss

import uuid
from autogluon.multimodal import MultiModalPredictor

model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_non_focal"

predictor2 = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)

predictor2.fit(
    hyperparameters={
        "model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
        "env.num_gpus": 1,
        "optimization.max_epochs": 10,
    },
    train_data=imbalanced_train_data,
)

predictor2.evaluate(test_data, metrics=["acc"])
=================== System Info ===================
AutoGluon Version:  1.1.1b20240716
Python Version:     3.10.13
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Fri May 17 18:07:48 UTC 2024
CPU Count:          8
Pytorch Version:    2.3.1+cu121
CUDA Version:       12.1
Memory Avail:       24.92 GB / 30.95 GB (80.5%)
Disk Space Avail:   179.56 GB / 255.99 GB (70.1%)
===================================================

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/964511d81c0a4025ab5251ae72331139-automm_shopee_non_focal
    ```

Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
GPU 0 Name: Tesla T4
GPU 0 Memory: 0.59GB/15.0GB (Used/Total)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                            | Params | Mode 
------------------------------------------------------------------------------
0 | model             | TimmAutoModelForImagePrediction | 95.7 M | train
1 | validation_metric | MulticlassAccuracy              | 0      | train
2 | loss_func         | CrossEntropyLoss                | 0      | train
------------------------------------------------------------------------------
95.7 M    Trainable params
0         Non-trainable params
95.7 M    Total params
382.772   Total estimated model params size (MB)
Epoch 0, global step 2: 'val_accuracy' reached 0.69492 (best 0.69492), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/964511d81c0a4025ab5251ae72331139-automm_shopee_non_focal/epoch=0-step=2.ckpt' as top 3
Epoch 1, global step 4: 'val_accuracy' reached 0.76271 (best 0.76271), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/964511d81c0a4025ab5251ae72331139-automm_shopee_non_focal/epoch=1-step=4.ckpt' as top 3
Epoch 2, global step 6: 'val_accuracy' reached 0.94915 (best 0.94915), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/964511d81c0a4025ab5251ae72331139-automm_shopee_non_focal/epoch=2-step=6.ckpt' as top 3
Epoch 3, global step 8: 'val_accuracy' reached 0.93220 (best 0.94915), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/964511d81c0a4025ab5251ae72331139-automm_shopee_non_focal/epoch=3-step=8.ckpt' as top 3
Epoch 4, global step 10: 'val_accuracy' reached 0.93220 (best 0.94915), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/964511d81c0a4025ab5251ae72331139-automm_shopee_non_focal/epoch=4-step=10.ckpt' as top 3
Epoch 5, global step 12: 'val_accuracy' was not in top 3
Epoch 6, global step 14: 'val_accuracy' was not in top 3
Epoch 7, global step 16: 'val_accuracy' reached 0.94915 (best 0.94915), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/964511d81c0a4025ab5251ae72331139-automm_shopee_non_focal/epoch=7-step=16.ckpt' as top 3
Start to fuse 3 checkpoints via the greedy soup algorithm.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/964511d81c0a4025ab5251ae72331139-automm_shopee_non_focal")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
{'acc': 0.725}

As we can see that the model with focal loss is able to achieve a much better performance compared to the model without focal loss. When your data is imbalanced, try out focal loss to see if it brings improvements to the performance!

Citations

@misc{https://doi.org/10.48550/arxiv.1708.02002,
  doi = {10.48550/ARXIV.1708.02002},
  
  url = {https://arxiv.org/abs/1708.02002},
  
  author = {Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Dollár, Piotr},
  
  keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
  
  title = {Focal Loss for Dense Object Detection},
  
  publisher = {arXiv},
  
  year = {2017},
  
  copyright = {arXiv.org perpetual, non-exclusive license}
}