File size: 3,764 Bytes
e1c08c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
""" Finetuning example.
"""
from __future__ import print_function
import sys
import numpy as np
from os.path import abspath, dirname
sys.path.insert(0, dirname(dirname(abspath(__file__))))

import json
import math
from torchmoji.model_def import torchmoji_transfer
from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH
from torchmoji.finetuning import (
     load_benchmark,
     finetune)
from torchmoji.class_avg_finetuning import class_avg_finetune

def roundup(x):
    return int(math.ceil(x / 10.0)) * 10


# Format: (dataset_name,
#          path_to_dataset,
#          nb_classes,
#          use_f1_score)
DATASETS = [
     #('SE0714', '../data/SE0714/raw.pickle', 3, True),
     #('Olympic', '../data/Olympic/raw.pickle', 4, True),
     #('PsychExp', '../data/PsychExp/raw.pickle', 7, True),
     #('SS-Twitter', '../data/SS-Twitter/raw.pickle', 2, False),
     ('SS-Youtube', '../data/SS-Youtube/raw.pickle', 2, False),
     #('SE1604', '../data/SE1604/raw.pickle', 3, False), # Excluded due to Twitter's ToS
     #('SCv1', '../data/SCv1/raw.pickle', 2, True),
     #('SCv2-GEN', '../data/SCv2-GEN/raw.pickle', 2, True)
      ]

RESULTS_DIR = 'results'

# 'new' | 'last' | 'full' | 'chain-thaw'
FINETUNE_METHOD = 'last'
VERBOSE = 1

nb_tokens = 50000
nb_epochs = 1000
epoch_size = 1000

with open(VOCAB_PATH, 'r') as f:
    vocab = json.load(f)

for rerun_iter in range(5):
    for p in DATASETS:

        # debugging
        assert len(vocab) == nb_tokens

        dset = p[0]
        path = p[1]
        nb_classes = p[2]
        use_f1_score = p[3]

        if FINETUNE_METHOD == 'last':
            extend_with = 0
        elif FINETUNE_METHOD in ['new', 'full', 'chain-thaw']:
            extend_with = 10000
        else:
            raise ValueError('Finetuning method not recognised!')

        # Load dataset.
        data = load_benchmark(path, vocab, extend_with=extend_with)

        (X_train, y_train) = (data['texts'][0], data['labels'][0])
        (X_val, y_val) = (data['texts'][1], data['labels'][1])
        (X_test, y_test) = (data['texts'][2], data['labels'][2])

        weight_path = PRETRAINED_PATH if FINETUNE_METHOD != 'new' else None
        nb_model_classes = 2 if use_f1_score else nb_classes
        model = torchmoji_transfer(
                    nb_model_classes,
                    weight_path,
                    extend_embedding=data['added'])
        print(model)

        # Training
        print('Training: {}'.format(path))
        if use_f1_score:
            model, result = class_avg_finetune(model, data['texts'],
                                               data['labels'],
                                               nb_classes, data['batch_size'],
                                               FINETUNE_METHOD,
                                               verbose=VERBOSE)
        else:
            model, result = finetune(model, data['texts'], data['labels'],
                                     nb_classes, data['batch_size'],
                                     FINETUNE_METHOD, metric='acc',
                                     verbose=VERBOSE)

        # Write results
        if use_f1_score:
            print('Overall F1 score (dset = {}): {}'.format(dset, result))
            with open('{}/{}_{}_{}_results.txt'.
                      format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter),
                      "w") as f:
                f.write("F1: {}\n".format(result))
        else:
            print('Test accuracy (dset = {}): {}'.format(dset, result))
            with open('{}/{}_{}_{}_results.txt'.
                      format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter),
                      "w") as f:
                f.write("Acc: {}\n".format(result))