asr_arena / app.py
jasspier's picture
Update app.py
ac47f83 verified
raw
history blame
No virus
2.78 kB
import gradio as gr
import torch
import torchaudio
from torchaudio.transforms import Resample
# 定义一个假设的 ASR 模型结构
class ASRModel(torch.nn.Module):
def __init__(self):
super(ASRModel, self).__init__()
self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True)
self.linear = torch.nn.Linear(256, 29) # 假设有 29 个输出类用于字符
def forward(self, x):
x, _ = self.lstm(x)
x = self.linear(x)
return x
# 定义模型路径
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt"
# 下载模型文件
print("Downloading model file...")
torch.hub.download_url_to_file(model_path, 'large.pt')
print("Model file downloaded.")
# 初始化模型
model = ASRModel()
# 加载模型参数
print("Loading model checkpoint...")
checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
print("Checkpoint keys:", checkpoint.keys())
# 打印模型参数中的键
if 'model' in checkpoint:
state_dict = checkpoint['model']
print("Model state_dict keys:", state_dict.keys())
else:
print("Key 'model' not found in checkpoint.")
state_dict = checkpoint
# 加载模型状态字典
try:
model.load_state_dict(state_dict)
print("Model state_dict loaded successfully.")
except Exception as e:
print("Error loading model state_dict:", str(e))
model.eval()
# 定义处理函数
def transcribe(audio):
print("Transcribing audio...")
waveform, sample_rate = torchaudio.load(audio)
resample = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resample(waveform).squeeze()
# 将输入数据转换为符合模型预期的形状
num_frames = waveform.size(0)
if num_frames % 160 != 0:
# 如果样本数量不是160的倍数,则填充样本
num_frames_padded = ((num_frames // 160) + 1) * 160
padding = num_frames_padded - num_frames
waveform = torch.nn.functional.pad(waveform, (0, padding))
input_values = waveform.view(-1, 160).unsqueeze(0) # 确保输入形状为 (batch_size, seq_len, input_size)
with torch.no_grad():
logits = model(input_values)
predicted_ids = torch.argmax(logits, dim=-1)
transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
print("Transcription:", transcription)
return transcription
# 创建 Gradio 界面
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath"),
outputs="text",
title="TeleSpeech ASR",
description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
)
print("Launching Gradio interface...")
iface.launch()