import numpy as np import torch, torchvision import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from utils.utils_word_embedding import initialize_wordembedding_matrix class Backbone(nn.Module): def __init__(self, backbone='resnet18'): super(Backbone, self).__init__() if backbone == 'resnet18': resnet = torchvision.models.resnet.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1) elif backbone == 'resnet50': resnet = torchvision.models.resnet.resnet50(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1) elif backbone == 'resnet101': resnet = torchvision.models.resnet.resnet101(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1) self.block0 = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, ) self.block1 = resnet.layer1 self.block2 = resnet.layer2 self.block3 = resnet.layer3 self.block4 = resnet.layer4 def forward(self, x, returned=[4]): blocks = [self.block0(x)] blocks.append(self.block1(blocks[-1])) blocks.append(self.block2(blocks[-1])) blocks.append(self.block3(blocks[-1])) blocks.append(self.block4(blocks[-1])) out = [blocks[i] for i in returned] return out class CosineClassifier(nn.Module): def __init__(self, temp=0.05): super(CosineClassifier, self).__init__() self.temp = temp def forward(self, img, concept, scale=True): """ img: (bs, emb_dim) concept: (n_class, emb_dim) """ img_norm = F.normalize(img, dim=-1) concept_norm = F.normalize(concept, dim=-1) pred = torch.matmul(img_norm, concept_norm.transpose(0, 1)) if scale: pred = pred / self.temp return pred class Embedder(nn.Module): """ Text and Visual Embedding Model. """ def __init__(self, type_name, feat_dim = 512, mid_dim = 1024, out_dim = 324, drop_rate = 0.35, cosine_cls_temp = 0.05, wordembs = 'glove', extractor_name = 'resnet18'): super(Embedder, self).__init__() mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] self.type_name = type_name self.feat_dim = feat_dim self.mid_dim = mid_dim self.out_dim = out_dim self.drop_rate = drop_rate self.cosine_cls_temp = cosine_cls_temp self.wordembs = wordembs self.extractor_name = extractor_name self.transform = transforms.Normalize(mean, std) self._setup_word_embedding() self._setup_image_embedding() def _setup_image_embedding(self): # image embedding self.feat_extractor = Backbone(self.extractor_name) img_emb_modules = [ nn.Conv2d(self.feat_dim, self.mid_dim, kernel_size=1, bias=False), nn.BatchNorm2d(self.mid_dim), nn.ReLU() ] if self.drop_rate > 0: img_emb_modules += [nn.Dropout2d(self.drop_rate)] self.img_embedder = nn.Sequential(*img_emb_modules) self.img_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.img_final = nn.Linear(self.mid_dim, self.out_dim) self.classifier = CosineClassifier(temp=self.cosine_cls_temp) def _setup_word_embedding(self): self.type2idx = {self.type_name[i]: i for i in range(len(self.type_name))} self.num_type = len(self.type_name) train_type = [self.type2idx[type_i] for type_i in self.type_name] self.train_type = torch.LongTensor(train_type).to("cuda" if torch.cuda.is_available() else "cpu") wordemb, self.word_dim = \ initialize_wordembedding_matrix(self.wordembs, self.type_name) self.embedder = nn.Embedding(self.num_type, self.word_dim) self.embedder.weight.data.copy_(wordemb) self.mlp = nn.Sequential( nn.Linear(self.word_dim, self.out_dim), nn.ReLU(True) ) def train_forward(self, batch): scene, img = batch[0], self.transform(batch[1]) bs = img.shape[0] # word embedding scene_emb = self.embedder(self.train_type) scene_weight = self.mlp(scene_emb) #image embedding img = self.feat_extractor(img)[0] img = self.img_embedder(img) img = self.img_avg_pool(img).squeeze(3).squeeze(2) img = self.img_final(img) pred = self.classifier(img, scene_weight) label_loss = F.cross_entropy(pred, scene) pred = torch.max(pred, dim=1)[1] type_pred = self.train_type[pred] correct_type = (type_pred == scene) out = { 'loss_total': label_loss, 'acc_type': torch.div(correct_type.sum(),float(bs)), } return out def image_encoder_forward(self, batch): img = self.transform(batch) # word embedding scene_emb = self.embedder(self.train_type) scene_weight = self.mlp(scene_emb) #image embedding img = self.feat_extractor(img)[0] bs, _, h, w = img.shape img = self.img_embedder(img) img = self.img_avg_pool(img).squeeze(3).squeeze(2) img = self.img_final(img) pred = self.classifier(img, scene_weight) pred = torch.max(pred, dim=1)[1] out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu") for i in range(bs): out_embedding[i,:] = scene_weight[pred[i],:] num_type = self.train_type[pred] text_type = [self.type_name[num_type[i]] for i in range(bs)] return out_embedding, num_type, text_type def text_encoder_forward(self, text): bs = len(text) # word embedding scene_emb = self.embedder(self.train_type) scene_weight = self.mlp(scene_emb) num_type = torch.zeros((bs)).to("cuda" if torch.cuda.is_available() else "cpu") for i in range(bs): num_type[i] = self.type2idx[text[i]] out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu") for i in range(bs): out_embedding[i,:] = scene_weight[int(num_type[i]),:] text_type = text return out_embedding, num_type, text_type def text_idx_encoder_forward(self, idx): bs = idx.shape[0] # word embedding scene_emb = self.embedder(self.train_type) scene_weight = self.mlp(scene_emb) num_type = idx out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu") for i in range(bs): out_embedding[i,:] = scene_weight[int(num_type[i]),:] return out_embedding def contrast_loss_forward(self, batch): img = self.transform(batch) #image embedding img = self.feat_extractor(img)[0] img = self.img_embedder(img) img = self.img_avg_pool(img).squeeze(3).squeeze(2) img = self.img_final(img) return img def forward(self, x, type = 'image_encoder'): if type == 'train': out = self.train_forward(x) elif type == 'image_encoder': with torch.no_grad(): out = self.image_encoder_forward(x) elif type == 'text_encoder': out = self.text_encoder_forward(x) elif type == 'text_idx_encoder': out = self.text_idx_encoder_forward(x) elif type == 'visual_embed': x = F.interpolate(x,size=(224,224),mode='bilinear') out = self.contrast_loss_forward(x) return out