Spaces:
Build error
Build error
File size: 5,930 Bytes
711b041 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from FakeVD.code_test.utils.metrics import *
from FakeVD.code_test.models.SVFEND import SVFENDModel
from FakeVD.code_test.utils.dataloader import SVFENDDataset
from FakeVD.code_test.run import _init_fn, SVFEND_collate_fn
# from VGGish_Feature_Extractor.my_vggish_folder_fun import vggish_audio
from FakeVD.code_test.VGGish_Feature_Extractor.my_vggish_fun import vggish_audio, load_model_vggish
from FakeVD.code_test.VGG19_Feature_Extractor.vgg19_feature import process_video as vgg19_frame
from FakeVD.code_test.VGG19_Feature_Extractor.vgg19_feature import load_model_vgg19
from FakeVD.code_test.C3D_Feature_Extractor.feature_extractor_vid import feature_extractor as c3d_video
from FakeVD.code_test.C3D_Feature_Extractor.feature_extractor_vid import load_model_c3d
from FakeVD.code_test.Text_Feature_Extractor.main import video_work as asr_text
from FakeVD.code_test.Text_Feature_Extractor.wav2text import wav2text
def load_model(checkpoint_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SVFENDModel(bert_model='bert-base-chinese', fea_dim=128,dropout=0.1)
# model.load_state_dict(torch.load(checkpoint_path))
model.load_state_dict(torch.load(checkpoint_path, map_location=device), False)
model.eval()
return model
def get_model(checkpoint_path='./FakeVD/code_test/checkpoints/SVFEND/SVFEND/_test_epoch4_0.7943'):
# 加载检测模型 模型存放路径 checkpoint_path
model_main = load_model(checkpoint_path)
model_vggish = load_model_vggish()
model_vgg19 = load_model_vgg19()
model_c3d = load_model_c3d()
model_text = wav2text()
models = {
'model_main': model_main,
'model_vggish': model_vggish,
'model_vgg19': model_vgg19,
'model_c3d' : model_c3d,
'model_text' : model_text
}
return models
# label = 0 if item['annotation']=='真' else 1
def test(model, dataloader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# model.cuda()
model.eval()
pred = []
label = []
prob = []
for batch in tqdm(dataloader):
with torch.no_grad():
batch_data = batch
for k, v in batch_data.items():
batch_data[k] = v.to(device)
batch_label = batch_data['label']
batch_outputs, fea = model(**batch_data)
_, batch_preds = torch.max(batch_outputs, 1)
softmax_probs = F.softmax(batch_outputs, dim=1) # 计算softmax概率
label.extend(batch_label.detach().cpu().numpy().tolist())
pred.extend(batch_preds.detach().cpu().numpy().tolist())
prob.extend(softmax_probs.detach().cpu().numpy().tolist()) # 收集softmax概率
return (label, pred, prob)
def main(models,
video_file_path,
preprocessed_flag=False,
feature_path='./FakeVD/code_test/preprocessed_feature'):
# 视频是否已经过预处理 preprocessed_flag
# 特征存放目录 feature_path
# 获取模型
model_main = models['model_main']
model_vggish = models['model_vggish']
model_vgg19 = models['model_vgg19']
model_c3d = models['model_c3d']
model_text = models['model_text']
# 获取视频文件夹路径
video_folder_path = os.path.dirname(video_file_path)
# 获取视频文件名(包含扩展名)
video_file_name = os.path.basename(video_file_path)
# 提取视频文件名(不包括扩展名)作为视频ID
vids = []
vid = os.path.splitext(video_file_name)[0]
vids.append(vid)
# video_file_name = os.path.basename(video_file_path)
# vids.append(os.path.splitext(video_file_name)[0])
# # vids.append(video_file_name.split('_')[1].split('.')[0]
# VGGish_audio特征目录
VGGish_audio_feature_path = os.path.join(feature_path, vid+'.pkl')
# C3D_video特征目录
C3D_video_feature_path = os.path.join(feature_path, 'C3D/')
# VGG19_frame特征目录
VGG19_frame_feature_path = os.path.join(feature_path, 'VGG19/')
# ASR_text特征目录
asr_text_feature_path = os.path.join(feature_path, 'ASR/'+vid+'.json')
# 特征提取
if not preprocessed_flag:
vggish_audio(model_vggish, video_file_path, VGGish_audio_feature_path)
vgg19_frame(model_vgg19, video_file_name, video_folder_path, VGG19_frame_feature_path)
c3d_video(model_c3d, C3D_video_feature_path, video_folder_path, video_file_name)
asr_text(model_text, model_vggish, video_file_path, asr_text_feature_path)
# 数据路径
data = vids
data_paths = {
'VGGish_audio' : VGGish_audio_feature_path,
'C3D_video' : C3D_video_feature_path,
'VGG19_frame' : VGG19_frame_feature_path,
'ASR_text' : asr_text_feature_path
}
# 创建Dataset和DataLoader
dataset = SVFENDDataset(data, data_paths)
dataloader=DataLoader(dataset, batch_size=1,
num_workers=0,
pin_memory=True,
shuffle=False,
worker_init_fn=_init_fn,
collate_fn=SVFEND_collate_fn)
# 进行预测
predictions = test(model_main, dataloader)
annotation = '真' if predictions[1][0]==0 else '假'
prob_softmax = predictions[2]
# annotation_prob = max(prob_softmax[0])
annotation_prob = prob_softmax[0][0]#真的概率
annotation_prob1 = prob_softmax[0][1]#假的概率
# 打印预测结果
print(annotation, annotation_prob, annotation_prob1)
return annotation_prob1
if __name__ == "__main__":
# 视频是否已经过预处理
preprocessed_flag = False
video_file_path = "./FakeVD/dataset/videos_1/douyin_6700861687563570439.mp4"
models = get_model()
main(models, video_file_path, preprocessed_flag)
|