ViT from scratch¶

by Talgat Daulbaev

Motivation for the seminar¶

  • Probably, „transformer“ is one of the most frequent word in the lecture
  • Transformers are the best idea in AI by Andrej Karpathy
  • Paperswithcode for image classification
  • „How to achieve success in a machine learning PhD“ by Patrick Kidger:

...Write your own implementation of multihead attention...

Today we are going to write our own implementation of ViT!

One-head scaled dot-product self-attention¶

  • Each sample is a variable length sequence of $D$-dimensional features: $x_1, x_2, \ldots, x_{n_i}$
  • Thus, each sample can be viewed as a matrix $X \in \mathbb{R}^{n_i \times d}$
  • Parameters: $W_Q, W_K, W_V \in \mathbb{R}^{D \times d}$
  • Compute:
    • $Q = X W_Q$ of size $n_i \times d$
    • $K = X W_K$ of size $n_i \times d$
    • $V = X W_V$ of size $n_i \times d$
    • $$\text{Attention}(Q, K, V) = \text{softmax_for_rows}\left(\dfrac{Q K^\top}{\sqrt{d}}\right) V$$
  • You can additionally normalize Q, K matrices

Attention is a „soft dictionary“¶

  • {key1: value1, key2: value2, ...}
  • Query is key that you are searching for
  • Imagine that $q_i$ and $k_j$ are normalized: $\|q_i\| = \|k_j\| = 1$
  • Then, $\langle q_i, k_j \rangle = \cos(\theta)$, where $\theta$ is an angle between $q_i$ and $k_j$ — cosine similarity
  • We are looking at the similarity measure between $q_i$ and all key vectors and compute a linear combination of values
$$\text{softmax_for_rows}\left(\dfrac{Q K^\top}{\sqrt{d}}\right) V$$

[Task] One-head scaled dot-product self-attention¶

$$\text{Attention}(Q, K, V) = \text{softmax_for_rows}\left(\dfrac{Q K^\top}{\sqrt{d}}\right) V$$

bit.ly/vit_sem¶

In [268]:
import torch
from torch import nn

class Attention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.scale = self.dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3)

    def forward(self, x):
        '''
        Args: 
            x: Tensor of shape (batch_size, seq_len, input_dim)
            
        Returns:
            Tensor of shape (batch_size, seq_len, input_dim)
        '''
        # Your code is here
        return x

[Solution] One-head scaled dot-product self-attention¶

In [269]:
class Attention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.scale = self.dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=False)

    def forward(self, x):
        '''
        Args: 
            x: Tensor of shape (batch_size, seq_len, input_dim)
            
        Returns:
            Tensor of shape (batch_size, seq_len, input_dim)
        '''
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.dim)
        q, k, v = qkv.unbind(2)
        # ...or q, k, v = [qkv[:, :, idx, :] for idx in range(3)]
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        x = attn @ v
        return x
In [270]:
x = torch.ones(11, 12, 8)
assert Attention(8)(x).shape == x.shape

Multi-head attention¶

  • Divide each vector in a sequence into num_heads vectors ($d$ mod num_heads = 0)
  • Apply attention layers independently, concatenate the result
$$\text{head}_i = \text{Attention}(Q_i, K_i, V_i)$$$$ \textrm{concat} \left( \text{head}_1, \text{head}_2, \ldots, \text{head}_h \right) $$
  • Apply an extra linear layer to mix independent attention branches
  • How to implement without loops?

Multi-head attention¶

Q: How to implement without loops?

A: Use a single QKV-linear layer, then reshape and slice

[Task] Multi-Head Attention¶

In [114]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        if dim % num_heads:
            raise ValueError('dim % num_heads != 0')
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        '''
        Args: 
            x: Tensor of shape (batch_size, seq_len, input_dim)
            
        Returns:
            Tensor of shape (batch_size, seq_len, input_dim)
        '''
        # Hint: you might want to use torch.permute function
        return x

[Solution] Multi-Head Attention¶

In [289]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        if dim % num_heads:
            raise ValueError('dim % num_heads != 0')
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        '''
        Args: 
            x: Tensor of shape (batch_size, seq_len, input_dim)
            
        Returns:
            Tensor of shape (batch_size, seq_len, input_dim)
        '''
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        # qkv: 3 × B × num_heads × N × head_dim
        q, k, v = qkv.unbind(0)
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        x = attn @ v  # attn: B × num_heads × N × N    v: B × num_heads × N × head_dim
        # B × num_heads × N × head_dim
        x = x.transpose(1, 2).reshape(B, N, C) 
        # B × N × (num_heads × head_dim)
        x = self.proj(x)
        return x
In [290]:
MultiHeadAttention(128, 8)(torch.ones(11, 12, 128)).shape
Out[290]:
torch.Size([11, 12, 128])

Transformer Block¶

[Task] Transformer Block¶

In [271]:
class Block(nn.Module):

    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=4,  # ratio between hidden_dim and input_dim in MLP
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm
    ):
        super().__init__()
        # Your code is here

    def forward(self, x):
        # Your code is here
        return x

[Solution] Transformer Block¶

In [272]:
class Block(nn.Module):

    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=4,  # ratio between hidden_dim and input_dim in MLP
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = MultiHeadAttention(dim, num_heads=num_heads)
        self.norm2 = norm_layer(dim)
        self.mlp = nn.Sequential(nn.Linear(dim, dim * mlp_ratio), 
                                 act_layer(), 
                                 nn.Linear(dim * mlp_ratio, dim))

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x
In [273]:
depth = 12
many_layers = nn.Sequential(*[Block(128, 8) for _ in range(depth)])

Einsum¶

Let's learn torch.einsum by example

