File size: 26,268 Bytes
5fbdd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
import itertools
from typing import Optional
from datasets import load_dataset
from model import ModelArguments
import segment
from tqdm import tqdm
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from shared import GeneralArguments, CustomTokens
import csv
import re
import random
import logging
from youtube_transcript_api import YouTubeTranscriptApi
from youtube_transcript_api._errors import CouldNotRetrieveTranscript, YouTubeRequestFailed
import os
import json
import time
import requests
from utils import InterruptibleThreadPool, Job


def find(s, ch):
    return [i for i, ltr in enumerate(s) if ltr == ch]


def wordify(transcript):
    """Try to replicate format for automatically generated transcripts"""
    words = []

    for line_index, line in enumerate(transcript):
        text = line['text'].replace('\n', ' ').strip()
        if not text:
            continue

        start = line['start']
        next_start = transcript[line_index +
                                1]['start'] if line_index < len(transcript) - 1 else float('inf')
        end = min(start + line['duration'], next_start)
        duration = end - start

        indices = find(text, ' ') + [len(text)]
        start_index = 0
        for i in range(len(indices)):
            word = text[start_index:indices[i]].strip()
            if not word:
                continue  # Skip empty words (e.g., \n)
            percentage = start_index/indices[-1]

            w_duration = len(word)/indices[-1] * duration

            w_start = start + percentage * duration

            words.append({
                'start': round(w_start, 5),
                'duration': round(w_duration, 5),
                'end': round(w_start + w_duration, 5),
                'text': word,
            })

            start_index = indices[i] + 1

    return words


def get_manual_words(transcript_list):
    transcript = transcript_list.find_manually_created_transcript(
        ['en-GB', 'en-US', 'en']).fetch()
    return wordify(transcript)


def get_auto_words(transcript_list):
    words = []
    transcript = transcript_list.find_generated_transcript(['en'])
    url = transcript._url + '&fmt=json3'
    info = transcript._http_client.get(url)

    for event in info.json()['events']:
        start_ms = event.get('tStartMs', 0)

        for word in event.get('segs') or []:
            offset_ms = word.get('tOffsetMs', 0)

            texts = word['utf8'].replace(
                CustomTokens.PROFANITY_RAW.value, CustomTokens.PROFANITY_CONVERTED.value
            ).strip().split()

            for text in texts:
                words.append({
                    'start': (start_ms + offset_ms)/1000,
                    'text': text
                })

    return words


def get_words(video_id, process=True, fallback=False, transcript_type='auto'):
    """Get parsed video transcript with caching system
    returns None if not processed yet and process is False
    """
    get_manual_if_fail = fallback and transcript_type == 'auto'
    transcript_path = os.path.join(
        'transcripts', transcript_type, f'{video_id}.json')
    words = []
    try:
        if os.path.exists(transcript_path):
            with open(transcript_path) as fp:
                wds = json.load(fp)

            if not wds and get_manual_if_fail:
                return get_words(video_id, process, fallback, 'manual')
            return wds

        elif not process:
            return None

        transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)

        if transcript_type == 'manual':
            words = get_manual_words(transcript_list)
        else:
            words = get_auto_words(transcript_list)

    except YouTubeRequestFailed as e:
        print(e)
        time.sleep(30)  # Timeout
        return get_words(video_id, process, fallback, transcript_type)

    except CouldNotRetrieveTranscript:
        if get_manual_if_fail:
            print('fallback')
            return get_words(video_id, process, fallback, 'manual')

    except json.decoder.JSONDecodeError:
        # Warning, unable to parse JSON
        pass

    with open(transcript_path, 'w') as fp:
        json.dump(words, fp)

    return words


# TODO make min_sponsor_segment_length param
def extract_sponsors(words, min_sponsor_segment_length=5):
    if len(words) < min_sponsor_segment_length:
        return []  # Force short phrases to not be sponsors

    paragraphs = []
    current = []
    for word in words:
        if not word.get('sponsor') and not current:
            continue

        if word['sponsor']:
            current.append(word['text'])
        else:
            paragraphs.append(current)
            current = []
    if current:
        paragraphs.append(current)

    # Remove all too short:
    paragraphs = list(filter(lambda x: len(
        x) >= min_sponsor_segment_length, paragraphs))

    return paragraphs


def clean_text(text):

    # Replace impossibly long words with a special token
    # Usually the result of incorrect labelling
    text = re.sub(r'\w{64,}', CustomTokens.LONG_WORD.value, text)

    SHORT_HYPHENATED_REGEX = r'\w{1,2}(?:-\w{1,2}){3,}(?:-?\w*)'

    # Replace hyphenated URLs with special token
    # For some reason, youtube sometimes transcribes urls in this form:
    # 'b-a-b-b-e-l-dot-com', 'g-e-t-r-o-m-a-n-com'
    # not 'e-commerce'
    text = re.sub(f'{SHORT_HYPHENATED_REGEX}(?:com|org|net)',
                  CustomTokens.HYPHENATED_URL.value, text)

    # Replace short+hyphenated text with a special token. Of the form:
    # 'i-i-i-i-i-i-i-i-i-i-i-i', 'b-u-m-f-u-z-z-l-e', 'v-e-r-i-t-a-s-i-u-m', 'do-do-do-do-do'
    text = re.sub(SHORT_HYPHENATED_REGEX,
                  CustomTokens.SHORT_HYPHENATED.value, text)

    # Replace URLs with URL_TOKEN
    URL_REGEX = r'(?:(?:http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.(?:[a-zA-Z]){2,6}(?:[a-zA-Z0-9\.\&\/\?\:@\-_=#%])*'
    text = re.sub(URL_REGEX, CustomTokens.URL.value, text)

    NUM_REGEX = r'(?:\d+,)*(?:\d*[.])?\d+'

    # Encode specific numeric words
    # Of the form: 12%, 12.34%
    # Usually included in sponsorships
    text = re.sub(f'{NUM_REGEX}%',
                  CustomTokens.NUMBER_PERCENTAGE.value, text)

    # Normal numbers, should not have an effect on sponsorship
    text = re.sub(NUM_REGEX, CustomTokens.NUMBER.value, text)

    # Replace profanity with special token
    text = text.replace(CustomTokens.PROFANITY_RAW.value,
                        CustomTokens.PROFANITY.value)
    text = text.replace(CustomTokens.PROFANITY_CONVERTED.value,
                        CustomTokens.PROFANITY.value)

    return text.strip()


def remove_duplicate_sponsor_segments(sponsor_segments):
    """Choose the best sponsor segment if overlapping with others"""

    # Algorithm based on SponsorBlock algorithm
    # Find sponsors that are overlapping
    similar = []
    for i in sponsor_segments:
        for j in sponsor_segments:
            # Since we do pairwise, we only check one direction
            if (j['start'] >= i['start'] and j['start'] <= i['end']):
                similar.append([i, j])

    # Within each group, choose the segment with the most votes.
    processed = []
    best = []
    for i in similar:
        if i in processed:
            continue
        group = i
        for j in similar:
            if j[0] in group or j[1] in group:  # If either in, append both
                group.append(j[0])
                group.append(j[1])
                processed.append(j)

        best.append(max(group, key=lambda item: (
            item['votes'], item['reputation'], item['views'])))

    return best


@dataclass
class PreprocessArguments:
    """
    Arguments pertaining to what data we are going to preprocess.
    """
    update_database: bool = field(
        default=False, metadata={'help': 'Download the raw database.'}
    )

    do_create: bool = field(
        default=False, metadata={'help': 'Merge sponsor segments into single file'}
    )
    min_votes: int = field(
        default=0, metadata={'help': 'Minimum number of votes'})
    # Downvotes will make this negative.
    # 1 = At least one positive vote

    do_transcribe: bool = field(
        default=False, metadata={'help': 'Get transcripts for videos'}
    )
    num_jobs: int = field(
        default=4, metadata={'help': 'Number of transcripts to download in parallel'})

    overwrite: bool = field(
        default=False, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
    )

    do_generate: bool = field(
        default=False, metadata={'help': 'Generate labelled data.'}
    )

    do_split: bool = field(
        default=False, metadata={'help': 'Generate training, testing and validation data.'}
    )
    percentage_positive: float = field(
        default=0.5, metadata={'help': 'Ratio of positive (sponsor) segments to include in final output'})

    train_split: float = field(
        default=0.9, metadata={'help': 'Ratio of training data. Value between 0 and 1.'})

    # TODO play around with ratios? lower test/validation split?
    test_split: float = field(
        default=0.05, metadata={'help': 'Ratio of testing data. Value between 0 and 1.'})
    valid_split: float = field(
        default=0.05, metadata={'help': 'Ratio of validation data. Value between 0 and 1.'})

    skip_videos: int = field(default=None, metadata={
        'help': 'Number of videos to skip. Set this to the latest video index to append to the current file'})

    max_videos: int = field(default=None, metadata={
        'help': 'Maximum number of videos to preprocess.'})

    max_segments: int = field(default=None, metadata={
        'help': 'Maximum number of segments to produce to preprocess.'})

    raw_data_dir: Optional[str] = field(
        default='raw',
        metadata={
            'help': 'Raw data directory'
        },
    )
    raw_data_file: Optional[str] = field(
        default='sponsorTimes.csv',
        metadata={
            'help': 'Raw data file'
        },
    )

    min_wps: float = field(
        default=0.4, metadata={'help': 'Ignore videos with not enough words spoken per second. This is usually indicitive of video whose captions aren\'t English.'})
    # 0.1 ~ 1%
    # 0.4 ~ 2.5%
    # 0.9 ~ 5%


# Mirrors for database
MIRRORS = [
    'https://sponsor.ajay.app/database/sponsorTimes.csv',  # Latest
    'https://sb-mirror.mchang.xyz/sponsorTimes.csv',  # 5 minute delay
    'https://sb.ltn.fi/database/sponsorTimes.csv',  # 5 minute delay
]
# TODO only download latest (updates/changes)


def download_file(url, filename):
    """
    Helper method handling downloading large files from `url` to `filename`.

    Adapted from https://stackoverflow.com/a/42071418
    """
    chunk_size = 1024
    r = requests.get(url, stream=True)
    total_bytes = int(r.headers['Content-Length'])
    with open(filename, 'wb') as f, tqdm(unit='B', total=total_bytes) as progress:
        for chunk in r.iter_content(chunk_size=chunk_size):
            if chunk:  # filter out keep-alive new chunks
                progress.update(len(chunk))
                f.write(chunk)

    return total_bytes == os.path.getsize(filename)


@dataclass
class ProcessedArguments:
    processed_dir: Optional[str] = field(
        default='processed',
        metadata={
            'help': 'Processed data directory'
        },
    )
    processed_file: Optional[str] = field(
        default='final.json',
        metadata={
            'help': 'Processed data file'
        },
    )


def load_datasets(dataset_args):
    print('Reading datasets')
    data_files = {}

    if dataset_args.train_file is not None:
        data_files['train'] = os.path.join(
            dataset_args.data_dir, dataset_args.train_file)
    if dataset_args.validation_file is not None:
        data_files['validation'] = os.path.join(
            dataset_args.data_dir, dataset_args.validation_file)
    if dataset_args.test_file is not None:
        data_files['test'] = os.path.join(
            dataset_args.data_dir, dataset_args.test_file)

    return load_dataset('json', data_files=data_files)


@dataclass
class DatasetArguments:
    data_dir: Optional[str] = field(
        default='data',
        metadata={
            'help': 'The directory which stores train, test and/or validation data.'
        },
    )

    train_file: Optional[str] = field(
        default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
    )
    validation_file: Optional[str] = field(
        default='valid.json',
        metadata={
            'help': 'An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines file).'
        },
    )
    test_file: Optional[str] = field(
        default='test.json',
        metadata={
            'help': 'An optional input test data file to evaluate the metrics (rouge) on (a jsonlines file).'
        },
    )
    excess_file: Optional[str] = field(
        default='excess.json',
        metadata={
            'help': 'The excess segments left after the split'
        },
    )

    overwrite_cache: bool = field(
        default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'}
    )

    positive_file: Optional[str] = field(
        default='sponsor_segments.json', metadata={'help': 'File to output sponsored segments to (a jsonlines file).'}
    )
    negative_file: Optional[str] = field(
        default='normal_segments.json', metadata={'help': 'File to output normal segments to (a jsonlines file).'}
    )

    def __post_init__(self):
        # TODO check if train/validation datasets exist
        if self.train_file is None and self.validation_file is None:
            raise ValueError(
                'Need either a dataset name or a training/validation file.')


def main():
    # Responsible for getting transcrips using youtube_transcript_api,
    # then labelling it according to SponsorBlock's API

    logging.getLogger().setLevel(logging.INFO)  # TODO make param

    # Generate final.json from sponsorTimes.csv
    hf_parser = HfArgumentParser((
        PreprocessArguments,
        ProcessedArguments,
        DatasetArguments,
        segment.SegmentationArguments,
        ModelArguments,
        GeneralArguments
    ))
    preprocess_args, processed_args, dataset_args, segmentation_args, model_args, _ = hf_parser.parse_args_into_dataclasses()

    raw_dataset_path = os.path.join(
        preprocess_args.raw_data_dir, preprocess_args.raw_data_file)

    def get_rows():
        with open(raw_dataset_path, newline='') as csvfile:
            reader = csv.DictReader(csvfile)
            for line in reader:
                if line['service'] != 'YouTube':
                    continue

                # TODO add support for other categories and action types?
                if line['category'] != 'sponsor':
                    continue
                if line['actionType'] != 'skip':
                    continue

                # Ignore hidden items
                if line['hidden'] == '1' or line['shadowHidden'] == '1':
                    continue

                if len(line['videoID']) != 11:
                    continue  # Invalid youtube video ID

                # Skip those that aren't highly voted
                line['votes'] = int(line['votes'])
                # incorrect_votes = int(line['incorrectVotes'])

                if line['votes'] < preprocess_args.min_votes:
                    continue

                yield line

    if preprocess_args.update_database:
        print('Updating database')
        for mirror in MIRRORS:
            print('Downloading from', mirror)
            if download_file(mirror, raw_dataset_path):
                break
            print('Failed, trying next')

    # 'videoID', 'startTime', 'endTime', 'votes', 'locked', 'incorrectVotes', 'UUID',
    # 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
    # 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
    data_rows = None
    if preprocess_args.do_transcribe:
        print('Collecting videos')
        video_ids = set()
        data_rows = get_rows()
        for row in data_rows:
            video_ids.add(row['videoID'])

        print('Start transcribing')
        with tqdm(total=len(video_ids)) as progress:
            def on_job_complete(job):
                progress.set_description(f'Processed {job.video_id}')
                progress.update()

            pool = InterruptibleThreadPool(
                preprocess_args.num_jobs, on_job_complete=on_job_complete)

            print('Adding jobs to pool')
            for video_id in video_ids:
                job = Job(get_words, video_id)
                job.video_id = video_id
                pool.add_job(job)

            print('Start processing')
            pool.run()

        print('Finished transcribing')

    final_path = os.path.join(
        processed_args.processed_dir, processed_args.processed_file)

    if os.path.exists(final_path) and not preprocess_args.do_create:
        logging.info(f'{final_path} exists, opening file')
        with open(final_path) as fp:
            final_data = json.load(fp)
    else:
        print('Create final data')

        final_data = {}

        if data_rows is None:
            data_rows = get_rows()

        # TODO add progress bar
        # TODO parallelise?
        for line in data_rows:
            video_id = line['videoID']

            if video_id not in final_data:
                final_data[video_id] = []

            segment_start = float(line['startTime'])
            segment_end = float(line['endTime'])

            video_words = get_words(video_id, process=True)
            segment_words = segment.extract_segment(
                video_words, segment_start, segment_end)

            if len(segment_words) <= 1:
                continue  # Useless to add segment since no words

            # duration = segment.word_end(segment_words[-1]) - segment.word_start(segment_words[0])
            duration = segment_end - segment_start
            wps = len(segment_words)/duration if duration > 0 else 0

            if wps < preprocess_args.min_wps:
                print('bad segment in', video_id, '| wps =', wps)
                continue

            final_data[video_id].append({
                'start': segment_start,
                'end': segment_end,
                'votes': line['votes'],
                'locked': line['locked'] == '1',
                'views': line['views'],
                'reputation': line['reputation'],
                'category': line['category'],
                'action': line['actionType'],
                'uuid': line['UUID'],
            })

        # Remove duplicate sponsor segments by choosing best (most votes)
        for key in final_data:
            final_data[key] = remove_duplicate_sponsor_segments(
                final_data[key])

        # Save data
        with open(final_path, 'w') as fp:
            json.dump(final_data, fp)

        # final_data = preprocess(
        #     raw_dataset_path, final_path, preprocess_args.min_votes)
        # # TODO save metadata in final.json?

    logging.info(f'Found {len(final_data)} videos')

    # TODO shuffle final_data

    # if not os.path.exists(excess_path) or preprocess_args.overwrite
    # TODO use overwrite param

    os.makedirs(dataset_args.data_dir, exist_ok=True)

    positive_file = os.path.join(
        dataset_args.data_dir, dataset_args.positive_file)
    negative_file = os.path.join(
        dataset_args.data_dir, dataset_args.negative_file)

    if preprocess_args.do_generate:
        print('Generating')
        from model import get_tokenizer

        # max_videos=preprocess_args.max_videos,
        # max_segments=preprocess_args.max_segments,
        # , max_videos, max_segments

        tokenizer = get_tokenizer(model_args)

        count_videos = 0
        count_segments = 0  # TODO

        write_mode = 'w' if preprocess_args.overwrite else 'a'

        get_all = preprocess_args.max_videos is None
        if get_all:
            total = len(final_data)
        else:
            total = preprocess_args.max_videos

        index = 0
        data = final_data.items()
        if preprocess_args.skip_videos is not None:
            print('Skipping first', preprocess_args.skip_videos, 'videos')
            data = itertools.islice(data, preprocess_args.skip_videos, None)
            index = preprocess_args.skip_videos

            if get_all:
                total = max(0, total - preprocess_args.skip_videos)
            else:
                total = min(len(final_data) -
                            preprocess_args.skip_videos, total)

        with open(positive_file, write_mode, encoding='utf-8') as positive, \
                open(negative_file, write_mode, encoding='utf-8') as negative, \
                tqdm(total=total) as progress:

            for video_id, sponsor_segments in data:
                index += 1  # TODO FIX index + incrementing
                progress.set_description(f'Processing {video_id}')

                if get_all:
                    progress.update()
                elif count_videos >= preprocess_args.max_videos:
                    break

                words = get_words(video_id, False)
                if not words:
                    continue

                num_words = len(words)
                if num_words <= 1:
                    continue

                # TODO only count words that aren't [Music], [Applause], etc.

                segments = segment.generate_labelled_segments(
                    words, tokenizer, segmentation_args, sponsor_segments)

                if not segments:
                    continue

                count_videos += 1
                if not get_all:
                    progress.update()

                for seg in segments:

                    segment_text = ' '.join((x['text'] for x in seg))

                    extracted_text = ''
                    for p in extract_sponsors(seg):
                        p_text = ' '.join(p)
                        extracted_text += f'{CustomTokens.START_SPONSOR.value} {p_text} {CustomTokens.END_SPONSOR.value}. '

                    duration = segment.word_end(
                        seg[-1]) - segment.word_start(seg[0])
                    wps = len(seg)/duration if duration > 0 else 0
                    # Ignore segments with "not enough words" in the transcript
                    if wps < preprocess_args.min_wps:
                        continue

                    d = {
                        'video_index': index,
                        'video_id': video_id,
                        'text': clean_text(segment_text),
                        'words_per_second': wps,
                    }

                    d['sponsor'] = bool(extracted_text)
                    d['extracted'] = clean_text(
                        extracted_text) if d['sponsor'] else CustomTokens.NO_SPONSOR.value

                    print(json.dumps(d), file=(
                        positive if d['sponsor'] else negative))

    if preprocess_args.do_split:
        print('Splitting')
        print('Read files')

        with open(positive_file, encoding='utf-8') as positive:
            sponsors = positive.readlines()

        with open(negative_file, encoding='utf-8') as negative:
            non_sponsors = negative.readlines()

        print('Shuffle')
        random.shuffle(sponsors)
        random.shuffle(non_sponsors)

        print('Calculate ratios')
        # Ensure correct ratio of positive to negative segments
        percentage_negative = 1 - preprocess_args.percentage_positive

        if preprocess_args.percentage_positive * len(sponsors) > len(non_sponsors):
            # Negative is limiting
            z = int(preprocess_args.percentage_positive /
                    percentage_negative * len(non_sponsors))

            excess = sponsors[z:]
            sponsors = sponsors[:z]

        else:
            # Positive is limiting
            z = int(percentage_negative /
                    preprocess_args.percentage_positive * len(sponsors))

            excess = non_sponsors[z:]
            non_sponsors = non_sponsors[:z]

        print('Join')
        all_labelled_segments = sponsors + non_sponsors

        random.shuffle(all_labelled_segments)

        print('Split')
        ratios = [preprocess_args.train_split,
                  preprocess_args.test_split,
                  preprocess_args.valid_split]

        train_data, test_data, valid_data = split(
            all_labelled_segments, ratios)

        splits = {
            dataset_args.train_file: train_data,
            dataset_args.test_file: test_data,
            dataset_args.validation_file: valid_data
        }

        # Output training, testing and validation data
        for name, items in splits.items():
            outfile = os.path.join(dataset_args.data_dir, name)
            if not os.path.exists(outfile) or preprocess_args.overwrite:
                with open(outfile, 'w', encoding='utf-8') as fp:
                    fp.writelines(items)
            else:
                print('Skipping', name)

        print('Write')
        # Save excess items
        excess_path = os.path.join(
            dataset_args.data_dir, dataset_args.excess_file)
        if not os.path.exists(excess_path) or preprocess_args.overwrite:
            with open(excess_path, 'w', encoding='utf-8') as fp:
                fp.writelines(excess)
        else:
            print('Skipping', dataset_args.excess_file)

        print('Finished splitting:', len(sponsors),
              'sponsors,', len(non_sponsors), 'non sponsors')


def split(arr, ratios):
    """Split array according to ratios. Sum of ratios should be less than 1"""

    to_return = []

    cumulative_sum = 0
    for r in ratios:
        current = cumulative_sum

        cumulative_sum += r * len(arr)
        to_return.append(arr[int(current):int(cumulative_sum)])

    return to_return


if __name__ == '__main__':
    main()