|
|
|
""" Finetuning functions for doing transfer learning to new datasets. |
|
""" |
|
from __future__ import print_function |
|
|
|
import sys |
|
import uuid |
|
from time import sleep |
|
from io import open |
|
|
|
import math |
|
import pickle |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.autograd import Variable |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.utils.data.sampler import BatchSampler, SequentialSampler |
|
from torch.nn.utils import clip_grad_norm |
|
|
|
from sklearn.metrics import f1_score |
|
|
|
from torchmoji.global_variables import (FINETUNING_METHODS, |
|
FINETUNING_METRICS, |
|
WEIGHTS_DIR) |
|
from torchmoji.tokenizer import tokenize |
|
from torchmoji.sentence_tokenizer import SentenceTokenizer |
|
|
|
IS_PYTHON2 = int(sys.version[0]) == 2 |
|
unicode_ = unicode if IS_PYTHON2 else str |
|
|
|
def load_benchmark(path, vocab, extend_with=0): |
|
""" Loads the given benchmark dataset. |
|
|
|
Tokenizes the texts using the provided vocabulary, extending it with |
|
words from the training dataset if extend_with > 0. Splits them into |
|
three lists: training, validation and testing (in that order). |
|
|
|
Also calculates the maximum length of the texts and the |
|
suggested batch_size. |
|
|
|
# Arguments: |
|
path: Path to the dataset to be loaded. |
|
vocab: Vocabulary to be used for tokenizing texts. |
|
extend_with: If > 0, the vocabulary will be extended with up to |
|
extend_with tokens from the training set before tokenizing. |
|
|
|
# Returns: |
|
A dictionary with the following fields: |
|
texts: List of three lists, containing tokenized inputs for |
|
training, validation and testing (in that order). |
|
labels: List of three lists, containing labels for training, |
|
validation and testing (in that order). |
|
added: Number of tokens added to the vocabulary. |
|
batch_size: Batch size. |
|
maxlen: Maximum length of an input. |
|
""" |
|
|
|
with open(path, 'rb') as dataset: |
|
if IS_PYTHON2: |
|
data = pickle.load(dataset) |
|
else: |
|
data = pickle.load(dataset, fix_imports=True) |
|
|
|
|
|
try: |
|
texts = [unicode_(x) for x in data['texts']] |
|
except UnicodeDecodeError: |
|
texts = [x.decode('utf-8') for x in data['texts']] |
|
|
|
|
|
labels = [x['label'] for x in data['info']] |
|
|
|
batch_size, maxlen = calculate_batchsize_maxlen(texts) |
|
|
|
st = SentenceTokenizer(vocab, maxlen) |
|
|
|
|
|
|
|
texts, labels, added = st.split_train_val_test(texts, |
|
labels, |
|
[data['train_ind'], |
|
data['val_ind'], |
|
data['test_ind']], |
|
extend_with=extend_with) |
|
return {'texts': texts, |
|
'labels': labels, |
|
'added': added, |
|
'batch_size': batch_size, |
|
'maxlen': maxlen} |
|
|
|
|
|
def calculate_batchsize_maxlen(texts): |
|
""" Calculates the maximum length in the provided texts and a suitable |
|
batch size. Rounds up maxlen to the nearest multiple of ten. |
|
|
|
# Arguments: |
|
texts: List of inputs. |
|
|
|
# Returns: |
|
Batch size, |
|
max length |
|
""" |
|
def roundup(x): |
|
return int(math.ceil(x / 10.0)) * 10 |
|
|
|
|
|
|
|
lengths = [len(tokenize(t)) for t in texts] |
|
maxlen = roundup(np.percentile(lengths, 80.0)) |
|
batch_size = 250 if maxlen <= 100 else 50 |
|
return batch_size, maxlen |
|
|
|
|
|
|
|
def freeze_layers(model, unfrozen_types=[], unfrozen_keyword=None): |
|
""" Freezes all layers in the given model, except for ones that are |
|
explicitly specified to not be frozen. |
|
|
|
# Arguments: |
|
model: Model whose layers should be modified. |
|
unfrozen_types: List of layer types which shouldn't be frozen. |
|
unfrozen_keyword: Name keywords of layers that shouldn't be frozen. |
|
|
|
# Returns: |
|
Model with the selected layers frozen. |
|
""" |
|
|
|
trainable_modules = [(n, m) for n, m in model.named_children() if len([id(p) for p in m.parameters()]) != 0] |
|
for name, module in trainable_modules: |
|
trainable = (any(typ in str(module) for typ in unfrozen_types) or |
|
(unfrozen_keyword is not None and unfrozen_keyword.lower() in name.lower())) |
|
change_trainable(module, trainable, verbose=False) |
|
return model |
|
|
|
|
|
def change_trainable(module, trainable, verbose=False): |
|
""" Helper method that freezes or unfreezes a given layer. |
|
|
|
# Arguments: |
|
module: Module to be modified. |
|
trainable: Whether the layer should be frozen or unfrozen. |
|
verbose: Verbosity flag. |
|
""" |
|
|
|
if verbose: print('Changing MODULE', module, 'to trainable =', trainable) |
|
for name, param in module.named_parameters(): |
|
if verbose: print('Setting weight', name, 'to trainable =', trainable) |
|
param.requires_grad = trainable |
|
|
|
if verbose: |
|
action = 'Unfroze' if trainable else 'Froze' |
|
if verbose: print("{} {}".format(action, module)) |
|
|
|
|
|
def find_f1_threshold(model, val_gen, test_gen, average='binary'): |
|
""" Choose a threshold for F1 based on the validation dataset |
|
(see https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4442797/ |
|
for details on why to find another threshold than simply 0.5) |
|
|
|
# Arguments: |
|
model: pyTorch model |
|
val_gen: Validation set dataloader. |
|
test_gen: Testing set dataloader. |
|
|
|
# Returns: |
|
F1 score for the given data and |
|
the corresponding F1 threshold |
|
""" |
|
thresholds = np.arange(0.01, 0.5, step=0.01) |
|
f1_scores = [] |
|
|
|
model.eval() |
|
val_out = [(y, model(X)) for X, y in val_gen] |
|
y_val, y_pred_val = (list(t) for t in zip(*val_out)) |
|
|
|
test_out = [(y, model(X)) for X, y in test_gen] |
|
y_test, y_pred_test = (list(t) for t in zip(*val_out)) |
|
|
|
for t in thresholds: |
|
y_pred_val_ind = (y_pred_val > t) |
|
f1_val = f1_score(y_val, y_pred_val_ind, average=average) |
|
f1_scores.append(f1_val) |
|
|
|
best_t = thresholds[np.argmax(f1_scores)] |
|
y_pred_ind = (y_pred_test > best_t) |
|
f1_test = f1_score(y_test, y_pred_ind, average=average) |
|
return f1_test, best_t |
|
|
|
|
|
def finetune(model, texts, labels, nb_classes, batch_size, method, |
|
metric='acc', epoch_size=5000, nb_epochs=1000, embed_l2=1E-6, |
|
verbose=1): |
|
""" Compiles and finetunes the given pytorch model. |
|
|
|
# Arguments: |
|
model: Model to be finetuned |
|
texts: List of three lists, containing tokenized inputs for training, |
|
validation and testing (in that order). |
|
labels: List of three lists, containing labels for training, |
|
validation and testing (in that order). |
|
nb_classes: Number of classes in the dataset. |
|
batch_size: Batch size. |
|
method: Finetuning method to be used. For available methods, see |
|
FINETUNING_METHODS in global_variables.py. |
|
metric: Evaluation metric to be used. For available metrics, see |
|
FINETUNING_METRICS in global_variables.py. |
|
epoch_size: Number of samples in an epoch. |
|
nb_epochs: Number of epochs. Doesn't matter much as early stopping is used. |
|
embed_l2: L2 regularization for the embedding layer. |
|
verbose: Verbosity flag. |
|
|
|
# Returns: |
|
Model after finetuning, |
|
score after finetuning using the provided metric. |
|
""" |
|
|
|
if method not in FINETUNING_METHODS: |
|
raise ValueError('ERROR (finetune): Invalid method parameter. ' |
|
'Available options: {}'.format(FINETUNING_METHODS)) |
|
if metric not in FINETUNING_METRICS: |
|
raise ValueError('ERROR (finetune): Invalid metric parameter. ' |
|
'Available options: {}'.format(FINETUNING_METRICS)) |
|
|
|
train_gen = get_data_loader(texts[0], labels[0], batch_size, |
|
extended_batch_sampler=True, epoch_size=epoch_size) |
|
val_gen = get_data_loader(texts[1], labels[1], batch_size, |
|
extended_batch_sampler=False) |
|
test_gen = get_data_loader(texts[2], labels[2], batch_size, |
|
extended_batch_sampler=False) |
|
|
|
checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \ |
|
.format(WEIGHTS_DIR, str(uuid.uuid4())) |
|
|
|
if method in ['last', 'new']: |
|
lr = 0.001 |
|
elif method in ['full', 'chain-thaw']: |
|
lr = 0.0001 |
|
|
|
loss_op = nn.BCEWithLogitsLoss() if nb_classes <= 2 \ |
|
else nn.CrossEntropyLoss() |
|
|
|
|
|
if method == 'last': |
|
model = freeze_layers(model, unfrozen_keyword='output_layer') |
|
|
|
|
|
if method == 'last': |
|
adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr) |
|
elif method in ['full', 'new']: |
|
|
|
embed_params_id = [id(p) for p in model.embed.parameters()] |
|
output_layer_params_id = [id(p) for p in model.output_layer.parameters()] |
|
base_params = [p for p in model.parameters() |
|
if id(p) not in embed_params_id and id(p) not in output_layer_params_id and p.requires_grad] |
|
embed_params = [p for p in model.parameters() if id(p) in embed_params_id and p.requires_grad] |
|
output_layer_params = [p for p in model.parameters() if id(p) in output_layer_params_id and p.requires_grad] |
|
adam = optim.Adam([ |
|
{'params': base_params}, |
|
{'params': embed_params, 'weight_decay': embed_l2}, |
|
{'params': output_layer_params, 'lr': 0.001}, |
|
], lr=lr) |
|
|
|
|
|
if verbose: |
|
print('Method: {}'.format(method)) |
|
print('Metric: {}'.format(metric)) |
|
print('Classes: {}'.format(nb_classes)) |
|
|
|
if method == 'chain-thaw': |
|
result = chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, embed_l2=embed_l2, |
|
evaluate=metric, verbose=verbose) |
|
else: |
|
result = tune_trainable(model, loss_op, adam, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, |
|
evaluate=metric, verbose=verbose) |
|
return model, result |
|
|
|
|
|
def tune_trainable(model, loss_op, optim_op, train_gen, val_gen, test_gen, |
|
nb_epochs, checkpoint_path, patience=5, evaluate='acc', |
|
verbose=2): |
|
""" Finetunes the given model using the accuracy measure. |
|
|
|
# Arguments: |
|
model: Model to be finetuned. |
|
nb_classes: Number of classes in the given dataset. |
|
train: Training data, given as a tuple of (inputs, outputs) |
|
val: Validation data, given as a tuple of (inputs, outputs) |
|
test: Testing data, given as a tuple of (inputs, outputs) |
|
epoch_size: Number of samples in an epoch. |
|
nb_epochs: Number of epochs. |
|
batch_size: Batch size. |
|
checkpoint_weight_path: Filepath where weights will be checkpointed to |
|
during training. This file will be rewritten by the function. |
|
patience: Patience for callback methods. |
|
evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'. |
|
verbose: Verbosity flag. |
|
|
|
# Returns: |
|
Accuracy of the trained model, ONLY if 'evaluate' is set. |
|
""" |
|
if verbose: |
|
print("Trainable weights: {}".format([n for n, p in model.named_parameters() if p.requires_grad])) |
|
print("Training...") |
|
if evaluate == 'acc': |
|
print("Evaluation on test set prior training:", evaluate_using_acc(model, test_gen)) |
|
elif evaluate == 'weighted_f1': |
|
print("Evaluation on test set prior training:", evaluate_using_weighted_f1(model, test_gen, val_gen)) |
|
|
|
fit_model(model, loss_op, optim_op, train_gen, val_gen, nb_epochs, checkpoint_path, patience) |
|
|
|
|
|
|
|
sleep(1) |
|
model.load_state_dict(torch.load(checkpoint_path)) |
|
if verbose >= 2: |
|
print("Loaded weights from {}".format(checkpoint_path)) |
|
|
|
if evaluate == 'acc': |
|
return evaluate_using_acc(model, test_gen) |
|
elif evaluate == 'weighted_f1': |
|
return evaluate_using_weighted_f1(model, test_gen, val_gen) |
|
|
|
|
|
def evaluate_using_weighted_f1(model, test_gen, val_gen): |
|
""" Evaluation function using macro weighted F1 score. |
|
|
|
# Arguments: |
|
model: Model to be evaluated. |
|
X_test: Inputs of the testing set. |
|
y_test: Outputs of the testing set. |
|
X_val: Inputs of the validation set. |
|
y_val: Outputs of the validation set. |
|
batch_size: Batch size. |
|
|
|
# Returns: |
|
Weighted F1 score of the given model. |
|
""" |
|
|
|
f1_test, _ = find_f1_threshold(model, test_gen, val_gen, average='weighted_f1') |
|
return f1_test |
|
|
|
|
|
def evaluate_using_acc(model, test_gen): |
|
""" Evaluation function using accuracy. |
|
|
|
# Arguments: |
|
model: Model to be evaluated. |
|
test_gen: Testing data iterator (DataLoader) |
|
|
|
# Returns: |
|
Accuracy of the given model. |
|
""" |
|
|
|
|
|
model.eval() |
|
correct_count = 0.0 |
|
total_y = sum(len(y) for _, y in test_gen) |
|
for i, data in enumerate(test_gen): |
|
x, y = data |
|
outs = model(x) |
|
pred = (outs >= 0).long() |
|
added_counts = (pred == y).double().sum() |
|
correct_count += added_counts |
|
return correct_count/total_y |
|
|
|
|
|
def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, |
|
patience=5, initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, evaluate='acc', verbose=1): |
|
""" Finetunes given model using chain-thaw and evaluates using accuracy. |
|
|
|
# Arguments: |
|
model: Model to be finetuned. |
|
train: Training data, given as a tuple of (inputs, outputs) |
|
val: Validation data, given as a tuple of (inputs, outputs) |
|
test: Testing data, given as a tuple of (inputs, outputs) |
|
batch_size: Batch size. |
|
loss: Loss function to be used during training. |
|
epoch_size: Number of samples in an epoch. |
|
nb_epochs: Number of epochs. |
|
checkpoint_weight_path: Filepath where weights will be checkpointed to |
|
during training. This file will be rewritten by the function. |
|
initial_lr: Initial learning rate. Will only be used for the first |
|
training step (i.e. the output_layer layer) |
|
next_lr: Learning rate for every subsequent step. |
|
seed: Random number generator seed. |
|
verbose: Verbosity flag. |
|
evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'. |
|
|
|
# Returns: |
|
Accuracy of the finetuned model. |
|
""" |
|
if verbose: |
|
print('Training..') |
|
|
|
|
|
train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path, |
|
initial_lr, next_lr, embed_l2, verbose) |
|
|
|
if evaluate == 'acc': |
|
return evaluate_using_acc(model, test_gen) |
|
elif evaluate == 'weighted_f1': |
|
return evaluate_using_weighted_f1(model, test_gen, val_gen) |
|
|
|
|
|
def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path, |
|
initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, verbose=1): |
|
""" Finetunes model using the chain-thaw method. |
|
|
|
This is done as follows: |
|
1) Freeze every layer except the last (output_layer) layer and train it. |
|
2) Freeze every layer except the first layer and train it. |
|
3) Freeze every layer except the second etc., until the second last layer. |
|
4) Unfreeze all layers and train entire model. |
|
|
|
# Arguments: |
|
model: Model to be trained. |
|
train_gen: Training sample generator. |
|
val_data: Validation data. |
|
loss: Loss function to be used. |
|
finetuning_args: Training early stopping and checkpoint saving parameters |
|
epoch_size: Number of samples in an epoch. |
|
nb_epochs: Number of epochs. |
|
checkpoint_weight_path: Where weight checkpoints should be saved. |
|
batch_size: Batch size. |
|
initial_lr: Initial learning rate. Will only be used for the first |
|
training step (i.e. the output_layer layer) |
|
next_lr: Learning rate for every subsequent step. |
|
verbose: Verbosity flag. |
|
""" |
|
|
|
layers = [m for m in model.children() if len([id(p) for p in m.parameters()]) != 0] |
|
|
|
|
|
layers.insert(0, layers.pop(len(layers) - 1)) |
|
|
|
|
|
layers.append(None) |
|
|
|
lr = None |
|
|
|
|
|
for layer in layers: |
|
if lr is None: |
|
lr = initial_lr |
|
elif lr == initial_lr: |
|
lr = next_lr |
|
|
|
|
|
for _layer in layers: |
|
if _layer is not None: |
|
trainable = _layer == layer or layer is None |
|
change_trainable(_layer, trainable=trainable, verbose=False) |
|
|
|
|
|
for _layer in model.children(): |
|
assert all(p.requires_grad == (_layer == layer) for p in _layer.parameters()) or layer is None |
|
|
|
if verbose: |
|
if layer is None: |
|
print('Finetuning all layers') |
|
else: |
|
print('Finetuning {}'.format(layer)) |
|
|
|
special_params = [id(p) for p in model.embed.parameters()] |
|
base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad] |
|
embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad] |
|
adam = optim.Adam([ |
|
{'params': base_params}, |
|
{'params': embed_parameters, 'weight_decay': embed_l2}, |
|
], lr=lr) |
|
|
|
fit_model(model, loss_op, adam, train_gen, val_gen, nb_epochs, |
|
checkpoint_path, patience) |
|
|
|
|
|
|
|
sleep(1) |
|
model.load_state_dict(torch.load(checkpoint_path)) |
|
if verbose >= 2: |
|
print("Loaded weights from {}".format(checkpoint_path)) |
|
|
|
def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs, |
|
checkpoint_path, patience): |
|
""" Analog to Keras fit_generator function. |
|
|
|
# Arguments: |
|
model: Model to be finetuned. |
|
loss_op: loss operation (BCEWithLogitsLoss or CrossEntropy for e.g.) |
|
optim_op: optimization operation (Adam e.g.) |
|
train_gen: Training data iterator (DataLoader) |
|
val_gen: Validation data iterator (DataLoader) |
|
epochs: Number of epochs. |
|
checkpoint_path: Filepath where weights will be checkpointed to |
|
during training. This file will be rewritten by the function. |
|
patience: Patience for callback methods. |
|
verbose: Verbosity flag. |
|
|
|
# Returns: |
|
Accuracy of the trained model, ONLY if 'evaluate' is set. |
|
""" |
|
|
|
torch.save(model.state_dict(), checkpoint_path) |
|
|
|
model.eval() |
|
best_loss = np.mean([loss_op(model(Variable(xv)).squeeze(), Variable(yv.float()).squeeze()).data.cpu().numpy()[0] for xv, yv in val_gen]) |
|
print("original val loss", best_loss) |
|
|
|
epoch_without_impr = 0 |
|
for epoch in range(epochs): |
|
for i, data in enumerate(train_gen): |
|
X_train, y_train = data |
|
X_train = Variable(X_train, requires_grad=False) |
|
y_train = Variable(y_train, requires_grad=False) |
|
model.train() |
|
optim_op.zero_grad() |
|
output = model(X_train) |
|
loss = loss_op(output, y_train.float()) |
|
loss.backward() |
|
clip_grad_norm(model.parameters(), 1) |
|
optim_op.step() |
|
|
|
acc = evaluate_using_acc(model, [(X_train.data, y_train.data)]) |
|
print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy()[0], "train acc", acc) |
|
|
|
model.eval() |
|
acc = evaluate_using_acc(model, val_gen) |
|
print("val acc", acc) |
|
|
|
val_loss = np.mean([loss_op(model(Variable(xv)).squeeze(), Variable(yv.float()).squeeze()).data.cpu().numpy()[0] for xv, yv in val_gen]) |
|
print("val loss", val_loss) |
|
if best_loss is not None and val_loss >= best_loss: |
|
epoch_without_impr += 1 |
|
print('No improvement over previous best loss: ', best_loss) |
|
|
|
|
|
if best_loss is None or val_loss < best_loss: |
|
best_loss = val_loss |
|
torch.save(model.state_dict(), checkpoint_path) |
|
print('Saving model at', checkpoint_path) |
|
|
|
|
|
if epoch_without_impr >= patience: |
|
break |
|
|
|
def get_data_loader(X_in, y_in, batch_size, extended_batch_sampler=True, epoch_size=25000, upsample=False, seed=42): |
|
""" Returns a dataloader that enables larger epochs on small datasets and |
|
has upsampling functionality. |
|
|
|
# Arguments: |
|
X_in: Inputs of the given dataset. |
|
y_in: Outputs of the given dataset. |
|
batch_size: Batch size. |
|
epoch_size: Number of samples in an epoch. |
|
upsample: Whether upsampling should be done. This flag should only be |
|
set on binary class problems. |
|
|
|
# Returns: |
|
DataLoader. |
|
""" |
|
dataset = DeepMojiDataset(X_in, y_in) |
|
|
|
if extended_batch_sampler: |
|
batch_sampler = DeepMojiBatchSampler(y_in, batch_size, epoch_size=epoch_size, upsample=upsample, seed=seed) |
|
else: |
|
batch_sampler = BatchSampler(SequentialSampler(y_in), batch_size, drop_last=False) |
|
|
|
return DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0) |
|
|
|
class DeepMojiDataset(Dataset): |
|
""" A simple Dataset class. |
|
|
|
# Arguments: |
|
X_in: Inputs of the given dataset. |
|
y_in: Outputs of the given dataset. |
|
|
|
# __getitem__ output: |
|
(torch.LongTensor, torch.LongTensor) |
|
""" |
|
def __init__(self, X_in, y_in): |
|
|
|
if not isinstance(X_in, torch.LongTensor): |
|
X_in = torch.from_numpy(X_in.astype('int64')).long() |
|
if not isinstance(y_in, torch.LongTensor): |
|
y_in = torch.from_numpy(y_in.astype('int64')).long() |
|
|
|
self.X_in = torch.split(X_in, 1, dim=0) |
|
self.y_in = torch.split(y_in, 1, dim=0) |
|
|
|
def __len__(self): |
|
return len(self.X_in) |
|
|
|
def __getitem__(self, idx): |
|
return self.X_in[idx].squeeze(), self.y_in[idx].squeeze() |
|
|
|
class DeepMojiBatchSampler(object): |
|
"""A Batch sampler that enables larger epochs on small datasets and |
|
has upsampling functionality. |
|
|
|
# Arguments: |
|
y_in: Labels of the dataset. |
|
batch_size: Batch size. |
|
epoch_size: Number of samples in an epoch. |
|
upsample: Whether upsampling should be done. This flag should only be |
|
set on binary class problems. |
|
seed: Random number generator seed. |
|
|
|
# __iter__ output: |
|
iterator of lists (batches) of indices in the dataset |
|
""" |
|
|
|
def __init__(self, y_in, batch_size, epoch_size, upsample, seed): |
|
self.batch_size = batch_size |
|
self.epoch_size = epoch_size |
|
self.upsample = upsample |
|
|
|
np.random.seed(seed) |
|
|
|
if upsample: |
|
|
|
assert len(y_in.shape) == 1 |
|
neg = np.where(y_in.numpy() == 0)[0] |
|
pos = np.where(y_in.numpy() == 1)[0] |
|
assert epoch_size % 2 == 0 |
|
samples_pr_class = int(epoch_size / 2) |
|
else: |
|
ind = range(len(y_in)) |
|
|
|
if not upsample: |
|
|
|
self.sample_ind = np.random.choice(ind, epoch_size, replace=True) |
|
else: |
|
|
|
sample_neg = np.random.choice(neg, samples_pr_class, replace=True) |
|
sample_pos = np.random.choice(pos, samples_pr_class, replace=True) |
|
concat_ind = np.concatenate((sample_neg, sample_pos), axis=0) |
|
|
|
|
|
|
|
p = np.random.permutation(len(concat_ind)) |
|
self.sample_ind = concat_ind[p] |
|
|
|
label_dist = np.mean(y_in.numpy()[self.sample_ind]) |
|
assert(label_dist > 0.45) |
|
assert(label_dist < 0.55) |
|
|
|
def __iter__(self): |
|
|
|
for i in range(int(self.epoch_size/self.batch_size)): |
|
start = i * self.batch_size |
|
end = min(start + self.batch_size, self.epoch_size) |
|
yield self.sample_ind[start:end] |
|
|
|
def __len__(self): |
|
|
|
return (self.epoch_size + self.batch_size - 1) // self.batch_size |
|
|