File size: 5,930 Bytes
711b041
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np

from tqdm import tqdm
from FakeVD.code_test.utils.metrics import *

from FakeVD.code_test.models.SVFEND import SVFENDModel
from FakeVD.code_test.utils.dataloader import SVFENDDataset
from FakeVD.code_test.run import _init_fn, SVFEND_collate_fn

# from VGGish_Feature_Extractor.my_vggish_folder_fun import vggish_audio
from FakeVD.code_test.VGGish_Feature_Extractor.my_vggish_fun import vggish_audio, load_model_vggish
from FakeVD.code_test.VGG19_Feature_Extractor.vgg19_feature import process_video as vgg19_frame
from FakeVD.code_test.VGG19_Feature_Extractor.vgg19_feature import load_model_vgg19
from FakeVD.code_test.C3D_Feature_Extractor.feature_extractor_vid import feature_extractor as c3d_video
from FakeVD.code_test.C3D_Feature_Extractor.feature_extractor_vid import load_model_c3d
from FakeVD.code_test.Text_Feature_Extractor.main import video_work as asr_text
from FakeVD.code_test.Text_Feature_Extractor.wav2text import wav2text

def load_model(checkpoint_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SVFENDModel(bert_model='bert-base-chinese', fea_dim=128,dropout=0.1)
    # model.load_state_dict(torch.load(checkpoint_path))
    model.load_state_dict(torch.load(checkpoint_path, map_location=device), False)
    model.eval()
    return model

def get_model(checkpoint_path='./FakeVD/code_test/checkpoints/SVFEND/SVFEND/_test_epoch4_0.7943'):
    # 加载检测模型  模型存放路径 checkpoint_path
    model_main = load_model(checkpoint_path)
    model_vggish = load_model_vggish()
    model_vgg19 = load_model_vgg19()
    model_c3d = load_model_c3d()
    model_text = wav2text()

    models = {
    'model_main': model_main,
    'model_vggish': model_vggish,
    'model_vgg19': model_vgg19,
    'model_c3d' : model_c3d,
    'model_text' : model_text
    }

    return models



# label = 0 if item['annotation']=='真' else 1
def test(model, dataloader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    # model.cuda()
    model.eval()

    pred = []
    label = []
    prob = []

    for batch in tqdm(dataloader):
        with torch.no_grad():
            batch_data = batch
            for k, v in batch_data.items():
                batch_data[k] = v.to(device)
            batch_label = batch_data['label']

            batch_outputs, fea = model(**batch_data)

            _, batch_preds = torch.max(batch_outputs, 1)

            softmax_probs = F.softmax(batch_outputs, dim=1)  # 计算softmax概率

            label.extend(batch_label.detach().cpu().numpy().tolist())
            pred.extend(batch_preds.detach().cpu().numpy().tolist())
            prob.extend(softmax_probs.detach().cpu().numpy().tolist())  # 收集softmax概率

    return (label, pred, prob)

def main(models,
         video_file_path, 
         preprocessed_flag=False, 
         feature_path='./FakeVD/code_test/preprocessed_feature'):
    # 视频是否已经过预处理 preprocessed_flag
    # 特征存放目录 feature_path
    
    # 获取模型
    model_main = models['model_main']
    model_vggish = models['model_vggish']
    model_vgg19 = models['model_vgg19']
    model_c3d = models['model_c3d']
    model_text = models['model_text']

    # 获取视频文件夹路径
    video_folder_path = os.path.dirname(video_file_path)

    # 获取视频文件名(包含扩展名)
    video_file_name = os.path.basename(video_file_path)

    # 提取视频文件名(不包括扩展名)作为视频ID
    vids = []
    vid = os.path.splitext(video_file_name)[0]
    vids.append(vid)
    # video_file_name = os.path.basename(video_file_path)
    # vids.append(os.path.splitext(video_file_name)[0])
    # # vids.append(video_file_name.split('_')[1].split('.')[0]
    
    # VGGish_audio特征目录
    VGGish_audio_feature_path = os.path.join(feature_path, vid+'.pkl')
    # C3D_video特征目录
    C3D_video_feature_path = os.path.join(feature_path, 'C3D/')
    # VGG19_frame特征目录
    VGG19_frame_feature_path = os.path.join(feature_path, 'VGG19/')
    # ASR_text特征目录
    asr_text_feature_path = os.path.join(feature_path, 'ASR/'+vid+'.json')

    # 特征提取
    if not preprocessed_flag:
        vggish_audio(model_vggish, video_file_path, VGGish_audio_feature_path)
        vgg19_frame(model_vgg19, video_file_name, video_folder_path, VGG19_frame_feature_path)
        c3d_video(model_c3d, C3D_video_feature_path, video_folder_path, video_file_name)
        asr_text(model_text, model_vggish, video_file_path, asr_text_feature_path)

    # 数据路径
    data = vids
    data_paths = {
        'VGGish_audio'  :   VGGish_audio_feature_path,
        'C3D_video'     :   C3D_video_feature_path,
        'VGG19_frame'   :   VGG19_frame_feature_path,
        'ASR_text'      :   asr_text_feature_path
    }

    # 创建Dataset和DataLoader
    dataset = SVFENDDataset(data, data_paths)

    dataloader=DataLoader(dataset, batch_size=1,
            num_workers=0,
            pin_memory=True,
            shuffle=False,
            worker_init_fn=_init_fn,
            collate_fn=SVFEND_collate_fn)

    # 进行预测
    predictions = test(model_main, dataloader)
    annotation = '真' if predictions[1][0]==0 else '假'
    prob_softmax = predictions[2]
    # annotation_prob = max(prob_softmax[0])
    annotation_prob = prob_softmax[0][0]#真的概率
    annotation_prob1 = prob_softmax[0][1]#假的概率
    # 打印预测结果
    print(annotation, annotation_prob, annotation_prob1)
    
    return annotation_prob1


if __name__ == "__main__":
    # 视频是否已经过预处理
    preprocessed_flag = False
    video_file_path = "./FakeVD/dataset/videos_1/douyin_6700861687563570439.mp4"
    models = get_model()
    main(models, video_file_path, preprocessed_flag)