import torch import torch.nn as nn from torch.hub import load import torchvision.models as models dino_backbones = { 'dinov2_s':{ 'name':'dinov2_vits14', 'embedding_size':384, 'patch_size':14 }, 'dinov2_b':{ 'name':'dinov2_vitb14', 'embedding_size':768, 'patch_size':14 }, 'dinov2_l':{ 'name':'dinov2_vitl14', 'embedding_size':1024, 'patch_size':14 }, 'dinov2_g':{ 'name':'dinov2_vitg14', 'embedding_size':1536, 'patch_size':14 }, } class linear_head(nn.Module): def __init__(self, embedding_size = 384, num_classes = 5): super(linear_head, self).__init__() self.fc = nn.Linear(embedding_size, num_classes) def forward(self, x): return self.fc(x) class conv_head(nn.Module): def __init__(self, embedding_size = 384, num_classes = 5): super(conv_head, self).__init__() self.segmentation_conv = nn.Sequential( nn.Upsample(scale_factor=2), nn.Conv2d(embedding_size, 64, (3,3), padding=(1,1)), nn.Upsample(scale_factor=2), nn.Conv2d(64, num_classes, (3,3), padding=(1,1)), ) def forward(self, x): x = self.segmentation_conv(x) x = torch.sigmoid(x) return x def threshold_mask(predicted, threshold=0.55): thresholded_mask = (predicted > threshold).float() return thresholded_mask class Segmentor(nn.Module): def __init__(self, device,num_classes, backbone = 'dinov2_s', head = 'conv', backbones = dino_backbones): super(Segmentor, self).__init__() self.heads = { 'conv':conv_head } self.backbones = dino_backbones self.backbone = load('facebookresearch/dinov2', self.backbones[backbone]['name']) self.backbone.eval() self.num_classes = num_classes self.embedding_size = self.backbones[backbone]['embedding_size'] self.patch_size = self.backbones[backbone]['patch_size'] self.head = self.heads[head](self.embedding_size,self.num_classes) self.device=device def forward(self, x): batch_size = x.shape[0] mask_dim = (x.shape[2] / self.patch_size, x.shape[3] / self.patch_size) x = self.backbone.forward_features(x.to(self.device)) x = x['x_norm_patchtokens'] x = x.permute(0,2,1) x = x.reshape(batch_size,self.embedding_size,int(mask_dim[0]),int(mask_dim[1])) x = self.head(x) return x