Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

From Vision Transformers to Masked Autoencoders in 5 Minutes

A Straightforward Guide on How NLP Tasks Generalize to Computer Vision

Essam Wisam
Towards Data Science
7 min readJun 28, 2024

--

Nearly all natural language processing tasks which range from language modeling and masked word prediction to translation and question-answering were revolutionized as the transformer architecture made its debute in 2017. It didn’t take more than 2–3 years for transformers to also excel in computer vision tasks. In this story, we explore two fundamental architectures that enabled transformers to break into the world of computer vision.

Table of Contents

· The Vision Transformer
Key Idea
Operation
Hybrid Architecture
Loss of Structure
Results
Self-supervised Learning by Masking
· Masked Autoencoder Vision Transformer
Key Idea
Architecture
Final Remark and Example

The Vision Transformer

Image from Paper: “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”

Key Idea

The vision transformer is simply meant to generalize the standard transformer architecture to process and learn from image input. There is a key idea about the architecture that the authors were transparent enough to highlight:

“Inspired by the Transformer scaling successes in NLP, we experiment with applying a standard Transformer directly to images, with the fewest possible modifications.”

Operation

It’s valid to take “fewest possible modifications” quite literally because they pretty much make zero modifications. What they actuall modify is input structure:

  • In NLP, the transformer encoder takes a sequence of one-hot vectors (or equivalently token indices) that represent the input sentence/paragraph and returns a sequence of contextual embedding vectors that could be used for a further tasks (e.g., classification)
  • To generalize the CV, the vision transformer takes a sequence of patch vectors that represent the input image and returns a sequence of contextual embedding vectors that could be used for a further tasks (e.g., classification)

In particular, suppose the input images have dimensions (n,n,3) to pass this as an input to the transformer, what the vision transformer does is:

  • Divides it into k² patches for some k (e.g., k=3) as in the figure above.
  • Now each patch will be (n/k,n/k,3) the next step is to flatten each patch into a vector

The patch vector will be of dimensionality 3*(n/k)*(n/k). For example, if the image is (900,900,3) and we use k=3 then a patch vector will have dimensionality 300*300*3 representing the pixel values in the flattened patch. In the paper, authors use k=16. Hence, the paper’s name “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale” instead of feeding a one-hot vector representing the word they represent a vector pixels representing a patch of the image.

The rest of the operations remains as in the original transformer encoder:

  • These patch vectors pass by a trainable embedding layer
  • Positional embeddings are added to each vector to maintain a sense of spatial information in the image
  • The output is num_patches encoder representations (one for each patch) which could be used for classification on the patch or image level
  • More often (and as in the paper), a CLS token is prepended the representation corresponding to that is used to make a prediction over the whole image (similar to BERT)

How about the transformer decoder?

Well, remember it’s just like the transformer encoder; the difference is that it uses masked self-attention instead of self-attention (but the same input signature remains). In any case, you should expect to seldom use a decoder-only transformer architecture because simply predicting the next patch may not a task of great interest.

Hybrid Architecture

Authors also mentions that it’s possible to start with a CNN feature map instead of the image itself to form a hybrid architecture (CNN feeding output to vision transformer). In this case, we think of the input as a generic (n,n,p) feature map and a patch vector will have dimensions (n/k)*(n/k)*p.

Loss of Structure

It may cross your mind that this architecture shouldn’t be so good because it treated the image as a linear structure when it isn’t. The author try to depict that this is intentional by mentioning

“The two-dimensional neighborhood structure is used very sparingly…position embeddings at initialization time carry no information about the 2D positions of the patches and all spatial relations between the patches have to be learned from scratch”

We will see that the transformer is able to learn this as evidenced by its good performance in their experiments and more importantly the architecture in the next paper.

Results

The main verdict from the results is that vision transformers tend to not outperform CNN-based models for small datasets but approach or outperofrm CNN-based models for larger datasets and either way require significantly less compute:

Table from Paper: “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”.

