draft streamlit app
Browse files- app.py +262 -0
- requirements.txt +10 -0
@@ -0,0 +1,262 @@
1 |
import os
2 |
import json
3 |
4 |
import numpy as np
5 |
import ffmpeg
6 |
import whisper
7 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8 |
from sklearn.tree import DecisionTreeRegressor
9 |
import torch
10 |
import youtube_dl
11 |
import pandas as pd
12 |
import streamlit as st
13 |
import altair as alt
14 |
15 |
DATA_DIR = "./data"
16 |
if not os.path.exists(DATA_DIR):
17 |
18 |
19 |
20 |
"download_archive": os.path.join(DATA_DIR, "archive.txt"),
21 |
"format": "bestaudio/best",
22 |
"outtmpl": os.path.join(DATA_DIR, "%(title)s.%(ext)s"),
23 |
"postprocessors": [
24 |
25 |
"key": "FFmpegExtractAudio",
26 |
"preferredcodec": "mp3",
27 |
"preferredquality": "192",
28 |
29 |
30 |
31 |
32 |
llm = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
33 |
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
34 |
device = "cuda" if torch.cuda.is_available() else "cpu"
35 |
36 |
37 |
def download(url, ydl_opts):
38 |
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
39 |
result = ydl.extract_info("{}".format(url))
40 |
fname = ydl.prepare_filename(result)
41 |
return fname
42 |
43 |
44 |
def transcribe(audio_path, transcript_path):
45 |
if os.path.exists(transcript_path):
46 |
with open(transcript_path, "r") as f:
47 |
result = json.load(f)
48 |
49 |
whisper_model = whisper.load_model("base")
50 |
result = whisper_model.transcribe(audio_path)
51 |
with open(transcript_path, "w") as f:
52 |
json.dump(result, f)
53 |
return result["segments"]
54 |
55 |
56 |
def compute_seg_durations(segments):
57 |
return [s["end"] - s["start"] for s in segments]
58 |
59 |
60 |
def compute_info_densities(
61 |
segments, seg_durations, llm, tokenizer, device, ctxt_len=512
62 |
63 |
seg_encodings = [tokenizer(seg["text"], return_tensors="pt") for seg in segments]
64 |
input_ids = [enc.input_ids.to(device) for enc in seg_encodings]
65 |
seg_lens = [x.shape[1] for x in input_ids]
66 |
cat_input_ids = torch.cat(input_ids, axis=1)
67 |
end = 0
68 |
seg_nlls = []
69 |
n = cat_input_ids.shape[1]
70 |
for i, seg_len in enumerate(seg_lens):
71 |
end = min(n, end + seg_len)
72 |
start = max(0, end - ctxt_len)
73 |
ctxt_ids = cat_input_ids[:, start:end]
74 |
target_ids = ctxt_ids.clone()
75 |
target_ids[:, :-seg_len] = -100
76 |
avg_nll = llm(ctxt_ids, labels=target_ids).loss.detach().numpy()
77 |
nll = avg_nll * seg_len
78 |
79 |
seg_nlls = np.array(seg_nlls)
80 |
info_densities = seg_nlls / seg_durations
81 |
return info_densities
82 |
83 |
84 |
def smooth_info_densities(info_densities, seg_durations, max_leaf_nodes, min_sec_leaf):
85 |
min_samples_leaf = int(np.ceil(min_sec_leaf / np.mean(seg_durations)))
86 |
tree = DecisionTreeRegressor(
87 |
max_leaf_nodes=max_leaf_nodes, min_samples_leaf=min_samples_leaf
88 |
89 |
X = np.arange(0, len(info_densities), 1)[:, np.newaxis]
90 |
tree.fit(X, info_densities)
91 |
smoothed_info_densities = tree.predict(X)
92 |
return smoothed_info_densities
93 |
94 |
95 |
def squash_segs(segments, info_densities):
96 |
start = segments[0]["start"]
97 |
end = None
98 |
seg_times = []
99 |
seg_densities = [info_densities[0]]
100 |
for i in range(1, len(segments)):
101 |
curr_density = info_densities[i]
102 |
if curr_density != info_densities[i - 1]:
103 |
seg = segments[i]
104 |
seg_start = seg["start"]
105 |
seg_times.append((start, seg_start))
106 |
107 |
start = seg_start
108 |
seg_times.append((start, segments[-1]["end"]))
109 |
return seg_times, seg_densities
110 |
111 |
112 |
def compute_speedups(info_densities):
113 |
avg_density = np.mean(info_densities)
114 |
speedups = avg_density / info_densities
115 |
return speedups
116 |
117 |
118 |
def compute_actual_speedup(durations, speedups, total_duration):
119 |
spedup_durations = durations / speedups
120 |
spedup_total_duration = spedup_durations.sum()
121 |
actual_speedup_factor = total_duration / spedup_total_duration
122 |
return spedup_total_duration, actual_speedup_factor
123 |
124 |
125 |
def postprocess_speedups(
126 |
speedups, factor, min_speedup, max_speedup, durations, total_duration, thresh=0.01
127 |
128 |
assert min_speedup <= factor and factor <= max_speedup
129 |
tuned_factor = np.array([factor / 10, factor * 10])
130 |
actual_speedup_factor = None
131 |
while (
132 |
actual_speedup_factor is None
133 |
or abs(actual_speedup_factor - factor) / factor > thresh
134 |
135 |
mid = tuned_factor.mean()
136 |
tuned_speedups = speedups * mid
137 |
tuned_speedups = np.round(tuned_speedups, decimals=2)
138 |
tuned_speedups = np.clip(tuned_speedups, min_speedup, max_speedup)
139 |
_, actual_speedup_factor = compute_actual_speedup(
140 |
durations, tuned_speedups, total_duration
141 |
142 |
tuned_factor[0 if actual_speedup_factor < factor else 1] = mid
143 |
return tuned_speedups
144 |
145 |
146 |
def cat_clips(seg_times, speedups, audio_path, output_path):
147 |
if os.path.exists(output_path):
148 |
149 |
in_file = ffmpeg.input(audio_path)
150 |
segs = []
151 |
for (start, end), speedup in zip(seg_times, speedups):
152 |
seg = in_file.filter("atrim", start=start, end=end).filter("atempo", speedup)
153 |
154 |
cat = ffmpeg.concat(*segs, v=0, a=1)
155 |
156 |
157 |
158 |
def format_duration(duration):
159 |
s = duration % 60
160 |
m = duration // 60
161 |
h = m // 60
162 |
return "%02d:%02d:%02d" % (h, m, s)
163 |
164 |
165 |
def strike(url, speedup_factor, min_speedup, max_speedup, max_num_segments):
166 |
167 |
min_speedup = max(0.5, min_speedup) # ffmpeg limit
168 |
169 |
name = download(url, YDL_OPTS)
170 |
assert name.endswith(".m4a")
171 |
name = name.split(".m4a")[0].split("/")[-1]
172 |
173 |
audio_path = os.path.join(DATA_DIR, "%s.mp3" % name)
174 |
transcript_path = os.path.join(DATA_DIR, "%s.json" % name)
175 |
output_path = os.path.join(DATA_DIR, "%s_smooth.mp3" % name)
176 |
177 |
segments = transcribe(audio_path, transcript_path)
178 |
179 |
seg_durations = compute_seg_durations(segments)
180 |
181 |
info_densities = compute_info_densities(
182 |
segments, seg_durations, llm, tokenizer, device
183 |
184 |
185 |
total_duration = segments[-1]["end"] - segments[0]["start"]
186 |
min_sec_leaf = total_duration / max_num_segments
187 |
smoothed_info_densities = smooth_info_densities(
188 |
info_densities, seg_durations, max_num_segments, min_sec_leaf
189 |
190 |
191 |
squashed_times, squashed_densities = squash_segs(segments, smoothed_info_densities)
192 |
squashed_durations = np.array([end - start for start, end in squashed_times])
193 |
194 |
speedups = compute_speedups(squashed_densities)
195 |
speedups = postprocess_speedups(
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
cat_clips(squashed_times, speedups, audio_path, output_path)
205 |
206 |
spedup_total_duration, actual_speedup_factor = compute_actual_speedup(
207 |
squashed_durations, speedups, total_duration
208 |
209 |
st.write("original duration: %s" % format_duration(total_duration))
210 |
st.write("new duration: %s" % format_duration(spedup_total_duration))
211 |
st.write("speedup: %0.2f" % actual_speedup_factor)
212 |
213 |
times = np.array([(seg["start"] + seg["end"]) / 2 for seg in segments])
214 |
times /= 60
215 |
annotations = [seg["text"] for seg in segments]
216 |
data = [times, info_densities / np.log(2), annotations]
217 |
cols = ["time (minutes)", "bits per second", "transcript"]
218 |
df = pd.DataFrame(list(zip(*data)), columns=cols)
219 |
220 |
lines = (
221 |
alt.Chart(df, title="information rate")
222 |
.mark_line(color="gray", opacity=0.5)
223 |
224 |
225 |
226 |
227 |
228 |
dots = (
229 |
230 |
.mark_circle(size=50, opacity=1)
231 |
.encode(x=cols[0], y=cols[1], tooltip=["transcript"])
232 |
233 |
st.altair_chart((lines + dots).interactive(), use_container_width=True)
234 |
235 |
times = sum([list(x) for x in squashed_times], [])
236 |
times = np.array(times)
237 |
times /= 60
238 |
data = [times, np.repeat(speedups, 2)]
239 |
cols = ["time (minutes)", "speedup"]
240 |
df = pd.DataFrame(list(zip(*data)), columns=cols)
241 |
st.line_chart(df, x=cols[0], y=cols[1])
242 |
243 |
return output_path
244 |
245 |
246 |
with st.form("my_form"):
247 |
url = st.text_input(
248 |
"youtube url", value="https://www.youtube.com/watch?v=_3MBQm7GFIM"
249 |
250 |
speedup_factor = st.slider("speedup", min_value=1.0, max_value=10.0, value=1.5)
251 |
min_speedup = 1
252 |
max_speedup = st.slider("maximum speedup", min_value=1.0, max_value=10.0, value=2.0)
253 |
speedup_factor = min(speedup_factor, max_speedup)
254 |
max_num_segments = st.slider(
255 |
"variance in speedup over time", min_value=2, max_value=100, value=20
256 |
257 |
submitted = st.form_submit_button("submit")
258 |
if submitted:
259 |
output_path = strike(
260 |
url, speedup_factor, min_speedup, max_speedup, max_num_segments
261 |
262 |
@@ -0,0 +1,10 @@
1 |
2 |
3 |
4 |
5 |
6 |
whisper @ git+https://github.com/openai/whisper.git@9f70a352f9f8630ab3aa0d06af5cb9532bd8c21d
7 |
8 |
9 |
10 |