vvd2003's picture
Upload 15 files
6fa8c33
import torch
import torch.nn as nn
from torchvision import transforms
class CNNBlock(nn.Module):
"""Base block in CNN"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bn_act=True):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=not bn_act)
self.bn = nn.BatchNorm2d(out_channels)
self.silu = nn.SiLU()
self.use_bn_act = bn_act
def forward(self, x):
if self.use_bn_act:
x = self.silu(self.bn(self.conv(x)))
return x
else:
return self.conv(x)
class BottleNeckBlock(nn.Module):
def __init__(self, channels, short_cut=True):
super().__init__()
self.short_cut = short_cut
self.Conv = nn.Sequential(CNNBlock(channels, channels//2, 3, 1, 1),
CNNBlock(channels//2, channels, 3, 1, 1))
def forward(self, x):
if self.short_cut:
return self.Conv(x) + x
else:
return self.Conv(x)
class C2FBlock(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.Conv = CNNBlock(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.Conv_end = CNNBlock(int(0.5*(1+2)*out_channels), out_channels, kernel_size=1, stride=1, padding=0)
self.BottleNeck = BottleNeckBlock(out_channels//2, **kwargs)
def forward(self, x):
x = self.Conv(x)
x, x1 = torch.split(x, self.out_channels//2, dim=1)
x2 = self.BottleNeck(x1)
x = torch.cat([x, x1, x2], dim=1)
x = self.Conv_end(x)
return x
class C2F_2_Block(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.Conv = CNNBlock(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.Conv_end = CNNBlock(int(0.5*(2+2)*out_channels), out_channels, kernel_size=1, stride=1, padding=0)
self.BottleNeck = BottleNeckBlock(out_channels//2, **kwargs)
def forward(self, x):
x = self.Conv(x)
x, x1 = torch.split(x, self.out_channels//2, dim=1)
x2 = self.BottleNeck(x1)
x3 = self.BottleNeck(x2)
x = torch.cat([x, x1, x2, x3], dim=1)
x = self.Conv_end(x)
return x
class SPPFBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.Conv = CNNBlock(channels, channels, kernel_size=1, stride=1, padding=0)
self.Conv_end = CNNBlock(4*channels, channels, kernel_size=1, stride=1, padding=0)
self.MaxPool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.Conv(x)
x = torch.cat([x, self.MaxPool(x), self.MaxPool(self.MaxPool(x)), self.MaxPool(self.MaxPool(self.MaxPool(x)))],
dim=1)
x = self.Conv_end(x)
return x
class Classifier(nn.Module):
def __init__(self, num_classes=500):
super().__init__()
self.Conv = nn.Sequential(CNNBlock(512, 1280, kernel_size=1, stride=1, padding=0))
self.Flatten = nn.Flatten()
self.Linear = nn.Sequential(nn.Linear(62720, num_classes))
def forward(self, x):
x = self.Conv(x)
x = self.Flatten(x)
x = self.Linear(x)
return x
class Yolov8_cls(nn.Module):
"""Model architecture based page: https://blog.roboflow.com/whats-new-in-yolov8/
and the ONNX file of yolov8_cls.onnx"""
def __init__(self, in_channels, num_classes=500):
super().__init__()
self.Block1 = nn.Sequential(CNNBlock(in_channels, 32, 3, 2, 1),
CNNBlock(32, 64, 3, 2, 1))
self.Block2 = C2FBlock(64, 64)
self.Block3 = nn.Sequential(CNNBlock(64, 128, 3, 2, 1),
C2F_2_Block(128, 128))
self.Block4 = nn.Sequential(CNNBlock(128, 256, 3, 2, 1),
C2F_2_Block(256, 256))
self.Block5 = nn.Sequential(CNNBlock(256, 512, 3, 2, 1),
C2F_2_Block(512, 512))
self.Block6 = Classifier(num_classes)
def forward(self, x):
x = self.Block1(x)
x = self.Block2(x)
x = self.Block3(x)
x = self.Block4(x)
x = self.Block5(x)
x = self.Block6(x)
return x
def Load_model():
"""Load model and transforms.
Returns:
model (torch.nn.Module): EffNetB2 feature extractor model.
transforms (torchvision.transforms): EffNetB2 image transforms.
"""
IMAGE_SIZE= 224
model = Yolov8_cls(3)
transform = transforms.Compose([transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
return model, transform