File size: 6,289 Bytes
a2c2b3c
b8b135b
039ef51
c2110e8
a2c2b3c
0adbe8f
a2c2b3c
 
fbf8e92
 
 
 
 
a2c2b3c
e559d03
800c39a
 
e559d03
 
 
a2c2b3c
e559d03
 
 
 
a2c2b3c
e559d03
 
a2c2b3c
febac70
 
e14dab8
febac70
 
 
 
 
 
 
 
 
 
 
 
55a586c
 
 
a2c2b3c
238a413
febac70
03978da
e559d03
77fc3c3
accb4e2
55a586c
03978da
800c39a
77fc3c3
bba23d3
77fc3c3
 
800c39a
c2110e8
accb4e2
febac70
55a586c
800c39a
 
 
 
55a586c
800c39a
14c0ec2
febac70
 
 
 
 
 
0d49d91
 
febac70
f15fab7
febac70
800c39a
 
 
 
 
a2c2b3c
 
 
 
 
febac70
238a413
 
 
0d49d91
 
800c39a
febac70
238a413
 
 
febac70
 
 
800c39a
 
febac70
6a22ca5
febac70
a2c2b3c
 
 
 
febac70
238a413
800c39a
febac70
 
 
 
0d49d91
efd7dc3
2994d17
febac70
 
a68a0c7
a2c2b3c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#Imports-------------------------------------------------------------
import gradio as gr
import subprocess
import torch
from transformers import pipeline
import os

#User defined functions (UDF)
from functions.charts import spider_chart
from functions.dictionaries import calculate_average, transform_dict
from functions.icon import generate_icon
from functions.timestamp import format_timestamp
from functions.youtube import get_youtube_video_id
#---------------------------------------------------------------------

MODEL_NAME = "openai/whisper-medium"
#MODEL_NAME = "jpdiazpardo/whisper-tiny-metal"
BATCH_SIZE = 8
device = 0 if torch.cuda.is_available() else "cpu"

#Transformers pipeline
pipe = pipeline(
    task="automatic-speech-recognition",
    model=MODEL_NAME,
    chunk_length_s=30,
    device=device
)

#Formating---------------------------------------------------------------------------------------------
title = "Whisper Demo: Transcribe Audio"
description = ("Transcribe long-form audio inputs with the click of a button! Demo uses the"
        f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
        " of arbitrary length. Check some of the 'cool' examples below")

linkedin = generate_icon("linkedin")
github = generate_icon("github")
article = ("<div style='text-align: center; max-width:800px; margin:10px auto;'>"
            f"<p>{linkedin} <a href='https://www.linkedin.com/in/juanpablodiazp/' target='_blank'>Juan Pablo Díaz Pardo</a><br>"
            f"{github} <a href='https://github.com/jpdiazpardo' target='_blank'>jpdiazpardo</a></p>")

title = "Scream: Fine-Tuned Whisper model for automatic gutural speech recognition 🤟🤟🤟"

#-------------------------------------------------------------------------------------------------------------------------------

#Define classifier for sentiment analysis
classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=None)

#Functions-----------------------------------------------------------------------------------------------------------------------
def transcribe(file,use_timestamps=True,sentiment_analysis=True):
    '''inputs: file, return_timestamps'''
    outputs = pipe(file, batch_size=BATCH_SIZE, generate_kwargs={"task": 'transcribe'}, return_timestamps=True)
    text = outputs["text"]
    timestamps = outputs["chunks"]

    #If return timestamps is True, return html text with timestamps format
    if use_timestamps==True:
      spider_text = [f"{chunk['text']}" for chunk in timestamps] #Text for spider chart without timestamps
      timestamps = [f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}" for chunk in timestamps]

    else:
      timestamps = [f"{chunk['text']}" for chunk in timestamps]
      spider_text = timestamps

    text = "<br>".join(str(feature) for feature in timestamps)
    text = f"<h4>Transcription</h4><div style='overflow-y: scroll; height: 150px;'>{text}</div>"

    spider_text = "\n".join(str(feature) for feature in spider_text)
    trans_dict=[transform_dict(classifier.predict(t)[0]) for t in spider_text.split("\n")]
    av_dict = calculate_average(trans_dict)
    fig = spider_chart(av_dict)
    
    return file, text, fig, av_dict

embed_html = '<iframe src="https://www.youtube.com/embed/YOUTUBE_ID'\
              'title="YouTube video player" frameborder="0" allow="accelerometer;'\
              'autoplay; clipboard-write; encrypted-media; gyroscope;'\
              'picture-in-picture" allowfullscreen></iframe>'

def download(link):
    '''Runs youtubetowav.py
    inputs: link from textbox'''
    subprocess.run(['python3', 'youtubetowav.py', link])
    return thumbnail.update(value=embed_html.replace("YOUTUBE_ID",get_youtube_video_id(link)))

def hide_sentiment(value):
  if value == True:
    return sentiment_plot.update(visible=True), sentiment_frequency.update(visible=True)
  else:
    return sentiment_plot.update(visible=False), sentiment_frequency.update(visible=False)

#----------------------------------------------------------------------------------------------------------------------------------------------
      
#Components------------------------------------------------------------------------------------------------------------------------------------      
      
#Input components        
#yt_link = gr.Textbox(value=None,label="YouTube link", info = "Optional: Copy and paste YouTube URL") #0
#download_button = gr.Button(value="Download") #1
#thumbnail = gr.HTML(value="", label = "Thumbnail") #2
audio_input = gr.Audio(source="upload", type="filepath", label="Upload audio file for transcription") #3
timestamp_checkbox = gr.Checkbox(value=True, label="Return timestamps") #4
sentiment_checkbox = gr.Checkbox(value=True, label="Sentiment analysis") #5

inputs = [audio_input, #0
          timestamp_checkbox, #1
          sentiment_checkbox] #2

#Ouput components
audio_out = gr.Audio(label="Processed Audio", type="filepath", info = "Vocals only") 
sentiment_plot = gr.Plot(label="Sentiment Analysis")
sentiment_frequency = gr.Label(label="Frequency")

outputs = [audio_out, gr.outputs.HTML("text"), sentiment_plot, sentiment_frequency]

#----------------------------------------------------------------------------------------------------------------------------------------------------

#Launch demo-----------------------------------------------------------------------------------------------------------------------------------------

with gr.Blocks() as demo:
  #download_button.click(download, inputs=[yt_link], outputs=[thumbnail])
  sentiment_checkbox.change(hide_sentiment, inputs=[sentiment_checkbox], outputs=[sentiment_plot, sentiment_frequency])


  with gr.Column():
      gr.Interface(title = title, fn=transcribe, inputs = inputs, outputs = outputs, 
                   description=description, cache_examples=True, allow_flagging="never", article = article,
                   examples='examples')

            
demo.queue(concurrency_count=3)
demo.launch(debug = True)
#----------------------------------------------------------------------------------------------------------------------------------------------------