Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torchaudio | |
from torchaudio.transforms import Resample | |
# 定义一个假设的 ASR 模型结构 | |
class ASRModel(torch.nn.Module): | |
def __init__(self): | |
super(ASRModel, self).__init__() | |
# 这里假设模型架构是一个简单的 LSTM | |
self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True) | |
self.linear = torch.nn.Linear(256, 29) # 假设有 29 个输出类用于字符 | |
def forward(self, x): | |
x, _ = self.lstm(x) | |
x = self.linear(x) | |
return x | |
# 定义模型路径 | |
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/base.pt" | |
# 下载模型文件 | |
print("Downloading model file...") | |
torch.hub.download_url_to_file(model_path, 'large.pt') | |
print("Model file downloaded.") | |
# 初始化模型 | |
model = ASRModel() | |
# 加载模型参数 | |
print("Loading model checkpoint...") | |
checkpoint = torch.load('large.pt', map_location=torch.device('cpu')) | |
print("Checkpoint keys:", checkpoint.keys()) | |
# 打印模型参数中的键 | |
if 'model' in checkpoint: | |
state_dict = checkpoint['model'] | |
print("Model state_dict keys:", state_dict.keys()) | |
else: | |
print("Key 'model' not found in checkpoint.") | |
state_dict = checkpoint | |
# 加载模型状态字典 | |
try: | |
model.load_state_dict(state_dict) | |
print("Model state_dict loaded successfully.") | |
except Exception as e: | |
print("Error loading model state_dict:", str(e)) | |
model.eval() | |
# 定义处理函数 | |
def transcribe(audio): | |
print("Transcribing audio...") | |
waveform, sample_rate = torchaudio.load(audio) | |
resample = Resample(orig_freq=sample_rate, new_freq=16000) | |
waveform = resample(waveform).squeeze() | |
input_values = waveform.unsqueeze(0) | |
with torch.no_grad(): | |
logits = model(input_values) | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符 | |
print("Transcription:", transcription) | |
return transcription | |
# 创建 Gradio 界面 | |
iface = gr.Interface( | |
fn=transcribe, | |
inputs=gr.Audio(type="filepath"), | |
outputs="text", | |
title="TeleSpeech ASR", | |
description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model." | |
) | |
print("Launching Gradio interface...") | |
iface.launch() | |