Spaces:
Running
on
Zero
Running
on
Zero
wli3221134
commited on
Upload 8 files
Browse files
app.py
CHANGED
@@ -27,46 +27,78 @@ checkpoint_path = load_model()
|
|
27 |
@spaces.GPU
|
28 |
def detect_on_gpu(dataset):
|
29 |
"""在 GPU 上进行音频伪造检测"""
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
model = Wav2Vec2BERT_Llama().to(device)
|
33 |
|
34 |
-
|
35 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
36 |
model_state_dict = checkpoint['model_state_dict']
|
37 |
threshold = 0.9996
|
|
|
38 |
|
39 |
# 处理模型状态字典的 key
|
40 |
if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
|
|
|
41 |
model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
|
42 |
elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
|
|
|
43 |
model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
|
44 |
|
45 |
model.load_state_dict(model_state_dict)
|
46 |
model.eval()
|
|
|
47 |
|
|
|
48 |
with torch.no_grad():
|
49 |
-
for batch in dataset:
|
|
|
|
|
|
|
50 |
main_features = {
|
51 |
'input_features': batch['main_features']['input_features'].to(device),
|
52 |
'attention_mask': batch['main_features']['attention_mask'].to(device)
|
53 |
}
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
|
60 |
outputs = model({
|
61 |
'main_features': main_features,
|
62 |
'prompt_features': prompt_features,
|
63 |
'prompt_labels': prompt_labels
|
64 |
})
|
65 |
|
|
|
66 |
avg_scores = outputs['avg_logits'].softmax(dim=-1)
|
67 |
deepfake_scores = avg_scores[:, 1].cpu()
|
68 |
-
is_fake = deepfake_scores[0] > threshold
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
70 |
return result
|
71 |
|
72 |
# 修改音频伪造检测主函数
|
@@ -84,7 +116,7 @@ def audio_deepfake_detection(demonstrations, query_audio_path):
|
|
84 |
|
85 |
return {
|
86 |
"Is AI Generated": result["is_fake"],
|
87 |
-
"Confidence": f"{result['confidence']:.2f}%"
|
88 |
}
|
89 |
|
90 |
# Gradio 界面
|
|
|
27 |
@spaces.GPU
|
28 |
def detect_on_gpu(dataset):
|
29 |
"""在 GPU 上进行音频伪造检测"""
|
30 |
+
print("\n=== 开始音频检测 ===")
|
31 |
+
|
32 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
33 |
+
print(f"使用设备: {device}")
|
34 |
+
|
35 |
+
print("正在初始化模型...")
|
36 |
model = Wav2Vec2BERT_Llama().to(device)
|
37 |
|
38 |
+
print(f"正在加载模型权重: {checkpoint_path}")
|
39 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
40 |
model_state_dict = checkpoint['model_state_dict']
|
41 |
threshold = 0.9996
|
42 |
+
print(f"检测阈值设置为: {threshold}")
|
43 |
|
44 |
# 处理模型状态字典的 key
|
45 |
if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
|
46 |
+
print("添加 'module.' 前缀到状态字典的 key")
|
47 |
model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
|
48 |
elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
|
49 |
+
print("移除状态字典 key 中的 'module.' 前缀")
|
50 |
model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
|
51 |
|
52 |
model.load_state_dict(model_state_dict)
|
53 |
model.eval()
|
54 |
+
print("模型加载完成,进入评估模式")
|
55 |
|
56 |
+
print("\n开始处理音频数据...")
|
57 |
with torch.no_grad():
|
58 |
+
for batch_idx, batch in enumerate(dataset):
|
59 |
+
print(f"\n处理批次 {batch_idx + 1}")
|
60 |
+
|
61 |
+
print("准备主特征...")
|
62 |
main_features = {
|
63 |
'input_features': batch['main_features']['input_features'].to(device),
|
64 |
'attention_mask': batch['main_features']['attention_mask'].to(device)
|
65 |
}
|
66 |
+
print(f"主特征形状: {main_features['input_features'].shape}")
|
67 |
+
|
68 |
+
if len(batch['prompt_features']) > 0:
|
69 |
+
print("\n准备提示特征...")
|
70 |
+
prompt_features = [{
|
71 |
+
'input_features': pf['input_features'].to(device),
|
72 |
+
'attention_mask': pf['attention_mask'].to(device)
|
73 |
+
} for pf in batch['prompt_features']]
|
74 |
+
print(f"提示特征数量: {len(prompt_features)}")
|
75 |
+
print(f"第一个提示特征形状: {prompt_features[0]['input_features'].shape}")
|
76 |
+
|
77 |
+
print("\n准备提示标签...")
|
78 |
+
prompt_labels = batch['prompt_labels'].to(device)
|
79 |
+
print(f"提示标签形状: {prompt_labels.shape}")
|
80 |
+
print(f"提示标签值: {prompt_labels}")
|
81 |
+
else:
|
82 |
+
prompt_features = []
|
83 |
+
prompt_labels = []
|
84 |
|
85 |
+
print("\n执行模型推理...")
|
86 |
outputs = model({
|
87 |
'main_features': main_features,
|
88 |
'prompt_features': prompt_features,
|
89 |
'prompt_labels': prompt_labels
|
90 |
})
|
91 |
|
92 |
+
print("\n处理模型输出...")
|
93 |
avg_scores = outputs['avg_logits'].softmax(dim=-1)
|
94 |
deepfake_scores = avg_scores[:, 1].cpu()
|
95 |
+
is_fake = deepfake_scores[0].item() > threshold
|
96 |
+
|
97 |
+
result = {"is_fake": is_fake, "confidence": deepfake_scores[0] if is_fake else 1-deepfake_scores[0]}
|
98 |
+
|
99 |
+
break
|
100 |
+
|
101 |
+
print("\n=== 检测完成 ===")
|
102 |
return result
|
103 |
|
104 |
# 修改音频伪造检测主函数
|
|
|
116 |
|
117 |
return {
|
118 |
"Is AI Generated": result["is_fake"],
|
119 |
+
"Confidence": f"{100*result['confidence']:.2f}%"
|
120 |
}
|
121 |
|
122 |
# Gradio 界面
|
env.sh
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# Raise error if any command fails
|
7 |
+
set -e
|
8 |
+
|
9 |
+
# Install ffmpeg in Linux
|
10 |
+
conda install -c conda-forge ffmpeg
|
11 |
+
|
12 |
+
# Pip packages
|
13 |
+
pip install setuptools ruamel.yaml tqdm colorama easydict tabulate loguru json5 Cython unidecode inflect argparse g2p_en tgt librosa==0.9.1 matplotlib typeguard einops omegaconf hydra-core humanfriendly pandas munch
|
14 |
+
|
15 |
+
pip install tensorboard tensorboardX torch==2.0.1 torchaudio==2.0.2 torchvision==0.15.2 accelerate==0.24.1 transformers==4.41.2 diffusers praat-parselmouth audiomentations pedalboard ffmpeg-python==0.2.0 pyworld diffsptk==1.0.1 nnAudio unidecode inflect ptwt
|
16 |
+
|
17 |
+
pip install https://github.com/vBaiCai/python-pesq/archive/master.zip
|
18 |
+
|
19 |
+
pip install fairseq
|
20 |
+
|
21 |
+
pip install git+https://github.com/lhotse-speech/lhotse
|
22 |
+
|
23 |
+
pip install black==24.1.1
|
24 |
+
|
25 |
+
# Uninstall nvidia-cublas-cu11 if there exist some bugs about CUDA version
|
26 |
+
# pip uninstall nvidia-cublas-cu11
|