Semantic Segmentation (solutions)¶

img

Unet architecture¶

img

In [18]:
import torch
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
%matplotlib inline
import gdown
In [2]:
gdown.download(url="https://drive.google.com/file/d/1kiR9pPP3gIqoyT9bWUnA0rgZuRTiUv7q/view?usp=share_link", output="./unet_dict.pt", quiet=False, fuzzy=True)
Downloading...
From: https://drive.google.com/uc?id=1kiR9pPP3gIqoyT9bWUnA0rgZuRTiUv7q
To: /content/unet_dict.pt
100%|██████████| 124M/124M [00:00<00:00, 238MB/s]
Out[2]:
'./unet_dict.pt'
In [5]:
class Encoder_Block(torch.nn.Module):
  def __init__(self,inp_channels,out_channels):
    super().__init__()
    self.model = torch.nn.Sequential(
        torch.nn.Conv2d(inp_channels,out_channels,kernel_size=3,padding=1),
        torch.nn.BatchNorm2d(out_channels),
        torch.nn.ReLU(),
        torch.nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1),
        torch.nn.BatchNorm2d(out_channels),
        torch.nn.ReLU(),
    )
    self.pooling = torch.nn.MaxPool2d(2)
  def forward(self,x):
    int_out = self.model(x)
    return self.pooling(int_out), int_out

class Decoder_Block(torch.nn.Module):
  def __init__(self,inp_channels,out_channels):
    super().__init__()
    self.upsample = torch.nn.ConvTranspose2d(inp_channels,out_channels,kernel_size=2,stride=2)
    self.model = torch.nn.Sequential(
        torch.nn.Conv2d(inp_channels,out_channels,kernel_size=3,padding=1),
        torch.nn.BatchNorm2d(out_channels),
        torch.nn.ReLU(),
        torch.nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1),
        torch.nn.BatchNorm2d(out_channels),
        torch.nn.ReLU(),
    )
  def forward(self,x,enc_x):
    x = self.upsample(x)
    x = torch.cat([x,enc_x],dim=1)
    return self.model(x)

Let's check our implementation¶

In [6]:
enc = Encoder_Block(64,128)
dec = Decoder_Block(256,128)
inp = torch.randn(1,64,64,64)
x,enc_x = enc(inp)

assert x.shape == (1,128,32,32), "correct encoder implementation"

inp = torch.randn(1,256,32,32)
x = dec(inp,enc_x)

assert x.shape == (1,128,64,64), "correct decoder implementation"
In [8]:
class Unet(torch.nn.Module):
  def __init__(self,inc,outc,hidden_size=64):
    super().__init__()
    self.Encoder = torch.nn.ModuleList([
        Encoder_Block(inc,hidden_size),
        Encoder_Block(hidden_size,hidden_size*2),
        Encoder_Block(hidden_size*2,hidden_size*4),
        Encoder_Block(hidden_size*4,hidden_size*8),
    ])
    self.bottleneck = torch.nn.Sequential(
        torch.nn.Conv2d(hidden_size*8,hidden_size*16,kernel_size=1),
        torch.nn.BatchNorm2d(hidden_size*16),
        torch.nn.ReLU(),
        torch.nn.Conv2d(hidden_size*16,hidden_size*16,kernel_size=1),
        torch.nn.BatchNorm2d(hidden_size*16),
        torch.nn.ReLU()
    )
    self.Decoder = torch.nn.ModuleList([
        Decoder_Block(hidden_size*16,hidden_size*8),
        Decoder_Block(hidden_size*8,hidden_size*4),
        Decoder_Block(hidden_size*4,hidden_size*2),
        Decoder_Block(hidden_size*2,hidden_size*1),
    ])
    self.last_layer = torch.nn.Conv2d(hidden_size,outc,kernel_size=3,padding="same")
  def forward(self,x):
    enc_xs = []
    for module in self.Encoder:
      x, enc_x= module(x)
      enc_xs.append(enc_x)
    
    enc_xs = enc_xs[::-1]
    x = self.bottleneck(x)

    for i,module in enumerate(self.Decoder):
      x = module(x,enc_xs[i])
    return self.last_layer(x)

Check that your implementation works correctly¶

In [9]:
unet = Unet(3,11)

assert unet(torch.randn(1,3,128,128)).shape == (1,11,128,128), "check your implementation"

Let's start working with data¶

