enoreyes commited on
Commit
6e77739
1 Parent(s): bb4dfd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -234
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import re
3
  import functools
 
4
 
5
  import requests
6
  import pandas as pd
@@ -8,75 +9,24 @@ import plotly.express as px
8
 
9
  import torch
10
  import gradio as gr
11
- from transformers import pipeline, WhisperProcessor
12
  from pyannote.audio import Pipeline
13
- from librosa import load, resample
14
- from rpunct import RestorePuncts
15
 
16
- from utils import split_into_sentences
 
17
 
18
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
  device = 0 if torch.cuda.is_available() else -1
20
 
21
- # summarization is done over inference API
22
- headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
23
- summarization_url = (
24
- "https://api-inference.huggingface.co/models/knkarthick/MEETING_SUMMARY"
25
- )
26
-
27
- # There was an error related to Non-english text being detected,
28
- # so this regular expression gets rid of any weird character.
29
- # This might be completely unnecessary.
30
- eng_pattern = r"[^\d\s\w'\.\,\?]"
31
-
32
-
33
- def summarize(diarized, check):
34
- """
35
- diarized: a list of tuples. Each tuple has a string to be displayed and a label for highlighting.
36
- The start/end times are not highlighted [(speaker text, speaker id), (start time/end time, None)]
37
- check is a list of speaker ids whose speech will get summarized
38
- """
39
-
40
- if len(check) == 0:
41
- return ""
42
-
43
- text = ""
44
- for d in diarized:
45
- if len(check) == 2 and d[1] is not None:
46
- text += f"\n{d[1]}: {d[0]}"
47
- elif d[1] in check:
48
- text += f"\n{d[0]}"
49
-
50
- # inner function cached because outer function cannot be cached
51
- @functools.lru_cache(maxsize=128)
52
- def call_summarize_api(text):
53
- payload = {
54
- "inputs": text,
55
- "options": {
56
- "use_gpu": False,
57
- "wait_for_model": True,
58
- },
59
- }
60
- response = requests.post(summarization_url, headers=headers, json=payload)
61
- return response.json()[0]["summary_text"]
62
-
63
- return call_summarize_api(text)
64
-
65
 
66
  # Audio components
67
- asr_model = "openai/whisper-large"
68
- processor = WhisperProcessor.from_pretrained(asr_model)
69
- asr = pipeline(
70
- "automatic-speech-recognition",
71
- model=asr_model,
72
- tokenizer=processor.tokenizer,
73
- feature_extractor=processor.feature_extractor,
74
- device=device,
75
- chunk_length_s=6, # 12
76
- stride_length_s=(2, 3), # must have with chunk_length_s
77
- )
78
- speaker_segmentation = Pipeline.from_pretrained("pyannote/speaker-segmentation", use_auth_token='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')
79
- rpunct = RestorePuncts()
80
 
81
  # Text components
