Spaces:
Runtime error
Runtime error
File size: 2,775 Bytes
1869d7a 71b0e8e 31564d0 1869d7a 31564d0 ac47f83 31564d0 1869d7a e2bcfc6 71b0e8e 31564d0 eb24c35 04a460e 1869d7a 31564d0 f8ebe93 31564d0 e2bcfc6 71b0e8e 1869d7a f8ebe93 31564d0 1869d7a e2bcfc6 1869d7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import gradio as gr
import torch
import torchaudio
from torchaudio.transforms import Resample
# 定义一个假设的 ASR 模型结构
class ASRModel(torch.nn.Module):
def __init__(self):
super(ASRModel, self).__init__()
self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True)
self.linear = torch.nn.Linear(256, 29) # 假设有 29 个输出类用于字符
def forward(self, x):
x, _ = self.lstm(x)
x = self.linear(x)
return x
# 定义模型路径
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt"
# 下载模型文件
print("Downloading model file...")
torch.hub.download_url_to_file(model_path, 'large.pt')
print("Model file downloaded.")
# 初始化模型
model = ASRModel()
# 加载模型参数
print("Loading model checkpoint...")
checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
print("Checkpoint keys:", checkpoint.keys())
# 打印模型参数中的键
if 'model' in checkpoint:
state_dict = checkpoint['model']
print("Model state_dict keys:", state_dict.keys())
else:
print("Key 'model' not found in checkpoint.")
state_dict = checkpoint
# 加载模型状态字典
try:
model.load_state_dict(state_dict)
print("Model state_dict loaded successfully.")
except Exception as e:
print("Error loading model state_dict:", str(e))
model.eval()
# 定义处理函数
def transcribe(audio):
print("Transcribing audio...")
waveform, sample_rate = torchaudio.load(audio)
resample = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resample(waveform).squeeze()
# 将输入数据转换为符合模型预期的形状
num_frames = waveform.size(0)
if num_frames % 160 != 0:
# 如果样本数量不是160的倍数,则填充样本
num_frames_padded = ((num_frames // 160) + 1) * 160
padding = num_frames_padded - num_frames
waveform = torch.nn.functional.pad(waveform, (0, padding))
input_values = waveform.view(-1, 160).unsqueeze(0) # 确保输入形状为 (batch_size, seq_len, input_size)
with torch.no_grad():
logits = model(input_values)
predicted_ids = torch.argmax(logits, dim=-1)
transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
print("Transcription:", transcription)
return transcription
# 创建 Gradio 界面
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath"),
outputs="text",
title="TeleSpeech ASR",
description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
)
print("Launching Gradio interface...")
iface.launch()
|