Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from charts import spider_chart | |
from dictionaries import calculate_average, transform_dict | |
from icon import generate_icon | |
from transformers import pipeline | |
from timestamp import format_timestamp | |
from youtube import get_youtube_video_id | |
MODEL_NAME = "openai/whisper-medium" | |
BATCH_SIZE = 8 | |
device = 0 if torch.cuda.is_available() else "cpu" | |
pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=MODEL_NAME, | |
chunk_length_s=30, | |
device=device, | |
) | |
#Formating | |
title = "Whisper Demo: Transcribe Audio" | |
MODEL_NAME1 = "jpdiazpardo/whisper-tiny-metal" | |
description = ("Transcribe long-form audio inputs with the click of a button! Demo uses the" | |
f" checkpoint [{MODEL_NAME1}](https://huggingface.co/{MODEL_NAME1}) and 🤗 Transformers to transcribe audio files" | |
" of arbitrary length. Check some of the 'cool' examples below") | |
examples = [["https://www.youtube.com/watch?v=W72Lnz1n-jw&ab_channel=Whitechapel-Topic",None,None, | |
"When a Demon Defiles a Witch.wav",True, True], | |
["https://www.youtube.com/watch?v=BnO3Io0KOl4&ab_channel=MotionlessInWhite-Topic",None,None, | |
"Immaculate Misconception.wav",True, True]] | |
linkedin = generate_icon("linkedin") | |
github = generate_icon("github") | |
article = ("<div style='text-align: center; max-width:800px; margin:10px auto;'>" | |
f"<p>{linkedin} <a href='https://www.linkedin.com/in/juanpablodiazp/' target='_blank'>Juan Pablo Díaz Pardo</a><br>" | |
f"{github} <a href='https://github.com/jpdiazpardo' target='_blank'>jpdiazpardo</a></p>") | |
title = "Scream: Fine-Tuned Whisper model for automatic gutural speech recognition 🤟🤟🤟" | |
#------------------------------------------------------------------------------------------------------------------------------- | |
#Define classifier for sentiment analysis | |
classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=None) | |
def transcribe(*args):#file, return_timestamps, *kwargs): | |
'''inputs: file, return_timestamps''' | |
outputs = pipe(args[3], batch_size=BATCH_SIZE, generate_kwargs={"task": 'transcribe'}, return_timestamps=True) | |
text = outputs["text"] | |
timestamps = outputs["chunks"] | |
#If return timestamps is True, return html text with timestamps format | |
if args[4]==True: | |
spider_text = [f"{chunk['text']}" for chunk in timestamps] #Text for spider chart without timestamps | |
timestamps = [f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}" for chunk in timestamps] | |
else: | |
timestamps = [f"{chunk['text']}" for chunk in timestamps] | |
spider_text = timestamps | |
text = "<br>".join(str(feature) for feature in timestamps) | |
text = f"<h4>Transcription</h4><div style='overflow-y: scroll; height: 150px;'>{text}</div>" | |
spider_text = "\n".join(str(feature) for feature in spider_text) | |
trans_dict=[transform_dict(classifier.predict(t)[0]) for t in spider_text.split("\n")] | |
av_dict = calculate_average(trans_dict) | |
fig = spider_chart(av_dict) | |
return args[3], text, fig, av_dict | |
def filter(choice): | |
if choice=="YouTube": | |
return yt_link.update(interactive=True), audio_input.update(interactive=False) | |
elif choice == "Upload File": | |
return yt_link.update(value=None,interactive=False), audio_input.update(interactive=True) | |
else: | |
return yt_link.update(interactive=False), audio_input.update(interactive=False) | |
embed_html = '<iframe src="https://www.youtube.com/embed/YOUTUBE_ID'\ | |
'title="YouTube video player" frameborder="0" allow="accelerometer;'\ | |
'autoplay; clipboard-write; encrypted-media; gyroscope;'\ | |
'picture-in-picture" allowfullscreen></iframe>' | |
def download(link): | |
subprocess.run(['python3', 'youtubetowav.py', link]) | |
return thumbnail.update(value=embed_html.replace("YOUTUBE_ID",get_youtube_video_id(link)), visible=True) | |
def hide_sa(value): | |
if value == True: | |
return sa_plot.update(visible=True), sa_frequency.update(visible=True) | |
else: | |
return sa_plot.update(visible=False), sa_frequency.update(visible=False) | |
#Input components | |
yt_link = gr.Textbox(value=None,label="YouTube link", info = "Optional: Copy and paste YouTube URL") | |
audio_input = gr.Audio(source="upload", type="filepath", label="Upload audio file for transcription") | |
download_button = gr.Button("Download") | |
thumbnail = gr.HTML(value=embed_html, visible=False) | |
sa_checkbox = gr.Checkbox(value=True, label="Sentiment analysis") | |
inputs = [yt_link, #0 | |
download_button, #1 | |
thumbnail, #2 | |
audio_input, #3 | |
gr.Checkbox(value=True, label="Return timestamps"), #4 | |
sa_checkbox] #5 | |
#Ouput components | |
audio_out = gr.Audio(label="Processed Audio", type="filepath", info = "Vocals only") | |
sa_plot = gr.Plot(label="Sentiment Analysis") | |
sa_frequency = gr.Label(label="Frequency") | |
outputs = [audio_out, gr.outputs.HTML("text"), sa_plot, sa_frequency] | |
with gr.Blocks() as demo: | |
download_button.click(download, inputs=[yt_link], outputs=[thumbnail]) | |
sa_checkbox.change(hide_sa, inputs=[sa_checkbox], outputs=[sa_plot, sa_frequency]) | |
with gr.Column(): | |
gr.Interface(title = title, fn=transcribe, inputs = inputs, outputs = outputs, | |
description=description, cache_examples=True, allow_flagging="never", article = article , examples=examples) | |
demo.queue(concurrency_count=3) | |
demo.launch(debug = True) | |