import torch import torch.nn as nn from torchvision import transforms class CNNBlock(nn.Module): """Base block in CNN""" def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bn_act=True): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=not bn_act) self.bn = nn.BatchNorm2d(out_channels) self.silu = nn.SiLU() self.use_bn_act = bn_act def forward(self, x): if self.use_bn_act: x = self.silu(self.bn(self.conv(x))) return x else: return self.conv(x) class BottleNeckBlock(nn.Module): def __init__(self, channels, short_cut=True): super().__init__() self.short_cut = short_cut self.Conv = nn.Sequential(CNNBlock(channels, channels//2, 3, 1, 1), CNNBlock(channels//2, channels, 3, 1, 1)) def forward(self, x): if self.short_cut: return self.Conv(x) + x else: return self.Conv(x) class C2FBlock(nn.Module): def __init__(self, in_channels, out_channels, **kwargs): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.Conv = CNNBlock(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.Conv_end = CNNBlock(int(0.5*(1+2)*out_channels), out_channels, kernel_size=1, stride=1, padding=0) self.BottleNeck = BottleNeckBlock(out_channels//2, **kwargs) def forward(self, x): x = self.Conv(x) x, x1 = torch.split(x, self.out_channels//2, dim=1) x2 = self.BottleNeck(x1) x = torch.cat([x, x1, x2], dim=1) x = self.Conv_end(x) return x class C2F_2_Block(nn.Module): def __init__(self, in_channels, out_channels, **kwargs): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.Conv = CNNBlock(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.Conv_end = CNNBlock(int(0.5*(2+2)*out_channels), out_channels, kernel_size=1, stride=1, padding=0) self.BottleNeck = BottleNeckBlock(out_channels//2, **kwargs) def forward(self, x): x = self.Conv(x) x, x1 = torch.split(x, self.out_channels//2, dim=1) x2 = self.BottleNeck(x1) x3 = self.BottleNeck(x2) x = torch.cat([x, x1, x2, x3], dim=1) x = self.Conv_end(x) return x class SPPFBlock(nn.Module): def __init__(self, channels): super().__init__() self.Conv = CNNBlock(channels, channels, kernel_size=1, stride=1, padding=0) self.Conv_end = CNNBlock(4*channels, channels, kernel_size=1, stride=1, padding=0) self.MaxPool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) def forward(self, x): x = self.Conv(x) x = torch.cat([x, self.MaxPool(x), self.MaxPool(self.MaxPool(x)), self.MaxPool(self.MaxPool(self.MaxPool(x)))], dim=1) x = self.Conv_end(x) return x class Classifier(nn.Module): def __init__(self, num_classes=500): super().__init__() self.Conv = nn.Sequential(CNNBlock(512, 1280, kernel_size=1, stride=1, padding=0)) self.Flatten = nn.Flatten() self.Linear = nn.Sequential(nn.Linear(62720, num_classes)) def forward(self, x): x = self.Conv(x) x = self.Flatten(x) x = self.Linear(x) return x class Yolov8_cls(nn.Module): """Model architecture based page: https://blog.roboflow.com/whats-new-in-yolov8/ and the ONNX file of yolov8_cls.onnx""" def __init__(self, in_channels, num_classes=500): super().__init__() self.Block1 = nn.Sequential(CNNBlock(in_channels, 32, 3, 2, 1), CNNBlock(32, 64, 3, 2, 1)) self.Block2 = C2FBlock(64, 64) self.Block3 = nn.Sequential(CNNBlock(64, 128, 3, 2, 1), C2F_2_Block(128, 128)) self.Block4 = nn.Sequential(CNNBlock(128, 256, 3, 2, 1), C2F_2_Block(256, 256)) self.Block5 = nn.Sequential(CNNBlock(256, 512, 3, 2, 1), C2F_2_Block(512, 512)) self.Block6 = Classifier(num_classes) def forward(self, x): x = self.Block1(x) x = self.Block2(x) x = self.Block3(x) x = self.Block4(x) x = self.Block5(x) x = self.Block6(x) return x def Load_model(): """Load model and transforms. Returns: model (torch.nn.Module): EffNetB2 feature extractor model. transforms (torchvision.transforms): EffNetB2 image transforms. """ IMAGE_SIZE= 224 model = Yolov8_cls(3) transform = transforms.Compose([transforms.Resize(IMAGE_SIZE), transforms.CenterCrop(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) return model, transform