Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import numpy as np | |
from torchvision import transforms | |
import torch.nn.functional as F | |
from efficientnet_pytorch import EfficientNet | |
from efficientnet_pytorch.utils import ( | |
round_filters, | |
round_repeats, | |
drop_connect, | |
get_same_padding_conv2d, | |
get_model_params, | |
efficientnet_params, | |
load_pretrained_weights, | |
Swish, | |
MemoryEfficientSwish, | |
) | |
from efficientnet_pytorch.model import MBConvBlock | |
from torchvision.models import resnet | |
from pytorchcv.model_provider import get_model | |
class Head(nn.Module): | |
def __init__(self, in_f, out_f): | |
super(Head, self).__init__() | |
self.f = nn.Flatten() | |
self.l = nn.Linear(in_f, 512) | |
self.d = nn.Dropout(0.5) | |
self.o = nn.Linear(512, out_f) | |
self.b1 = nn.BatchNorm1d(in_f) | |
self.b2 = nn.BatchNorm1d(512) | |
self.r = nn.ReLU() | |
def forward(self, x): | |
x = self.f(x) | |
x = self.b1(x) | |
x = self.d(x) | |
x = self.l(x) | |
x = self.r(x) | |
x = self.b2(x) | |
x = self.d(x) | |
out = self.o(x) | |
return out | |
class FCN(nn.Module): | |
def __init__(self, base, in_f, out_f): | |
super(FCN, self).__init__() | |
self.base = base | |
self.h1 = Head(in_f, out_f) | |
def forward(self, x): | |
x = self.base(x) | |
return self.h1(x) | |
class BaseFCN(nn.Module): | |
def __init__(self, n_classes: int): | |
super(BaseFCN, self).__init__() | |
self.f = nn.Flatten() | |
self.l = nn.Linear(625, 256) | |
self.d = nn.Dropout(0.5) | |
self.o = nn.Linear(256, n_classes) | |
def forward(self, x): | |
x = self.f(x) | |
x = self.l(x) | |
x = self.d(x) | |
out = self.o(x) | |
return out | |
def get_trainable_parameters_cooccur(self): | |
return self.parameters() | |
class BaseFCNHigh(nn.Module): | |
def __init__(self, n_classes: int): | |
super(BaseFCNHigh, self).__init__() | |
self.f = nn.Flatten() | |
self.l = nn.Linear(625, 512) | |
self.d = nn.Dropout(0.5) | |
self.o = nn.Linear(512, n_classes) | |
def forward(self, x): | |
x = self.f(x) | |
x = self.l(x) | |
x = self.d(x) | |
out = self.o(x) | |
return out | |
def get_trainable_parameters_cooccur(self): | |
return self.parameters() | |
class BaseFCN4(nn.Module): | |
def __init__(self, n_classes: int): | |
super(BaseFCN4, self).__init__() | |
self.f = nn.Flatten() | |
self.l1 = nn.Linear(625, 512) | |
self.l2 = nn.Linear(512, 384) | |
self.l3 = nn.Linear(384, 256) | |
self.d = nn.Dropout(0.5) | |
self.o = nn.Linear(256, n_classes) | |
def forward(self, x): | |
x = self.f(x) | |
x = self.l1(x) | |
x = self.d(x) | |
x = self.l2(x) | |
x = self.d(x) | |
x = self.l3(x) | |
x = self.d(x) | |
out = self.o(x) | |
return out | |
def get_trainable_parameters_cooccur(self): | |
return self.parameters() | |
class BaseFCNBnR(nn.Module): | |
def __init__(self, n_classes: int): | |
super(BaseFCNBnR, self).__init__() | |
self.f = nn.Flatten() | |
self.b1 = nn.BatchNorm1d(625) | |
self.b2 = nn.BatchNorm1d(256) | |
self.l = nn.Linear(625, 256) | |
self.d = nn.Dropout(0.5) | |
self.o = nn.Linear(256, n_classes) | |
self.r = nn.ReLU() | |
def forward(self, x): | |
x = self.f(x) | |
x = self.b1(x) | |
x = self.d(x) | |
x = self.l(x) | |
x = self.r(x) | |
x = self.b2(x) | |
x = self.d(x) | |
out = self.o(x) | |
return out | |
def get_trainable_parameters_cooccur(self): | |
return self.parameters() | |
def forward_resnet_conv(net, x, upto: int = 4): | |
""" | |
Forward ResNet only in its convolutional part | |
:param net: | |
:param x: | |
:param upto: | |
:return: | |
""" | |
x = net.conv1(x) # N / 2 | |
x = net.bn1(x) | |
x = net.relu(x) | |
x = net.maxpool(x) # N / 4 | |
if upto >= 1: | |
x = net.layer1(x) # N / 4 | |
if upto >= 2: | |
x = net.layer2(x) # N / 8 | |
if upto >= 3: | |
x = net.layer3(x) # N / 16 | |
if upto >= 4: | |
x = net.layer4(x) # N / 32 | |
return x | |
class FeatureExtractor(nn.Module): | |
""" | |
Abstract class to be extended when supporting features extraction. | |
It also provides standard normalized and parameters | |
""" | |
def features(self, x: torch.Tensor) -> torch.Tensor: | |
raise NotImplementedError | |
def get_trainable_parameters(self): | |
return self.parameters() | |
def get_normalizer(): | |
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
class FeatureExtractorGray(nn.Module): | |
""" | |
Abstract class to be extended when supporting features extraction. | |
It also provides standard normalized and parameters | |
""" | |
def features(self, x: torch.Tensor) -> torch.Tensor: | |
raise NotImplementedError | |
def get_trainable_parameters(self): | |
return self.parameters() | |
def get_normalizer(): | |
return transforms.Normalize(mean=[0.479], std=[0.226]) | |
class EfficientNetGen(FeatureExtractor): | |
def __init__(self, model: str, n_classes: int, pretrained: bool): | |
super(EfficientNetGen, self).__init__() | |
if pretrained: | |
self.efficientnet = EfficientNet.from_pretrained(model) | |
else: | |
self.efficientnet = EfficientNet.from_name(model) | |
self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes) | |
del self.efficientnet._fc | |
def features(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.efficientnet.extract_features(x) | |
x = self.efficientnet._avg_pooling(x) | |
x = x.flatten(start_dim=1) | |
return x | |
def forward(self, x): | |
x = self.features(x) | |
x = self.efficientnet._dropout(x) | |
x = self.classifier(x) | |
# x = F.softmax(x, dim=-1) | |
return x | |
class EfficientNetB0(EfficientNetGen): | |
def __init__(self, n_classes: int, pretrained: bool): | |
super(EfficientNetB0, self).__init__(model='efficientnet-b0', n_classes=n_classes, pretrained=pretrained) | |
class EfficientNetB4(EfficientNetGen): | |
def __init__(self, n_classes: int, pretrained: bool): | |
super(EfficientNetB4, self).__init__(model='efficientnet-b4', n_classes=n_classes, pretrained=pretrained) | |
class EfficientNetGenPostStem(FeatureExtractor): | |
def __init__(self, model: str, n_classes: int, pretrained: bool, n_ir_blocks: int): | |
super(EfficientNetGenPostStem, self).__init__() | |
if pretrained: | |
self.efficientnet = EfficientNet.from_pretrained(model) | |
else: | |
self.efficientnet = EfficientNet.from_name(model) | |
self.n_ir_blocks = n_ir_blocks | |
self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes) | |
# modify STEM | |
in_channels = 3 # rgb | |
out_channels = round_filters(32, self.efficientnet._global_params) | |
Conv2d = get_same_padding_conv2d(image_size=self.efficientnet._global_params.image_size) | |
self.efficientnet._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, bias=False) | |
self.init_blocks_args = self.efficientnet._blocks_args[0] | |
self.init_blocks_args = self.init_blocks_args._replace(output_filters=32) | |
self.init_block = MBConvBlock(self.init_blocks_args, self.efficientnet._global_params) | |
self.last_block_args = self.efficientnet._blocks_args[0] | |
self.last_block_args = self.last_block_args._replace(output_filters=32, stride=2) | |
self.last_block = MBConvBlock(self.last_block_args, self.efficientnet._global_params) | |
del self.efficientnet._fc | |
def features(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.efficientnet._swish(self.efficientnet._bn0(self.efficientnet._conv_stem(x))) | |
# init blocks | |
for b in range(self.n_ir_blocks - 1): | |
x = self.init_block(x, drop_connect_rate=0) | |
# last block | |
x = self.last_block(x, drop_connect_rate=0) | |
# standard blocks efficientNet: | |
for idx, block in enumerate(self.efficientnet._blocks): | |
drop_connect_rate = self.efficientnet._global_params.drop_connect_rate | |
if drop_connect_rate: | |
drop_connect_rate *= float(idx) / len(self.efficientnet._blocks) | |
x = block(x, drop_connect_rate=drop_connect_rate) | |
x = self.efficientnet._swish(self.efficientnet._bn1(self.efficientnet._conv_head(x))) | |
x = self.efficientnet._avg_pooling(x) | |
x = x.flatten(start_dim=1) | |
return x | |
def forward(self, x): | |
x = self.features(x) | |
x = self.efficientnet._dropout(x) | |
x = self.classifier(x) | |
# x = F.softmax(x, dim=-1) | |
return x | |
class EfficientNetB0PostStemIR(EfficientNetGenPostStem): | |
def __init__(self, n_classes: int, pretrained: bool, n_ir_blocks: int): | |
super(EfficientNetB0PostStemIR, self).__init__(model='efficientnet-b0', n_classes=n_classes, | |
pretrained=pretrained, n_ir_blocks=n_ir_blocks) | |
class EfficientNetGenPreStem(FeatureExtractor): | |
def __init__(self, model: str, n_classes: int, pretrained: bool, n_ir_blocks: int): | |
super(EfficientNetGenPreStem, self).__init__() | |
if pretrained: | |
self.efficientnet = EfficientNet.from_pretrained(model) | |
else: | |
self.efficientnet = EfficientNet.from_name(model) | |
self.n_ir_blocks = n_ir_blocks | |
self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes) | |
self.init_block_args = self.efficientnet._blocks_args[0] | |
self.init_block_args = self.init_block_args._replace(input_filters=3, output_filters=32) | |
self.init_block = MBConvBlock(self.init_block_args, self.efficientnet._global_params) | |
self.last_blocks_args = self.efficientnet._blocks_args[0] | |
self.last_blocks_args = self.last_blocks_args._replace(output_filters=32) | |
self.last_block = MBConvBlock(self.last_blocks_args, self.efficientnet._global_params) | |
# modify STEM | |
in_channels = 32 | |
out_channels = round_filters(32, self.efficientnet._global_params) | |
Conv2d = get_same_padding_conv2d(image_size=self.efficientnet._global_params.image_size) | |
self.efficientnet._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) | |
del self.efficientnet._fc | |
def features(self, x: torch.Tensor) -> torch.Tensor: | |
# init block | |
x = self.init_block(x, drop_connect_rate=0) | |
# other blocks | |
for b in range(self.n_ir_blocks - 1): | |
x = self.last_block(x, drop_connect_rate=0) | |
# standard stem efficientNet: | |
x = self.efficientnet._swish(self.efficientnet._bn0(self.efficientnet._conv_stem(x))) | |
# standard blocks efficientNet: | |
for idx, block in enumerate(self.efficientnet._blocks): | |
drop_connect_rate = self.efficientnet._global_params.drop_connect_rate | |
if drop_connect_rate: | |
drop_connect_rate *= float(idx) / len(self.efficientnet._blocks) | |
x = block(x, drop_connect_rate=drop_connect_rate) | |
x = self.efficientnet._swish(self.efficientnet._bn1(self.efficientnet._conv_head(x))) | |
x = self.efficientnet._avg_pooling(x) | |
x = x.flatten(start_dim=1) | |
return x | |
def forward(self, x): | |
x = self.features(x) | |
x = self.efficientnet._dropout(x) | |
x = self.classifier(x) | |
# x = F.softmax(x, dim=-1) | |
return x | |
class EfficientNetB0PreStemIR(EfficientNetGenPreStem): | |
def __init__(self, n_classes: int, pretrained: bool, n_ir_blocks: int): | |
super(EfficientNetB0PreStemIR, self).__init__(model='efficientnet-b0', n_classes=n_classes, | |
pretrained=pretrained, n_ir_blocks=n_ir_blocks) | |
class ResNet50(FeatureExtractor): | |
def __init__(self, n_classes: int, pretrained: bool): | |
super(ResNet50, self).__init__() | |
self.resnet = resnet.resnet50(pretrained=pretrained) | |
self.fc = nn.Linear(in_features=self.resnet.fc.in_features, out_features=n_classes) | |
del self.resnet.fc | |
def features(self, x): | |
x = forward_resnet_conv(self.resnet, x) | |
x = self.resnet.avgpool(x).flatten(start_dim=1) | |
return x | |
def forward(self, x): | |
x = self.features(x) | |
x = self.fc(x) | |
return x | |
""" | |
Xception from Kaggle | |
""" | |
class XceptionWeiHao(FeatureExtractor): | |
def __init__(self, n_classes: int, pretrained: bool): | |
super(XceptionWeiHao, self).__init__() | |
self.model = get_model("xception", pretrained=pretrained) | |
self.model = nn.Sequential(*list(self.model.children())[:-1]) # Remove original output layer | |
self.model[0].final_block.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1))) | |
self.model = FCN(self.model, 2048, n_classes) | |
def features(self, x: torch.Tensor) -> torch.Tensor: | |
return self.model.base(x) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.features(x) | |
return self.model.h1(x) | |