$$c_{n} = \sum_{i, j} a_{n i j} b_{n j} $$
In [274]:
A = torch.randn(10, 5, 7)
B = torch.randn(10, 7)
C = torch.einsum('nij,nj->n', A, B)
C
Out[274]:
tensor([-0.9909, 15.4934,  3.8673, 10.1980, -8.3178,  2.1868, -3.9169, -3.2118,
        -4.8607,  2.9071])
In [275]:
C_loop = torch.zeros(10)
for n in range(10):
    for i in range(5):
        for j in range(7):
            C_loop[n] += A[n, i, j] * B[n, j]
C_loop
            
Out[275]:
tensor([-0.9909, 15.4933,  3.8673, 10.1980, -8.3178,  2.1868, -3.9169, -3.2118,
        -4.8607,  2.9071])

Einsum examples¶

  • Transposition: torch.einsum('ij->ji', A)
  • Scalar product: torch.einsum('i,i->i', A, B)
  • Matrix product: torch.einsum('ik,kj->ij', A, B)
  • ...

Einops.rearrange¶

https://github.com/arogozhnikov/einops

In [276]:
! python3 -m pip install einops -q
from einops import rearrange
In [277]:
# Transposition:
rearrange(torch.arange(1024).reshape(2, 4, 8, 16), 'aa b c d -> d c b aa').shape 
Out[277]:
torch.Size([16, 8, 4, 2])
In [278]:
res = rearrange(torch.arange(30).reshape(5, 6), 'a (b c) -> a b c', b=2, c=3)
res.shape
Out[278]:
torch.Size([5, 2, 3])
In [279]:
torch.allclose(res[0], res[0].flatten().reshape(2, 3))
Out[279]:
True

Average Pooling with einops.rearrange¶

In [280]:
from einops import rearrange
img = torch.randn(3, 32, 32)
blocks = rearrange(img, 'c (h h_patch) (w w_patch) -> (c h w) (h_patch w_patch)', h_patch=2, w_patch=2)
avgpool = torch.mean(blocks, dim=-1).reshape(3, img.shape[1] // 2, img.shape[2] // 2)
In [281]:
torch.allclose(avgpool, torch.nn.functional.avg_pool2d(img, kernel_size=(2, 2)))
Out[281]:
True

This operation can also be done via einops.reduce function:

In [282]:
import einops
img = torch.randn(3, 32, 32)
avgpool = einops.reduce(img, 'c (h h_patch) (w w_patch) -> c h w', h_patch=2, w_patch=2, reduction='mean')
torch.allclose(avgpool, torch.nn.functional.avg_pool2d(img, kernel_size=(2, 2)))
Out[282]:
True

[Task] Patches crafting¶

In [283]:
! python3 -m pip install einops -q
from einops import rearrange

def img2patches(img, patch_size=8):
    '''
    Args:
        img: (batch_size, c, h, w) Tensor
        
    Returns:
        (batch_size, num_patches, vectorized_patch) Tensor
    '''
    # Your code is here

[Solution] Patches crafting¶

In [284]:
def img2patches(img, patch_size=8):
    '''
    Args:
        img: (batch_size, c, h, w) Tensor
        
    Returns:
        (batch_size, num_patches, vectorized_patch) Tensor
    '''
    return rearrange(img, 'batch_size c (h ph) (w pw) -> batch_size (h w) (c ph pw)', 
                     ph=patch_size, pw=patch_size)
    
    
img2patches(torch.ones(2, 3, 264, 264)).shape
Out[284]:
torch.Size([2, 1089, 192])

  • CLS token: an extra learnable token
  • Position embeddings: x = x + pos_embedding, where pos_embedding is trained for every element is a sequence

[Task] Build ViT¶

In [286]:
class ViT(nn.Module):
    def __init__(
                    self,
                    img_size=(224, 224),
                    patch_size=16,
                    in_chans=3,
                    num_classes=10,
                    embed_dim=768,
                    depth=12,
                    num_heads=12,
                    mlp_ratio=4,
                    class_token=True,
                    norm_layer=nn.LayerNorm,
                    act_layer=nn.GELU
            ):
        # Your code is here
        
        pass
        
    def forward(self, x):
        '''
        Args: 
            x: (batch_size, in_channels, img_size[0], img_size[1])
            
        Return:
            (batch_size, num_classes) probabilities
        '''
        pass
    
In [287]:
class ViT(nn.Module):
    def __init__(
                    self,
                    img_size=(224, 224),
                    patch_size=16,
                    in_chans=3,
                    num_classes=10,
                    embed_dim=768,
                    depth=12,
                    num_heads=12,
                    mlp_ratio=4,
                    norm_layer=nn.LayerNorm,
                    act_layer=nn.GELU
            ):
        # Your code is here
        super().__init__()
        self.patch_size = patch_size
        self.blocks = nn.Sequential(*[
            Block(embed_dim, num_heads, mlp_ratio, act_layer, norm_layer) for _ in range(depth)
        ])
        self.patch_proj = nn.Linear(3 * patch_size * patch_size, embed_dim)
        self.embed_len = (img_size[0] * img_size[1]) // (patch_size * patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, self.embed_len, embed_dim) * .02)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.head = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        '''
        Args: 
            x: (batch_size, in_channels, img_size[0], img_size[1])
            
        Return:
            (batch_size, num_classes)
        '''
        x = img2patches(x, patch_size=self.patch_size)
        x = self.patch_proj(x)
        x = x + self.pos_embed
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        x = self.blocks(x)
        x = x[:, 0, :]  # take CLS token
        return self.head(x)
In [291]:
ViT()(torch.ones(5, 3, 224, 224)).shape
Out[291]:
torch.Size([5, 10])

Practical Notes¶

  • Almost all popular CV models are implemented in timm
  • There is a fast implementation of the multi-head attention in torch
In [ ]: