wli3221134 commited on
Commit
0aed679
·
verified ·
1 Parent(s): 24fa17c

Upload 8 files

Browse files
Files changed (2) hide show
  1. app.py +44 -12
  2. env.sh +26 -0
app.py CHANGED
@@ -27,46 +27,78 @@ checkpoint_path = load_model()
27
  @spaces.GPU
28
  def detect_on_gpu(dataset):
29
  """在 GPU 上进行音频伪造检测"""
30
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- device = 'cpu'
 
 
 
 
32
  model = Wav2Vec2BERT_Llama().to(device)
33
 
34
- # 加载模型权重
35
  checkpoint = torch.load(checkpoint_path, map_location=device)
36
  model_state_dict = checkpoint['model_state_dict']
37
  threshold = 0.9996
 
38
 
39
  # 处理模型状态字典的 key
40
  if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
 
41
  model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
42
  elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
 
43
  model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
44
 
45
  model.load_state_dict(model_state_dict)
46
  model.eval()
 
47
 
 
48
  with torch.no_grad():
49
- for batch in dataset:
 
 
 
50
  main_features = {
51
  'input_features': batch['main_features']['input_features'].to(device),
52
  'attention_mask': batch['main_features']['attention_mask'].to(device)
53
  }
54
- prompt_features = [{
55
- 'input_features': pf['input_features'].to(device),
56
- 'attention_mask': pf['attention_mask'].to(device)
57
- } for pf in batch['prompt_features']]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- prompt_labels = batch['prompt_labels'].to(device)
60
  outputs = model({
61
  'main_features': main_features,
62
  'prompt_features': prompt_features,
63
  'prompt_labels': prompt_labels
64
  })
65
 
 
66
  avg_scores = outputs['avg_logits'].softmax(dim=-1)
67
  deepfake_scores = avg_scores[:, 1].cpu()
68
- is_fake = deepfake_scores[0] > threshold
69
- result = {"is_fake": is_fake, "confidence": deepfake_scores[0]}
 
 
 
 
 
70
  return result
71
 
72
  # 修改音频伪造检测主函数
@@ -84,7 +116,7 @@ def audio_deepfake_detection(demonstrations, query_audio_path):
84
 
85
  return {
86
  "Is AI Generated": result["is_fake"],
87
- "Confidence": f"{result['confidence']:.2f}%"
88
  }
89
 
90
  # Gradio 界面
 
27
  @spaces.GPU
28
  def detect_on_gpu(dataset):
29
  """在 GPU 上进行音频伪造检测"""
30
+ print("\n=== 开始音频检测 ===")
31
+
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ print(f"使用设备: {device}")
34
+
35
+ print("正在初始化模型...")
36
  model = Wav2Vec2BERT_Llama().to(device)
37
 
38
+ print(f"正在加载模型权重: {checkpoint_path}")
39
  checkpoint = torch.load(checkpoint_path, map_location=device)
40
  model_state_dict = checkpoint['model_state_dict']
41
  threshold = 0.9996
42
+ print(f"检测阈值设置为: {threshold}")
43
 
44
  # 处理模型状态字典的 key
45
  if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
46
+ print("添加 'module.' 前缀到状态字典的 key")
47
  model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
48
  elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
49
+ print("移除状态字典 key 中的 'module.' 前缀")
50
  model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
51
 
52
  model.load_state_dict(model_state_dict)
53
  model.eval()
54
+ print("模型加载完成,进入评估模式")
55
 
56
+ print("\n开始处理音频数据...")
57
  with torch.no_grad():
58
+ for batch_idx, batch in enumerate(dataset):
59
+ print(f"\n处理批次 {batch_idx + 1}")
60
+
61
+ print("准备主特征...")
62
  main_features = {
63
  'input_features': batch['main_features']['input_features'].to(device),
64
  'attention_mask': batch['main_features']['attention_mask'].to(device)
65
  }
66
+ print(f"主特征形状: {main_features['input_features'].shape}")
67
+
68
+ if len(batch['prompt_features']) > 0:
69
+ print("\n准备提示特征...")
70
+ prompt_features = [{
71
+ 'input_features': pf['input_features'].to(device),
72
+ 'attention_mask': pf['attention_mask'].to(device)
73
+ } for pf in batch['prompt_features']]
74
+ print(f"提示特征数量: {len(prompt_features)}")
75
+ print(f"第一个提示特征形状: {prompt_features[0]['input_features'].shape}")
76
+
77
+ print("\n准备提示标签...")
78
+ prompt_labels = batch['prompt_labels'].to(device)
79
+ print(f"提示标签形状: {prompt_labels.shape}")
80
+ print(f"提示标签值: {prompt_labels}")
81
+ else:
82
+ prompt_features = []
83
+ prompt_labels = []
84
 
85
+ print("\n执行模型推理...")
86
  outputs = model({
87
  'main_features': main_features,
88
  'prompt_features': prompt_features,
89
  'prompt_labels': prompt_labels
90
  })
91
 
92
+ print("\n处理模型输出...")
93
  avg_scores = outputs['avg_logits'].softmax(dim=-1)
94
  deepfake_scores = avg_scores[:, 1].cpu()
95
+ is_fake = deepfake_scores[0].item() > threshold
96
+
97
+ result = {"is_fake": is_fake, "confidence": deepfake_scores[0] if is_fake else 1-deepfake_scores[0]}
98
+
99
+ break
100
+
101
+ print("\n=== 检测完成 ===")
102
  return result
103
 
104
  # 修改音频伪造检测主函数
 
116
 
117
  return {
118
  "Is AI Generated": result["is_fake"],
119
+ "Confidence": f"{100*result['confidence']:.2f}%"
120
  }
121
 
122
  # Gradio 界面
env.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # Raise error if any command fails
7
+ set -e
8
+
9
+ # Install ffmpeg in Linux
10
+ conda install -c conda-forge ffmpeg
11
+
12
+ # Pip packages
13
+ pip install setuptools ruamel.yaml tqdm colorama easydict tabulate loguru json5 Cython unidecode inflect argparse g2p_en tgt librosa==0.9.1 matplotlib typeguard einops omegaconf hydra-core humanfriendly pandas munch
14
+
15
+ pip install tensorboard tensorboardX torch==2.0.1 torchaudio==2.0.2 torchvision==0.15.2 accelerate==0.24.1 transformers==4.41.2 diffusers praat-parselmouth audiomentations pedalboard ffmpeg-python==0.2.0 pyworld diffsptk==1.0.1 nnAudio unidecode inflect ptwt
16
+
17
+ pip install https://github.com/vBaiCai/python-pesq/archive/master.zip
18
+
19
+ pip install fairseq
20
+
21
+ pip install git+https://github.com/lhotse-speech/lhotse
22
+
23
+ pip install black==24.1.1
24
+
25
+ # Uninstall nvidia-cublas-cu11 if there exist some bugs about CUDA version
26
+ # pip uninstall nvidia-cublas-cu11