Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,22 +1,30 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
3 |
import torch
|
4 |
-
import
|
|
|
5 |
|
6 |
-
#
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# 定义处理函数
|
12 |
def transcribe(audio):
|
13 |
-
waveform,
|
14 |
-
|
|
|
|
|
|
|
15 |
with torch.no_grad():
|
16 |
-
logits = model(input_values)
|
17 |
predicted_ids = torch.argmax(logits, dim=-1)
|
18 |
-
transcription =
|
19 |
-
return transcription
|
20 |
|
21 |
# 创建 Gradio 界面
|
22 |
iface = gr.Interface(
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import torch
|
3 |
+
import torchaudio
|
4 |
+
from torchaudio.transforms import Resample
|
5 |
|
6 |
+
# 定义模型路径
|
7 |
+
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/large.pt"
|
8 |
+
|
9 |
+
# 下载模型文件
|
10 |
+
torch.hub.download_url_to_file(model_path, 'large.pt')
|
11 |
+
|
12 |
+
# 加载模型
|
13 |
+
model = torch.jit.load('large.pt')
|
14 |
+
model.eval()
|
15 |
|
16 |
# 定义处理函数
|
17 |
def transcribe(audio):
|
18 |
+
waveform, sample_rate = torchaudio.load(audio)
|
19 |
+
resample = Resample(orig_freq=sample_rate, new_freq=16000)
|
20 |
+
waveform = resample(waveform)
|
21 |
+
|
22 |
+
input_values = waveform.unsqueeze(0)
|
23 |
with torch.no_grad():
|
24 |
+
logits = model(input_values)
|
25 |
predicted_ids = torch.argmax(logits, dim=-1)
|
26 |
+
transcription = tokenizer.decode(predicted_ids[0])
|
27 |
+
return transcription
|
28 |
|
29 |
# 创建 Gradio 界面
|
30 |
iface = gr.Interface(
|