Oisin Mac Aodha
added bat code
9ace58a
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from .model_helpers import *
import torchvision
import torch.fft
from torch import nn
class Net2DFast(nn.Module):
def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5):
super(Net2DFast, self).__init__()
self.num_classes = num_classes
self.emb_dim = emb_dim
self.num_filts = num_filts
self.resize_factor = resize_factor
self.ip_height_rs = ip_height
self.bneck_height = self.ip_height_rs//32
# encoder
self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1)
self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1)
self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1)
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1)
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2)
# bottleneck
self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0)
self.conv_1d_bn = nn.BatchNorm2d(num_filts*2)
self.att = SelfAttention(num_filts*2, num_filts*2)
# decoder
self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8)
self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4)
self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2)
# output
# +1 to include background class for class output
self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1)
self.conv_op_bn = nn.BatchNorm2d(num_filts//4)
self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0)
self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0)
if self.emb_dim > 0:
self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0)
def forward(self, ip, return_feats=False):
# encoder
x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
# bottleneck
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = self.att(x)
x = x.repeat([1,1,self.bneck_height*4,1])
# decoder
x = self.conv_up_2(x+x3)
x = self.conv_up_3(x+x2)
x = self.conv_up_4(x+x1)
# output
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
op = {}
op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1)
op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True)
op['pred_class'] = comb
op['pred_class_un_norm'] = cls
if self.emb_dim > 0:
op['pred_emb'] = self.conv_emb(x)
if return_feats:
op['features'] = x
return op
class Net2DFastNoAttn(nn.Module):
def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5):
super(Net2DFastNoAttn, self).__init__()
self.num_classes = num_classes
self.emb_dim = emb_dim
self.num_filts = num_filts
self.resize_factor = resize_factor
self.ip_height_rs = ip_height
self.bneck_height = self.ip_height_rs//32
self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1)
self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1)
self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1)
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1)
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2)
self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0)
self.conv_1d_bn = nn.BatchNorm2d(num_filts*2)
self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8)
self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4)
self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2)
# output
# +1 to include background class for class output
self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1)
self.conv_op_bn = nn.BatchNorm2d(num_filts//4)
self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0)
self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0)
if self.emb_dim > 0:
self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0)
def forward(self, ip, return_feats=False):
x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = x.repeat([1,1,self.bneck_height*4,1])
x = self.conv_up_2(x+x3)
x = self.conv_up_3(x+x2)
x = self.conv_up_4(x+x1)
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
op = {}
op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1)
op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True)
op['pred_class'] = comb
op['pred_class_un_norm'] = cls
if self.emb_dim > 0:
op['pred_emb'] = self.conv_emb(x)
if return_feats:
op['features'] = x
return op
class Net2DFastNoCoordConv(nn.Module):
def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5):
super(Net2DFastNoCoordConv, self).__init__()
self.num_classes = num_classes
self.emb_dim = emb_dim
self.num_filts = num_filts
self.resize_factor = resize_factor
self.ip_height_rs = ip_height
self.bneck_height = self.ip_height_rs//32
self.conv_dn_0 = ConvBlockDownStandard(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1)
self.conv_dn_1 = ConvBlockDownStandard(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1)
self.conv_dn_2 = ConvBlockDownStandard(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1)
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1)
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2)
self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0)
self.conv_1d_bn = nn.BatchNorm2d(num_filts*2)
self.att = SelfAttention(num_filts*2, num_filts*2)
self.conv_up_2 = ConvBlockUpStandard(num_filts*2, num_filts//2, self.ip_height_rs//8)
self.conv_up_3 = ConvBlockUpStandard(num_filts//2, num_filts//4, self.ip_height_rs//4)
self.conv_up_4 = ConvBlockUpStandard(num_filts//4, num_filts//4, self.ip_height_rs//2)
# output
# +1 to include background class for class output
self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1)
self.conv_op_bn = nn.BatchNorm2d(num_filts//4)
self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0)
self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0)
if self.emb_dim > 0:
self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0)
def forward(self, ip, return_feats=False):
x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = self.att(x)
x = x.repeat([1,1,self.bneck_height*4,1])
x = self.conv_up_2(x+x3)
x = self.conv_up_3(x+x2)
x = self.conv_up_4(x+x1)
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
op = {}
op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1)
op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True)
op['pred_class'] = comb
op['pred_class_un_norm'] = cls
if self.emb_dim > 0:
op['pred_emb'] = self.conv_emb(x)
if return_feats:
op['features'] = x
return op