|
|
|
""" Class average finetuning functions. Before using any of these finetuning |
|
functions, ensure that the model is set up with nb_classes=2. |
|
""" |
|
from __future__ import print_function |
|
|
|
import uuid |
|
from time import sleep |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
|
|
from torchmoji.global_variables import ( |
|
FINETUNING_METHODS, |
|
WEIGHTS_DIR) |
|
from torchmoji.finetuning import ( |
|
freeze_layers, |
|
get_data_loader, |
|
fit_model, |
|
train_by_chain_thaw, |
|
find_f1_threshold) |
|
|
|
def relabel(y, current_label_nr, nb_classes): |
|
""" Makes a binary classification for a specific class in a |
|
multi-class dataset. |
|
|
|
# Arguments: |
|
y: Outputs to be relabelled. |
|
current_label_nr: Current label number. |
|
nb_classes: Total number of classes. |
|
|
|
# Returns: |
|
Relabelled outputs of a given multi-class dataset into a binary |
|
classification dataset. |
|
""" |
|
|
|
|
|
if nb_classes == 2 and len(y.shape) == 1: |
|
return y |
|
|
|
y_new = np.zeros(len(y)) |
|
y_cut = y[:, current_label_nr] |
|
label_pos = np.where(y_cut == 1)[0] |
|
y_new[label_pos] = 1 |
|
return y_new |
|
|
|
|
|
def class_avg_finetune(model, texts, labels, nb_classes, batch_size, |
|
method, epoch_size=5000, nb_epochs=1000, embed_l2=1E-6, |
|
verbose=True): |
|
""" Compiles and finetunes the given model. |
|
|
|
# Arguments: |
|
model: Model to be finetuned |
|
texts: List of three lists, containing tokenized inputs for training, |
|
validation and testing (in that order). |
|
labels: List of three lists, containing labels for training, |
|
validation and testing (in that order). |
|
nb_classes: Number of classes in the dataset. |
|
batch_size: Batch size. |
|
method: Finetuning method to be used. For available methods, see |
|
FINETUNING_METHODS in global_variables.py. Note that the model |
|
should be defined accordingly (see docstring for torchmoji_transfer()) |
|
epoch_size: Number of samples in an epoch. |
|
nb_epochs: Number of epochs. Doesn't matter much as early stopping is used. |
|
embed_l2: L2 regularization for the embedding layer. |
|
verbose: Verbosity flag. |
|
|
|
# Returns: |
|
Model after finetuning, |
|
score after finetuning using the class average F1 metric. |
|
""" |
|
|
|
if method not in FINETUNING_METHODS: |
|
raise ValueError('ERROR (class_avg_tune_trainable): ' |
|
'Invalid method parameter. ' |
|
'Available options: {}'.format(FINETUNING_METHODS)) |
|
|
|
(X_train, y_train) = (texts[0], labels[0]) |
|
(X_val, y_val) = (texts[1], labels[1]) |
|
(X_test, y_test) = (texts[2], labels[2]) |
|
|
|
checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \ |
|
.format(WEIGHTS_DIR, str(uuid.uuid4())) |
|
|
|
f1_init_path = '{}/torchmoji-f1-init-{}.bin' \ |
|
.format(WEIGHTS_DIR, str(uuid.uuid4())) |
|
|
|
if method in ['last', 'new']: |
|
lr = 0.001 |
|
elif method in ['full', 'chain-thaw']: |
|
lr = 0.0001 |
|
|
|
loss_op = nn.BCEWithLogitsLoss() |
|
|
|
|
|
if method == 'last': |
|
model = freeze_layers(model, unfrozen_keyword='output_layer') |
|
|
|
|
|
if method == 'last': |
|
adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr) |
|
elif method in ['full', 'new']: |
|
|
|
special_params = [id(p) for p in model.embed.parameters()] |
|
base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad] |
|
embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad] |
|
adam = optim.Adam([ |
|
{'params': base_params}, |
|
{'params': embed_parameters, 'weight_decay': embed_l2}, |
|
], lr=lr) |
|
|
|
|
|
if verbose: |
|
print('Method: {}'.format(method)) |
|
print('Classes: {}'.format(nb_classes)) |
|
|
|
if method == 'chain-thaw': |
|
result = class_avg_chainthaw(model, nb_classes=nb_classes, |
|
loss_op=loss_op, |
|
train=(X_train, y_train), |
|
val=(X_val, y_val), |
|
test=(X_test, y_test), |
|
batch_size=batch_size, |
|
epoch_size=epoch_size, |
|
nb_epochs=nb_epochs, |
|
checkpoint_weight_path=checkpoint_path, |
|
f1_init_weight_path=f1_init_path, |
|
verbose=verbose) |
|
else: |
|
result = class_avg_tune_trainable(model, nb_classes=nb_classes, |
|
loss_op=loss_op, |
|
optim_op=adam, |
|
train=(X_train, y_train), |
|
val=(X_val, y_val), |
|
test=(X_test, y_test), |
|
epoch_size=epoch_size, |
|
nb_epochs=nb_epochs, |
|
batch_size=batch_size, |
|
init_weight_path=f1_init_path, |
|
checkpoint_weight_path=checkpoint_path, |
|
verbose=verbose) |
|
return model, result |
|
|
|
|
|
def prepare_labels(y_train, y_val, y_test, iter_i, nb_classes): |
|
|
|
y_train_new = relabel(y_train, iter_i, nb_classes) |
|
y_val_new = relabel(y_val, iter_i, nb_classes) |
|
y_test_new = relabel(y_test, iter_i, nb_classes) |
|
return y_train_new, y_val_new, y_test_new |
|
|
|
def prepare_generators(X_train, y_train_new, X_val, y_val_new, batch_size, epoch_size): |
|
|
|
|
|
train_gen = get_data_loader(X_train, y_train_new, batch_size, |
|
extended_batch_sampler=True) |
|
val_gen = get_data_loader(X_val, y_val_new, epoch_size, |
|
extended_batch_sampler=True) |
|
X_val_resamp, y_val_resamp = next(iter(val_gen)) |
|
return train_gen, X_val_resamp, y_val_resamp |
|
|
|
|
|
def class_avg_tune_trainable(model, nb_classes, loss_op, optim_op, train, val, test, |
|
epoch_size, nb_epochs, batch_size, |
|
init_weight_path, checkpoint_weight_path, patience=5, |
|
verbose=True): |
|
""" Finetunes the given model using the F1 measure. |
|
|
|
# Arguments: |
|
model: Model to be finetuned. |
|
nb_classes: Number of classes in the given dataset. |
|
train: Training data, given as a tuple of (inputs, outputs) |
|
val: Validation data, given as a tuple of (inputs, outputs) |
|
test: Testing data, given as a tuple of (inputs, outputs) |
|
epoch_size: Number of samples in an epoch. |
|
nb_epochs: Number of epochs. |
|
batch_size: Batch size. |
|
init_weight_path: Filepath where weights will be initially saved before |
|
training each class. This file will be rewritten by the function. |
|
checkpoint_weight_path: Filepath where weights will be checkpointed to |
|
during training. This file will be rewritten by the function. |
|
verbose: Verbosity flag. |
|
|
|
# Returns: |
|
F1 score of the trained model |
|
""" |
|
total_f1 = 0 |
|
nb_iter = nb_classes if nb_classes > 2 else 1 |
|
|
|
|
|
X_train, y_train = train |
|
X_val, y_val = val |
|
X_test, y_test = test |
|
|
|
|
|
|
|
torch.save(model.state_dict(), init_weight_path) |
|
for i in range(nb_iter): |
|
if verbose: |
|
print('Iteration number {}/{}'.format(i+1, nb_iter)) |
|
|
|
model.load_state_dict(torch.load(init_weight_path)) |
|
y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val, |
|
y_test, i, nb_classes) |
|
train_gen, X_val_resamp, y_val_resamp = \ |
|
prepare_generators(X_train, y_train_new, X_val, y_val_new, |
|
batch_size, epoch_size) |
|
|
|
if verbose: |
|
print("Training..") |
|
fit_model(model, loss_op, optim_op, train_gen, [(X_val_resamp, y_val_resamp)], |
|
nb_epochs, checkpoint_weight_path, patience, verbose=0) |
|
|
|
|
|
|
|
sleep(1) |
|
model.load_state_dict(torch.load(checkpoint_weight_path)) |
|
|
|
|
|
y_pred_val = model(X_val).cpu().numpy() |
|
y_pred_test = model(X_test).cpu().numpy() |
|
|
|
f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val, |
|
y_test_new, y_pred_test) |
|
if verbose: |
|
print('f1_test: {}'.format(f1_test)) |
|
print('best_t: {}'.format(best_t)) |
|
total_f1 += f1_test |
|
|
|
return total_f1 / nb_iter |
|
|
|
|
|
def class_avg_chainthaw(model, nb_classes, loss_op, train, val, test, batch_size, |
|
epoch_size, nb_epochs, checkpoint_weight_path, |
|
f1_init_weight_path, patience=5, |
|
initial_lr=0.001, next_lr=0.0001, verbose=True): |
|
""" Finetunes given model using chain-thaw and evaluates using F1. |
|
For a dataset with multiple classes, the model is trained once for |
|
each class, relabeling those classes into a binary classification task. |
|
The result is an average of all F1 scores for each class. |
|
|
|
# Arguments: |
|
model: Model to be finetuned. |
|
nb_classes: Number of classes in the given dataset. |
|
train: Training data, given as a tuple of (inputs, outputs) |
|
val: Validation data, given as a tuple of (inputs, outputs) |
|
test: Testing data, given as a tuple of (inputs, outputs) |
|
batch_size: Batch size. |
|
loss: Loss function to be used during training. |
|
epoch_size: Number of samples in an epoch. |
|
nb_epochs: Number of epochs. |
|
checkpoint_weight_path: Filepath where weights will be checkpointed to |
|
during training. This file will be rewritten by the function. |
|
f1_init_weight_path: Filepath where weights will be saved to and |
|
reloaded from before training each class. This ensures that |
|
each class is trained independently. This file will be rewritten. |
|
initial_lr: Initial learning rate. Will only be used for the first |
|
training step (i.e. the softmax layer) |
|
next_lr: Learning rate for every subsequent step. |
|
seed: Random number generator seed. |
|
verbose: Verbosity flag. |
|
|
|
# Returns: |
|
Averaged F1 score. |
|
""" |
|
|
|
|
|
X_train, y_train = train |
|
X_val, y_val = val |
|
X_test, y_test = test |
|
|
|
total_f1 = 0 |
|
nb_iter = nb_classes if nb_classes > 2 else 1 |
|
|
|
torch.save(model.state_dict(), f1_init_weight_path) |
|
|
|
for i in range(nb_iter): |
|
if verbose: |
|
print('Iteration number {}/{}'.format(i+1, nb_iter)) |
|
|
|
model.load_state_dict(torch.load(f1_init_weight_path)) |
|
y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val, |
|
y_test, i, nb_classes) |
|
train_gen, X_val_resamp, y_val_resamp = \ |
|
prepare_generators(X_train, y_train_new, X_val, y_val_new, |
|
batch_size, epoch_size) |
|
|
|
if verbose: |
|
print("Training..") |
|
|
|
|
|
train_by_chain_thaw(model=model, train_gen=train_gen, |
|
val_gen=[(X_val_resamp, y_val_resamp)], |
|
loss_op=loss_op, patience=patience, |
|
nb_epochs=nb_epochs, |
|
checkpoint_path=checkpoint_weight_path, |
|
initial_lr=initial_lr, next_lr=next_lr, |
|
verbose=verbose) |
|
|
|
|
|
y_pred_val = model(X_val).cpu().numpy() |
|
y_pred_test = model(X_test).cpu().numpy() |
|
|
|
f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val, |
|
y_test_new, y_pred_test) |
|
|
|
if verbose: |
|
print('f1_test: {}'.format(f1_test)) |
|
print('best_t: {}'.format(best_t)) |
|
total_f1 += f1_test |
|
|
|
return total_f1 / nb_iter |
|
|