| """ Finetuning example. |
| """ |
| from __future__ import print_function |
| import sys |
| import numpy as np |
| from os.path import abspath, dirname |
| sys.path.insert(0, dirname(dirname(abspath(__file__)))) |
|
|
| import json |
| import math |
| from torchmoji.model_def import torchmoji_transfer |
| from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH |
| from torchmoji.finetuning import ( |
| load_benchmark, |
| finetune) |
| from torchmoji.class_avg_finetuning import class_avg_finetune |
|
|
| def roundup(x): |
| return int(math.ceil(x / 10.0)) * 10 |
|
|
|
|
| |
| |
| |
| |
| DATASETS = [ |
| |
| |
| |
| |
| ('SS-Youtube', '../data/SS-Youtube/raw.pickle', 2, False), |
| |
| |
| |
| ] |
|
|
| RESULTS_DIR = 'results' |
|
|
| |
| FINETUNE_METHOD = 'last' |
| VERBOSE = 1 |
|
|
| nb_tokens = 50000 |
| nb_epochs = 1000 |
| epoch_size = 1000 |
|
|
| with open(VOCAB_PATH, 'r') as f: |
| vocab = json.load(f) |
|
|
| for rerun_iter in range(5): |
| for p in DATASETS: |
|
|
| |
| assert len(vocab) == nb_tokens |
|
|
| dset = p[0] |
| path = p[1] |
| nb_classes = p[2] |
| use_f1_score = p[3] |
|
|
| if FINETUNE_METHOD == 'last': |
| extend_with = 0 |
| elif FINETUNE_METHOD in ['new', 'full', 'chain-thaw']: |
| extend_with = 10000 |
| else: |
| raise ValueError('Finetuning method not recognised!') |
|
|
| |
| data = load_benchmark(path, vocab, extend_with=extend_with) |
|
|
| (X_train, y_train) = (data['texts'][0], data['labels'][0]) |
| (X_val, y_val) = (data['texts'][1], data['labels'][1]) |
| (X_test, y_test) = (data['texts'][2], data['labels'][2]) |
|
|
| weight_path = PRETRAINED_PATH if FINETUNE_METHOD != 'new' else None |
| nb_model_classes = 2 if use_f1_score else nb_classes |
| model = torchmoji_transfer( |
| nb_model_classes, |
| weight_path, |
| extend_embedding=data['added']) |
| print(model) |
|
|
| |
| print('Training: {}'.format(path)) |
| if use_f1_score: |
| model, result = class_avg_finetune(model, data['texts'], |
| data['labels'], |
| nb_classes, data['batch_size'], |
| FINETUNE_METHOD, |
| verbose=VERBOSE) |
| else: |
| model, result = finetune(model, data['texts'], data['labels'], |
| nb_classes, data['batch_size'], |
| FINETUNE_METHOD, metric='acc', |
| verbose=VERBOSE) |
|
|
| |
| if use_f1_score: |
| print('Overall F1 score (dset = {}): {}'.format(dset, result)) |
| with open('{}/{}_{}_{}_results.txt'. |
| format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter), |
| "w") as f: |
| f.write("F1: {}\n".format(result)) |
| else: |
| print('Test accuracy (dset = {}): {}'.format(dset, result)) |
| with open('{}/{}_{}_{}_results.txt'. |
| format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter), |
| "w") as f: |
| f.write("Acc: {}\n".format(result)) |
|
|