ybbwcwaps
some FakeVD
711b041
raw
history blame
20.6 kB
import collections
import json
import os
import time
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# from gensim.models import KeyedVectors
from FakeVD.code_test.models.Baselines import *
from FakeVD.code_test.models.FANVM import FANVMModel
from FakeVD.code_test.models.SVFEND import SVFENDModel
from FakeVD.code_test.models.TikTec import TikTecModel
from FakeVD.code_test.utils.dataloader import *
from FakeVD.code_test.models.Trainer import Trainer
from FakeVD.code_test.models.Trainer_3set import Trainer3
def pad_sequence(seq_len,lst, emb):
result=[]
for video in lst:
if isinstance(video, list):
video = torch.stack(video)
ori_len=video.shape[0]
if ori_len == 0:
video = torch.zeros([seq_len,emb],dtype=torch.long)
elif ori_len>=seq_len:
if emb == 200:
video=torch.FloatTensor(video[:seq_len])
else:
video=torch.LongTensor(video[:seq_len])
else:
video=torch.cat([video,torch.zeros([seq_len-ori_len,video.shape[1]],dtype=torch.long)],dim=0)
if emb == 200:
video=torch.FloatTensor(video)
else:
video=torch.LongTensor(video)
result.append(video)
return torch.stack(result)
def pad_sequence_bbox(seq_len,lst):
result=[]
for video in lst:
if isinstance(video, list):
video = torch.stack(video)
ori_len=video.shape[0]
if ori_len == 0:
video = torch.zeros([seq_len,45,4096],dtype=torch.float)
elif ori_len>=seq_len:
video=torch.FloatTensor(video[:seq_len])
else:
video=torch.cat([video,torch.zeros([seq_len-ori_len,45,4096],dtype=torch.float)],dim=0)
result.append(video)
return torch.stack(result)
def pad_frame_sequence(seq_len,lst):
attention_masks = []
result=[]
for video in lst:
video=torch.FloatTensor(video)
ori_len=video.shape[0]
if ori_len>=seq_len:
gap=ori_len//seq_len
video=video[::gap][:seq_len]
mask = np.ones((seq_len))
else:
video=torch.cat((video,torch.zeros([seq_len-ori_len,video.shape[1]],dtype=torch.float)),dim=0)
mask = np.append(np.ones(ori_len), np.zeros(seq_len-ori_len))
result.append(video)
mask = torch.IntTensor(mask)
attention_masks.append(mask)
return torch.stack(result), torch.stack(attention_masks)
def _init_fn(worker_id):
np.random.seed(2022)
def SVFEND_collate_fn(batch):
num_frames = 83
num_audioframes = 50
title_inputid = [item['title_inputid'] for item in batch]
title_mask = [item['title_mask'] for item in batch]
frames = [item['frames'] for item in batch]
frames, frames_masks = pad_frame_sequence(num_frames, frames)
audioframes = [item['audioframes'] for item in batch]
audioframes, audioframes_masks = pad_frame_sequence(num_audioframes, audioframes)
c3d = [item['c3d'] for item in batch]
c3d, c3d_masks = pad_frame_sequence(num_frames, c3d)
label = [item['label'] for item in batch]
return {
'label': torch.stack(label),
'title_inputid': torch.stack(title_inputid),
'title_mask': torch.stack(title_mask),
'audioframes': audioframes,
'audioframes_masks': audioframes_masks,
'frames':frames,
'frames_masks': frames_masks,
'c3d': c3d,
'c3d_masks': c3d_masks,
}
def FANVM_collate_fn(batch):
num_comments = 23
num_frames = 83
title_inputid = [item['title_inputid'] for item in batch]
title_mask = [item['title_mask'] for item in batch]
comments_like = [item['comments_like'] for item in batch]
comments_inputid = [item['comments_inputid'] for item in batch]
comments_mask = [item['comments_mask'] for item in batch]
comments_inputid_resorted = []
comments_mask_resorted = []
comments_like_resorted = []
for idx in range(len(comments_like)):
comments_like_one = comments_like[idx]
comments_inputid_one = comments_inputid[idx]
comments_mask_one = comments_mask[idx]
if comments_like_one.shape != torch.Size([0]):
comments_inputid_one, comments_mask_one, comments_like_one = (list(t) for t in zip(*sorted(zip(comments_inputid_one, comments_mask_one, comments_like_one), key=lambda s: s[2], reverse=True)))
comments_inputid_resorted.append(comments_inputid_one)
comments_mask_resorted.append(comments_mask_one)
comments_like_resorted.append(comments_like_one)
comments_inputid = pad_sequence(num_comments,comments_inputid_resorted,250)
comments_mask = pad_sequence(num_comments,comments_mask_resorted,250)
comments_like=[]
for idx in range(len(comments_like_resorted)):
comments_like_resorted_one = comments_like_resorted[idx]
if len(comments_like_resorted_one)>=num_comments:
comments_like.append(torch.tensor(comments_like_resorted_one[:num_comments]))
else:
if isinstance(comments_like_resorted_one, list):
comments_like.append(torch.tensor(comments_like_resorted_one+[0]*(num_comments-len(comments_like_resorted_one))))
else:
comments_like.append(torch.tensor(comments_like_resorted_one.tolist()+[0]*(num_comments-len(comments_like_resorted_one))))
frames = [item['frames'] for item in batch]
frames, frames_masks = pad_frame_sequence(num_frames, frames)
frame_thmub = [item['frame_thmub'] for item in batch]
label = [item['label'] for item in batch]
label_event = [item['label_event'] for item in batch]
s = [item['s'] for item in batch]
return {
'label': torch.stack(label),
'title_inputid': torch.stack(title_inputid),
'title_mask': torch.stack(title_mask),
'comments_inputid': comments_inputid,
'comments_mask': comments_mask,
'comments_like': torch.stack(comments_like),
'frames':frames,
'frames_masks': frames_masks,
'frame_thmub': torch.stack(frame_thmub),
's': torch.stack(s),
'label_event':torch.stack(label_event),
}
def bbox_collate_fn(batch):
num_frames = 83
bbox_vgg = [item['bbox_vgg'] for item in batch]
bbox_vgg = pad_sequence_bbox(num_frames,bbox_vgg)
label = [item['label'] for item in batch]
return {
'label': torch.stack(label),
'bbox_vgg': bbox_vgg,
}
def c3d_collate_fn(batch):
num_frames = 83
c3d = [item['c3d'] for item in batch]
c3d, c3d_masks = pad_frame_sequence(num_frames, c3d)
label = [item['label'] for item in batch]
return {
'label': torch.stack(label),
'c3d': c3d,
'c3d_masks': c3d_masks,
}
def vgg_collate_fn(batch):
num_frames = 83
frames = [item['frames'] for item in batch]
frames, frames_masks = pad_frame_sequence(num_frames, frames)
label = [item['label'] for item in batch]
return {
'label': torch.stack(label),
'frames':frames,
'frames_masks': frames_masks,
}
def comments_collate_fn(batch):
num_comments = 23
comments_like = [item['comments_like'] for item in batch]
comments_inputid = [item['comments_inputid'] for item in batch]
comments_mask = [item['comments_mask'] for item in batch]
comments_inputid_resorted = []
comments_mask_resorted = []
comments_like_resorted = []
for idx in range(len(comments_like)):
comments_like_one = comments_like[idx]
comments_inputid_one = comments_inputid[idx]
comments_mask_one = comments_mask[idx]
if comments_like_one.shape != torch.Size([0]):
comments_inputid_one, comments_mask_one, comments_like_one = (list(t) for t in zip(*sorted(zip(comments_inputid_one, comments_mask_one, comments_like_one), key=lambda s: s[2], reverse=True)))
comments_inputid_resorted.append(comments_inputid_one)
comments_mask_resorted.append(comments_mask_one)
comments_like_resorted.append(comments_like_one)
comments_inputid = pad_sequence(num_comments,comments_inputid_resorted,250)
comments_mask = pad_sequence(num_comments,comments_mask_resorted,250)
comments_like=[]
for idx in range(len(comments_like_resorted)):
comments_like_resorted_one = comments_like_resorted[idx]
if len(comments_like_resorted_one)>=num_comments:
comments_like.append(torch.tensor(comments_like_resorted_one[:num_comments]))
else:
if isinstance(comments_like_resorted_one, list):
comments_like.append(torch.tensor(comments_like_resorted_one+[0]*(num_comments-len(comments_like_resorted_one))))
else:
comments_like.append(torch.tensor(comments_like_resorted_one.tolist()+[0]*(num_comments-len(comments_like_resorted_one))))
label = [item['label'] for item in batch]
return {
'label': torch.stack(label),
'comments_inputid': comments_inputid,
'comments_mask': comments_mask,
'comments_like': torch.stack(comments_like),
}
def title_w2v_collate_fn(batch):
length_title = 128
title_w2v = [item['title_w2v'] for item in batch]
title_w2v = pad_sequence(length_title, title_w2v, 100)
label = [item['label'] for item in batch]
return {
'label': torch.stack(label),
'title_w2v': title_w2v,
}
def tictec_collate_fn(batch):
"""
将一批样本组合成一个批次。
Args:
batch (list of dict): 包含单个样本的列表,每个样本是一个字典,包含 'label'、'caption_feature'、'visual_feature'、'asr_feature'、'mask_K' 和 'mask_N'。
Returns:
dict: 包含批次数据的字典,'labels' 是一个张量,其他特征和掩码也是张量。
"""
num_frames = 83
labels = torch.stack([item['label'] for item in batch])
caption_features = torch.stack([item['caption_feature'] for item in batch])
visual_features = torch.stack([item['visual_feature'] for item in batch])
asr_features = torch.stack([item['asr_feature'] for item in batch])
mask_Ks = torch.stack([item['mask_K'] for item in batch])
mask_Ns = torch.stack([item['mask_N'] for item in batch])
return {
'label': labels,
'caption_feature': caption_features,
'visual_feature': visual_features,
'asr_feature': asr_features,
'mask_K': mask_Ks,
'mask_N': mask_Ns,
}
class Run():
def __init__(self,
config
):
self.model_name = config['model_name']
self.mode_eval = config['mode_eval']
self.fold = config['fold']
self.data_type = 'SVFEND'
self.epoches = config['epoches']
self.batch_size = config['batch_size']
self.num_workers = config['num_workers']
self.epoch_stop = config['epoch_stop']
self.seed = config['seed']
self.device = config['device']
self.lr = config['lr']
self.lambd=config['lambd']
self.save_param_dir = config['path_param']
self.path_tensorboard = config['path_tensorboard']
self.dropout = config['dropout']
self.weight_decay = config['weight_decay']
self.event_num = 616
self.mode ='normal'
def get_dataloader(self,data_type,data_fold):
collate_fn=None
if data_type=='SVFEND':
dataset_train = SVFENDDataset(f'vid_fold_{1}.txt')
dataset_test = SVFENDDataset(f'vid_fold_{2}.txt')
collate_fn=SVFEND_collate_fn
elif data_type=='FANVM':
dataset_train = FANVMDataset_train(f'vid_fold_no_{data_fold}.txt')
dataset_test = FANVMDataset_test(path_vid_train=f'vid_fold_no_{data_fold}.txt', path_vid_test=f'vid_fold_{data_fold}.txt')
collate_fn = FANVM_collate_fn
elif data_type=='c3d':
dataset_train = C3DDataset(f'vid_fold_no_{data_fold}.txt')
dataset_test = C3DDataset(f'vid_fold_{data_fold}.txt')
collate_fn = c3d_collate_fn
elif data_type=='vgg':
dataset_train = VGGDataset(f'vid_fold_no_{data_fold}.txt')
dataset_test = VGGDataset(f'vid_fold_{data_fold}.txt')
collate_fn = vgg_collate_fn
elif data_type=='bbox':
dataset_train = BboxDataset('vid_fold_no1.txt')
dataset_test = BboxDataset('vid_fold_1.txt')
collate_fn = bbox_collate_fn
elif data_type=='comments':
dataset_train = CommentsDataset(f'vid_fold_no_{data_fold}.txt')
dataset_test = CommentsDataset(f'vid_fold_{data_fold}.txt')
collate_fn = comments_collate_fn
elif data_type=='TikTec':
dataset_train = TikTecDataset(f'vid_fold_no_{data_fold}.txt')
dataset_test = TikTecDataset(f'vid_fold_{data_fold}.txt')
collate_fn = tictec_collate_fn
# elif data_type=='w2v':
# wv_from_text = KeyedVectors.load_word2vec_format("./stores/tencent-ailab-embedding-zh-d100-v0.2.0-s/tencent-ailab-embedding-zh-d100-v0.2.0-s.txt", binary=False)
# dataset_train = Title_W2V_Dataset(f'vid_fold_no{data_fold}.txt', wv_from_text)
# dataset_test = Title_W2V_Dataset(f'vid_fold_{data_fold}.txt', wv_from_text)
# collate_fn = title_w2v_collate_fn
train_dataloader = DataLoader(dataset_train, batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=True,
worker_init_fn=_init_fn,
collate_fn=collate_fn)
test_dataloader=DataLoader(dataset_test, batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=False,
worker_init_fn=_init_fn,
collate_fn=collate_fn)
dataloaders = dict(zip(['train', 'test'],[train_dataloader, test_dataloader]))
return dataloaders
def get_dataloader_temporal(self, data_type):
collate_fn=None
if data_type=='SVFEND':
dataset_train = SVFENDDataset('vid_time3_train.txt')
dataset_val = SVFENDDataset('vid_time3_val.txt')
dataset_test = SVFENDDataset('vid_time3_test.txt')
collate_fn=SVFEND_collate_fn
elif data_type=='FANVM':
dataset_train = FANVMDataset_train('vid_time3_train.txt')
dataset_val = FANVMDataset_test(path_vid_train='vid_time3_train.txt', path_vid_test='vid_time3_valid.txt')
dataset_test = FANVMDataset_test(path_vid_train='vid_time3_train.txt', path_vid_test='vid_time3_test.txt')
collate_fn = FANVM_collate_fn
else:
# can be added
print ("Not available")
train_dataloader = DataLoader(dataset_train, batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=True,
worker_init_fn=_init_fn,
collate_fn=collate_fn)
val_dataloader = DataLoader(dataset_val, batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=False,
worker_init_fn=_init_fn,
collate_fn=collate_fn)
test_dataloader=DataLoader(dataset_test, batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=False,
worker_init_fn=_init_fn,
collate_fn=collate_fn)
dataloaders = dict(zip(['train', 'val', 'test'],[train_dataloader, val_dataloader, test_dataloader]))
return dataloaders
def get_model(self):
if self.model_name == 'SVFEND':
self.model = SVFENDModel(bert_model='bert-base-chinese', fea_dim=128,dropout=self.dropout)
elif self.model_name == 'FANVM':
self.model = FANVMModel(bert_model='bert-base-chinese', fea_dim=128)
self.data_type = "FANVM"
self.mode = 'eann'
elif self.model_name == 'C3D':
self.model = bC3D(fea_dim=128)
self.data_type = "c3d"
elif self.model_name == 'VGG':
self.model = bVGG(fea_dim=128)
self.data_type = "vgg"
elif self.model_name == 'Bbox':
self.model = bBbox(fea_dim=128)
self.data_type = "bbox"
elif self.model_name == 'Vggish':
self.model = bVggish(fea_dim=128)
elif self.model_name == 'Bert':
self.model = bBert(bert_model='bert-base-chinese', fea_dim=128,dropout=self.dropout)
elif self.model_name == 'TextCNN':
self.model = bTextCNN(fea_dim=128, vocab_size=100)
self.data_type = "w2v"
elif self.model_name == 'Comments':
self.model = bComments(bert_model='bert-base-chinese', fea_dim=128)
self.data_type = "comments"
elif self.model_name == 'TikTec':
self.model = TikTecModel(VCIF_dropout=self.dropout, MLP_dropout=self.dropout)
self.data_type = 'TikTec'
return self.model
def main(self):
if self.mode_eval == "nocv":
self.model = self.get_model()
dataloaders = self.get_dataloader(data_type=self.data_type, data_fold=self.fold)
trainer = Trainer(model=self.model, device = self.device, lr = self.lr, dataloaders = dataloaders, epoches = self.epoches, dropout = self.dropout, weight_decay = self.weight_decay, mode = self.mode, model_name = self.model_name, event_num = self.event_num,
epoch_stop = self.epoch_stop, save_param_path = self.save_param_dir+self.data_type+"/"+self.model_name+"/", writer = SummaryWriter(self.path_tensorboard))
result=trainer.train()
for metric in ['acc', 'f1', 'precision', 'recall', 'auc']:
print ('%s : %.4f' % (metric, result[metric]))
elif self.mode_eval == "temporal":
self.model = self.get_model()
dataloaders = self.get_dataloader_temporal(data_type=self.data_type)
trainer = Trainer3(model=self.model, device = self.device, lr = self.lr, dataloaders = dataloaders, epoches = self.epoches, dropout = self.dropout, weight_decay = self.weight_decay, mode = self.mode, model_name = self.model_name, event_num = self.event_num,
epoch_stop = self.epoch_stop, save_param_path = self.save_param_dir+self.data_type+"/"+self.model_name+"/", writer = SummaryWriter(self.path_tensorboard))
result=trainer.train()
for metric in ['acc', 'f1', 'precision', 'recall', 'auc']:
print ('%s : %.4f' % (metric, result[metric]))
return result
elif self.mode_eval == "cv":
collate_fn=None
# if self.model_name == 'TextCNN':
# wv_from_text = KeyedVectors.load_word2vec_format("./stores/tencent-ailab-embedding-zh-d100-v0.2.0-s/tencent-ailab-embedding-zh-d100-v0.2.0-s.txt", binary=False)
history = collections.defaultdict(list)
for fold in range(1, 6):
print('-' * 50)
print ('fold %d:' % fold)
print('-' * 50)
self.model = self.get_model()
dataloaders = self.get_dataloader(data_type=self.data_type, data_fold=fold)
trainer = Trainer(model = self.model, device = self.device, lr = self.lr, dataloaders = dataloaders, epoches = self.epoches, dropout = self.dropout, weight_decay = self.weight_decay, mode = self.mode, model_name = self.model_name, event_num = self.event_num,
epoch_stop = self.epoch_stop, save_param_path = self.save_param_dir+self.data_type+"/"+self.model_name+"/", writer = SummaryWriter(self.path_tensorboard+"fold_"+str(fold)+"/"))
result = trainer.train()
history['auc'].append(result['auc'])
history['f1'].append(result['f1'])
history['recall'].append(result['recall'])
history['precision'].append(result['precision'])
history['acc'].append(result['acc'])
print ('results on 5-fold cross-validation: ')
for metric in ['acc', 'f1', 'precision', 'recall', 'auc']:
print ('%s : %.4f +/- %.4f' % (metric, np.mean(history[metric]), np.std(history[metric])))
else:
print ("Not Available")