|
""" Finetuning example. |
|
""" |
|
from __future__ import print_function |
|
import sys |
|
import numpy as np |
|
from os.path import abspath, dirname |
|
sys.path.insert(0, dirname(dirname(abspath(__file__)))) |
|
|
|
import json |
|
import math |
|
from torchmoji.model_def import torchmoji_transfer |
|
from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH |
|
from torchmoji.finetuning import ( |
|
load_benchmark, |
|
finetune) |
|
from torchmoji.class_avg_finetuning import class_avg_finetune |
|
|
|
def roundup(x): |
|
return int(math.ceil(x / 10.0)) * 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
DATASETS = [ |
|
|
|
|
|
|
|
|
|
('SS-Youtube', '../data/SS-Youtube/raw.pickle', 2, False), |
|
|
|
|
|
|
|
] |
|
|
|
RESULTS_DIR = 'results' |
|
|
|
|
|
FINETUNE_METHOD = 'last' |
|
VERBOSE = 1 |
|
|
|
nb_tokens = 50000 |
|
nb_epochs = 1000 |
|
epoch_size = 1000 |
|
|
|
with open(VOCAB_PATH, 'r') as f: |
|
vocab = json.load(f) |
|
|
|
for rerun_iter in range(5): |
|
for p in DATASETS: |
|
|
|
|
|
assert len(vocab) == nb_tokens |
|
|
|
dset = p[0] |
|
path = p[1] |
|
nb_classes = p[2] |
|
use_f1_score = p[3] |
|
|
|
if FINETUNE_METHOD == 'last': |
|
extend_with = 0 |
|
elif FINETUNE_METHOD in ['new', 'full', 'chain-thaw']: |
|
extend_with = 10000 |
|
else: |
|
raise ValueError('Finetuning method not recognised!') |
|
|
|
|
|
data = load_benchmark(path, vocab, extend_with=extend_with) |
|
|
|
(X_train, y_train) = (data['texts'][0], data['labels'][0]) |
|
(X_val, y_val) = (data['texts'][1], data['labels'][1]) |
|
(X_test, y_test) = (data['texts'][2], data['labels'][2]) |
|
|
|
weight_path = PRETRAINED_PATH if FINETUNE_METHOD != 'new' else None |
|
nb_model_classes = 2 if use_f1_score else nb_classes |
|
model = torchmoji_transfer( |
|
nb_model_classes, |
|
weight_path, |
|
extend_embedding=data['added']) |
|
print(model) |
|
|
|
|
|
print('Training: {}'.format(path)) |
|
if use_f1_score: |
|
model, result = class_avg_finetune(model, data['texts'], |
|
data['labels'], |
|
nb_classes, data['batch_size'], |
|
FINETUNE_METHOD, |
|
verbose=VERBOSE) |
|
else: |
|
model, result = finetune(model, data['texts'], data['labels'], |
|
nb_classes, data['batch_size'], |
|
FINETUNE_METHOD, metric='acc', |
|
verbose=VERBOSE) |
|
|
|
|
|
if use_f1_score: |
|
print('Overall F1 score (dset = {}): {}'.format(dset, result)) |
|
with open('{}/{}_{}_{}_results.txt'. |
|
format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter), |
|
"w") as f: |
|
f.write("F1: {}\n".format(result)) |
|
else: |
|
print('Test accuracy (dset = {}): {}'.format(dset, result)) |
|
with open('{}/{}_{}_{}_results.txt'. |
|
format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter), |
|
"w") as f: |
|
f.write("Acc: {}\n".format(result)) |
|
|