Shad0ws commited on
Commit
b7b7347
1 Parent(s): e935a20

Upload 21 files

Browse files
.github/FUNDING.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ github: xenova
2
+ ko_fi: xenova
3
+ custom: https://www.buymeacoffee.com/xenova
.github/workflows/check_large_file.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Check file size
2
+ on: # or directly `on: [push]` to run the action on every push on any branch
3
+ pull_request:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - name: Check large files
14
+ uses: ActionsDesk/lfs-warning@v2.0
15
+ with:
16
+ filesizelimit: 10485760 # this is 10MB so we can sync to HF Spaces
.github/workflows/sync_with_huggingface.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v2
14
+ with:
15
+ fetch-depth: 0
16
+ - name: Push to hub
17
+ env:
18
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
19
+ run: git push https://Xenova:$HF_TOKEN@huggingface.co/spaces/Xenova/sponsorblock-ml main
app.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from functools import partial
3
+ from math import ceil, floor
4
+ import streamlit.components.v1 as components
5
+ import streamlit as st
6
+ import sys
7
+ import os
8
+ import json
9
+ from urllib.parse import quote
10
+
11
+ # Allow direct execution
12
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
13
+
14
+ from preprocess import get_words
15
+ from predict import PredictArguments, SegmentationArguments, predict as pred
16
+ from shared import GeneralArguments, seconds_to_time, CATGEGORY_OPTIONS
17
+ from utils import regex_search
18
+ from model import get_model_tokenizer_classifier
19
+ from errors import TranscriptError
20
+
21
+ st.set_page_config(
22
+ page_title='SponsorBlock ML',
23
+ page_icon='🤖',
24
+ layout='wide',
25
+ # initial_sidebar_state="expanded",
26
+ menu_items={
27
+ # 'Get Help': 'https://github.com/xenova/sponsorblock-ml',
28
+ # 'Report a bug': 'https://github.com/xenova/sponsorblock-ml/issues/new/choose',
29
+ # 'About': "# This is a header. This is an *extremely* cool app!"
30
+ }
31
+ )
32
+
33
+
34
+ YT_VIDEO_REGEX = r'''(?x)^
35
+ (?:
36
+ # http(s):// or protocol-independent URL
37
+ (?:https?://|//)
38
+ (?:(?:(?:(?:\w+\.)?[yY][oO][uU][tT][uU][bB][eE](?:-nocookie|kids)?\.com/|
39
+ youtube\.googleapis\.com/) # the various hostnames, with wildcard subdomains
40
+ (?:.*?\#/)? # handle anchor (#/) redirect urls
41
+ (?: # the various things that can precede the ID:
42
+ # v/ or embed/ or e/
43
+ (?:(?:v|embed|e)/(?!videoseries))
44
+ |(?: # or the v= param in all its forms
45
+ # preceding watch(_popup|.php) or nothing (like /?v=xxxx)
46
+ (?:(?:watch|movie)(?:_popup)?(?:\.php)?/?)?
47
+ (?:\?|\#!?) # the params delimiter ? or # or #!
48
+ # any other preceding param (like /?s=tuff&v=xxxx or ?s=tuff&v=V36LpHqtcDY)
49
+ (?:.*?[&;])??
50
+ v=
51
+ )
52
+ ))
53
+ |(?:
54
+ youtu\.be # just youtu.be/xxxx
55
+ )/)
56
+ )? # all until now is optional -> you can pass the naked ID
57
+ # here is it! the YouTube video ID
58
+ (?P<id>[0-9A-Za-z_-]{11})'''
59
+
60
+ # https://github.com/google-research/text-to-text-transfer-transformer#released-model-checkpoints
61
+ # https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#experimental-t5-pre-trained-model-checkpoints
62
+
63
+ # https://huggingface.co/docs/transformers/model_doc/t5
64
+ # https://huggingface.co/docs/transformers/model_doc/t5v1.1
65
+
66
+
67
+ # Faster caching system for predictions (No need to hash)
68
+ @st.cache(persist=True, allow_output_mutation=True)
69
+ def create_prediction_cache():
70
+ return {}
71
+
72
+
73
+ @st.cache(persist=True, allow_output_mutation=True)
74
+ def create_function_cache():
75
+ return {}
76
+
77
+
78
+ prediction_cache = create_prediction_cache()
79
+ prediction_function_cache = create_function_cache()
80
+
81
+ MODELS = {
82
+ 'Small (293 MB)': {
83
+ 'pretrained': 'google/t5-v1_1-small',
84
+ 'repo_id': 'Xenova/sponsorblock-small',
85
+ 'num_parameters': '77M'
86
+ },
87
+ 'Base v1 (850 MB)': {
88
+ 'pretrained': 't5-base',
89
+ 'repo_id': 'Xenova/sponsorblock-base-v1',
90
+ 'num_parameters': '220M'
91
+ },
92
+
93
+ 'Base v1.1 (944 MB)': {
94
+ 'pretrained': 'google/t5-v1_1-base',
95
+ 'repo_id': 'Xenova/sponsorblock-base-v1.1',
96
+ 'num_parameters': '250M'
97
+ }
98
+ }
99
+
100
+ # Create per-model cache
101
+ for m in MODELS:
102
+ if m not in prediction_cache:
103
+ prediction_cache[m] = {}
104
+
105
+
106
+ CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier-v2'
107
+
108
+
109
+ TRANSCRIPT_TYPES = {
110
+ 'AUTO_MANUAL': {
111
+ 'label': 'Auto-generated (fallback to manual)',
112
+ 'type': 'auto',
113
+ 'fallback': 'manual'
114
+ },
115
+ 'MANUAL_AUTO': {
116
+ 'label': 'Manual (fallback to auto-generated)',
117
+ 'type': 'manual',
118
+ 'fallback': 'auto'
119
+ },
120
+ # 'TRANSLATED': 'Translated to English' # Coming soon
121
+ }
122
+
123
+
124
+ def predict_function(model_id, model, tokenizer, segmentation_args, classifier, video_id, words, ts_type_id):
125
+ cache_id = f'{video_id}_{ts_type_id}'
126
+
127
+ if cache_id not in prediction_cache[model_id]:
128
+ prediction_cache[model_id][cache_id] = pred(
129
+ video_id, model, tokenizer,
130
+ segmentation_args=segmentation_args,
131
+ words=words,
132
+ classifier=classifier
133
+ )
134
+ return prediction_cache[model_id][cache_id]
135
+
136
+
137
+ def load_predict(model_id):
138
+ model_info = MODELS[model_id]
139
+
140
+ if model_id not in prediction_function_cache:
141
+ # Use default segmentation and classification arguments
142
+ predict_args = PredictArguments(model_name_or_path=model_info['repo_id'])
143
+ general_args = GeneralArguments()
144
+ segmentation_args = SegmentationArguments()
145
+
146
+ model, tokenizer, classifier = get_model_tokenizer_classifier(predict_args, general_args)
147
+
148
+ prediction_function_cache[model_id] = partial(
149
+ predict_function, model_id, model, tokenizer, segmentation_args, classifier)
150
+
151
+
152
+ return prediction_function_cache[model_id]
153
+
154
+
155
+ def create_button(text, url):
156
+ return f"""<div class="row-widget stButton" style="text-align: center">
157
+ <a href="{url}" target="_blank" rel="noopener noreferrer" class="btn-link">
158
+ <button kind="primary" class="btn">{text}</button>
159
+ </a>
160
+ </div>"""
161
+
162
+
163
+ def main():
164
+ st.markdown("""<style>
165
+ .btn {
166
+ display: inline-flex;
167
+ -webkit-box-align: center;
168
+ align-items: center;
169
+ -webkit-box-pack: center;
170
+ justify-content: center;
171
+ font-weight: 600;
172
+ padding: 0.25rem 0.75rem;
173
+ border-radius: 0.25rem;
174
+ margin: 0px;
175
+ line-height: 1.5;
176
+ color: inherit;
177
+ width: auto;
178
+ user-select: none;
179
+ background-color: inherit;
180
+ border: 1px solid rgba(49, 51, 63, 0.2);
181
+ }
182
+ .btn-link {
183
+ color: inherit;
184
+ text-decoration: none;
185
+ }
186
+ </style>""", unsafe_allow_html=True)
187
+
188
+ top = st.container()
189
+ output = st.empty()
190
+
191
+ # Display heading and subheading
192
+ top.markdown('# PromoDetect')
193
+ top.markdown(
194
+ '##### Automatically detect in-video YouTube sponsorships, self/unpaid promotions, and interaction reminders.')
195
+
196
+ # Add controls
197
+
198
+ col1, col2 = top.columns(2)
199
+
200
+ with col1:
201
+ model_id = st.selectbox(
202
+ 'Select model', MODELS.keys(), index=0, on_change=output.empty)
203
+
204
+ with col2:
205
+ ts_type_id = st.selectbox(
206
+ 'Transcript type', TRANSCRIPT_TYPES.keys(), index=0, format_func=lambda x: TRANSCRIPT_TYPES[x]['label'], on_change=output.empty)
207
+
208
+ query_params = st.experimental_get_query_params()
209
+
210
+ video_id = None
211
+
212
+ if 'v' in query_params:
213
+ video_id = query_params['v'][0]
214
+
215
+ if video_id is None:
216
+ video_input = top.text_input('Video URL/ID:', on_change=output.empty)
217
+ else :
218
+ video_input = top.text_input('Video URL/ID:', on_change=output.empty,value = video_id)
219
+ categories = top.multiselect('Categories:',
220
+ CATGEGORY_OPTIONS.keys(),
221
+ CATGEGORY_OPTIONS.keys(),
222
+ format_func=CATGEGORY_OPTIONS.get, on_change=output.empty
223
+ )
224
+
225
+ # Hide segments with a confidence lower than
226
+ confidence_threshold = top.slider(
227
+ 'Confidence Threshold (%):', min_value=0, value=50, max_value=100, on_change=output.empty)
228
+
229
+ if len(video_input) == 0: # No input, do not continue
230
+ return
231
+
232
+ # Load prediction function
233
+ with st.spinner('Loading model...'):
234
+ predict = load_predict(model_id)
235
+
236
+ with output.container(): # Place all content in output container
237
+ video_id = regex_search(video_input, YT_VIDEO_REGEX)
238
+ if video_id is None:
239
+ st.exception(ValueError('Invalid YouTube URL/ID'))
240
+ return
241
+
242
+ try:
243
+ with st.spinner('Downloading transcript...'):
244
+ words = get_words(video_id,
245
+ transcript_type=TRANSCRIPT_TYPES[ts_type_id]['type'],
246
+ fallback=TRANSCRIPT_TYPES[ts_type_id]['fallback']
247
+ )
248
+ except TranscriptError:
249
+ pass
250
+
251
+ if not words:
252
+ st.error('No transcript found!')
253
+ return
254
+
255
+ with st.spinner('Running model...'):
256
+ predictions = predict(video_id, words, ts_type_id)
257
+
258
+ if len(predictions) == 0:
259
+ st.success('No segments found!')
260
+ return
261
+
262
+ submit_segments = []
263
+ for index, prediction in enumerate(predictions, start=1):
264
+ category_key = prediction['category'].upper()
265
+ if category_key not in categories:
266
+ continue # Skip
267
+
268
+ confidence = prediction['probability'] * 100
269
+
270
+ if confidence < confidence_threshold:
271
+ continue
272
+
273
+ submit_segments.append({
274
+ 'segment': [prediction['start'], prediction['end']],
275
+ 'category': prediction['category'],
276
+ 'actionType': 'skip'
277
+ })
278
+ start_time = seconds_to_time(prediction['start'])
279
+ end_time = seconds_to_time(prediction['end'])
280
+ with st.expander(
281
+ f"[{category_key}] Prediction #{index} ({start_time} \u2192 {end_time})"
282
+ ):
283
+
284
+ url = f"https://www.youtube-nocookie.com/embed/{video_id}?&start={floor(prediction['start'])}&end={ceil(prediction['end'])}"
285
+ # autoplay=1controls=0&&modestbranding=1&fs=0
286
+
287
+ # , width=None, height=None, scrolling=False
288
+ components.iframe(url, width=670, height=376)
289
+
290
+ text = ' '.join(w['text'] for w in prediction['words'])
291
+ st.write(f"**Times:** {start_time} \u2192 {end_time}")
292
+ st.write(
293
+ f"**Category:** {CATGEGORY_OPTIONS[category_key]}")
294
+ st.write(f"**Confidence:** {confidence:.2f}%")
295
+ st.write(f'**Text:** "{text}"')
296
+
297
+ if not submit_segments:
298
+ st.success(
299
+ f'No segments found! ({len(predictions)} ignored due to filters/settings)')
300
+ return
301
+
302
+ num_hidden = len(predictions) - len(submit_segments)
303
+ if num_hidden > 0:
304
+ st.info(
305
+ f'{num_hidden} predictions hidden (adjust the settings and filters to view them all).')
306
+
307
+ json_data = quote(json.dumps(submit_segments))
308
+ link = f'https://www.youtube.com/watch?v={video_id}#segments={json_data}'
309
+ st.markdown(create_button('Submit Segments', link),
310
+ unsafe_allow_html=True)
311
+
312
+ # st.markdown(f"""<div style="text-align: center;font-size: 16px;margin-top: 6px">
313
+ # <a href="https://wiki.sponsor.ajay.app/w/Automating_Submissions" target="_blank" rel="noopener noreferrer">(Review before submitting!)</a>
314
+ # </div>""", unsafe_allow_html=True)
315
+
316
+
317
+ if __name__ == '__main__':
318
+ main()
data/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Ignore everything in this directory
2
+ *
3
+ # Except this file
4
+ !.gitignore
models/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Ignore everything in this directory
2
+ *
3
+ # Except this file
4
+ !.gitignore
raw/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Ignore everything in this directory
2
+ *
3
+ # Except this file
4
+ !.gitignore
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ datasets
3
+ youtube_transcript_api
4
+ torch
5
+ pandas
6
+ numpy
7
+ sentencepiece
src/classify.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TextClassificationPipeline
2
+ import preprocess
3
+ import segment
4
+
5
+
6
+ class SponsorBlockClassificationPipeline(TextClassificationPipeline):
7
+ def __init__(self, model, tokenizer):
8
+ device = next(model.parameters()).device.index
9
+ if device is None:
10
+ device = -1
11
+ super().__init__(model=model, tokenizer=tokenizer,
12
+ return_all_scores=True, truncation=True, device=device)
13
+
14
+ def preprocess(self, data, **tokenizer_kwargs):
15
+ # TODO add support for lists
16
+ texts = []
17
+
18
+ if not isinstance(data, list):
19
+ data = [data]
20
+
21
+ for d in data:
22
+ if isinstance(d, dict): # Otherwise, get data from transcript
23
+ words = preprocess.get_words(d['video_id'])
24
+ segment_words = segment.extract_segment(
25
+ words, d['start'], d['end'])
26
+ text = preprocess.clean_text(
27
+ ' '.join(x['text'] for x in segment_words))
28
+ texts.append(text)
29
+ elif isinstance(d, str): # If string, assume this is what user wants to classify
30
+ texts.append(d)
31
+ else:
32
+ raise ValueError(f'Invalid input type: "{type(d)}"')
33
+
34
+ return self.tokenizer(
35
+ texts, return_tensors=self.framework, **tokenizer_kwargs)
36
+
37
+
38
+ def main():
39
+ pass
40
+
41
+
42
+ if __name__ == '__main__':
43
+ main()
src/errors.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class SponsorBlockException(Exception):
3
+ """Base class for all sponsor block exceptions"""
4
+ pass
5
+
6
+
7
+ class InferenceException(SponsorBlockException):
8
+ """An exception occurred while predicting sponsor segments"""
9
+ pass
10
+
11
+
12
+ class TranscriptError(SponsorBlockException):
13
+ """An exception occurred while retrieving the video transcript"""
14
+ pass
15
+
16
+
17
+ class ModelError(SponsorBlockException):
18
+ """Base class for model-related errors"""
19
+ pass
20
+
21
+
22
+ class ModelLoadError(ModelError):
23
+ """An exception occurred while loading the model"""
24
+ pass
src/evaluate.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from model import get_model_tokenizer_classifier, InferenceArguments
3
+ from utils import jaccard, safe_print
4
+ from transformers import HfArgumentParser
5
+ from preprocess import get_words, clean_text
6
+ from shared import GeneralArguments, DatasetArguments
7
+ from predict import predict
8
+ from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
9
+ import pandas as pd
10
+ from dataclasses import dataclass, field
11
+ from typing import Optional
12
+ from tqdm import tqdm
13
+ import json
14
+ import os
15
+ import random
16
+ from shared import seconds_to_time
17
+ from urllib.parse import quote
18
+ import logging
19
+
20
+ logging.basicConfig()
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class EvaluationArguments(InferenceArguments):
26
+ """Arguments pertaining to how evaluation will occur."""
27
+ output_file: Optional[str] = field(
28
+ default='metrics.csv',
29
+ metadata={
30
+ 'help': 'Save metrics to output file'
31
+ }
32
+ )
33
+
34
+ skip_missing: bool = field(
35
+ default=False,
36
+ metadata={
37
+ 'help': 'Whether to skip checking for missing segments. If False, predictions will be made.'
38
+ }
39
+ )
40
+ skip_incorrect: bool = field(
41
+ default=False,
42
+ metadata={
43
+ 'help': 'Whether to skip checking for incorrect segments. If False, classifications will be made on existing segments.'
44
+ }
45
+ )
46
+
47
+
48
+ def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
49
+ """Attach sponsor segments to closest prediction"""
50
+ for prediction in predictions:
51
+ prediction['best_overlap'] = 0
52
+ prediction['best_sponsorship'] = None
53
+
54
+ # Assign predictions to actual (labelled) sponsored segments
55
+ for sponsor_segment in sponsor_segments:
56
+ j = jaccard(prediction['start'], prediction['end'],
57
+ sponsor_segment['start'], sponsor_segment['end'])
58
+ if prediction['best_overlap'] < j:
59
+ prediction['best_overlap'] = j
60
+ prediction['best_sponsorship'] = sponsor_segment
61
+
62
+ return sponsor_segments
63
+
64
+
65
+ def calculate_metrics(labelled_words, predictions):
66
+
67
+ metrics = {
68
+ 'true_positive': 0, # Is sponsor, predicted sponsor
69
+ # Is sponsor, predicted not sponsor (i.e., missed it - bad)
70
+ 'false_negative': 0,
71
+ # Is not sponsor, predicted sponsor (classified incorectly, not that bad since we do manual checking afterwards)
72
+ 'false_positive': 0,
73
+ 'true_negative': 0, # Is not sponsor, predicted not sponsor
74
+ }
75
+
76
+ metrics['video_duration'] = word_end(
77
+ labelled_words[-1])-word_start(labelled_words[0])
78
+
79
+ for index, word in enumerate(labelled_words):
80
+ if index >= len(labelled_words) - 1:
81
+ continue
82
+
83
+ duration = word_end(word) - word_start(word)
84
+
85
+ predicted_sponsor = False
86
+ for p in predictions:
87
+ # Is in some prediction
88
+ if p['start'] <= word['start'] <= p['end']:
89
+ predicted_sponsor = True
90
+ break
91
+
92
+ if predicted_sponsor:
93
+ # total_positive_time += duration
94
+ if word.get('category') is not None: # Is actual sponsor
95
+ metrics['true_positive'] += duration
96
+ else:
97
+ metrics['false_positive'] += duration
98
+ else:
99
+ # total_negative_time += duration
100
+ if word.get('category') is not None: # Is actual sponsor
101
+ metrics['false_negative'] += duration
102
+ else:
103
+ metrics['true_negative'] += duration
104
+
105
+ # NOTE In cases where we encounter division by 0, we say that the value is 1
106
+ # https://stats.stackexchange.com/a/1775
107
+ # (Precision) TP+FP=0: means that all instances were predicted as negative
108
+ # (Recall) TP+FN=0: means that there were no positive cases in the input data
109
+
110
+ # The fraction of predictions our model got right
111
+ # Can simplify, but use full formula
112
+ z = metrics['true_positive'] + metrics['true_negative'] + \
113
+ metrics['false_positive'] + metrics['false_negative']
114
+ metrics['accuracy'] = (
115
+ (metrics['true_positive'] + metrics['true_negative']) / z) if z > 0 else 1
116
+
117
+ # What proportion of positive identifications was actually correct?
118
+ z = metrics['true_positive'] + metrics['false_positive']
119
+ metrics['precision'] = (metrics['true_positive'] / z) if z > 0 else 1
120
+
121
+ # What proportion of actual positives was identified correctly?
122
+ z = metrics['true_positive'] + metrics['false_negative']
123
+ metrics['recall'] = (metrics['true_positive'] / z) if z > 0 else 1
124
+
125
+ # https://deepai.org/machine-learning-glossary-and-terms/f-score
126
+
127
+ s = metrics['precision'] + metrics['recall']
128
+ metrics['f-score'] = (2 * (metrics['precision'] *
129
+ metrics['recall']) / s) if s > 0 else 0
130
+
131
+ return metrics
132
+
133
+
134
+ def main():
135
+ logger.setLevel(logging.DEBUG)
136
+
137
+ hf_parser = HfArgumentParser((
138
+ EvaluationArguments,
139
+ DatasetArguments,
140
+ SegmentationArguments,
141
+ GeneralArguments
142
+ ))
143
+
144
+ evaluation_args, dataset_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()
145
+
146
+ if evaluation_args.skip_missing and evaluation_args.skip_incorrect:
147
+ logger.error('ERROR: Nothing to do')
148
+ return
149
+
150
+ # Load labelled data:
151
+ final_path = os.path.join(
152
+ dataset_args.data_dir, dataset_args.processed_file)
153
+
154
+ if not os.path.exists(final_path):
155
+ logger.error('ERROR: Processed database not found.\n'
156
+ f'Run `python src/preprocess.py --update_database --do_create` to generate "{final_path}".')
157
+ return
158
+
159
+ model, tokenizer, classifier = get_model_tokenizer_classifier(
160
+ evaluation_args, general_args)
161
+
162
+ with open(final_path) as fp:
163
+ final_data = json.load(fp)
164
+
165
+ if evaluation_args.video_ids: # Use specified
166
+ video_ids = evaluation_args.video_ids
167
+
168
+ else: # Use items found in preprocessed database
169
+ video_ids = list(final_data.keys())
170
+ random.shuffle(video_ids)
171
+
172
+ if evaluation_args.start_index is not None:
173
+ video_ids = video_ids[evaluation_args.start_index:]
174
+
175
+ if evaluation_args.max_videos is not None:
176
+ video_ids = video_ids[:evaluation_args.max_videos]
177
+
178
+ out_metrics = []
179
+
180
+ all_metrics = {}
181
+ if not evaluation_args.skip_missing:
182
+ all_metrics['total_prediction_accuracy'] = 0
183
+ all_metrics['total_prediction_precision'] = 0
184
+ all_metrics['total_prediction_recall'] = 0
185
+ all_metrics['total_prediction_fscore'] = 0
186
+
187
+ if not evaluation_args.skip_incorrect:
188
+ all_metrics['classifier_segment_correct'] = 0
189
+ all_metrics['classifier_segment_count'] = 0
190
+
191
+ metric_count = 0
192
+
193
+ postfix_info = {}
194
+
195
+ try:
196
+ with tqdm(video_ids) as progress:
197
+ for video_index, video_id in enumerate(progress):
198
+ progress.set_description(f'Processing {video_id}')
199
+
200
+ words = get_words(video_id)
201
+ if not words:
202
+ continue
203
+
204
+ # Get labels
205
+ sponsor_segments = final_data.get(video_id)
206
+
207
+ # Reset previous
208
+ missed_segments = []
209
+ incorrect_segments = []
210
+
211
+ current_metrics = {
212
+ 'video_id': video_id
213
+ }
214
+ metric_count += 1
215
+
216
+ if not evaluation_args.skip_missing: # Make predictions
217
+ predictions = predict(video_id, model, tokenizer, segmentation_args,
218
+ classifier=classifier,
219
+ min_probability=evaluation_args.min_probability)
220
+
221
+ if sponsor_segments:
222
+ labelled_words = add_labels_to_words(
223
+ words, sponsor_segments)
224
+
225
+ current_metrics.update(
226
+ calculate_metrics(labelled_words, predictions))
227
+
228
+ all_metrics['total_prediction_accuracy'] += current_metrics['accuracy']
229
+ all_metrics['total_prediction_precision'] += current_metrics['precision']
230
+ all_metrics['total_prediction_recall'] += current_metrics['recall']
231
+ all_metrics['total_prediction_fscore'] += current_metrics['f-score']
232
+
233
+ # Just for display purposes
234
+ postfix_info.update({
235
+ 'accuracy': all_metrics['total_prediction_accuracy']/metric_count,
236
+ 'precision': all_metrics['total_prediction_precision']/metric_count,
237
+ 'recall': all_metrics['total_prediction_recall']/metric_count,
238
+ 'f-score': all_metrics['total_prediction_fscore']/metric_count,
239
+ })
240
+
241
+ sponsor_segments = attach_predictions_to_sponsor_segments(
242
+ predictions, sponsor_segments)
243
+
244
+ # Identify possible issues:
245
+ for prediction in predictions:
246
+ if prediction['best_sponsorship'] is not None:
247
+ continue
248
+
249
+ prediction_words = prediction.pop('words', [])
250
+
251
+ # Attach original text to missed segments
252
+ prediction['text'] = ' '.join(
253
+ x['text'] for x in prediction_words)
254
+ missed_segments.append(prediction)
255
+
256
+ else:
257
+ # Not in database (all segments missed)
258
+ missed_segments = predictions
259
+
260
+ if not evaluation_args.skip_incorrect and sponsor_segments:
261
+ # Check for incorrect segments using the classifier
262
+
263
+ segments_to_check = []
264
+ cleaned_texts = [] # Texts to send through tokenizer
265
+ for sponsor_segment in sponsor_segments:
266
+ segment_words = extract_segment(
267
+ words, sponsor_segment['start'], sponsor_segment['end'])
268
+ sponsor_segment['text'] = ' '.join(
269
+ x['text'] for x in segment_words)
270
+
271
+ duration = sponsor_segment['end'] - \
272
+ sponsor_segment['start']
273
+ wps = (len(segment_words) /
274
+ duration) if duration > 0 else 0
275
+ if wps < 1.5:
276
+ continue
277
+
278
+ # Do not worry about those that are locked or have enough votes
279
+ # or segment['votes'] > 5:
280
+ if sponsor_segment['locked']:
281
+ continue
282
+
283
+ cleaned_texts.append(
284
+ clean_text(sponsor_segment['text']))
285
+ segments_to_check.append(sponsor_segment)
286
+
287
+ if segments_to_check: # Some segments to check
288
+
289
+ segments_scores = classifier(cleaned_texts)
290
+
291
+ num_correct = 0
292
+ for segment, scores in zip(segments_to_check, segments_scores):
293
+
294
+ fixed_scores = {
295
+ score['label']: score['score']
296
+ for score in scores
297
+ }
298
+
299
+ all_metrics['classifier_segment_count'] += 1
300
+
301
+ prediction = max(scores, key=lambda x: x['score'])
302
+ predicted_category = prediction['label'].lower()
303
+
304
+ if predicted_category == segment['category']:
305
+ num_correct += 1
306
+ continue # Ignore correct segments
307
+
308
+ segment.update({
309
+ 'predicted': predicted_category,
310
+ 'scores': fixed_scores
311
+ })
312
+
313
+ incorrect_segments.append(segment)
314
+
315
+ current_metrics['num_segments'] = len(
316
+ segments_to_check)
317
+ current_metrics['classified_correct'] = num_correct
318
+
319
+ all_metrics['classifier_segment_correct'] += num_correct
320
+
321
+ if all_metrics['classifier_segment_count'] > 0:
322
+ postfix_info['classifier_accuracy'] = all_metrics['classifier_segment_correct'] / \
323
+ all_metrics['classifier_segment_count']
324
+
325
+ out_metrics.append(current_metrics)
326
+ progress.set_postfix(postfix_info)
327
+
328
+ if missed_segments or incorrect_segments:
329
+
330
+ if evaluation_args.output_as_json:
331
+ to_print = {'video_id': video_id}
332
+
333
+ if missed_segments:
334
+ to_print['missed'] = missed_segments
335
+
336
+ if incorrect_segments:
337
+ to_print['incorrect'] = incorrect_segments
338
+
339
+ safe_print(json.dumps(to_print))
340
+
341
+ else:
342
+ safe_print(
343
+ f'Issues identified for {video_id} (#{video_index})')
344
+ # Potentially missed segments (model predicted, but not in database)
345
+ if missed_segments:
346
+ safe_print(' - Missed segments:')
347
+ segments_to_submit = []
348
+ for i, missed_segment in enumerate(missed_segments, start=1):
349
+ safe_print(f'\t#{i}:', seconds_to_time(
350
+ missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
351
+ safe_print('\t\tText: "',
352
+ missed_segment['text'], '"', sep='')
353
+ safe_print('\t\tCategory:',
354
+ missed_segment.get('category'))
355
+ if 'probability' in missed_segment:
356
+ safe_print('\t\tProbability:',
357
+ missed_segment['probability'])
358
+
359
+ segments_to_submit.append({
360
+ 'segment': [missed_segment['start'], missed_segment['end']],
361
+ 'category': missed_segment['category'].lower(),
362
+ 'actionType': 'skip'
363
+ })
364
+
365
+ json_data = quote(json.dumps(segments_to_submit))
366
+ safe_print(
367
+ f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')
368
+
369
+ # Incorrect segments (in database, but incorrectly classified)
370
+ if incorrect_segments:
371
+ safe_print(' - Incorrect segments:')
372
+ for i, incorrect_segment in enumerate(incorrect_segments, start=1):
373
+ safe_print(f'\t#{i}:', seconds_to_time(
374
+ incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
375
+
376
+ safe_print(
377
+ '\t\tText: "', incorrect_segment['text'], '"', sep='')
378
+ safe_print(
379
+ '\t\tUUID:', incorrect_segment['uuid'])
380
+ safe_print(
381
+ '\t\tVotes:', incorrect_segment['votes'])
382
+ safe_print(
383
+ '\t\tViews:', incorrect_segment['views'])
384
+ safe_print('\t\tLocked:',
385
+ incorrect_segment['locked'])
386
+
387
+ safe_print('\t\tCurrent Category:',
388
+ incorrect_segment['category'])
389
+ safe_print('\t\tPredicted Category:',
390
+ incorrect_segment['predicted'])
391
+ safe_print('\t\tProbabilities:')
392
+ for label, score in incorrect_segment['scores'].items():
393
+ safe_print(
394
+ f"\t\t\t{label}: {score}")
395
+
396
+ safe_print()
397
+
398
+ except KeyboardInterrupt:
399
+ pass
400
+
401
+ df = pd.DataFrame(out_metrics)
402
+
403
+ df.to_csv(evaluation_args.output_file)
404
+ logger.info(df.mean())
405
+
406
+
407
+ if __name__ == '__main__':
408
+ main()
src/model.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, TrainingArguments
2
+ from shared import CustomTokens, GeneralArguments
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional, Union
5
+ import torch
6
+ import classify
7
+ import base64
8
+ import re
9
+ import requests
10
+ import json
11
+ import logging
12
+
13
+ logging.basicConfig()
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Public innertube key (b64 encoded so that it is not incorrectly flagged)
17
+ INNERTUBE_KEY = base64.b64decode(
18
+ b'QUl6YVN5QU9fRkoyU2xxVThRNFNURUhMR0NpbHdfWTlfMTFxY1c4').decode()
19
+
20
+ YT_CONTEXT = {
21
+ 'client': {
22
+ 'userAgent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36,gzip(gfe)',
23
+ 'clientName': 'WEB',
24
+ 'clientVersion': '2.20211221.00.00',
25
+ }
26
+ }
27
+ _YT_INITIAL_DATA_RE = r'(?:window\s*\[\s*["\']ytInitialData["\']\s*\]|ytInitialData)\s*=\s*({.+?})\s*;\s*(?:var\s+meta|</script|\n)'
28
+
29
+
30
+ def get_all_channel_vids(channel_id):
31
+ continuation = None
32
+ while True:
33
+ if continuation is None:
34
+ params = {'list': channel_id.replace('UC', 'UU', 1)}
35
+ response = requests.get(
36
+ 'https://www.youtube.com/playlist', params=params)
37
+ items = json.loads(re.search(_YT_INITIAL_DATA_RE, response.text).group(1))['contents']['twoColumnBrowseResultsRenderer']['tabs'][0]['tabRenderer']['content'][
38
+ 'sectionListRenderer']['contents'][0]['itemSectionRenderer']['contents'][0]['playlistVideoListRenderer']['contents']
39
+ else:
40
+ params = {'key': INNERTUBE_KEY}
41
+ data = {
42
+ 'context': YT_CONTEXT,
43
+ 'continuation': continuation
44
+ }
45
+ response = requests.post(
46
+ 'https://www.youtube.com/youtubei/v1/browse', params=params, json=data)
47
+ items = response.json()[
48
+ 'onResponseReceivedActions'][0]['appendContinuationItemsAction']['continuationItems']
49
+
50
+ new_token = None
51
+ for vid in items:
52
+ info = vid.get('playlistVideoRenderer')
53
+ if info:
54
+ yield info['videoId']
55
+ continue
56
+
57
+ info = vid.get('continuationItemRenderer')
58
+ if info:
59
+ new_token = info['continuationEndpoint']['continuationCommand']['token']
60
+
61
+ if new_token is None:
62
+ break
63
+ continuation = new_token
64
+
65
+
66
+
67
+ @dataclass
68
+ class ModelArguments:
69
+ """
70
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
71
+ """
72
+
73
+ model_name_or_path: str = field(
74
+ default=None,
75
+ metadata={
76
+ 'help': 'Path to pretrained model or model identifier from huggingface.co/models'
77
+ }
78
+ )
79
+
80
+ cache_dir: Optional[str] = field(
81
+ default='models',
82
+ metadata={
83
+ 'help': 'Where to store the pretrained models downloaded from huggingface.co'
84
+ },
85
+ )
86
+ use_fast_tokenizer: bool = field(
87
+ default=True,
88
+ metadata={
89
+ 'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'
90
+ },
91
+ )
92
+ model_revision: str = field(
93
+ default='main',
94
+ metadata={
95
+ 'help': 'The specific model version to use (can be a branch name, tag name or commit id).'
96
+ },
97
+ )
98
+ use_auth_token: bool = field(
99
+ default=False,
100
+ metadata={
101
+ 'help': 'Will use the token generated when running `transformers-cli login` (necessary to use this script '
102
+ 'with private models).'
103
+ },
104
+ )
105
+
106
+ import itertools
107
+ from errors import InferenceException, ModelLoadError
108
+
109
+ @dataclass
110
+ class InferenceArguments(ModelArguments):
111
+
112
+ model_name_or_path: str = field(
113
+ default='Xenova/sponsorblock-small',
114
+ metadata={
115
+ 'help': 'Path to pretrained model used for prediction'
116
+ }
117
+ )
118
+ classifier_model_name_or_path: str = field(
119
+ default='EColi/SB_Classifier',
120
+ metadata={
121
+ 'help': 'Use a pretrained classifier'
122
+ }
123
+ )
124
+
125
+ max_videos: Optional[int] = field(
126
+ default=None,
127
+ metadata={
128
+ 'help': 'The number of videos to test on'
129
+ }
130
+ )
131
+ start_index: int = field(default=None, metadata={
132
+ 'help': 'Video to start the evaluation at.'})
133
+ channel_id: Optional[str] = field(
134
+ default=None,
135
+ metadata={
136
+ 'help': 'Used to evaluate a channel'
137
+ }
138
+ )
139
+ video_ids: str = field(
140
+ default_factory=lambda: [],
141
+ metadata={
142
+ 'nargs': '+'
143
+ }
144
+ )
145
+
146
+ output_as_json: bool = field(default=False, metadata={
147
+ 'help': 'Output evaluations as JSON'})
148
+
149
+ min_probability: float = field(
150
+ default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
151
+
152
+ def __post_init__(self):
153
+
154
+ self.video_ids = list(map(str.strip, self.video_ids))
155
+
156
+ if any(len(video_id) != 11 for video_id in self.video_ids):
157
+ raise InferenceException('Invalid video IDs (length not 11)')
158
+
159
+ if self.channel_id is not None:
160
+ start = self.start_index or 0
161
+ end = None if self.max_videos is None else start + self.max_videos
162
+
163
+ channel_video_ids = list(itertools.islice(get_all_channel_vids(
164
+ self.channel_id), start, end))
165
+ logger.info(
166
+ f'Found {len(channel_video_ids)} for channel {self.channel_id}')
167
+
168
+ self.video_ids += channel_video_ids
169
+
170
+
171
+
172
+ def get_model_tokenizer_classifier(inference_args: InferenceArguments, general_args: GeneralArguments):
173
+
174
+ original_path = inference_args.model_name_or_path
175
+
176
+ # Load main model and tokenizer
177
+ model, tokenizer = get_model_tokenizer(inference_args, general_args)
178
+
179
+ # Load classifier
180
+ inference_args.model_name_or_path = inference_args.classifier_model_name_or_path
181
+ classifier_model, classifier_tokenizer = get_model_tokenizer(
182
+ inference_args, general_args, model_type='classifier')
183
+
184
+ classifier = classify.SponsorBlockClassificationPipeline(
185
+ classifier_model, classifier_tokenizer)
186
+
187
+ # Reset to original model_name_or_path
188
+ inference_args.model_name_or_path = original_path
189
+
190
+ return model, tokenizer, classifier
191
+
192
+
193
+ def get_model_tokenizer(model_args: ModelArguments, general_args: Union[GeneralArguments, TrainingArguments] = None, config_args=None, model_type='seq2seq'):
194
+ if model_args.model_name_or_path is None:
195
+ raise ModelLoadError('Must specify --model_name_or_path')
196
+
197
+ if config_args is None:
198
+ config_args = {}
199
+
200
+ use_auth_token = True if model_args.use_auth_token else None
201
+
202
+ config = AutoConfig.from_pretrained(
203
+ model_args.model_name_or_path,
204
+ cache_dir=model_args.cache_dir,
205
+ revision=model_args.model_revision,
206
+ use_auth_token=use_auth_token,
207
+ **config_args
208
+ )
209
+
210
+ tokenizer = AutoTokenizer.from_pretrained(
211
+ model_args.model_name_or_path,
212
+ cache_dir=model_args.cache_dir,
213
+ use_fast=model_args.use_fast_tokenizer,
214
+ revision=model_args.model_revision,
215
+ use_auth_token=use_auth_token,
216
+ )
217
+
218
+ model_type = AutoModelForSeq2SeqLM if model_type == 'seq2seq' else AutoModelForSequenceClassification
219
+ model = model_type.from_pretrained(
220
+ model_args.model_name_or_path,
221
+ config=config,
222
+ cache_dir=model_args.cache_dir,
223
+ revision=model_args.model_revision,
224
+ use_auth_token=use_auth_token,
225
+ )
226
+
227
+ # Add custom tokens
228
+ CustomTokens.add_custom_tokens(tokenizer)
229
+ model.resize_token_embeddings(len(tokenizer))
230
+
231
+ # Potentially move model to gpu
232
+ if general_args is not None and not general_args.no_cuda:
233
+ model.to('cuda' if torch.cuda.is_available() else 'cpu')
234
+
235
+ return model, tokenizer
src/predict.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import HfArgumentParser
3
+ from dataclasses import dataclass, field
4
+ import logging
5
+ from shared import CustomTokens, extract_sponsor_matches, GeneralArguments, seconds_to_time
6
+ from segment import (
7
+ generate_segments,
8
+ extract_segment,
9
+ MIN_SAFETY_TOKENS,
10
+ SAFETY_TOKENS_PERCENTAGE,
11
+ word_start,
12
+ word_end,
13
+ SegmentationArguments
14
+ )
15
+ import preprocess
16
+ from errors import TranscriptError
17
+ from model import get_model_tokenizer_classifier, InferenceArguments
18
+
19
+ logging.basicConfig()
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class PredictArguments(InferenceArguments):
25
+ video_id: str = field(
26
+ default=None,
27
+ metadata={
28
+ 'help': 'Video to predict segments for'}
29
+ )
30
+
31
+ def __post_init__(self):
32
+ if self.video_id is not None:
33
+ self.video_ids.append(self.video_id)
34
+
35
+ super().__post_init__()
36
+
37
+
38
+ MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
39
+ MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
40
+
41
+ # Any prediction whose start time is <= this will be set to start at 0
42
+ START_TIME_ZERO_THRESHOLD = 0.08
43
+
44
+
45
+ def filter_and_add_probabilities(predictions, classifier, min_probability):
46
+ """Use classifier to filter predictions"""
47
+ if not predictions:
48
+ return predictions
49
+
50
+ # We update the predicted category from the extractive transformer
51
+ # if the classifier is confident enough it is another category
52
+
53
+ texts = [
54
+ preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
55
+ for pred in predictions
56
+ ]
57
+ classifications = classifier(texts)
58
+
59
+ filtered_predictions = []
60
+ for prediction, probabilities in zip(predictions, classifications):
61
+ predicted_probabilities = {
62
+ p['label'].lower(): p['score'] for p in probabilities}
63
+
64
+ # Get best category + probability
65
+ classifier_category = max(
66
+ predicted_probabilities, key=predicted_probabilities.get)
67
+ classifier_probability = predicted_probabilities[classifier_category]
68
+
69
+ if (prediction['category'] not in predicted_probabilities) \
70
+ or (classifier_category != 'none' and classifier_probability > 0.5): # TODO make param
71
+ # Unknown category or we are confident enough to overrule,
72
+ # so change category to what was predicted by classifier
73
+ prediction['category'] = classifier_category
74
+
75
+ if prediction['category'] == 'none':
76
+ continue # Ignore if categorised as nothing
77
+
78
+ prediction['probability'] = predicted_probabilities[prediction['category']]
79
+
80
+ if min_probability is not None and prediction['probability'] < min_probability:
81
+ continue # Ignore if below threshold
82
+
83
+ # TODO add probabilities, but remove None and normalise rest
84
+ prediction['probabilities'] = predicted_probabilities
85
+
86
+ # if prediction['probability'] < classifier_args.min_probability:
87
+ # continue
88
+
89
+ filtered_predictions.append(prediction)
90
+
91
+ return filtered_predictions
92
+
93
+
94
+ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifier=None, min_probability=None):
95
+ # Allow words to be passed in so that we don't have to get the words if we already have them
96
+ if words is None:
97
+ words = preprocess.get_words(video_id)
98
+ if not words:
99
+ raise TranscriptError('Unable to retrieve transcript')
100
+
101
+ segments = generate_segments(
102
+ words,
103
+ tokenizer,
104
+ segmentation_args
105
+ )
106
+
107
+ predictions = segments_to_predictions(segments, model, tokenizer)
108
+ # Add words back to time_ranges
109
+ for prediction in predictions:
110
+ # Stores words in the range
111
+ prediction['words'] = extract_segment(
112
+ words, prediction['start'], prediction['end'])
113
+
114
+ if classifier is not None:
115
+ predictions = filter_and_add_probabilities(
116
+ predictions, classifier, min_probability)
117
+
118
+ return predictions
119
+
120
+
121
+ def greedy_match(list, sublist):
122
+ # Return index and length of longest matching sublist
123
+
124
+ best_i = -1
125
+ best_j = -1
126
+ best_k = 0
127
+
128
+ for i in range(len(list)): # Start position in main list
129
+ for j in range(len(sublist)): # Start position in sublist
130
+ for k in range(len(sublist)-j, 0, -1): # Width of sublist window
131
+ if k > best_k and list[i:i+k] == sublist[j:j+k]:
132
+ best_i, best_j, best_k = i, j, k
133
+ break # Since window size decreases
134
+
135
+ return best_i, best_j, best_k
136
+
137
+
138
+ def predict_sponsor_from_texts(texts, model, tokenizer):
139
+ clean_texts = list(map(preprocess.clean_text, texts))
140
+ return predict_sponsor_from_cleaned_texts(clean_texts, model, tokenizer)
141
+
142
+
143
+ def predict_sponsor_from_cleaned_texts(cleaned_texts, model, tokenizer):
144
+ """Given a body of text, predict the words which are part of the sponsor"""
145
+ model_device = next(model.parameters()).device
146
+
147
+ decoded_outputs = []
148
+ # Do individually, to avoid running out of memory for long videos
149
+ for cleaned_words in cleaned_texts:
150
+ text = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value + \
151
+ ' '.join(cleaned_words)
152
+ input_ids = tokenizer(text, return_tensors='pt',
153
+ truncation=True).input_ids.to(model_device)
154
+
155
+ # Optimise output length so that we do not generate unnecessarily long texts
156
+ max_out_len = round(min(
157
+ max(
158
+ len(input_ids[0])/SAFETY_TOKENS_PERCENTAGE,
159
+ len(input_ids[0]) + MIN_SAFETY_TOKENS
160
+ ),
161
+ model.model_dim)
162
+ )
163
+
164
+ outputs = model.generate(input_ids, max_length=max_out_len)
165
+ decoded_outputs.append(tokenizer.decode(
166
+ outputs[0], skip_special_tokens=True))
167
+
168
+ return decoded_outputs
169
+
170
+
171
+ def segments_to_predictions(segments, model, tokenizer):
172
+ predicted_time_ranges = []
173
+
174
+ cleaned_texts = [
175
+ [x['cleaned'] for x in cleaned_segment]
176
+ for cleaned_segment in segments
177
+ ]
178
+
179
+ sponsorship_texts = predict_sponsor_from_cleaned_texts(
180
+ cleaned_texts, model, tokenizer)
181
+
182
+ matches = extract_sponsor_matches(sponsorship_texts)
183
+
184
+ for segment_matches, cleaned_batch, segment in zip(matches, cleaned_texts, segments):
185
+
186
+ for match in segment_matches: # one segment might contain multiple sponsors/ir/selfpromos
187
+
188
+ matched_text = match['text'].split()
189
+
190
+ i1, j1, k1 = greedy_match(
191
+ cleaned_batch, matched_text[:MATCH_WINDOW])
192
+ i2, j2, k2 = greedy_match(
193
+ cleaned_batch, matched_text[-MATCH_WINDOW:])
194
+
195
+ extracted_words = segment[i1:i2+k2]
196
+ if not extracted_words:
197
+ continue
198
+
199
+ predicted_time_ranges.append({
200
+ 'start': word_start(extracted_words[0]),
201
+ 'end': word_end(extracted_words[-1]),
202
+ 'category': match['category']
203
+ })
204
+
205
+ # Necessary to sort matches by start time
206
+ predicted_time_ranges.sort(key=word_start)
207
+
208
+ # Merge overlapping predictions and sponsorships that are close together
209
+ # Caused by model having max input size
210
+
211
+ prev_prediction = None
212
+
213
+ final_predicted_time_ranges = []
214
+ for range in predicted_time_ranges:
215
+ start_time = range['start'] if range['start'] > START_TIME_ZERO_THRESHOLD else 0
216
+ end_time = range['end']
217
+
218
+ if prev_prediction is not None and \
219
+ (start_time <= prev_prediction['end'] <= end_time or # Merge overlapping segments
220
+ (range['category'] == prev_prediction['category'] # Merge disconnected segments if same category and within threshold
221
+ and start_time - prev_prediction['end'] <= MERGE_TIME_WITHIN)):
222
+ # Extend last prediction range
223
+ final_predicted_time_ranges[-1]['end'] = end_time
224
+
225
+ else: # No overlap, is a new prediction
226
+ final_predicted_time_ranges.append({
227
+ 'start': start_time,
228
+ 'end': end_time,
229
+ 'category': range['category']
230
+ })
231
+
232
+ prev_prediction = range
233
+
234
+ return final_predicted_time_ranges
235
+
236
+
237
+ def main():
238
+ # Test on unseen data
239
+ logger.setLevel(logging.DEBUG)
240
+
241
+ hf_parser = HfArgumentParser((
242
+ PredictArguments,
243
+ SegmentationArguments,
244
+ GeneralArguments
245
+ ))
246
+ predict_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()
247
+
248
+ if not predict_args.video_ids:
249
+ logger.error(
250
+ 'No video IDs supplied. Use `--video_id`, `--video_ids`, or `--channel_id`.')
251
+ return
252
+
253
+ model, tokenizer, classifier = get_model_tokenizer_classifier(
254
+ predict_args, general_args)
255
+
256
+ for video_id in predict_args.video_ids:
257
+ try:
258
+ predictions = predict(video_id, model, tokenizer, segmentation_args,
259
+ classifier=classifier,
260
+ min_probability=predict_args.min_probability)
261
+ except TranscriptError:
262
+ logger.warning(f'No transcript available for {video_id}')
263
+ continue
264
+ video_url = f'https://www.youtube.com/watch?v={video_id}'
265
+ if not predictions:
266
+ logger.info(f'No predictions found for {video_url}')
267
+ continue
268
+
269
+ # TODO use predict_args.output_as_json
270
+ print(len(predictions), 'predictions found for', video_url)
271
+ for index, prediction in enumerate(predictions, start=1):
272
+ print(f'Prediction #{index}:')
273
+ print('Text: "',
274
+ ' '.join([w['text'] for w in prediction['words']]), '"', sep='')
275
+ print('Time:', seconds_to_time(
276
+ prediction['start']), '\u2192', seconds_to_time(prediction['end']))
277
+ print('Category:', prediction.get('category'))
278
+ if 'probability' in prediction:
279
+ print('Probability:', prediction['probability'])
280
+ print()
281
+ print()
282
+
283
+
284
+ if __name__ == '__main__':
285
+ main()
src/preprocess.py ADDED
@@ -0,0 +1,979 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shared import DatasetArguments
2
+ from utils import jaccard
3
+ from functools import lru_cache
4
+ from datetime import datetime
5
+ import itertools
6
+ from typing import Optional
7
+ import model as model_module
8
+ import segment
9
+ from tqdm import tqdm
10
+ from dataclasses import dataclass, field
11
+ from transformers import HfArgumentParser
12
+ from shared import extract_sponsor_matches_from_text, ACTION_OPTIONS, CATEGORIES, CATGEGORY_OPTIONS, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, CustomTokens
13
+ import csv
14
+ import re
15
+ import random
16
+ import logging
17
+ from youtube_transcript_api import YouTubeTranscriptApi, CouldNotRetrieveTranscript, YouTubeRequestFailed, TooManyRequests
18
+ import os
19
+ import json
20
+ import time
21
+ import requests
22
+
23
+
24
+ logging.basicConfig()
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
29
+ PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
30
+
31
+
32
+ NUM_DECIMALS = 3
33
+
34
+ # https://www.fincher.org/Utilities/CountryLanguageList.shtml
35
+ # https://lingohub.com/developers/supported-locales/language-designators-with-regions
36
+ LANGUAGE_PREFERENCE_LIST = ['en-GB', 'en-US', 'en-CA', 'en-AU', 'en-NZ', 'en-ZA',
37
+ 'en-IE', 'en-IN', 'en-JM', 'en-BZ', 'en-TT', 'en-PH', 'en-ZW',
38
+ 'en']
39
+
40
+
41
+ def parse_transcript_json(json_data, granularity):
42
+ assert json_data['wireMagic'] == 'pb3'
43
+
44
+ assert granularity in ('word', 'chunk')
45
+
46
+ # TODO remove bracketed words?
47
+ # (kiss smacks)
48
+ # (upbeat music)
49
+ # [text goes here]
50
+
51
+ # Some manual transcripts aren't that well formatted... but do have punctuation
52
+ # https://www.youtube.com/watch?v=LR9FtWVjk2c
53
+
54
+ parsed_transcript = []
55
+
56
+ events = json_data['events']
57
+
58
+ for event_index, event in enumerate(events):
59
+ segments = event.get('segs')
60
+ if not segments:
61
+ continue
62
+
63
+ # This value is known (when phrase appears on screen)
64
+ start_ms = event['tStartMs']
65
+ total_characters = 0
66
+
67
+ new_segments = []
68
+ for seg in segments:
69
+ # Replace \n, \t, etc. with space
70
+ text = ' '.join(seg['utf8'].split())
71
+
72
+ # Remove zero-width spaces and strip trailing and leading whitespace
73
+ text = text.replace('\u200b', '').replace('\u200c', '').replace(
74
+ '\u200d', '').replace('\ufeff', '').strip()
75
+
76
+ # Alternatively,
77
+ # text = text.encode('ascii', 'ignore').decode()
78
+
79
+ # Needed for auto-generated transcripts
80
+ text = text.replace(PROFANITY_RAW, PROFANITY_CONVERTED)
81
+
82
+ if not text:
83
+ continue
84
+
85
+ offset_ms = seg.get('tOffsetMs', 0)
86
+
87
+ new_segments.append({
88
+ 'text': text,
89
+ 'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS)
90
+ })
91
+
92
+ total_characters += len(text)
93
+
94
+ if not new_segments:
95
+ continue
96
+
97
+ if event_index < len(events) - 1:
98
+ next_start_ms = events[event_index + 1]['tStartMs']
99
+ total_event_duration_ms = min(
100
+ event.get('dDurationMs', float('inf')), next_start_ms - start_ms)
101
+ else:
102
+ total_event_duration_ms = event.get('dDurationMs', 0)
103
+
104
+ # Ensure duration is non-negative
105
+ total_event_duration_ms = max(total_event_duration_ms, 0)
106
+
107
+ avg_seconds_per_character = (
108
+ total_event_duration_ms/total_characters)/1000
109
+
110
+ num_char_count = 0
111
+ for seg_index, seg in enumerate(new_segments):
112
+ num_char_count += len(seg['text'])
113
+
114
+ # Estimate segment end
115
+ seg_end = seg['start'] + \
116
+ (num_char_count * avg_seconds_per_character)
117
+
118
+ if seg_index < len(new_segments) - 1:
119
+ # Do not allow longer than next
120
+ seg_end = min(seg_end, new_segments[seg_index+1]['start'])
121
+
122
+ seg['end'] = round(seg_end, NUM_DECIMALS)
123
+ parsed_transcript.append(seg)
124
+
125
+ final_parsed_transcript = []
126
+ for i in range(len(parsed_transcript)):
127
+
128
+ word_level = granularity == 'word'
129
+ if word_level:
130
+ split_text = parsed_transcript[i]['text'].split()
131
+ elif granularity == 'chunk':
132
+ # Split on space after punctuation
133
+ split_text = re.split(
134
+ r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text'])
135
+ if len(split_text) == 1:
136
+ split_on_whitespace = parsed_transcript[i]['text'].split()
137
+
138
+ if len(split_on_whitespace) >= 8: # Too many words
139
+ # Rather split on whitespace instead of punctuation
140
+ split_text = split_on_whitespace
141
+ else:
142
+ word_level = True
143
+ else:
144
+ raise ValueError('Unknown granularity')
145
+
146
+ segment_end = parsed_transcript[i]['end']
147
+ if i < len(parsed_transcript) - 1:
148
+ segment_end = min(segment_end, parsed_transcript[i+1]['start'])
149
+
150
+ segment_duration = segment_end - parsed_transcript[i]['start']
151
+
152
+ num_chars_in_text = sum(map(len, split_text))
153
+
154
+ num_char_count = 0
155
+ current_offset = 0
156
+ for s in split_text:
157
+ num_char_count += len(s)
158
+
159
+ next_offset = (num_char_count/num_chars_in_text) * segment_duration
160
+
161
+ word_start = round(
162
+ parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS)
163
+ word_end = round(
164
+ parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS)
165
+
166
+ # Make the reasonable assumption that min wps is 1.5
167
+ final_parsed_transcript.append({
168
+ 'text': s,
169
+ 'start': word_start,
170
+ 'end': min(word_end, word_start + 1.5) if word_level else word_end
171
+ })
172
+ current_offset = next_offset
173
+
174
+ return final_parsed_transcript
175
+
176
+
177
+ def list_transcripts(video_id):
178
+ try:
179
+ return YouTubeTranscriptApi.list_transcripts(video_id)
180
+ except json.decoder.JSONDecodeError:
181
+ return None
182
+
183
+
184
+ WORDS_TO_REMOVE = [
185
+ CustomTokens.MUSIC.value,
186
+ CustomTokens.APPLAUSE.value,
187
+ CustomTokens.LAUGHTER.value
188
+ ]
189
+
190
+
191
+ @lru_cache(maxsize=16)
192
+ def get_words(video_id, process=True, transcript_type='auto', fallback='manual', filter_words_to_remove=True, download=False, granularity='word'):
193
+ """Get parsed video transcript with caching system
194
+ returns None if not processed yet and process is False
195
+ """
196
+ # NOTE: granularity='chunk' should only be used for generating training data... nowhere else
197
+
198
+ transcript_path = os.path.join( # TODO use relative path to this
199
+ 'transcripts', transcript_type, f'{video_id}.json')
200
+
201
+ raw_transcript_json = None
202
+ try:
203
+ if not download and os.path.exists(transcript_path): # Load from file
204
+ with open(transcript_path) as fp:
205
+ raw_transcript_json = json.load(fp) # May be empty
206
+
207
+ elif process:
208
+ transcript_list = list_transcripts(video_id)
209
+
210
+ if transcript_list is not None:
211
+ if transcript_type == 'manual':
212
+ ts = transcript_list.find_manually_created_transcript(
213
+ LANGUAGE_PREFERENCE_LIST)
214
+ else:
215
+ ts = transcript_list.find_generated_transcript(
216
+ LANGUAGE_PREFERENCE_LIST)
217
+ raw_transcript = ts._http_client.get(
218
+ f'{ts._url}&fmt=json3').content
219
+ if raw_transcript:
220
+ raw_transcript_json = json.loads(raw_transcript)
221
+
222
+ except (TooManyRequests, YouTubeRequestFailed):
223
+ raise # Cannot recover from these errors and do not mark as empty transcript
224
+
225
+ except requests.exceptions.RequestException: # Can recover
226
+ time.sleep(10) # Timeout
227
+ return get_words(video_id, process, transcript_type, fallback, granularity)
228
+
229
+ except CouldNotRetrieveTranscript: # Retrying won't solve
230
+ pass # Mark as empty transcript
231
+
232
+ except json.decoder.JSONDecodeError:
233
+ logger.warning(f'JSONDecodeError for {video_id}')
234
+ if os.path.exists(transcript_path):
235
+ os.remove(transcript_path) # Remove file and try again
236
+ return get_words(video_id, process, transcript_type, fallback, granularity)
237
+
238
+ # Tried to process it, but it was empty...
239
+ if download or (process and not os.path.exists(transcript_path)):
240
+ with open(transcript_path, 'w') as fp:
241
+ json.dump(raw_transcript_json, fp)
242
+
243
+ if not raw_transcript_json and fallback is not None:
244
+ return get_words(video_id, process, transcript_type=fallback, fallback=None, granularity=granularity)
245
+
246
+ if raw_transcript_json:
247
+ processed_transcript = parse_transcript_json(
248
+ raw_transcript_json, granularity)
249
+ if filter_words_to_remove:
250
+ processed_transcript = list(
251
+ filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript))
252
+ else:
253
+ processed_transcript = raw_transcript_json # Either None or []
254
+
255
+ return processed_transcript
256
+
257
+
258
+ # TODO make min_sponsor_segment_length param
259
+ # TODO rename to extract_segments
260
+ def extract_sponsors(words, min_sponsor_segment_length=3):
261
+ if not words:
262
+ return []
263
+
264
+ paragraphs = []
265
+ current = []
266
+ prev_category = None
267
+
268
+ for i in range(len(words) + 1):
269
+ unimportant = i == len(words) or words[i].get('category') is None
270
+
271
+ if unimportant or words[i].get('category') != prev_category:
272
+ if current: # Save the current batch
273
+ paragraphs.append({
274
+ 'words': current,
275
+ 'category': current[-1].get('category'),
276
+ })
277
+
278
+ current = []
279
+
280
+ if not unimportant: # Some useful information to save
281
+ current.append(words[i])
282
+ prev_category = words[i].get('category')
283
+
284
+ # Remove all too short:
285
+ return list(filter(lambda x: len(x['words']) >= min_sponsor_segment_length, paragraphs))
286
+
287
+
288
+ def clean_text(text):
289
+
290
+ # Replace impossibly long words with a special token
291
+ # Usually the result of incorrect labelling
292
+ text = re.sub(r'\w{64,}', CustomTokens.LONG_WORD.value, text)
293
+
294
+ SHORT_HYPHENATED_REGEX = r'\w{1,2}(?:-\w{1,2}){3,}(?:-?\w*)'
295
+
296
+ # Replace hyphenated URLs with special token
297
+ # For some reason, youtube sometimes transcribes urls in this form:
298
+ # 'b-a-b-b-e-l-dot-com', 'g-e-t-r-o-m-a-n-com'
299
+ # not 'e-commerce'
300
+ text = re.sub(f'{SHORT_HYPHENATED_REGEX}(?:com|org|net)',
301
+ CustomTokens.HYPHENATED_URL.value, text)
302
+
303
+ # Replace short+hyphenated text with a special token. Of the form:
304
+ # 'i-i-i-i-i-i-i-i-i-i-i-i', 'b-u-m-f-u-z-z-l-e', 'v-e-r-i-t-a-s-i-u-m', 'do-do-do-do-do'
305
+ text = re.sub(SHORT_HYPHENATED_REGEX,
306
+ CustomTokens.SHORT_HYPHENATED.value, text)
307
+
308
+ # Replace URLs with URL_TOKEN
309
+ URL_REGEX = r'(?:(?:http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.(?:[a-zA-Z]){2,6}(?:[a-zA-Z0-9\.\&\/\?\:@\-_=#%])*'
310
+ text = re.sub(URL_REGEX, CustomTokens.URL.value, text)
311
+
312
+ NUM_REGEX = r'(?:\d+,)*(?:\d*[.])?\d+'
313
+
314
+ # Encode specific numeric words
315
+ # Of the form: 12%, 12.34%
316
+ # Usually included in sponsorships
317
+ text = re.sub(f'{NUM_REGEX}%',
318
+ CustomTokens.NUMBER_PERCENTAGE.value, text)
319
+
320
+ # Normal numbers, should not have an effect on sponsorship
321
+ text = re.sub(NUM_REGEX, CustomTokens.NUMBER.value, text)
322
+
323
+ # Replace profanity with special token
324
+ text = text.replace(PROFANITY_RAW, CustomTokens.PROFANITY.value)
325
+ text = text.replace(PROFANITY_CONVERTED, CustomTokens.PROFANITY.value)
326
+
327
+ return text.strip()
328
+
329
+
330
+ def remove_duplicate_segments(segments):
331
+ # Algorithm based on SponsorBlock algorithm
332
+ # https://blog.ajay.app/voting-and-pseudo-randomness-or-sponsorblock-or-youtube-sponsorship-segment-blocker
333
+ # Find sponsors that are overlapping
334
+
335
+ best = []
336
+ for i in segments:
337
+ similar_segments = []
338
+ for j in segments:
339
+ if jaccard(i['start'], i['end'], j['start'], j['end']) > 0.1: # Some overlap
340
+ similar_segments.append(j)
341
+
342
+ if similar_segments:
343
+ best_similar_seg = max(similar_segments, key=lambda item: (
344
+ item['locked'],
345
+ item['votes'],
346
+ item['views'],
347
+ item['reputation']
348
+ ))
349
+ if best_similar_seg not in best:
350
+ best.append(best_similar_seg)
351
+
352
+ if len(segments) != len(best): # Saw some reduction... try again
353
+ return remove_duplicate_segments(best)
354
+
355
+ return best
356
+
357
+
358
+ @dataclass
359
+ class PreprocessArguments:
360
+ """
361
+ Arguments pertaining to what data we are going to preprocess.
362
+ """
363
+ update_database: bool = field(
364
+ default=False, metadata={'help': 'Download the raw database.'}
365
+ )
366
+
367
+ do_create: bool = field(
368
+ default=False, metadata={'help': 'Merge sponsor segments into single file'}
369
+ )
370
+
371
+ min_votes: int = field(
372
+ default=0, metadata={'help': 'Minimum number of votes'})
373
+ # Downvotes will make this negative.
374
+ # 1 = At least one positive vote
375
+
376
+ max_segment_duration: float = field(
377
+ default=180, # 3 minutes
378
+ # >180 => 2.8%
379
+ # >200 => 2.1%
380
+ # >250 => 1.1%
381
+ # >300 => 0.06%
382
+ metadata={'help': 'Ignore all segments whose duration in seconds is longer than this value (negative means no limit)'})
383
+
384
+ min_views: int = field(
385
+ default=5, metadata={'help': 'Minimum number of views a segment must have to be considered. 0 = show all'})
386
+
387
+ # min_reputation: int = field(
388
+ # default=0, metadata={'help': 'Minimum reputation a user must have for the segment to be included'})
389
+
390
+ min_date: str = field(
391
+ # default='08/06/2020', # release of v2.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/2.0)
392
+ # release of v3.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/3.0)
393
+ default='20/08/2021',
394
+ # default='01/10/2020', # No more autovote
395
+ metadata={'help': 'Only use submissions from after this date (inclusive)'})
396
+
397
+ max_date: str = field(
398
+ # default='01/01/9999', # Include all
399
+ default='15/04/2022',
400
+ metadata={'help': 'Only use videos that have some segment from before this date (exclusive). This allows for videos to have segments be corrected, but ignores new videos (posted after this date) to enter the pool.'})
401
+
402
+ # max_unseen_date: str = field( # TODO
403
+ # default='02/03/2022',
404
+ # metadata={'help': 'Generate test and validation data from `max_date` to `max_unseen_date`'})
405
+ # Specify min/max video id for splitting (seen vs. unseen)
406
+
407
+ keep_duplicate_segments: bool = field(
408
+ default=False, metadata={'help': 'Keep duplicate segments'}
409
+ )
410
+
411
+ do_process_database: bool = field(
412
+ default=False, metadata={'help': 'Process the raw database'}
413
+ )
414
+ do_transcribe: bool = field(
415
+ default=False, metadata={'help': 'Get transcripts for videos'}
416
+ )
417
+ num_jobs: int = field(
418
+ default=4, metadata={'help': 'Number of transcripts to download in parallel'})
419
+
420
+ # overwrite: bool = field(
421
+ # default=False, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
422
+ # )
423
+
424
+ do_generate: bool = field(
425
+ default=False, metadata={'help': 'Generate labelled data.'}
426
+ )
427
+
428
+ do_split: bool = field(
429
+ default=False, metadata={'help': 'Generate training, testing and validation data.'}
430
+ )
431
+
432
+ positive_file: Optional[str] = field(
433
+ default='sponsor_segments.json', metadata={'help': 'File to output sponsored segments to (a jsonlines file).'}
434
+ )
435
+ negative_file: Optional[str] = field(
436
+ default='normal_segments.json', metadata={'help': 'File to output normal segments to (a jsonlines file).'}
437
+ )
438
+
439
+ percentage_positive: float = field(
440
+ default=0.5, metadata={'help': 'Ratio of positive (sponsor) segments to include in final output'})
441
+
442
+ train_split: float = field(
443
+ default=0.9, metadata={'help': 'Ratio of training data. Value between 0 and 1.'})
444
+
445
+ # TODO play around with ratios? lower test/validation split?
446
+ test_split: float = field(
447
+ default=0.05, metadata={'help': 'Ratio of testing data. Value between 0 and 1.'})
448
+ valid_split: float = field(
449
+ default=0.05, metadata={'help': 'Ratio of validation data. Value between 0 and 1.'})
450
+
451
+ start_index: int = field(default=None, metadata={
452
+ 'help': 'Video to start at.'})
453
+
454
+ max_videos: int = field(default=None, metadata={
455
+ 'help': 'Maximum number of videos to preprocess.'})
456
+
457
+ max_segments: int = field(default=None, metadata={
458
+ 'help': 'Maximum number of segments to produce to preprocess.'})
459
+
460
+ raw_data_dir: Optional[str] = field(
461
+ default='raw',
462
+ metadata={
463
+ 'help': 'Raw data directory'
464
+ },
465
+ )
466
+ raw_data_file: Optional[str] = field(
467
+ default='sponsorTimes.csv',
468
+ metadata={
469
+ 'help': 'Raw data file'
470
+ },
471
+ )
472
+
473
+ min_wps: float = field(
474
+ default=1.5, metadata={'help': 'Ignore videos with not enough words spoken per second. This is usually indicitive of video whose captions aren\'t English.'})
475
+ # 0.1 ~ 1%
476
+ # 0.4 ~ 2.5%
477
+ # 0.9 ~ 5%
478
+
479
+
480
+ # Mirrors for database
481
+ MIRRORS = [
482
+ 'https://sponsor.ajay.app/database/sponsorTimes.csv', # Latest
483
+ 'https://sb-mirror.mchang.xyz/sponsorTimes.csv', # 5 minute delay
484
+ 'https://sb.ltn.fi/database/sponsorTimes.csv', # 5 minute delay
485
+ ]
486
+ # TODO only download latest updates/changes
487
+
488
+
489
+ def download_file(url, filename):
490
+ """
491
+ Helper method handling downloading large files from `url` to `filename`.
492
+
493
+ Adapted from https://stackoverflow.com/a/42071418
494
+ """
495
+ chunk_size = 1024
496
+ r = requests.get(url, stream=True)
497
+ total_bytes = int(r.headers['Content-Length'])
498
+ with open(filename, 'wb') as f, tqdm(unit='B', total=total_bytes) as progress:
499
+ for chunk in r.iter_content(chunk_size=chunk_size):
500
+ if chunk: # filter out keep-alive new chunks
501
+ progress.update(len(chunk))
502
+ f.write(chunk)
503
+
504
+ return total_bytes == os.path.getsize(filename)
505
+
506
+
507
+ def main():
508
+ # Responsible for getting transcrips using youtube_transcript_api,
509
+ # then labelling it according to SponsorBlock's API
510
+ logger.setLevel(logging.DEBUG)
511
+
512
+ # Generate final.json from sponsorTimes.csv
513
+ hf_parser = HfArgumentParser((
514
+ PreprocessArguments,
515
+ DatasetArguments,
516
+ segment.SegmentationArguments,
517
+ model_module.ModelArguments,
518
+ GeneralArguments
519
+ ))
520
+ preprocess_args, dataset_args, segmentation_args, model_args, general_args = hf_parser.parse_args_into_dataclasses()
521
+
522
+ raw_dataset_path = os.path.join(
523
+ preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
524
+
525
+ if preprocess_args.update_database:
526
+ logger.info('Updating database')
527
+ for mirror in MIRRORS:
528
+ logger.info(f'Downloading from {mirror}')
529
+ if download_file(mirror, raw_dataset_path):
530
+ break
531
+ logger.warning('Failed, trying next')
532
+
533
+ os.makedirs(dataset_args.data_dir, exist_ok=True)
534
+ processed_db_path = os.path.join(
535
+ dataset_args.data_dir, dataset_args.processed_database)
536
+
537
+ # TODO process all valid possible items and then do filtering only later
538
+ @lru_cache(maxsize=1)
539
+ def read_db():
540
+ # if not preprocess_args.overwrite and os.path.exists(processed_db_path):
541
+ # logger.info(
542
+ # 'Using cached processed database (use `--overwrite` to avoid this behaviour).')
543
+ # with open(processed_db_path) as fp:
544
+ # return json.load(fp)
545
+ logger.info('Processing raw database')
546
+ db = {}
547
+
548
+ allowed_categories = list(map(str.lower, CATGEGORY_OPTIONS))
549
+ with open(raw_dataset_path, newline='') as csvfile:
550
+ reader = csv.DictReader(csvfile)
551
+
552
+ for line in reader:
553
+
554
+ # Never show:
555
+ if line['service'] != 'YouTube':
556
+ continue
557
+ if len(line['videoID']) != 11:
558
+ continue # Invalid youtube video ID
559
+
560
+ if line['category'] not in allowed_categories:
561
+ continue
562
+ if line['actionType'] not in ACTION_OPTIONS:
563
+ continue
564
+
565
+ # Ignore hidden items
566
+ if line['hidden'] == '1' or line['shadowHidden'] == '1':
567
+ continue
568
+
569
+ # Skip those that aren't highly voted
570
+ votes = int(line['votes'])
571
+ if votes < preprocess_args.min_votes:
572
+ continue
573
+
574
+ locked = line['locked'] == '1'
575
+
576
+ reputation = float(line['reputation'])
577
+ # if reputation < preprocess_args.min_reputation:
578
+ # continue # TODO add back?
579
+ # Problems like mGVn1wCkBrE
580
+
581
+ # TODO ignore if over max_duration
582
+
583
+ if line['videoID'] not in db:
584
+ db[line['videoID']] = []
585
+
586
+ db[line['videoID']].append({
587
+ 'uuid': line['UUID'],
588
+ 'start': float(line['startTime']),
589
+ 'end': float(line['endTime']),
590
+ 'votes': votes,
591
+ 'locked': locked,
592
+ 'views': int(line['views']),
593
+ 'submission_time': float(line['timeSubmitted'])/1e3,
594
+ 'reputation': reputation,
595
+ 'category': line['category'],
596
+ 'action': line['actionType'],
597
+ })
598
+
599
+ # First, remove videos that contain a full-video label
600
+ # (may confuse model since disclaimers and such aren't labelled)
601
+ # Must do it here before removing duplicate segments
602
+ for key in list(db):
603
+ if any(x['action'] == 'full' for x in db[key]):
604
+ del db[key]
605
+
606
+ # Remove duplicate sponsor segments by choosing best (most votes)
607
+ if not preprocess_args.keep_duplicate_segments:
608
+ logger.info('Remove duplicate segments')
609
+ for key in db:
610
+ db[key] = remove_duplicate_segments(db[key])
611
+
612
+ # We now remove whole videos from the list
613
+ # Helps with obtaining "fully-labelled" videos
614
+ min_date = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y')
615
+ max_date = datetime.strptime(preprocess_args.max_date, '%d/%m/%Y')
616
+ for key in list(db):
617
+ if preprocess_args.max_segment_duration >= 0 and any(x['end'] - x['start'] > preprocess_args.max_segment_duration for x in db[key]):
618
+ # Remove videos that have at least one segment that is longer than
619
+ # the maximum allowed segment duration. This avoids introducing
620
+ # segments into training that might contain ignored context (since
621
+ # they are too long, so the middle might be normal content)
622
+ del db[key]
623
+ elif any(datetime.fromtimestamp(x['submission_time']) < min_date for x in db[key]):
624
+ # Remove videos where any of its segments were submitted before min_date
625
+ # (essentially removes videos uploaded before min_date)
626
+ # Prevents issues where some segments of a video are excluded
627
+ del db[key]
628
+ elif all(datetime.fromtimestamp(x['submission_time']) > max_date for x in db[key]):
629
+ # Remove videos where all of its segments were submitted after max_date
630
+ # (essentially removes videos uploaded after max_date)
631
+ # Allows for segments to be corrected for past videos
632
+ del db[key]
633
+ elif any(not x['locked'] and x['views'] < preprocess_args.min_views for x in db[key]):
634
+ # Remove videos where any of its non-locked segments do not have enough views
635
+ # (essentially skips videos that have not been fully watched/reviewed)
636
+ # Always include segments locked by VIPs, regardless of view count
637
+ del db[key]
638
+
639
+ logger.info(f'Saved {len(db)} videos')
640
+
641
+ with open(processed_db_path, 'w') as fp:
642
+ json.dump(db, fp)
643
+
644
+ return db
645
+
646
+ if preprocess_args.do_process_database:
647
+ read_db()
648
+
649
+ # 'videoID', 'startTime', 'endTime', 'votes', 'locked', 'incorrectVotes', 'UUID',
650
+ # 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
651
+ # 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
652
+ if preprocess_args.do_transcribe:
653
+ logger.info('Collecting videos')
654
+ parsed_database = read_db()
655
+
656
+ # Remove transcripts already processed
657
+ finished = set(x.split('.')[0] for x in os.listdir(
658
+ 'transcripts/auto/') + os.listdir('transcripts/manual/'))
659
+
660
+ video_ids = list(parsed_database.keys() - finished)
661
+
662
+ # https://stackoverflow.com/a/63495323
663
+ import concurrent
664
+ POLL_INTERVAL = 0.1
665
+
666
+ # Wrap get words function to return video_id after completion
667
+ def get_words_wrapper(video_id):
668
+ get_words(video_id)
669
+ return video_id
670
+
671
+ logger.info('Setting up ThreadPoolExecutor')
672
+ with concurrent.futures.ThreadPoolExecutor(max_workers=preprocess_args.num_jobs) as pool, \
673
+ tqdm(total=len(video_ids)) as progress:
674
+
675
+ all_futures = (pool.submit(get_words_wrapper, video_id)
676
+ for video_id in video_ids)
677
+ to_process = set(itertools.islice(
678
+ all_futures, preprocess_args.num_jobs))
679
+ try:
680
+ while to_process:
681
+ just_finished, to_process = concurrent.futures.wait(
682
+ to_process, timeout=POLL_INTERVAL)
683
+ to_process |= set(itertools.islice(
684
+ all_futures, len(just_finished)))
685
+
686
+ for d in just_finished:
687
+ progress.set_description(f'Processed {d.result()}')
688
+ progress.update()
689
+
690
+ except KeyboardInterrupt:
691
+ logger.info(
692
+ 'Gracefully shutting down: Cancelling unscheduled tasks')
693
+
694
+ # only futures that are not done will prevent exiting
695
+ for future in to_process:
696
+ future.cancel()
697
+
698
+ logger.info('Waiting for in-progress tasks to complete')
699
+ concurrent.futures.wait(to_process, timeout=None)
700
+ logger.info('Cancellation successful')
701
+
702
+ final_path = os.path.join(
703
+ dataset_args.data_dir, dataset_args.processed_file)
704
+
705
+ if preprocess_args.do_create:
706
+ logger.info('Create final data')
707
+
708
+ final_data = {}
709
+
710
+ parsed_database = read_db()
711
+
712
+ transcribed = set(x.split('.')[0] for x in os.listdir(
713
+ 'transcripts/auto/') + os.listdir('transcripts/manual/'))
714
+
715
+ # Only consider videos that have been transcribed already
716
+ video_ids = parsed_database.keys() & transcribed
717
+
718
+ with tqdm(total=len(video_ids)) as progress:
719
+ for index, video_id in enumerate(video_ids):
720
+ if preprocess_args.max_videos is not None and index >= preprocess_args.max_videos:
721
+ break
722
+ progress.set_description(f'Processing {video_id}')
723
+ progress.update()
724
+
725
+ video_words = get_words(video_id, process=False)
726
+ if not video_words:
727
+ continue
728
+
729
+ final_vid_segs = []
730
+ # Only add segments with high enough wps
731
+ for seg in parsed_database[video_id]:
732
+ segment_words = segment.extract_segment(
733
+ video_words, seg['start'], seg['end'])
734
+
735
+ if len(segment_words) <= 1:
736
+ continue # Useless to add segment since no words
737
+
738
+ # duration = segment.word_end(segment_words[-1]) - segment.word_start(segment_words[0])
739
+ duration = seg['end'] - seg['start']
740
+ wps = len(segment_words)/duration if duration > 0 else 0
741
+
742
+ # print(video_id, wps)
743
+ if wps < preprocess_args.min_wps:
744
+ # Skip sponsor segments without many words
745
+ # e.g. music ads with some words on each side
746
+ # progress.set_description(f'Skipping bad segment in {video_id} (wps={wps})')
747
+ continue
748
+ final_vid_segs.append(seg)
749
+
750
+ if final_vid_segs:
751
+ final_data[video_id] = final_vid_segs
752
+
753
+ # Save data
754
+ with open(final_path, 'w') as fp:
755
+ json.dump(final_data, fp)
756
+
757
+ # final_data = preprocess(
758
+ # raw_dataset_path, final_path, preprocess_args.min_votes)
759
+ # # TODO save metadata in final.json?
760
+
761
+ elif os.path.exists(final_path):
762
+ # Already exists
763
+ logging.info(f'{final_path} exists, opening file')
764
+ with open(final_path) as fp:
765
+ final_data = json.load(fp)
766
+ logging.info(f'Found {len(final_data)} videos')
767
+ else:
768
+ return # Do not continue
769
+
770
+ # TODO shuffle final_data
771
+ # if not os.path.exists(excess_path) or preprocess_args.overwrite
772
+ # TODO use overwrite param
773
+
774
+ positive_file = os.path.join(
775
+ dataset_args.data_dir, preprocess_args.positive_file)
776
+ negative_file = os.path.join(
777
+ dataset_args.data_dir, preprocess_args.negative_file)
778
+
779
+ if preprocess_args.do_generate:
780
+ logger.info('Generating')
781
+ # max_videos=preprocess_args.max_videos,
782
+ # max_segments=preprocess_args.max_segments,
783
+ # , max_videos, max_segments
784
+
785
+ from model import get_model_tokenizer
786
+ model, tokenizer = get_model_tokenizer(model_args, general_args)
787
+
788
+ # TODO
789
+ # count_videos = 0
790
+ # count_segments = 0
791
+
792
+ data = final_data.items()
793
+
794
+ start_index = preprocess_args.start_index or 0
795
+ end_index = (preprocess_args.max_videos or len(data)) + start_index
796
+
797
+ data = list(itertools.islice(data, start_index, end_index))
798
+
799
+ write_mode = 'w' # if preprocess_args.overwrite else 'a'
800
+ with open(positive_file, write_mode, encoding='utf-8') as positive, \
801
+ open(negative_file, write_mode, encoding='utf-8') as negative, \
802
+ tqdm(data) as progress:
803
+
804
+ for offset, (video_id, sponsor_segments) in enumerate(data):
805
+
806
+ progress.set_description(f'Processing {video_id}')
807
+ progress.update()
808
+
809
+ # Use chunk granularity to improve manual transcripts
810
+ words = get_words(video_id, process=False, granularity='chunk')
811
+ if not words:
812
+ continue
813
+
814
+ if len(words) <= 1:
815
+ continue
816
+
817
+ segments = segment.generate_labelled_segments(
818
+ words, tokenizer, segmentation_args, sponsor_segments)
819
+
820
+ if not segments:
821
+ continue
822
+
823
+ for seg in segments:
824
+ seg_start = segment.word_start(seg[0])
825
+ seg_end = segment.word_end(seg[-1])
826
+ duration = seg_end - seg_start
827
+ wps = len(seg)/duration if duration > 0 else 0
828
+
829
+ # Ignore segments with "not enough words" in the transcript
830
+ # Must do here since this includes non-sponsor segments
831
+ if wps < preprocess_args.min_wps:
832
+ continue
833
+
834
+ d = {
835
+ # 'video_index': offset + start_index,
836
+ 'video_id': video_id,
837
+ # 'uuid': video_id, # TODO add uuid
838
+ 'text': ' '.join(x['cleaned'] for x in seg),
839
+ 'start': seg_start,
840
+ 'end': seg_end,
841
+ }
842
+
843
+ extracted_segments = extract_sponsors(seg)
844
+ if extracted_segments:
845
+ extracted_texts = []
846
+ for s in extracted_segments:
847
+ w = ' '.join(q['cleaned'] for q in s['words'])
848
+ category = s['category'].upper()
849
+ extracted_texts.append(
850
+ f'{START_SEGMENT_TEMPLATE.format(category)} {w} {END_SEGMENT_TEMPLATE.format(category)}'
851
+ )
852
+
853
+ d['extracted'] = f' {CustomTokens.BETWEEN_SEGMENTS.value} '.join(
854
+ extracted_texts)
855
+ print(json.dumps(d), file=positive)
856
+
857
+ else:
858
+ d['extracted'] = CustomTokens.NO_SEGMENT.value
859
+ print(json.dumps(d), file=negative)
860
+
861
+ if preprocess_args.do_split:
862
+ logger.info('Splitting')
863
+ logger.info('Read files')
864
+
865
+ with open(positive_file, encoding='utf-8') as positive:
866
+ sponsors = positive.readlines()
867
+
868
+ with open(negative_file, encoding='utf-8') as negative:
869
+ non_sponsors = negative.readlines()
870
+
871
+ logger.info('Shuffle')
872
+ random.shuffle(sponsors)
873
+ random.shuffle(non_sponsors)
874
+
875
+ logger.info('Calculate ratios')
876
+ # Ensure correct ratio of positive to negative segments
877
+ percentage_negative = 1 - preprocess_args.percentage_positive
878
+
879
+ if preprocess_args.percentage_positive * len(sponsors) > len(non_sponsors):
880
+ # Negative is limiting
881
+ z = int(preprocess_args.percentage_positive /
882
+ percentage_negative * len(non_sponsors))
883
+
884
+ # excess = sponsors[z:]
885
+ sponsors = sponsors[:z]
886
+
887
+ else:
888
+ # Positive is limiting
889
+ z = int(percentage_negative /
890
+ preprocess_args.percentage_positive * len(sponsors))
891
+
892
+ # excess = non_sponsors[z:]
893
+ non_sponsors = non_sponsors[:z]
894
+
895
+ logger.info('Join')
896
+ all_labelled_segments = sponsors + non_sponsors
897
+
898
+ random.shuffle(all_labelled_segments)
899
+
900
+ # TODO split based on video ids
901
+ logger.info('Split')
902
+ ratios = [preprocess_args.train_split,
903
+ preprocess_args.test_split,
904
+ preprocess_args.valid_split]
905
+
906
+ train_data, test_data, valid_data = split(
907
+ all_labelled_segments, ratios)
908
+
909
+ splits = {
910
+ dataset_args.train_file: train_data,
911
+ dataset_args.test_file: test_data,
912
+ dataset_args.validation_file: valid_data
913
+ }
914
+
915
+ # Output training, testing and validation data
916
+ for name, items in splits.items():
917
+ outfile = os.path.join(dataset_args.data_dir, name)
918
+ with open(outfile, 'w', encoding='utf-8') as fp:
919
+ fp.writelines(items)
920
+
921
+ classifier_splits = {
922
+ dataset_args.c_train_file: train_data,
923
+ dataset_args.c_test_file: test_data,
924
+ dataset_args.c_validation_file: valid_data
925
+ }
926
+
927
+ none_category = CATEGORIES.index(None)
928
+
929
+ # Output training, testing and validation data
930
+ for name, items in classifier_splits.items():
931
+ outfile = os.path.join(dataset_args.data_dir, name)
932
+ with open(outfile, 'w', encoding='utf-8') as fp:
933
+ for item in items:
934
+ parsed_item = json.loads(item) # TODO add uuid
935
+
936
+ matches = extract_sponsor_matches_from_text(
937
+ parsed_item['extracted'])
938
+
939
+ if matches:
940
+ for match in matches:
941
+ print(json.dumps({
942
+ 'text': match['text'],
943
+ 'label': CATEGORIES.index(match['category'])
944
+ }), file=fp)
945
+ else:
946
+ print(json.dumps({
947
+ 'text': parsed_item['text'],
948
+ 'label': none_category
949
+ }), file=fp)
950
+
951
+ logger.info('Write')
952
+ # Save excess items
953
+ # excess_path = os.path.join(
954
+ # dataset_args.data_dir, dataset_args.excess_file)
955
+ # if not os.path.exists(excess_path) or preprocess_args.overwrite:
956
+ # with open(excess_path, 'w', encoding='utf-8') as fp:
957
+ # fp.writelines(excess)
958
+ # else:
959
+ # logger.info(f'Skipping {dataset_args.excess_file}')
960
+
961
+ logger.info(
962
+ f'Finished splitting: {len(sponsors)} sponsors, {len(non_sponsors)} non sponsors')
963
+
964
+
965
+ def split(arr, ratios):
966
+ """Split array according to ratios. Sum of ratios should be <= 1"""
967
+ to_return = []
968
+
969
+ cumulative_sum = 0
970
+ for r in ratios:
971
+ current = cumulative_sum
972
+ cumulative_sum += r * len(arr)
973
+ to_return.append(arr[int(current):int(cumulative_sum)])
974
+
975
+ return to_return
976
+
977
+
978
+ if __name__ == '__main__':
979
+ main()
src/segment.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import preprocess
2
+ from dataclasses import dataclass, field
3
+
4
+
5
+ @dataclass
6
+ class SegmentationArguments:
7
+ pause_threshold: int = field(default=2.5, metadata={
8
+ 'help': 'When the time between words is greater than pause threshold, force into a new segment'})
9
+
10
+
11
+ def get_overlapping_chunks_of_tokens(tokens, size, overlap):
12
+ for i in range(0, len(tokens), size-overlap+1):
13
+ yield tokens[i:i+size]
14
+
15
+
16
+ # Generate up to SAFETY_TOKENS_PERCENTAGE*max_tokens tokens
17
+ MIN_SAFETY_TOKENS = 8
18
+ SAFETY_TOKENS_PERCENTAGE = 0.9765625
19
+ # e.g. 512 -> 500, 768 -> 750
20
+
21
+
22
+ # TODO play around with this?
23
+ OVERLAP_TOKEN_PERCENTAGE = 0.5 # 0.25
24
+
25
+
26
+ def add_labels_to_words(words, sponsor_segments):
27
+
28
+ for sponsor_segment in sponsor_segments:
29
+ for w in extract_segment(words, sponsor_segment['start'], sponsor_segment['end']):
30
+ w['category'] = sponsor_segment['category']
31
+
32
+ return words
33
+
34
+
35
+ def generate_labelled_segments(words, tokenizer, segmentation_args, sponsor_segments):
36
+ segments = generate_segments(words, tokenizer, segmentation_args)
37
+
38
+ labelled_segments = list(
39
+ map(lambda x: add_labels_to_words(x, sponsor_segments), segments))
40
+
41
+ return labelled_segments
42
+
43
+
44
+ def word_start(word):
45
+ return word['start']
46
+
47
+
48
+ def word_end(word):
49
+ return word.get('end', word['start'])
50
+
51
+
52
+ def generate_segments(words, tokenizer, segmentation_args):
53
+
54
+ cleaned_words_list = []
55
+ for w in words:
56
+ w['cleaned'] = preprocess.clean_text(w['text'])
57
+ cleaned_words_list.append(w['cleaned'])
58
+
59
+ # Get lengths of tokenized words
60
+ num_tokens_list = tokenizer(cleaned_words_list, add_special_tokens=False,
61
+ truncation=True, return_attention_mask=False, return_length=True).length
62
+
63
+ first_pass_segments = []
64
+ for index, (word, num_tokens) in enumerate(zip(words, num_tokens_list)):
65
+ word['num_tokens'] = num_tokens
66
+
67
+ # Add new segment
68
+ if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
69
+ first_pass_segments.append([word])
70
+
71
+ else: # Add to current segment
72
+ first_pass_segments[-1].append(word)
73
+
74
+ max_q_size = round(SAFETY_TOKENS_PERCENTAGE * tokenizer.model_max_length)
75
+
76
+ buffer_size = OVERLAP_TOKEN_PERCENTAGE*max_q_size # tokenizer.model_max_length
77
+
78
+ # In second pass, we split those segments if too big
79
+ second_pass_segments = []
80
+
81
+ for segment in first_pass_segments:
82
+ current_segment_num_tokens = 0
83
+ current_segment = []
84
+ after_split_segments = []
85
+ for word in segment:
86
+ new_seg = current_segment_num_tokens + \
87
+ word['num_tokens'] >= max_q_size
88
+ if new_seg:
89
+ # Adding this token would make it have too many tokens
90
+ # We save this batch and create new
91
+ after_split_segments.append(current_segment)
92
+
93
+ # Add tokens to current segment
94
+ current_segment.append(word)
95
+ current_segment_num_tokens += word['num_tokens']
96
+
97
+ if not new_seg:
98
+ continue
99
+
100
+ # Just created a new segment, so we remove until we only have buffer_size tokens
101
+ last_index = 0
102
+ while current_segment_num_tokens > buffer_size and current_segment:
103
+ current_segment_num_tokens -= current_segment[last_index]['num_tokens']
104
+ last_index += 1
105
+
106
+ current_segment = current_segment[last_index:]
107
+
108
+ if current_segment: # Add remaining segment
109
+ after_split_segments.append(current_segment)
110
+
111
+ # TODO if len(after_split_segments) > 1, a split occurred
112
+
113
+ second_pass_segments.extend(after_split_segments)
114
+
115
+ # Cleaning up, delete 'num_tokens' from each word
116
+ for word in words:
117
+ word.pop('num_tokens', None)
118
+
119
+ return second_pass_segments
120
+
121
+
122
+ def extract_segment(words, start, end, map_function=None):
123
+ """Extracts all words with time in [start, end]"""
124
+ if words is None:
125
+ words = []
126
+
127
+ a = max(binary_search_below(words, 0, len(words), start), 0)
128
+ b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
129
+
130
+ to_transform = map_function is not None and callable(map_function)
131
+
132
+ return [
133
+ map_function(words[i]) if to_transform else words[i] for i in range(a, b)
134
+ ]
135
+
136
+
137
+ def avg(*items):
138
+ return sum(items)/len(items)
139
+
140
+
141
+ def binary_search_below(transcript, start_index, end_index, time):
142
+ if start_index >= end_index:
143
+ return end_index
144
+
145
+ middle_index = (start_index + end_index) // 2
146
+ middle = transcript[middle_index]
147
+ middle_time = avg(word_start(middle), word_end(middle))
148
+
149
+ if time <= middle_time:
150
+ return binary_search_below(transcript, start_index, middle_index, time)
151
+ else:
152
+ return binary_search_below(transcript, middle_index + 1, end_index, time)
153
+
154
+
155
+ def binary_search_above(transcript, start_index, end_index, time):
156
+ if start_index >= end_index:
157
+ return end_index
158
+
159
+ middle_index = (start_index + end_index + 1) // 2
160
+ middle = transcript[middle_index]
161
+ middle_time = avg(word_start(middle), word_end(middle))
162
+
163
+ if time >= middle_time:
164
+ return binary_search_above(transcript, middle_index, end_index, time)
165
+ else:
166
+ return binary_search_above(transcript, start_index, middle_index - 1, time)
src/shared.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.trainer_utils import get_last_checkpoint as glc
2
+ import os
3
+ from utils import re_findall
4
+ import logging
5
+ import sys
6
+ from datasets import load_dataset
7
+ import re
8
+ import gc
9
+ from time import time_ns
10
+ import random
11
+ import numpy as np
12
+ import torch
13
+ from typing import Optional
14
+ from dataclasses import dataclass, field
15
+ from enum import Enum
16
+
17
+
18
+ logging.basicConfig()
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Setup logging
22
+ logging.basicConfig(
23
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
24
+ datefmt='%m/%d/%Y %H:%M:%S',
25
+ handlers=[logging.StreamHandler(sys.stdout)],
26
+ )
27
+
28
+ CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
29
+
30
+ ACTION_OPTIONS = ['skip', 'mute', 'full']
31
+
32
+ CATGEGORY_OPTIONS = {
33
+ 'SPONSOR': 'Sponsor',
34
+ 'SELFPROMO': 'Self/unpaid promo',
35
+ 'INTERACTION': 'Interaction reminder',
36
+ }
37
+
38
+ START_SEGMENT_TEMPLATE = 'START_{}_TOKEN'
39
+ END_SEGMENT_TEMPLATE = 'END_{}_TOKEN'
40
+
41
+
42
+ class CustomTokens(Enum):
43
+ EXTRACT_SEGMENTS_PREFIX = 'EXTRACT_SEGMENTS: '
44
+
45
+ # Preprocessing tokens
46
+ URL = 'URL_TOKEN'
47
+ HYPHENATED_URL = 'HYPHENATED_URL_TOKEN'
48
+ NUMBER_PERCENTAGE = 'NUMBER_PERCENTAGE_TOKEN'
49
+ NUMBER = 'NUMBER_TOKEN'
50
+
51
+ SHORT_HYPHENATED = 'SHORT_HYPHENATED_TOKEN'
52
+ LONG_WORD = 'LONG_WORD_TOKEN'
53
+
54
+ # Custom YouTube tokens
55
+ MUSIC = '[Music]'
56
+ APPLAUSE = '[Applause]'
57
+ LAUGHTER = '[Laughter]'
58
+
59
+ PROFANITY = 'PROFANITY_TOKEN'
60
+
61
+ # Segment tokens
62
+ NO_SEGMENT = 'NO_SEGMENT_TOKEN'
63
+
64
+ START_SPONSOR = START_SEGMENT_TEMPLATE.format('SPONSOR')
65
+ END_SPONSOR = END_SEGMENT_TEMPLATE.format('SPONSOR')
66
+
67
+ START_SELFPROMO = START_SEGMENT_TEMPLATE.format('SELFPROMO')
68
+ END_SELFPROMO = END_SEGMENT_TEMPLATE.format('SELFPROMO')
69
+
70
+ START_INTERACTION = START_SEGMENT_TEMPLATE.format('INTERACTION')
71
+ END_INTERACTION = END_SEGMENT_TEMPLATE.format('INTERACTION')
72
+
73
+ BETWEEN_SEGMENTS = 'BETWEEN_SEGMENTS_TOKEN'
74
+
75
+ @classmethod
76
+ def custom_tokens(cls):
77
+ return [e.value for e in cls]
78
+
79
+ @classmethod
80
+ def add_custom_tokens(cls, tokenizer):
81
+ tokenizer.add_tokens(cls.custom_tokens())
82
+
83
+
84
+ _SEGMENT_START = START_SEGMENT_TEMPLATE.format(r'(?P<category>\w+)')
85
+ _SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
86
+ SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
87
+
88
+
89
+ def extract_sponsor_matches_from_text(text):
90
+ if CustomTokens.NO_SEGMENT.value in text:
91
+ return []
92
+ else:
93
+ return re_findall(SEGMENT_MATCH_RE, text)
94
+
95
+
96
+ def extract_sponsor_matches(texts):
97
+ return list(map(extract_sponsor_matches_from_text, texts))
98
+
99
+
100
+ @dataclass
101
+ class DatasetArguments:
102
+ data_dir: Optional[str] = field(
103
+ default='data',
104
+ metadata={
105
+ 'help': 'The directory which stores train, test and/or validation data.'
106
+ },
107
+ )
108
+ processed_file: Optional[str] = field(
109
+ default='segments.json',
110
+ metadata={
111
+ 'help': 'Processed data file'
112
+ },
113
+ )
114
+ processed_database: Optional[str] = field(
115
+ default='processed_database.json',
116
+ metadata={
117
+ 'help': 'Processed database file'
118
+ },
119
+ )
120
+
121
+ overwrite_cache: bool = field(
122
+ default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'}
123
+ )
124
+
125
+ dataset_cache_dir: Optional[str] = field(
126
+ default=None,
127
+ metadata={
128
+ 'help': 'Where to store the cached datasets'
129
+ },
130
+ )
131
+
132
+ train_file: Optional[str] = field(
133
+ default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
134
+ )
135
+ validation_file: Optional[str] = field(
136
+ default='valid.json',
137
+ metadata={
138
+ 'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
139
+ },
140
+ )
141
+ test_file: Optional[str] = field(
142
+ default='test.json',
143
+ metadata={
144
+ 'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
145
+ },
146
+ )
147
+
148
+ c_train_file: Optional[str] = field(
149
+ default='c_train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
150
+ )
151
+ c_validation_file: Optional[str] = field(
152
+ default='c_valid.json',
153
+ metadata={
154
+ 'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
155
+ },
156
+ )
157
+ c_test_file: Optional[str] = field(
158
+ default='c_test.json',
159
+ metadata={
160
+ 'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
161
+ },
162
+ )
163
+
164
+ def __post_init__(self):
165
+ if self.train_file is None or self.validation_file is None:
166
+ raise ValueError(
167
+ 'Need either a dataset name or a training/validation file.')
168
+
169
+ else:
170
+ train_extension = self.train_file.split(".")[-1]
171
+ assert train_extension in [
172
+ "csv", "json"], "`train_file` should be a csv or a json file."
173
+ validation_extension = self.validation_file.split(".")[-1]
174
+ assert (
175
+ validation_extension == train_extension
176
+ ), "`validation_file` should have the same extension (csv or json) as `train_file`."
177
+
178
+
179
+ @dataclass
180
+ class OutputArguments:
181
+
182
+ output_dir: str = field(
183
+ default='out',
184
+ metadata={
185
+ 'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
186
+ },
187
+ )
188
+ checkpoint: Optional[str] = field(
189
+ default=None,
190
+ metadata={
191
+ 'help': 'Choose the checkpoint/model to train from or test with. Defaults to the latest checkpoint found in `output_dir`.'
192
+ },
193
+ )
194
+ models_dir: str = field(
195
+ default='models',
196
+ metadata={
197
+ 'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
198
+ },
199
+ )
200
+ # classifier_dir: str = field(
201
+ # default='out',
202
+ # metadata={
203
+ # 'help': 'The output directory where the model predictions and checkpoints will be written to and read from.'
204
+ # },
205
+ # )
206
+
207
+
208
+ def seed_factory():
209
+ return time_ns() % (2**32 - 1)
210
+
211
+
212
+ @dataclass
213
+ class GeneralArguments:
214
+ seed: Optional[int] = field(default_factory=seed_factory, metadata={
215
+ 'help': 'Set seed for deterministic training and testing. By default, it uses the current time (results in essentially random results).'
216
+ })
217
+ no_cuda: bool = field(default=False, metadata={
218
+ 'help': 'Do not use CUDA even when it is available'})
219
+
220
+ def __post_init__(self):
221
+ random.seed(self.seed)
222
+ np.random.seed(self.seed)
223
+ torch.manual_seed(self.seed)
224
+ torch.cuda.manual_seed_all(self.seed)
225
+
226
+
227
+ def seconds_to_time(seconds, remove_leading_zeroes=False):
228
+ fractional = round(seconds % 1, 3)
229
+ fractional = '' if fractional == 0 else str(fractional)[1:]
230
+ h, remainder = divmod(abs(int(seconds)), 3600)
231
+ m, s = divmod(remainder, 60)
232
+ hms = f'{h:02}:{m:02}:{s:02}'
233
+ if remove_leading_zeroes:
234
+ hms = re.sub(r'^0(?:0:0?)?', '', hms)
235
+ return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
236
+
237
+
238
+ def reset():
239
+ torch.clear_autocast_cache()
240
+ torch.cuda.empty_cache()
241
+ gc.collect()
242
+ print(torch.cuda.memory_summary(device=None, abbreviated=False))
243
+
244
+
245
+ def load_datasets(dataset_args: DatasetArguments):
246
+
247
+ logger.info('Reading datasets')
248
+ data_files = {}
249
+
250
+ if dataset_args.train_file is not None:
251
+ data_files['train'] = os.path.join(
252
+ dataset_args.data_dir, dataset_args.train_file)
253
+ if dataset_args.validation_file is not None:
254
+ data_files['validation'] = os.path.join(
255
+ dataset_args.data_dir, dataset_args.validation_file)
256
+ if dataset_args.test_file is not None:
257
+ data_files['test'] = os.path.join(
258
+ dataset_args.data_dir, dataset_args.test_file)
259
+
260
+ return load_dataset('json', data_files=data_files, cache_dir=dataset_args.dataset_cache_dir)
261
+
262
+
263
+ @dataclass
264
+ class AdditionalTrainingArguments:
265
+ seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
266
+
267
+ num_train_epochs: float = field(
268
+ default=1, metadata={'help': 'Total number of training epochs to perform.'})
269
+
270
+ save_steps: int = field(default=5000, metadata={
271
+ 'help': 'Save checkpoint every X updates steps.'})
272
+ eval_steps: int = field(default=25000, metadata={
273
+ 'help': 'Run an evaluation every X steps.'})
274
+ logging_steps: int = field(default=5000, metadata={
275
+ 'help': 'Log every X updates steps.'})
276
+
277
+ # do_eval: bool = field(default=False, metadata={
278
+ # 'help': 'Whether to run eval on the dev set.'})
279
+ # do_predict: bool = field(default=False, metadata={
280
+ # 'help': 'Whether to run predictions on the test set.'})
281
+
282
+ per_device_train_batch_size: int = field(
283
+ default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for training.'}
284
+ )
285
+ per_device_eval_batch_size: int = field(
286
+ default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for evaluation.'}
287
+ )
288
+
289
+ # report_to: Optional[List[str]] = field(
290
+ # default=None, metadata={"help": "The list of integrations to report the results and logs to."}
291
+ # )
292
+ evaluation_strategy: str = field(
293
+ default='steps',
294
+ metadata={
295
+ 'help': 'The evaluation strategy to use.',
296
+ 'choices': ['no', 'steps', 'epoch']
297
+ },
298
+ )
299
+
300
+ # evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
301
+ # The evaluation strategy to adopt during training. Possible values are:
302
+
303
+ # * :obj:`"no"`: No evaluation is done during training.
304
+ # * :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
305
+ # * :obj:`"epoch"`: Evaluation is done at the end of each epoch.
306
+
307
+ preprocessing_num_workers: Optional[int] = field(
308
+ default=None,
309
+ metadata={'help': 'The number of processes to use for the preprocessing.'},
310
+ )
311
+ max_seq_length: int = field(
312
+ default=512,
313
+ metadata={
314
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
315
+ "than this will be truncated, sequences shorter will be padded."
316
+ },
317
+ )
318
+ max_train_samples: Optional[int] = field(
319
+ default=None,
320
+ metadata={
321
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
322
+ "value if set."
323
+ },
324
+ )
325
+ max_eval_samples: Optional[int] = field(
326
+ default=None,
327
+ metadata={
328
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
329
+ "value if set."
330
+ },
331
+ )
332
+ max_predict_samples: Optional[int] = field(
333
+ default=None,
334
+ metadata={
335
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
336
+ "value if set."
337
+ },
338
+ )
339
+
340
+
341
+ @dataclass
342
+ class CustomTrainingArguments(OutputArguments, AdditionalTrainingArguments):
343
+ pass
344
+
345
+
346
+ def get_last_checkpoint(training_args):
347
+ last_checkpoint = None
348
+ if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
349
+ last_checkpoint = glc(training_args.output_dir)
350
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
351
+ raise ValueError(
352
+ f'Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome.'
353
+ )
354
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
355
+ logger.info(
356
+ f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
357
+ )
358
+ return last_checkpoint
359
+
360
+
361
+ def train_from_checkpoint(trainer, last_checkpoint, training_args):
362
+ checkpoint = None
363
+ if training_args.resume_from_checkpoint is not None:
364
+ checkpoint = training_args.resume_from_checkpoint
365
+ elif last_checkpoint is not None:
366
+ checkpoint = last_checkpoint
367
+
368
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
369
+
370
+ trainer.save_model() # Saves the tokenizer too for easy upload
371
+
372
+ return train_result
373
+
374
+
375
+ def prepare_datasets(raw_datasets, dataset_args: DatasetArguments, training_args: CustomTrainingArguments, preprocess_function):
376
+
377
+ with training_args.main_process_first(desc="dataset map pre-processing"):
378
+ raw_datasets = raw_datasets.map(
379
+ preprocess_function,
380
+ batched=True,
381
+ load_from_cache_file=not dataset_args.overwrite_cache,
382
+ desc="Running tokenizer on dataset",
383
+ )
384
+
385
+ if 'train' not in raw_datasets:
386
+ raise ValueError('Train dataset missing')
387
+ train_dataset = raw_datasets['train']
388
+ if training_args.max_train_samples is not None:
389
+ train_dataset = train_dataset.select(
390
+ range(training_args.max_train_samples))
391
+
392
+ if 'validation' not in raw_datasets:
393
+ raise ValueError('Validation dataset missing')
394
+ eval_dataset = raw_datasets['validation']
395
+ if training_args.max_eval_samples is not None:
396
+ eval_dataset = eval_dataset.select(
397
+ range(training_args.max_eval_samples))
398
+
399
+ if 'test' not in raw_datasets:
400
+ raise ValueError('Test dataset missing')
401
+ predict_dataset = raw_datasets['test']
402
+ if training_args.max_predict_samples is not None:
403
+ predict_dataset = predict_dataset.select(
404
+ range(training_args.max_predict_samples))
405
+
406
+ return train_dataset, eval_dataset, predict_dataset
src/train.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shared import (
2
+ CustomTokens,
3
+ DatasetArguments,
4
+ prepare_datasets,
5
+ load_datasets,
6
+ CustomTrainingArguments,
7
+ get_last_checkpoint,
8
+ train_from_checkpoint
9
+ )
10
+ from model import ModelArguments
11
+ import transformers
12
+ import logging
13
+ import os
14
+ import sys
15
+ from datasets import utils as d_utils
16
+ from transformers import (
17
+ DataCollatorForSeq2Seq,
18
+ HfArgumentParser,
19
+ Seq2SeqTrainer,
20
+ Seq2SeqTrainingArguments,
21
+ )
22
+
23
+ from transformers.utils import check_min_version
24
+ from transformers.utils.versions import require_version
25
+ from dataclasses import dataclass
26
+
27
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
28
+ check_min_version('4.17.0')
29
+ require_version('datasets>=1.8.0',
30
+ 'To fix: pip install -r requirements.txt')
31
+
32
+ os.environ['WANDB_DISABLED'] = 'true'
33
+
34
+ logging.basicConfig()
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # Setup logging
38
+ logging.basicConfig(
39
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
40
+ datefmt='%m/%d/%Y %H:%M:%S',
41
+ handlers=[logging.StreamHandler(sys.stdout)],
42
+ )
43
+
44
+
45
+ @dataclass
46
+ class Seq2SeqTrainingArguments(CustomTrainingArguments, Seq2SeqTrainingArguments):
47
+ pass
48
+
49
+
50
+ def main():
51
+
52
+ # See all possible arguments in src/transformers/training_args.py
53
+ # or by passing the --help flag to this script.
54
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
55
+
56
+ hf_parser = HfArgumentParser((
57
+ ModelArguments,
58
+ DatasetArguments,
59
+ Seq2SeqTrainingArguments
60
+ ))
61
+ model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
62
+
63
+ log_level = training_args.get_process_log_level()
64
+ logger.setLevel(log_level)
65
+ d_utils.logging.set_verbosity(log_level)
66
+ transformers.utils.logging.set_verbosity(log_level)
67
+ transformers.utils.logging.enable_default_handler()
68
+ transformers.utils.logging.enable_explicit_format()
69
+
70
+ # Set seed before initializing model.
71
+ # set_seed(training_args.seed)
72
+
73
+ # Log on each process the small summary:
74
+ logger.warning(
75
+ f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
76
+ + f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
77
+ )
78
+ logger.info(f'Training/evaluation parameters {training_args}')
79
+
80
+ # FP16 https://github.com/huggingface/transformers/issues/9295
81
+
82
+ # Works:
83
+ # https://huggingface.co/docs/transformers/model_doc/t5v1.1
84
+ # google/t5-v1_1-small
85
+ # google/t5-v1_1-base
86
+ # google/t5-v1_1-large
87
+ # google/t5-v1_1-xl
88
+ # google/t5-v1_1-xxl
89
+
90
+ # https://huggingface.co/docs/transformers/model_doc/t5
91
+ # t5-small
92
+ # t5-base
93
+ # t5-large
94
+ # t5-3b
95
+ # t5-11b
96
+
97
+ # allenai/led-base-16384 - https://github.com/huggingface/transformers/issues/9810
98
+
99
+ # Further work:
100
+ # Multilingual- https://huggingface.co/docs/transformers/model_doc/mt5
101
+
102
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
103
+ # download the dataset.
104
+ raw_datasets = load_datasets(dataset_args)
105
+ # , cache_dir=model_args.cache_dir
106
+
107
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
108
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
109
+
110
+ # Detecting last checkpoint.
111
+ last_checkpoint = get_last_checkpoint(training_args)
112
+
113
+ from model import get_model_tokenizer
114
+ model, tokenizer = get_model_tokenizer(model_args, training_args)
115
+
116
+ # Preprocessing the datasets.
117
+ # We need to tokenize inputs and targets.
118
+
119
+ prefix = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value
120
+
121
+ PAD_TOKEN_REPLACE_ID = -100
122
+
123
+ # https://github.com/huggingface/transformers/issues/5204
124
+ def preprocess_function(examples):
125
+ inputs = examples['text']
126
+ targets = examples['extracted']
127
+ inputs = [prefix + inp for inp in inputs]
128
+ model_inputs = tokenizer(inputs, truncation=True)
129
+
130
+ # Setup the tokenizer for targets
131
+ with tokenizer.as_target_tokenizer():
132
+ labels = tokenizer(targets, truncation=True)
133
+
134
+ # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100
135
+ # when we want to ignore padding in the loss.
136
+
137
+ model_inputs['labels'] = [
138
+ [(l if l != tokenizer.pad_token_id else PAD_TOKEN_REPLACE_ID)
139
+ for l in label]
140
+ for label in labels['input_ids']
141
+ ]
142
+
143
+ return model_inputs
144
+
145
+ train_dataset, eval_dataset, predict_dataset = prepare_datasets(
146
+ raw_datasets, dataset_args, training_args, preprocess_function)
147
+
148
+ # Data collator
149
+ data_collator = DataCollatorForSeq2Seq(
150
+ tokenizer,
151
+ model=model,
152
+ label_pad_token_id=PAD_TOKEN_REPLACE_ID,
153
+ pad_to_multiple_of=8 if training_args.fp16 else None,
154
+ )
155
+
156
+ # Done processing datasets
157
+
158
+ # Initialize our Trainer
159
+ trainer = Seq2SeqTrainer(
160
+ model=model,
161
+ args=training_args,
162
+ train_dataset=train_dataset,
163
+ eval_dataset=eval_dataset,
164
+ tokenizer=tokenizer,
165
+ data_collator=data_collator,
166
+ )
167
+
168
+ # Training
169
+ train_result = train_from_checkpoint(
170
+ trainer, last_checkpoint, training_args)
171
+
172
+ metrics = train_result.metrics
173
+ max_train_samples = training_args.max_train_samples or len(
174
+ train_dataset)
175
+ metrics['train_samples'] = min(max_train_samples, len(train_dataset))
176
+
177
+ trainer.log_metrics('train', metrics)
178
+ trainer.save_metrics('train', metrics)
179
+ trainer.save_state()
180
+
181
+ kwargs = {'finetuned_from': model_args.model_name_or_path,
182
+ 'tasks': 'summarization'}
183
+
184
+ if training_args.push_to_hub:
185
+ trainer.push_to_hub(**kwargs)
186
+ else:
187
+ trainer.create_model_card(**kwargs)
188
+
189
+
190
+ if __name__ == '__main__':
191
+ main()
src/train_classifier.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """ Finetuning the library models for sequence classification."""
3
+
4
+ import logging
5
+ import os
6
+ import sys
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+
10
+ import datasets
11
+ import numpy as np
12
+
13
+ import transformers
14
+ from transformers import (
15
+ DataCollatorWithPadding,
16
+ EvalPrediction,
17
+ HfArgumentParser,
18
+ Trainer,
19
+ TrainingArguments,
20
+ set_seed,
21
+ )
22
+ from transformers.utils import check_min_version
23
+ from transformers.utils.versions import require_version
24
+ from shared import (
25
+ CATEGORIES,
26
+ DatasetArguments,
27
+ prepare_datasets,
28
+ load_datasets,
29
+ CustomTrainingArguments,
30
+ train_from_checkpoint,
31
+ get_last_checkpoint
32
+ )
33
+ from model import get_model_tokenizer, ModelArguments
34
+
35
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
36
+ check_min_version('4.17.0')
37
+ require_version('datasets>=1.8.0', 'To fix: pip install -r requirements.txt')
38
+
39
+ os.environ['WANDB_DISABLED'] = 'true'
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ @dataclass
45
+ class ClassifierTrainingArguments(CustomTrainingArguments, TrainingArguments):
46
+ pass
47
+
48
+
49
+ @dataclass
50
+ class ClassifierDatasetArguments(DatasetArguments):
51
+ def __post_init__(self):
52
+ self.train_file = self.c_train_file
53
+ self.validation_file = self.c_validation_file
54
+ self.test_file = self.c_test_file
55
+
56
+
57
+ def main():
58
+ # See all possible arguments in src/transformers/training_args.py
59
+ # or by passing the --help flag to this script.
60
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
61
+
62
+ hf_parser = HfArgumentParser((
63
+ ModelArguments,
64
+ ClassifierDatasetArguments,
65
+ ClassifierTrainingArguments
66
+ ))
67
+ model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
68
+
69
+ # Setup logging
70
+ logging.basicConfig(
71
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
72
+ datefmt='%m/%d/%Y %H:%M:%S',
73
+ handlers=[logging.StreamHandler(sys.stdout)],
74
+ )
75
+
76
+ log_level = training_args.get_process_log_level()
77
+ logger.setLevel(log_level)
78
+ datasets.utils.logging.set_verbosity(log_level)
79
+ transformers.utils.logging.set_verbosity(log_level)
80
+ transformers.utils.logging.enable_default_handler()
81
+ transformers.utils.logging.enable_explicit_format()
82
+
83
+ # Log on each process the small summary:
84
+ logger.warning(
85
+ f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
86
+ + f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
87
+ )
88
+ logger.info(f'Training/evaluation parameters {training_args}')
89
+
90
+ # Detecting last checkpoint.
91
+ last_checkpoint = get_last_checkpoint(training_args)
92
+
93
+ # Set seed before initializing model.
94
+ set_seed(training_args.seed)
95
+
96
+ # Loading a dataset from your local files.
97
+ # CSV/JSON training and evaluation files are needed.
98
+ raw_datasets = load_datasets(dataset_args)
99
+
100
+ # See more about loading any type of standard or custom dataset at
101
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
102
+
103
+ config_args = {
104
+ 'num_labels': len(CATEGORIES),
105
+ 'id2label': {k: str(v).upper() for k, v in enumerate(CATEGORIES)},
106
+ 'label2id': {str(v).upper(): k for k, v in enumerate(CATEGORIES)}
107
+ }
108
+ model, tokenizer = get_model_tokenizer(
109
+ model_args, training_args, config_args=config_args, model_type='classifier')
110
+
111
+ if training_args.max_seq_length > tokenizer.model_max_length:
112
+ logger.warning(
113
+ f'The max_seq_length passed ({training_args.max_seq_length}) is larger than the maximum length for the'
114
+ f'model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}.'
115
+ )
116
+ max_seq_length = min(training_args.max_seq_length,
117
+ tokenizer.model_max_length)
118
+
119
+ def preprocess_function(examples):
120
+ # Tokenize the texts
121
+ result = tokenizer(
122
+ examples['text'], padding='max_length', max_length=max_seq_length, truncation=True)
123
+ result['label'] = examples['label']
124
+ return result
125
+
126
+ train_dataset, eval_dataset, predict_dataset = prepare_datasets(
127
+ raw_datasets, dataset_args, training_args, preprocess_function)
128
+
129
+ # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
130
+ # predictions and label_ids field) and has to return a dictionary string to float.
131
+ def compute_metrics(p: EvalPrediction):
132
+ preds = p.predictions[0] if isinstance(
133
+ p.predictions, tuple) else p.predictions
134
+ preds = np.argmax(preds, axis=1)
135
+ return {'accuracy': (preds == p.label_ids).astype(np.float32).mean().item()}
136
+
137
+ # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
138
+ # we already did the padding.
139
+ if training_args.fp16:
140
+ data_collator = DataCollatorWithPadding(
141
+ tokenizer, pad_to_multiple_of=8)
142
+ else:
143
+ data_collator = None
144
+
145
+ # Initialize our Trainer
146
+ trainer = Trainer(
147
+ model=model,
148
+ args=training_args,
149
+ train_dataset=train_dataset,
150
+ eval_dataset=eval_dataset,
151
+ compute_metrics=compute_metrics,
152
+ tokenizer=tokenizer,
153
+ data_collator=data_collator,
154
+ )
155
+
156
+ # Training
157
+ train_result = train_from_checkpoint(
158
+ trainer, last_checkpoint, training_args)
159
+
160
+ metrics = train_result.metrics
161
+ max_train_samples = (
162
+ training_args.max_train_samples if training_args.max_train_samples is not None else len(
163
+ train_dataset)
164
+ )
165
+ metrics['train_samples'] = min(max_train_samples, len(train_dataset))
166
+
167
+ trainer.save_model() # Saves the tokenizer too for easy upload
168
+
169
+ trainer.log_metrics('train', metrics)
170
+ trainer.save_metrics('train', metrics)
171
+ trainer.save_state()
172
+
173
+ kwargs = {'finetuned_from': model_args.model_name_or_path,
174
+ 'tasks': 'text-classification'}
175
+ if training_args.push_to_hub:
176
+ trainer.push_to_hub(**kwargs)
177
+ else:
178
+ trainer.create_model_card(**kwargs)
179
+
180
+
181
+ if __name__ == '__main__':
182
+ main()
src/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import locale
4
+ import io
5
+
6
+ def re_findall(pattern, string):
7
+ return [m.groupdict() for m in re.finditer(pattern, string)]
8
+
9
+
10
+ def jaccard(x1, x2, y1, y2):
11
+ # Calculate jaccard index
12
+ intersection = max(0, min(x2, y2)-max(x1, y1))
13
+ filled_union = max(x2, y2) - min(x1, y1)
14
+ return intersection/filled_union if filled_union > 0 else 0
15
+
16
+
17
+ def regex_search(text, pattern, group=1, default=None):
18
+ match = re.search(pattern, text)
19
+ return match.group(group) if match else default
20
+
21
+
22
+ def _windows_write_string(s, out, skip_errors=True):
23
+ """ Returns True if the string was written using special methods,
24
+ False if it has yet to be written out."""
25
+ # Adapted from http://stackoverflow.com/a/3259271/35070
26
+
27
+ import ctypes
28
+ import ctypes.wintypes
29
+
30
+ WIN_OUTPUT_IDS = {
31
+ 1: -11,
32
+ 2: -12,
33
+ }
34
+
35
+ try:
36
+ fileno = out.fileno()
37
+ except AttributeError:
38
+ # If the output stream doesn't have a fileno, it's virtual
39
+ return False
40
+ except io.UnsupportedOperation:
41
+ # Some strange Windows pseudo files?
42
+ return False
43
+ if fileno not in WIN_OUTPUT_IDS:
44
+ return False
45
+
46
+ GetStdHandle = ctypes.WINFUNCTYPE(
47
+ ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD)(
48
+ ('GetStdHandle', ctypes.windll.kernel32))
49
+ h = GetStdHandle(WIN_OUTPUT_IDS[fileno])
50
+
51
+ WriteConsoleW = ctypes.WINFUNCTYPE(
52
+ ctypes.wintypes.BOOL, ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR,
53
+ ctypes.wintypes.DWORD, ctypes.POINTER(ctypes.wintypes.DWORD),
54
+ ctypes.wintypes.LPVOID)(('WriteConsoleW', ctypes.windll.kernel32))
55
+ written = ctypes.wintypes.DWORD(0)
56
+
57
+ GetFileType = ctypes.WINFUNCTYPE(ctypes.wintypes.DWORD, ctypes.wintypes.DWORD)(
58
+ ('GetFileType', ctypes.windll.kernel32))
59
+ FILE_TYPE_CHAR = 0x0002
60
+ FILE_TYPE_REMOTE = 0x8000
61
+ GetConsoleMode = ctypes.WINFUNCTYPE(
62
+ ctypes.wintypes.BOOL, ctypes.wintypes.HANDLE,
63
+ ctypes.POINTER(ctypes.wintypes.DWORD))(
64
+ ('GetConsoleMode', ctypes.windll.kernel32))
65
+ INVALID_HANDLE_VALUE = ctypes.wintypes.DWORD(-1).value
66
+
67
+ def not_a_console(handle):
68
+ if handle == INVALID_HANDLE_VALUE or handle is None:
69
+ return True
70
+ return ((GetFileType(handle) & ~FILE_TYPE_REMOTE) != FILE_TYPE_CHAR or GetConsoleMode(handle, ctypes.byref(ctypes.wintypes.DWORD())) == 0)
71
+
72
+ if not_a_console(h):
73
+ return False
74
+
75
+ def next_nonbmp_pos(s):
76
+ try:
77
+ return next(i for i, c in enumerate(s) if ord(c) > 0xffff)
78
+ except StopIteration:
79
+ return len(s)
80
+
81
+ while s:
82
+ count = min(next_nonbmp_pos(s), 1024)
83
+
84
+ ret = WriteConsoleW(
85
+ h, s, count if count else 2, ctypes.byref(written), None)
86
+ if ret == 0:
87
+ if skip_errors:
88
+ continue
89
+ else:
90
+ raise OSError('Failed to write string')
91
+ if not count: # We just wrote a non-BMP character
92
+ assert written.value == 2
93
+ s = s[1:]
94
+ else:
95
+ assert written.value > 0
96
+ s = s[written.value:]
97
+ return True
98
+
99
+ def preferredencoding():
100
+ """Get preferred encoding.
101
+ Returns the best encoding scheme for the system, based on
102
+ locale.getpreferredencoding() and some further tweaks.
103
+ """
104
+ try:
105
+ pref = locale.getpreferredencoding()
106
+ 'TEST'.encode(pref)
107
+ except Exception:
108
+ pref = 'utf-8'
109
+
110
+ return pref
111
+
112
+ def safe_print(*objects, sep=' ', end='\n', out=None, encoding=None, flush=False):
113
+ """
114
+ Ensure printing to standard output can be done safely (especially on Windows).
115
+ There are usually issues with printing emojis and non utf-8 characters.
116
+ """
117
+
118
+ output_string = sep.join(map(lambda x: str(x), objects)) + end
119
+
120
+ if out is None:
121
+ out = sys.stdout
122
+
123
+ if sys.platform == 'win32' and encoding is None and hasattr(out, 'fileno'):
124
+ if _windows_write_string(output_string, out):
125
+ return
126
+
127
+ if 'b' in getattr(out, 'mode', '') or not hasattr(out, 'buffer'):
128
+ out.write(output_string)
129
+ else:
130
+ enc = encoding or getattr(out, 'encoding', None) or preferredencoding()
131
+ byt = output_string.encode(enc, 'ignore')
132
+ out.buffer.write(byt)
133
+
134
+ if flush and hasattr(out, 'flush'):
135
+ out.flush()
transcripts/auto/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Cached transcripts
2
+ # Ignore everything in this directory
3
+ *
4
+ # Except this file
5
+ !.gitignore
transcripts/manual/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Cached transcripts
2
+ # Ignore everything in this directory
3
+ *
4
+ # Except this file
5
+ !.gitignore