lang_id_testing / app.py
barto17's picture
Update app.py
01368ce
raw
history blame
6.97 kB
#References — Code:
#Whisper 1: https://huggingface.co/spaces/chinhon/whisper_transcribe/blob/main/app.py
#Whisper 2: https://huggingface.co/spaces/sanchit-gandhi/whisper-language-id/blob/main/app.py
#Roberta: https://huggingface.co/spaces/ivanlau/language-detection-xlm-roberta-base/blob/main/app.py
import torch
from pydub import AudioSegment
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.pipelines.audio_utils import ffmpeg_read
import gradio as gr
from pytube import YouTube
#from transformers import WhisperForConditionalGeneration, WhisperProcessor
#from transformers.models.whisper.tokenization_whisper import LANGUAGES
#from transformers.pipelines.audio_utils import ffmpeg_read
model_id = "openai/whisper-large-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
LANGUANGE_MAP = {
0: 'Arabic',
1: 'Basque',
2: 'Breton',
3: 'Catalan',
4: 'Chinese_China',
5: 'Chinese_Hongkong',
6: 'Chinese_Taiwan',
7: 'Chuvash',
8: 'Czech',
9: 'Dhivehi',
10: 'Dutch',
11: 'English',
12: 'Esperanto',
13: 'Estonian',
14: 'French',
15: 'Frisian',
16: 'Georgian',
17: 'German',
18: 'Greek',
19: 'Hakha_Chin',
20: 'Indonesian',
21: 'Interlingua',
22: 'Italian',
23: 'Japanese',
24: 'Kabyle',
25: 'Kinyarwanda',
26: 'Kyrgyz',
27: 'Latvian',
28: 'Maltese',
29: 'Mongolian',
30: 'Persian',
31: 'Polish',
32: 'Portuguese',
33: 'Romanian',
34: 'Romansh_Sursilvan',
35: 'Russian',
36: 'Sakha',
37: 'Slovenian',
38: 'Spanish',
39: 'Swedish',
40: 'Tamil',
41: 'Tatar',
42: 'Turkish',
43: 'Ukranian',
44: 'Welsh'
}
import whisper
# define function for transcription
def transcribe(Microphone, File_Upload, URL):
warn_output = ""
if (Microphone is not None) and (File_Upload is not None):
warn_output = "WARNING: You've uploaded an audio file and used the microphone. " \
"The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
file = Microphone
elif Microphone is not None:
file = Microphone
#elif URL:
# link = YouTube(URL)
# file = link.streams.filter(only_audio=True)[0].download(filename="audio.mp3")
elif URL:
link = YouTube(URL)
stream = link.streams.filter(only_audio=True).first()
# Download the audio file with a temporary filename
temp_filename = "temp_audio_file"
stream.download(filename=temp_filename)
# Load the downloaded file with pydub and convert it to mp3
audio = AudioSegment.from_file(temp_filename, format="mp4")
# Truncate it to the first 30 seconds
truncated_audio = audio[:30000] # AudioSegment works in milliseconds
file = "file.mp3"
truncated_audio.export(file, format="mp3")
else:
file = File_Upload
language = None
options = whisper.DecodingOptions(without_timestamps=True)
loaded_model = whisper.load_model("base")
transcript = loaded_model.transcribe(file, language=language)
return detect_language(transcript["text"])
def detect_language(sentence):
model_ckpt = "barto17/language-detection-fine-tuned-on-xlm-roberta-base"
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt)
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
tokenized_sentence = tokenizer(sentence, return_tensors='pt')
output = model(**tokenized_sentence)
predictions = torch.nn.functional.softmax(output.logits, dim=-1)
probability, pred_idx = torch.max(predictions, dim=-1)
language = LANGUANGE_MAP[pred_idx.item()]
return sentence, language, probability.item()
"""
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id)
model.eval()
model.to(device)
bos_token_id = processor.tokenizer.all_special_ids[-106]
decoder_input_ids = torch.tensor([bos_token_id]).to(device)
def process_audio_file(file, sampling_rate):
with open(file, "rb") as f:
inputs = f.read()
audio = ffmpeg_read(inputs, sampling_rate)
print(audio)
return audio
def transcribe(Microphone, File_Upload):
warn_output = ""
if (Microphone is not None) and (File_Upload is not None):
warn_output = "WARNING: You've uploaded an audio file and used the microphone. " \
"The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
file = Microphone
elif (Microphone is None) and (File_Upload is None):
return "ERROR: You have to either use the microphone or upload an audio file"
elif Microphone is not None:
file = Microphone
else:
file = File_Upload
sampling_rate = processor.feature_extractor.sampling_rate
audio_data = process_audio_file(file, sampling_rate)
input_features = processor(audio_data, return_tensors="pt").input_features
with torch.no_grad():
logits = model.forward(input_features.to(device), decoder_input_ids=decoder_input_ids).logits
pred_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(pred_ids[0])
language, probability = detect_language(transcription)
return transcription.capitalize(), language, probability
"""
examples=['sample1.mp3', 'sample2.mp3', 'sample3.mp3']
examples = [[f"./{f}"] for f in examples]
outputs=gr.outputs.Label(label="Language detected:")
article = """
Fine-tuned on xlm-roberta-base model.\n
Supported languages:\n
'Arabic', 'Basque', 'Breton', 'Catalan', 'Chinese_China', 'Chinese_Hongkong', 'Chinese_Taiwan', 'Chuvash', 'Czech',
'Dhivehi', 'Dutch', 'English', 'Esperanto', 'Estonian', 'French', 'Frisian', 'Georgian', 'German', 'Greek', 'Hakha_Chin',
'Indonesian', 'Interlingua', 'Italian', 'Japanese', 'Kabyle', 'Kinyarwanda', 'Kyrgyz', 'Latvian', 'Maltese',
'Mangolian', 'Persian', 'Polish', 'Portuguese', 'Romanian', 'Romansh_Sursilvan', 'Russian', 'Sakha', 'Slovenian',
'Spanish', 'Swedish', 'Tamil', 'Tatar', 'Turkish', 'Ukranian', 'Welsh'
"""
gr.Interface(
fn=transcribe,
inputs=[
gr.inputs.Audio(source="microphone", type='filepath', optional=True),
gr.inputs.Audio(source="upload", type='filepath', optional=True),
gr.Textbox(label="Paste YouTube link here [For computation purposes: only first 30 seconds are evaluated]"),
],
outputs=[
gr.outputs.Textbox(label="Transcription"),
gr.outputs.Textbox(label="Language"),
gr.Number(label="Probability"),
],
verbose=True,
examples = examples,
title="Language Identification from Audio",
description="Detect the Language from Audio.",
article=article,
theme="huggingface"
).launch()