82
  emotion_pipeline = pipeline(
@@ -84,168 +34,31 @@ emotion_pipeline = pipeline(
84
  model="bhadresh-savani/distilbert-base-uncased-emotion",
85
  device=device,
86
  )
 
 
 
 
 
87
 
88
- EXAMPLES = [["example_audio.wav"], ["Customer_Support_Call.wav"]]
89
-
90
- # display if the sentiment value is above these thresholds
91
- thresholds = {
92
- "joy": 0.99,
93
- "anger": 0.95,
94
- "surprise": 0.95,
95
- "sadness": 0.98,
96
- "fear": 0.95,
97
- "love": 0.99,
98
- }
99
-
100
-
101
- def speech_to_text(speech):
102
- speaker_output = speaker_segmentation(speech)
103
- speech, sampling_rate = load(speech)
104
- if sampling_rate != 16000:
105
- speech = resample(speech, sampling_rate, 16000)
106
- text = asr(speech)
107
- print(text)
108
- chunks = text["chunks"]
109
-
110
- diarized_output = []
111
- i = 0
112
- speaker_counter = 0
113
-
114
- # New iteration every time the speaker changes
115
- for turn, _, _ in speaker_output.itertracks(yield_label=True):
116
-
117
- speaker = "Customer" if speaker_counter % 2 == 0 else "Support"
118
- diarized = ""
119
- while i < len(chunks) and chunks[i]["timestamp"][1] <= turn.end:
120
- diarized += chunks[i]["text"].lower() + " "
121
- i += 1
122
-
123
- if diarized != "":
124
- diarized = rpunct.punctuate(re.sub(eng_pattern, "", diarized), lang="en")
125
-
126
- diarized_output.extend(
127
- [
128
- (diarized, speaker),
129
- ("from {:.2f}-{:.2f}".format(turn.start, turn.end), None),
130
- ]
131
- )
132
-
133
- speaker_counter += 1
134
-
135
- return diarized_output
136
-
137
-
138
- def sentiment(diarized):
139
- """
140
- diarized: a list of tuples. Each tuple has a string to be displayed and a label for highlighting.
141
- The start/end times are not highlighted [(speaker text, speaker id), (start time/end time, None)]
142
-
143
- This function gets the customer's sentiment and returns a list for highlighted text as well
144
- as a plot of sentiment over time.
145
- """
146
-
147
- customer_sentiments = []
148
-
149
- to_plot = []
150
- plot_sentences = []
151
-
152
- # used to set the x range of ticks on the plot
153
- x_min = 100
154
- x_max = 0
155
-
156
- for i in range(0, len(diarized), 2):
157
- speaker_speech, speaker_id = diarized[i]
158
- times, _ = diarized[i + 1]
159
-
160
- sentences = split_into_sentences(speaker_speech)
161
- start_time, end_time = times[5:].split("-")
162
- start_time, end_time = float(start_time), float(end_time)
163
- interval_size = (end_time - start_time) / len(sentences)
164
-
165
- if "Customer" in speaker_id:
166
-
167
- outputs = emotion_pipeline(sentences)
168
-
169
- for idx, (o, t) in enumerate(zip(outputs, sentences)):
170
- sent = "neutral"
171
- if o["score"] > thresholds[o["label"]]:
172
- customer_sentiments.append(
173
- (t + f"({round(idx*interval_size+start_time,1)} s)", o["label"])
174
- )
175
- if o["label"] in {"joy", "love", "surprise"}:
176
- sent = "positive"
177
- elif o["label"] in {"sadness", "anger", "fear"}:
178
- sent = "negative"
179
- if sent != "neutral":
180
- to_plot.append((start_time + idx * interval_size, sent))
181
- plot_sentences.append(t)
182
-
183
- if start_time < x_min:
184
- x_min = start_time
185
- if end_time > x_max:
186
- x_max = end_time
187
-
188
- x_min -= 5
189
- x_max += 5
190
-
191
- x, y = list(zip(*to_plot))
192
-
193
- plot_df = pd.DataFrame(
194
- data={
195
- "x": x,
196
- "y": y,
197
- "sentence": plot_sentences,
198
- }
199
- )
200
 
201
- fig = px.line(
202
- plot_df,
203
- x="x",
204
- y="y",
205
- hover_data={
206
- "sentence": True,
207
- "x": True,
208
- "y": False,
209
- },
210
- labels={"x": "time (seconds)", "y": "sentiment"},
211
- title=f"Customer sentiment over time",
212
- )
213
 
214
- fig = fig.update_yaxes(categoryorder="category ascending")
215
- fig = fig.update_layout(
216
- font=dict(
217
- size=18,
218
- ),
219
- xaxis_range=[x_min, x_max],
 
220
  )
221
 
222
- return customer_sentiments, fig
223
-
224
-
225
- demo = gr.Blocks(enable_queue=True)
226
- demo.encrypt = False
227
 
228
- # for highlighting purposes
229
- color_map = {
230
- "joy": "green",
231
- "anger": "red",
232
- "surprise": "yellow",
233
- "sadness": "blue",
234
- "fear": "orange",
235
- "love": "purple",
236
- }
237
-
238
- with demo:
239
  with gr.Row():
240
  with gr.Column():
241
  audio = gr.Audio(label="Audio file", type="filepath")
242
- with gr.Row():
243
- btn = gr.Button("Transcribe")
244
- with gr.Row():
245
- examples = gr.components.Dataset(
246
- components=[audio], samples=EXAMPLES, type="index"
247
- )
248
- with gr.Column():
249
  gr.Markdown("**Call Transcript:**")
250
  diarized = gr.HighlightedText(label="Call Transcript")
251
  gr.Markdown("Choose speaker to summarize:")
@@ -257,31 +70,26 @@ with demo:
257
  analyzed = gr.HighlightedText(color_map=color_map)
258
  plot = gr.Plot(label="Sentiment over time", type="plotly")
259
 
 
 
 
 
 
 
 
 
 
260
  # when example button is clicked, convert audio file to text and diarize
261
  btn.click(
262
- speech_to_text,
263
- audio,
264
- [diarized],
265
- status_tracker=gr.StatusTracker(cover_container=True),
266
  )
267
  # when summarize checkboxes are changed, create summary
268
- check.change(summarize, [diarized, check], summary)
269
 
270
  # when sentiment button clicked, display highlighted text and plot
271
- sentiment_btn.click(sentiment, [diarized], [analyzed, plot])
272
-
273
-
274
- def cache_example(example):
275
- diarized_output = speech_to_text(example)
276
- return audio, diarized_output
277
 
278
- cache = [cache_example(e[0]) for e in EXAMPLES]
279
-
280
- def load_example(example_id):
281
- return cache[example_id]
282
-
283
- examples._click_no_postprocess(
284
- load_example, inputs=[examples], outputs=[audio, diarized], queue=False
285
- )
286
 
287
- demo.launch(debug=1)
 
1
  import os
2
  import re
3
  import functools
4
+ from functools import partial
5
 
6
  import requests
7
  import pandas as pd
 
9
 
10
  import torch
11
  import gradio as gr
12
+ from transformers import pipeline, Wav2Vec2ProcessorWithLM
13
  from pyannote.audio import Pipeline
14
+ import whisperx
 
15
 
16
+ from utils import split_into_sentences, summarize, sentiment, color_map
17
+ from utils import speech_to_text as stt
18
 
19
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
  device = 0 if torch.cuda.is_available() else -1
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Audio components
24
+ whisper_device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ whisper = whisperx.load_model("tiny.en", whisper_device)
26
+ alignment_model, metadata = whisperx.load_align_model(language_code="en", device=whisper_device)
27
+ speaker_segmentation = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
28
+ use_auth_token=os.environ['HF_TOKEN'])
29
+
 
 
 
 
 
 
 
