import spaces import gradio as gr import os import torch from model import Wav2Vec2BERT_Llama # 自定义模型模块 import dataset # 自定义数据集模块 from huggingface_hub import hf_hub_download @spaces.GPU def dummy(): # just a dummy pass # 修改 load_model 函数 def load_model(): checkpoint_path = hf_hub_download( repo_id="amphion/deepfake_detection", filename="checkpoints_wav2vec2bert_ft_llama_labels_ASVspoof2019_RandomPrompts_6/model_checkpoint.pth", repo_type="model" ) if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") return checkpoint_path checkpoint_path = load_model() # 将 detect 函数移到 GPU 装饰器下 @spaces.GPU def detect_on_gpu(dataset): """在 GPU 上进行音频伪造检测""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Wav2Vec2BERT_Llama().to(device) # 加载模型权重 checkpoint = torch.load(checkpoint_path, map_location=device) model_state_dict = checkpoint['model_state_dict'] threshold = 0.9996 # 处理模型状态字典的 key if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()): model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()} elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()): model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()} model.load_state_dict(model_state_dict) model.eval() with torch.no_grad(): for batch in dataset: main_features = { 'input_features': batch['main_features']['input_features'].to(device), 'attention_mask': batch['main_features']['attention_mask'].to(device) } prompt_features = [{ 'input_features': pf['input_features'].to(device), 'attention_mask': pf['attention_mask'].to(device) } for pf in batch['prompt_features']] prompt_labels = batch['prompt_labels'].to(device) outputs = model({ 'main_features': main_features, 'prompt_features': prompt_features, 'prompt_labels': prompt_labels }) avg_scores = outputs['avg_logits'].softmax(dim=-1) deepfake_scores = avg_scores[:, 1].cpu() is_fake = deepfake_scores[0] > threshold result = {"is_fake": is_fake, "confidence": deepfake_scores[0]} return result # 修改音频伪造检测主函数 def audio_deepfake_detection(demonstrations, query_audio_path): demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None] demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None] if len(demonstration_paths) != len(demonstration_labels): demonstration_labels = demonstration_labels[:len(demonstration_paths)] # 数据集处理 audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path) # 调用 GPU 检测函数 result = detect_on_gpu(audio_dataset) return { "Is AI Generated": result["is_fake"], "Confidence": f"{result['confidence']:.2f}%" } # Gradio 界面 def gradio_ui(): def detection_wrapper(demonstration_audio1, label1, demonstration_audio2, label2, demonstration_audio3, label3, query_audio): demonstrations = [ (demonstration_audio1, label1), (demonstration_audio2, label2), (demonstration_audio3, label3), ] return audio_deepfake_detection(demonstrations,query_audio) interface = gr.Interface( fn=detection_wrapper, inputs=[ gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 1"), gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 1"), gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 2"), gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 2"), gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 3"), gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 3"), gr.Audio(sources=["upload"], type="filepath", label="Query Audio (Audio for Detection)") ], outputs=gr.JSON(label="Detection Results"), title="Audio Deepfake Detection System", description="Upload demonstration audios and a query audio to detect whether the query is AI-generated.", ) return interface if __name__ == "__main__": demo = gradio_ui() demo.launch()