Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
from transformers import BertModel | |
from .layers import Attention | |
class bBbox(torch.nn.Module): | |
def __init__(self,fea_dim): | |
super(bBbox, self).__init__() | |
self.img_dim = 4096 | |
self.attention1 = Attention(dim=128,heads=4) | |
self.attention2 = Attention(dim=128,heads=4) | |
self.linear_img = nn.Sequential(torch.nn.Linear(self.img_dim, fea_dim),torch.nn.ReLU()) | |
self.classifier = nn.Linear(fea_dim,2) | |
def forward(self, **kwargs): | |
frames=kwargs['bbox_vgg'] | |
fea_img = self.linear_img(frames) | |
fea_img = torch.reshape(fea_img, (-1, 45, 128)) | |
fea_img = self.attention1(fea_img) | |
fea_img = torch.mean(fea_img, -2) | |
fea_img = torch.reshape(fea_img, (-1, 83, 128)) | |
fea_img = self.attention2(fea_img) | |
fea_img = torch.mean(fea_img, -2) | |
output = self.classifier(fea_img) | |
return output, fea_img | |
class bC3D(torch.nn.Module): | |
def __init__(self,fea_dim): | |
super(bC3D, self).__init__() | |
# self.video_dim = 4096 | |
self.video_dim = 2048 | |
self.attention = Attention(dim=128,heads=4) | |
self.linear_video = nn.Sequential(torch.nn.Linear(self.video_dim, fea_dim),torch.nn.ReLU()) | |
self.classifier = nn.Linear(fea_dim,2) | |
def forward(self, **kwargs): | |
c3d = kwargs['c3d'] | |
fea_video = self.linear_video(c3d) | |
fea_video = self.attention(fea_video) | |
fea_video = torch.mean(fea_video, -2) | |
output = self.classifier(fea_video) | |
return output | |
class bVGG(torch.nn.Module): | |
def __init__(self,fea_dim): | |
super(bVGG, self).__init__() | |
# self.img_dim = 4096 | |
self.img_dim = 2048 | |
self.attention = Attention(dim=128,heads=4) | |
self.linear_img = nn.Sequential(torch.nn.Linear(self.img_dim, fea_dim),torch.nn.ReLU()) | |
self.classifier = nn.Linear(fea_dim,2) | |
def forward(self, **kwargs): | |
frames=kwargs['frames'] | |
fea_img = self.linear_img(frames) | |
fea_img = self.attention(fea_img) | |
fea_img = torch.mean(fea_img, -2) | |
output = self.classifier(fea_img) | |
return output | |
class bVggish(torch.nn.Module): | |
def __init__(self,fea_dim): | |
super(bVggish, self).__init__() | |
# self.audio_dim = 128 | |
self.attention = Attention(dim=128,heads=4) | |
self.vggish_layer = torch.hub.load('./torchvggish/', 'vggish', source = 'local') | |
net_structure = list(self.vggish_layer.children()) | |
self.vggish_modified = nn.Sequential(*net_structure[-2:-1]) | |
self.classifier = nn.Linear(fea_dim,2) | |
def forward(self, **kwargs): | |
audioframes=kwargs['audioframes'] | |
fea_audio = self.vggish_modified(audioframes) | |
fea_audio = self.attention(fea_audio) | |
fea_audio = torch.mean(fea_audio, -2) | |
print (fea_audio.shape) | |
output = self.classifier(fea_audio) | |
return output, fea_audio | |
class bBert(torch.nn.Module): | |
def __init__(self,bert_model,fea_dim, dropout): | |
super(bBert, self).__init__() | |
self.text_dim = 768 | |
self.bert = BertModel.from_pretrained(bert_model).requires_grad_(False) | |
self.linear_text = nn.Sequential(torch.nn.Linear(self.text_dim, fea_dim),torch.nn.ReLU()) | |
self.classifier = nn.Linear(fea_dim,2) | |
def forward(self, **kwargs): | |
title_inputid = kwargs['title_inputid'] | |
title_mask=kwargs['title_mask'] | |
fea_text=self.bert(title_inputid,attention_mask=title_mask)[1] | |
fea_text=self.linear_text(fea_text) | |
output = self.classifier(fea_text) | |
return output,fea_text | |
class bTextCNN(nn.Module): | |
def __init__(self, fea_dim, vocab_size): | |
super(bTextCNN, self).__init__() | |
self.vocab_size = vocab_size | |
self.fea_dim=fea_dim | |
self.channel_in = 1 | |
self.filter_num = 14 | |
self.window_size = [3,4,5] | |
self.textcnn =nn.ModuleList([nn.Conv2d(self.channel_in, self.filter_num, (K,self.vocab_size)) for K in self.window_size]) | |
self.linear = nn.Sequential(torch.nn.Linear(len(self.window_size) * self.filter_num, self.fea_dim),torch.nn.ReLU()) | |
self.classifier = nn.Linear(self.fea_dim,2) | |
def forward(self, **kwargs): | |
title_w2v = kwargs['title_w2v'] | |
text = title_w2v.unsqueeze(1) | |
text = [F.relu(conv(text)).squeeze(3) for conv in self.textcnn] | |
text = [F.max_pool1d(i.squeeze(2), i.shape[-1]).squeeze(2) for i in text] | |
fea_text = torch.cat(text, 1) | |
fea_text = self.linear(fea_text) | |
output = self.classifier(fea_text) | |
return output | |
class bComments(torch.nn.Module): | |
def __init__(self,bert_model,fea_dim): | |
super(bComments, self).__init__() | |
self.comment_dim = 768 | |
self.bert = BertModel.from_pretrained(bert_model).requires_grad_(False) | |
self.attention = Attention(dim=128,heads=4) | |
self.linear_comment = nn.Sequential(torch.nn.Linear(self.comment_dim, fea_dim),torch.nn.ReLU()) | |
self.classifier = nn.Linear(fea_dim,2) | |
def forward(self, **kwargs): | |
comments_inputid = kwargs['comments_inputid'] | |
comments_mask=kwargs['comments_mask'] | |
comments_feature=[] | |
for i in range(comments_inputid.shape[0]): | |
bert_fea=self.bert(comments_inputid[i], attention_mask=comments_mask[i])[1] | |
comments_feature.append(bert_fea) | |
comments_feature=torch.stack(comments_feature) | |
fea_comments=self.linear_comment(comments_feature) | |
print (fea_comments.shape) | |
fea_comments = self.attention(fea_comments) | |
fea_comments = torch.mean(fea_comments, -2) | |
output = self.classifier(fea_comments) | |
return output | |