enoreyes commited on
Commit
bb4dfd1
1 Parent(s): d85e0c1

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +130 -1
utils.py CHANGED
@@ -1,3 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  alphabets= "([A-Za-z])"
3
  prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
@@ -30,4 +45,118 @@ def split_into_sentences(text):
30
  sentences = text.split("<stop>")
31
  sentences = sentences[:-1]
32
  sentences = [s.strip() for s in sentences]
33
- return sentences
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import re
3
+ import functools
4
+
5
+ import requests
6
+ import pandas as pd
7
+ import plotly.express as px
8
+
9
+ import torch
10
+ import gradio as gr
11
+ from transformers import pipeline, Wav2Vec2ProcessorWithLM
12
+ from pyannote.audio import Pipeline
13
+ from librosa import load, resample
14
+ import whisperx
15
+
16
  import re
17
  alphabets= "([A-Za-z])"
18
  prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
 
45
  sentences = text.split("<stop>")
46
  sentences = sentences[:-1]
47
  sentences = [s.strip() for s in sentences]
48
+ return sentences
49
+
50
+
51
+ def summarize(diarized, check, summarization_pipeline):
52
+ """
53
+ diarized: a list of tuples. Each tuple has a string to be displayed and a label for highlighting.
54
+ The start/end times are not highlighted [(speaker text, speaker id), (start time/end time, None)]
55
+ check is a list of speaker ids whose speech will get summarized
56
+ """
57
+
58
+ if not check:
59
+ return ""
60
+
61
+ # Combine text based on the speaker id
62
+ text_lines = [f"{d[1]}: {d[0]}" if len(check) == 2 and d[1] is not None else d[0] for d in diarized if d[1] in check]
63
+ text = "\n".join(text_lines)
64
+
65
+ # Cache the inner function because the outer function cannot be cached
66
+ @functools.lru_cache(maxsize=128)
67
+ def call_summarize_api(text):
68
+ return summarization_pipeline(text)[0]["summary_text"]
69
+
70
+ return call_summarize_api(text)
71
+
72
+
73
+ # display if the sentiment value is above these thresholds
74
+ thresholds = {
75
+ "joy": 0.99,
76
+ "anger": 0.95,
77
+ "surprise": 0.95,
78
+ "sadness": 0.98,
79
+ "fear": 0.95,
80
+ "love": 0.99,
81
+ }
82
+
83
+ color_map = {
84
+ "joy": "green",
85
+ "anger": "red",
86
+ "surprise": "yellow",
87
+ "sadness": "blue",
88
+ "fear": "orange",
89
+ "love": "purple",
90
+ }
91
+
92
+
93
+ def sentiment(diarized, emotion_pipeline):
94
+ def split_into_intervals(speaker_speech, start_time, end_time):
95
+ sentences = split_into_sentences(speaker_speech)
96
+ interval_size = (end_time - start_time) / len(sentences)
97
+ return sentences, interval_size
98
+
99
+ def process_customer_emotion(outputs, sentences, start_time, interval_size):
100
+ sentiments = []
101
+ for idx, (o, t) in enumerate(zip(outputs, sentences)):
102
+ sent = "neutral"
103
+ if o["score"] > thresholds[o["label"]]:
104
+ sentiments.append((t + f"({round(idx*interval_size+start_time,1)} s)", o["label"]))
105
+ if o["label"] in {"joy", "love", "surprise"}:
106
+ sent = "positive"
107
+ elif o["label"] in {"sadness", "anger", "fear"}:
108
+ sent = "negative"
109
+ if sent != "neutral":
110
+ to_plot.append((start_time + idx * interval_size, sent))
111
+ plot_sentences.append(t)
112
+ return sentiments
113
+
114
+ x_min = 100
115
+ x_max = 0
116
+
117
+ customer_sentiments, to_plot, plot_sentences = [], [], []
118
+
119
+ for i in range(0, len(diarized), 2):
120
+ speaker_speech, speaker_id = diarized[i]
121
+ times, _ = diarized[i + 1]
122
+ start_time, end_time = map(float, times[5:].split("-"))
123
+ x_min, x_max = min(x_min, start_time), max(x_max, end_time)
124
+
125
+ if "Customer" in speaker_id:
126
+ sentences, interval_size = split_into_intervals(speaker_speech, start_time, end_time)
127
+ outputs = emotion_pipeline(sentences)
128
+ customer_sentiments.extend(process_customer_emotion(outputs, sentences, start_time, interval_size))
129
+
130
+ plot_df = pd.DataFrame(data={"x": [x for x, _ in to_plot], "y": [y for _, y in to_plot], "sentence": plot_sentences})
131
+ fig = px.line(plot_df, x="x", y="y", hover_data={"sentence": True, "x": True, "y": False}, labels={"x": "time (seconds)", "y": "sentiment"}, title=f"Customer sentiment over time", markers=True)
132
+ fig.update_yaxes(categoryorder="category ascending")
133
+ fig.update_layout(font=dict(size=18), xaxis_range=[x_min - 5, x_max + 5])
134
+
135
+ return customer_sentiments, fig
136
+
137
+ def speech_to_text(speech_file, speaker_segmentation, whisper, alignment_model, metadata, whisper_device):
138
+
139
+ def process_chunks(turn, chunks):
140
+ diarized = ""
141
+ i = 0
142
+ while i < len(chunks) and chunks[i]["end"] <= turn.end:
143
+ diarized += chunks[i]["text"] + " "
144
+ i += 1
145
+ return diarized, i
146
+
147
+ speaker_output = speaker_segmentation(speech_file)
148
+ result = whisper.transcribe(speech_file)
149
+ chunks = whisperx.align(result["segments"], alignment_model, metadata, speech_file, whisper_device)["word_segments"]
150
+
151
+ diarized_output = []
152
+ i = 0
153
+ speaker_counter = 0
154
+
155
+ for turn, _, _ in speaker_output.itertracks(yield_label=True):
156
+ speaker = "Customer" if speaker_counter % 2 == 0 else "Support"
157
+ diarized, i = process_chunks(turn, chunks[i:])
158
+ if diarized:
159
+ diarized_output.extend([(diarized, speaker), (f"from {turn.start:.2f}-{turn.end:.2f}", None)])
160
+ speaker_counter += 1
161
+
162
+ return diarized_output