This repository contains the official PyTorch implementation of our paper:
Our classification codebase is built upon the MMClassification toolkit (old version).
conda create -n plain_mamba python=3.10 -y
source activate plain_mamba
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 -f https://download.pytorch.org/whl/torch_stable.html --no-cache
conda install -c conda-forge cudatoolkit-dev # Optional, only needed when facing cuda errors
pip install -U openmim
mim install mmcv-full
pip install mamba-ssm
pip install mlflow fvcore timm lmdb
cd plain_mamba
pip install -e .
cd downstream/mmdetection # set up object detection and instance segmentation
pip install -e .
cd downstream/mmsegmentation # set up semantic segmentation
pip install -e .
For ImageNet experiment, we convert the dataset to LMDB format for efficient data loading. You can convert the dataset by running:
python tools/dataset_tools/create_lmdb_dataset.py \
--train-img-dir data/imagenet/train \
--train-out data/imagenet/imagenet_lmdb/train \
--val-img-dir data/imagenet/val \
--val-out data/imagenet/imagenet_lmdb/val
You will also need to download the ImageNet meta data from Link.
For downstream tasks, please follow MMDetection and MMSegmentation to set up your datasets.
After setting up, the datasets file structure should be as the following:
PlainMamba
|-- ...
|-- data
| |__ imagenet
| |-- imagenet_lmdb
| | |-- train
| | | |-- data.mdb
| | | |__ lock.mdb
| | |-- val
| | | |-- data.mdb
| | | |__ lock.mdb
| |__ meta
| |__ ...
|__ downstream
|-- mmsegmentation
| |-- ...
| |__ data
| |__ ade
| |__ ADEChallengeData2016
| |-- annotations
| | |__ ...
| |-- images
| | |__ ...
| |-- objectInfo150.txt
| |__ sceneCategories.txt
|
|__ mmdetection
|-- ...
|__ data
|__ coco
|-- train2017
| |__ ...
|-- val2017
| |__ ...
|__ annotations
|-- instances_train2017.json
|-- instances_val2017.json
|__ ...
# Example: Training PlainMamba-L1 model
zsh tool/dist_train.sh plain_mamba_configs/plain_mamba_l1_in1k_300e.py 8
# Example: Testing PlainMamba-L1 model
zsh tool/dist_test.sh plain_mamba_configs/plain_mamba_l1_in1k_300e.py work_dirs/plain_mamba_l1_in1k_300e/epoch_300.pth 8 --metrics accuracy
Run cd downstream/mmdetection
first.
# Example: Training PlainMamba-Adapter-L1 Mask R-CNN with 1x schedule
zsh tools/dist_train.sh plain_mamba_det_configs/maskrcnn/l1_maskrcnn_1x.py 8
# Example: Training PlainMamba-Adapter-L1 RetinaNet with 1x schedule
zsh tools/dist_train.sh plain_mamba_det_configs/retinanet/l1_retinanet_1x.py 8
# Example: Testing PlainMamba-Adapter-L1 Mask R-CNN 1x model
zsh tools/dist_test.sh plain_mamba_det_configs/maskrcnn/l1_maskrcnn_1x.py work_dirs/l1_maskrcnn_1x/epoch_12.pth 8 --eval bbox segm
# Example: Testing PlainMamba-Adapter-L1 RetinaNet 1x model
zsh tools/dist_test.sh plain_mamba_det_configs/retinanet/l1_retinanet_1x.py work_dirs/l1_retinanet_1x/epoch_12.pth 8 --eval bbox
Run cd downstream/mmsegmentation
first.
# Example: Training PlainMamba-L1 based UperNet
zsh tools/dist_train.sh plain_mamba_seg_configs/l1_upernet.py 8
# Example: Testing PlainMamba-L1 based UperNet
zsh tools/dist_test.sh plain_mamba_seg_configs/l1_upernet.py work_dirs/l1_upernet/iter_160000.pth 8 --eval mIoU
Model | #Params (M) | Top-1 Acc | Top-5 Acc | Config | Model |
---|---|---|---|---|---|
PlainMamba-L1 | 7.3 | 77.9 | 94.0 | Link | Link |
PlainMamba-L2 | 25.7 | 81.6 | 95.6 | Link | Link |
PlainMamba-L3 | 50.5 | 82.3 | 95.9 | Link | Link |
Model | #Params (M) | AP Box | AP Mask | Config | Model |
---|---|---|---|---|---|
PlainMamba-Adapter-L1 | 31 | 44.1 | 39.1 | Link | Link |
PlainMamba-Adapter-L2 | 53 | 46.0 | 40.6 | Link | Link |
PlainMamba-Adapter-L3 | 79 | 46.8 | 41.2 | Link | Link |
Model | #Params (M) | AP Box | Config | Model |
---|---|---|---|---|
PlainMamba-Adapter-L1 | 19 | 41.7 | Link | Link |
PlainMamba-Adapter-L2 | 40 | 43.9 | Link | Link |
PlainMamba-Adapter-L3 | 67 | 44.8 | Link | Link |
Model | #Params (M) | mIoU | Config | Model |
---|---|---|---|---|
PlainMamba-L1 | 35 | 44.1 | Link | Link |
PlainMamba-L2 | 55 | 46.8 | Link | Link |
PlainMamba-L3 | 81 | 49.1 | Link | Link |
@misc{yang2024plainmamba,
title={PlainMamba: Improving Non-Hierarchical Mamba in Visual Recognition},
author={Chenhongyi Yang and Zehui Chen and Miguel Espinosa and Linus Ericsson and Zhenyu Wang and Jiaming Liu and Elliot J. Crowley},
year={2024},
eprint={2403.17695},
archivePrefix={arXiv},
primaryClass={cs.CV}
}