import os os.system("pip install gradio==2.8.0b2") import gradio as gr import librosa from transformers import AutoFeatureExtractor, AutoTokenizer, SpeechEncoderDecoderModel import torch model_name = "facebook/wav2vec2-xls-r-2b-22-to-16" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) model = SpeechEncoderDecoderModel.from_pretrained(model_name).to(device) if torch.cuda.is_available(): model.half() def process_audio_file(file): data, sr = librosa.load(file) if sr != 16000: data = librosa.resample(data, sr, 16000) print(data.shape) input_values = feature_extractor(data, return_tensors="pt").input_values.to(device) if torch.cuda.is_available(): input_values = input_values.to(torch.float16) return input_values def transcribe(file_mic, file_upload, target_language): target_code = target_language.split("(")[-1].split(")")[0] forced_bos_token_id = MAPPING[target_code] warn_output = "" if (file_mic is not None) and (file_upload is not None): warn_output = "WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\n" file = file_mic elif (file_mic is None) and (file_upload is None): return "ERROR: You have to either use the microphone or upload an audio file" elif file_mic is not None: file = file_mic else: file = file_upload input_values = process_audio_file(file) sequences = model.generate(input_values, forced_bos_token_id=forced_bos_token_id) transcription = tokenizer.batch_decode(sequences, skip_special_tokens=True) return warn_output + transcription[0] target_language = [ "English (en)", "German (de)", "Turkish (tr)", "Persian (fa)", "Swedish (sv)", "Mongolian (mn)", "Chinese (zh)", "Welsh (cy)", "Catalan (ca)", "Slovenian (sl)", "Estonian (et)", "Indonesian (id)", "Arabic (ar)", "Tamil (ta)", "Latvian (lv)", "Japanese (ja)", ] MAPPING = { "en": 250004, "de": 250003, "tr": 250023, "fa": 250029, "sv": 250042, "mn": 250037, "zh": 250025, "cy": 250007, "ca": 250005, "sl": 250052, "et": 250006, "id": 250032, "ar": 250001, "ta": 250044, "lv": 250017, "ja": 250012, } iface = gr.Interface( fn=transcribe, inputs=[ gr.inputs.Audio(source="microphone", type='filepath', optional=True), gr.inputs.Audio(source="upload", type='filepath', optional=True), gr.inputs.Dropdown(target_language), ], outputs="text", layout="horizontal", theme="huggingface", title="XLS-R 2B 22-to-16 Speech Translation", description="A simple interface to translate from 22 input spoken languages to 16 written languages.", article = "

Click to learn more about XLS-R-2B-22-16 | With 🎙️ from Facebook XLS-R

", enable_queue=True, allow_flagging=False, ) iface.launch()