This repo contains text data, code and pre-trained models for paper Improving CLIP Training with Language Rewrites. If you find the data, models or code useful, please consider citing our paper:
@inproceedings{fan2023improving,
title={Improving CLIP Training with Language Rewrites},
author={Fan, Lijie and Krishnan, Dilip and Isola, Phillip and Katabi, Dina and Tian, Yonglong},
booktitle={NeurIPS},
year={2023}
}
We propose Language augmented CLIP (LaCLIP). LaCLIP enhances CLIP training by rewriting text descriptions with large language models. Key steps:
- Meta-Input-Output Generation: we explored different strategies for generating meta-input-output pairs that can be used as examples in the prompt context for LLaMA in-context learning, namely ChatGPT, Bard, MSCOCO and Human. Examples of generating such pairs with ChatGPT:
- In-Context Learning with LLaMA: Utilizing the constructed context input as a prompt, LLaMA exhibits its ability to perform text completion and generate rewritten versions of the corresponding text samples. This process is conducted for each text sample present in the pre-training image-text dataset. Example of LLaMA rewriting a text sample:
Dataset | Method | Zero-Shot | Checkpoint |
---|---|---|---|
CC3M | CLIP | 15.8 | ViT-B/16 |
CC3M | LaCLIP | 21.5 | ViT-B/16 |
CC12M | CLIP | 40.2 | ViT-B/16 |
CC12M | LaCLIP | 48.4 | ViT-B/16 |
RedCaps | CLIP | 42.9 | ViT-B/16 |
RedCaps | LaCLIP | 46.2 | ViT-B/16 |
LAION-400M | CLIP | 62.0 | ViT-B/32 |
LAION-400M | LaCLIP | 64.4 | ViT-B/32 |
LAION-400M | CLIP | 67.0 | ViT-B/16 |
LAION-400M | LaCLIP | 69.3 | ViT-B/16 |
LAION-400M | CLIP | 71.8 | ViT-L/14 |
LAION-400M | LaCLIP | 74.5 | ViT-L/14 |
- Code for generating rewrites of text samples
- 4 versions of augmented text on 3 datasets (CC3M, CC12M, RedCaps)
- Pre-trained models with LaCLIP and vanilla CLIP
- Zero-shot evaluation code on ImageNet
- Code for training LaCLIP
- PyTorch 1.11.0
- torchvision 0.12.0
- timm 0.5.4
- open_clip (optional, for LAION-400M models)
- LLaMA (for generating rewrites)
- Original is the original caption associated with each image.
- ChatGPT/Bard/MSCOCO/Human is the text generated by LLaMA ICL with the ChatGPT/Bard/MSCOCO/Human Meta-Input-Output pairs as in-context learning examples.
Dataset | Original | ChatGPT | Bard | MSCOCO | Human |
---|---|---|---|---|---|
CC3M | Link | Link | Link | Link | Link |
CC12M | Link | Link | Link | Link | Link |
RedCaps | Link | Link | Link | Link | Link |
In order to rewrite for other datasets of your own interest, we put the code for rewriting in the rewrite
folder.
Please refer to Meta LLaMA page for detailed instructions for model access and environment setup.
The rewritten text could be generated by running the following command:
export LLAMA_FOLDER=/PATH/TO/LLAMA/WEIGHTS
export PYTHONPATH=/PATH/TO/LLAMA/
export model='7b'
torchrun --nproc_per_node 1 --master_port 12388 \
llama_rewrite.py --ckpt_dir ${LLAMA_FOLDER}/${model} --tokenizer_path ${LLAMA_FOLDER}/${model}/tokenizer.model \
--max_batch_size 100 --max_seq_len 400 --prompt_filename text/source.txt --output_filename text/target.txt --sample_mode chatgpt --temperature 0.9
--prompt_filename
: text file to be rewritten, each line is one sentence--output_filename
: output path--sample_mode
: sample mode for in-context learning (chatgpt
,bard
,mscoco
, orhuman
)--temperature
: temperature for sampling, higher temperature leads to more diverse text
To perform zero-shot evaluation on ImageNet, use the following command:
For CC3M, CC12M and RedCaps models:
python eval_zeroshot_imagenet.py --imagenet-root [PATH_TO_IMAGENET] --ckpt-path [PATH_TO_CHECKPOINT] --model CLIP_VITB16 --batch-size 128 --workers 8
For LAION-400M models:
python eval_zeroshot_imagenet_laion.py --imagenet-root [PATH_TO_IMAGENET] --ckpt-path [PATH_TO_CHECKPOINT] --model [ViT-B-32, ViT-B-16 or ViT-L-14] --batch-size 128 --workers 8
add --quickgelu
for ViT-L-14 models.
To train LaCLIP, use the following command:
torchrun --nproc_per_node=GPU_PER_NODE --nnodes=NUM_NODE --node_rank=NODE_RANK \
--master_addr=MASTER_NODE --master_port=PORT \
train.py \
--train-data PATH/TO/TRAINING/CSV \
--root PATH/TO/TRAINING/IMAGE/ROOT \
--imagenet-root PATH/TO/IMAGENET/ROOT \
--aug-text --augmented_caption_filelist PATH/TO/AUGMENTED/CAPTION/FILES \
--output-dir PATH/TO/OUTPUT \
--model CLIP_VITB16 \
--lr 1e-3 --wd 0.5 --warmup-epochs 1 --batch-size 256 --epochs 35
--train-data
: csv file for training data, each line is one image-text pair, with the relative image path and original caption separated by a comma--root
: root dir for images--imagenet-root
: root dir for ImageNet, used for zero-shot evaluation--aug-text
: whether to use augmented text--augmented_caption_filelist
: text files for augmented text, each line is one sentence, the order of the sentences should be the same as the order of the images in--train-data
. Seperate the augmented text files with a space for multiple augmented text files.--output-dir
: saving dir for logs and checkpoints--model
: CLIP backbone architecture
- Make sure the sample order in the
--augmented_caption_filelist
is the same as the order in--train-data
. - Please refer to Table A3 in the paper for the hyperparameters used for each dataset.
- To train vanilla CLIP, comment out the
--aug-text
and--augmented_caption_filelist
arguments.