Image-Text Semantic Matching with AutoMM - Zero-Shot

Open In Colab Open In SageMaker Studio Lab

The task of image-text semantic matching refers to measuring the visual-semantic similarity between an image and a sentence. AutoMM supports zero-shot image-text matching by leveraging the powerful CLIP. Thanks to the contrastive loss objective and trained on millions of image-text pairs, CLIP learns good embeddings for both vision and language, and their connections. Hence, we can use it to extract embeddings for retrieval and matching.

CLIP has a two-tower architecture, which means it has two encoders: one for image, the other for text. An overview of CLIP model can be seen in the diagram below. Left shows its pre-training stage, and Right shows its zero-shot predicton stage. By computing the cosine similarity scores between one image embedding and all the text images, we pick the text which has the highest similarity as the prediction.

Given the two encoders, we can extract image embeddings, or text embeddings. And most importantly, embedding extraction can be done offline, only similarity computation needs to be done online. So this means good scalability. CLIP

In this tutorial, we will show how the AutoMM’s easy-to-use APIs can ship the powerful CLIP to you.

Prepare Demo Data

First, let’s get some texts and download some images. These images are from COCO datasets.

from autogluon.multimodal import download

texts = [
    "A cheetah chases prey on across a field.",
    "A man is eating a piece of bread.",
    "The girl is carrying a baby.",
    "There is an airplane over a car.",
    "A man is riding a horse.",
    "Two men pushed carts through the woods.",
    "There is a carriage in the image.",
    "A man is riding a white horse on an enclosed ground.",
    "A monkey is playing drums.",
]

urls = ['http://farm4.staticflickr.com/3179/2872917634_f41e6987a8_z.jpg',
        'http://farm4.staticflickr.com/3629/3608371042_75f9618851_z.jpg',
        'https://farm4.staticflickr.com/3795/9591251800_9c9727e178_z.jpg',
        'http://farm8.staticflickr.com/7188/6848765123_252bfca33d_z.jpg',
        'https://farm6.staticflickr.com/5251/5548123650_1a69ce1e34_z.jpg']

image_paths = [download(url) for url in urls]
Downloading 2872917634_f41e6987a8_z.jpg from http://farm4.staticflickr.com/3179/2872917634_f41e6987a8_z.jpg...
Downloading 3608371042_75f9618851_z.jpg from http://farm4.staticflickr.com/3629/3608371042_75f9618851_z.jpg...
Downloading 9591251800_9c9727e178_z.jpg from https://farm4.staticflickr.com/3795/9591251800_9c9727e178_z.jpg...
Downloading 6848765123_252bfca33d_z.jpg from http://farm8.staticflickr.com/7188/6848765123_252bfca33d_z.jpg...
Downloading 5548123650_1a69ce1e34_z.jpg from https://farm6.staticflickr.com/5251/5548123650_1a69ce1e34_z.jpg...
                     

Extract Embeddings

We need to use image_text_similarity as the problem type when initializing the predictor.

from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor(problem_type="image_text_similarity")

Let’s extract image and text embeddings separately. The image and text data will go through their corresponding encoders, respectively.

image_embeddings = predictor.extract_embedding(image_paths, as_tensor=True)
print(image_embeddings.shape)
torch.Size([5, 512])
/home/ci/opt/venv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
text_embeddings = predictor.extract_embedding(texts, as_tensor=True)
print(text_embeddings.shape)
torch.Size([9, 512])

Then you can use the embeddings for a range of tasks such as image retrieval and text retrieval.

Image Retrieval with Text Query

Suppose we have a large image database (e.g., video footage), now we want to retrieve some images defined by a text query. How can we do this?

It is simple. First, extract all the image embeddings offline as shown above. Then, extract the text query’s embedding. Finally, compute the cosine similarities between the text embedding and all the image embeddings and return the top candidates.

Suppose we use the text below as the query.

print(texts[6])
There is a carriage in the image.

You can directly call our util function semantic_search to search semantically similar images.

from autogluon.multimodal.utils import semantic_search
hits = semantic_search(
        matcher=predictor,
        query_embeddings=text_embeddings[6][None,],
        response_embeddings=image_embeddings,
        top_k=5,
    )
print(hits)
[[{'response_id': 2, 'score': 0.2744503319263458}, {'response_id': 4, 'score': 0.22526244819164276}, {'response_id': 0, 'score': 0.21866320073604584}, {'response_id': 1, 'score': 0.21707728505134583}, {'response_id': 3, 'score': 0.20675909519195557}]]

We can see that we successfully find the image with a carriage in it.

from IPython.display import Image, display
pil_img = Image(filename=image_paths[hits[0][0]["response_id"]])
display(pil_img)
../../../_images/87e28b48842eb33ab4abd47639739d9315f1669977488525ed2baea0d7cb8fba.jpg

Text Retrieval with Image Query

Similarly, given one text database and an image query, we can search texts that match the image. For example, let’s search texts for the following image.

pil_img = Image(filename=image_paths[4])
display(pil_img)
../../../_images/1bb148081a3ce140fc18fd77f2ab63ee353af233b3fdb5cc06d049c25692a38b.jpg

We still use the semantic_search function, but switch the assignments of query_embeddings and response_embeddings.

hits = semantic_search(
        matcher=predictor,
        query_embeddings=image_embeddings[4][None,],
        response_embeddings=text_embeddings,
        top_k=5,
    )
print(hits)
[[{'response_id': 3, 'score': 0.2526739835739136}, {'response_id': 6, 'score': 0.22526244819164276}, {'response_id': 7, 'score': 0.1940707564353943}, {'response_id': 2, 'score': 0.18509851396083832}, {'response_id': 4, 'score': 0.18197666108608246}]]

We can observe that the top-1 text matches the query image.

texts[hits[0][0]["response_id"]]
'There is an airplane over a car.'

Predict Whether Image-Text Pairs Match

In addition to retrieval, we can let the predictor tell us whether image-text pairs match. To do so, we need to initialize the predictor with the additional arguments query and response, which represent names of image/text and text/image.

predictor = MultiModalPredictor(
            query="abc",
            response="xyz",
            problem_type="image_text_similarity",
        )

Given image-text pairs, we can make predictions.

pred = predictor.predict({"abc": [image_paths[4]], "xyz": [texts[3]]})
print(pred)
[1]
/home/ci/opt/venv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(

Predict Matching Probabilities

It is also easy to predict the matching probabilities. You can make predictions by applying customized thresholds to the probabilities.

proba = predictor.predict_proba({"abc": [image_paths[4]], "xyz": [texts[3]]})
print(proba)
[[0.37367177 0.62632823]]

Other Examples

You may go to AutoMM Examples to explore other examples about AutoMM.

Customization

To learn how to customize AutoMM, please refer to Customize AutoMM.