File size: 2,291 Bytes
c7e9ab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import librosa
from transformers import AutoFeatureExtractor, AutoTokenizer, SpeechEncoderDecoderModel

model_name = "facebook/wav2vec2-xls-r-2b-en-to-15",
    
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name, use_auth_token="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")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token="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", use_fast=False)
model = SpeechEncoderDecoderModel.from_pretrained(model_name, use_auth_token="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")

def process_audio_file(file):
    data, sr = librosa.load(file)
    if sr != 16000:
        data = librosa.resample(data, sr, 16000)
    print(data.shape)
    input_values = feature_extractor(data, return_tensors="pt").input_values
    return input_values
    
def transcribe(file, target_language):
    
    target_code = target_language.split("(")[-1].split(")")[0]
    forced_bos_token_id = MAPPING[target_code]

    input_values = process_audio_file(file)
    
    sequences = model.generate(input_values, forced_bos_token_id=forced_bos_token_id)
    
    transcription = tokenizer.batch_decode(sequences, skip_special_tokens=True)
    return transcription[0]
    
target_language = [
    "English (en)",
    "German (de)",
    "Turkish (tr)",
    "Persian (fa)",
    "Swedish (sv)",
    "Mongolian (mn)",
    "Chinese (zh)",
    "Welsh (cy)",
    "Catalan (ca)",
    "Slovenian (sl)",
    "Estonian (et)",
    "Indonesian (id)",
    "Arabic (ar)",
    "Tamil (ta)",
    "Latvian (lv)",
    "Japanese (ja)",
]

MAPPING = {
    "en": 250004,
    "de": 250003,
    "tr": 250023,
    "fa": 250029,
    "sv": 250042,
    "mn": 250037,
    "zh": 250025,
    "cy": 250007,
    "ca": 250005,
    "sl": 250052,
    "et": 250006,
    "id": 250032,
    "ar": 250001,
    "ta": 250044,
    "lv": 250017,
    "ja": 250012,
}
    
iface = gr.Interface(
    fn=transcribe, 
    inputs=[
        gr.inputs.Audio(source="microphone", type='filepath'),
        gr.inputs.Dropdown(target_language),
    ],
    outputs="text",
    layout="horizontal",
    theme="huggingface",
    title="XLS-R 2B EN-to-15 Speech Translation",
    description="A simple interface to translate from spoken English to 15 written languages.",
)
iface.launch()