Spaces:
Build error
Build error
import os | |
import re | |
import functools | |
from functools import partial | |
import requests | |
import pandas as pd | |
import plotly.express as px | |
import torch | |
import gradio as gr | |
from transformers import pipeline, Wav2Vec2ProcessorWithLM | |
from pyannote.audio import Pipeline | |
import whisperx | |
from utils import split, create_fig, color_map, thresholds | |
from utils import speech_to_text as stt | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
device = 0 if torch.cuda.is_available() else -1 | |
# Audio components | |
whisper_device = "cuda" if torch.cuda.is_available() else "cpu" | |
whisper = whisperx.load_model("tiny.en", whisper_device) | |
alignment_model, metadata = whisperx.load_align_model(language_code="en", device=whisper_device) | |
speaker_segmentation = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1", | |
use_auth_token=os.environ['ENO_TOKEN']) | |
# Text components | |
emotion_pipeline = pipeline( | |
"text-classification", | |
model="bhadresh-savani/distilbert-base-uncased-emotion", | |
device=device, | |
) | |
summarization_pipeline = pipeline( | |
"summarization", | |
model="knkarthick/MEETING_SUMMARY", | |
device=device | |
) | |
EXAMPLES = [["Customer_Support_Call.wav"]] | |
speech_to_text = partial( | |
stt, | |
speaker_segmentation=speaker_segmentation, | |
whisper=whisper, | |
alignment_model=alignment_model, | |
metadata=metadata, | |
whisper_device=whisper_device | |
) | |
def summarize(diarized, check, summarization_pipeline): | |
""" | |
diarized: a list of tuples. Each tuple has a string to be displayed and a label for highlighting. | |
The start/end times are not highlighted [(speaker text, speaker id), (start time/end time, None)] | |
check is a list of speaker ids whose speech will get summarized | |
""" | |
if len(check) == 0: | |
return "" | |
text = "" | |
for d in diarized: | |
if len(check) == 2 and d[1] is not None: | |
text += f"\n{d[1]}: {d[0]}" | |
elif d[1] in check: | |
text += f"\n{d[0]}" | |
# inner function cached because outer function cannot be cached | |
def call_summarize_api(text): | |
return summarization_pipeline(text)[0]["summary_text"] | |
return call_summarize_api(text) | |
def sentiment(diarized, emotion_pipeline): | |
""" | |
diarized: a list of tuples. Each tuple has a string to be displayed and a label for highlighting. | |
The start/end times are not highlighted [(speaker text, speaker id), (start time/end time, None)] | |
This function gets the customer's sentiment and returns a list for highlighted text. | |
""" | |
customer_sentiments = [] | |
sentences = split(speaker_speech) | |
for i in range(0, len(diarized), 2): | |
speaker_speech, speaker_id = diarized[i] | |
times, _ = diarized[i + 1] | |
start_time, end_time = times[5:].split("-") | |
start_time, end_time = float(start_time), float(end_time) | |
interval_size = (end_time - start_time) / len(sentences) | |
sentences = split_into_sentences(speaker_speech) | |
if "Customer" in speaker_id: | |
outputs = emotion_pipeline(sentences) | |
for idx, (o, t) in enumerate(zip(outputs, sentences)): | |
sent = "neutral" | |
if o["score"] > thresholds[o["label"]]: | |
customer_sentiments.append( | |
(t + f"({round(idx*interval_size+start_time,1)} s)", o["label"]) | |
) | |
if o["label"] in {"joy", "love", "surprise"}: | |
sent = "positive" | |
elif o["label"] in {"sadness", "anger", "fear"}: | |
sent = "negative" | |
return customer_sentiments | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
audio = gr.Audio(label="Audio file", type="filepath") | |
btn = gr.Button("Transcribe and Diarize") | |
gr.Markdown("**Call Transcript:**") | |
diarized = gr.HighlightedText(label="Call Transcript") | |
gr.Markdown("Choose speaker to summarize:") | |
check = gr.CheckboxGroup( | |
choices=["Customer", "Support"], show_label=False, type="value" | |
) | |
summary = gr.Textbox(lines=4) | |
sentiment_btn = gr.Button("Get Customer Sentiment") | |
analyzed = gr.HighlightedText(color_map=color_map) | |
with gr.Column(): | |
gr.Markdown("## Example Files") | |
gr.Examples( | |
examples=EXAMPLES, | |
inputs=[audio], | |
outputs=[diarized], | |
fn=speech_to_text, | |
cache_examples=True | |
) | |
# when example button is clicked, convert audio file to text and diarize | |
btn.click( | |
fn=speech_to_text, | |
inputs=audio, | |
outputs=diarized, | |
) | |
# when summarize checkboxes are changed, create summary | |
check.change(fn=partial(summarize, summarization_pipeline=summarization_pipeline), inputs=[diarized, check], outputs=summary) | |
# when sentiment button clicked, display highlighted text and plot | |
sentiment_btn.click(fn=partial(sentiment, emotion_pipeline=emotion_pipeline), inputs=diarized, outputs=[analyzed]) | |
demo.launch(debug=1) |