by Talgat Daulbaev
...Write your own implementation of multihead attention...
Today we are going to write our own implementation of ViT!
import torch
from torch import nn
class Attention(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.scale = self.dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
def forward(self, x):
'''
Args:
x: Tensor of shape (batch_size, seq_len, input_dim)
Returns:
Tensor of shape (batch_size, seq_len, input_dim)
'''
# Your code is here
return x
class Attention(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.scale = self.dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
def forward(self, x):
'''
Args:
x: Tensor of shape (batch_size, seq_len, input_dim)
Returns:
Tensor of shape (batch_size, seq_len, input_dim)
'''
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.dim)
q, k, v = qkv.unbind(2)
# ...or q, k, v = [qkv[:, :, idx, :] for idx in range(3)]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
return x
x = torch.ones(11, 12, 8)
assert Attention(8)(x).shape == x.shape
num_heads
vectors ($d$ mod num_heads
= 0)Q: How to implement without loops?
A: Use a single QKV-linear layer, then reshape and slice
class MultiHeadAttention(nn.Module):
def __init__(self, dim, num_heads=8):
super().__init__()
if dim % num_heads:
raise ValueError('dim % num_heads != 0')
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
'''
Args:
x: Tensor of shape (batch_size, seq_len, input_dim)
Returns:
Tensor of shape (batch_size, seq_len, input_dim)
'''
# Hint: you might want to use torch.permute function
return x
class MultiHeadAttention(nn.Module):
def __init__(self, dim, num_heads=8):
super().__init__()
if dim % num_heads:
raise ValueError('dim % num_heads != 0')
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
'''
Args:
x: Tensor of shape (batch_size, seq_len, input_dim)
Returns:
Tensor of shape (batch_size, seq_len, input_dim)
'''
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
# qkv: 3 × B × num_heads × N × head_dim
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v # attn: B × num_heads × N × N v: B × num_heads × N × head_dim
# B × num_heads × N × head_dim
x = x.transpose(1, 2).reshape(B, N, C)
# B × N × (num_heads × head_dim)
x = self.proj(x)
return x
MultiHeadAttention(128, 8)(torch.ones(11, 12, 128)).shape
torch.Size([11, 12, 128])
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4, # ratio between hidden_dim and input_dim in MLP
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__()
# Your code is here
def forward(self, x):
# Your code is here
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4, # ratio between hidden_dim and input_dim in MLP
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = MultiHeadAttention(dim, num_heads=num_heads)
self.norm2 = norm_layer(dim)
self.mlp = nn.Sequential(nn.Linear(dim, dim * mlp_ratio),
act_layer(),
nn.Linear(dim * mlp_ratio, dim))
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
depth = 12
many_layers = nn.Sequential(*[Block(128, 8) for _ in range(depth)])
Let's learn torch.einsum
by example
A = torch.randn(10, 5, 7)
B = torch.randn(10, 7)
C = torch.einsum('nij,nj->n', A, B)
C
tensor([-0.9909, 15.4934, 3.8673, 10.1980, -8.3178, 2.1868, -3.9169, -3.2118, -4.8607, 2.9071])
C_loop = torch.zeros(10)
for n in range(10):
for i in range(5):
for j in range(7):
C_loop[n] += A[n, i, j] * B[n, j]
C_loop
tensor([-0.9909, 15.4933, 3.8673, 10.1980, -8.3178, 2.1868, -3.9169, -3.2118, -4.8607, 2.9071])
torch.einsum('ij->ji', A)
torch.einsum('i,i->i', A, B)
torch.einsum('ik,kj->ij', A, B)
! python3 -m pip install einops -q
from einops import rearrange
# Transposition:
rearrange(torch.arange(1024).reshape(2, 4, 8, 16), 'aa b c d -> d c b aa').shape
torch.Size([16, 8, 4, 2])
res = rearrange(torch.arange(30).reshape(5, 6), 'a (b c) -> a b c', b=2, c=3)
res.shape
torch.Size([5, 2, 3])
torch.allclose(res[0], res[0].flatten().reshape(2, 3))
True
from einops import rearrange
img = torch.randn(3, 32, 32)
blocks = rearrange(img, 'c (h h_patch) (w w_patch) -> (c h w) (h_patch w_patch)', h_patch=2, w_patch=2)
avgpool = torch.mean(blocks, dim=-1).reshape(3, img.shape[1] // 2, img.shape[2] // 2)
torch.allclose(avgpool, torch.nn.functional.avg_pool2d(img, kernel_size=(2, 2)))
True
This operation can also be done via einops.reduce
function:
import einops
img = torch.randn(3, 32, 32)
avgpool = einops.reduce(img, 'c (h h_patch) (w w_patch) -> c h w', h_patch=2, w_patch=2, reduction='mean')
torch.allclose(avgpool, torch.nn.functional.avg_pool2d(img, kernel_size=(2, 2)))
True
! python3 -m pip install einops -q
from einops import rearrange
def img2patches(img, patch_size=8):
'''
Args:
img: (batch_size, c, h, w) Tensor
Returns:
(batch_size, num_patches, vectorized_patch) Tensor
'''
# Your code is here
def img2patches(img, patch_size=8):
'''
Args:
img: (batch_size, c, h, w) Tensor
Returns:
(batch_size, num_patches, vectorized_patch) Tensor
'''
return rearrange(img, 'batch_size c (h ph) (w pw) -> batch_size (h w) (c ph pw)',
ph=patch_size, pw=patch_size)
img2patches(torch.ones(2, 3, 264, 264)).shape
torch.Size([2, 1089, 192])
x = x + pos_embedding
, where pos_embedding
is trained for every element is a sequenceclass ViT(nn.Module):
def __init__(
self,
img_size=(224, 224),
patch_size=16,
in_chans=3,
num_classes=10,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
class_token=True,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU
):
# Your code is here
pass
def forward(self, x):
'''
Args:
x: (batch_size, in_channels, img_size[0], img_size[1])
Return:
(batch_size, num_classes) probabilities
'''
pass
class ViT(nn.Module):
def __init__(
self,
img_size=(224, 224),
patch_size=16,
in_chans=3,
num_classes=10,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU
):
# Your code is here
super().__init__()
self.patch_size = patch_size
self.blocks = nn.Sequential(*[
Block(embed_dim, num_heads, mlp_ratio, act_layer, norm_layer) for _ in range(depth)
])
self.patch_proj = nn.Linear(3 * patch_size * patch_size, embed_dim)
self.embed_len = (img_size[0] * img_size[1]) // (patch_size * patch_size)
self.pos_embed = nn.Parameter(torch.randn(1, self.embed_len, embed_dim) * .02)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
'''
Args:
x: (batch_size, in_channels, img_size[0], img_size[1])
Return:
(batch_size, num_classes)
'''
x = img2patches(x, patch_size=self.patch_size)
x = self.patch_proj(x)
x = x + self.pos_embed
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.blocks(x)
x = x[:, 0, :] # take CLS token
return self.head(x)
ViT()(torch.ones(5, 3, 224, 224)).shape
torch.Size([5, 10])