Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
from transformers.models.whisper.tokenization_whisper import LANGUAGES | |
from transformers.pipelines.audio_utils import ffmpeg_read | |
import gradio as gr | |
device = "cuda" if torch.cuda.is_available() else "CPU" | |
model_ckpt = "ivanlau/language-detection-fine-tuned-on-xlm-roberta-base" | |
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt) | |
tokenizer = AutoTokenizer.from_pretrained(model_ckpt) | |
def detect_language(sentence): | |
tokenized_sentence = tokenizer(sentence, return_tensors='pt') | |
output = model(**tokenized_sentence) | |
predictions = torch.nn.functional.softmax(output.logits, dim=-1) | |
probability, pred_idx = torch.max(predictions, dim=-1) | |
language = LANGUANGE_MAP[pred_idx.item()] | |
return language, probability.item() | |
def process_audio_file(file): | |
with open(file, "rb") as f: | |
inputs = f.read() | |
audio = ffmpeg_read(inputs, sampling_rate) | |
return audio | |
def transcribe(Microphone, File_Upload): | |
warn_output = "" | |
if (Microphone is not None) and (File_Upload is not None): | |
warn_output = "WARNING: You've uploaded an audio file and used the microphone. " \ | |
"The recorded file from the microphone will be used and the uploaded audio will be discarded.\n" | |
file = Microphone | |
elif (Microphone is None) and (File_Upload is None): | |
return "ERROR: You have to either use the microphone or upload an audio file" | |
elif Microphone is not None: | |
file = Microphone | |
else: | |
file = File_Upload | |
audio_data = process_audio_file(file) | |
input_features = processor(audio_data, return_tensors="pt").input_features | |
with torch.no_grad(): | |
logits = model.forward(input_features.to(device), decoder_input_ids=decoder_input_ids).logits | |
pred_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.decode(pred_ids[0]) | |
detect_language(transcription.capitalize()) | |
examples=['sample1.mp3', 'sample2.mp3', 'sample3.mp3'] | |
outputs=gr.outputs.Label(label="Language detected:") | |
article = """ | |
Fine-tuned on xlm-roberta-base model.\n | |
Supported languages:\n | |
'Arabic', 'Basque', 'Breton', 'Catalan', 'Chinese_China', 'Chinese_Hongkong', 'Chinese_Taiwan', 'Chuvash', 'Czech', | |
'Dhivehi', 'Dutch', 'English', 'Esperanto', 'Estonian', 'French', 'Frisian', 'Georgian', 'German', 'Greek', 'Hakha_Chin', | |
'Indonesian', 'Interlingua', 'Italian', 'Japanese', 'Kabyle', 'Kinyarwanda', 'Kyrgyz', 'Latvian', 'Maltese', | |
'Mangolian', 'Persian', 'Polish', 'Portuguese', 'Romanian', 'Romansh_Sursilvan', 'Russian', 'Sakha', 'Slovenian', | |
'Spanish', 'Swedish', 'Tamil', 'Tatar', 'Turkish', 'Ukranian', 'Welsh' | |
""" | |
gr.Interface( | |
fn=detect_language, | |
fn=transcribe, | |
inputs=[ | |
gr.inputs.Audio(source="microphone", type='filepath', optional=True), | |
gr.inputs.Audio(source="upload", type='filepath', optional=True), | |
], | |
outputs=outputs=[ | |
gr.outputs.Textbox(label="Language"), | |
gr.Number(label="Probability"), | |
], | |
verbose=True, | |
examples = examples, | |
title="Language Identification from Audio", | |
description="Detect the Language from Audio.", | |
article=article, | |
theme="huggingface" | |
).launch() | |