import gradio as gr import torch import torchaudio from torchaudio.transforms import Resample import importlib.util # Function to dynamically import wav2vec2 module and avoid duplicate registration def import_wav2vec2(): if 'wav2vec2' not in sys.modules: spec = importlib.util.spec_from_file_location("wav2vec2", "wav2vec2.py") wav2vec2 = importlib.util.module_from_spec(spec) sys.modules['wav2vec2'] = wav2vec2 spec.loader.exec_module(wav2vec2) else: wav2vec2 = sys.modules['wav2vec2'] Wav2Vec2Model = wav2vec2.Wav2Vec2Model Wav2Vec2Config = wav2vec2.Wav2Vec2Config return Wav2Vec2Model, Wav2Vec2Config Wav2Vec2Model, Wav2Vec2Config = import_wav2vec2() # 定义模型路径 model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt" # 下载模型文件 print("Downloading model file...") torch.hub.download_url_to_file(model_path, 'large.pt') print("Model file downloaded.") # 加载模型配置和初始化模型 config = Wav2Vec2Config() model = Wav2Vec2Model.build_model(config) # 加载模型参数 print("Loading model checkpoint...") checkpoint = torch.load('large.pt', map_location=torch.device('cpu')) print("Checkpoint keys:", checkpoint.keys()) # 打印模型参数中的键 if 'model' in checkpoint: state_dict = checkpoint['model'] print("Model state_dict keys:", state_dict.keys()) else: print("Key 'model' not found in checkpoint.") state_dict = checkpoint # 加载模型状态字典 try: model.load_state_dict(state_dict) print("Model state_dict loaded successfully.") except Exception as e: print("Error loading model state_dict:", str(e)) model.eval() # 定义处理函数 def transcribe(audio): print("Transcribing audio...") waveform, sample_rate = torchaudio.load(audio) if sample_rate != 16000: resample = Resample(orig_freq=sample_rate, new_freq=16000) waveform = resample(waveform).squeeze() else: waveform = waveform.squeeze() # 将输入数据转换为符合模型预期的形状 input_values = waveform.unsqueeze(0) # (batch_size, seq_len) with torch.no_grad(): outputs = model.extract_features(input_values, padding_mask=None) logits = outputs["x"] predicted_ids = torch.argmax(logits, dim=-1) transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符 print("Transcription:", transcription) return transcription # 创建 Gradio 界面 iface = gr.Interface( fn=transcribe, inputs=gr.Audio(type="filepath"), outputs="text", title="TeleSpeech ASR", description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model." ) print("Launching Gradio interface...") iface.launch()