Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -3,35 +3,55 @@ import torch
|
|
3 |
import torchaudio
|
4 |
from torchaudio.transforms import Resample
|
5 |
|
6 |
-
#
|
7 |
-
class
|
8 |
def __init__(self):
|
9 |
-
super(
|
10 |
-
|
11 |
-
self.
|
|
|
12 |
|
13 |
def forward(self, x):
|
14 |
-
x
|
15 |
-
x = self.
|
16 |
return x
|
17 |
|
18 |
# 定义模型路径
|
19 |
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/large.pt"
|
20 |
|
21 |
# 下载模型文件
|
|
|
22 |
torch.hub.download_url_to_file(model_path, 'large.pt')
|
|
|
23 |
|
24 |
# 初始化模型
|
25 |
-
model =
|
26 |
|
27 |
# 加载模型参数
|
|
|
28 |
checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
model.eval()
|
32 |
|
33 |
# 定义处理函数
|
34 |
def transcribe(audio):
|
|
|
35 |
waveform, sample_rate = torchaudio.load(audio)
|
36 |
resample = Resample(orig_freq=sample_rate, new_freq=16000)
|
37 |
waveform = resample(waveform)
|
@@ -41,6 +61,7 @@ def transcribe(audio):
|
|
41 |
logits = model(input_values)
|
42 |
predicted_ids = torch.argmax(logits, dim=-1)
|
43 |
transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
|
|
|
44 |
return transcription
|
45 |
|
46 |
# 创建 Gradio 界面
|
@@ -52,4 +73,5 @@ iface = gr.Interface(
|
|
52 |
description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
|
53 |
)
|
54 |
|
|
|
55 |
iface.launch()
|
|
|
3 |
import torchaudio
|
4 |
from torchaudio.transforms import Resample
|
5 |
|
6 |
+
# 使用一个假设的 Transformer ASR 模型结构
|
7 |
+
class TransformerASRModel(torch.nn.Module):
|
8 |
def __init__(self):
|
9 |
+
super(TransformerASRModel, self).__init__()
|
10 |
+
# 定义模型架构,这里需要根据实际情况进行调整
|
11 |
+
self.encoder = torch.nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
12 |
+
self.decoder = torch.nn.Linear(512, 29) # 假设29个输出类用于字符
|
13 |
|
14 |
def forward(self, x):
|
15 |
+
x = self.encoder(x)
|
16 |
+
x = self.decoder(x)
|
17 |
return x
|
18 |
|
19 |
# 定义模型路径
|
20 |
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/large.pt"
|
21 |
|
22 |
# 下载模型文件
|
23 |
+
print("Downloading model file...")
|
24 |
torch.hub.download_url_to_file(model_path, 'large.pt')
|
25 |
+
print("Model file downloaded.")
|
26 |
|
27 |
# 初始化模型
|
28 |
+
model = TransformerASRModel()
|
29 |
|
30 |
# 加载模型参数
|
31 |
+
print("Loading model checkpoint...")
|
32 |
checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
|
33 |
+
print("Checkpoint keys:", checkpoint.keys())
|
34 |
+
|
35 |
+
# 打印模型参数中的键
|
36 |
+
if 'model' in checkpoint:
|
37 |
+
state_dict = checkpoint['model']
|
38 |
+
print("Model state_dict keys:", state_dict.keys())
|
39 |
+
else:
|
40 |
+
print("Key 'model' not found in checkpoint.")
|
41 |
+
state_dict = checkpoint
|
42 |
+
|
43 |
+
# 加载模型状态字典
|
44 |
+
try:
|
45 |
+
model.load_state_dict(state_dict)
|
46 |
+
print("Model state_dict loaded successfully.")
|
47 |
+
except Exception as e:
|
48 |
+
print("Error loading model state_dict:", str(e))
|
49 |
+
|
50 |
model.eval()
|
51 |
|
52 |
# 定义处理函数
|
53 |
def transcribe(audio):
|
54 |
+
print("Transcribing audio...")
|
55 |
waveform, sample_rate = torchaudio.load(audio)
|
56 |
resample = Resample(orig_freq=sample_rate, new_freq=16000)
|
57 |
waveform = resample(waveform)
|
|
|
61 |
logits = model(input_values)
|
62 |
predicted_ids = torch.argmax(logits, dim=-1)
|
63 |
transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
|
64 |
+
print("Transcription:", transcription)
|
65 |
return transcription
|
66 |
|
67 |
# 创建 Gradio 界面
|
|
|
73 |
description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
|
74 |
)
|
75 |
|
76 |
+
print("Launching Gradio interface...")
|
77 |
iface.launch()
|