Spaces:
Build error
Build error
import copy | |
import json | |
import os | |
import time | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
import tqdm | |
from sklearn.metrics import * | |
from tqdm import tqdm | |
from transformers import AutoConfig, BertModel | |
from transformers.models.bert.modeling_bert import BertLayer | |
from zmq import device | |
from .coattention import * | |
from .layers import * | |
from FakeVD.code_test.utils.metrics import * | |
class SVFENDModel(torch.nn.Module): | |
def __init__(self,bert_model,fea_dim,dropout): | |
super(SVFENDModel, self).__init__() | |
self.bert = BertModel.from_pretrained("./FakeVD/Models/bert-base-chinese/").requires_grad_(False) | |
self.text_dim = 768 | |
self.comment_dim = 768 | |
self.img_dim = 4096 | |
self.video_dim = 4096 | |
self.num_frames = 83 | |
self.num_audioframes = 50 | |
self.num_comments = 23 | |
self.dim = fea_dim | |
self.num_heads = 4 | |
self.dropout = dropout | |
self.vggish_layer = torch.hub.load('./FakeVD/Models/torchvggish/', 'vggish', source = 'local') | |
net_structure = list(self.vggish_layer.children()) | |
self.vggish_modified = nn.Sequential(*net_structure[-2:-1]) | |
self.co_attention_ta = co_attention(d_k=fea_dim, d_v=fea_dim, n_heads=self.num_heads, dropout=self.dropout, d_model=fea_dim, | |
visual_len=self.num_audioframes, sen_len=512, fea_v=self.dim, fea_s=self.dim, pos=False) | |
self.co_attention_tv = co_attention(d_k=fea_dim, d_v=fea_dim, n_heads=self.num_heads, dropout=self.dropout, d_model=fea_dim, | |
visual_len=self.num_frames, sen_len=512, fea_v=self.dim, fea_s=self.dim, pos=False) | |
self.trm = nn.TransformerEncoderLayer(d_model = self.dim, nhead = 2, batch_first = True) | |
self.linear_text = nn.Sequential(torch.nn.Linear(self.text_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout)) | |
self.linear_comment = nn.Sequential(torch.nn.Linear(self.comment_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout)) | |
self.linear_img = nn.Sequential(torch.nn.Linear(self.img_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout)) | |
self.linear_video = nn.Sequential(torch.nn.Linear(self.video_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout)) | |
self.linear_intro = nn.Sequential(torch.nn.Linear(self.text_dim, fea_dim),torch.nn.ReLU(),nn.Dropout(p=self.dropout)) | |
self.linear_audio = nn.Sequential(torch.nn.Linear(fea_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout)) | |
self.classifier = nn.Linear(fea_dim,2) | |
def forward(self, **kwargs): | |
### User Intro ### | |
### Title ### | |
title_inputid = kwargs['title_inputid']#(batch,512) | |
title_mask=kwargs['title_mask']#(batch,512) | |
fea_text=self.bert(title_inputid,attention_mask=title_mask)['last_hidden_state']#(batch,sequence,768) | |
fea_text=self.linear_text(fea_text) | |
### Audio Frames ### | |
audioframes=kwargs['audioframes']#(batch,36,12288) | |
audioframes_masks = kwargs['audioframes_masks'] | |
fea_audio = self.vggish_modified(audioframes) #(batch, frames, 128) | |
fea_audio = self.linear_audio(fea_audio) | |
fea_audio, fea_text = self.co_attention_ta(v=fea_audio, s=fea_text, v_len=fea_audio.shape[1], s_len=fea_text.shape[1]) | |
fea_audio = torch.mean(fea_audio, -2) | |
### Image Frames ### | |
frames=kwargs['frames']#(batch,30,4096) | |
frames_masks = kwargs['frames_masks'] | |
fea_img = self.linear_img(frames) | |
fea_img, fea_text = self.co_attention_tv(v=fea_img, s=fea_text, v_len=fea_img.shape[1], s_len=fea_text.shape[1]) | |
fea_img = torch.mean(fea_img, -2) | |
fea_text = torch.mean(fea_text, -2) | |
### C3D ### | |
c3d = kwargs['c3d'] # (batch, 36, 4096) | |
c3d_masks = kwargs['c3d_masks'] | |
fea_video = self.linear_video(c3d) #(batch, frames, 128) | |
fea_video = torch.mean(fea_video, -2) | |
### Comment ### | |
fea_text = fea_text.unsqueeze(1) | |
# fea_comments = fea_comments.unsqueeze(1) | |
fea_img = fea_img.unsqueeze(1) | |
fea_audio = fea_audio.unsqueeze(1) | |
fea_video = fea_video.unsqueeze(1) | |
# fea_intro = fea_intro.unsqueeze(1) | |
fea=torch.cat((fea_text, fea_audio, fea_video,fea_img),1) # (bs, 6, 128) | |
fea = self.trm(fea) | |
fea = torch.mean(fea, -2) | |
output = self.classifier(fea) | |
return output, fea | |