enoreyes ktangri commited on
Commit
20ec090
0 Parent(s):

Duplicate from huggingface/call-sentiment-demo

Browse files

Co-authored-by: Kunal Tangri <ktangri@users.noreply.huggingface.co>

Files changed (9) hide show
  1. .gitattributes +31 -0
  2. Customer_Support_Call.wav +3 -0
  3. README.md +13 -0
  4. app.py +286 -0
  5. example_audio.wav +3 -0
  6. packages.txt +2 -0
  7. requirements.txt +10 -0
  8. short-take-1.wav +3 -0
  9. utils.py +33 -0
.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.wasm filter=lfs diff=lfs merge=lfs -text
25
+ *.xz filter=lfs diff=lfs merge=lfs -text
26
+ *.zip filter=lfs diff=lfs merge=lfs -text
27
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
28
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
29
+ example_audio.wav filter=lfs diff=lfs merge=lfs -text
30
+ short-take-1.wav filter=lfs diff=lfs merge=lfs -text
31
+ Customer_Support_Call.wav filter=lfs diff=lfs merge=lfs -text
Customer_Support_Call.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db6489658bb04f84503531d628a67028de9d754ee0b18cf229f39deec7828001
3
+ size 31497612
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Call Sentiment Blocks 2
3
+ emoji: 🐠
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 2.9b23
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: huggingface/call-sentiment-demo
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import functools
4
+
5
+ import requests
6
+ import pandas as pd
7
+ import plotly.express as px
8
+
9
+ import torch
10
+ import gradio as gr
11
+ from transformers import pipeline, Wav2Vec2ProcessorWithLM
12
+ from pyannote.audio import Pipeline
13
+ from librosa import load, resample
14
+ from rpunct import RestorePuncts
15
+
16
+ from utils import split_into_sentences
17
+
18
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
+ device = 0 if torch.cuda.is_available() else -1
20
+
21
+ # summarization is done over inference API
22
+ headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
23
+ summarization_url = (
24
+ "https://api-inference.huggingface.co/models/knkarthick/MEETING_SUMMARY"
25
+ )
26
+
27
+ # There was an error related to Non-english text being detected,
28
+ # so this regular expression gets rid of any weird character.
29
+ # This might be completely unnecessary.
30
+ eng_pattern = r"[^\d\s\w'\.\,\?]"
31
+
32
+
33
+ def summarize(diarized, check):
34
+ """
35
+ diarized: a list of tuples. Each tuple has a string to be displayed and a label for highlighting.
36
+ The start/end times are not highlighted [(speaker text, speaker id), (start time/end time, None)]
37
+ check is a list of speaker ids whose speech will get summarized
38
+ """
39
+
40
+ if len(check) == 0:
41
+ return ""
42
+
43
+ text = ""
44
+ for d in diarized:
45
+ if len(check) == 2 and d[1] is not None:
46
+ text += f"\n{d[1]}: {d[0]}"
47
+ elif d[1] in check:
48
+ text += f"\n{d[0]}"
49
+
50
+ # inner function cached because outer function cannot be cached
51
+ @functools.lru_cache(maxsize=128)
52
+ def call_summarize_api(text):
53
+ payload = {
54
+ "inputs": text,
55
+ "options": {
56
+ "use_gpu": False,
57
+ "wait_for_model": True,
58
+ },
59
+ }
60
+ response = requests.post(summarization_url, headers=headers, json=payload)
61
+ return response.json()[0]["summary_text"]
62
+
63
+ return call_summarize_api(text)
64
+
65
+
66
+ # Audio components
67
+ asr_model = "patrickvonplaten/wav2vec2-base-960h-4-gram"
68
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model)
69
+ asr = pipeline(
70
+ "automatic-speech-recognition",
71
+ model=asr_model,
72
+ tokenizer=processor.tokenizer,
73
+ feature_extractor=processor.feature_extractor,
74
+ decoder=processor.decoder,
75
+ device=device,
76
+ )
77
+ speaker_segmentation = Pipeline.from_pretrained("pyannote/speaker-segmentation")
78
+ rpunct = RestorePuncts()
79
+
80
+ # Text components
81
+ emotion_pipeline = pipeline(
82
+ "text-classification",
83
+ model="bhadresh-savani/distilbert-base-uncased-emotion",
84
+ device=device,
85
+ )
86
+
87
+ EXAMPLES = [["example_audio.wav"], ["Customer_Support_Call.wav"]]
88
+
89
+ # display if the sentiment value is above these thresholds
90
+ thresholds = {
91
+ "joy": 0.99,
92
+ "anger": 0.95,
93
+ "surprise": 0.95,
94
+ "sadness": 0.98,
95
+ "fear": 0.95,
96
+ "love": 0.99,
97
+ }
98
+
99
+
100
+ def speech_to_text(speech):
101
+ speaker_output = speaker_segmentation(speech)
102
+ speech, sampling_rate = load(speech)
103
+ if sampling_rate != 16000:
104
+ speech = resample(speech, sampling_rate, 16000)
105
+ text = asr(speech, return_timestamps="word")
106
+ chunks = text["chunks"]
107
+
108
+ diarized_output = []
109
+ i = 0
110
+ speaker_counter = 0
111
+
112
+ # New iteration every time the speaker changes
113
+ for turn, _, _ in speaker_output.itertracks(yield_label=True):
114
+
115
+ speaker = "Customer" if speaker_counter % 2 == 0 else "Support"
116
+ diarized = ""
117
+ while i < len(chunks) and chunks[i]["timestamp"][1] <= turn.end:
118
+ diarized += chunks[i]["text"].lower() + " "
119
+ i += 1
120
+
121
+ if diarized != "":
122
+ diarized = rpunct.punctuate(re.sub(eng_pattern, "", diarized), lang="en")
123
+
124
+ diarized_output.extend(
125
+ [
126
+ (diarized, speaker),
127
+ ("from {:.2f}-{:.2f}".format(turn.start, turn.end), None),
128
+ ]
129
+ )
130
+
131
+ speaker_counter += 1
132
+
133
+ return diarized_output
134
+
135
+
136
+ def sentiment(diarized):
137
+ """
138
+ diarized: a list of tuples. Each tuple has a string to be displayed and a label for highlighting.
139
+ The start/end times are not highlighted [(speaker text, speaker id), (start time/end time, None)]
140
+
141
+ This function gets the customer's sentiment and returns a list for highlighted text as well
142
+ as a plot of sentiment over time.
143
+ """
144
+
145
+ customer_sentiments = []
146
+
147
+ to_plot = []
148
+ plot_sentences = []
149
+
150
+ # used to set the x range of ticks on the plot
151
+ x_min = 100
152
+ x_max = 0
153
+
154
+ for i in range(0, len(diarized), 2):
155
+ speaker_speech, speaker_id = diarized[i]
156
+ times, _ = diarized[i + 1]
157
+
158
+ sentences = split_into_sentences(speaker_speech)
159
+ start_time, end_time = times[5:].split("-")
160
+ start_time, end_time = float(start_time), float(end_time)
161
+ interval_size = (end_time - start_time) / len(sentences)
162
+
163
+ if "Customer" in speaker_id:
164
+
165
+ outputs = emotion_pipeline(sentences)
166
+
167
+ for idx, (o, t) in enumerate(zip(outputs, sentences)):
168
+ sent = "neutral"
169
+ if o["score"] > thresholds[o["label"]]:
170
+ customer_sentiments.append(
171
+ (t + f"({round(idx*interval_size+start_time,1)} s)", o["label"])
172
+ )
173
+ if o["label"] in {"joy", "love", "surprise"}:
174
+ sent = "positive"
175
+ elif o["label"] in {"sadness", "anger", "fear"}:
176
+ sent = "negative"
177
+ if sent != "neutral":
178
+ to_plot.append((start_time + idx * interval_size, sent))
179
+ plot_sentences.append(t)
180
+
181
+ if start_time < x_min:
182
+ x_min = start_time
183
+ if end_time > x_max:
184
+ x_max = end_time
185
+
186
+ x_min -= 5
187
+ x_max += 5
188
+
189
+ x, y = list(zip(*to_plot))
190
+
191
+ plot_df = pd.DataFrame(
192
+ data={
193
+ "x": x,
194
+ "y": y,
195
+ "sentence": plot_sentences,
196
+ }
197
+ )
198
+
199
+ fig = px.line(
200
+ plot_df,
201
+ x="x",
202
+ y="y",
203
+ hover_data={
204
+ "sentence": True,
205
+ "x": True,
206
+ "y": False,
207
+ },
208
+ labels={"x": "time (seconds)", "y": "sentiment"},
209
+ title=f"Customer sentiment over time",
210
+ )
211
+
212
+ fig = fig.update_yaxes(categoryorder="category ascending")
213
+ fig = fig.update_layout(
214
+ font=dict(
215
+ size=18,
216
+ ),
217
+ xaxis_range=[x_min, x_max],
218
+ )
219
+
220
+ return customer_sentiments, fig
221
+
222
+
223
+ demo = gr.Blocks(enable_queue=True)
224
+ demo.encrypt = False
225
+
226
+ # for highlighting purposes
227
+ color_map = {
228
+ "joy": "green",
229
+ "anger": "red",
230
+ "surprise": "yellow",
231
+ "sadness": "blue",
232
+ "fear": "orange",
233
+ "love": "purple",
234
+ }
235
+
236
+ with demo:
237
+ with gr.Row():
238
+ with gr.Column():
239
+ audio = gr.Audio(label="Audio file", type="filepath")
240
+ with gr.Row():
241
+ btn = gr.Button("Transcribe")
242
+ with gr.Row():
243
+ examples = gr.components.Dataset(
244
+ components=[audio], samples=EXAMPLES, type="index"
245
+ )
246
+ with gr.Column():
247
+ gr.Markdown("**Call Transcript:**")
248
+ diarized = gr.HighlightedText(label="Call Transcript")
249
+ gr.Markdown("Choose speaker to summarize:")
250
+ check = gr.CheckboxGroup(
251
+ choices=["Customer", "Support"], show_label=False, type="value"
252
+ )
253
+ summary = gr.Textbox(lines=4)
254
+ sentiment_btn = gr.Button("Get Customer Sentiment")
255
+ analyzed = gr.HighlightedText(color_map=color_map)
256
+ plot = gr.Plot(label="Sentiment over time", type="plotly")
257
+
258
+ # when example button is clicked, convert audio file to text and diarize
259
+ btn.click(
260
+ speech_to_text,
261
+ audio,
262
+ [diarized],
263
+ status_tracker=gr.StatusTracker(cover_container=True),
264
+ )
265
+ # when summarize checkboxes are changed, create summary
266
+ check.change(summarize, [diarized, check], summary)
267
+
268
+ # when sentiment button clicked, display highlighted text and plot
269
+ sentiment_btn.click(sentiment, [diarized], [analyzed, plot])
270
+
271
+
272
+ def cache_example(example):
273
+ processed_examples = audio.preprocess_example(example)
274
+ diarized_output = speech_to_text(example)
275
+ return processed_examples, diarized_output
276
+
277
+ cache = [cache_example(e[0]) for e in EXAMPLES]
278
+
279
+ def load_example(example_id):
280
+ return cache[example_id]
281
+
282
+ examples._click_no_postprocess(
283
+ load_example, inputs=[examples], outputs=[audio, diarized], queue=False
284
+ )
285
+
286
+ demo.launch(debug=1)
example_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43166418f743e61807c7681944bf344c4720924adb4e5879dfa954dc7ecc82b2
3
+ size 3202638
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libsndfile1
2
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ librosa
4
+ pyctcdecode
5
+ pypi-kenlm
6
+ git+https://github.com/ktangri/rpunct.git
7
+ https://github.com/pyannote/pyannote-audio/archive/develop.zip
8
+ requests
9
+ speechbrain
10
+ plotly
short-take-1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf15193510fc5a5680fdfdffda6c7cc5b8595bdde3d267b9ef5223e62035a952
3
+ size 20079500
utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ alphabets= "([A-Za-z])"
3
+ prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
4
+ suffixes = "(Inc|Ltd|Jr|Sr|Co)"
5
+ starters = "(Mr|Mrs|Ms|Dr|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
6
+ acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
7
+ websites = "[.](com|net|org|io|gov)"
8
+
9
+ def split_into_sentences(text):
10
+ text = " " + text + " "
11
+ text = text.replace("\n"," ")
12
+ text = re.sub(prefixes,"\\1<prd>",text)
13
+ text = re.sub(websites,"<prd>\\1",text)
14
+ if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>")
15
+ text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text)
16
+ text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text)
17
+ text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text)
18
+ text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text)
19
+ text = re.sub(" "+suffixes+"[.] "+starters," \\1<stop> \\2",text)
20
+ text = re.sub(" "+suffixes+"[.]"," \\1<prd>",text)
21
+ text = re.sub(" " + alphabets + "[.]"," \\1<prd>",text)
22
+ if "”" in text: text = text.replace(".”","”.")
23
+ if "\"" in text: text = text.replace(".\"","\".")
24
+ if "!" in text: text = text.replace("!\"","\"!")
25
+ if "?" in text: text = text.replace("?\"","\"?")
26
+ text = text.replace(".",".<stop>")
27
+ text = text.replace("?","?<stop>")
28
+ text = text.replace("!","!<stop>")
29
+ text = text.replace("<prd>",".")
30
+ sentences = text.split("<stop>")
31
+ sentences = sentences[:-1]
32
+ sentences = [s.strip() for s in sentences]
33
+ return sentences