TroglodyteDerivations commited on
Commit
937e691
1 Parent(s): dbee188

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +339 -0
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torchaudio
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from dataclasses import dataclass
7
+ import string
8
+ import IPython
9
+
10
+ # Part A: Import torch and torchaudio
11
+ st.write(torch.__version__)
12
+ st.write(torchaudio.__version__)
13
+ device = 'cpu'
14
+ st.write(device)
15
+
16
+ # Part B: Load the audio file
17
+ SPEECH_FILE = 'abby_cadabby.wav'
18
+ waveform, sample_rate = torchaudio.load(SPEECH_FILE)
19
+ st.write(SPEECH_FILE)
20
+
21
+ # Part C: torchaudio.pipelines | bundle.get_model | bundle.get_labels()
22
+ bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
23
+ model = bundle.get_model().to(device)
24
+ labels = bundle.get_labels()
25
+
26
+ # Inference mode
27
+ with torch.inference_mode():
28
+ # Load the audio file using torchaudio.load
29
+ waveform, sample_rate = torchaudio.load(SPEECH_FILE)
30
+ waveform = waveform.to(device)
31
+
32
+ # Pass the waveform through the model
33
+ emissions, _ = model(waveform)
34
+ emissions = torch.log_softmax(emissions, dim=-1)
35
+
36
+ # Get the emissions for the first example
37
+ emission = emissions[0].cpu().detach()
38
+
39
+ # Print the labels
40
+ st.write('Labels are: ', labels)
41
+ st.write('Length of labels are: ', len(labels))
42
+
43
+ # Part D: Frame-wise class probability plot
44
+ def plot():
45
+ fig, ax = plt.subplots()
46
+ img = ax.imshow(emission.T)
47
+ ax.set_title("Frame-wise class probability")
48
+ ax.set_xlabel("Time")
49
+ ax.set_ylabel("Labels")
50
+ fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
51
+ fig.tight_layout()
52
+ return fig
53
+
54
+ st.pyplot(plot())
55
+
56
+ # Part E: Remove punctuation add | after each word. Also, convert into all UPPERCASE
57
+ def remove_punctuation(input_string):
58
+ # Make a translator object to remove all punctuation
59
+ translator = str.maketrans('', '', string.punctuation)
60
+
61
+ # Split the input string into words
62
+ words = input_string.split()
63
+
64
+ # Remove punctuation from each word, convert to uppercase, and join them with '|'
65
+ clean_words = ['|' + word.translate(translator).upper() + '|' for word in words]
66
+ clean_transcript = ''.join(clean_words).strip('|')
67
+
68
+ return clean_transcript
69
+
70
+ # Test the function
71
+ transcript = " Oh hi! It's me, Abby Cadabby. Do you want to watch me practice my magic? I am going to turn this"
72
+
73
+ clean_transcript = remove_punctuation(transcript)
74
+ st.write(clean_transcript)
75
+
76
+ # Part F: Populate Trellis
77
+ def get_trellis(emission, tokens, blank_id=0):
78
+ num_frame = emission.size(0)
79
+ num_tokens = len(tokens)
80
+
81
+ trellis = torch.zeros((num_frame, num_tokens))
82
+ trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
83
+ trellis[0, 1:] = -float("inf")
84
+ trellis[-num_tokens + 1 :, 0] = float("inf")
85
+
86
+ for t in range(num_frame - 1):
87
+ trellis[t + 1, 1:] = torch.maximum(
88
+ # Score for staying at the same token
89
+ trellis[t, 1:] + emission[t, blank_id],
90
+ # Score for changing to the next token
91
+ trellis[t, :-1] + emission[t, tokens[1:]]
92
+ )
93
+ return trellis
94
+
95
+ trellis = get_trellis(emission, tokens)
96
+ st.write('Trellis =', trellis)
97
+
98
+ # Part G: Labels and Time -Inf | +Inf
99
+ def n_inf_to_p_inf():
100
+ fig, ax = plt.subplots()
101
+ img = ax.imshow(trellis.T, origin="lower")
102
+ ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
103
+ # Shift the "+ Inf" annotation to the left by decreasing the x-coordinate value
104
+ ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 1.4, trellis.size(1) / 3))
105
+ fig.colorbar(img, ax=ax, shrink=0.25, location="bottom")
106
+ fig.tight_layout()
107
+ return fig
108
+
109
+ st.pyplot(n_inf_to_p_inf())
110
+
111
+ # Part H: Backtrack Trellis Emissions Tensor and Tokens
112
+ @dataclass
113
+ class Point:
114
+ token_index: int
115
+ time_index: int
116
+ score: float
117
+
118
+ def backtrack(trellis, emission, tokens, blank_id=0):
119
+ t, j = trellis.size(0) - 1, trellis.size(1) - 1
120
+
121
+ path = [Point(j, t, emission[t, blank_id].exp().item())]
122
+ while j > 0:
123
+ # Should not happen but just in case
124
+ assert t > 0
125
+
126
+ # 1. Figure out if the current position was stay or change
127
+ # Frame-wise score of stay vs change
128
+ p_stay = emission[t - 1, blank_id]
129
+ p_change = emission[t - 1, tokens[j]]
130
+
131
+ # Context-aware score for stay vs change
132
+ stayed = trellis[t - 1, j] + p_stay
133
+ changed = trellis[t - 1, j - 1] + p_change
134
+
135
+ # Update position
136
+ t -= 1
137
+ if changed > stayed:
138
+ j -= 1
139
+
140
+ # Store the path with frame-wise probability
141
+ prob = (p_change if changed > stayed else p_stay).exp().item()
142
+ path.append(Point(j, t, prob))
143
+
144
+ # Now j == 0, which means, it reached the SOS.
145
+ # Fill up the rest for the sake of visualization
146
+ while t > 0:
147
+ prob = emission[t - 1, blank_id].exp().item()
148
+ path.append(Point(j, t - 1, prob))
149
+ t -= 1
150
+ return path[::-1]
151
+
152
+ path = backtrack(trellis, emission, tokens)
153
+ for p in path:
154
+ st.write('Token index, Time index and Score:')
155
+ st.write(p)
156
+
157
+ # Part I: Trellis with Path Visualization
158
+ def plot_trellis_with_path(trellis, path):
159
+ # To plot trellis with path, we take advantage of 'nan' value
160
+ trellis_with_path = trellis.clone()
161
+ for _, p in enumerate(path):
162
+ trellis_with_path[p.time_index, p.token_index] = float("nan")
163
+ plt.imshow(trellis_with_path.T, origin="lower")
164
+ plt.title("The path found by backtracking")
165
+ plt.tight_layout()
166
+ return plt
167
+
168
+ st.pyplot(plot_trellis_with_path(trellis, path))
169
+
170
+ # Part J: Merge Repeats | Segments
171
+ # Merge the labels
172
+ @dataclass
173
+ class Segment:
174
+ label: str
175
+ start: int
176
+ end: int
177
+ score: float
178
+
179
+ def __repr__(self):
180
+ return f"{self.label}\t({self.score:4.2f}) : [{self.start:5d}, {self.end:5d})"
181
+
182
+ @property
183
+ def length(self):
184
+ return self.end - self.start
185
+
186
+ def merge_repeats(path):
187
+ i1, i2 = 0, 0
188
+ segments = []
189
+ while i1 < len(path):
190
+ while i2 < len(path) and path[i1].token_index == path[i2].token_index:
191
+ i2 += 1
192
+ score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
193
+ segments.append(
194
+ Segment(
195
+ transcript[path[i1].token_index],
196
+ path[i1].time_index,
197
+ path[i2 - 1].time_index + 1,
198
+ score,
199
+ )
200
+ )
201
+ i1 = i2
202
+ return segments
203
+
204
+ segments = merge_repeats(path)
205
+ for seg in segments:
206
+ st.write('Segments:')
207
+ st.write(seg)
208
+
209
+ # Part K: Trellis with Segments Visualization
210
+ def plot_trellis_with_segments(trellis, segments, transcript):
211
+ # To plot trellis with path, we take advantage of 'nan' value
212
+ trellis_with_path = trellis.clone()
213
+ for i, seg in enumerate(segments):
214
+ if seg.label != "|":
215
+ trellis_with_path[seg.start : seg.end, i] = float("nan")
216
+
217
+ fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True, figsize=(15, 15))
218
+ ax1.set_title("Path, label and probability for each label")
219
+ ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
220
+
221
+ # Adjust the position of the annotations to spread them out
222
+ for i, seg in enumerate(segments):
223
+ if seg.label != "|":
224
+ ax1.annotate(seg.label, (seg.start, i - 0.3), size="small")
225
+ ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 0.3), size="small")
226
+
227
+ ax2.set_title("Label probability with and without repetition")
228
+ xs, hs, ws = [], [], []
229
+ for seg in segments:
230
+ if seg.label != "|":
231
+ xs.append((seg.end + seg.start) / 2 + 0.4)
232
+ hs.append(seg.score)
233
+ ws.append(seg.end - seg.start)
234
+ ax2.annotate(seg.label, (seg.start + 0.8, -0.07), rotation=0)
235
+ ax2.bar(xs, hs, width=ws, color="gray", alpha=0.9, edgecolor="black")
236
+
237
+ xs, hs = [], []
238
+ for p in path:
239
+ label = transcript[p.token_index]
240
+ if label != "|":
241
+ xs.append(p.time_index + 1)
242
+ hs.append(p.score)
243
+
244
+ ax2.bar(xs, hs, width=0.9, alpha=0.9)
245
+ ax2.axhline(0, color="black")
246
+ ax2.grid(True, axis="y")
247
+ ax2.set_ylim(-0.1, 1.1)
248
+ fig.tight_layout()
249
+ return fig
250
+
251
+
252
+ plot_trellis_with_segments(trellis, segments, clean_transcript)
253
+ st.pyplot(plot_trellis_with_segments(trellis, segments, clean_transcript))
254
+
255
+ # Part L: Merge words | Segments
256
+ # Merge words
257
+ def merge_words(segments, separator="|"):
258
+ words = []
259
+ i1, i2 = 0, 0
260
+ while i1 < len(segments):
261
+ if i2 >= len(segments) or segments[i2].label == separator:
262
+ if i1 != i2:
263
+ segs = segments[i1:i2]
264
+ word = "".join([seg.label for seg in segs])
265
+ score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
266
+ words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
267
+ i1 = i2 + 1
268
+ i2 = i1
269
+ else:
270
+ i2 += 1
271
+ return words
272
+
273
+
274
+ word_segments = merge_words(segments)
275
+ for word in word_segments:
276
+ st.write('Word Segments:')
277
+ st.write(word)
278
+
279
+ # Part M: Alignment Visualizations
280
+ def plot_alignments(trellis, segments, word_segments, waveform=np.random.randn(1024), sample_rate=44100):
281
+ trellis_with_path = trellis.clone()
282
+ for i, seg in enumerate(segments):
283
+ if seg.label != "|":
284
+ trellis_with_path[seg.start : seg.end, i] = float("nan")
285
+
286
+ fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(20, 18))
287
+
288
+ ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
289
+ ax1.set_facecolor("lightgray")
290
+ ax1.set_xticks([])
291
+ ax1.set_yticks([])
292
+
293
+ for word in word_segments:
294
+ ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none")
295
+
296
+ for i, seg in enumerate(segments):
297
+ if seg.label != "|":
298
+ ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
299
+ ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
300
+
301
+ # The original waveform
302
+ NFFT = 1024
303
+ #ratio = waveform.size(0) / sample_rate / trellis.size(0)
304
+ #ratio = len(waveform) / sample_rate / trellis.size(0)
305
+ ratio = len(waveform) / sample_rate / trellis.size(0) #-> populates both visualizations
306
+
307
+ ax2.specgram(waveform, Fs=sample_rate, NFFT=NFFT)
308
+ for word in word_segments:
309
+ x0 = ratio * word.start
310
+ x1 = ratio * word.end
311
+ ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/")
312
+ ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False)
313
+
314
+ for seg in segments:
315
+ if seg.label != "|":
316
+ ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False)
317
+ ax2.set_xlabel("time [second]")
318
+ ax2.set_yticks([])
319
+ fig.tight_layout()
320
+ return fig
321
+
322
+
323
+ plot_alignments(trellis, segments, word_segments, waveform, sample_rate)
324
+ st.pyplot(plot_alignments(trellis, word_segments, waveform, sample_rate))
325
+
326
+ # Part N: Display Segment
327
+ def display_segment(i):
328
+ ratio = waveform.size(1) / trellis.size(0)
329
+ word = word_segments[i]
330
+ x0 = int(ratio * word.start)
331
+ x1 = int(ratio * word.end)
332
+ print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec")
333
+ segment = waveform[:, x0:x1]
334
+ return IPython.display.Audio(segment.numpy(), rate=bundle.sample_rate)
335
+
336
+ # Part O: Audio generation for each segment
337
+ st.write('Abby Cadabby Transcript:')
338
+ st.write('Transcript')
339
+ st.write(IPython.display.Audio(SPEECH_FILE))