Xuehai He1 Diji Yang1 Weixi Feng2 Tsu-Jui Fu2 Arjun Akula3 Varun Jampani3
Pradyumna Narayana3 Sugato Basu3 William Yang Wang2 Xin Eric Wang1
UC Santa Cruz, 2 UC Santa Barbara, 3 Google
Figure 2: The counterfactual prompt learning framework. We freeze the vision encoder F and the text encoder
G, and only optimize the task-agnostic prompts and the instance-conditioned net M (blue blocks). Please refer to
Section 3.2 for the explanation.
where τ is the temperature parameter, < · > de- derstand non-spurious semantic information and
notes the cosine similarity. Cross-entropy loss is learn generalized prompt representations.
then minimized and the gradients can be back-
propagated via the text encoder G to update the 3.3 Controllable Counterfactual Generation
learnable prompt representation p. During training, By viewing image feature v as a potential cause
the weights of CLIP always remain frozen. During of the label, a non-spurious feature shall be a suf-
inference, Eq. 1 is used to compute the probability ficient cause of the label. So we would like to
for each class. generate counterfactuals by identifying minimal
non-spurious feature change that causes the label
3.2 Method Overview
change. The illustration of the counterfactual con-
An overview of the Counterfactual Prompt Learn- struction process is shown in Figure 3. Given posi-
ing (CPL) framework is shown in Figure 2. For tive image features v and negative image features
pre-processing, we construct task-relevant prompts v − , we can generate negative counterfactual image
for all training samples. The goal is to optimize features v ′ as below:
the task-agnostic prompt p.2 During training,
given a positive image-prompt pair, we first per- v ′ = (1 − u) ◦ v + u ◦ v − , (2)
form text-based negative sampling to find the most
semantically-similar negative sample based on text where ◦ is the element-wise multiplication and u
similarity scores. Then we adopt a controllable is the parameter controlling the amount of nega-
counterfactual generation strategy to construct the tive image feature that replaces the positive image
counterfactual from the positive and negative sam- feature. The negative image features are extracted
ples in the visual feature space. Finally, we perform from those images similar to the original image
contrastive learning using both generated counter- at the semantic level, which we will introduce in
factual image features and factual image features Section 3.4.
in a joint optimization framework to fine-tune the To capture the non-spuriousness, we would like
task-agnostic prompt p, allowing the model to un- to construct counterfactuals by replacing essential
non-spurious features only. This can be achieved
Together with the instance-conditional net M as intro-
duced in Zhou et al. (2022). For simplicity, we will only use by minimizing the amount of feature change u∗
p hereafter as p and M are always optimized together. to the original image that can causally incur label
robustness and algorithm efficiency. Therefore, dur-
ing training, in each batch, we only utilize the most
= + semantically-similar one to generate counterfactual
image features. Other image samples are filtered
Semantic concepts may be highly complex in the
visual representations, and thus it is hard to directly
measure semantic similarity in the visual space.
While language is more expressive and naturally
preserves semantic meanings. Therefore, we pro-
Figure 3: Counterfactual generation process. v and pose a text-based negative sampling method. We
c are the positive image feature and label, while v − first measure the text similarity between prompts
and c− are the negative image feature and label. ◦ is
with BERTScore (Zhang et al., 2019), which com- 3
element-wise multiplication. By mixing v and v − , the
counterfactual image feature v ′ is predicted as a nega- putes pairwise cosine similarity between reference
tive label c− by the discriminator D. u is minimized sentences and candidate sentences using BERT
so a minimal change to the positive image feature u is contextual embedding (Devlin et al., 2019). We
captured here to causally change the label. compute a similarity matrix with the value of each
element being:
change: sim(i, j) = BERTScore(hi , hj ). (5)
∥u∗ ∥1 Denote B as the collection of sampled instances.
u (3) During training, each prompt hc ∈ B (1 ≤ c ≤ C,
s.t. u∗ = arg maxDc− (v ′ ).
u where C is the size of sampled instances) can be
Given the factual and counterfactual features v treated as a query. Given a query prompt hq , its
and v ′ , we aim to learn the prompt that can help most semantically similar prompt (the one with
CLIP better align visual features v and textual fea- the highest BERTScore) hk is searched from B.
tures G(t) with same semantic meanings. This can Then we use the CLIP vision encoder to obtain the
be achieved by maximizing the mutual information features of the corresponding positive and negative
(MI) between v and G(t). Therefore, by minimiz- images v and v − .
ing the InfoNCE loss (Hjelm et al., 2018), we can 3.5 Joint Optimization
maximize the lower bound on MI(v, G(t)). To this
end, we define the contrastive objective function In addition to the contrastive learning loss as intro-
based on the InfoNCE estimator following Khosla duced in Eq. 4, we also adopt the standard cross-
et al. (2020): entropy loss for training:
S(v,G(t)) LCE (p) = − y c log p (tc | x) , (6)
∗ e τ
LCL (p, u ) = −log( S(v,G(t)) S(v ′ ,G(t))
), c
e τ +e τ
where y c denotes the one-hot ground-truth an-
notation of the label. We treat all downstream
where S (·, ·) is normally the cosine similarity func-
tasks in this work as classification tasks, where
tion and τ is the temperature value.
the model predicts if the image and text prompt
3.4 Text-based Negative Sampling pair is matched or not.
Then the task-agnostic prompt p is learned
We then discuss how to perform negative sampling
by minimizing the weighted combination of con-
for constructing counterfactual features. As sug-
trastive learning loss and cross-entropy loss:
gested in Robinson et al. (2020), good negative
samples have different labels and are difficult to be L(p) = LCE (p) + λ · LCL (p, u∗ ), (7)
distinguished from an anchor point, while their se-
mantic representations are close (Suresh and Ong, where λ determines the weight of LCL .
2021). Since not all negative samples can serve In fact, we can seek to put Eq. 3 and Eq. 7 in
as useful negatives (Chuang et al., 2020), indis- a single-stage optimization framework. The in-
criminate leverage of these data may harm model tuition is that we generate counterfactual image
Algorithm 1 Counterfactual Prompt Learning cation, the prompts are class labels for each task;
1: X: image space for image-text retrieval, captions for each image
2: Y: label space are adopted as prompts; for visual question an-
3: hc : task-relevant prompt for the c-th class
4: H: the set of task-relevant prompts swering, we first use a pre-trained generative T5
5: p: the task-agnostic prompt model (Raffel et al., 2019) to convert the question-
6: v: image features answer pairs into declarative sentences referring
7: v − : negative image features
8: u: parameter controls the generation of counterfactual to the VQA prompt generation method proposed
image features in Song et al. (2022b). Then, motivated by Wei et al.
9: function CPL(X, Y)
10: H←Y
(2022), we add additional category information into
11: tc ← [p, hc ] the prompt generated from templates based on the
12: for each i, j do question type to help the model perform interme-
13: sim(i, j) = BERTScore(hi , hj ) ▷ Eq. 5
14: end for
diate reasoning steps. Specifically, we add “The
15: for q in the batch do question is asking about others” for Other ques-
16: v ← vq tions before the generated declarative sentence. In
17: Find the index k that maximize sim(q, k) with the
given index q a similar vein, “The question is asking about yes
18: v − ← vk or no” and “The question is asking about numbers”
19: Generate counterfactual image features ▷ Eq. 2 are added for Yes/No and Number questions.
20: LCE ← cross-entropy loss ▷ Eq. 6
21: LCL ← contrastive loss ▷ Eq. 4
22: Update p and u with the joint optimization loss ▷ 4 Experiments
Eq. 7
23: end for 4.1 Tasks and Datasets
24: end function
Image Classification. We employ seven pub-
licly available image classification datasets used
features with minimal feature change that can max-
in CLIP: SUN397 (Xiao et al., 2010), Cal-
imize the negative prediction probability, and at
tech101 (Griffin et al., 2007), ImageNet (Deng
the same time, utilize contrastive learning to learn
et al., 2009), OxfordPets (Parkhi et al., 2012),
the prompt that can guide CLIP to explicitly distin-
StandfordCars (Krause et al., 2013), Flow-
guish between factual images and counterfactual
ers102 (Nilsback and Zisserman, 2008), and
images. Putting all pieces together, we have:
Food101 (Bossard et al., 2014). These datasets
minimize LCE (p) + λ · LCL (p, u∗ ) + ∥u∗ ∥1 constitute a comprehensive benchmark, which cov-
p,u ers a diverse set of vision tasks including the clas-
s.t. u∗ = arg maxDc− (v ′ ) sification of generic objects, fine-grained image
where v ′ = (1 − u) ◦ v + u ◦ v − . recognition, action classification, etc. To evaluate
(8) the generalization ability of methods, we split those
In Eq. 8, the gradients can be back-propagated all datasets into seen and unseen classes. Only images
the way through the text encoder G to the task- in the seen classes will be used for training. The
agnostic prompt, making use of the rich knowledge setting follows the few-shot evaluation protocol in
encoded in the pre-trained CLIP model to optimize CLIP, where we use 16 shots for training and full
the prompt. test sets for testing.
Algorithm 1 presents the learning algorithm of
CPL. In summary, given few input training samples Image-Text Retrieval. We consider two datasets
{(x1 , y1 ) , . . . , (xn , yn )}, CPL consists of three for image-text retrieval: MSCOCO (Lin et al.,
main steps: (1) compute the similarity matrix be- 2014) and Flickr30K (Plummer et al., 2015). We
tween different text prompts within the sampled adopt the widely used Karpathy split (Karpathy
batch; (2) generate counterfactual image features; and Fei-Fei, 2015) for both the MSCOCO and
(3) optimize p and u with contrastive learning loss Flickr30K datasets, where MSCOCO contains
and cross-entropy loss. 113/5K/5K for train/validation/test. Flickr30K con-
tains 29K/1K/1K images for train/validation/test.
3.6 Task-relevant Prompt Construction We construct few-shot setting subsets for both Co-
We construct task-relevant prompts H for image CoOp and CPL by taking 0.5%, 1%, and 3% of
classification, image-text retrieval, and visual ques- training instances. We train the model with the sub-
tion answering, respectively. For image classifi- sets and evaluate its performance on the complete
Classes Method SUN397 Caltech101 ImageNet OxfordPets StanfordCars Flowers102 Food101 Average
CLIP 69.40 96.51 72.46 91.33 74.85 72.17 90.12 80.98
Seen CoCoOp 79.08 [+13.95] 97.66 [+1.19] 76.01 [+4.90] 95.18 [+4.22] 70.91 [-5.26] 94.65 [+31.15] 90.67 [+0.61] 86.31 [+6.58]
CPL (ours) 81.05 [+16.79] 97.70 [+1.23] 78.81 [+8.76] 96.69 [+5.87] 75.51 [+0.88] 93.91 [+30.12] 93.01 [+3.21] 88.10 [+8.79]
CLIP 75.40 94.10 68.09 97.04 74.95 77.87 91.30 82.68
Unseen CoCoOp 76.83 [+1.90] 93.92 [-0.19] 70.44 [+3.45] 97.78 [+0.76] 73.09 [-2.48] 69.24 [-11.08] 91.53 [+0.25] 81.83 [-1.02]
CPL (ours) 80.19 [+6.35] 94.94 [+0.89] 73.17 [+7.46] 98.81 [+1.82] 78.90 [+5.27] 72.30 [-7.15] 93.44 [+2.34] 84.54 [+2.25]
Table 1: Result comparison between CPL and CoCoOp (Zhou et al., 2022) on seen and unseen classes across
seven image classification datasets in terms of accuracy (%) under the few-shot setting. The relative difference (%)
compared with CLIP is reported in color.
Training data used Method Flickr30k MSCOCO Average The questions are first converted into a masked
0 CLIP 83.00 53.35 68.18
CoCoOp 82.40 [-0.72] 55.55 [+4.12] 68.98 [+1.17]
template using the pre-trained T5 model and pre-
CPL (ours) 85.64 [+3.18] 57.91 [+8.55] 71.78 [+5.28] defined rules. The infilled template along with the
CoCoOp 84.80 [+2.17] 56.62 [+6.13] 70.71 [+3.71] questions will be turned into prompts that naturally
CPL (ours) 86.91 [+4.71] 58.43 [+9.52] 72.67 [+6.59]
CoCoOp 85.90 [+3.49] 58.08 [+8.87] 71.99 [+5.59]
connect questions and answers. The model will
CPL (ours) 87.74 [+5.71] 59.96 [+12.39] 73.85 [+8.32] predict whether the given prompt and image pairs
are matched. We construct the few-shot setting by
Table 2: Result comparison between CPL and CoCoOp taking 0.5%, 1%, and 3% instances for training.
on two image-text retrieval datasets, Flickr30k (Plum-
mer et al., 2015) and MSCOCO (Lin et al., 2014), on the
unseen test sets in terms of Recall@1 (%). The relative
difference (%) over CLIP is reported in color. 4.2 Implementation Details
Training data used Method VQAv2 Baselines. We mainly compare CPL with Co-
0 CLIP 11.83 CoOp (Zhou et al., 2022), one of the earliest prompt
CoCoOp 27.98 [+136.52] tuning methods proposed for vision-and-language
0.5% CPL w/o. Category Information 31.68 [+167.79] pre-trained models. CoCoOp considers each input
CPL 33.39 [+182.25]
image and injects the learnable instance-aware to-
CoCoOp 28.51 [+141.00]
1% CPL w/o. Category Information 34.70 [+193.32] kens into the context vectors as the final prompt.
CPL 35.66 [+201.44] For a fair comparison, both CPL and CoCoOp
CoCoOp 30.18 [+155.11] adopt CLIP (Radford et al., 2021) as the pre-trained
3% CPL w/o. Category Information 35.41 [+199.32]
CPL 36.32 [+207.02] vision-and-language backbone and are compared
with respect to their relative improvements over
Table 3: Result comparison on the VQAv2 zero-shot CLIP.
dataset (Goyal et al., 2017a) in terms of accuracy (%).
The relative improvements over CLIP are reported in
color. Incorporating category information into task-
Prompt Tuning. The task-agnostic prompt is ran-
relevant prompts can further improve the performance.
domly initialized from a zero-mean Gaussian dis-
tribution with the standard deviation 0.02, where
test set. We use Recall at 1 (R@1) as the default we set length L = 4 by default. For vision and
evaluation metric. language tasks, in contrast to image classification,
where an image is labeled by a category, the task-
Visual Question Answering. VQAv2 (Goyal relevant prompts comprise more fine-grained de-
et al., 2017b) is an extended dataset from the tails, usually a sentence. We here similarly to-
VQA (Antol et al., 2015) dataset. The questions are kenize the whole sentence using the CLIP word
categorized into three types: Number, Yes/No, and embedding (Radford et al., 2021), and feed the tok-
Other. We set up the experiments following An- enized results to the text encoder with task-agnostic
derson et al. (2018), which treats visual question prompt vectors, to generate the language embed-
answering as a classification problem: for each ding for each prompt. In both the image-text re-
question, the model picks the corresponding an- trieval and visual question answering, all data in
swer from a given set of predefined most frequent the test set can be treated as belonging to unseen
candidate answers and matches it with the image. classes.
4.3 Main Results
Positive Examples BERTScore Sampled Random Sampled
Image Classification. The experimental results
for image classification are shown in Table 1. With
better prompts learned from counterfactual exam- Tabby cat Tiger cat Jeep
ples, our CPL method achieves clear advantages (BERTScore = 0.9126) (BERTScore = 0.8556)
across almost all datasets. Particularly on unseen
classes, we gain an average relative improvement
of 3.55%. A big bunch of ripe
yellow bananas on
Bunches of bananas are
neatly arranged on a
The plate is empty on
the table
Meanwhile, CoCoOp shows its poor generaliza- display display
(BERTScore = 0.9313)
(BERTScore = 0.8908)
Figure 6: Ablation of four different λ values on the
72.83 SUN397 dataset in terms of average accuracy (%). The
performance of CPL peaks at λ = 1.
