Spaces:
Build error
Build error
import os | |
os.system("pip install gradio==2.8.0b2") | |
import gradio as gr | |
import librosa | |
from transformers import AutoFeatureExtractor, AutoTokenizer, SpeechEncoderDecoderModel | |
import torch | |
model_name = "facebook/wav2vec2-xls-r-2b-22-to-16" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
model = SpeechEncoderDecoderModel.from_pretrained(model_name).to(device) | |
if torch.cuda.is_available(): | |
model.half() | |
def process_audio_file(file): | |
data, sr = librosa.load(file) | |
if sr != 16000: | |
data = librosa.resample(data, sr, 16000) | |
print(data.shape) | |
input_values = feature_extractor(data, return_tensors="pt").input_values.to(device) | |
if torch.cuda.is_available(): | |
input_values = input_values.to(torch.float16) | |
return input_values | |
def transcribe(file_mic, file_upload, target_language): | |
target_code = target_language.split("(")[-1].split(")")[0] | |
forced_bos_token_id = MAPPING[target_code] | |
warn_output = "" | |
if (file_mic is not None) and (file_upload is not None): | |
warn_output = "WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\n" | |
file = file_mic | |
elif (file_mic is None) and (file_upload is None): | |
return "ERROR: You have to either use the microphone or upload an audio file" | |
elif file_mic is not None: | |
file = file_mic | |
else: | |
file = file_upload | |
input_values = process_audio_file(file) | |
sequences = model.generate(input_values, forced_bos_token_id=forced_bos_token_id) | |
transcription = tokenizer.batch_decode(sequences, skip_special_tokens=True) | |
return warn_output + transcription[0] | |
target_language = [ | |
"English (en)", | |
"German (de)", | |
"Turkish (tr)", | |
"Persian (fa)", | |
"Swedish (sv)", | |
"Mongolian (mn)", | |
"Chinese (zh)", | |
"Welsh (cy)", | |
"Catalan (ca)", | |
"Slovenian (sl)", | |
"Estonian (et)", | |
"Indonesian (id)", | |
"Arabic (ar)", | |
"Tamil (ta)", | |
"Latvian (lv)", | |
"Japanese (ja)", | |
] | |
MAPPING = { | |
"en": 250004, | |
"de": 250003, | |
"tr": 250023, | |
"fa": 250029, | |
"sv": 250042, | |
"mn": 250037, | |
"zh": 250025, | |
"cy": 250007, | |
"ca": 250005, | |
"sl": 250052, | |
"et": 250006, | |
"id": 250032, | |
"ar": 250001, | |
"ta": 250044, | |
"lv": 250017, | |
"ja": 250012, | |
} | |
iface = gr.Interface( | |
fn=transcribe, | |
inputs=[ | |
gr.inputs.Audio(source="microphone", type='filepath', optional=True), | |
gr.inputs.Audio(source="upload", type='filepath', optional=True), | |
gr.inputs.Dropdown(target_language), | |
], | |
outputs="text", | |
layout="horizontal", | |
theme="huggingface", | |
title="XLS-R 2B 22-to-16 Speech Translation", | |
description="A simple interface to translate from 22 input spoken languages to 16 written languages.", | |
article = "<p style='text-align: center'><a href='https://huggingface.co/facebook/wav2vec2-xls-r-2b-22-to-16' target='_blank'>Click to learn more about XLS-R-2B-22-16 </a> | <a href='https://arxiv.org/abs/2111.09296' target='_blank'> With ποΈ from Facebook XLS-R </a></p>", | |
enable_queue=True, | |
allow_flagging=False, | |
) | |
iface.launch() | |