import os import torch import numpy as np import pandas as pd import matplotlib.pyplot as plt from torch.utils.data import DataLoader import torch.nn as nn from collections import defaultdict import torchvision import torch.nn.functional as F from torch.utils.data.sampler import Sampler class Block(nn.Module): def __init__(self, in_ch, out_ch, padding='same'): super().__init__() self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=padding) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=padding) def forward(self, x): return self.conv2(self.relu(self.conv1(x))) class Encoder(nn.Module): def __init__(self, chs=(3,32,64,128,256)): super().__init__() self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) self.pool = nn.MaxPool2d(2) def forward(self, x): ftrs = [] for block in self.enc_blocks: x = block(x) ftrs.append(x) x = self.pool(x) return ftrs class Decoder(nn.Module): def __init__(self, chs=(256,128, 64, 32), aux_ch=70): super().__init__() upchs = tuple([chs[i]+aux_ch if i == 0 else chs[i] for i in range(len(chs))]) self.chs = chs self.upchs = upchs self.upconvs = nn.ModuleList([nn.ConvTranspose2d(upchs[i], upchs[i+1], 2, 2) for i in range(len(upchs)-1)]) self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) def forward(self, x, encoder_features): for i in range(len(self.chs)-1): # pdb.set_trace() x = self.upconvs[i](x) enc_ftrs = self.crop(encoder_features[i], x) x = torch.cat([x, enc_ftrs], dim=1) x = self.dec_blocks[i](x) return x def crop(self, enc_ftrs, x): _, _, H, W = x.shape enc_ftrs = torchvision.transforms.CenterCrop([H, W])(enc_ftrs) return enc_ftrs class AuxUNet(nn.Module): # UNet with auxiliary feature at the bottom def __init__(self, enc_chs=(3,32,64,128,256), dec_chs=(256,128, 64, 32), aux_ch=70, num_class=7, retain_dim=False, out_sz=(224,224)): super().__init__() self.encoder = Encoder(enc_chs) self.decoder = Decoder(dec_chs, aux_ch) self.head = nn.Conv2d(dec_chs[-1], num_class, 1) self.retain_dim = retain_dim def forward(self, x, aux): # aux: auxiliary feature at the bottom enc_ftrs = self.encoder(x) enc_ftrs[-1] = torch.cat((enc_ftrs[-1], aux), 1) out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:]) out = self.head(out) if self.retain_dim: out = F.interpolate(out, out_sz) return out