import torch.nn as nn import torch import torch.nn.functional as F import numpy as np from .model_helpers import * import torchvision import torch.fft from torch import nn class Net2DFast(nn.Module): def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5): super(Net2DFast, self).__init__() self.num_classes = num_classes self.emb_dim = emb_dim self.num_filts = num_filts self.resize_factor = resize_factor self.ip_height_rs = ip_height self.bneck_height = self.ip_height_rs//32 # encoder self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1) self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1) self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1) self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1) self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2) # bottleneck self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0) self.conv_1d_bn = nn.BatchNorm2d(num_filts*2) self.att = SelfAttention(num_filts*2, num_filts*2) # decoder self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8) self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4) self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2) # output # +1 to include background class for class output self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1) self.conv_op_bn = nn.BatchNorm2d(num_filts//4) self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0) self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0) if self.emb_dim > 0: self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0) def forward(self, ip, return_feats=False): # encoder x1 = self.conv_dn_0(ip) x2 = self.conv_dn_1(x1) x3 = self.conv_dn_2(x2) x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) # bottleneck x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) x = self.att(x) x = x.repeat([1,1,self.bneck_height*4,1]) # decoder x = self.conv_up_2(x+x3) x = self.conv_up_3(x+x2) x = self.conv_up_4(x+x1) # output x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) op = {} op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1) op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True) op['pred_class'] = comb op['pred_class_un_norm'] = cls if self.emb_dim > 0: op['pred_emb'] = self.conv_emb(x) if return_feats: op['features'] = x return op class Net2DFastNoAttn(nn.Module): def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5): super(Net2DFastNoAttn, self).__init__() self.num_classes = num_classes self.emb_dim = emb_dim self.num_filts = num_filts self.resize_factor = resize_factor self.ip_height_rs = ip_height self.bneck_height = self.ip_height_rs//32 self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1) self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1) self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1) self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1) self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2) self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0) self.conv_1d_bn = nn.BatchNorm2d(num_filts*2) self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8) self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4) self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2) # output # +1 to include background class for class output self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1) self.conv_op_bn = nn.BatchNorm2d(num_filts//4) self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0) self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0) if self.emb_dim > 0: self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0) def forward(self, ip, return_feats=False): x1 = self.conv_dn_0(ip) x2 = self.conv_dn_1(x1) x3 = self.conv_dn_2(x2) x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) x = x.repeat([1,1,self.bneck_height*4,1]) x = self.conv_up_2(x+x3) x = self.conv_up_3(x+x2) x = self.conv_up_4(x+x1) x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) op = {} op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1) op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True) op['pred_class'] = comb op['pred_class_un_norm'] = cls if self.emb_dim > 0: op['pred_emb'] = self.conv_emb(x) if return_feats: op['features'] = x return op class Net2DFastNoCoordConv(nn.Module): def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5): super(Net2DFastNoCoordConv, self).__init__() self.num_classes = num_classes self.emb_dim = emb_dim self.num_filts = num_filts self.resize_factor = resize_factor self.ip_height_rs = ip_height self.bneck_height = self.ip_height_rs//32 self.conv_dn_0 = ConvBlockDownStandard(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1) self.conv_dn_1 = ConvBlockDownStandard(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1) self.conv_dn_2 = ConvBlockDownStandard(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1) self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1) self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2) self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0) self.conv_1d_bn = nn.BatchNorm2d(num_filts*2) self.att = SelfAttention(num_filts*2, num_filts*2) self.conv_up_2 = ConvBlockUpStandard(num_filts*2, num_filts//2, self.ip_height_rs//8) self.conv_up_3 = ConvBlockUpStandard(num_filts//2, num_filts//4, self.ip_height_rs//4) self.conv_up_4 = ConvBlockUpStandard(num_filts//4, num_filts//4, self.ip_height_rs//2) # output # +1 to include background class for class output self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1) self.conv_op_bn = nn.BatchNorm2d(num_filts//4) self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0) self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0) if self.emb_dim > 0: self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0) def forward(self, ip, return_feats=False): x1 = self.conv_dn_0(ip) x2 = self.conv_dn_1(x1) x3 = self.conv_dn_2(x2) x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) x = self.att(x) x = x.repeat([1,1,self.bneck_height*4,1]) x = self.conv_up_2(x+x3) x = self.conv_up_3(x+x2) x = self.conv_up_4(x+x1) x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) op = {} op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1) op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True) op['pred_class'] = comb op['pred_class_un_norm'] = cls if self.emb_dim > 0: op['pred_emb'] = self.conv_emb(x) if return_feats: op['features'] = x return op