File size: 3,396 Bytes
6bc4b20
 
e5d93c0
 
 
ec5489a
e5d93c0
388e862
07140c6
ec5489a
e5d93c0
ec5489a
 
 
e5d93c0
89b20be
 
 
e5d93c0
 
 
 
 
ec5489a
89b20be
 
ad145b5
e5d93c0
 
c6dd8c8
 
e5d93c0
 
c6dd8c8
 
 
 
 
 
 
 
 
 
 
 
e5d93c0
 
c6dd8c8
e5d93c0
 
c6dd8c8
e5d93c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6dd8c8
 
e5d93c0
 
 
 
 
ec5489a
e5d93c0
ec5489a
5ee0b21
 
e5d93c0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
os.system("pip install gradio==2.8.0b2")
import gradio as gr
import librosa
from transformers import AutoFeatureExtractor, AutoTokenizer, SpeechEncoderDecoderModel
import torch

model_name = "facebook/wav2vec2-xls-r-2b-22-to-16"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = SpeechEncoderDecoderModel.from_pretrained(model_name).to(device)

if torch.cuda.is_available():
    model.half()

def process_audio_file(file):
    data, sr = librosa.load(file)
    if sr != 16000:
        data = librosa.resample(data, sr, 16000)
    print(data.shape)
    input_values = feature_extractor(data, return_tensors="pt").input_values.to(device)
    
    if torch.cuda.is_available():
        input_values = input_values.to(torch.float16)
    return input_values
    
def transcribe(file_mic, file_upload, target_language):
    
    target_code = target_language.split("(")[-1].split(")")[0]
    forced_bos_token_id = MAPPING[target_code]
    
    warn_output = ""
    if (file_mic is not None) and (file_upload is not None):
       warn_output = "WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
       file = file_mic
    elif (file_mic is None) and (file_upload is None):
       return "ERROR: You have to either use the microphone or upload an audio file"
    elif file_mic is not None:
       file = file_mic
    else:
       file = file_upload
       
    input_values = process_audio_file(file)
    
    sequences = model.generate(input_values, forced_bos_token_id=forced_bos_token_id)
    
    transcription = tokenizer.batch_decode(sequences, skip_special_tokens=True)
    return warn_output + transcription[0]
    
target_language = [
    "English (en)",
    "German (de)",
    "Turkish (tr)",
    "Persian (fa)",
    "Swedish (sv)",
    "Mongolian (mn)",
    "Chinese (zh)",
    "Welsh (cy)",
    "Catalan (ca)",
    "Slovenian (sl)",
    "Estonian (et)",
    "Indonesian (id)",
    "Arabic (ar)",
    "Tamil (ta)",
    "Latvian (lv)",
    "Japanese (ja)",
]

MAPPING = {
    "en": 250004,
    "de": 250003,
    "tr": 250023,
    "fa": 250029,
    "sv": 250042,
    "mn": 250037,
    "zh": 250025,
    "cy": 250007,
    "ca": 250005,
    "sl": 250052,
    "et": 250006,
    "id": 250032,
    "ar": 250001,
    "ta": 250044,
    "lv": 250017,
    "ja": 250012,
}
    
iface = gr.Interface(
    fn=transcribe, 
    inputs=[
        gr.inputs.Audio(source="microphone", type='filepath', optional=True),
        gr.inputs.Audio(source="upload", type='filepath', optional=True),
        gr.inputs.Dropdown(target_language),
    ],
    outputs="text",
    layout="horizontal",
    theme="huggingface",
    title="XLS-R 2B 22-to-16 Speech Translation",
    description="A simple interface to translate from 22 input spoken languages to 16 written languages.",
    article = "<p style='text-align: center'><a href='https://huggingface.co/facebook/wav2vec2-xls-r-2b-22-to-16' target='_blank'>Click to learn more about XLS-R-2B-22-16 </a> | <a href='https://arxiv.org/abs/2111.09296' target='_blank'> With 🎙️ from Facebook XLS-R </a></p>",
    enable_queue=True,
    allow_flagging=False,
)
iface.launch()