30
 
31
  # Text components
32
  emotion_pipeline = pipeline(
 
34
  model="bhadresh-savani/distilbert-base-uncased-emotion",
35
  device=device,
36
  )
37
+ summarization_pipeline = pipeline(
38
+ "summarization",
39
+ model="knkarthick/MEETING_SUMMARY",
40
+ device=device
41
+ )
42
 
43
+ EXAMPLES = [["Customer_Support_Call.wav"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ speech_to_text = partial(
47
+ stt,
48
+ speaker_segmentation=speaker_segmentation,
49
+ whisper=whisper,
50
+ alignment_model=alignment_model,
51
+ metadata=metadata,
52
+ whisper_device=whisper_device
53
  )
54
 
55
+ with gr.Blocks() as demo:
 
 
 
 
56
 
 
 
 
 
 
 
 
 
 
 
 
57
  with gr.Row():
58
  with gr.Column():
59
  audio = gr.Audio(label="Audio file", type="filepath")
60
+ btn = gr.Button("Transcribe and Diarize")
61
+
 
 
 
 
 
62
  gr.Markdown("**Call Transcript:**")
63
  diarized = gr.HighlightedText(label="Call Transcript")
64
  gr.Markdown("Choose speaker to summarize:")
 
70
  analyzed = gr.HighlightedText(color_map=color_map)
71
  plot = gr.Plot(label="Sentiment over time", type="plotly")
72
 
73
+ with gr.Column():
74
+ gr.Markdown("## Example Files")
75
+ gr.Examples(
76
+ examples=EXAMPLES,
77
+ inputs=[audio],
78
+ outputs=[diarized],
79
+ fn=speech_to_text,
80
+ cache_examples=True
81
+ )
82
  # when example button is clicked, convert audio file to text and diarize
83
  btn.click(
84
+ fn=speech_to_text,
85
+ inputs=audio,
86
+ outputs=diarized,
 
87
  )
88
  # when summarize checkboxes are changed, create summary
89
+ check.change(fn=partial(summarize, summarization_pipeline=summarization_pipeline), inputs=[diarized, check], outputs=summary)
90
 
91
  # when sentiment button clicked, display highlighted text and plot
92
+ sentiment_btn.click(fn=partial(sentiment, emotion_pipeline=emotion_pipeline), inputs=diarized, outputs=[analyzed, plot])
 
 
 
 
 
93
 
 
 
 
 
 
 
 
 
94
 
95
+ demo.launch(debug=1)