hf-deepmoji / scripts /convert_all_datasets.py
cclauss
Fix flake8 issues
2c047e3
raw
history blame contribute delete
No virus
3.39 kB
from __future__ import print_function
import json
import math
import pickle
import sys
from io import open
import numpy as np
from os.path import abspath, dirname
sys.path.insert(0, dirname(dirname(abspath(__file__))))
from torchmoji.word_generator import WordGenerator
from torchmoji.create_vocab import VocabBuilder
from torchmoji.sentence_tokenizer import SentenceTokenizer, extend_vocab, coverage
from torchmoji.tokenizer import tokenize
try:
unicode # Python 2
except NameError:
unicode = str # Python 3
IS_PYTHON2 = int(sys.version[0]) == 2
DATASETS = [
'Olympic',
'PsychExp',
'SCv1',
'SCv2-GEN',
'SE0714',
#'SE1604', # Excluded due to Twitter's ToS
'SS-Twitter',
'SS-Youtube',
]
DIR = '../data'
FILENAME_RAW = 'raw.pickle'
FILENAME_OWN = 'own_vocab.pickle'
FILENAME_OUR = 'twitter_vocab.pickle'
FILENAME_COMBINED = 'combined_vocab.pickle'
def roundup(x):
return int(math.ceil(x / 10.0)) * 10
def format_pickle(dset, train_texts, val_texts, test_texts, train_labels, val_labels, test_labels):
return {'dataset': dset,
'train_texts': train_texts,
'val_texts': val_texts,
'test_texts': test_texts,
'train_labels': train_labels,
'val_labels': val_labels,
'test_labels': test_labels}
def convert_dataset(filepath, extend_with, vocab):
print('-- Generating {} '.format(filepath))
sys.stdout.flush()
st = SentenceTokenizer(vocab, maxlen)
tokenized, dicts, _ = st.split_train_val_test(texts,
labels,
[data['train_ind'],
data['val_ind'],
data['test_ind']],
extend_with=extend_with)
pick = format_pickle(dset, tokenized[0], tokenized[1], tokenized[2],
dicts[0], dicts[1], dicts[2])
with open(filepath, 'w') as f:
pickle.dump(pick, f)
cover = coverage(tokenized[2])
print(' done. Coverage: {}'.format(cover))
with open('../model/vocabulary.json', 'r') as f:
vocab = json.load(f)
for dset in DATASETS:
print('Converting {}'.format(dset))
PATH_RAW = '{}/{}/{}'.format(DIR, dset, FILENAME_RAW)
PATH_OWN = '{}/{}/{}'.format(DIR, dset, FILENAME_OWN)
PATH_OUR = '{}/{}/{}'.format(DIR, dset, FILENAME_OUR)
PATH_COMBINED = '{}/{}/{}'.format(DIR, dset, FILENAME_COMBINED)
with open(PATH_RAW, 'rb') as dataset:
if IS_PYTHON2:
data = pickle.load(dataset)
else:
data = pickle.load(dataset, fix_imports=True)
# Decode data
try:
texts = [unicode(x) for x in data['texts']]
except UnicodeDecodeError:
texts = [x.decode('utf-8') for x in data['texts']]
wg = WordGenerator(texts)
vb = VocabBuilder(wg)
vb.count_all_words()
# Calculate max length of sequences considered
# Adjust batch_size accordingly to prevent GPU overflow
lengths = [len(tokenize(t)) for t in texts]
maxlen = roundup(np.percentile(lengths, 80.0))
# Extract labels
labels = [x['label'] for x in data['info']]
convert_dataset(PATH_OWN, 50000, {})
convert_dataset(PATH_OUR, 0, vocab)
convert_dataset(PATH_COMBINED, 10000, vocab)