patrickvonplaten's picture
Update app.py
0f0fe01
raw history blame
No virus
2.58 kB
import gradio as gr
import librosa
from transformers import AutoFeatureExtractor, AutoTokenizer, SpeechEncoderDecoderModel
import torch
model_name = "facebook/wav2vec2-xls-r-2b-22-to-16"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = SpeechEncoderDecoderModel.from_pretrained(model_name).to(device)
def process_audio_file(file):
data, sr = librosa.load(file)
if sr != 16000:
data = librosa.resample(data, sr, 16000)
print(data.shape)
input_values = feature_extractor(data, return_tensors="pt").input_values.to(device)
return input_values
def transcribe(file, target_language):
target_code = target_language.split("(")[-1].split(")")[0]
forced_bos_token_id = MAPPING[target_code]
input_values = process_audio_file(file)
sequences = model.generate(input_values, forced_bos_token_id=forced_bos_token_id, num_beams=1, max_length=30)
transcription = tokenizer.batch_decode(sequences, skip_special_tokens=True)
return transcription[0]
target_language = [
"English (en)",
"German (de)",
"Turkish (tr)",
"Persian (fa)",
"Swedish (sv)",
"Mongolian (mn)",
"Chinese (zh)",
"Welsh (cy)",
"Catalan (ca)",
"Slovenian (sl)",
"Estonian (et)",
"Indonesian (id)",
"Arabic (ar)",
"Tamil (ta)",
"Latvian (lv)",
"Japanese (ja)",
]
MAPPING = {
"en": 250004,
"de": 250003,
"tr": 250023,
"fa": 250029,
"sv": 250042,
"mn": 250037,
"zh": 250025,
"cy": 250007,
"ca": 250005,
"sl": 250052,
"et": 250006,
"id": 250032,
"ar": 250001,
"ta": 250044,
"lv": 250017,
"ja": 250012,
}
iface = gr.Interface(
fn=transcribe,
inputs=[
gr.inputs.Audio(source="microphone", type='filepath'),
gr.inputs.Dropdown(target_language),
],
outputs="text",
layout="horizontal",
theme="huggingface",
title="XLS-R 2B 22-to-16 Speech Translation",
description="A simple interface to translate from 22 input spoken languages to 16 written languages.",
article = "<p style='text-align: center'><a href='https://huggingface.co/facebook/wav2vec2-xls-r-2b-22-to-16' target='_blank'>Click to learn more about XLS-R-2B-22-16 </a> | <a href='https://arxiv.org/abs/2111.09296' target='_blank'> With πŸŽ™οΈ from Facebook XLS-R </a></p>",
enable_queue=True,
allow_flagging=False,
)
iface.launch()