Here we see that for the JFT-300M dataset (which has 300M images), the ViT models pre-trained on the dataset outperform ResNet-based baselines while taking substantially less computational resources to pre-train. As can be seen the larget vision transformer they used (ViT-Huge with 632M parameters and k=16) used about 25% of the compute used for the ResNet based model and still outperformed it. The performance doesn’t even downgrade that much with ViT-Large using only <6.8% of the compute.

Meanwhile, others also expose results where the ResNet performed significantly better when trained on ImageNet-1K which has just 1.3M images.

Self-supervised Learning by Masking

Authors performed a preliminary exploration on masked patch prediction for self-supervision, mimicking the masked language modeling task used in BERT (i.e., masking out patches and attempting to predict them).

“We employ the masked patch prediction objective for preliminary self-supervision experiments. To do so we corrupt 50% of patch embeddings by either replacing their embeddings with a learnable [mask] embedding (80%), a random other patch embedding (10%) or just keeping them as is (10%).”

With self-supervised pre-training, their smaller ViT-Base/16 model achieves 79.9% accuracy on ImageNet, a significant improvement of 2% to training from scratch. But still 4% behind supervised pre-training.

Masked Autoencoder Vision Transformer

Image from Paper: Masked Autoencoders Are Scalable Vision Learners

Key Idea

As we have seen from the vision transformer paper, the gains from pretraining by masking patches in input images were not as significant as in ordinary NLP where masked pretraining can lead to state-of-the-art results in some fine-tuning tasks.

This paper proposes a vision transformer architecture involving an encoder and a decoder that when pretrained with masking results in significant improvements over the base vision transformer model (as much as 6% improvement compared to training a base size vision transformer in a supervised fashion).

Image from Paper: Masked Autoencoders Are Scalable Vision Learners

This is some sample (input, output, true labels). It’s an autoencoder in the sense that it tried to reconstruct the input while filling the missing patches.

Architecture

Their encoder is simply the ordinary vision transformer encoder we explained earlier. In training and inference, it takes only the “observed” patches.

Meanwhile, their decoder is also simply the ordinary vision transformer encoder but it takes:

  • Masked token vectors for the missing patches
  • Encoder output vectors for the known patches

So for an image [ [ A, B, X], [C, X, X], [X, D, E]] where X denotes a missing patch, the decoder will take the sequence of patch vectors [Enc(A), Enc(B), Vec(X), Vec(X), Vec(X), Enc(D), Enc(E)]. Enc returns the encoder output vector given the patch vector and X is a vector to represent missing token.

The last layer in the decoder is a linear layer that maps the contextual embeddings (produced by the vision transformer encoder in the decoder) to a vector of length equal to the patch size. The loss function is mean squared error which squares the difference between the original patch vector and the predicted one by this layer. In the loss function, we only look at the decoder predictions due to masked tokens and ignore the ones corresponding the present ones (i.e., Dec(A),. Dec(B), Dec(C), etc.).

Final Remark and Example

It may be surprising that the authors suggest masking about 75% of the patches in the images; BERT would mask only about 15% of the words. They justify like so:

Images,are natural signals with heavy spatial redundancy — e.g., a missing patch can be recovered from neighboring patches with little high-level understanding of parts, objects, and scenes. To overcome this difference and encourage learning useful features, we mask a very high portion of random patches.

Want to try it out yourself? Checkout this demo notebook by NielsRogge.

This is all for this story. We went through a journey to understand how fundamental transformer models generalize to the computer vision world. Hope you have found it clear, insighful and worth your time.

References:

[1] Dosovitskiy, A. et al. (2021) An image is worth 16x16 words: Transformers for image recognition at scale, arXiv.org. Available at: https://arxiv.org/abs/2010.11929 (Accessed: 28 June 2024).

[2] He, K. et al. (2021) Masked autoencoders are scalable vision learners, arXiv.org. Available at: https://arxiv.org/abs/2111.06377 (Accessed: 28 June 2024).

--

--