In [10]:
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.CenterCrop(128),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5,std=0.5)
])
def class_split(data,n=3):
  data = np.array(data)
  res = []
  for i in range(1,1+n):
    mask = np.zeros_like(data)
    mask[data==i] = 1.
    res.append(mask[None])
  return torch.from_numpy(np.concatenate(res,axis=0)).to(torch.float)

target_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.CenterCrop(128),
    class_split
])
dataset = torchvision.datasets.OxfordIIITPet("./data",split="trainval",target_types="segmentation",download=True,transform=transform,target_transform=target_transform)
Downloading https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz to data/oxford-iiit-pet/images.tar.gz
100%|██████████| 791918971/791918971 [00:26<00:00, 30094727.25it/s]
Extracting data/oxford-iiit-pet/images.tar.gz to data/oxford-iiit-pet
Downloading https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz to data/oxford-iiit-pet/annotations.tar.gz
100%|██████████| 19173078/19173078 [00:01<00:00, 15159384.99it/s]
Extracting data/oxford-iiit-pet/annotations.tar.gz to data/oxford-iiit-pet
In [11]:
transforms.ToPILImage()(dataset[0][0]/2+0.5)
Out[11]:
In [12]:
transforms.ToPILImage()(dataset[0][1]/2+0.5)
Out[12]:
In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = Unet(3,3).to(device)
dataloader = torch.utils.data.DataLoader(dataset,batch_size=32,shuffle=True)
optimizer = torch.optim.Adam(unet.parameters(),lr=0.001)
In [19]:
def train(model,dataloader,optimizer,loss_func=torch.nn.CrossEntropyLoss(),epochs=5):
  for i in range(epochs):
    for x,y in tqdm(dataloader):
      x = x.to(device)
      y = y.to(device)

      out = model(x)

      loss = loss_func(out,y)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    print(i)
In [20]:
train(unet,dataloader,optimizer,epochs=10)
100%|██████████| 115/115 [00:45<00:00,  2.50it/s]
0
100%|██████████| 115/115 [00:46<00:00,  2.47it/s]
1
100%|██████████| 115/115 [00:46<00:00,  2.50it/s]
2
100%|██████████| 115/115 [00:46<00:00,  2.47it/s]
3
100%|██████████| 115/115 [00:46<00:00,  2.47it/s]
4
100%|██████████| 115/115 [00:46<00:00,  2.48it/s]
5
100%|██████████| 115/115 [00:46<00:00,  2.48it/s]
6
100%|██████████| 115/115 [00:46<00:00,  2.48it/s]
7
100%|██████████| 115/115 [00:46<00:00,  2.47it/s]
8
100%|██████████| 115/115 [00:46<00:00,  2.50it/s]
9

In [24]:
iou = []
for x,y in tqdm(dataloader):
  y_hat = (torch.nn.Softmax(dim=1)(unet(x.to(device))) > 0.5).to(torch.float)

  intersection = y_hat * y.to(device)
  union = (y_hat + y.to(device)).clamp(0,1)
  iou.append(intersection.sum()/union.sum())
100%|██████████| 115/115 [00:48<00:00,  2.35it/s]
In [26]:
iou = [item.item() for item in iou]
In [27]:
round(np.mean(iou)*100,1)
Out[27]:
78.0
In [28]:
!wget -nv "https://fikiwiki.com/uploads/posts/2022-02/1644990866_45-fikiwiki-com-p-prikolnie-kartinki-pro-zhivotnikh-47.png"
2023-04-04 12:22:25 URL:https://fikiwiki.com/uploads/posts/2022-02/1644990866_45-fikiwiki-com-p-prikolnie-kartinki-pro-zhivotnikh-47.png [1205459/1205459] -> "1644990866_45-fikiwiki-com-p-prikolnie-kartinki-pro-zhivotnikh-47.png" [1]
In [29]:
from PIL import Image
img = Image.open("/content/1644990866_45-fikiwiki-com-p-prikolnie-kartinki-pro-zhivotnikh-47.png")
In [30]:
img = transform(img)[:3][None].to(device)
mask = unet(img)[0]
In [32]:
transforms.ToPILImage()((torch.nn.Softmax(dim=0)(mask)[0:1].detach().cpu() > 0.5).to(torch.float))
Out[32]:
In [33]:
transforms.ToPILImage()(img[0].detach().cpu()/2+0.5)
Out[33]:
In [ ]: