Spaces:
Runtime error
Runtime error
from utils import jaccard | |
from shared import START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE | |
from functools import lru_cache | |
from datetime import datetime | |
import itertools | |
from typing import Optional, List | |
from datasets import load_dataset | |
from model import ModelArguments | |
import segment | |
from tqdm import tqdm | |
from dataclasses import dataclass, field | |
from transformers import HfArgumentParser | |
from shared import GeneralArguments, CustomTokens | |
import csv | |
import re | |
import random | |
import logging | |
from youtube_transcript_api import YouTubeTranscriptApi | |
from youtube_transcript_api._errors import CouldNotRetrieveTranscript, YouTubeRequestFailed, TooManyRequests | |
import os | |
import json | |
import time | |
import requests | |
from utils import Task, InterruptibleTaskPool | |
def find(s, ch): | |
return [i for i, ltr in enumerate(s) if ltr == ch] | |
def wordify(transcript, maximum_wps=1): | |
"""Try to replicate format for automatically generated transcripts""" | |
# Do not allow segments to be on screen for too long using maximum_wps | |
words = [] | |
for line_index, line in enumerate(transcript): | |
text = line['text'].replace('\n', ' ').strip() | |
if not text: | |
continue | |
start = line['start'] | |
next_start = transcript[line_index + 1]['start'] \ | |
if line_index < len(transcript) - 1 else float('inf') | |
# Use maximum wps to calculate latest end (to avoid segments which stay on screen too long) | |
longest_duration = maximum_wps * text.count(' ') | |
latest_end = start + longest_duration | |
end = min(start + line['duration'], next_start, latest_end) | |
duration = end - start | |
indices = find(text, ' ') + [len(text)] | |
start_index = 0 | |
for i in range(len(indices)): | |
word = text[start_index:indices[i]].strip() | |
if not word: | |
continue # Skip empty words (e.g., \n) | |
percentage = start_index/indices[-1] | |
w_duration = len(word)/indices[-1] * duration | |
w_start = start + percentage * duration | |
words.append({ | |
'start': round(w_start, 3), | |
'duration': round(w_duration, 3), | |
'end': round(w_start + w_duration, 3), | |
'text': word, | |
}) | |
start_index = indices[i] + 1 | |
return words | |
def get_manual_words(transcript_list): | |
transcript = transcript_list.find_manually_created_transcript( | |
['en-GB', 'en-US', 'en']).fetch() | |
return wordify(transcript) | |
PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity | |
PROFANITY_CONVERTED = '*****' # Safer version for tokenizing | |
def get_auto_words(transcript_list): | |
words = [] | |
transcript = transcript_list.find_generated_transcript(['en']) | |
url = transcript._url + '&fmt=json3' | |
info = transcript._http_client.get(url) | |
for event in info.json()['events']: | |
start_ms = event.get('tStartMs', 0) | |
for word in event.get('segs') or []: | |
offset_ms = word.get('tOffsetMs', 0) | |
texts = word['utf8'].replace( | |
PROFANITY_RAW, PROFANITY_CONVERTED | |
).strip().split() | |
for text in texts: | |
words.append({ | |
'start': (start_ms + offset_ms)/1000, | |
'text': text | |
}) | |
return words | |
def list_transcripts(video_id): | |
return YouTubeTranscriptApi.list_transcripts(video_id) | |
def get_words(video_id, process=True, fallback=True, transcript_type='auto'): | |
"""Get parsed video transcript with caching system | |
returns None if not processed yet and process is False | |
""" | |
get_manual_if_fail = fallback and transcript_type == 'auto' | |
transcript_path = os.path.join( # TODO use relative path to this | |
'transcripts', transcript_type, f'{video_id}.json') | |
words = [] | |
try: | |
if os.path.exists(transcript_path): # Load from file | |
with open(transcript_path) as fp: | |
words = json.load(fp) | |
elif process: | |
transcript_list = list_transcripts(video_id) | |
if transcript_type == 'manual': | |
words = get_manual_words(transcript_list) | |
else: | |
words = get_auto_words(transcript_list) | |
except (TooManyRequests, YouTubeRequestFailed, requests.exceptions.ConnectionError) as e: # Can retry | |
print(e) | |
time.sleep(10) # Timeout | |
return get_words(video_id, process, fallback, transcript_type) | |
except CouldNotRetrieveTranscript: | |
pass | |
except json.decoder.JSONDecodeError: | |
print('JSONDecodeError for', video_id) | |
os.remove(transcript_path) # Remove file and try again | |
return get_words(video_id, process, fallback, transcript_type) | |
# Even save empty | |
with open(transcript_path, 'w') as fp: | |
json.dump(words, fp) | |
if not words and get_manual_if_fail: | |
return get_words(video_id, process, fallback, 'manual') | |
return words | |
# TODO make min_sponsor_segment_length param | |
def extract_sponsors(words, min_sponsor_segment_length=3): | |
if not words: | |
return [] | |
paragraphs = [] | |
current = [] | |
prev_category = None | |
i = 0 | |
while i <= len(words): | |
unimportant = i == len(words) or words[i]['category'] is None | |
if unimportant or words[i]['category'] != prev_category: | |
if current: # Save the current batch | |
paragraphs.append({ | |
'words': current, | |
'category': current[-1]['category'], | |
}) | |
current = [] | |
if not unimportant: # Some useful information to save | |
current.append(words[i]) | |
prev_category = words[i]['category'] | |
i += 1 | |
# Remove all too short: | |
return list(filter(lambda x: len(x['words']) >= min_sponsor_segment_length, paragraphs)) | |
def clean_text(text): | |
# Replace impossibly long words with a special token | |
# Usually the result of incorrect labelling | |
text = re.sub(r'\w{64,}', CustomTokens.LONG_WORD.value, text) | |
SHORT_HYPHENATED_REGEX = r'\w{1,2}(?:-\w{1,2}){3,}(?:-?\w*)' | |
# Replace hyphenated URLs with special token | |
# For some reason, youtube sometimes transcribes urls in this form: | |
# 'b-a-b-b-e-l-dot-com', 'g-e-t-r-o-m-a-n-com' | |
# not 'e-commerce' | |
text = re.sub(f'{SHORT_HYPHENATED_REGEX}(?:com|org|net)', | |
CustomTokens.HYPHENATED_URL.value, text) | |
# Replace short+hyphenated text with a special token. Of the form: | |
# 'i-i-i-i-i-i-i-i-i-i-i-i', 'b-u-m-f-u-z-z-l-e', 'v-e-r-i-t-a-s-i-u-m', 'do-do-do-do-do' | |
text = re.sub(SHORT_HYPHENATED_REGEX, | |
CustomTokens.SHORT_HYPHENATED.value, text) | |
# Replace URLs with URL_TOKEN | |
URL_REGEX = r'(?:(?:http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.(?:[a-zA-Z]){2,6}(?:[a-zA-Z0-9\.\&\/\?\:@\-_=#%])*' | |
text = re.sub(URL_REGEX, CustomTokens.URL.value, text) | |
NUM_REGEX = r'(?:\d+,)*(?:\d*[.])?\d+' | |
# Encode specific numeric words | |
# Of the form: 12%, 12.34% | |
# Usually included in sponsorships | |
text = re.sub(f'{NUM_REGEX}%', | |
CustomTokens.NUMBER_PERCENTAGE.value, text) | |
# Normal numbers, should not have an effect on sponsorship | |
text = re.sub(NUM_REGEX, CustomTokens.NUMBER.value, text) | |
# Replace profanity with special token | |
text = text.replace(PROFANITY_RAW, CustomTokens.PROFANITY.value) | |
text = text.replace(PROFANITY_CONVERTED, CustomTokens.PROFANITY.value) | |
return text.strip() | |
def remove_duplicate_segments(segments): | |
# Algorithm based on SponsorBlock algorithm | |
# https://blog.ajay.app/voting-and-pseudo-randomness-or-sponsorblock-or-youtube-sponsorship-segment-blocker | |
# Find sponsors that are overlapping | |
best = [] | |
for i in segments: | |
similar_segments = [] | |
for j in segments: | |
if jaccard(i['start'], i['end'], j['start'], j['end']) > 0.1: # Some overlap | |
similar_segments.append(j) | |
if similar_segments: | |
best_similar_seg = max(similar_segments, key=lambda item: ( | |
item['locked'], | |
item['votes'], | |
item['views'], | |
item['reputation'] | |
)) | |
if best_similar_seg not in best: | |
best.append(best_similar_seg) | |
return best | |
class PreprocessArguments: | |
""" | |
Arguments pertaining to what data we are going to preprocess. | |
""" | |
update_database: bool = field( | |
default=False, metadata={'help': 'Download the raw database.'} | |
) | |
do_create: bool = field( | |
default=False, metadata={'help': 'Merge sponsor segments into single file'} | |
) | |
min_votes: int = field( | |
default=0, metadata={'help': 'Minimum number of votes'}) | |
# Downvotes will make this negative. | |
# 1 = At least one positive vote | |
min_views: int = field( | |
default=5, metadata={'help': 'Minimum number of views a segment must have to be considered. 0 = show all'}) | |
min_date: str = field( | |
# release of v2.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/2.0) | |
default='08/06/2020', | |
# default='20/08/2021', # release of v3.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/3.0) | |
# default='01/10/2020', # No more autovote | |
metadata={'help': 'Only use submissions from after this date'}) | |
# TODO move? | |
categories: str = field( | |
default_factory=lambda: ['sponsor', 'selfpromo', 'interaction'], | |
metadata={ | |
'nargs': '+', | |
'choices': ['intro', 'sponsor', 'interaction'] | |
# 'outro', 'selfpromo', 'preview', | |
# 'poi_highlight', 'filler', 'music_offtopic', | |
# 'moreCategories' | |
} | |
) | |
do_transcribe: bool = field( | |
default=False, metadata={'help': 'Get transcripts for videos'} | |
) | |
num_jobs: int = field( | |
default=4, metadata={'help': 'Number of transcripts to download in parallel'}) | |
# append: bool = field( | |
# default=False, metadata={'help': 'Append to training, testing and validation data, if present.'} | |
# ) | |
do_generate: bool = field( | |
default=False, metadata={'help': 'Generate labelled data.'} | |
) | |
do_split: bool = field( | |
default=False, metadata={'help': 'Generate training, testing and validation data.'} | |
) | |
percentage_positive: float = field( | |
default=0.5, metadata={'help': 'Ratio of positive (sponsor) segments to include in final output'}) | |
train_split: float = field( | |
default=0.9, metadata={'help': 'Ratio of training data. Value between 0 and 1.'}) | |
# TODO play around with ratios? lower test/validation split? | |
test_split: float = field( | |
default=0.05, metadata={'help': 'Ratio of testing data. Value between 0 and 1.'}) | |
valid_split: float = field( | |
default=0.05, metadata={'help': 'Ratio of validation data. Value between 0 and 1.'}) | |
start_index: int = field(default=None, metadata={ | |
'help': 'Video to start at.'}) | |
max_videos: int = field(default=None, metadata={ | |
'help': 'Maximum number of videos to preprocess.'}) | |
max_segments: int = field(default=None, metadata={ | |
'help': 'Maximum number of segments to produce to preprocess.'}) | |
raw_data_dir: Optional[str] = field( | |
default='raw', | |
metadata={ | |
'help': 'Raw data directory' | |
}, | |
) | |
raw_data_file: Optional[str] = field( | |
default='sponsorTimes.csv', | |
metadata={ | |
'help': 'Raw data file' | |
}, | |
) | |
min_wps: float = field( | |
default=1.5, metadata={'help': 'Ignore videos with not enough words spoken per second. This is usually indicitive of video whose captions aren\'t English.'}) | |
# 0.1 ~ 1% | |
# 0.4 ~ 2.5% | |
# 0.9 ~ 5% | |
# Mirrors for database | |
MIRRORS = [ | |
'https://sponsor.ajay.app/database/sponsorTimes.csv', # Latest | |
'https://sb-mirror.mchang.xyz/sponsorTimes.csv', # 5 minute delay | |
'https://sb.ltn.fi/database/sponsorTimes.csv', # 5 minute delay | |
] | |
# TODO only download latest updates/changes | |
def download_file(url, filename): | |
""" | |
Helper method handling downloading large files from `url` to `filename`. | |
Adapted from https://stackoverflow.com/a/42071418 | |
""" | |
chunk_size = 1024 | |
r = requests.get(url, stream=True) | |
total_bytes = int(r.headers['Content-Length']) | |
with open(filename, 'wb') as f, tqdm(unit='B', total=total_bytes) as progress: | |
for chunk in r.iter_content(chunk_size=chunk_size): | |
if chunk: # filter out keep-alive new chunks | |
progress.update(len(chunk)) | |
f.write(chunk) | |
return total_bytes == os.path.getsize(filename) | |
class ProcessedArguments: | |
processed_dir: Optional[str] = field( | |
default='processed', | |
metadata={ | |
'help': 'Processed data directory' | |
}, | |
) | |
processed_file: Optional[str] = field( | |
default='final.json', | |
metadata={ | |
'help': 'Processed data file' | |
}, | |
) | |
def load_datasets(dataset_args): | |
print('Reading datasets') | |
data_files = {} | |
if dataset_args.train_file is not None: | |
data_files['train'] = os.path.join( | |
dataset_args.data_dir, dataset_args.train_file) | |
if dataset_args.validation_file is not None: | |
data_files['validation'] = os.path.join( | |
dataset_args.data_dir, dataset_args.validation_file) | |
if dataset_args.test_file is not None: | |
data_files['test'] = os.path.join( | |
dataset_args.data_dir, dataset_args.test_file) | |
return load_dataset('json', data_files=data_files) | |
class DatasetArguments: | |
data_dir: Optional[str] = field( | |
default='data', | |
metadata={ | |
'help': 'The directory which stores train, test and/or validation data.' | |
}, | |
) | |
train_file: Optional[str] = field( | |
default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'} | |
) | |
validation_file: Optional[str] = field( | |
default='valid.json', | |
metadata={ | |
'help': 'An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines file).' | |
}, | |
) | |
test_file: Optional[str] = field( | |
default='test.json', | |
metadata={ | |
'help': 'An optional input test data file to evaluate the metrics (rouge) on (a jsonlines file).' | |
}, | |
) | |
excess_file: Optional[str] = field( | |
default='excess.json', | |
metadata={ | |
'help': 'The excess segments left after the split' | |
}, | |
) | |
overwrite_cache: bool = field( | |
default=False, metadata={'help': 'Overwrite the cached training and evaluation sets'} | |
) | |
positive_file: Optional[str] = field( | |
default='sponsor_segments.json', metadata={'help': 'File to output sponsored segments to (a jsonlines file).'} | |
) | |
negative_file: Optional[str] = field( | |
default='normal_segments.json', metadata={'help': 'File to output normal segments to (a jsonlines file).'} | |
) | |
def __post_init__(self): | |
# TODO check if train/validation datasets exist | |
if self.train_file is None and self.validation_file is None: | |
raise ValueError( | |
'Need either a dataset name or a training/validation file.') | |
def main(): | |
# Responsible for getting transcrips using youtube_transcript_api, | |
# then labelling it according to SponsorBlock's API | |
logging.getLogger().setLevel(logging.INFO) # TODO make param | |
# Generate final.json from sponsorTimes.csv | |
hf_parser = HfArgumentParser(( | |
PreprocessArguments, | |
ProcessedArguments, | |
DatasetArguments, | |
segment.SegmentationArguments, | |
ModelArguments, | |
GeneralArguments | |
)) | |
preprocess_args, processed_args, dataset_args, segmentation_args, model_args, _ = hf_parser.parse_args_into_dataclasses() | |
raw_dataset_path = os.path.join( | |
preprocess_args.raw_data_dir, preprocess_args.raw_data_file) | |
if preprocess_args.update_database: | |
print('Updating database') | |
for mirror in MIRRORS: | |
print('Downloading from', mirror) | |
if download_file(mirror, raw_dataset_path): | |
break | |
print('Failed, trying next') | |
def read_db(): # TODO save as file | |
print('Parsing raw database') | |
db = {} | |
latest_time = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y') | |
with open(raw_dataset_path, newline='') as csvfile: | |
reader = csv.DictReader(csvfile) | |
for line in reader: | |
submission_time = float(line['timeSubmitted'])/1e3 | |
if datetime.fromtimestamp(submission_time) < latest_time: | |
continue | |
if line['service'] != 'YouTube': | |
continue | |
if len(line['videoID']) != 11: | |
continue # Invalid youtube video ID | |
if line['category'] not in preprocess_args.categories: | |
continue | |
if line['actionType'] != 'skip': | |
continue | |
# Ignore hidden items | |
if line['hidden'] == '1' or line['shadowHidden'] == '1': | |
continue | |
# Skip those that aren't highly voted | |
line['votes'] = int(line['votes']) | |
if line['votes'] < preprocess_args.min_votes: | |
continue | |
locked = line['locked'] == '1' | |
# Skip segments with low views (i.e., not really reviewed) | |
# Always include segments locked by VIPs, regardless of view count | |
line['views'] = int(line['views']) | |
if not locked and line['views'] < preprocess_args.min_views: | |
continue | |
if line['videoID'] not in db: | |
db[line['videoID']] = [] | |
db[line['videoID']].append({ | |
'uuid': line['UUID'], | |
'start': float(line['startTime']), | |
'end': float(line['endTime']), | |
'votes': line['votes'], | |
'locked': locked, | |
'views': line['views'], | |
'submission_time': submission_time, | |
'reputation': line['reputation'], | |
'category': line['category'], | |
'action': line['actionType'], | |
}) | |
num_segments = 0 | |
# Remove duplicate sponsor segments by choosing best (most votes) | |
print('Remove duplicate segments') | |
for key in db: | |
db[key] = remove_duplicate_segments(db[key]) | |
num_segments += len(db[key]) | |
print('Saved', len(db), 'videos and', num_segments, 'segments') | |
return db | |
# 'videoID', 'startTime', 'endTime', 'votes', 'locked', 'incorrectVotes', 'UUID', | |
# 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration', | |
# 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description' | |
parsed_database = None | |
if preprocess_args.do_transcribe: | |
print('Collecting videos') | |
parsed_database = read_db() | |
# Remove transcripts already processed | |
finished = set(os.listdir('transcripts/auto/') + | |
os.listdir('transcripts/manual/')) | |
finished = set([x.split('.')[0] for x in finished]) | |
video_ids = list(parsed_database.keys() - finished) | |
# Create tasks generator | |
tasks = ( | |
Task(get_words, video_id) | |
for video_id in video_ids | |
) | |
print('start') | |
with tqdm(total=len(video_ids)) as progress: | |
def callback(task): | |
progress.set_description(f'Processing {task.args[0]}') | |
progress.update() | |
InterruptibleTaskPool( | |
tasks, preprocess_args.num_jobs, callback).start() | |
final_path = os.path.join( | |
processed_args.processed_dir, processed_args.processed_file) | |
if preprocess_args.do_create: | |
print('Create final data') | |
final_data = {} | |
parsed_database = read_db() | |
# TODO add progress bar | |
# TODO parallelise? | |
with tqdm(total=len(parsed_database)) as progress: | |
for index, (video_id, segments) in enumerate(parsed_database.items()): | |
if preprocess_args.max_videos is not None and index >= preprocess_args.max_videos: | |
break | |
progress.set_description(f'Processing {video_id}') | |
progress.update() | |
final_data[video_id] = [] | |
video_words = get_words(video_id, process=False) | |
if not video_words: | |
continue | |
for seg in segments: # Only add segments with high enough wps | |
segment_words = segment.extract_segment( | |
video_words, seg['start'], seg['end']) | |
if len(segment_words) <= 1: | |
continue # Useless to add segment since no words | |
# duration = segment.word_end(segment_words[-1]) - segment.word_start(segment_words[0]) | |
duration = seg['end'] - seg['start'] | |
wps = len(segment_words)/duration if duration > 0 else 0 | |
# print(video_id, wps) | |
if wps < preprocess_args.min_wps: | |
# Skip sponsor segments without many words | |
# e.g. music ads with some words on each side | |
# progress.set_description(f'Skipping bad segment in {video_id} (wps={wps})') | |
continue | |
final_data[video_id].append(seg) | |
# Save data | |
with open(final_path, 'w') as fp: | |
json.dump(final_data, fp) | |
# final_data = preprocess( | |
# raw_dataset_path, final_path, preprocess_args.min_votes) | |
# # TODO save metadata in final.json? | |
elif os.path.exists(final_path): | |
# Already exists | |
logging.info(f'{final_path} exists, opening file') | |
with open(final_path) as fp: | |
final_data = json.load(fp) | |
logging.info(f'Found {len(final_data)} videos') | |
else: | |
return # Do not continue | |
# TODO shuffle final_data | |
# if not os.path.exists(excess_path) or preprocess_args.overwrite | |
# TODO use overwrite param | |
os.makedirs(dataset_args.data_dir, exist_ok=True) | |
positive_file = os.path.join( | |
dataset_args.data_dir, dataset_args.positive_file) | |
negative_file = os.path.join( | |
dataset_args.data_dir, dataset_args.negative_file) | |
if preprocess_args.do_generate: | |
print('Generating') | |
from model import get_tokenizer | |
# max_videos=preprocess_args.max_videos, | |
# max_segments=preprocess_args.max_segments, | |
# , max_videos, max_segments | |
tokenizer = get_tokenizer(model_args) | |
# TODO | |
# count_videos = 0 | |
# count_segments = 0 | |
data = final_data.items() | |
start_index = preprocess_args.start_index or 0 | |
end_index = (preprocess_args.max_videos or len(data)) + start_index | |
data = list(itertools.islice(data, start_index, end_index)) | |
with open(positive_file, 'a', encoding='utf-8') as positive, \ | |
open(negative_file, 'a', encoding='utf-8') as negative, \ | |
tqdm(data) as progress: | |
for offset, (video_id, sponsor_segments) in enumerate(data): | |
progress.set_description(f'Processing {video_id}') | |
progress.update() | |
words = get_words(video_id, process=False) | |
if not words: | |
continue | |
num_words = len(words) | |
if num_words <= 1: | |
continue | |
# TODO only count words that aren't [Music], [Applause], etc. | |
segments = segment.generate_labelled_segments( | |
words, tokenizer, segmentation_args, sponsor_segments) | |
if not segments: | |
continue | |
for seg in segments: | |
duration = segment.word_end( | |
seg[-1]) - segment.word_start(seg[0]) | |
wps = len(seg)/duration if duration > 0 else 0 | |
# Ignore segments with "not enough words" in the transcript | |
# Must do here since this includes non-sponsor segments | |
if wps < preprocess_args.min_wps: | |
continue | |
d = { | |
'video_index': offset + start_index, | |
'video_id': video_id, | |
'text': clean_text(' '.join(x['text'] for x in seg)), | |
'words_per_second': round(wps, 3), | |
} | |
extracted_segments = extract_sponsors(seg) | |
if extracted_segments: | |
extracted_texts = [] | |
for s in extracted_segments: | |
w = ' '.join(q['text'] for q in s['words']) | |
category = s['category'].upper() | |
extracted_texts.append( | |
f'{START_SEGMENT_TEMPLATE.format(category)} {w} {END_SEGMENT_TEMPLATE.format(category)}' | |
) | |
extracted_text = f' {CustomTokens.BETWEEN_SEGMENTS.value} '.join( | |
extracted_texts) | |
d['extracted'] = clean_text(extracted_text) | |
print(json.dumps(d), file=positive) | |
else: | |
d['extracted'] = CustomTokens.NO_SEGMENT.value | |
print(json.dumps(d), file=negative) | |
if preprocess_args.do_split: | |
print('Splitting') | |
print('Read files') | |
with open(positive_file, encoding='utf-8') as positive: | |
sponsors = positive.readlines() | |
with open(negative_file, encoding='utf-8') as negative: | |
non_sponsors = negative.readlines() | |
print('Shuffle') | |
random.shuffle(sponsors) | |
random.shuffle(non_sponsors) | |
print('Calculate ratios') | |
# Ensure correct ratio of positive to negative segments | |
percentage_negative = 1 - preprocess_args.percentage_positive | |
if preprocess_args.percentage_positive * len(sponsors) > len(non_sponsors): | |
# Negative is limiting | |
z = int(preprocess_args.percentage_positive / | |
percentage_negative * len(non_sponsors)) | |
excess = sponsors[z:] | |
sponsors = sponsors[:z] | |
else: | |
# Positive is limiting | |
z = int(percentage_negative / | |
preprocess_args.percentage_positive * len(sponsors)) | |
excess = non_sponsors[z:] | |
non_sponsors = non_sponsors[:z] | |
print('Join') | |
all_labelled_segments = sponsors + non_sponsors | |
random.shuffle(all_labelled_segments) | |
print('Split') | |
ratios = [preprocess_args.train_split, | |
preprocess_args.test_split, | |
preprocess_args.valid_split] | |
train_data, test_data, valid_data = split( | |
all_labelled_segments, ratios) | |
splits = { | |
dataset_args.train_file: train_data, | |
dataset_args.test_file: test_data, | |
dataset_args.validation_file: valid_data | |
} | |
# Output training, testing and validation data | |
for name, items in splits.items(): | |
outfile = os.path.join(dataset_args.data_dir, name) | |
if not os.path.exists(outfile) or preprocess_args.overwrite: | |
with open(outfile, 'w', encoding='utf-8') as fp: | |
fp.writelines(items) | |
else: | |
print('Skipping', name) | |
print('Write') | |
# Save excess items | |
excess_path = os.path.join( | |
dataset_args.data_dir, dataset_args.excess_file) | |
if not os.path.exists(excess_path) or preprocess_args.overwrite: | |
with open(excess_path, 'w', encoding='utf-8') as fp: | |
fp.writelines(excess) | |
else: | |
print('Skipping', dataset_args.excess_file) | |
print('Finished splitting:', len(sponsors), | |
'sponsors,', len(non_sponsors), 'non sponsors') | |
def split(arr, ratios): | |
"""Split array according to ratios. Sum of ratios should be less than 1""" | |
to_return = [] | |
cumulative_sum = 0 | |
for r in ratios: | |
current = cumulative_sum | |
cumulative_sum += r * len(arr) | |
to_return.append(arr[int(current):int(cumulative_sum)]) | |
return to_return | |
if __name__ == '__main__': | |
main() | |