ybbwcwaps
some FakeVD
711b041
import copy
import json
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import tqdm
from sklearn.metrics import *
from tqdm import tqdm
from transformers import AutoConfig, BertModel
from transformers.models.bert.modeling_bert import BertLayer
from zmq import device
from .coattention import *
from .layers import *
from FakeVD.code_test.utils.metrics import *
class SVFENDModel(torch.nn.Module):
def __init__(self,bert_model,fea_dim,dropout):
super(SVFENDModel, self).__init__()
self.bert = BertModel.from_pretrained("./FakeVD/Models/bert-base-chinese/").requires_grad_(False)
self.text_dim = 768
self.comment_dim = 768
self.img_dim = 4096
self.video_dim = 4096
self.num_frames = 83
self.num_audioframes = 50
self.num_comments = 23
self.dim = fea_dim
self.num_heads = 4
self.dropout = dropout
self.vggish_layer = torch.hub.load('./FakeVD/Models/torchvggish/', 'vggish', source = 'local')
net_structure = list(self.vggish_layer.children())
self.vggish_modified = nn.Sequential(*net_structure[-2:-1])
self.co_attention_ta = co_attention(d_k=fea_dim, d_v=fea_dim, n_heads=self.num_heads, dropout=self.dropout, d_model=fea_dim,
visual_len=self.num_audioframes, sen_len=512, fea_v=self.dim, fea_s=self.dim, pos=False)
self.co_attention_tv = co_attention(d_k=fea_dim, d_v=fea_dim, n_heads=self.num_heads, dropout=self.dropout, d_model=fea_dim,
visual_len=self.num_frames, sen_len=512, fea_v=self.dim, fea_s=self.dim, pos=False)
self.trm = nn.TransformerEncoderLayer(d_model = self.dim, nhead = 2, batch_first = True)
self.linear_text = nn.Sequential(torch.nn.Linear(self.text_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout))
self.linear_comment = nn.Sequential(torch.nn.Linear(self.comment_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout))
self.linear_img = nn.Sequential(torch.nn.Linear(self.img_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout))
self.linear_video = nn.Sequential(torch.nn.Linear(self.video_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout))
self.linear_intro = nn.Sequential(torch.nn.Linear(self.text_dim, fea_dim),torch.nn.ReLU(),nn.Dropout(p=self.dropout))
self.linear_audio = nn.Sequential(torch.nn.Linear(fea_dim, fea_dim), torch.nn.ReLU(),nn.Dropout(p=self.dropout))
self.classifier = nn.Linear(fea_dim,2)
def forward(self, **kwargs):
### User Intro ###
### Title ###
title_inputid = kwargs['title_inputid']#(batch,512)
title_mask=kwargs['title_mask']#(batch,512)
fea_text=self.bert(title_inputid,attention_mask=title_mask)['last_hidden_state']#(batch,sequence,768)
fea_text=self.linear_text(fea_text)
### Audio Frames ###
audioframes=kwargs['audioframes']#(batch,36,12288)
audioframes_masks = kwargs['audioframes_masks']
fea_audio = self.vggish_modified(audioframes) #(batch, frames, 128)
fea_audio = self.linear_audio(fea_audio)
fea_audio, fea_text = self.co_attention_ta(v=fea_audio, s=fea_text, v_len=fea_audio.shape[1], s_len=fea_text.shape[1])
fea_audio = torch.mean(fea_audio, -2)
### Image Frames ###
frames=kwargs['frames']#(batch,30,4096)
frames_masks = kwargs['frames_masks']
fea_img = self.linear_img(frames)
fea_img, fea_text = self.co_attention_tv(v=fea_img, s=fea_text, v_len=fea_img.shape[1], s_len=fea_text.shape[1])
fea_img = torch.mean(fea_img, -2)
fea_text = torch.mean(fea_text, -2)
### C3D ###
c3d = kwargs['c3d'] # (batch, 36, 4096)
c3d_masks = kwargs['c3d_masks']
fea_video = self.linear_video(c3d) #(batch, frames, 128)
fea_video = torch.mean(fea_video, -2)
### Comment ###
fea_text = fea_text.unsqueeze(1)
# fea_comments = fea_comments.unsqueeze(1)
fea_img = fea_img.unsqueeze(1)
fea_audio = fea_audio.unsqueeze(1)
fea_video = fea_video.unsqueeze(1)
# fea_intro = fea_intro.unsqueeze(1)
fea=torch.cat((fea_text, fea_audio, fea_video,fea_img),1) # (bs, 6, 128)
fea = self.trm(fea)
fea = torch.mean(fea, -2)
output = self.classifier(fea)
return output, fea