OneRestore / model /Embedder.py
gy65896's picture
Upload 36 files
2940390 verified
raw
history blame
7.88 kB
import numpy as np
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from utils.utils_word_embedding import initialize_wordembedding_matrix
class Backbone(nn.Module):
def __init__(self, backbone='resnet18'):
super(Backbone, self).__init__()
if backbone == 'resnet18':
resnet = torchvision.models.resnet.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
elif backbone == 'resnet50':
resnet = torchvision.models.resnet.resnet50(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
elif backbone == 'resnet101':
resnet = torchvision.models.resnet.resnet101(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
self.block0 = nn.Sequential(
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool,
)
self.block1 = resnet.layer1
self.block2 = resnet.layer2
self.block3 = resnet.layer3
self.block4 = resnet.layer4
def forward(self, x, returned=[4]):
blocks = [self.block0(x)]
blocks.append(self.block1(blocks[-1]))
blocks.append(self.block2(blocks[-1]))
blocks.append(self.block3(blocks[-1]))
blocks.append(self.block4(blocks[-1]))
out = [blocks[i] for i in returned]
return out
class CosineClassifier(nn.Module):
def __init__(self, temp=0.05):
super(CosineClassifier, self).__init__()
self.temp = temp
def forward(self, img, concept, scale=True):
"""
img: (bs, emb_dim)
concept: (n_class, emb_dim)
"""
img_norm = F.normalize(img, dim=-1)
concept_norm = F.normalize(concept, dim=-1)
pred = torch.matmul(img_norm, concept_norm.transpose(0, 1))
if scale:
pred = pred / self.temp
return pred
class Embedder(nn.Module):
"""
Text and Visual Embedding Model.
"""
def __init__(self,
type_name,
feat_dim = 512,
mid_dim = 1024,
out_dim = 324,
drop_rate = 0.35,
cosine_cls_temp = 0.05,
wordembs = 'glove',
extractor_name = 'resnet18'):
super(Embedder, self).__init__()
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
self.type_name = type_name
self.feat_dim = feat_dim
self.mid_dim = mid_dim
self.out_dim = out_dim
self.drop_rate = drop_rate
self.cosine_cls_temp = cosine_cls_temp
self.wordembs = wordembs
self.extractor_name = extractor_name
self.transform = transforms.Normalize(mean, std)
self._setup_word_embedding()
self._setup_image_embedding()
def _setup_image_embedding(self):
# image embedding
self.feat_extractor = Backbone(self.extractor_name)
img_emb_modules = [
nn.Conv2d(self.feat_dim, self.mid_dim, kernel_size=1, bias=False),
nn.BatchNorm2d(self.mid_dim),
nn.ReLU()
]
if self.drop_rate > 0:
img_emb_modules += [nn.Dropout2d(self.drop_rate)]
self.img_embedder = nn.Sequential(*img_emb_modules)
self.img_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.img_final = nn.Linear(self.mid_dim, self.out_dim)
self.classifier = CosineClassifier(temp=self.cosine_cls_temp)
def _setup_word_embedding(self):
self.type2idx = {self.type_name[i]: i for i in range(len(self.type_name))}
self.num_type = len(self.type_name)
train_type = [self.type2idx[type_i] for type_i in self.type_name]
self.train_type = torch.LongTensor(train_type).to("cuda" if torch.cuda.is_available() else "cpu")
wordemb, self.word_dim = \
initialize_wordembedding_matrix(self.wordembs, self.type_name)
self.embedder = nn.Embedding(self.num_type, self.word_dim)
self.embedder.weight.data.copy_(wordemb)
self.mlp = nn.Sequential(
nn.Linear(self.word_dim, self.out_dim),
nn.ReLU(True)
)
def train_forward(self, batch):
scene, img = batch[0], self.transform(batch[1])
bs = img.shape[0]
# word embedding
scene_emb = self.embedder(self.train_type)
scene_weight = self.mlp(scene_emb)
#image embedding
img = self.feat_extractor(img)[0]
img = self.img_embedder(img)
img = self.img_avg_pool(img).squeeze(3).squeeze(2)
img = self.img_final(img)
pred = self.classifier(img, scene_weight)
label_loss = F.cross_entropy(pred, scene)
pred = torch.max(pred, dim=1)[1]
type_pred = self.train_type[pred]
correct_type = (type_pred == scene)
out = {
'loss_total': label_loss,
'acc_type': torch.div(correct_type.sum(),float(bs)),
}
return out
def image_encoder_forward(self, batch):
img = self.transform(batch)
# word embedding
scene_emb = self.embedder(self.train_type)
scene_weight = self.mlp(scene_emb)
#image embedding
img = self.feat_extractor(img)[0]
bs, _, h, w = img.shape
img = self.img_embedder(img)
img = self.img_avg_pool(img).squeeze(3).squeeze(2)
img = self.img_final(img)
pred = self.classifier(img, scene_weight)
pred = torch.max(pred, dim=1)[1]
out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu")
for i in range(bs):
out_embedding[i,:] = scene_weight[pred[i],:]
num_type = self.train_type[pred]
text_type = [self.type_name[num_type[i]] for i in range(bs)]
return out_embedding, num_type, text_type
def text_encoder_forward(self, text):
bs = len(text)
# word embedding
scene_emb = self.embedder(self.train_type)
scene_weight = self.mlp(scene_emb)
num_type = torch.zeros((bs)).to("cuda" if torch.cuda.is_available() else "cpu")
for i in range(bs):
num_type[i] = self.type2idx[text[i]]
out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu")
for i in range(bs):
out_embedding[i,:] = scene_weight[int(num_type[i]),:]
text_type = text
return out_embedding, num_type, text_type
def text_idx_encoder_forward(self, idx):
bs = idx.shape[0]
# word embedding
scene_emb = self.embedder(self.train_type)
scene_weight = self.mlp(scene_emb)
num_type = idx
out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu")
for i in range(bs):
out_embedding[i,:] = scene_weight[int(num_type[i]),:]
return out_embedding
def contrast_loss_forward(self, batch):
img = self.transform(batch)
#image embedding
img = self.feat_extractor(img)[0]
img = self.img_embedder(img)
img = self.img_avg_pool(img).squeeze(3).squeeze(2)
img = self.img_final(img)
return img
def forward(self, x, type = 'image_encoder'):
if type == 'train':
out = self.train_forward(x)
elif type == 'image_encoder':
with torch.no_grad():
out = self.image_encoder_forward(x)
elif type == 'text_encoder':
out = self.text_encoder_forward(x)
elif type == 'text_idx_encoder':
out = self.text_idx_encoder_forward(x)
elif type == 'visual_embed':
x = F.interpolate(x,size=(224,224),mode='bilinear')
out = self.contrast_loss_forward(x)
return out