Spaces:
Running
Running
import os | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
import numpy as np | |
from tqdm import tqdm | |
from FakeVD.code_test.utils.metrics import * | |
from FakeVD.code_test.models.SVFEND import SVFENDModel | |
from FakeVD.code_test.utils.dataloader import SVFENDDataset | |
from FakeVD.code_test.run import _init_fn, SVFEND_collate_fn | |
# from VGGish_Feature_Extractor.my_vggish_folder_fun import vggish_audio | |
from FakeVD.code_test.VGGish_Feature_Extractor.my_vggish_fun import vggish_audio, load_model_vggish | |
from FakeVD.code_test.VGG19_Feature_Extractor.vgg19_feature import process_video as vgg19_frame | |
from FakeVD.code_test.VGG19_Feature_Extractor.vgg19_feature import load_model_vgg19 | |
from FakeVD.code_test.C3D_Feature_Extractor.feature_extractor_vid import feature_extractor as c3d_video | |
from FakeVD.code_test.C3D_Feature_Extractor.feature_extractor_vid import load_model_c3d | |
from FakeVD.code_test.Text_Feature_Extractor.main import video_work as asr_text | |
from FakeVD.code_test.Text_Feature_Extractor.wav2text import wav2text | |
def load_model(checkpoint_path): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = SVFENDModel(bert_model='bert-base-chinese', fea_dim=128,dropout=0.1) | |
# model.load_state_dict(torch.load(checkpoint_path)) | |
model.load_state_dict(torch.load(checkpoint_path, map_location=device), False) | |
model.eval() | |
return model | |
def get_model(checkpoint_path='./FakeVD/code_test/checkpoints/SVFEND/SVFEND/_test_epoch4_0.7943'): | |
# 加载检测模型 模型存放路径 checkpoint_path | |
model_main = load_model(checkpoint_path) | |
model_vggish = load_model_vggish() | |
model_vgg19 = load_model_vgg19() | |
model_c3d = load_model_c3d() | |
model_text = wav2text() | |
models = { | |
'model_main': model_main, | |
'model_vggish': model_vggish, | |
'model_vgg19': model_vgg19, | |
'model_c3d' : model_c3d, | |
'model_text' : model_text | |
} | |
return models | |
# label = 0 if item['annotation']=='真' else 1 | |
def test(model, dataloader): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# model.cuda() | |
model.eval() | |
pred = [] | |
label = [] | |
prob = [] | |
for batch in tqdm(dataloader): | |
with torch.no_grad(): | |
batch_data = batch | |
for k, v in batch_data.items(): | |
batch_data[k] = v.to(device) | |
batch_label = batch_data['label'] | |
batch_outputs, fea = model(**batch_data) | |
_, batch_preds = torch.max(batch_outputs, 1) | |
softmax_probs = F.softmax(batch_outputs, dim=1) # 计算softmax概率 | |
label.extend(batch_label.detach().cpu().numpy().tolist()) | |
pred.extend(batch_preds.detach().cpu().numpy().tolist()) | |
prob.extend(softmax_probs.detach().cpu().numpy().tolist()) # 收集softmax概率 | |
return (label, pred, prob) | |
def main(models, | |
video_file_path, | |
preprocessed_flag=False, | |
feature_path='./FakeVD/code_test/preprocessed_feature'): | |
# 视频是否已经过预处理 preprocessed_flag | |
# 特征存放目录 feature_path | |
# 获取模型 | |
model_main = models['model_main'] | |
model_vggish = models['model_vggish'] | |
model_vgg19 = models['model_vgg19'] | |
model_c3d = models['model_c3d'] | |
model_text = models['model_text'] | |
# 获取视频文件夹路径 | |
video_folder_path = os.path.dirname(video_file_path) | |
# 获取视频文件名(包含扩展名) | |
video_file_name = os.path.basename(video_file_path) | |
# 提取视频文件名(不包括扩展名)作为视频ID | |
vids = [] | |
vid = os.path.splitext(video_file_name)[0] | |
vids.append(vid) | |
# video_file_name = os.path.basename(video_file_path) | |
# vids.append(os.path.splitext(video_file_name)[0]) | |
# # vids.append(video_file_name.split('_')[1].split('.')[0] | |
# VGGish_audio特征目录 | |
VGGish_audio_feature_path = os.path.join(feature_path, vid+'.pkl') | |
# C3D_video特征目录 | |
C3D_video_feature_path = os.path.join(feature_path, 'C3D/') | |
# VGG19_frame特征目录 | |
VGG19_frame_feature_path = os.path.join(feature_path, 'VGG19/') | |
# ASR_text特征目录 | |
asr_text_feature_path = os.path.join(feature_path, 'ASR/'+vid+'.json') | |
# 特征提取 | |
if not preprocessed_flag: | |
vggish_audio(model_vggish, video_file_path, VGGish_audio_feature_path) | |
vgg19_frame(model_vgg19, video_file_name, video_folder_path, VGG19_frame_feature_path) | |
c3d_video(model_c3d, C3D_video_feature_path, video_folder_path, video_file_name) | |
asr_text(model_text, model_vggish, video_file_path, asr_text_feature_path) | |
# 数据路径 | |
data = vids | |
data_paths = { | |
'VGGish_audio' : VGGish_audio_feature_path, | |
'C3D_video' : C3D_video_feature_path, | |
'VGG19_frame' : VGG19_frame_feature_path, | |
'ASR_text' : asr_text_feature_path | |
} | |
# 创建Dataset和DataLoader | |
dataset = SVFENDDataset(data, data_paths) | |
dataloader=DataLoader(dataset, batch_size=1, | |
num_workers=0, | |
pin_memory=True, | |
shuffle=False, | |
worker_init_fn=_init_fn, | |
collate_fn=SVFEND_collate_fn) | |
# 进行预测 | |
predictions = test(model_main, dataloader) | |
annotation = '真' if predictions[1][0]==0 else '假' | |
prob_softmax = predictions[2] | |
# annotation_prob = max(prob_softmax[0]) | |
annotation_prob = prob_softmax[0][0]#真的概率 | |
annotation_prob1 = prob_softmax[0][1]#假的概率 | |
# 打印预测结果 | |
print(annotation, annotation_prob, annotation_prob1) | |
return annotation_prob1 | |
if __name__ == "__main__": | |
# 视频是否已经过预处理 | |
preprocessed_flag = False | |
video_file_path = "./FakeVD/dataset/videos_1/douyin_6700861687563570439.mp4" | |
models = get_model() | |
main(models, video_file_path, preprocessed_flag) | |