File size: 2,227 Bytes
1869d7a
 
71b0e8e
31564d0
ca625d0
31564d0
 
ac47f83
31564d0
 
 
 
 
 
ca625d0
 
 
31564d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1869d7a
 
 
e2bcfc6
71b0e8e
31564d0
eb24c35
 
04a460e
ca625d0
04a460e
1869d7a
ca625d0
 
f8ebe93
31564d0
e2bcfc6
71b0e8e
1869d7a
 
 
f8ebe93
31564d0
1869d7a
 
 
 
 
e2bcfc6
1869d7a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import torch
import torchaudio
from torchaudio.transforms import Resample
from data2vec2 import Data2VecMultiModel, Data2VecMultiConfig, Modality

# 定义模型路径
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt"

# 下载模型文件
print("Downloading model file...")
torch.hub.download_url_to_file(model_path, 'large.pt')
print("Model file downloaded.")

# 加载模型配置和初始化模型
config = Data2VecMultiConfig()
model = Data2VecMultiModel(config, modalities=[Modality.AUDIO])

# 加载模型参数
print("Loading model checkpoint...")
checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
print("Checkpoint keys:", checkpoint.keys())

# 打印模型参数中的键
if 'model' in checkpoint:
    state_dict = checkpoint['model']
    print("Model state_dict keys:", state_dict.keys())
else:
    print("Key 'model' not found in checkpoint.")
    state_dict = checkpoint

# 加载模型状态字典
try:
    model.load_state_dict(state_dict)
    print("Model state_dict loaded successfully.")
except Exception as e:
    print("Error loading model state_dict:", str(e))

model.eval()

# 定义处理函数
def transcribe(audio):
    print("Transcribing audio...")
    waveform, sample_rate = torchaudio.load(audio)
    resample = Resample(orig_freq=sample_rate, new_freq=16000)
    waveform = resample(waveform).squeeze()

    # 将输入数据转换为符合模型预期的形状
    input_values = waveform.unsqueeze(0)  # (batch_size, seq_len)
    
    with torch.no_grad():
        outputs = model.extract_features(input_values, mode='AUDIO')
        logits = outputs["x"]
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()])  # 解码预测到字符
    print("Transcription:", transcription)
    return transcription

# 创建 Gradio 界面
iface = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(type="filepath"),
    outputs="text",
    title="TeleSpeech ASR",
    description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
)

print("Launching Gradio interface...")
iface.launch()