Spaces:
Sleeping
Sleeping
Upload 21 files
Browse files- .github/FUNDING.yml +3 -0
- .github/workflows/check_large_file.yml +16 -0
- .github/workflows/sync_with_huggingface.yml +19 -0
- app.py +318 -0
- data/.gitignore +4 -0
- models/.gitignore +4 -0
- raw/.gitignore +4 -0
- requirements.txt +7 -0
- src/classify.py +43 -0
- src/errors.py +24 -0
- src/evaluate.py +408 -0
- src/model.py +235 -0
- src/predict.py +285 -0
- src/preprocess.py +979 -0
- src/segment.py +166 -0
- src/shared.py +406 -0
- src/train.py +191 -0
- src/train_classifier.py +182 -0
- src/utils.py +135 -0
- transcripts/auto/.gitignore +5 -0
- transcripts/manual/.gitignore +5 -0
.github/FUNDING.yml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
github: xenova
|
2 |
+
ko_fi: xenova
|
3 |
+
custom: https://www.buymeacoffee.com/xenova
|
.github/workflows/check_large_file.yml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Check file size
|
2 |
+
on: # or directly `on: [push]` to run the action on every push on any branch
|
3 |
+
pull_request:
|
4 |
+
branches: [main]
|
5 |
+
|
6 |
+
# to run this workflow manually from the Actions tab
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
sync-to-hub:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- name: Check large files
|
14 |
+
uses: ActionsDesk/lfs-warning@v2.0
|
15 |
+
with:
|
16 |
+
filesizelimit: 10485760 # this is 10MB so we can sync to HF Spaces
|
.github/workflows/sync_with_huggingface.yml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face hub
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: [main]
|
5 |
+
|
6 |
+
# to run this workflow manually from the Actions tab
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
sync-to-hub:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v2
|
14 |
+
with:
|
15 |
+
fetch-depth: 0
|
16 |
+
- name: Push to hub
|
17 |
+
env:
|
18 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
19 |
+
run: git push https://Xenova:$HF_TOKEN@huggingface.co/spaces/Xenova/sponsorblock-ml main
|
app.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from functools import partial
|
3 |
+
from math import ceil, floor
|
4 |
+
import streamlit.components.v1 as components
|
5 |
+
import streamlit as st
|
6 |
+
import sys
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
from urllib.parse import quote
|
10 |
+
|
11 |
+
# Allow direct execution
|
12 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
|
13 |
+
|
14 |
+
from preprocess import get_words
|
15 |
+
from predict import PredictArguments, SegmentationArguments, predict as pred
|
16 |
+
from shared import GeneralArguments, seconds_to_time, CATGEGORY_OPTIONS
|
17 |
+
from utils import regex_search
|
18 |
+
from model import get_model_tokenizer_classifier
|
19 |
+
from errors import TranscriptError
|
20 |
+
|
21 |
+
st.set_page_config(
|
22 |
+
page_title='SponsorBlock ML',
|
23 |
+
page_icon='🤖',
|
24 |
+
layout='wide',
|
25 |
+
# initial_sidebar_state="expanded",
|
26 |
+
menu_items={
|
27 |
+
# 'Get Help': 'https://github.com/xenova/sponsorblock-ml',
|
28 |
+
# 'Report a bug': 'https://github.com/xenova/sponsorblock-ml/issues/new/choose',
|
29 |
+
# 'About': "# This is a header. This is an *extremely* cool app!"
|
30 |
+
}
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
YT_VIDEO_REGEX = r'''(?x)^
|
35 |
+
(?:
|
36 |
+
# http(s):// or protocol-independent URL
|
37 |
+
(?:https?://|//)
|
38 |
+
(?:(?:(?:(?:\w+\.)?[yY][oO][uU][tT][uU][bB][eE](?:-nocookie|kids)?\.com/|
|
39 |
+
youtube\.googleapis\.com/) # the various hostnames, with wildcard subdomains
|
40 |
+
(?:.*?\#/)? # handle anchor (#/) redirect urls
|
41 |
+
(?: # the various things that can precede the ID:
|
42 |
+
# v/ or embed/ or e/
|
43 |
+
(?:(?:v|embed|e)/(?!videoseries))
|
44 |
+
|(?: # or the v= param in all its forms
|
45 |
+
# preceding watch(_popup|.php) or nothing (like /?v=xxxx)
|
46 |
+
(?:(?:watch|movie)(?:_popup)?(?:\.php)?/?)?
|
47 |
+
(?:\?|\#!?) # the params delimiter ? or # or #!
|
48 |
+
# any other preceding param (like /?s=tuff&v=xxxx or ?s=tuff&v=V36LpHqtcDY)
|
49 |
+
(?:.*?[&;])??
|
50 |
+
v=
|
51 |
+
)
|
52 |
+
))
|
53 |
+
|(?:
|
54 |
+
youtu\.be # just youtu.be/xxxx
|
55 |
+
)/)
|
56 |
+
)? # all until now is optional -> you can pass the naked ID
|
57 |
+
# here is it! the YouTube video ID
|
58 |
+
(?P<id>[0-9A-Za-z_-]{11})'''
|
59 |
+
|
60 |
+
# https://github.com/google-research/text-to-text-transfer-transformer#released-model-checkpoints
|
61 |
+
# https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#experimental-t5-pre-trained-model-checkpoints
|
62 |
+
|
63 |
+
# https://huggingface.co/docs/transformers/model_doc/t5
|
64 |
+
# https://huggingface.co/docs/transformers/model_doc/t5v1.1
|
65 |
+
|
66 |
+
|
67 |
+
# Faster caching system for predictions (No need to hash)
|
68 |
+
@st.cache(persist=True, allow_output_mutation=True)
|
69 |
+
def create_prediction_cache():
|
70 |
+
return {}
|
71 |
+
|
72 |
+
|
73 |
+
@st.cache(persist=True, allow_output_mutation=True)
|
74 |
+
def create_function_cache():
|
75 |
+
return {}
|
76 |
+
|
77 |
+
|
78 |
+
prediction_cache = create_prediction_cache()
|
79 |
+
prediction_function_cache = create_function_cache()
|
80 |
+
|
81 |
+
MODELS = {
|
82 |
+
'Small (293 MB)': {
|
83 |
+
'pretrained': 'google/t5-v1_1-small',
|
84 |
+
'repo_id': 'Xenova/sponsorblock-small',
|
85 |
+
'num_parameters': '77M'
|
86 |
+
},
|
87 |
+
'Base v1 (850 MB)': {
|
88 |
+
'pretrained': 't5-base',
|
89 |
+
'repo_id': 'Xenova/sponsorblock-base-v1',
|
90 |
+
'num_parameters': '220M'
|
91 |
+
},
|
92 |
+
|
93 |
+
'Base v1.1 (944 MB)': {
|
94 |
+
'pretrained': 'google/t5-v1_1-base',
|
95 |
+
'repo_id': 'Xenova/sponsorblock-base-v1.1',
|
96 |
+
'num_parameters': '250M'
|
97 |
+
}
|
98 |
+
}
|
99 |
+
|
100 |
+
# Create per-model cache
|
101 |
+
for m in MODELS:
|
102 |
+
if m not in prediction_cache:
|
103 |
+
prediction_cache[m] = {}
|
104 |
+
|
105 |
+
|
106 |
+
CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier-v2'
|
107 |
+
|
108 |
+
|
109 |
+
TRANSCRIPT_TYPES = {
|
110 |
+
'AUTO_MANUAL': {
|
111 |
+
'label': 'Auto-generated (fallback to manual)',
|
112 |
+
'type': 'auto',
|
113 |
+
'fallback': 'manual'
|
114 |
+
},
|
115 |
+
'MANUAL_AUTO': {
|
116 |
+
'label': 'Manual (fallback to auto-generated)',
|
117 |
+
'type': 'manual',
|
118 |
+
'fallback': 'auto'
|
119 |
+
},
|
120 |
+
# 'TRANSLATED': 'Translated to English' # Coming soon
|
121 |
+
}
|
122 |
+
|
123 |
+
|
124 |
+
def predict_function(model_id, model, tokenizer, segmentation_args, classifier, video_id, words, ts_type_id):
|
125 |
+
cache_id = f'{video_id}_{ts_type_id}'
|
126 |
+
|
127 |
+
if cache_id not in prediction_cache[model_id]:
|
128 |
+
prediction_cache[model_id][cache_id] = pred(
|
129 |
+
video_id, model, tokenizer,
|
130 |
+
segmentation_args=segmentation_args,
|
131 |
+
words=words,
|
132 |
+
classifier=classifier
|
133 |
+
)
|
134 |
+
return prediction_cache[model_id][cache_id]
|
135 |
+
|
136 |
+
|
137 |
+
def load_predict(model_id):
|
138 |
+
model_info = MODELS[model_id]
|
139 |
+
|
140 |
+
if model_id not in prediction_function_cache:
|
141 |
+
# Use default segmentation and classification arguments
|
142 |
+
predict_args = PredictArguments(model_name_or_path=model_info['repo_id'])
|
143 |
+
general_args = GeneralArguments()
|
144 |
+
segmentation_args = SegmentationArguments()
|
145 |
+
|
146 |
+
model, tokenizer, classifier = get_model_tokenizer_classifier(predict_args, general_args)
|
147 |
+
|
148 |
+
prediction_function_cache[model_id] = partial(
|
149 |
+
predict_function, model_id, model, tokenizer, segmentation_args, classifier)
|
150 |
+
|
151 |
+
|
152 |
+
return prediction_function_cache[model_id]
|
153 |
+
|
154 |
+
|
155 |
+
def create_button(text, url):
|
156 |
+
return f"""<div class="row-widget stButton" style="text-align: center">
|
157 |
+
<a href="{url}" target="_blank" rel="noopener noreferrer" class="btn-link">
|
158 |
+
<button kind="primary" class="btn">{text}</button>
|
159 |
+
</a>
|
160 |
+
</div>"""
|
161 |
+
|
162 |
+
|
163 |
+
def main():
|
164 |
+
st.markdown("""<style>
|
165 |
+
.btn {
|
166 |
+
display: inline-flex;
|
167 |
+
-webkit-box-align: center;
|
168 |
+
align-items: center;
|
169 |
+
-webkit-box-pack: center;
|
170 |
+
justify-content: center;
|
171 |
+
font-weight: 600;
|
172 |
+
padding: 0.25rem 0.75rem;
|
173 |
+
border-radius: 0.25rem;
|
174 |
+
margin: 0px;
|
175 |
+
line-height: 1.5;
|
176 |
+
color: inherit;
|
177 |
+
width: auto;
|
178 |
+
user-select: none;
|
179 |
+
background-color: inherit;
|
180 |
+
border: 1px solid rgba(49, 51, 63, 0.2);
|
181 |
+
}
|
182 |
+
.btn-link {
|
183 |
+
color: inherit;
|
184 |
+
text-decoration: none;
|
185 |
+
}
|
186 |
+
</style>""", unsafe_allow_html=True)
|
187 |
+
|
188 |
+
top = st.container()
|
189 |
+
output = st.empty()
|
190 |
+
|
191 |
+
# Display heading and subheading
|
192 |
+
top.markdown('# PromoDetect')
|
193 |
+
top.markdown(
|
194 |
+
'##### Automatically detect in-video YouTube sponsorships, self/unpaid promotions, and interaction reminders.')
|
195 |
+
|
196 |
+
# Add controls
|
197 |
+
|
198 |
+
col1, col2 = top.columns(2)
|
199 |
+
|
200 |
+
with col1:
|
201 |
+
model_id = st.selectbox(
|
202 |
+
'Select model', MODELS.keys(), index=0, on_change=output.empty)
|
203 |
+
|
204 |
+
with col2:
|
205 |
+
ts_type_id = st.selectbox(
|
206 |
+
'Transcript type', TRANSCRIPT_TYPES.keys(), index=0, format_func=lambda x: TRANSCRIPT_TYPES[x]['label'], on_change=output.empty)
|
207 |
+
|
208 |
+
query_params = st.experimental_get_query_params()
|
209 |
+
|
210 |
+
video_id = None
|
211 |
+
|
212 |
+
if 'v' in query_params:
|
213 |
+
video_id = query_params['v'][0]
|
214 |
+
|
215 |
+
if video_id is None:
|
216 |
+
video_input = top.text_input('Video URL/ID:', on_change=output.empty)
|
217 |
+
else :
|
218 |
+
video_input = top.text_input('Video URL/ID:', on_change=output.empty,value = video_id)
|
219 |
+
categories = top.multiselect('Categories:',
|
220 |
+
CATGEGORY_OPTIONS.keys(),
|
221 |
+
CATGEGORY_OPTIONS.keys(),
|
222 |
+
format_func=CATGEGORY_OPTIONS.get, on_change=output.empty
|
223 |
+
)
|
224 |
+
|
225 |
+
# Hide segments with a confidence lower than
|
226 |
+
confidence_threshold = top.slider(
|
227 |
+
'Confidence Threshold (%):', min_value=0, value=50, max_value=100, on_change=output.empty)
|
228 |
+
|
229 |
+
if len(video_input) == 0: # No input, do not continue
|
230 |
+
return
|
231 |
+
|
232 |
+
# Load prediction function
|
233 |
+
with st.spinner('Loading model...'):
|
234 |
+
predict = load_predict(model_id)
|
235 |
+
|
236 |
+
with output.container(): # Place all content in output container
|
237 |
+
video_id = regex_search(video_input, YT_VIDEO_REGEX)
|
238 |
+
if video_id is None:
|
239 |
+
st.exception(ValueError('Invalid YouTube URL/ID'))
|
240 |
+
return
|
241 |
+
|
242 |
+
try:
|
243 |
+
with st.spinner('Downloading transcript...'):
|
244 |
+
words = get_words(video_id,
|
245 |
+
transcript_type=TRANSCRIPT_TYPES[ts_type_id]['type'],
|
246 |
+
fallback=TRANSCRIPT_TYPES[ts_type_id]['fallback']
|
247 |
+
)
|
248 |
+
except TranscriptError:
|
249 |
+
pass
|
250 |
+
|
251 |
+
if not words:
|
252 |
+
st.error('No transcript found!')
|
253 |
+
return
|
254 |
+
|
255 |
+
with st.spinner('Running model...'):
|
256 |
+
predictions = predict(video_id, words, ts_type_id)
|
257 |
+
|
258 |
+
if len(predictions) == 0:
|
259 |
+
st.success('No segments found!')
|
260 |
+
return
|
261 |
+
|
262 |
+
submit_segments = []
|
263 |
+
for index, prediction in enumerate(predictions, start=1):
|
264 |
+
category_key = prediction['category'].upper()
|
265 |
+
if category_key not in categories:
|
266 |
+
continue # Skip
|
267 |
+
|
268 |
+
confidence = prediction['probability'] * 100
|
269 |
+
|
270 |
+
if confidence < confidence_threshold:
|
271 |
+
continue
|
272 |
+
|
273 |
+
submit_segments.append({
|
274 |
+
'segment': [prediction['start'], prediction['end']],
|
275 |
+
'category': prediction['category'],
|
276 |
+
'actionType': 'skip'
|
277 |
+
})
|
278 |
+
start_time = seconds_to_time(prediction['start'])
|
279 |
+
end_time = seconds_to_time(prediction['end'])
|
280 |
+
with st.expander(
|
281 |
+
f"[{category_key}] Prediction #{index} ({start_time} \u2192 {end_time})"
|
282 |
+
):
|
283 |
+
|
284 |
+
url = f"https://www.youtube-nocookie.com/embed/{video_id}?&start={floor(prediction['start'])}&end={ceil(prediction['end'])}"
|
285 |
+
# autoplay=1controls=0&&modestbranding=1&fs=0
|
286 |
+
|
287 |
+
# , width=None, height=None, scrolling=False
|
288 |
+
components.iframe(url, width=670, height=376)
|
289 |
+
|
290 |
+
text = ' '.join(w['text'] for w in prediction['words'])
|
291 |
+
st.write(f"**Times:** {start_time} \u2192 {end_time}")
|
292 |
+
st.write(
|
293 |
+
f"**Category:** {CATGEGORY_OPTIONS[category_key]}")
|
294 |
+
st.write(f"**Confidence:** {confidence:.2f}%")
|
295 |
+
st.write(f'**Text:** "{text}"')
|
296 |
+
|
297 |
+
if not submit_segments:
|
298 |
+
st.success(
|
299 |
+
f'No segments found! ({len(predictions)} ignored due to filters/settings)')
|
300 |
+
return
|
301 |
+
|
302 |
+
num_hidden = len(predictions) - len(submit_segments)
|
303 |
+
if num_hidden > 0:
|
304 |
+
st.info(
|
305 |
+
f'{num_hidden} predictions hidden (adjust the settings and filters to view them all).')
|
306 |
+
|
307 |
+
json_data = quote(json.dumps(submit_segments))
|
308 |
+
link = f'https://www.youtube.com/watch?v={video_id}#segments={json_data}'
|
309 |
+
st.markdown(create_button('Submit Segments', link),
|
310 |
+
unsafe_allow_html=True)
|
311 |
+
|
312 |
+
# st.markdown(f"""<div style="text-align: center;font-size: 16px;margin-top: 6px">
|
313 |
+
# <a href="https://wiki.sponsor.ajay.app/w/Automating_Submissions" target="_blank" rel="noopener noreferrer">(Review before submitting!)</a>
|
314 |
+
# </div>""", unsafe_allow_html=True)
|
315 |
+
|
316 |
+
|
317 |
+
if __name__ == '__main__':
|
318 |
+
main()
|
data/.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore everything in this directory
|
2 |
+
*
|
3 |
+
# Except this file
|
4 |
+
!.gitignore
|
models/.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore everything in this directory
|
2 |
+
*
|
3 |
+
# Except this file
|
4 |
+
!.gitignore
|
raw/.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore everything in this directory
|
2 |
+
*
|
3 |
+
# Except this file
|
4 |
+
!.gitignore
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
datasets
|
3 |
+
youtube_transcript_api
|
4 |
+
torch
|
5 |
+
pandas
|
6 |
+
numpy
|
7 |
+
sentencepiece
|
src/classify.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TextClassificationPipeline
|
2 |
+
import preprocess
|
3 |
+
import segment
|
4 |
+
|
5 |
+
|
6 |
+
class SponsorBlockClassificationPipeline(TextClassificationPipeline):
|
7 |
+
def __init__(self, model, tokenizer):
|
8 |
+
device = next(model.parameters()).device.index
|
9 |
+
if device is None:
|
10 |
+
device = -1
|
11 |
+
super().__init__(model=model, tokenizer=tokenizer,
|
12 |
+
return_all_scores=True, truncation=True, device=device)
|
13 |
+
|
14 |
+
def preprocess(self, data, **tokenizer_kwargs):
|
15 |
+
# TODO add support for lists
|
16 |
+
texts = []
|
17 |
+
|
18 |
+
if not isinstance(data, list):
|
19 |
+
data = [data]
|
20 |
+
|
21 |
+
for d in data:
|
22 |
+
if isinstance(d, dict): # Otherwise, get data from transcript
|
23 |
+
words = preprocess.get_words(d['video_id'])
|
24 |
+
segment_words = segment.extract_segment(
|
25 |
+
words, d['start'], d['end'])
|
26 |
+
text = preprocess.clean_text(
|
27 |
+
' '.join(x['text'] for x in segment_words))
|
28 |
+
texts.append(text)
|
29 |
+
elif isinstance(d, str): # If string, assume this is what user wants to classify
|
30 |
+
texts.append(d)
|
31 |
+
else:
|
32 |
+
raise ValueError(f'Invalid input type: "{type(d)}"')
|
33 |
+
|
34 |
+
return self.tokenizer(
|
35 |
+
texts, return_tensors=self.framework, **tokenizer_kwargs)
|
36 |
+
|
37 |
+
|
38 |
+
def main():
|
39 |
+
pass
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == '__main__':
|
43 |
+
main()
|
src/errors.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class SponsorBlockException(Exception):
|
3 |
+
"""Base class for all sponsor block exceptions"""
|
4 |
+
pass
|
5 |
+
|
6 |
+
|
7 |
+
class InferenceException(SponsorBlockException):
|
8 |
+
"""An exception occurred while predicting sponsor segments"""
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
class TranscriptError(SponsorBlockException):
|
13 |
+
"""An exception occurred while retrieving the video transcript"""
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
class ModelError(SponsorBlockException):
|
18 |
+
"""Base class for model-related errors"""
|
19 |
+
pass
|
20 |
+
|
21 |
+
|
22 |
+
class ModelLoadError(ModelError):
|
23 |
+
"""An exception occurred while loading the model"""
|
24 |
+
pass
|
src/evaluate.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from model import get_model_tokenizer_classifier, InferenceArguments
|
3 |
+
from utils import jaccard, safe_print
|
4 |
+
from transformers import HfArgumentParser
|
5 |
+
from preprocess import get_words, clean_text
|
6 |
+
from shared import GeneralArguments, DatasetArguments
|
7 |
+
from predict import predict
|
8 |
+
from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
|
9 |
+
import pandas as pd
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
from typing import Optional
|
12 |
+
from tqdm import tqdm
|
13 |
+
import json
|
14 |
+
import os
|
15 |
+
import random
|
16 |
+
from shared import seconds_to_time
|
17 |
+
from urllib.parse import quote
|
18 |
+
import logging
|
19 |
+
|
20 |
+
logging.basicConfig()
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class EvaluationArguments(InferenceArguments):
|
26 |
+
"""Arguments pertaining to how evaluation will occur."""
|
27 |
+
output_file: Optional[str] = field(
|
28 |
+
default='metrics.csv',
|
29 |
+
metadata={
|
30 |
+
'help': 'Save metrics to output file'
|
31 |
+
}
|
32 |
+
)
|
33 |
+
|
34 |
+
skip_missing: bool = field(
|
35 |
+
default=False,
|
36 |
+
metadata={
|
37 |
+
'help': 'Whether to skip checking for missing segments. If False, predictions will be made.'
|
38 |
+
}
|
39 |
+
)
|
40 |
+
skip_incorrect: bool = field(
|
41 |
+
default=False,
|
42 |
+
metadata={
|
43 |
+
'help': 'Whether to skip checking for incorrect segments. If False, classifications will be made on existing segments.'
|
44 |
+
}
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
|
49 |
+
"""Attach sponsor segments to closest prediction"""
|
50 |
+
for prediction in predictions:
|
51 |
+
prediction['best_overlap'] = 0
|
52 |
+
prediction['best_sponsorship'] = None
|
53 |
+
|
54 |
+
# Assign predictions to actual (labelled) sponsored segments
|
55 |
+
for sponsor_segment in sponsor_segments:
|
56 |
+
j = jaccard(prediction['start'], prediction['end'],
|
57 |
+
sponsor_segment['start'], sponsor_segment['end'])
|
58 |
+
if prediction['best_overlap'] < j:
|
59 |
+
prediction['best_overlap'] = j
|
60 |
+
prediction['best_sponsorship'] = sponsor_segment
|
61 |
+
|
62 |
+
return sponsor_segments
|
63 |
+
|
64 |
+
|
65 |
+
def calculate_metrics(labelled_words, predictions):
|
66 |
+
|
67 |
+
metrics = {
|
68 |
+
'true_positive': 0, # Is sponsor, predicted sponsor
|
69 |
+
# Is sponsor, predicted not sponsor (i.e., missed it - bad)
|
70 |
+
'false_negative': 0,
|
71 |
+
# Is not sponsor, predicted sponsor (classified incorectly, not that bad since we do manual checking afterwards)
|
72 |
+
'false_positive': 0,
|
73 |
+
'true_negative': 0, # Is not sponsor, predicted not sponsor
|
74 |
+
}
|
75 |
+
|
76 |
+
metrics['video_duration'] = word_end(
|
77 |
+
labelled_words[-1])-word_start(labelled_words[0])
|
78 |
+
|
79 |
+
for index, word in enumerate(labelled_words):
|
80 |
+
if index >= len(labelled_words) - 1:
|
81 |
+
continue
|
82 |
+
|
83 |
+
duration = word_end(word) - word_start(word)
|
84 |
+
|
85 |
+
predicted_sponsor = False
|
86 |
+
for p in predictions:
|
87 |
+
# Is in some prediction
|
88 |
+
if p['start'] <= word['start'] <= p['end']:
|
89 |
+
predicted_sponsor = True
|
90 |
+
break
|
91 |
+
|
92 |
+
if predicted_sponsor:
|
93 |
+
# total_positive_time += duration
|
94 |
+
if word.get('category') is not None: # Is actual sponsor
|
95 |
+
metrics['true_positive'] += duration
|
96 |
+
else:
|
97 |
+
metrics['false_positive'] += duration
|
98 |
+
else:
|
99 |
+
# total_negative_time += duration
|
100 |
+
if word.get('category') is not None: # Is actual sponsor
|
101 |
+
metrics['false_negative'] += duration
|
102 |
+
else:
|
103 |
+
metrics['true_negative'] += duration
|
104 |
+
|
105 |
+
# NOTE In cases where we encounter division by 0, we say that the value is 1
|
106 |
+
# https://stats.stackexchange.com/a/1775
|
107 |
+
# (Precision) TP+FP=0: means that all instances were predicted as negative
|
108 |
+
# (Recall) TP+FN=0: means that there were no positive cases in the input data
|
109 |
+
|
110 |
+
# The fraction of predictions our model got right
|
111 |
+
# Can simplify, but use full formula
|
112 |
+
z = metrics['true_positive'] + metrics['true_negative'] + \
|
113 |
+
metrics['false_positive'] + metrics['false_negative']
|
114 |
+
metrics['accuracy'] = (
|
115 |
+
(metrics['true_positive'] + metrics['true_negative']) / z) if z > 0 else 1
|
116 |
+
|
117 |
+
# What proportion of positive identifications was actually correct?
|
118 |
+
z = metrics['true_positive'] + metrics['false_positive']
|
119 |
+
metrics['precision'] = (metrics['true_positive'] / z) if z > 0 else 1
|
120 |
+
|
121 |
+
# What proportion of actual positives was identified correctly?
|
122 |
+
z = metrics['true_positive'] + metrics['false_negative']
|
123 |
+
metrics['recall'] = (metrics['true_positive'] / z) if z > 0 else 1
|
124 |
+
|
125 |
+
# https://deepai.org/machine-learning-glossary-and-terms/f-score
|
126 |
+
|
127 |
+
s = metrics['precision'] + metrics['recall']
|
128 |
+
metrics['f-score'] = (2 * (metrics['precision'] *
|
129 |
+
metrics['recall']) / s) if s > 0 else 0
|
130 |
+
|
131 |
+
return metrics
|
132 |
+
|
133 |
+
|
134 |
+
def main():
|
135 |
+
logger.setLevel(logging.DEBUG)
|
136 |
+
|
137 |
+
hf_parser = HfArgumentParser((
|
138 |
+
EvaluationArguments,
|
139 |
+
DatasetArguments,
|
140 |
+
SegmentationArguments,
|
141 |
+
GeneralArguments
|
142 |
+
))
|
143 |
+
|
144 |
+
evaluation_args, dataset_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()
|
145 |
+
|
146 |
+
if evaluation_args.skip_missing and evaluation_args.skip_incorrect:
|
147 |
+
logger.error('ERROR: Nothing to do')
|
148 |
+
return
|
149 |
+
|
150 |
+
# Load labelled data:
|
151 |
+
final_path = os.path.join(
|
152 |
+
dataset_args.data_dir, dataset_args.processed_file)
|
153 |
+
|
154 |
+
if not os.path.exists(final_path):
|
155 |
+
logger.error('ERROR: Processed database not found.\n'
|
156 |
+
f'Run `python src/preprocess.py --update_database --do_create` to generate "{final_path}".')
|
157 |
+
return
|
158 |
+
|
159 |
+
model, tokenizer, classifier = get_model_tokenizer_classifier(
|
160 |
+
evaluation_args, general_args)
|
161 |
+
|
162 |
+
with open(final_path) as fp:
|
163 |
+
final_data = json.load(fp)
|
164 |
+
|
165 |
+
if evaluation_args.video_ids: # Use specified
|
166 |
+
video_ids = evaluation_args.video_ids
|
167 |
+
|
168 |
+
else: # Use items found in preprocessed database
|
169 |
+
video_ids = list(final_data.keys())
|
170 |
+
random.shuffle(video_ids)
|
171 |
+
|
172 |
+
if evaluation_args.start_index is not None:
|
173 |
+
video_ids = video_ids[evaluation_args.start_index:]
|
174 |
+
|
175 |
+
if evaluation_args.max_videos is not None:
|
176 |
+
video_ids = video_ids[:evaluation_args.max_videos]
|
177 |
+
|
178 |
+
out_metrics = []
|
179 |
+
|
180 |
+
all_metrics = {}
|
181 |
+
if not evaluation_args.skip_missing:
|
182 |
+
all_metrics['total_prediction_accuracy'] = 0
|
183 |
+
all_metrics['total_prediction_precision'] = 0
|
184 |
+
all_metrics['total_prediction_recall'] = 0
|
185 |
+
all_metrics['total_prediction_fscore'] = 0
|
186 |
+
|
187 |
+
if not evaluation_args.skip_incorrect:
|
188 |
+
all_metrics['classifier_segment_correct'] = 0
|
189 |
+
all_metrics['classifier_segment_count'] = 0
|
190 |
+
|
191 |
+
metric_count = 0
|
192 |
+
|
193 |
+
postfix_info = {}
|
194 |
+
|
195 |
+
try:
|
196 |
+
with tqdm(video_ids) as progress:
|
197 |
+
for video_index, video_id in enumerate(progress):
|
198 |
+
progress.set_description(f'Processing {video_id}')
|
199 |
+
|
200 |
+
words = get_words(video_id)
|
201 |
+
if not words:
|
202 |
+
continue
|
203 |
+
|
204 |
+
# Get labels
|
205 |
+
sponsor_segments = final_data.get(video_id)
|
206 |
+
|
207 |
+
# Reset previous
|
208 |
+
missed_segments = []
|
209 |
+
incorrect_segments = []
|
210 |
+
|
211 |
+
current_metrics = {
|
212 |
+
'video_id': video_id
|
213 |
+
}
|
214 |
+
metric_count += 1
|
215 |
+
|
216 |
+
if not evaluation_args.skip_missing: # Make predictions
|
217 |
+
predictions = predict(video_id, model, tokenizer, segmentation_args,
|
218 |
+
classifier=classifier,
|
219 |
+
min_probability=evaluation_args.min_probability)
|
220 |
+
|
221 |
+
if sponsor_segments:
|
222 |
+
labelled_words = add_labels_to_words(
|
223 |
+
words, sponsor_segments)
|
224 |
+
|
225 |
+
current_metrics.update(
|
226 |
+
calculate_metrics(labelled_words, predictions))
|
227 |
+
|
228 |
+
all_metrics['total_prediction_accuracy'] += current_metrics['accuracy']
|
229 |
+
all_metrics['total_prediction_precision'] += current_metrics['precision']
|
230 |
+
all_metrics['total_prediction_recall'] += current_metrics['recall']
|
231 |
+
all_metrics['total_prediction_fscore'] += current_metrics['f-score']
|
232 |
+
|
233 |
+
# Just for display purposes
|
234 |
+
postfix_info.update({
|
235 |
+
'accuracy': all_metrics['total_prediction_accuracy']/metric_count,
|
236 |
+
'precision': all_metrics['total_prediction_precision']/metric_count,
|
237 |
+
'recall': all_metrics['total_prediction_recall']/metric_count,
|
238 |
+
'f-score': all_metrics['total_prediction_fscore']/metric_count,
|
239 |
+
})
|
240 |
+
|
241 |
+
sponsor_segments = attach_predictions_to_sponsor_segments(
|
242 |
+
predictions, sponsor_segments)
|
243 |
+
|
244 |
+
# Identify possible issues:
|
245 |
+
for prediction in predictions:
|
246 |
+
if prediction['best_sponsorship'] is not None:
|
247 |
+
continue
|
248 |
+
|
249 |
+
prediction_words = prediction.pop('words', [])
|
250 |
+
|
251 |
+
# Attach original text to missed segments
|
252 |
+
prediction['text'] = ' '.join(
|
253 |
+
x['text'] for x in prediction_words)
|
254 |
+
missed_segments.append(prediction)
|
255 |
+
|
256 |
+
else:
|
257 |
+
# Not in database (all segments missed)
|
258 |
+
missed_segments = predictions
|
259 |
+
|
260 |
+
if not evaluation_args.skip_incorrect and sponsor_segments:
|
261 |
+
# Check for incorrect segments using the classifier
|
262 |
+
|
263 |
+
segments_to_check = []
|
264 |
+
cleaned_texts = [] # Texts to send through tokenizer
|
265 |
+
for sponsor_segment in sponsor_segments:
|
266 |
+
segment_words = extract_segment(
|
267 |
+
words, sponsor_segment['start'], sponsor_segment['end'])
|
268 |
+
sponsor_segment['text'] = ' '.join(
|
269 |
+
x['text'] for x in segment_words)
|
270 |
+
|
271 |
+
duration = sponsor_segment['end'] - \
|
272 |
+
sponsor_segment['start']
|
273 |
+
wps = (len(segment_words) /
|
274 |
+
duration) if duration > 0 else 0
|
275 |
+
if wps < 1.5:
|
276 |
+
continue
|
277 |
+
|
278 |
+
# Do not worry about those that are locked or have enough votes
|
279 |
+
# or segment['votes'] > 5:
|
280 |
+
if sponsor_segment['locked']:
|
281 |
+
continue
|
282 |
+
|
283 |
+
cleaned_texts.append(
|
284 |
+
clean_text(sponsor_segment['text']))
|
285 |
+
segments_to_check.append(sponsor_segment)
|
286 |
+
|
287 |
+
if segments_to_check: # Some segments to check
|
288 |
+
|
289 |
+
segments_scores = classifier(cleaned_texts)
|
290 |
+
|
291 |
+
num_correct = 0
|
292 |
+
for segment, scores in zip(segments_to_check, segments_scores):
|
293 |
+
|
294 |
+
fixed_scores = {
|
295 |
+
score['label']: score['score']
|
296 |
+
for score in scores
|
297 |
+
}
|
298 |
+
|
299 |
+
all_metrics['classifier_segment_count'] += 1
|
300 |
+
|
301 |
+
prediction = max(scores, key=lambda x: x['score'])
|
302 |
+
predicted_category = prediction['label'].lower()
|
303 |
+
|
304 |
+
if predicted_category == segment['category']:
|
305 |
+
num_correct += 1
|
306 |
+
continue # Ignore correct segments
|
307 |
+
|
308 |
+
segment.update({
|
309 |
+
'predicted': predicted_category,
|
310 |
+
'scores': fixed_scores
|
311 |
+
})
|
312 |
+
|
313 |
+
incorrect_segments.append(segment)
|
314 |
+
|
315 |
+
current_metrics['num_segments'] = len(
|
316 |
+
segments_to_check)
|
317 |
+
current_metrics['classified_correct'] = num_correct
|
318 |
+
|
319 |
+
all_metrics['classifier_segment_correct'] += num_correct
|
320 |
+
|
321 |
+
if all_metrics['classifier_segment_count'] > 0:
|
322 |
+
postfix_info['classifier_accuracy'] = all_metrics['classifier_segment_correct'] / \
|
323 |
+
all_metrics['classifier_segment_count']
|
324 |
+
|
325 |
+
out_metrics.append(current_metrics)
|
326 |
+
progress.set_postfix(postfix_info)
|
327 |
+
|
328 |
+
if missed_segments or incorrect_segments:
|
329 |
+
|
330 |
+
if evaluation_args.output_as_json:
|
331 |
+
to_print = {'video_id': video_id}
|
332 |
+
|
333 |
+
if missed_segments:
|
334 |
+
to_print['missed'] = missed_segments
|
335 |
+
|
336 |
+
if incorrect_segments:
|
337 |
+
to_print['incorrect'] = incorrect_segments
|
338 |
+
|
339 |
+
safe_print(json.dumps(to_print))
|
340 |
+
|
341 |
+
else:
|
342 |
+
safe_print(
|
343 |
+
f'Issues identified for {video_id} (#{video_index})')
|
344 |
+
# Potentially missed segments (model predicted, but not in database)
|
345 |
+
if missed_segments:
|
346 |
+
safe_print(' - Missed segments:')
|
347 |
+
segments_to_submit = []
|
348 |
+
for i, missed_segment in enumerate(missed_segments, start=1):
|
349 |
+
safe_print(f'\t#{i}:', seconds_to_time(
|
350 |
+
missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
|
351 |
+
safe_print('\t\tText: "',
|
352 |
+
missed_segment['text'], '"', sep='')
|
353 |
+
safe_print('\t\tCategory:',
|
354 |
+
missed_segment.get('category'))
|
355 |
+
if 'probability' in missed_segment:
|
356 |
+
safe_print('\t\tProbability:',
|
357 |
+
missed_segment['probability'])
|
358 |
+
|
359 |
+
segments_to_submit.append({
|
360 |
+
'segment': [missed_segment['start'], missed_segment['end']],
|
361 |
+
'category': missed_segment['category'].lower(),
|
362 |
+
'actionType': 'skip'
|
363 |
+
})
|
364 |
+
|
365 |
+
json_data = quote(json.dumps(segments_to_submit))
|
366 |
+
safe_print(
|
367 |
+
f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')
|
368 |
+
|
369 |
+
# Incorrect segments (in database, but incorrectly classified)
|
370 |
+
if incorrect_segments:
|
371 |
+
safe_print(' - Incorrect segments:')
|
372 |
+
for i, incorrect_segment in enumerate(incorrect_segments, start=1):
|
373 |
+
safe_print(f'\t#{i}:', seconds_to_time(
|
374 |
+
incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
|
375 |
+
|
376 |
+
safe_print(
|
377 |
+
'\t\tText: "', incorrect_segment['text'], '"', sep='')
|
378 |
+
safe_print(
|
379 |
+
'\t\tUUID:', incorrect_segment['uuid'])
|
380 |
+
safe_print(
|
381 |
+
'\t\tVotes:', incorrect_segment['votes'])
|
382 |
+
safe_print(
|
383 |
+
'\t\tViews:', incorrect_segment['views'])
|
384 |
+
safe_print('\t\tLocked:',
|
385 |
+
incorrect_segment['locked'])
|
386 |
+
|
387 |
+
safe_print('\t\tCurrent Category:',
|
388 |
+
incorrect_segment['category'])
|
389 |
+
safe_print('\t\tPredicted Category:',
|
390 |
+
incorrect_segment['predicted'])
|
391 |
+
safe_print('\t\tProbabilities:')
|
392 |
+
for label, score in incorrect_segment['scores'].items():
|
393 |
+
safe_print(
|
394 |
+
f"\t\t\t{label}: {score}")
|
395 |
+
|
396 |
+
safe_print()
|
397 |
+
|
398 |
+
except KeyboardInterrupt:
|
399 |
+
pass
|
400 |
+
|
401 |
+
df = pd.DataFrame(out_metrics)
|
402 |
+
|
403 |
+
df.to_csv(evaluation_args.output_file)
|
404 |
+
logger.info(df.mean())
|
405 |
+
|
406 |
+
|
407 |
+
if __name__ == '__main__':
|
408 |
+
main()
|
src/model.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, TrainingArguments
|
2 |
+
from shared import CustomTokens, GeneralArguments
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import Optional, Union
|
5 |
+
import torch
|
6 |
+
import classify
|
7 |
+
import base64
|
8 |
+
import re
|
9 |
+
import requests
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
|
13 |
+
logging.basicConfig()
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
# Public innertube key (b64 encoded so that it is not incorrectly flagged)
|
17 |
+
INNERTUBE_KEY = base64.b64decode(
|
18 |
+
b'QUl6YVN5QU9fRkoyU2xxVThRNFNURUhMR0NpbHdfWTlfMTFxY1c4').decode()
|
19 |
+
|
20 |
+
YT_CONTEXT = {
|
21 |
+
'client': {
|
22 |
+
'userAgent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36,gzip(gfe)',
|
23 |
+
'clientName': 'WEB',
|
24 |
+
'clientVersion': '2.20211221.00.00',
|
25 |
+
}
|
26 |
+
}
|
27 |
+
_YT_INITIAL_DATA_RE = r'(?:window\s*\[\s*["\']ytInitialData["\']\s*\]|ytInitialData)\s*=\s*({.+?})\s*;\s*(?:var\s+meta|</script|\n)'
|
28 |
+
|
29 |
+
|
30 |
+
def get_all_channel_vids(channel_id):
|
31 |
+
continuation = None
|
32 |
+
while True:
|
33 |
+
if continuation is None:
|
34 |
+
params = {'list': channel_id.replace('UC', 'UU', 1)}
|
35 |
+
response = requests.get(
|
36 |
+
'https://www.youtube.com/playlist', params=params)
|
37 |
+
items = json.loads(re.search(_YT_INITIAL_DATA_RE, response.text).group(1))['contents']['twoColumnBrowseResultsRenderer']['tabs'][0]['tabRenderer']['content'][
|
38 |
+
'sectionListRenderer']['contents'][0]['itemSectionRenderer']['contents'][0]['playlistVideoListRenderer']['contents']
|
39 |
+
else:
|
40 |
+
params = {'key': INNERTUBE_KEY}
|
41 |
+
data = {
|
42 |
+
'context': YT_CONTEXT,
|
43 |
+
'continuation': continuation
|
44 |
+
}
|
45 |
+
response = requests.post(
|
46 |
+
'https://www.youtube.com/youtubei/v1/browse', params=params, json=data)
|
47 |
+
items = response.json()[
|
48 |
+
'onResponseReceivedActions'][0]['appendContinuationItemsAction']['continuationItems']
|
49 |
+
|
50 |
+
new_token = None
|
51 |
+
for vid in items:
|
52 |
+
info = vid.get('playlistVideoRenderer')
|
53 |
+
if info:
|
54 |
+
yield info['videoId']
|
55 |
+
continue
|
56 |
+
|
57 |
+
info = vid.get('continuationItemRenderer')
|
58 |
+
if info:
|
59 |
+
new_token = info['continuationEndpoint']['continuationCommand']['token']
|
60 |
+
|
61 |
+
if new_token is None:
|
62 |
+
break
|
63 |
+
continuation = new_token
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
@dataclass
|
68 |
+
class ModelArguments:
|
69 |
+
"""
|
70 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
71 |
+
"""
|
72 |
+
|
73 |
+
model_name_or_path: str = field(
|
74 |
+
default=None,
|
75 |
+
metadata={
|
76 |
+
'help': 'Path to pretrained model or model identifier from huggingface.co/models'
|
77 |
+
}
|
78 |
+
)
|
79 |
+
|
80 |
+
cache_dir: Optional[str] = field(
|
81 |
+
default='models',
|
82 |
+
metadata={
|
83 |
+
'help': 'Where to store the pretrained models downloaded from huggingface.co'
|
84 |
+
},
|
85 |
+
)
|
86 |
+
use_fast_tokenizer: bool = field(
|
87 |
+
default=True,
|
88 |
+
metadata={
|
89 |
+
'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'
|
90 |
+
},
|
91 |
+
)
|
92 |
+
model_revision: str = field(
|
93 |
+
default='main',
|
94 |
+
metadata={
|
95 |
+
'help': 'The specific model version to use (can be a branch name, tag name or commit id).'
|
96 |
+
},
|
97 |
+
)
|
98 |
+
use_auth_token: bool = field(
|
99 |
+
default=False,
|
100 |
+
metadata={
|
101 |
+
'help': 'Will use the token generated when running `transformers-cli login` (necessary to use this script '
|
102 |
+
'with private models).'
|
103 |
+
},
|
104 |
+
)
|
105 |
+
|
106 |
+
import itertools
|
107 |
+
from errors import InferenceException, ModelLoadError
|
108 |
+
|
109 |
+
@dataclass
|
110 |
+
class InferenceArguments(ModelArguments):
|
111 |
+
|
112 |
+
model_name_or_path: str = field(
|
113 |
+
default='Xenova/sponsorblock-small',
|
114 |
+
metadata={
|
115 |
+
'help': 'Path to pretrained model used for prediction'
|
116 |
+
}
|
117 |
+
)
|
118 |
+
classifier_model_name_or_path: str = field(
|
119 |
+
default='EColi/SB_Classifier',
|
120 |
+
metadata={
|
121 |
+
'help': 'Use a pretrained classifier'
|
122 |
+
}
|
123 |
+
)
|
124 |
+
|
125 |
+
max_videos: Optional[int] = field(
|
126 |
+
default=None,
|
127 |
+
metadata={
|
128 |
+
'help': 'The number of videos to test on'
|
129 |
+
}
|
130 |
+
)
|
131 |
+
start_index: int = field(default=None, metadata={
|
132 |
+
'help': 'Video to start the evaluation at.'})
|
133 |
+
channel_id: Optional[str] = field(
|
134 |
+
default=None,
|
135 |
+
metadata={
|
136 |
+
'help': 'Used to evaluate a channel'
|
137 |
+
}
|
138 |
+
)
|
139 |
+
video_ids: str = field(
|
140 |
+
default_factory=lambda: [],
|
141 |
+
metadata={
|
142 |
+
'nargs': '+'
|
143 |
+
}
|
144 |
+
)
|
145 |
+
|
146 |
+
output_as_json: bool = field(default=False, metadata={
|
147 |
+
'help': 'Output evaluations as JSON'})
|
148 |
+
|
149 |
+
min_probability: float = field(
|
150 |
+
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
151 |
+
|
152 |
+
def __post_init__(self):
|
153 |
+
|
154 |
+
self.video_ids = list(map(str.strip, self.video_ids))
|
155 |
+
|
156 |
+
if any(len(video_id) != 11 for video_id in self.video_ids):
|
157 |
+
raise InferenceException('Invalid video IDs (length not 11)')
|
158 |
+
|
159 |
+
if self.channel_id is not None:
|
160 |
+
start = self.start_index or 0
|
161 |
+
end = None if self.max_videos is None else start + self.max_videos
|
162 |
+
|
163 |
+
channel_video_ids = list(itertools.islice(get_all_channel_vids(
|
164 |
+
self.channel_id), start, end))
|
165 |
+
logger.info(
|
166 |
+
f'Found {len(channel_video_ids)} for channel {self.channel_id}')
|
167 |
+
|
168 |
+
self.video_ids += channel_video_ids
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
def get_model_tokenizer_classifier(inference_args: InferenceArguments, general_args: GeneralArguments):
|
173 |
+
|
174 |
+
original_path = inference_args.model_name_or_path
|
175 |
+
|
176 |
+
# Load main model and tokenizer
|
177 |
+
model, tokenizer = get_model_tokenizer(inference_args, general_args)
|
178 |
+
|
179 |
+
# Load classifier
|
180 |
+
inference_args.model_name_or_path = inference_args.classifier_model_name_or_path
|
181 |
+
classifier_model, classifier_tokenizer = get_model_tokenizer(
|
182 |
+
inference_args, general_args, model_type='classifier')
|
183 |
+
|
184 |
+
classifier = classify.SponsorBlockClassificationPipeline(
|
185 |
+
classifier_model, classifier_tokenizer)
|
186 |
+
|
187 |
+
# Reset to original model_name_or_path
|
188 |
+
inference_args.model_name_or_path = original_path
|
189 |
+
|
190 |
+
return model, tokenizer, classifier
|
191 |
+
|
192 |
+
|
193 |
+
def get_model_tokenizer(model_args: ModelArguments, general_args: Union[GeneralArguments, TrainingArguments] = None, config_args=None, model_type='seq2seq'):
|
194 |
+
if model_args.model_name_or_path is None:
|
195 |
+
raise ModelLoadError('Must specify --model_name_or_path')
|
196 |
+
|
197 |
+
if config_args is None:
|
198 |
+
config_args = {}
|
199 |
+
|
200 |
+
use_auth_token = True if model_args.use_auth_token else None
|
201 |
+
|
202 |
+
config = AutoConfig.from_pretrained(
|
203 |
+
model_args.model_name_or_path,
|
204 |
+
cache_dir=model_args.cache_dir,
|
205 |
+
revision=model_args.model_revision,
|
206 |
+
use_auth_token=use_auth_token,
|
207 |
+
**config_args
|
208 |
+
)
|
209 |
+
|
210 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
211 |
+
model_args.model_name_or_path,
|
212 |
+
cache_dir=model_args.cache_dir,
|
213 |
+
use_fast=model_args.use_fast_tokenizer,
|
214 |
+
revision=model_args.model_revision,
|
215 |
+
use_auth_token=use_auth_token,
|
216 |
+
)
|
217 |
+
|
218 |
+
model_type = AutoModelForSeq2SeqLM if model_type == 'seq2seq' else AutoModelForSequenceClassification
|
219 |
+
model = model_type.from_pretrained(
|
220 |
+
model_args.model_name_or_path,
|
221 |
+
config=config,
|
222 |
+
cache_dir=model_args.cache_dir,
|
223 |
+
revision=model_args.model_revision,
|
224 |
+
use_auth_token=use_auth_token,
|
225 |
+
)
|
226 |
+
|
227 |
+
# Add custom tokens
|
228 |
+
CustomTokens.add_custom_tokens(tokenizer)
|
229 |
+
model.resize_token_embeddings(len(tokenizer))
|
230 |
+
|
231 |
+
# Potentially move model to gpu
|
232 |
+
if general_args is not None and not general_args.no_cuda:
|
233 |
+
model.to('cuda' if torch.cuda.is_available() else 'cpu')
|
234 |
+
|
235 |
+
return model, tokenizer
|
src/predict.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import HfArgumentParser
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
import logging
|
5 |
+
from shared import CustomTokens, extract_sponsor_matches, GeneralArguments, seconds_to_time
|
6 |
+
from segment import (
|
7 |
+
generate_segments,
|
8 |
+
extract_segment,
|
9 |
+
MIN_SAFETY_TOKENS,
|
10 |
+
SAFETY_TOKENS_PERCENTAGE,
|
11 |
+
word_start,
|
12 |
+
word_end,
|
13 |
+
SegmentationArguments
|
14 |
+
)
|
15 |
+
import preprocess
|
16 |
+
from errors import TranscriptError
|
17 |
+
from model import get_model_tokenizer_classifier, InferenceArguments
|
18 |
+
|
19 |
+
logging.basicConfig()
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class PredictArguments(InferenceArguments):
|
25 |
+
video_id: str = field(
|
26 |
+
default=None,
|
27 |
+
metadata={
|
28 |
+
'help': 'Video to predict segments for'}
|
29 |
+
)
|
30 |
+
|
31 |
+
def __post_init__(self):
|
32 |
+
if self.video_id is not None:
|
33 |
+
self.video_ids.append(self.video_id)
|
34 |
+
|
35 |
+
super().__post_init__()
|
36 |
+
|
37 |
+
|
38 |
+
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
|
39 |
+
MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
|
40 |
+
|
41 |
+
# Any prediction whose start time is <= this will be set to start at 0
|
42 |
+
START_TIME_ZERO_THRESHOLD = 0.08
|
43 |
+
|
44 |
+
|
45 |
+
def filter_and_add_probabilities(predictions, classifier, min_probability):
|
46 |
+
"""Use classifier to filter predictions"""
|
47 |
+
if not predictions:
|
48 |
+
return predictions
|
49 |
+
|
50 |
+
# We update the predicted category from the extractive transformer
|
51 |
+
# if the classifier is confident enough it is another category
|
52 |
+
|
53 |
+
texts = [
|
54 |
+
preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
|
55 |
+
for pred in predictions
|
56 |
+
]
|
57 |
+
classifications = classifier(texts)
|
58 |
+
|
59 |
+
filtered_predictions = []
|
60 |
+
for prediction, probabilities in zip(predictions, classifications):
|
61 |
+
predicted_probabilities = {
|
62 |
+
p['label'].lower(): p['score'] for p in probabilities}
|
63 |
+
|
64 |
+
# Get best category + probability
|
65 |
+
classifier_category = max(
|
66 |
+
predicted_probabilities, key=predicted_probabilities.get)
|
67 |
+
classifier_probability = predicted_probabilities[classifier_category]
|
68 |
+
|
69 |
+
if (prediction['category'] not in predicted_probabilities) \
|
70 |
+
or (classifier_category != 'none' and classifier_probability > 0.5): # TODO make param
|
71 |
+
# Unknown category or we are confident enough to overrule,
|
72 |
+
# so change category to what was predicted by classifier
|
73 |
+
prediction['category'] = classifier_category
|
74 |
+
|
75 |
+
if prediction['category'] == 'none':
|
76 |
+
continue # Ignore if categorised as nothing
|
77 |
+
|
78 |
+
prediction['probability'] = predicted_probabilities[prediction['category']]
|
79 |
+
|
80 |
+
if min_probability is not None and prediction['probability'] < min_probability:
|
81 |
+
continue # Ignore if below threshold
|
82 |
+
|
83 |
+
# TODO add probabilities, but remove None and normalise rest
|
84 |
+
prediction['probabilities'] = predicted_probabilities
|
85 |
+
|
86 |
+
# if prediction['probability'] < classifier_args.min_probability:
|
87 |
+
# continue
|
88 |
+
|
89 |
+
filtered_predictions.append(prediction)
|
90 |
+
|
91 |
+
return filtered_predictions
|
92 |
+
|
93 |
+
|
94 |
+
def predict(video_id, model, tokenizer, segmentation_args, words=None, classifier=None, min_probability=None):
|
95 |
+
# Allow words to be passed in so that we don't have to get the words if we already have them
|
96 |
+
if words is None:
|
97 |
+
words = preprocess.get_words(video_id)
|
98 |
+
if not words:
|
99 |
+
raise TranscriptError('Unable to retrieve transcript')
|
100 |
+
|
101 |
+
segments = generate_segments(
|
102 |
+
words,
|
103 |
+
tokenizer,
|
104 |
+
segmentation_args
|
105 |
+
)
|
106 |
+
|
107 |
+
predictions = segments_to_predictions(segments, model, tokenizer)
|
108 |
+
# Add words back to time_ranges
|
109 |
+
for prediction in predictions:
|
110 |
+
# Stores words in the range
|
111 |
+
prediction['words'] = extract_segment(
|
112 |
+
words, prediction['start'], prediction['end'])
|
113 |
+
|
114 |
+
if classifier is not None:
|
115 |
+
predictions = filter_and_add_probabilities(
|
116 |
+
predictions, classifier, min_probability)
|
117 |
+
|
118 |
+
return predictions
|
119 |
+
|
120 |
+
|
121 |
+
def greedy_match(list, sublist):
|
122 |
+
# Return index and length of longest matching sublist
|
123 |
+
|
124 |
+
best_i = -1
|
125 |
+
best_j = -1
|
126 |
+
best_k = 0
|
127 |
+
|
128 |
+
for i in range(len(list)): # Start position in main list
|
129 |
+
for j in range(len(sublist)): # Start position in sublist
|
130 |
+
for k in range(len(sublist)-j, 0, -1): # Width of sublist window
|
131 |
+
if k > best_k and list[i:i+k] == sublist[j:j+k]:
|
132 |
+
best_i, best_j, best_k = i, j, k
|
133 |
+
break # Since window size decreases
|
134 |
+
|
135 |
+
return best_i, best_j, best_k
|
136 |
+
|
137 |
+
|
138 |
+
def predict_sponsor_from_texts(texts, model, tokenizer):
|
139 |
+
clean_texts = list(map(preprocess.clean_text, texts))
|
140 |
+
return predict_sponsor_from_cleaned_texts(clean_texts, model, tokenizer)
|
141 |
+
|
142 |
+
|
143 |
+
def predict_sponsor_from_cleaned_texts(cleaned_texts, model, tokenizer):
|
144 |
+
"""Given a body of text, predict the words which are part of the sponsor"""
|
145 |
+
model_device = next(model.parameters()).device
|
146 |
+
|
147 |
+
decoded_outputs = []
|
148 |
+
# Do individually, to avoid running out of memory for long videos
|
149 |
+
for cleaned_words in cleaned_texts:
|
150 |
+
text = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value + \
|
151 |
+
' '.join(cleaned_words)
|
152 |
+
input_ids = tokenizer(text, return_tensors='pt',
|
153 |
+
truncation=True).input_ids.to(model_device)
|
154 |
+
|
155 |
+
# Optimise output length so that we do not generate unnecessarily long texts
|
156 |
+
max_out_len = round(min(
|
157 |
+
max(
|
158 |
+
len(input_ids[0])/SAFETY_TOKENS_PERCENTAGE,
|
159 |
+
len(input_ids[0]) + MIN_SAFETY_TOKENS
|
160 |
+
),
|
161 |
+
model.model_dim)
|
162 |
+
)
|
163 |
+
|
164 |
+
outputs = model.generate(input_ids, max_length=max_out_len)
|
165 |
+
decoded_outputs.append(tokenizer.decode(
|
166 |
+
outputs[0], skip_special_tokens=True))
|
167 |
+
|
168 |
+
return decoded_outputs
|
169 |
+
|
170 |
+
|
171 |
+
def segments_to_predictions(segments, model, tokenizer):
|
172 |
+
predicted_time_ranges = []
|
173 |
+
|
174 |
+
cleaned_texts = [
|
175 |
+
[x['cleaned'] for x in cleaned_segment]
|
176 |
+
for cleaned_segment in segments
|
177 |
+
]
|
178 |
+
|
179 |
+
sponsorship_texts = predict_sponsor_from_cleaned_texts(
|
180 |
+
cleaned_texts, model, tokenizer)
|
181 |
+
|
182 |
+
matches = extract_sponsor_matches(sponsorship_texts)
|
183 |
+
|
184 |
+
for segment_matches, cleaned_batch, segment in zip(matches, cleaned_texts, segments):
|
185 |
+
|
186 |
+
for match in segment_matches: # one segment might contain multiple sponsors/ir/selfpromos
|
187 |
+
|
188 |
+
matched_text = match['text'].split()
|
189 |
+
|
190 |
+
i1, j1, k1 = greedy_match(
|
191 |
+
cleaned_batch, matched_text[:MATCH_WINDOW])
|
192 |
+
i2, j2, k2 = greedy_match(
|
193 |
+
cleaned_batch, matched_text[-MATCH_WINDOW:])
|
194 |
+
|
195 |
+
extracted_words = segment[i1:i2+k2]
|
196 |
+
if not extracted_words:
|
197 |
+
continue
|
198 |
+
|
199 |
+
predicted_time_ranges.append({
|
200 |
+
'start': word_start(extracted_words[0]),
|
201 |
+
'end': word_end(extracted_words[-1]),
|
202 |
+
'category': match['category']
|
203 |
+
})
|
204 |
+
|
205 |
+
# Necessary to sort matches by start time
|
206 |
+
predicted_time_ranges.sort(key=word_start)
|
207 |
+
|
208 |
+
# Merge overlapping predictions and sponsorships that are close together
|
209 |
+
# Caused by model having max input size
|
210 |
+
|
211 |
+
prev_prediction = None
|
212 |
+
|
213 |
+
final_predicted_time_ranges = []
|
214 |
+
for range in predicted_time_ranges:
|
215 |
+
start_time = range['start'] if range['start'] > START_TIME_ZERO_THRESHOLD else 0
|
216 |
+
end_time = range['end']
|
217 |
+
|
218 |
+
if prev_prediction is not None and \
|
219 |
+
(start_time <= prev_prediction['end'] <= end_time or # Merge overlapping segments
|
220 |
+
(range['category'] == prev_prediction['category'] # Merge disconnected segments if same category and within threshold
|
221 |
+
and start_time - prev_prediction['end'] <= MERGE_TIME_WITHIN)):
|
222 |
+
# Extend last prediction range
|
223 |
+
final_predicted_time_ranges[-1]['end'] = end_time
|
224 |
+
|
225 |
+
else: # No overlap, is a new prediction
|
226 |
+
final_predicted_time_ranges.append({
|
227 |
+
'start': start_time,
|
228 |
+
'end': end_time,
|
229 |
+
'category': range['category']
|
230 |
+
})
|
231 |
+
|
232 |
+
prev_prediction = range
|
233 |
+
|
234 |
+
return final_predicted_time_ranges
|
235 |
+
|
236 |
+
|
237 |
+
def main():
|
238 |
+
# Test on unseen data
|
239 |
+
logger.setLevel(logging.DEBUG)
|
240 |
+
|
241 |
+
hf_parser = HfArgumentParser((
|
242 |
+
PredictArguments,
|
243 |
+
SegmentationArguments,
|
244 |
+
GeneralArguments
|
245 |
+
))
|
246 |
+
predict_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()
|
247 |
+
|
248 |
+
if not predict_args.video_ids:
|
249 |
+
logger.error(
|
250 |
+
'No video IDs supplied. Use `--video_id`, `--video_ids`, or `--channel_id`.')
|
251 |
+
return
|
252 |
+
|
253 |
+
model, tokenizer, classifier = get_model_tokenizer_classifier(
|
254 |
+
predict_args, general_args)
|
255 |
+
|
256 |
+
for video_id in predict_args.video_ids:
|
257 |
+
try:
|
258 |
+
predictions = predict(video_id, model, tokenizer, segmentation_args,
|
259 |
+
classifier=classifier,
|
260 |
+
min_probability=predict_args.min_probability)
|
261 |
+
except TranscriptError:
|
262 |
+
logger.warning(f'No transcript available for {video_id}')
|
263 |
+
continue
|
264 |
+
video_url = f'https://www.youtube.com/watch?v={video_id}'
|
265 |
+
if not predictions:
|
266 |
+
logger.info(f'No predictions found for {video_url}')
|
267 |
+
continue
|
268 |
+
|
269 |
+
# TODO use predict_args.output_as_json
|
270 |
+
print(len(predictions), 'predictions found for', video_url)
|
271 |
+
for index, prediction in enumerate(predictions, start=1):
|
272 |
+
print(f'Prediction #{index}:')
|
273 |
+
print('Text: "',
|
274 |
+
' '.join([w['text'] for w in prediction['words']]), '"', sep='')
|
275 |
+
print('Time:', seconds_to_time(
|
276 |
+
prediction['start']), '\u2192', seconds_to_time(prediction['end']))
|
277 |
+
print('Category:', prediction.get('category'))
|
278 |
+
if 'probability' in prediction:
|
279 |
+
print('Probability:', prediction['probability'])
|
280 |
+
print()
|
281 |
+
print()
|
282 |
+
|
283 |
+
|
284 |
+
if __name__ == '__main__':
|
285 |
+
main()
|
src/preprocess.py
ADDED
@@ -0,0 +1,979 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from shared import DatasetArguments
|
2 |
+
from utils import jaccard
|
3 |
+
from functools import lru_cache
|
4 |
+
from datetime import datetime
|
5 |
+
import itertools
|
6 |
+
from typing import Optional
|
7 |
+
import model as model_module
|
8 |
+
import segment
|
9 |
+
from tqdm import tqdm
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
from transformers import HfArgumentParser
|
12 |
+
from shared import extract_sponsor_matches_from_text, ACTION_OPTIONS, CATEGORIES, CATGEGORY_OPTIONS, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, CustomTokens
|
13 |
+
import csv
|
14 |
+
import re
|
15 |
+
import random
|
16 |
+
import logging
|
17 |
+
from youtube_transcript_api import YouTubeTranscriptApi, CouldNotRetrieveTranscript, YouTubeRequestFailed, TooManyRequests
|
18 |
+
import os
|
19 |
+
import json
|
20 |
+
import time
|
21 |
+
import requests
|
22 |
+
|
23 |
+
|
24 |
+
logging.basicConfig()
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
|
29 |
+
PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
|
30 |
+
|
31 |
+
|
32 |
+
NUM_DECIMALS = 3
|
33 |
+
|
34 |
+
# https://www.fincher.org/Utilities/CountryLanguageList.shtml
|
35 |
+
# https://lingohub.com/developers/supported-locales/language-designators-with-regions
|
36 |
+
LANGUAGE_PREFERENCE_LIST = ['en-GB', 'en-US', 'en-CA', 'en-AU', 'en-NZ', 'en-ZA',
|
37 |
+
'en-IE', 'en-IN', 'en-JM', 'en-BZ', 'en-TT', 'en-PH', 'en-ZW',
|
38 |
+
'en']
|
39 |
+
|
40 |
+
|
41 |
+
def parse_transcript_json(json_data, granularity):
|
42 |
+
assert json_data['wireMagic'] == 'pb3'
|
43 |
+
|
44 |
+
assert granularity in ('word', 'chunk')
|
45 |
+
|
46 |
+
# TODO remove bracketed words?
|
47 |
+
# (kiss smacks)
|
48 |
+
# (upbeat music)
|
49 |
+
# [text goes here]
|
50 |
+
|
51 |
+
# Some manual transcripts aren't that well formatted... but do have punctuation
|
52 |
+
# https://www.youtube.com/watch?v=LR9FtWVjk2c
|
53 |
+
|
54 |
+
parsed_transcript = []
|
55 |
+
|
56 |
+
events = json_data['events']
|
57 |
+
|
58 |
+
for event_index, event in enumerate(events):
|
59 |
+
segments = event.get('segs')
|
60 |
+
if not segments:
|
61 |
+
continue
|
62 |
+
|
63 |
+
# This value is known (when phrase appears on screen)
|
64 |
+
start_ms = event['tStartMs']
|
65 |
+
total_characters = 0
|
66 |
+
|
67 |
+
new_segments = []
|
68 |
+
for seg in segments:
|
69 |
+
# Replace \n, \t, etc. with space
|
70 |
+
text = ' '.join(seg['utf8'].split())
|
71 |
+
|
72 |
+
# Remove zero-width spaces and strip trailing and leading whitespace
|
73 |
+
text = text.replace('\u200b', '').replace('\u200c', '').replace(
|
74 |
+
'\u200d', '').replace('\ufeff', '').strip()
|
75 |
+
|
76 |
+
# Alternatively,
|
77 |
+
# text = text.encode('ascii', 'ignore').decode()
|
78 |
+
|
79 |
+
# Needed for auto-generated transcripts
|
80 |
+
text = text.replace(PROFANITY_RAW, PROFANITY_CONVERTED)
|
81 |
+
|
82 |
+
if not text:
|
83 |
+
continue
|
84 |
+
|
85 |
+
offset_ms = seg.get('tOffsetMs', 0)
|
86 |
+
|
87 |
+
new_segments.append({
|
88 |
+
'text': text,
|
89 |
+
'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS)
|
90 |
+
})
|
91 |
+
|
92 |
+
total_characters += len(text)
|
93 |
+
|
94 |
+
if not new_segments:
|
95 |
+
continue
|
96 |
+
|
97 |
+
if event_index < len(events) - 1:
|
98 |
+
next_start_ms = events[event_index + 1]['tStartMs']
|
99 |
+
total_event_duration_ms = min(
|
100 |
+
event.get('dDurationMs', float('inf')), next_start_ms - start_ms)
|
101 |
+
else:
|
102 |
+
total_event_duration_ms = event.get('dDurationMs', 0)
|
103 |
+
|
104 |
+
# Ensure duration is non-negative
|
105 |
+
total_event_duration_ms = max(total_event_duration_ms, 0)
|
106 |
+
|
107 |
+
avg_seconds_per_character = (
|
108 |
+
total_event_duration_ms/total_characters)/1000
|
109 |
+
|
110 |
+
num_char_count = 0
|
111 |
+
for seg_index, seg in enumerate(new_segments):
|
112 |
+
num_char_count += len(seg['text'])
|
113 |
+
|
114 |
+
# Estimate segment end
|
115 |
+
seg_end = seg['start'] + \
|
116 |
+
(num_char_count * avg_seconds_per_character)
|
117 |
+
|
118 |
+
if seg_index < len(new_segments) - 1:
|
119 |
+
# Do not allow longer than next
|
120 |
+
seg_end = min(seg_end, new_segments[seg_index+1]['start'])
|
121 |
+
|
122 |
+
seg['end'] = round(seg_end, NUM_DECIMALS)
|
123 |
+
parsed_transcript.append(seg)
|
124 |
+
|
125 |
+
final_parsed_transcript = []
|
126 |
+
for i in range(len(parsed_transcript)):
|
127 |
+
|
128 |
+
word_level = granularity == 'word'
|
129 |
+
if word_level:
|
130 |
+
split_text = parsed_transcript[i]['text'].split()
|
131 |
+
elif granularity == 'chunk':
|
132 |
+
# Split on space after punctuation
|
133 |
+
split_text = re.split(
|
134 |
+
r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text'])
|
135 |
+
if len(split_text) == 1:
|
136 |
+
split_on_whitespace = parsed_transcript[i]['text'].split()
|
137 |
+
|
138 |
+
if len(split_on_whitespace) >= 8: # Too many words
|
139 |
+
# Rather split on whitespace instead of punctuation
|
140 |
+
split_text = split_on_whitespace
|
141 |
+
else:
|
142 |
+
word_level = True
|
143 |
+
else:
|
144 |
+
raise ValueError('Unknown granularity')
|
145 |
+
|
146 |
+
segment_end = parsed_transcript[i]['end']
|
147 |
+
if i < len(parsed_transcript) - 1:
|
148 |
+
segment_end = min(segment_end, parsed_transcript[i+1]['start'])
|
149 |
+
|
150 |
+
segment_duration = segment_end - parsed_transcript[i]['start']
|
151 |
+
|
152 |
+
num_chars_in_text = sum(map(len, split_text))
|
153 |
+
|
154 |
+
num_char_count = 0
|
155 |
+
current_offset = 0
|
156 |
+
for s in split_text:
|
157 |
+
num_char_count += len(s)
|
158 |
+
|
159 |
+
next_offset = (num_char_count/num_chars_in_text) * segment_duration
|
160 |
+
|
161 |
+
word_start = round(
|
162 |
+
parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS)
|
163 |
+
word_end = round(
|
164 |
+
parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS)
|
165 |
+
|
166 |
+
# Make the reasonable assumption that min wps is 1.5
|
167 |
+
final_parsed_transcript.append({
|
168 |
+
'text': s,
|
169 |
+
'start': word_start,
|
170 |
+
'end': min(word_end, word_start + 1.5) if word_level else word_end
|
171 |
+
})
|
172 |
+
current_offset = next_offset
|
173 |
+
|
174 |
+
return final_parsed_transcript
|
175 |
+
|
176 |
+
|
177 |
+
def list_transcripts(video_id):
|
178 |
+
try:
|
179 |
+
return YouTubeTranscriptApi.list_transcripts(video_id)
|
180 |
+
except json.decoder.JSONDecodeError:
|
181 |
+
return None
|
182 |
+
|
183 |
+
|
184 |
+
WORDS_TO_REMOVE = [
|
185 |
+
CustomTokens.MUSIC.value,
|
186 |
+
CustomTokens.APPLAUSE.value,
|
187 |
+
CustomTokens.LAUGHTER.value
|
188 |
+
]
|
189 |
+
|
190 |
+
|
191 |
+
@lru_cache(maxsize=16)
|
192 |
+
def get_words(video_id, process=True, transcript_type='auto', fallback='manual', filter_words_to_remove=True, download=False, granularity='word'):
|
193 |
+
"""Get parsed video transcript with caching system
|
194 |
+
returns None if not processed yet and process is False
|
195 |
+
"""
|
196 |
+
# NOTE: granularity='chunk' should only be used for generating training data... nowhere else
|
197 |
+
|
198 |
+
transcript_path = os.path.join( # TODO use relative path to this
|
199 |
+
'transcripts', transcript_type, f'{video_id}.json')
|
200 |
+
|
201 |
+
raw_transcript_json = None
|
202 |
+
try:
|
203 |
+
if not download and os.path.exists(transcript_path): # Load from file
|
204 |
+
with open(transcript_path) as fp:
|
205 |
+
raw_transcript_json = json.load(fp) # May be empty
|
206 |
+
|
207 |
+
elif process:
|
208 |
+
transcript_list = list_transcripts(video_id)
|
209 |
+
|
210 |
+
if transcript_list is not None:
|
211 |
+
if transcript_type == 'manual':
|
212 |
+
ts = transcript_list.find_manually_created_transcript(
|
213 |
+
LANGUAGE_PREFERENCE_LIST)
|
214 |
+
else:
|
215 |
+
ts = transcript_list.find_generated_transcript(
|
216 |
+
LANGUAGE_PREFERENCE_LIST)
|
217 |
+
raw_transcript = ts._http_client.get(
|
218 |
+
f'{ts._url}&fmt=json3').content
|
219 |
+
if raw_transcript:
|
220 |
+
raw_transcript_json = json.loads(raw_transcript)
|
221 |
+
|
222 |
+
except (TooManyRequests, YouTubeRequestFailed):
|
223 |
+
raise # Cannot recover from these errors and do not mark as empty transcript
|
224 |
+
|
225 |
+
except requests.exceptions.RequestException: # Can recover
|
226 |
+
time.sleep(10) # Timeout
|
227 |
+
return get_words(video_id, process, transcript_type, fallback, granularity)
|
228 |
+
|
229 |
+
except CouldNotRetrieveTranscript: # Retrying won't solve
|
230 |
+
pass # Mark as empty transcript
|
231 |
+
|
232 |
+
except json.decoder.JSONDecodeError:
|
233 |
+
logger.warning(f'JSONDecodeError for {video_id}')
|
234 |
+
if os.path.exists(transcript_path):
|
235 |
+
os.remove(transcript_path) # Remove file and try again
|
236 |
+
return get_words(video_id, process, transcript_type, fallback, granularity)
|
237 |
+
|
238 |
+
# Tried to process it, but it was empty...
|
239 |
+
if download or (process and not os.path.exists(transcript_path)):
|
240 |
+
with open(transcript_path, 'w') as fp:
|
241 |
+
json.dump(raw_transcript_json, fp)
|
242 |
+
|
243 |
+
if not raw_transcript_json and fallback is not None:
|
244 |
+
return get_words(video_id, process, transcript_type=fallback, fallback=None, granularity=granularity)
|
245 |
+
|
246 |
+
if raw_transcript_json:
|
247 |
+
processed_transcript = parse_transcript_json(
|
248 |
+
raw_transcript_json, granularity)
|
249 |
+
if filter_words_to_remove:
|
250 |
+
processed_transcript = list(
|
251 |
+
filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript))
|
252 |
+
else:
|
253 |
+
processed_transcript = raw_transcript_json # Either None or []
|
254 |
+
|
255 |
+
return processed_transcript
|
256 |
+
|
257 |
+
|
258 |
+
# TODO make min_sponsor_segment_length param
|
259 |
+
# TODO rename to extract_segments
|
260 |
+
def extract_sponsors(words, min_sponsor_segment_length=3):
|
261 |
+
if not words:
|
262 |
+
return []
|
263 |
+
|
264 |
+
paragraphs = []
|
265 |
+
current = []
|
266 |
+
prev_category = None
|
267 |
+
|
268 |
+
for i in range(len(words) + 1):
|
269 |
+
unimportant = i == len(words) or words[i].get('category') is None
|
270 |
+
|
271 |
+
if unimportant or words[i].get('category') != prev_category:
|
272 |
+
if current: # Save the current batch
|
273 |
+
paragraphs.append({
|
274 |
+
'words': current,
|
275 |
+
'category': current[-1].get('category'),
|
276 |
+
})
|
277 |
+
|
278 |
+
current = []
|
279 |
+
|
280 |
+
if not unimportant: # Some useful information to save
|
281 |
+
current.append(words[i])
|
282 |
+
prev_category = words[i].get('category')
|
283 |
+
|
284 |
+
# Remove all too short:
|
285 |
+
return list(filter(lambda x: len(x['words']) >= min_sponsor_segment_length, paragraphs))
|
286 |
+
|
287 |
+
|
288 |
+
def clean_text(text):
|
289 |
+
|
290 |
+
# Replace impossibly long words with a special token
|
291 |
+
# Usually the result of incorrect labelling
|
292 |
+
text = re.sub(r'\w{64,}', CustomTokens.LONG_WORD.value, text)
|
293 |
+
|
294 |
+
SHORT_HYPHENATED_REGEX = r'\w{1,2}(?:-\w{1,2}){3,}(?:-?\w*)'
|
295 |
+
|
296 |
+
# Replace hyphenated URLs with special token
|
297 |
+
# For some reason, youtube sometimes transcribes urls in this form:
|
298 |
+
# 'b-a-b-b-e-l-dot-com', 'g-e-t-r-o-m-a-n-com'
|
299 |
+
# not 'e-commerce'
|
300 |
+
text = re.sub(f'{SHORT_HYPHENATED_REGEX}(?:com|org|net)',
|
301 |
+
CustomTokens.HYPHENATED_URL.value, text)
|
302 |
+
|
303 |
+
# Replace short+hyphenated text with a special token. Of the form:
|
304 |
+
# 'i-i-i-i-i-i-i-i-i-i-i-i', 'b-u-m-f-u-z-z-l-e', 'v-e-r-i-t-a-s-i-u-m', 'do-do-do-do-do'
|
305 |
+
text = re.sub(SHORT_HYPHENATED_REGEX,
|
306 |
+
CustomTokens.SHORT_HYPHENATED.value, text)
|
307 |
+
|
308 |
+
# Replace URLs with URL_TOKEN
|
309 |
+
URL_REGEX = r'(?:(?:http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.(?:[a-zA-Z]){2,6}(?:[a-zA-Z0-9\.\&\/\?\:@\-_=#%])*'
|
310 |
+
text = re.sub(URL_REGEX, CustomTokens.URL.value, text)
|
311 |
+
|
312 |
+
NUM_REGEX = r'(?:\d+,)*(?:\d*[.])?\d+'
|
313 |
+
|
314 |
+
# Encode specific numeric words
|
315 |
+
# Of the form: 12%, 12.34%
|
316 |
+
# Usually included in sponsorships
|
317 |
+
text = re.sub(f'{NUM_REGEX}%',
|
318 |
+
CustomTokens.NUMBER_PERCENTAGE.value, text)
|
319 |
+
|
320 |
+
# Normal numbers, should not have an effect on sponsorship
|
321 |
+
text = re.sub(NUM_REGEX, CustomTokens.NUMBER.value, text)
|
322 |
+
|
323 |
+
# Replace profanity with special token
|
324 |
+
text = text.replace(PROFANITY_RAW, CustomTokens.PROFANITY.value)
|
325 |
+
text = text.replace(PROFANITY_CONVERTED, CustomTokens.PROFANITY.value)
|
326 |
+
|
327 |
+
return text.strip()
|
328 |
+
|
329 |
+
|
330 |
+
def remove_duplicate_segments(segments):
|
331 |
+
# Algorithm based on SponsorBlock algorithm
|
332 |
+
# https://blog.ajay.app/voting-and-pseudo-randomness-or-sponsorblock-or-youtube-sponsorship-segment-blocker
|
333 |
+
# Find sponsors that are overlapping
|
334 |
+
|
335 |
+
best = []
|
336 |
+
for i in segments:
|
337 |
+
similar_segments = []
|
338 |
+
for j in segments:
|
339 |
+
if jaccard(i['start'], i['end'], j['start'], j['end']) > 0.1: # Some overlap
|
340 |
+
similar_segments.append(j)
|
341 |
+
|
342 |
+
if similar_segments:
|
343 |
+
best_similar_seg = max(similar_segments, key=lambda item: (
|
344 |
+
item['locked'],
|
345 |
+
item['votes'],
|
346 |
+
item['views'],
|
347 |
+
item['reputation']
|
348 |
+
))
|
349 |
+
if best_similar_seg not in best:
|
350 |
+
best.append(best_similar_seg)
|
351 |
+
|
352 |
+
if len(segments) != len(best): # Saw some reduction... try again
|
353 |
+
return remove_duplicate_segments(best)
|
354 |
+
|
355 |
+
return best
|
356 |
+
|
357 |
+
|
358 |
+
@dataclass
|
359 |
+
class PreprocessArguments:
|
360 |
+
"""
|
361 |
+
Arguments pertaining to what data we are going to preprocess.
|
362 |
+
"""
|
363 |
+
update_database: bool = field(
|
364 |
+
default=False, metadata={'help': 'Download the raw database.'}
|
365 |
+
)
|
366 |
+
|
367 |
+
do_create: bool = field(
|
368 |
+
default=False, metadata={'help': 'Merge sponsor segments into single file'}
|
369 |
+
)
|
370 |
+
|
371 |
+
min_votes: int = field(
|
372 |
+
default=0, metadata={'help': 'Minimum number of votes'})
|
373 |
+
# Downvotes will make this negative.
|
374 |
+
# 1 = At least one positive vote
|
375 |
+
|
376 |
+
max_segment_duration: float = field(
|
377 |
+
default=180, # 3 minutes
|
378 |
+
# >180 => 2.8%
|
379 |
+
# >200 => 2.1%
|
380 |
+
# >250 => 1.1%
|
381 |
+
# >300 => 0.06%
|
382 |
+
metadata={'help': 'Ignore all segments whose duration in seconds is longer than this value (negative means no limit)'})
|
383 |
+
|
384 |
+
min_views: int = field(
|
385 |
+
default=5, metadata={'help': 'Minimum number of views a segment must have to be considered. 0 = show all'})
|
386 |
+
|
387 |
+
# min_reputation: int = field(
|
388 |
+
# default=0, metadata={'help': 'Minimum reputation a user must have for the segment to be included'})
|
389 |
+
|
390 |
+
min_date: str = field(
|
391 |
+
# default='08/06/2020', # release of v2.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/2.0)
|
392 |
+
# release of v3.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/3.0)
|
393 |
+
default='20/08/2021',
|
394 |
+
# default='01/10/2020', # No more autovote
|
395 |
+
metadata={'help': 'Only use submissions from after this date (inclusive)'})
|
396 |
+
|
397 |
+
max_date: str = field(
|
398 |
+
# default='01/01/9999', # Include all
|
399 |
+
default='15/04/2022',
|
400 |
+
metadata={'help': 'Only use videos that have some segment from before this date (exclusive). This allows for videos to have segments be corrected, but ignores new videos (posted after this date) to enter the pool.'})
|
401 |
+
|
402 |
+
# max_unseen_date: str = field( # TODO
|
403 |
+
# default='02/03/2022',
|
404 |
+
# metadata={'help': 'Generate test and validation data from `max_date` to `max_unseen_date`'})
|
405 |
+
# Specify min/max video id for splitting (seen vs. unseen)
|
406 |
+
|
407 |
+
keep_duplicate_segments: bool = field(
|
408 |
+
default=False, metadata={'help': 'Keep duplicate segments'}
|
409 |
+
)
|
410 |
+
|
411 |
+
do_process_database: bool = field(
|
412 |
+
default=False, metadata={'help': 'Process the raw database'}
|
413 |
+
)
|
414 |
+
do_transcribe: bool = field(
|
415 |
+
default=False, metadata={'help': 'Get transcripts for videos'}
|
416 |
+
)
|
417 |
+
num_jobs: int = field(
|
418 |
+
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
419 |
+
|
420 |
+
# overwrite: bool = field(
|
421 |
+
# default=False, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
|
422 |
+
# )
|
423 |
+
|
424 |
+
do_generate: bool = field(
|
425 |
+
default=False, metadata={'help': 'Generate labelled data.'}
|
426 |
+
)
|
427 |
+
|
428 |
+
do_split: bool = field(
|
429 |
+
default=False, metadata={'help': 'Generate training, testing and validation data.'}
|
430 |
+
)
|
431 |
+
|
432 |
+
positive_file: Optional[str] = field(
|
433 |
+
default='sponsor_segments.json', metadata={'help': 'File to output sponsored segments to (a jsonlines file).'}
|
434 |
+
)
|
435 |
+
negative_file: Optional[str] = field(
|
436 |
+
default='normal_segments.json', metadata={'help': 'File to output normal segments to (a jsonlines file).'}
|
437 |
+
)
|
438 |
+
|
439 |
+
percentage_positive: float = field(
|
440 |
+
default=0.5, metadata={'help': 'Ratio of positive (sponsor) segments to include in final output'})
|
441 |
+
|
442 |
+
train_split: float = field(
|
443 |
+
default=0.9, metadata={'help': 'Ratio of training data. Value between 0 and 1.'})
|
444 |
+
|
445 |
+
# TODO play around with ratios? lower test/validation split?
|
446 |
+
test_split: float = field(
|
447 |
+
default=0.05, metadata={'help': 'Ratio of testing data. Value between 0 and 1.'})
|
448 |
+
valid_split: float = field(
|
449 |
+
default=0.05, metadata={'help': 'Ratio of validation data. Value between 0 and 1.'})
|
450 |
+
|
451 |
+
start_index: int = field(default=None, metadata={
|
452 |
+
'help': 'Video to start at.'})
|
453 |
+
|
454 |
+
max_videos: int = field(default=None, metadata={
|
455 |
+
'help': 'Maximum number of videos to preprocess.'})
|
456 |
+
|
457 |
+
max_segments: int = field(default=None, metadata={
|
458 |
+
'help': 'Maximum number of segments to produce to preprocess.'})
|
459 |
+
|
460 |
+
raw_data_dir: Optional[str] = field(
|
461 |
+
default='raw',
|
462 |
+
metadata={
|
463 |
+
'help': 'Raw data directory'
|
464 |
+
},
|
465 |
+
)
|
466 |
+
raw_data_file: Optional[str] = field(
|
467 |
+
default='sponsorTimes.csv',
|
468 |
+
metadata={
|
469 |
+
'help': 'Raw data file'
|
470 |
+
},
|
471 |
+
)
|
472 |
+
|
473 |
+
min_wps: float = field(
|
474 |
+
default=1.5, metadata={'help': 'Ignore videos with not enough words spoken per second. This is usually indicitive of video whose captions aren\'t English.'})
|
475 |
+
# 0.1 ~ 1%
|
476 |
+
# 0.4 ~ 2.5%
|
477 |
+
# 0.9 ~ 5%
|
478 |
+
|
479 |
+
|
480 |
+
# Mirrors for database
|
481 |
+
MIRRORS = [
|
482 |
+
'https://sponsor.ajay.app/database/sponsorTimes.csv', # Latest
|
483 |
+
'https://sb-mirror.mchang.xyz/sponsorTimes.csv', # 5 minute delay
|
484 |
+
'https://sb.ltn.fi/database/sponsorTimes.csv', # 5 minute delay
|
485 |
+
]
|
486 |
+
# TODO only download latest updates/changes
|
487 |
+
|
488 |
+
|
489 |
+
def download_file(url, filename):
|
490 |
+
"""
|
491 |
+
Helper method handling downloading large files from `url` to `filename`.
|
492 |
+
|
493 |
+
Adapted from https://stackoverflow.com/a/42071418
|
494 |
+
"""
|
495 |
+
chunk_size = 1024
|
496 |
+
r = requests.get(url, stream=True)
|
497 |
+
total_bytes = int(r.headers['Content-Length'])
|
498 |
+
with open(filename, 'wb') as f, tqdm(unit='B', total=total_bytes) as progress:
|
499 |
+
for chunk in r.iter_content(chunk_size=chunk_size):
|
500 |
+
if chunk: # filter out keep-alive new chunks
|
501 |
+
progress.update(len(chunk))
|
502 |
+
f.write(chunk)
|
503 |
+
|
504 |
+
return total_bytes == os.path.getsize(filename)
|
505 |
+
|
506 |
+
|
507 |
+
def main():
|
508 |
+
# Responsible for getting transcrips using youtube_transcript_api,
|
509 |
+
# then labelling it according to SponsorBlock's API
|
510 |
+
logger.setLevel(logging.DEBUG)
|
511 |
+
|
512 |
+
# Generate final.json from sponsorTimes.csv
|
513 |
+
hf_parser = HfArgumentParser((
|
514 |
+
PreprocessArguments,
|
515 |
+
DatasetArguments,
|
516 |
+
segment.SegmentationArguments,
|
517 |
+
model_module.ModelArguments,
|
518 |
+
GeneralArguments
|
519 |
+
))
|
520 |
+
preprocess_args, dataset_args, segmentation_args, model_args, general_args = hf_parser.parse_args_into_dataclasses()
|
521 |
+
|
522 |
+
raw_dataset_path = os.path.join(
|
523 |
+
preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
|
524 |
+
|
525 |
+
if preprocess_args.update_database:
|
526 |
+
logger.info('Updating database')
|
527 |
+
for mirror in MIRRORS:
|
528 |
+
logger.info(f'Downloading from {mirror}')
|
529 |
+
if download_file(mirror, raw_dataset_path):
|
530 |
+
break
|
531 |
+
logger.warning('Failed, trying next')
|
532 |
+
|
533 |
+
os.makedirs(dataset_args.data_dir, exist_ok=True)
|
534 |
+
processed_db_path = os.path.join(
|
535 |
+
dataset_args.data_dir, dataset_args.processed_database)
|
536 |
+
|
537 |
+
# TODO process all valid possible items and then do filtering only later
|
538 |
+
@lru_cache(maxsize=1)
|
539 |
+
def read_db():
|
540 |
+
# if not preprocess_args.overwrite and os.path.exists(processed_db_path):
|
541 |
+
# logger.info(
|
542 |
+
# 'Using cached processed database (use `--overwrite` to avoid this behaviour).')
|
543 |
+
# with open(processed_db_path) as fp:
|
544 |
+
# return json.load(fp)
|
545 |
+
logger.info('Processing raw database')
|
546 |
+
db = {}
|
547 |
+
|
548 |
+
allowed_categories = list(map(str.lower, CATGEGORY_OPTIONS))
|
549 |
+
with open(raw_dataset_path, newline='') as csvfile:
|
550 |
+
reader = csv.DictReader(csvfile)
|
551 |
+
|
552 |
+
for line in reader:
|
553 |
+
|
554 |
+
# Never show:
|
555 |
+
if line['service'] != 'YouTube':
|
556 |
+
continue
|
557 |
+
if len(line['videoID']) != 11:
|
558 |
+
continue # Invalid youtube video ID
|
559 |
+
|
560 |
+
if line['category'] not in allowed_categories:
|
561 |
+
continue
|
562 |
+
if line['actionType'] not in ACTION_OPTIONS:
|
563 |
+
continue
|
564 |
+
|
565 |
+
# Ignore hidden items
|
566 |
+
if line['hidden'] == '1' or line['shadowHidden'] == '1':
|
567 |
+
continue
|
568 |
+
|
569 |
+
# Skip those that aren't highly voted
|
570 |
+
votes = int(line['votes'])
|
571 |
+
if votes < preprocess_args.min_votes:
|
572 |
+
continue
|
573 |
+
|
574 |
+
locked = line['locked'] == '1'
|
575 |
+
|
576 |
+
reputation = float(line['reputation'])
|
577 |
+
# if reputation < preprocess_args.min_reputation:
|
578 |
+
# continue # TODO add back?
|
579 |
+
# Problems like mGVn1wCkBrE
|
580 |
+
|
581 |
+
# TODO ignore if over max_duration
|
582 |
+
|
583 |
+
if line['videoID'] not in db:
|
584 |
+
db[line['videoID']] = []
|
585 |
+
|
586 |
+
db[line['videoID']].append({
|
587 |
+
'uuid': line['UUID'],
|
588 |
+
'start': float(line['startTime']),
|
589 |
+
'end': float(line['endTime']),
|
590 |
+
'votes': votes,
|
591 |
+
'locked': locked,
|
592 |
+
'views': int(line['views']),
|
593 |
+
'submission_time': float(line['timeSubmitted'])/1e3,
|
594 |
+
'reputation': reputation,
|
595 |
+
'category': line['category'],
|
596 |
+
'action': line['actionType'],
|
597 |
+
})
|
598 |
+
|
599 |
+
# First, remove videos that contain a full-video label
|
600 |
+
# (may confuse model since disclaimers and such aren't labelled)
|
601 |
+
# Must do it here before removing duplicate segments
|
602 |
+
for key in list(db):
|
603 |
+
if any(x['action'] == 'full' for x in db[key]):
|
604 |
+
del db[key]
|
605 |
+
|
606 |
+
# Remove duplicate sponsor segments by choosing best (most votes)
|
607 |
+
if not preprocess_args.keep_duplicate_segments:
|
608 |
+
logger.info('Remove duplicate segments')
|
609 |
+
for key in db:
|
610 |
+
db[key] = remove_duplicate_segments(db[key])
|
611 |
+
|
612 |
+
# We now remove whole videos from the list
|
613 |
+
# Helps with obtaining "fully-labelled" videos
|
614 |
+
min_date = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y')
|
615 |
+
max_date = datetime.strptime(preprocess_args.max_date, '%d/%m/%Y')
|
616 |
+
for key in list(db):
|
617 |
+
if preprocess_args.max_segment_duration >= 0 and any(x['end'] - x['start'] > preprocess_args.max_segment_duration for x in db[key]):
|
618 |
+
# Remove videos that have at least one segment that is longer than
|
619 |
+
# the maximum allowed segment duration. This avoids introducing
|
620 |
+
# segments into training that might contain ignored context (since
|
621 |
+
# they are too long, so the middle might be normal content)
|
622 |
+
del db[key]
|
623 |
+
elif any(datetime.fromtimestamp(x['submission_time']) < min_date for x in db[key]):
|
624 |
+
# Remove videos where any of its segments were submitted before min_date
|
625 |
+
# (essentially removes videos uploaded before min_date)
|
626 |
+
# Prevents issues where some segments of a video are excluded
|
627 |
+
del db[key]
|
628 |
+
elif all(datetime.fromtimestamp(x['submission_time']) > max_date for x in db[key]):
|
629 |
+
# Remove videos where all of its segments were submitted after max_date
|
630 |
+
# (essentially removes videos uploaded after max_date)
|
631 |
+
# Allows for segments to be corrected for past videos
|
632 |
+
del db[key]
|
633 |
+
elif any(not x['locked'] and x['views'] < preprocess_args.min_views for x in db[key]):
|
634 |
+
# Remove videos where any of its non-locked segments do not have enough views
|
635 |
+
# (essentially skips videos that have not been fully watched/reviewed)
|
636 |
+
# Always include segments locked by VIPs, regardless of view count
|
637 |
+
del db[key]
|
638 |
+
|
639 |
+
logger.info(f'Saved {len(db)} videos')
|
640 |
+
|
641 |
+
with open(processed_db_path, 'w') as fp:
|
642 |
+
json.dump(db, fp)
|
643 |
+
|
644 |
+
return db
|
645 |
+
|
646 |
+
if preprocess_args.do_process_database:
|
647 |
+
read_db()
|
648 |
+
|
649 |
+
# 'videoID', 'startTime', 'endTime', 'votes', 'locked', 'incorrectVotes', 'UUID',
|
650 |
+
# 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
|
651 |
+
# 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
|
652 |
+
if preprocess_args.do_transcribe:
|
653 |
+
logger.info('Collecting videos')
|
654 |
+
parsed_database = read_db()
|
655 |
+
|
656 |
+
# Remove transcripts already processed
|
657 |
+
finished = set(x.split('.')[0] for x in os.listdir(
|
658 |
+
'transcripts/auto/') + os.listdir('transcripts/manual/'))
|
659 |
+
|
660 |
+
video_ids = list(parsed_database.keys() - finished)
|
661 |
+
|
662 |
+
# https://stackoverflow.com/a/63495323
|
663 |
+
import concurrent
|
664 |
+
POLL_INTERVAL = 0.1
|
665 |
+
|
666 |
+
# Wrap get words function to return video_id after completion
|
667 |
+
def get_words_wrapper(video_id):
|
668 |
+
get_words(video_id)
|
669 |
+
return video_id
|
670 |
+
|
671 |
+
logger.info('Setting up ThreadPoolExecutor')
|
672 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=preprocess_args.num_jobs) as pool, \
|
673 |
+
tqdm(total=len(video_ids)) as progress:
|
674 |
+
|
675 |
+
all_futures = (pool.submit(get_words_wrapper, video_id)
|
676 |
+
for video_id in video_ids)
|
677 |
+
to_process = set(itertools.islice(
|
678 |
+
all_futures, preprocess_args.num_jobs))
|
679 |
+
try:
|
680 |
+
while to_process:
|
681 |
+
just_finished, to_process = concurrent.futures.wait(
|
682 |
+
to_process, timeout=POLL_INTERVAL)
|
683 |
+
to_process |= set(itertools.islice(
|
684 |
+
all_futures, len(just_finished)))
|
685 |
+
|
686 |
+
for d in just_finished:
|
687 |
+
progress.set_description(f'Processed {d.result()}')
|
688 |
+
progress.update()
|
689 |
+
|
690 |
+
except KeyboardInterrupt:
|
691 |
+
logger.info(
|
692 |
+
'Gracefully shutting down: Cancelling unscheduled tasks')
|
693 |
+
|
694 |
+
# only futures that are not done will prevent exiting
|
695 |
+
for future in to_process:
|
696 |
+
future.cancel()
|
697 |
+
|
698 |
+
logger.info('Waiting for in-progress tasks to complete')
|
699 |
+
concurrent.futures.wait(to_process, timeout=None)
|
700 |
+
logger.info('Cancellation successful')
|
701 |
+
|
702 |
+
final_path = os.path.join(
|
703 |
+
dataset_args.data_dir, dataset_args.processed_file)
|
704 |
+
|
705 |
+
if preprocess_args.do_create:
|
706 |
+
logger.info('Create final data')
|
707 |
+
|
708 |
+
final_data = {}
|
709 |
+
|
710 |
+
parsed_database = read_db()
|
711 |
+
|
712 |
+
transcribed = set(x.split('.')[0] for x in os.listdir(
|
713 |
+
'transcripts/auto/') + os.listdir('transcripts/manual/'))
|
714 |
+
|
715 |
+
# Only consider videos that have been transcribed already
|
716 |
+
video_ids = parsed_database.keys() & transcribed
|
717 |
+
|
718 |
+
with tqdm(total=len(video_ids)) as progress:
|
719 |
+
for index, video_id in enumerate(video_ids):
|
720 |
+
if preprocess_args.max_videos is not None and index >= preprocess_args.max_videos:
|
721 |
+
break
|
722 |
+
progress.set_description(f'Processing {video_id}')
|
723 |
+
progress.update()
|
724 |
+
|
725 |
+
video_words = get_words(video_id, process=False)
|
726 |
+
if not video_words:
|
727 |
+
continue
|
728 |
+
|
729 |
+
final_vid_segs = []
|
730 |
+
# Only add segments with high enough wps
|
731 |
+
for seg in parsed_database[video_id]:
|
732 |
+
segment_words = segment.extract_segment(
|
733 |
+
video_words, seg['start'], seg['end'])
|
734 |
+
|
735 |
+
if len(segment_words) <= 1:
|
736 |
+
continue # Useless to add segment since no words
|
737 |
+
|
738 |
+
# duration = segment.word_end(segment_words[-1]) - segment.word_start(segment_words[0])
|
739 |
+
duration = seg['end'] - seg['start']
|
740 |
+
wps = len(segment_words)/duration if duration > 0 else 0
|
741 |
+
|
742 |
+
# print(video_id, wps)
|
743 |
+
if wps < preprocess_args.min_wps:
|
744 |
+
# Skip sponsor segments without many words
|
745 |
+
# e.g. music ads with some words on each side
|
746 |
+
# progress.set_description(f'Skipping bad segment in {video_id} (wps={wps})')
|
747 |
+
continue
|
748 |
+
final_vid_segs.append(seg)
|
749 |
+
|
750 |
+
if final_vid_segs:
|
751 |
+
final_data[video_id] = final_vid_segs
|
752 |
+
|
753 |
+
# Save data
|
754 |
+
with open(final_path, 'w') as fp:
|
755 |
+
json.dump(final_data, fp)
|
756 |
+
|
757 |
+
# final_data = preprocess(
|
758 |
+
# raw_dataset_path, final_path, preprocess_args.min_votes)
|
759 |
+
# # TODO save metadata in final.json?
|
760 |
+
|
761 |
+
elif os.path.exists(final_path):
|
762 |
+
# Already exists
|
763 |
+
logging.info(f'{final_path} exists, opening file')
|
764 |
+
with open(final_path) as fp:
|
765 |
+
final_data = json.load(fp)
|
766 |
+
logging.info(f'Found {len(final_data)} videos')
|
767 |
+
else:
|
768 |
+
return # Do not continue
|
769 |
+
|
770 |
+
# TODO shuffle final_data
|
771 |
+
# if not os.path.exists(excess_path) or preprocess_args.overwrite
|
772 |
+
# TODO use overwrite param
|
773 |
+
|
774 |
+
positive_file = os.path.join(
|
775 |
+
dataset_args.data_dir, preprocess_args.positive_file)
|
776 |
+
negative_file = os.path.join(
|
777 |
+
dataset_args.data_dir, preprocess_args.negative_file)
|
778 |
+
|
779 |
+
if preprocess_args.do_generate:
|
780 |
+
logger.info('Generating')
|
781 |
+
# max_videos=preprocess_args.max_videos,
|
782 |
+
# max_segments=preprocess_args.max_segments,
|
783 |
+
# , max_videos, max_segments
|
784 |
+
|
785 |
+
from model import get_model_tokenizer
|
786 |
+
model, tokenizer = get_model_tokenizer(model_args, general_args)
|
787 |
+
|
788 |
+
# TODO
|
789 |
+
# count_videos = 0
|
790 |
+
# count_segments = 0
|
791 |
+
|
792 |
+
data = final_data.items()
|
793 |
+
|
794 |
+
start_index = preprocess_args.start_index or 0
|
795 |
+
end_index = (preprocess_args.max_videos or len(data)) + start_index
|
796 |
+
|
797 |
+
data = list(itertools.islice(data, start_index, end_index))
|
798 |
+
|
799 |
+
write_mode = 'w' # if preprocess_args.overwrite else 'a'
|
800 |
+
with open(positive_file, write_mode, encoding='utf-8') as positive, \
|
801 |
+
open(negative_file, write_mode, encoding='utf-8') as negative, \
|
802 |
+
tqdm(data) as progress:
|
803 |
+
|
804 |
+
for offset, (video_id, sponsor_segments) in enumerate(data):
|
805 |
+
|
806 |
+
progress.set_description(f'Processing {video_id}')
|
807 |
+
progress.update()
|
808 |
+
|
809 |
+
# Use chunk granularity to improve manual transcripts
|
810 |
+
words = get_words(video_id, process=False, granularity='chunk')
|
811 |
+
if not words:
|
812 |
+
continue
|
813 |
+
|
814 |
+
if len(words) <= 1:
|
815 |
+
continue
|
816 |
+
|
817 |
+
segments = segment.generate_labelled_segments(
|
818 |
+
words, tokenizer, segmentation_args, sponsor_segments)
|
819 |
+
|
820 |
+
if not segments:
|
821 |
+
continue
|
822 |
+
|
823 |
+
for seg in segments:
|
824 |
+
seg_start = segment.word_start(seg[0])
|
825 |
+
seg_end = segment.word_end(seg[-1])
|
826 |
+
duration = seg_end - seg_start
|
827 |
+
wps = len(seg)/duration if duration > 0 else 0
|
828 |
+
|
829 |
+
# Ignore segments with "not enough words" in the transcript
|
830 |
+
# Must do here since this includes non-sponsor segments
|
831 |
+
if wps < preprocess_args.min_wps:
|
832 |
+
continue
|
833 |
+
|
834 |
+
d = {
|
835 |
+
# 'video_index': offset + start_index,
|
836 |
+
'video_id': video_id,
|
837 |
+
# 'uuid': video_id, # TODO add uuid
|
838 |
+
'text': ' '.join(x['cleaned'] for x in seg),
|
839 |
+
'start': seg_start,
|
840 |
+
'end': seg_end,
|
841 |
+
}
|
842 |
+
|
843 |
+
extracted_segments = extract_sponsors(seg)
|
844 |
+
if extracted_segments:
|
845 |
+
extracted_texts = []
|
846 |
+
for s in extracted_segments:
|
847 |
+
w = ' '.join(q['cleaned'] for q in s['words'])
|
848 |
+
category = s['category'].upper()
|
849 |
+
extracted_texts.append(
|
850 |
+
f'{START_SEGMENT_TEMPLATE.format(category)} {w} {END_SEGMENT_TEMPLATE.format(category)}'
|
851 |
+
)
|
852 |
+
|
853 |
+
d['extracted'] = f' {CustomTokens.BETWEEN_SEGMENTS.value} '.join(
|
854 |
+
extracted_texts)
|
855 |
+
print(json.dumps(d), file=positive)
|
856 |
+
|
857 |
+
else:
|
858 |
+
d['extracted'] = CustomTokens.NO_SEGMENT.value
|
859 |
+
print(json.dumps(d), file=negative)
|
860 |
+
|
861 |
+
if preprocess_args.do_split:
|
862 |
+
logger.info('Splitting')
|
863 |
+
logger.info('Read files')
|
864 |
+
|
865 |
+
with open(positive_file, encoding='utf-8') as positive:
|
866 |
+
sponsors = positive.readlines()
|
867 |
+
|
868 |
+
with open(negative_file, encoding='utf-8') as negative:
|
869 |
+
non_sponsors = negative.readlines()
|
870 |
+
|
871 |
+
logger.info('Shuffle')
|
872 |
+
random.shuffle(sponsors)
|
873 |
+
random.shuffle(non_sponsors)
|
874 |
+
|
875 |
+
logger.info('Calculate ratios')
|
876 |
+
# Ensure correct ratio of positive to negative segments
|
877 |
+
percentage_negative = 1 - preprocess_args.percentage_positive
|
878 |
+
|
879 |
+
if preprocess_args.percentage_positive * len(sponsors) > len(non_sponsors):
|
880 |
+
# Negative is limiting
|
881 |
+
z = int(preprocess_args.percentage_positive /
|
882 |
+
percentage_negative * len(non_sponsors))
|
883 |
+
|
884 |
+
# excess = sponsors[z:]
|
885 |
+
sponsors = sponsors[:z]
|
886 |
+
|
887 |
+
else:
|
888 |
+
# Positive is limiting
|
889 |
+
z = int(percentage_negative /
|
890 |
+
preprocess_args.percentage_positive * len(sponsors))
|
891 |
+
|
892 |
+
# excess = non_sponsors[z:]
|
893 |
+
non_sponsors = non_sponsors[:z]
|
894 |
+
|
895 |
+
logger.info('Join')
|
896 |
+
all_labelled_segments = sponsors + non_sponsors
|
897 |
+
|
898 |
+
random.shuffle(all_labelled_segments)
|
899 |
+
|
900 |
+
# TODO split based on video ids
|
901 |
+
logger.info('Split')
|
902 |
+
ratios = [preprocess_args.train_split,
|
903 |
+
preprocess_args.test_split,
|
904 |
+
preprocess_args.valid_split]
|
905 |
+
|
906 |
+
train_data, test_data, valid_data = split(
|
907 |
+
all_labelled_segments, ratios)
|
908 |
+
|
909 |
+
splits = {
|
910 |
+
dataset_args.train_file: train_data,
|
911 |
+
dataset_args.test_file: test_data,
|
912 |
+
dataset_args.validation_file: valid_data
|
913 |
+
}
|
914 |
+
|
915 |
+
# Output training, testing and validation data
|
916 |
+
for name, items in splits.items():
|
917 |
+
outfile = os.path.join(dataset_args.data_dir, name)
|
918 |
+
with open(outfile, 'w', encoding='utf-8') as fp:
|
919 |
+
fp.writelines(items)
|
920 |
+
|
921 |
+
classifier_splits = {
|
922 |
+
dataset_args.c_train_file: train_data,
|
923 |
+
dataset_args.c_test_file: test_data,
|
924 |
+
dataset_args.c_validation_file: valid_data
|
925 |
+
}
|
926 |
+
|
927 |
+
none_category = CATEGORIES.index(None)
|
928 |
+
|
929 |
+
# Output training, testing and validation data
|
930 |
+
for name, items in classifier_splits.items():
|
931 |
+
outfile = os.path.join(dataset_args.data_dir, name)
|
932 |
+
with open(outfile, 'w', encoding='utf-8') as fp:
|
933 |
+
for item in items:
|
934 |
+
parsed_item = json.loads(item) # TODO add uuid
|
935 |
+
|
936 |
+
matches = extract_sponsor_matches_from_text(
|
937 |
+
parsed_item['extracted'])
|
938 |
+
|
939 |
+
if matches:
|
940 |
+
for match in matches:
|
941 |
+
print(json.dumps({
|
942 |
+
'text': match['text'],
|
943 |
+
'label': CATEGORIES.index(match['category'])
|
944 |
+
}), file=fp)
|
945 |
+
else:
|
946 |
+
print(json.dumps({
|
947 |
+
'text': parsed_item['text'],
|
948 |
+
'label': none_category
|
949 |
+
}), file=fp)
|
950 |
+
|
951 |
+
logger.info('Write')
|
952 |
+
# Save excess items
|
953 |
+
# excess_path = os.path.join(
|
954 |
+
# dataset_args.data_dir, dataset_args.excess_file)
|
955 |
+
# if not os.path.exists(excess_path) or preprocess_args.overwrite:
|
956 |
+
# with open(excess_path, 'w', encoding='utf-8') as fp:
|
957 |
+
# fp.writelines(excess)
|
958 |
+
# else:
|
959 |
+
# logger.info(f'Skipping {dataset_args.excess_file}')
|
960 |
+
|
961 |
+
logger.info(
|
962 |
+
f'Finished splitting: {len(sponsors)} sponsors, {len(non_sponsors)} non sponsors')
|
963 |
+
|
964 |
+
|
965 |
+
def split(arr, ratios):
|
966 |
+
"""Split array according to ratios. Sum of ratios should be <= 1"""
|
967 |
+
to_return = []
|
968 |
+
|
969 |
+
cumulative_sum = 0
|
970 |
+
for r in ratios:
|
971 |
+
current = cumulative_sum
|
972 |
+
cumulative_sum += r * len(arr)
|
973 |
+
to_return.append(arr[int(current):int(cumulative_sum)])
|
974 |
+
|
975 |
+
return to_return
|
976 |
+
|
977 |
+
|
978 |
+
if __name__ == '__main__':
|
979 |
+
main()
|
src/segment.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import preprocess
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class SegmentationArguments:
|
7 |
+
pause_threshold: int = field(default=2.5, metadata={
|
8 |
+
'help': 'When the time between words is greater than pause threshold, force into a new segment'})
|
9 |
+
|
10 |
+
|
11 |
+
def get_overlapping_chunks_of_tokens(tokens, size, overlap):
|
12 |
+
for i in range(0, len(tokens), size-overlap+1):
|
13 |
+
yield tokens[i:i+size]
|
14 |
+
|
15 |
+
|
16 |
+
# Generate up to SAFETY_TOKENS_PERCENTAGE*max_tokens tokens
|
17 |
+
MIN_SAFETY_TOKENS = 8
|
18 |
+
SAFETY_TOKENS_PERCENTAGE = 0.9765625
|
19 |
+
# e.g. 512 -> 500, 768 -> 750
|
20 |
+
|
21 |
+
|
22 |
+
# TODO play around with this?
|
23 |
+
OVERLAP_TOKEN_PERCENTAGE = 0.5 # 0.25
|
24 |
+
|
25 |
+
|
26 |
+
def add_labels_to_words(words, sponsor_segments):
|
27 |
+
|
28 |
+
for sponsor_segment in sponsor_segments:
|
29 |
+
for w in extract_segment(words, sponsor_segment['start'], sponsor_segment['end']):
|
30 |
+
w['category'] = sponsor_segment['category']
|
31 |
+
|
32 |
+
return words
|
33 |
+
|
34 |
+
|
35 |
+
def generate_labelled_segments(words, tokenizer, segmentation_args, sponsor_segments):
|
36 |
+
segments = generate_segments(words, tokenizer, segmentation_args)
|
37 |
+
|
38 |
+
labelled_segments = list(
|
39 |
+
map(lambda x: add_labels_to_words(x, sponsor_segments), segments))
|
40 |
+
|
41 |
+
return labelled_segments
|
42 |
+
|
43 |
+
|
44 |
+
def word_start(word):
|
45 |
+
return word['start']
|
46 |
+
|
47 |
+
|
48 |
+
def word_end(word):
|
49 |
+
return word.get('end', word['start'])
|
50 |
+
|
51 |
+
|
52 |
+
def generate_segments(words, tokenizer, segmentation_args):
|
53 |
+
|
54 |
+
cleaned_words_list = []
|
55 |
+
for w in words:
|
56 |
+
w['cleaned'] = preprocess.clean_text(w['text'])
|
57 |
+
cleaned_words_list.append(w['cleaned'])
|
58 |
+
|
59 |
+
# Get lengths of tokenized words
|
60 |
+
num_tokens_list = tokenizer(cleaned_words_list, add_special_tokens=False,
|
61 |
+
truncation=True, return_attention_mask=False, return_length=True).length
|
62 |
+
|
63 |
+
first_pass_segments = []
|
64 |
+
for index, (word, num_tokens) in enumerate(zip(words, num_tokens_list)):
|
65 |
+
word['num_tokens'] = num_tokens
|
66 |
+
|
67 |
+
# Add new segment
|
68 |
+
if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
|
69 |
+
first_pass_segments.append([word])
|
70 |
+
|
71 |
+
else: # Add to current segment
|
72 |
+
first_pass_segments[-1].append(word)
|
73 |
+
|
74 |
+
max_q_size = round(SAFETY_TOKENS_PERCENTAGE * tokenizer.model_max_length)
|
75 |
+
|
76 |
+
buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size # tokenizer.model_max_length
|
77 |
+
|
78 |
+
# In second pass, we split those segments if too big
|
79 |
+
second_pass_segments = []
|
80 |
+
|
81 |
+
for segment in first_pass_segments:
|
82 |
+
current_segment_num_tokens = 0
|
83 |
+
current_segment = []
|
84 |
+
after_split_segments = []
|
85 |
+
for word in segment:
|
86 |
+
new_seg = current_segment_num_tokens + \
|
87 |
+
word['num_tokens'] >= max_q_size
|
88 |
+
if new_seg:
|
89 |
+
# Adding this token would make it have too many tokens
|
90 |
+
# We save this batch and create new
|
91 |
+
after_split_segments.append(current_segment)
|
92 |
+
|
93 |
+
# Add tokens to current segment
|
94 |
+
current_segment.append(word)
|
95 |
+
current_segment_num_tokens += word['num_tokens']
|
96 |
+
|
97 |
+
if not new_seg:
|
98 |
+
continue
|
99 |
+
|
100 |
+
# Just created a new segment, so we remove until we only have buffer_size tokens
|
101 |
+
last_index = 0
|
102 |
+
while current_segment_num_tokens > buffer_size and current_segment:
|
103 |
+
current_segment_num_tokens -= current_segment[last_index]['num_tokens']
|
104 |
+
last_index += 1
|
105 |
+
|
106 |
+
current_segment = current_segment[last_index:]
|
107 |
+
|
108 |
+
if current_segment: # Add remaining segment
|
109 |
+
after_split_segments.append(current_segment)
|
110 |
+
|
111 |
+
# TODO if len(after_split_segments) > 1, a split occurred
|
112 |
+
|
113 |
+
second_pass_segments.extend(after_split_segments)
|
114 |
+
|
115 |
+
# Cleaning up, delete 'num_tokens' from each word
|
116 |
+
for word in words:
|
117 |
+
word.pop('num_tokens', None)
|
118 |
+
|
119 |
+
return second_pass_segments
|
120 |
+
|
121 |
+
|
122 |
+
def extract_segment(words, start, end, map_function=None):
|
123 |
+
"""Extracts all words with time in [start, end]"""
|
124 |
+
if words is None:
|
125 |
+
words = []
|
126 |
+
|
127 |
+
a = max(binary_search_below(words, 0, len(words), start), 0)
|
128 |
+
b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
|
129 |
+
|
130 |
+
to_transform = map_function is not None and callable(map_function)
|
131 |
+
|
132 |
+
return [
|
133 |
+
map_function(words[i]) if to_transform else words[i] for i in range(a, b)
|
134 |
+
]
|
135 |
+
|
136 |
+
|
137 |
+
def avg(*items):
|
138 |
+
return sum(items)/len(items)
|
139 |
+
|
140 |
+
|
141 |
+
def binary_search_below(transcript, start_index, end_index, time):
|
142 |
+
if start_index >= end_index:
|
143 |
+
return end_index
|
144 |
+
|
145 |
+
middle_index = (start_index + end_index) // 2
|
146 |
+
middle = transcript[middle_index]
|
147 |
+
middle_time = avg(word_start(middle), word_end(middle))
|
148 |
+
|
149 |
+
if time <= middle_time:
|
150 |
+
return binary_search_below(transcript, start_index, middle_index, time)
|
151 |
+
else:
|
152 |
+
return binary_search_below(transcript, middle_index + 1, end_index, time)
|
153 |
+
|
154 |
+
|
155 |
+
def binary_search_above(transcript, start_index, end_index, time):
|
156 |
+
if start_index >= end_index:
|
157 |
+
return end_index
|
158 |
+
|
159 |
+
middle_index = (start_index + end_index + 1) // 2
|
160 |
+
middle = transcript[middle_index]
|
161 |
+
middle_time = avg(word_start(middle), word_end(middle))
|
162 |
+
|
163 |
+
if time >= middle_time:
|
164 |
+
return binary_search_above(transcript, middle_index, end_index, time)
|
165 |
+
else:
|
166 |
+
return binary_search_above(transcript, start_index, middle_index - 1, time)
|
src/shared.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.trainer_utils import get_last_checkpoint as glc
|
2 |
+
import os
|
3 |
+
from utils import re_findall
|
4 |
+
import logging
|
5 |
+
import sys
|
6 |
+
from datasets import load_dataset
|
7 |
+
import re
|
8 |
+
import gc
|
9 |
+
from time import time_ns
|
10 |
+
import random
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from typing import Optional
|
14 |
+
from dataclasses import dataclass, field
|
15 |
+
from enum import Enum
|
16 |
+
|
17 |
+
|
18 |
+
logging.basicConfig()
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
# Setup logging
|
22 |
+
logging.basicConfig(
|
23 |
+
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
24 |
+
datefmt='%m/%d/%Y %H:%M:%S',
|
25 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
26 |
+
)
|
27 |
+
|
28 |
+
CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
|
29 |
+
|
30 |
+
ACTION_OPTIONS = ['skip', 'mute', 'full']
|
31 |
+
|
32 |
+
CATGEGORY_OPTIONS = {
|
33 |
+
'SPONSOR': 'Sponsor',
|
34 |
+
'SELFPROMO': 'Self/unpaid promo',
|
35 |
+
'INTERACTION': 'Interaction reminder',
|
36 |
+
}
|
37 |
+
|
38 |
+
START_SEGMENT_TEMPLATE = 'START_{}_TOKEN'
|
39 |
+
END_SEGMENT_TEMPLATE = 'END_{}_TOKEN'
|
40 |
+
|
41 |
+
|
42 |
+
class CustomTokens(Enum):
|
43 |
+
EXTRACT_SEGMENTS_PREFIX = 'EXTRACT_SEGMENTS: '
|
44 |
+
|
45 |
+
# Preprocessing tokens
|
46 |
+
URL = 'URL_TOKEN'
|
47 |
+
HYPHENATED_URL = 'HYPHENATED_URL_TOKEN'
|
48 |
+
NUMBER_PERCENTAGE = 'NUMBER_PERCENTAGE_TOKEN'
|
49 |
+
NUMBER = 'NUMBER_TOKEN'
|
50 |
+
|
51 |
+
SHORT_HYPHENATED = 'SHORT_HYPHENATED_TOKEN'
|
52 |
+
LONG_WORD = 'LONG_WORD_TOKEN'
|
53 |
+
|
54 |
+
# Custom YouTube tokens
|
55 |
+
MUSIC = '[Music]'
|
56 |
+
APPLAUSE = '[Applause]'
|
57 |
+
LAUGHTER = '[Laughter]'
|
58 |
+
|
59 |
+
PROFANITY = 'PROFANITY_TOKEN'
|
60 |
+
|
61 |
+
# Segment tokens
|
62 |
+
NO_SEGMENT = 'NO_SEGMENT_TOKEN'
|
63 |
+
|
64 |
+
START_SPONSOR = START_SEGMENT_TEMPLATE.format('SPONSOR')
|
65 |
+
END_SPONSOR = END_SEGMENT_TEMPLATE.format('SPONSOR')
|
66 |
+
|
67 |
+
START_SELFPROMO = START_SEGMENT_TEMPLATE.format('SELFPROMO')
|
68 |
+
END_SELFPROMO = END_SEGMENT_TEMPLATE.format('SELFPROMO')
|
69 |
+
|
70 |
+
START_INTERACTION = START_SEGMENT_TEMPLATE.format('INTERACTION')
|
71 |
+
END_INTERACTION = END_SEGMENT_TEMPLATE.format('INTERACTION')
|
72 |
+
|
73 |
+
BETWEEN_SEGMENTS = 'BETWEEN_SEGMENTS_TOKEN'
|
74 |
+
|
75 |
+
@classmethod
|
76 |
+
def custom_tokens(cls):
|
77 |
+
return [e.value for e in cls]
|
78 |
+
|
79 |
+
@classmethod
|
80 |
+
def add_custom_tokens(cls, tokenizer):
|
81 |
+
tokenizer.add_tokens(cls.custom_tokens())
|
82 |
+
|
83 |
+
|
84 |
+
_SEGMENT_START = START_SEGMENT_TEMPLATE.format(r'(?P<category>\w+)')
|
85 |
+
_SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
|
86 |
+
SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
|
87 |
+
|
88 |
+
|
89 |
+
def extract_sponsor_matches_from_text(text):
|
90 |
+
if CustomTokens.NO_SEGMENT.value in text:
|
91 |
+
return []
|
92 |
+
else:
|
93 |
+
return re_findall(SEGMENT_MATCH_RE, text)
|
94 |
+
|
95 |
+
|
96 |
+
def extract_sponsor_matches(texts):
|
97 |
+
return list(map(extract_sponsor_matches_from_text, texts))
|
98 |
+
|
99 |
+
|
100 |
+
@dataclass
|
101 |
+
class DatasetArguments:
|
102 |
+
data_dir: Optional[str] = field(
|
103 |
+
default='data',
|
104 |
+
metadata={
|
105 |
+
'help': 'The directory which stores train, test and/or validation data.'
|
106 |
+
},
|
107 |
+
)
|
108 |
+
processed_file: Optional[str] = field(
|
109 |
+
default='segments.json',
|
110 |
+
metadata={
|
111 |
+
'help': 'Processed data file'
|
112 |
+
},
|
113 |
+
)
|
114 |
+
processed_database: Optional[str] = field(
|
115 |
+
default='processed_database.json',
|
116 |
+
metadata={
|
117 |
+
'help': 'Processed database file'
|
118 |
+
},
|
119 |
+
)
|
120 |
+
|
121 |
+
overwrite_cache: bool = field(
|
122 |
+
default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'}
|
123 |
+
)
|
124 |
+
|
125 |
+
dataset_cache_dir: Optional[str] = field(
|
126 |
+
default=None,
|
127 |
+
metadata={
|
128 |
+
'help': 'Where to store the cached datasets'
|
129 |
+
},
|
130 |
+
)
|
131 |
+
|
132 |
+
train_file: Optional[str] = field(
|
133 |
+
default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
|
134 |
+
)
|
135 |
+
validation_file: Optional[str] = field(
|
136 |
+
default='valid.json',
|
137 |
+
metadata={
|
138 |
+
'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
|
139 |
+
},
|
140 |
+
)
|
141 |
+
test_file: Optional[str] = field(
|
142 |
+
default='test.json',
|
143 |
+
metadata={
|
144 |
+
'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
|
145 |
+
},
|
146 |
+
)
|
147 |
+
|
148 |
+
c_train_file: Optional[str] = field(
|
149 |
+
default='c_train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
|
150 |
+
)
|
151 |
+
c_validation_file: Optional[str] = field(
|
152 |
+
default='c_valid.json',
|
153 |
+
metadata={
|
154 |
+
'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
|
155 |
+
},
|
156 |
+
)
|
157 |
+
c_test_file: Optional[str] = field(
|
158 |
+
default='c_test.json',
|
159 |
+
metadata={
|
160 |
+
'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
|
161 |
+
},
|
162 |
+
)
|
163 |
+
|
164 |
+
def __post_init__(self):
|
165 |
+
if self.train_file is None or self.validation_file is None:
|
166 |
+
raise ValueError(
|
167 |
+
'Need either a dataset name or a training/validation file.')
|
168 |
+
|
169 |
+
else:
|
170 |
+
train_extension = self.train_file.split(".")[-1]
|
171 |
+
assert train_extension in [
|
172 |
+
"csv", "json"], "`train_file` should be a csv or a json file."
|
173 |
+
validation_extension = self.validation_file.split(".")[-1]
|
174 |
+
assert (
|
175 |
+
validation_extension == train_extension
|
176 |
+
), "`validation_file` should have the same extension (csv or json) as `train_file`."
|
177 |
+
|
178 |
+
|
179 |
+
@dataclass
|
180 |
+
class OutputArguments:
|
181 |
+
|
182 |
+
output_dir: str = field(
|
183 |
+
default='out',
|
184 |
+
metadata={
|
185 |
+
'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
|
186 |
+
},
|
187 |
+
)
|
188 |
+
checkpoint: Optional[str] = field(
|
189 |
+
default=None,
|
190 |
+
metadata={
|
191 |
+
'help': 'Choose the checkpoint/model to train from or test with. Defaults to the latest checkpoint found in `output_dir`.'
|
192 |
+
},
|
193 |
+
)
|
194 |
+
models_dir: str = field(
|
195 |
+
default='models',
|
196 |
+
metadata={
|
197 |
+
'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
|
198 |
+
},
|
199 |
+
)
|
200 |
+
# classifier_dir: str = field(
|
201 |
+
# default='out',
|
202 |
+
# metadata={
|
203 |
+
# 'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
|
204 |
+
# },
|
205 |
+
# )
|
206 |
+
|
207 |
+
|
208 |
+
def seed_factory():
|
209 |
+
return time_ns() % (2**32 - 1)
|
210 |
+
|
211 |
+
|
212 |
+
@dataclass
|
213 |
+
class GeneralArguments:
|
214 |
+
seed: Optional[int] = field(default_factory=seed_factory, metadata={
|
215 |
+
'help': 'Set seed for deterministic training and testing. By default, it uses the current time (results in essentially random results).'
|
216 |
+
})
|
217 |
+
no_cuda: bool = field(default=False, metadata={
|
218 |
+
'help': 'Do not use CUDA even when it is available'})
|
219 |
+
|
220 |
+
def __post_init__(self):
|
221 |
+
random.seed(self.seed)
|
222 |
+
np.random.seed(self.seed)
|
223 |
+
torch.manual_seed(self.seed)
|
224 |
+
torch.cuda.manual_seed_all(self.seed)
|
225 |
+
|
226 |
+
|
227 |
+
def seconds_to_time(seconds, remove_leading_zeroes=False):
|
228 |
+
fractional = round(seconds % 1, 3)
|
229 |
+
fractional = '' if fractional == 0 else str(fractional)[1:]
|
230 |
+
h, remainder = divmod(abs(int(seconds)), 3600)
|
231 |
+
m, s = divmod(remainder, 60)
|
232 |
+
hms = f'{h:02}:{m:02}:{s:02}'
|
233 |
+
if remove_leading_zeroes:
|
234 |
+
hms = re.sub(r'^0(?:0:0?)?', '', hms)
|
235 |
+
return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
|
236 |
+
|
237 |
+
|
238 |
+
def reset():
|
239 |
+
torch.clear_autocast_cache()
|
240 |
+
torch.cuda.empty_cache()
|
241 |
+
gc.collect()
|
242 |
+
print(torch.cuda.memory_summary(device=None, abbreviated=False))
|
243 |
+
|
244 |
+
|
245 |
+
def load_datasets(dataset_args: DatasetArguments):
|
246 |
+
|
247 |
+
logger.info('Reading datasets')
|
248 |
+
data_files = {}
|
249 |
+
|
250 |
+
if dataset_args.train_file is not None:
|
251 |
+
data_files['train'] = os.path.join(
|
252 |
+
dataset_args.data_dir, dataset_args.train_file)
|
253 |
+
if dataset_args.validation_file is not None:
|
254 |
+
data_files['validation'] = os.path.join(
|
255 |
+
dataset_args.data_dir, dataset_args.validation_file)
|
256 |
+
if dataset_args.test_file is not None:
|
257 |
+
data_files['test'] = os.path.join(
|
258 |
+
dataset_args.data_dir, dataset_args.test_file)
|
259 |
+
|
260 |
+
return load_dataset('json', data_files=data_files, cache_dir=dataset_args.dataset_cache_dir)
|
261 |
+
|
262 |
+
|
263 |
+
@dataclass
|
264 |
+
class AdditionalTrainingArguments:
|
265 |
+
seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
|
266 |
+
|
267 |
+
num_train_epochs: float = field(
|
268 |
+
default=1, metadata={'help': 'Total number of training epochs to perform.'})
|
269 |
+
|
270 |
+
save_steps: int = field(default=5000, metadata={
|
271 |
+
'help': 'Save checkpoint every X updates steps.'})
|
272 |
+
eval_steps: int = field(default=25000, metadata={
|
273 |
+
'help': 'Run an evaluation every X steps.'})
|
274 |
+
logging_steps: int = field(default=5000, metadata={
|
275 |
+
'help': 'Log every X updates steps.'})
|
276 |
+
|
277 |
+
# do_eval: bool = field(default=False, metadata={
|
278 |
+
# 'help': 'Whether to run eval on the dev set.'})
|
279 |
+
# do_predict: bool = field(default=False, metadata={
|
280 |
+
# 'help': 'Whether to run predictions on the test set.'})
|
281 |
+
|
282 |
+
per_device_train_batch_size: int = field(
|
283 |
+
default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for training.'}
|
284 |
+
)
|
285 |
+
per_device_eval_batch_size: int = field(
|
286 |
+
default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for evaluation.'}
|
287 |
+
)
|
288 |
+
|
289 |
+
# report_to: Optional[List[str]] = field(
|
290 |
+
# default=None, metadata={"help": "The list of integrations to report the results and logs to."}
|
291 |
+
# )
|
292 |
+
evaluation_strategy: str = field(
|
293 |
+
default='steps',
|
294 |
+
metadata={
|
295 |
+
'help': 'The evaluation strategy to use.',
|
296 |
+
'choices': ['no', 'steps', 'epoch']
|
297 |
+
},
|
298 |
+
)
|
299 |
+
|
300 |
+
# evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
|
301 |
+
# The evaluation strategy to adopt during training. Possible values are:
|
302 |
+
|
303 |
+
# * :obj:`"no"`: No evaluation is done during training.
|
304 |
+
# * :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
|
305 |
+
# * :obj:`"epoch"`: Evaluation is done at the end of each epoch.
|
306 |
+
|
307 |
+
preprocessing_num_workers: Optional[int] = field(
|
308 |
+
default=None,
|
309 |
+
metadata={'help': 'The number of processes to use for the preprocessing.'},
|
310 |
+
)
|
311 |
+
max_seq_length: int = field(
|
312 |
+
default=512,
|
313 |
+
metadata={
|
314 |
+
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
315 |
+
"than this will be truncated, sequences shorter will be padded."
|
316 |
+
},
|
317 |
+
)
|
318 |
+
max_train_samples: Optional[int] = field(
|
319 |
+
default=None,
|
320 |
+
metadata={
|
321 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
322 |
+
"value if set."
|
323 |
+
},
|
324 |
+
)
|
325 |
+
max_eval_samples: Optional[int] = field(
|
326 |
+
default=None,
|
327 |
+
metadata={
|
328 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
329 |
+
"value if set."
|
330 |
+
},
|
331 |
+
)
|
332 |
+
max_predict_samples: Optional[int] = field(
|
333 |
+
default=None,
|
334 |
+
metadata={
|
335 |
+
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
336 |
+
"value if set."
|
337 |
+
},
|
338 |
+
)
|
339 |
+
|
340 |
+
|
341 |
+
@dataclass
|
342 |
+
class CustomTrainingArguments(OutputArguments, AdditionalTrainingArguments):
|
343 |
+
pass
|
344 |
+
|
345 |
+
|
346 |
+
def get_last_checkpoint(training_args):
|
347 |
+
last_checkpoint = None
|
348 |
+
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
349 |
+
last_checkpoint = glc(training_args.output_dir)
|
350 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
351 |
+
raise ValueError(
|
352 |
+
f'Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome.'
|
353 |
+
)
|
354 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
355 |
+
logger.info(
|
356 |
+
f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
|
357 |
+
)
|
358 |
+
return last_checkpoint
|
359 |
+
|
360 |
+
|
361 |
+
def train_from_checkpoint(trainer, last_checkpoint, training_args):
|
362 |
+
checkpoint = None
|
363 |
+
if training_args.resume_from_checkpoint is not None:
|
364 |
+
checkpoint = training_args.resume_from_checkpoint
|
365 |
+
elif last_checkpoint is not None:
|
366 |
+
checkpoint = last_checkpoint
|
367 |
+
|
368 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
369 |
+
|
370 |
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
371 |
+
|
372 |
+
return train_result
|
373 |
+
|
374 |
+
|
375 |
+
def prepare_datasets(raw_datasets, dataset_args: DatasetArguments, training_args: CustomTrainingArguments, preprocess_function):
|
376 |
+
|
377 |
+
with training_args.main_process_first(desc="dataset map pre-processing"):
|
378 |
+
raw_datasets = raw_datasets.map(
|
379 |
+
preprocess_function,
|
380 |
+
batched=True,
|
381 |
+
load_from_cache_file=not dataset_args.overwrite_cache,
|
382 |
+
desc="Running tokenizer on dataset",
|
383 |
+
)
|
384 |
+
|
385 |
+
if 'train' not in raw_datasets:
|
386 |
+
raise ValueError('Train dataset missing')
|
387 |
+
train_dataset = raw_datasets['train']
|
388 |
+
if training_args.max_train_samples is not None:
|
389 |
+
train_dataset = train_dataset.select(
|
390 |
+
range(training_args.max_train_samples))
|
391 |
+
|
392 |
+
if 'validation' not in raw_datasets:
|
393 |
+
raise ValueError('Validation dataset missing')
|
394 |
+
eval_dataset = raw_datasets['validation']
|
395 |
+
if training_args.max_eval_samples is not None:
|
396 |
+
eval_dataset = eval_dataset.select(
|
397 |
+
range(training_args.max_eval_samples))
|
398 |
+
|
399 |
+
if 'test' not in raw_datasets:
|
400 |
+
raise ValueError('Test dataset missing')
|
401 |
+
predict_dataset = raw_datasets['test']
|
402 |
+
if training_args.max_predict_samples is not None:
|
403 |
+
predict_dataset = predict_dataset.select(
|
404 |
+
range(training_args.max_predict_samples))
|
405 |
+
|
406 |
+
return train_dataset, eval_dataset, predict_dataset
|
src/train.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from shared import (
|
2 |
+
CustomTokens,
|
3 |
+
DatasetArguments,
|
4 |
+
prepare_datasets,
|
5 |
+
load_datasets,
|
6 |
+
CustomTrainingArguments,
|
7 |
+
get_last_checkpoint,
|
8 |
+
train_from_checkpoint
|
9 |
+
)
|
10 |
+
from model import ModelArguments
|
11 |
+
import transformers
|
12 |
+
import logging
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
from datasets import utils as d_utils
|
16 |
+
from transformers import (
|
17 |
+
DataCollatorForSeq2Seq,
|
18 |
+
HfArgumentParser,
|
19 |
+
Seq2SeqTrainer,
|
20 |
+
Seq2SeqTrainingArguments,
|
21 |
+
)
|
22 |
+
|
23 |
+
from transformers.utils import check_min_version
|
24 |
+
from transformers.utils.versions import require_version
|
25 |
+
from dataclasses import dataclass
|
26 |
+
|
27 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
28 |
+
check_min_version('4.17.0')
|
29 |
+
require_version('datasets>=1.8.0',
|
30 |
+
'To fix: pip install -r requirements.txt')
|
31 |
+
|
32 |
+
os.environ['WANDB_DISABLED'] = 'true'
|
33 |
+
|
34 |
+
logging.basicConfig()
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
# Setup logging
|
38 |
+
logging.basicConfig(
|
39 |
+
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
40 |
+
datefmt='%m/%d/%Y %H:%M:%S',
|
41 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class Seq2SeqTrainingArguments(CustomTrainingArguments, Seq2SeqTrainingArguments):
|
47 |
+
pass
|
48 |
+
|
49 |
+
|
50 |
+
def main():
|
51 |
+
|
52 |
+
# See all possible arguments in src/transformers/training_args.py
|
53 |
+
# or by passing the --help flag to this script.
|
54 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
55 |
+
|
56 |
+
hf_parser = HfArgumentParser((
|
57 |
+
ModelArguments,
|
58 |
+
DatasetArguments,
|
59 |
+
Seq2SeqTrainingArguments
|
60 |
+
))
|
61 |
+
model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
|
62 |
+
|
63 |
+
log_level = training_args.get_process_log_level()
|
64 |
+
logger.setLevel(log_level)
|
65 |
+
d_utils.logging.set_verbosity(log_level)
|
66 |
+
transformers.utils.logging.set_verbosity(log_level)
|
67 |
+
transformers.utils.logging.enable_default_handler()
|
68 |
+
transformers.utils.logging.enable_explicit_format()
|
69 |
+
|
70 |
+
# Set seed before initializing model.
|
71 |
+
# set_seed(training_args.seed)
|
72 |
+
|
73 |
+
# Log on each process the small summary:
|
74 |
+
logger.warning(
|
75 |
+
f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
|
76 |
+
+ f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
|
77 |
+
)
|
78 |
+
logger.info(f'Training/evaluation parameters {training_args}')
|
79 |
+
|
80 |
+
# FP16 https://github.com/huggingface/transformers/issues/9295
|
81 |
+
|
82 |
+
# Works:
|
83 |
+
# https://huggingface.co/docs/transformers/model_doc/t5v1.1
|
84 |
+
# google/t5-v1_1-small
|
85 |
+
# google/t5-v1_1-base
|
86 |
+
# google/t5-v1_1-large
|
87 |
+
# google/t5-v1_1-xl
|
88 |
+
# google/t5-v1_1-xxl
|
89 |
+
|
90 |
+
# https://huggingface.co/docs/transformers/model_doc/t5
|
91 |
+
# t5-small
|
92 |
+
# t5-base
|
93 |
+
# t5-large
|
94 |
+
# t5-3b
|
95 |
+
# t5-11b
|
96 |
+
|
97 |
+
# allenai/led-base-16384 - https://github.com/huggingface/transformers/issues/9810
|
98 |
+
|
99 |
+
# Further work:
|
100 |
+
# Multilingual- https://huggingface.co/docs/transformers/model_doc/mt5
|
101 |
+
|
102 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
103 |
+
# download the dataset.
|
104 |
+
raw_datasets = load_datasets(dataset_args)
|
105 |
+
# , cache_dir=model_args.cache_dir
|
106 |
+
|
107 |
+
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
108 |
+
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
109 |
+
|
110 |
+
# Detecting last checkpoint.
|
111 |
+
last_checkpoint = get_last_checkpoint(training_args)
|
112 |
+
|
113 |
+
from model import get_model_tokenizer
|
114 |
+
model, tokenizer = get_model_tokenizer(model_args, training_args)
|
115 |
+
|
116 |
+
# Preprocessing the datasets.
|
117 |
+
# We need to tokenize inputs and targets.
|
118 |
+
|
119 |
+
prefix = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value
|
120 |
+
|
121 |
+
PAD_TOKEN_REPLACE_ID = -100
|
122 |
+
|
123 |
+
# https://github.com/huggingface/transformers/issues/5204
|
124 |
+
def preprocess_function(examples):
|
125 |
+
inputs = examples['text']
|
126 |
+
targets = examples['extracted']
|
127 |
+
inputs = [prefix + inp for inp in inputs]
|
128 |
+
model_inputs = tokenizer(inputs, truncation=True)
|
129 |
+
|
130 |
+
# Setup the tokenizer for targets
|
131 |
+
with tokenizer.as_target_tokenizer():
|
132 |
+
labels = tokenizer(targets, truncation=True)
|
133 |
+
|
134 |
+
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100
|
135 |
+
# when we want to ignore padding in the loss.
|
136 |
+
|
137 |
+
model_inputs['labels'] = [
|
138 |
+
[(l if l != tokenizer.pad_token_id else PAD_TOKEN_REPLACE_ID)
|
139 |
+
for l in label]
|
140 |
+
for label in labels['input_ids']
|
141 |
+
]
|
142 |
+
|
143 |
+
return model_inputs
|
144 |
+
|
145 |
+
train_dataset, eval_dataset, predict_dataset = prepare_datasets(
|
146 |
+
raw_datasets, dataset_args, training_args, preprocess_function)
|
147 |
+
|
148 |
+
# Data collator
|
149 |
+
data_collator = DataCollatorForSeq2Seq(
|
150 |
+
tokenizer,
|
151 |
+
model=model,
|
152 |
+
label_pad_token_id=PAD_TOKEN_REPLACE_ID,
|
153 |
+
pad_to_multiple_of=8 if training_args.fp16 else None,
|
154 |
+
)
|
155 |
+
|
156 |
+
# Done processing datasets
|
157 |
+
|
158 |
+
# Initialize our Trainer
|
159 |
+
trainer = Seq2SeqTrainer(
|
160 |
+
model=model,
|
161 |
+
args=training_args,
|
162 |
+
train_dataset=train_dataset,
|
163 |
+
eval_dataset=eval_dataset,
|
164 |
+
tokenizer=tokenizer,
|
165 |
+
data_collator=data_collator,
|
166 |
+
)
|
167 |
+
|
168 |
+
# Training
|
169 |
+
train_result = train_from_checkpoint(
|
170 |
+
trainer, last_checkpoint, training_args)
|
171 |
+
|
172 |
+
metrics = train_result.metrics
|
173 |
+
max_train_samples = training_args.max_train_samples or len(
|
174 |
+
train_dataset)
|
175 |
+
metrics['train_samples'] = min(max_train_samples, len(train_dataset))
|
176 |
+
|
177 |
+
trainer.log_metrics('train', metrics)
|
178 |
+
trainer.save_metrics('train', metrics)
|
179 |
+
trainer.save_state()
|
180 |
+
|
181 |
+
kwargs = {'finetuned_from': model_args.model_name_or_path,
|
182 |
+
'tasks': 'summarization'}
|
183 |
+
|
184 |
+
if training_args.push_to_hub:
|
185 |
+
trainer.push_to_hub(**kwargs)
|
186 |
+
else:
|
187 |
+
trainer.create_model_card(**kwargs)
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == '__main__':
|
191 |
+
main()
|
src/train_classifier.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
""" Finetuning the library models for sequence classification."""
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import datasets
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import transformers
|
14 |
+
from transformers import (
|
15 |
+
DataCollatorWithPadding,
|
16 |
+
EvalPrediction,
|
17 |
+
HfArgumentParser,
|
18 |
+
Trainer,
|
19 |
+
TrainingArguments,
|
20 |
+
set_seed,
|
21 |
+
)
|
22 |
+
from transformers.utils import check_min_version
|
23 |
+
from transformers.utils.versions import require_version
|
24 |
+
from shared import (
|
25 |
+
CATEGORIES,
|
26 |
+
DatasetArguments,
|
27 |
+
prepare_datasets,
|
28 |
+
load_datasets,
|
29 |
+
CustomTrainingArguments,
|
30 |
+
train_from_checkpoint,
|
31 |
+
get_last_checkpoint
|
32 |
+
)
|
33 |
+
from model import get_model_tokenizer, ModelArguments
|
34 |
+
|
35 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
36 |
+
check_min_version('4.17.0')
|
37 |
+
require_version('datasets>=1.8.0', 'To fix: pip install -r requirements.txt')
|
38 |
+
|
39 |
+
os.environ['WANDB_DISABLED'] = 'true'
|
40 |
+
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class ClassifierTrainingArguments(CustomTrainingArguments, TrainingArguments):
|
46 |
+
pass
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class ClassifierDatasetArguments(DatasetArguments):
|
51 |
+
def __post_init__(self):
|
52 |
+
self.train_file = self.c_train_file
|
53 |
+
self.validation_file = self.c_validation_file
|
54 |
+
self.test_file = self.c_test_file
|
55 |
+
|
56 |
+
|
57 |
+
def main():
|
58 |
+
# See all possible arguments in src/transformers/training_args.py
|
59 |
+
# or by passing the --help flag to this script.
|
60 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
61 |
+
|
62 |
+
hf_parser = HfArgumentParser((
|
63 |
+
ModelArguments,
|
64 |
+
ClassifierDatasetArguments,
|
65 |
+
ClassifierTrainingArguments
|
66 |
+
))
|
67 |
+
model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
|
68 |
+
|
69 |
+
# Setup logging
|
70 |
+
logging.basicConfig(
|
71 |
+
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
72 |
+
datefmt='%m/%d/%Y %H:%M:%S',
|
73 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
74 |
+
)
|
75 |
+
|
76 |
+
log_level = training_args.get_process_log_level()
|
77 |
+
logger.setLevel(log_level)
|
78 |
+
datasets.utils.logging.set_verbosity(log_level)
|
79 |
+
transformers.utils.logging.set_verbosity(log_level)
|
80 |
+
transformers.utils.logging.enable_default_handler()
|
81 |
+
transformers.utils.logging.enable_explicit_format()
|
82 |
+
|
83 |
+
# Log on each process the small summary:
|
84 |
+
logger.warning(
|
85 |
+
f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
|
86 |
+
+ f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
|
87 |
+
)
|
88 |
+
logger.info(f'Training/evaluation parameters {training_args}')
|
89 |
+
|
90 |
+
# Detecting last checkpoint.
|
91 |
+
last_checkpoint = get_last_checkpoint(training_args)
|
92 |
+
|
93 |
+
# Set seed before initializing model.
|
94 |
+
set_seed(training_args.seed)
|
95 |
+
|
96 |
+
# Loading a dataset from your local files.
|
97 |
+
# CSV/JSON training and evaluation files are needed.
|
98 |
+
raw_datasets = load_datasets(dataset_args)
|
99 |
+
|
100 |
+
# See more about loading any type of standard or custom dataset at
|
101 |
+
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
102 |
+
|
103 |
+
config_args = {
|
104 |
+
'num_labels': len(CATEGORIES),
|
105 |
+
'id2label': {k: str(v).upper() for k, v in enumerate(CATEGORIES)},
|
106 |
+
'label2id': {str(v).upper(): k for k, v in enumerate(CATEGORIES)}
|
107 |
+
}
|
108 |
+
model, tokenizer = get_model_tokenizer(
|
109 |
+
model_args, training_args, config_args=config_args, model_type='classifier')
|
110 |
+
|
111 |
+
if training_args.max_seq_length > tokenizer.model_max_length:
|
112 |
+
logger.warning(
|
113 |
+
f'The max_seq_length passed ({training_args.max_seq_length}) is larger than the maximum length for the'
|
114 |
+
f'model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}.'
|
115 |
+
)
|
116 |
+
max_seq_length = min(training_args.max_seq_length,
|
117 |
+
tokenizer.model_max_length)
|
118 |
+
|
119 |
+
def preprocess_function(examples):
|
120 |
+
# Tokenize the texts
|
121 |
+
result = tokenizer(
|
122 |
+
examples['text'], padding='max_length', max_length=max_seq_length, truncation=True)
|
123 |
+
result['label'] = examples['label']
|
124 |
+
return result
|
125 |
+
|
126 |
+
train_dataset, eval_dataset, predict_dataset = prepare_datasets(
|
127 |
+
raw_datasets, dataset_args, training_args, preprocess_function)
|
128 |
+
|
129 |
+
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
|
130 |
+
# predictions and label_ids field) and has to return a dictionary string to float.
|
131 |
+
def compute_metrics(p: EvalPrediction):
|
132 |
+
preds = p.predictions[0] if isinstance(
|
133 |
+
p.predictions, tuple) else p.predictions
|
134 |
+
preds = np.argmax(preds, axis=1)
|
135 |
+
return {'accuracy': (preds == p.label_ids).astype(np.float32).mean().item()}
|
136 |
+
|
137 |
+
# Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
|
138 |
+
# we already did the padding.
|
139 |
+
if training_args.fp16:
|
140 |
+
data_collator = DataCollatorWithPadding(
|
141 |
+
tokenizer, pad_to_multiple_of=8)
|
142 |
+
else:
|
143 |
+
data_collator = None
|
144 |
+
|
145 |
+
# Initialize our Trainer
|
146 |
+
trainer = Trainer(
|
147 |
+
model=model,
|
148 |
+
args=training_args,
|
149 |
+
train_dataset=train_dataset,
|
150 |
+
eval_dataset=eval_dataset,
|
151 |
+
compute_metrics=compute_metrics,
|
152 |
+
tokenizer=tokenizer,
|
153 |
+
data_collator=data_collator,
|
154 |
+
)
|
155 |
+
|
156 |
+
# Training
|
157 |
+
train_result = train_from_checkpoint(
|
158 |
+
trainer, last_checkpoint, training_args)
|
159 |
+
|
160 |
+
metrics = train_result.metrics
|
161 |
+
max_train_samples = (
|
162 |
+
training_args.max_train_samples if training_args.max_train_samples is not None else len(
|
163 |
+
train_dataset)
|
164 |
+
)
|
165 |
+
metrics['train_samples'] = min(max_train_samples, len(train_dataset))
|
166 |
+
|
167 |
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
168 |
+
|
169 |
+
trainer.log_metrics('train', metrics)
|
170 |
+
trainer.save_metrics('train', metrics)
|
171 |
+
trainer.save_state()
|
172 |
+
|
173 |
+
kwargs = {'finetuned_from': model_args.model_name_or_path,
|
174 |
+
'tasks': 'text-classification'}
|
175 |
+
if training_args.push_to_hub:
|
176 |
+
trainer.push_to_hub(**kwargs)
|
177 |
+
else:
|
178 |
+
trainer.create_model_card(**kwargs)
|
179 |
+
|
180 |
+
|
181 |
+
if __name__ == '__main__':
|
182 |
+
main()
|
src/utils.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import sys
|
3 |
+
import locale
|
4 |
+
import io
|
5 |
+
|
6 |
+
def re_findall(pattern, string):
|
7 |
+
return [m.groupdict() for m in re.finditer(pattern, string)]
|
8 |
+
|
9 |
+
|
10 |
+
def jaccard(x1, x2, y1, y2):
|
11 |
+
# Calculate jaccard index
|
12 |
+
intersection = max(0, min(x2, y2)-max(x1, y1))
|
13 |
+
filled_union = max(x2, y2) - min(x1, y1)
|
14 |
+
return intersection/filled_union if filled_union > 0 else 0
|
15 |
+
|
16 |
+
|
17 |
+
def regex_search(text, pattern, group=1, default=None):
|
18 |
+
match = re.search(pattern, text)
|
19 |
+
return match.group(group) if match else default
|
20 |
+
|
21 |
+
|
22 |
+
def _windows_write_string(s, out, skip_errors=True):
|
23 |
+
""" Returns True if the string was written using special methods,
|
24 |
+
False if it has yet to be written out."""
|
25 |
+
# Adapted from http://stackoverflow.com/a/3259271/35070
|
26 |
+
|
27 |
+
import ctypes
|
28 |
+
import ctypes.wintypes
|
29 |
+
|
30 |
+
WIN_OUTPUT_IDS = {
|
31 |
+
1: -11,
|
32 |
+
2: -12,
|
33 |
+
}
|
34 |
+
|
35 |
+
try:
|
36 |
+
fileno = out.fileno()
|
37 |
+
except AttributeError:
|
38 |
+
# If the output stream doesn't have a fileno, it's virtual
|
39 |
+
return False
|
40 |
+
except io.UnsupportedOperation:
|
41 |
+
# Some strange Windows pseudo files?
|
42 |
+
return False
|
43 |
+
if fileno not in WIN_OUTPUT_IDS:
|
44 |
+
return False
|
45 |
+
|
46 |
+
GetStdHandle = ctypes.WINFUNCTYPE(
|
47 |
+
ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD)(
|
48 |
+
('GetStdHandle', ctypes.windll.kernel32))
|
49 |
+
h = GetStdHandle(WIN_OUTPUT_IDS[fileno])
|
50 |
+
|
51 |
+
WriteConsoleW = ctypes.WINFUNCTYPE(
|
52 |
+
ctypes.wintypes.BOOL, ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR,
|
53 |
+
ctypes.wintypes.DWORD, ctypes.POINTER(ctypes.wintypes.DWORD),
|
54 |
+
ctypes.wintypes.LPVOID)(('WriteConsoleW', ctypes.windll.kernel32))
|
55 |
+
written = ctypes.wintypes.DWORD(0)
|
56 |
+
|
57 |
+
GetFileType = ctypes.WINFUNCTYPE(ctypes.wintypes.DWORD, ctypes.wintypes.DWORD)(
|
58 |
+
('GetFileType', ctypes.windll.kernel32))
|
59 |
+
FILE_TYPE_CHAR = 0x0002
|
60 |
+
FILE_TYPE_REMOTE = 0x8000
|
61 |
+
GetConsoleMode = ctypes.WINFUNCTYPE(
|
62 |
+
ctypes.wintypes.BOOL, ctypes.wintypes.HANDLE,
|
63 |
+
ctypes.POINTER(ctypes.wintypes.DWORD))(
|
64 |
+
('GetConsoleMode', ctypes.windll.kernel32))
|
65 |
+
INVALID_HANDLE_VALUE = ctypes.wintypes.DWORD(-1).value
|
66 |
+
|
67 |
+
def not_a_console(handle):
|
68 |
+
if handle == INVALID_HANDLE_VALUE or handle is None:
|
69 |
+
return True
|
70 |
+
return ((GetFileType(handle) & ~FILE_TYPE_REMOTE) != FILE_TYPE_CHAR or GetConsoleMode(handle, ctypes.byref(ctypes.wintypes.DWORD())) == 0)
|
71 |
+
|
72 |
+
if not_a_console(h):
|
73 |
+
return False
|
74 |
+
|
75 |
+
def next_nonbmp_pos(s):
|
76 |
+
try:
|
77 |
+
return next(i for i, c in enumerate(s) if ord(c) > 0xffff)
|
78 |
+
except StopIteration:
|
79 |
+
return len(s)
|
80 |
+
|
81 |
+
while s:
|
82 |
+
count = min(next_nonbmp_pos(s), 1024)
|
83 |
+
|
84 |
+
ret = WriteConsoleW(
|
85 |
+
h, s, count if count else 2, ctypes.byref(written), None)
|
86 |
+
if ret == 0:
|
87 |
+
if skip_errors:
|
88 |
+
continue
|
89 |
+
else:
|
90 |
+
raise OSError('Failed to write string')
|
91 |
+
if not count: # We just wrote a non-BMP character
|
92 |
+
assert written.value == 2
|
93 |
+
s = s[1:]
|
94 |
+
else:
|
95 |
+
assert written.value > 0
|
96 |
+
s = s[written.value:]
|
97 |
+
return True
|
98 |
+
|
99 |
+
def preferredencoding():
|
100 |
+
"""Get preferred encoding.
|
101 |
+
Returns the best encoding scheme for the system, based on
|
102 |
+
locale.getpreferredencoding() and some further tweaks.
|
103 |
+
"""
|
104 |
+
try:
|
105 |
+
pref = locale.getpreferredencoding()
|
106 |
+
'TEST'.encode(pref)
|
107 |
+
except Exception:
|
108 |
+
pref = 'utf-8'
|
109 |
+
|
110 |
+
return pref
|
111 |
+
|
112 |
+
def safe_print(*objects, sep=' ', end='\n', out=None, encoding=None, flush=False):
|
113 |
+
"""
|
114 |
+
Ensure printing to standard output can be done safely (especially on Windows).
|
115 |
+
There are usually issues with printing emojis and non utf-8 characters.
|
116 |
+
"""
|
117 |
+
|
118 |
+
output_string = sep.join(map(lambda x: str(x), objects)) + end
|
119 |
+
|
120 |
+
if out is None:
|
121 |
+
out = sys.stdout
|
122 |
+
|
123 |
+
if sys.platform == 'win32' and encoding is None and hasattr(out, 'fileno'):
|
124 |
+
if _windows_write_string(output_string, out):
|
125 |
+
return
|
126 |
+
|
127 |
+
if 'b' in getattr(out, 'mode', '') or not hasattr(out, 'buffer'):
|
128 |
+
out.write(output_string)
|
129 |
+
else:
|
130 |
+
enc = encoding or getattr(out, 'encoding', None) or preferredencoding()
|
131 |
+
byt = output_string.encode(enc, 'ignore')
|
132 |
+
out.buffer.write(byt)
|
133 |
+
|
134 |
+
if flush and hasattr(out, 'flush'):
|
135 |
+
out.flush()
|
transcripts/auto/.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cached transcripts
|
2 |
+
# Ignore everything in this directory
|
3 |
+
*
|
4 |
+
# Except this file
|
5 |
+
!.gitignore
|
transcripts/manual/.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cached transcripts
|
2 |
+
# Ignore everything in this directory
|
3 |
+
*
|
4 |
+
# Except this file
|
5 |
+
!.gitignore
|