Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
#! /usr/bin/env python3 | |
# coding=utf-8 | |
# This code is licensed under a non-commercial license. | |
import argparse | |
import csv | |
import json | |
import math | |
import time | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torch.optim | |
import torch.optim as optim | |
import torch.utils.data as data | |
from nltk.tokenize.treebank import TreebankWordDetokenizer | |
from torchtext import data as torchtext_data | |
from torchtext import datasets | |
from tqdm import tqdm, trange | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
torch.manual_seed(0) | |
np.random.seed(0) | |
EPSILON = 1e-10 | |
device = "cpu" | |
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had." | |
max_length_seq = 100 | |
class ClassificationHead(torch.nn.Module): | |
"""Classification Head for transformer encoders""" | |
def __init__(self, class_size, embed_size): | |
super(ClassificationHead, self).__init__() | |
self.class_size = class_size | |
self.embed_size = embed_size | |
# self.mlp1 = torch.nn.Linear(embed_size, embed_size) | |
# self.mlp2 = (torch.nn.Linear(embed_size, class_size)) | |
self.mlp = torch.nn.Linear(embed_size, class_size) | |
def forward(self, hidden_state): | |
# hidden_state = F.relu(self.mlp1(hidden_state)) | |
# hidden_state = self.mlp2(hidden_state) | |
logits = self.mlp(hidden_state) | |
return logits | |
class Discriminator(torch.nn.Module): | |
"""Transformer encoder followed by a Classification Head""" | |
def __init__( | |
self, | |
class_size, | |
pretrained_model="gpt2-medium", | |
cached_mode=False | |
): | |
super(Discriminator, self).__init__() | |
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) | |
self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model) | |
self.embed_size = self.encoder.transformer.config.hidden_size | |
self.classifier_head = ClassificationHead( | |
class_size=class_size, | |
embed_size=self.embed_size | |
) | |
self.cached_mode = cached_mode | |
def get_classifier(self): | |
return self.classifier_head | |
def train_custom(self): | |
for param in self.encoder.parameters(): | |
param.requires_grad = False | |
self.classifier_head.train() | |
def avg_representation(self, x): | |
mask = x.ne(0).unsqueeze(2).repeat( | |
1, 1, self.embed_size | |
).float().to(device).detach() | |
hidden, _ = self.encoder.transformer(x) | |
masked_hidden = hidden * mask | |
avg_hidden = torch.sum(masked_hidden, dim=1) / ( | |
torch.sum(mask, dim=1).detach() + EPSILON | |
) | |
return avg_hidden | |
def forward(self, x): | |
if self.cached_mode: | |
avg_hidden = x.to(device) | |
else: | |
avg_hidden = self.avg_representation(x.to(device)) | |
logits = self.classifier_head(avg_hidden) | |
probs = F.log_softmax(logits, dim=-1) | |
return probs | |
class Dataset(data.Dataset): | |
def __init__(self, X, y): | |
"""Reads source and target sequences from txt files.""" | |
self.X = X | |
self.y = y | |
def __len__(self): | |
return len(self.X) | |
def __getitem__(self, index): | |
"""Returns one data pair (source and target).""" | |
data = {} | |
data["X"] = self.X[index] | |
data["y"] = self.y[index] | |
return data | |
def collate_fn(data): | |
def pad_sequences(sequences): | |
lengths = [len(seq) for seq in sequences] | |
padded_sequences = torch.zeros( | |
len(sequences), | |
max(lengths) | |
).long() # padding value = 0 | |
for i, seq in enumerate(sequences): | |
end = lengths[i] | |
padded_sequences[i, :end] = seq[:end] | |
return padded_sequences, lengths | |
item_info = {} | |
for key in data[0].keys(): | |
item_info[key] = [d[key] for d in data] | |
x_batch, _ = pad_sequences(item_info["X"]) | |
y_batch = torch.tensor(item_info["y"], dtype=torch.long) | |
return x_batch, y_batch | |
def cached_collate_fn(data): | |
item_info = {} | |
for key in data[0].keys(): | |
item_info[key] = [d[key] for d in data] | |
x_batch = torch.cat(item_info["X"], 0) | |
y_batch = torch.tensor(item_info["y"], dtype=torch.long) | |
return x_batch, y_batch | |
def train_epoch(data_loader, discriminator, optimizer, | |
epoch=0, log_interval=10): | |
samples_so_far = 0 | |
discriminator.train_custom() | |
for batch_idx, (input_t, target_t) in enumerate(data_loader): | |
input_t, target_t = input_t.to(device), target_t.to(device) | |
optimizer.zero_grad() | |
output_t = discriminator(input_t) | |
loss = F.nll_loss(output_t, target_t) | |
loss.backward(retain_graph=True) | |
optimizer.step() | |
samples_so_far += len(input_t) | |
if batch_idx % log_interval == 0: | |
print( | |
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( | |
epoch + 1, | |
samples_so_far, len(data_loader.dataset), | |
100 * samples_so_far / len(data_loader.dataset), loss.item() | |
) | |
) | |
def evaluate_performance(data_loader, discriminator): | |
discriminator.eval() | |
test_loss = 0 | |
correct = 0 | |
with torch.no_grad(): | |
for input_t, target_t in data_loader: | |
input_t, target_t = input_t.to(device), target_t.to(device) | |
output_t = discriminator(input_t) | |
# sum up batch loss | |
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item() | |
# get the index of the max log-probability | |
pred_t = output_t.argmax(dim=1, keepdim=True) | |
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item() | |
test_loss /= len(data_loader.dataset) | |
print( | |
"Performance on test set: " | |
"Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format( | |
test_loss, correct, len(data_loader.dataset), | |
100. * correct / len(data_loader.dataset) | |
) | |
) | |
def predict(input_sentence, model, classes, cached=False): | |
input_t = model.tokenizer.encode(input_sentence) | |
input_t = torch.tensor([input_t], dtype=torch.long, device=device) | |
if cached: | |
input_t = model.avg_representation(input_t) | |
log_probs = model(input_t).data.cpu().numpy().flatten().tolist() | |
print("Input sentence:", input_sentence) | |
print("Predictions:", ", ".join( | |
"{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in | |
zip(classes, log_probs) | |
)) | |
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False): | |
data_loader = torch.utils.data.DataLoader(dataset=dataset, | |
batch_size=batch_size, | |
collate_fn=collate_fn) | |
xs = [] | |
ys = [] | |
for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)): | |
with torch.no_grad(): | |
x = x.to(device) | |
avg_rep = discriminator.avg_representation(x).cpu().detach() | |
avg_rep_list = torch.unbind(avg_rep.unsqueeze(1)) | |
xs += avg_rep_list | |
ys += y.cpu().numpy().tolist() | |
data_loader = torch.utils.data.DataLoader( | |
dataset=Dataset(xs, ys), | |
batch_size=batch_size, | |
shuffle=shuffle, | |
collate_fn=cached_collate_fn) | |
return data_loader | |
def train_discriminator( | |
dataset, dataset_fp=None, pretrained_model="gpt2-medium", | |
epochs=10, batch_size=64, log_interval=10, | |
save_model=False, cached=False, no_cuda=False): | |
global device | |
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" | |
print("Preprocessing {} dataset...".format(dataset)) | |
start = time.time() | |
if dataset == "SST": | |
idx2class = ["positive", "negative", "very positive", "very negative", | |
"neutral"] | |
class2idx = {c: i for i, c in enumerate(idx2class)} | |
discriminator = Discriminator( | |
class_size=len(idx2class), | |
pretrained_model=pretrained_model, | |
cached_mode=cached | |
).to(device) | |
text = torchtext_data.Field() | |
label = torchtext_data.Field(sequential=False) | |
train_data, val_data, test_data = datasets.SST.splits( | |
text, | |
label, | |
fine_grained=True, | |
train_subtrees=True, | |
) | |
x = [] | |
y = [] | |
for i in trange(len(train_data), ascii=True): | |
seq = TreebankWordDetokenizer().detokenize( | |
vars(train_data[i])["text"] | |
) | |
seq = discriminator.tokenizer.encode(seq) | |
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) | |
x.append(seq) | |
y.append(class2idx[vars(train_data[i])["label"]]) | |
train_dataset = Dataset(x, y) | |
test_x = [] | |
test_y = [] | |
for i in trange(len(test_data), ascii=True): | |
seq = TreebankWordDetokenizer().detokenize( | |
vars(test_data[i])["text"] | |
) | |
seq = discriminator.tokenizer.encode(seq) | |
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) | |
test_x.append(seq) | |
test_y.append(class2idx[vars(test_data[i])["label"]]) | |
test_dataset = Dataset(test_x, test_y) | |
discriminator_meta = { | |
"class_size": len(idx2class), | |
"embed_size": discriminator.embed_size, | |
"pretrained_model": pretrained_model, | |
"class_vocab": class2idx, | |
"default_class": 2, | |
} | |
elif dataset == "clickbait": | |
idx2class = ["non_clickbait", "clickbait"] | |
class2idx = {c: i for i, c in enumerate(idx2class)} | |
discriminator = Discriminator( | |
class_size=len(idx2class), | |
pretrained_model=pretrained_model, | |
cached_mode=cached | |
).to(device) | |
with open("datasets/clickbait/clickbait_train_prefix.txt") as f: | |
data = [] | |
for i, line in enumerate(f): | |
try: | |
data.append(eval(line)) | |
except: | |
print("Error evaluating line {}: {}".format( | |
i, line | |
)) | |
continue | |
x = [] | |
y = [] | |
with open("datasets/clickbait/clickbait_train_prefix.txt") as f: | |
for i, line in enumerate(tqdm(f, ascii=True)): | |
try: | |
d = eval(line) | |
seq = discriminator.tokenizer.encode(d["text"]) | |
if len(seq) < max_length_seq: | |
seq = torch.tensor( | |
[50256] + seq, device=device, dtype=torch.long | |
) | |
else: | |
print("Line {} is longer than maximum length {}".format( | |
i, max_length_seq | |
)) | |
continue | |
x.append(seq) | |
y.append(d["label"]) | |
except: | |
print("Error evaluating / tokenizing" | |
" line {}, skipping it".format(i)) | |
pass | |
full_dataset = Dataset(x, y) | |
train_size = int(0.9 * len(full_dataset)) | |
test_size = len(full_dataset) - train_size | |
train_dataset, test_dataset = torch.utils.data.random_split( | |
full_dataset, [train_size, test_size] | |
) | |
discriminator_meta = { | |
"class_size": len(idx2class), | |
"embed_size": discriminator.embed_size, | |
"pretrained_model": pretrained_model, | |
"class_vocab": class2idx, | |
"default_class": 1, | |
} | |
elif dataset == "toxic": | |
idx2class = ["non_toxic", "toxic"] | |
class2idx = {c: i for i, c in enumerate(idx2class)} | |
discriminator = Discriminator( | |
class_size=len(idx2class), | |
pretrained_model=pretrained_model, | |
cached_mode=cached | |
).to(device) | |
x = [] | |
y = [] | |
with open("datasets/toxic/toxic_train.txt") as f: | |
for i, line in enumerate(tqdm(f, ascii=True)): | |
try: | |
d = eval(line) | |
seq = discriminator.tokenizer.encode(d["text"]) | |
if len(seq) < max_length_seq: | |
seq = torch.tensor( | |
[50256] + seq, device=device, dtype=torch.long | |
) | |
else: | |
print("Line {} is longer than maximum length {}".format( | |
i, max_length_seq | |
)) | |
continue | |
x.append(seq) | |
y.append(int(np.sum(d["label"]) > 0)) | |
except: | |
print("Error evaluating / tokenizing" | |
" line {}, skipping it".format(i)) | |
pass | |
full_dataset = Dataset(x, y) | |
train_size = int(0.9 * len(full_dataset)) | |
test_size = len(full_dataset) - train_size | |
train_dataset, test_dataset = torch.utils.data.random_split( | |
full_dataset, [train_size, test_size] | |
) | |
discriminator_meta = { | |
"class_size": len(idx2class), | |
"embed_size": discriminator.embed_size, | |
"pretrained_model": pretrained_model, | |
"class_vocab": class2idx, | |
"default_class": 0, | |
} | |
else: # if dataset == "generic": | |
# This assumes the input dataset is a TSV with the following structure: | |
# class \t text | |
if dataset_fp is None: | |
raise ValueError("When generic dataset is selected, " | |
"dataset_fp needs to be specified aswell.") | |
classes = set() | |
with open(dataset_fp) as f: | |
csv_reader = csv.reader(f, delimiter="\t") | |
for row in tqdm(csv_reader, ascii=True): | |
if row: | |
classes.add(row[0]) | |
idx2class = sorted(classes) | |
class2idx = {c: i for i, c in enumerate(idx2class)} | |
discriminator = Discriminator( | |
class_size=len(idx2class), | |
pretrained_model=pretrained_model, | |
cached_mode=cached | |
).to(device) | |
x = [] | |
y = [] | |
with open(dataset_fp) as f: | |
csv_reader = csv.reader(f, delimiter="\t") | |
for i, row in enumerate(tqdm(csv_reader, ascii=True)): | |
if row: | |
label = row[0] | |
text = row[1] | |
try: | |
seq = discriminator.tokenizer.encode(text) | |
if (len(seq) < max_length_seq): | |
seq = torch.tensor( | |
[50256] + seq, | |
device=device, | |
dtype=torch.long | |
) | |
else: | |
print( | |
"Line {} is longer than maximum length {}".format( | |
i, max_length_seq | |
)) | |
continue | |
x.append(seq) | |
y.append(class2idx[label]) | |
except: | |
print("Error tokenizing line {}, skipping it".format(i)) | |
pass | |
full_dataset = Dataset(x, y) | |
train_size = int(0.9 * len(full_dataset)) | |
test_size = len(full_dataset) - train_size | |
train_dataset, test_dataset = torch.utils.data.random_split( | |
full_dataset, | |
[train_size, test_size] | |
) | |
discriminator_meta = { | |
"class_size": len(idx2class), | |
"embed_size": discriminator.embed_size, | |
"pretrained_model": pretrained_model, | |
"class_vocab": class2idx, | |
"default_class": 0, | |
} | |
end = time.time() | |
print("Preprocessed {} data points".format( | |
len(train_dataset) + len(test_dataset)) | |
) | |
print("Data preprocessing took: {:.3f}s".format(end - start)) | |
if cached: | |
print("Building representation cache...") | |
start = time.time() | |
train_loader = get_cached_data_loader( | |
train_dataset, batch_size, discriminator, shuffle=True | |
) | |
test_loader = get_cached_data_loader( | |
test_dataset, batch_size, discriminator | |
) | |
end = time.time() | |
print("Building representation cache took: {:.3f}s".format(end - start)) | |
else: | |
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
collate_fn=collate_fn) | |
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, | |
batch_size=batch_size, | |
collate_fn=collate_fn) | |
if save_model: | |
with open("{}_classifier_head_meta.json".format(dataset), | |
"w") as meta_file: | |
json.dump(discriminator_meta, meta_file) | |
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001) | |
for epoch in range(epochs): | |
start = time.time() | |
print("\nEpoch", epoch + 1) | |
train_epoch( | |
discriminator=discriminator, | |
data_loader=train_loader, | |
optimizer=optimizer, | |
epoch=epoch, | |
log_interval=log_interval | |
) | |
evaluate_performance( | |
data_loader=test_loader, | |
discriminator=discriminator | |
) | |
end = time.time() | |
print("Epoch took: {:.3f}s".format(end - start)) | |
print("\nExample prediction") | |
predict(example_sentence, discriminator, idx2class, cached) | |
if save_model: | |
# torch.save(discriminator.state_dict(), | |
# "{}_discriminator_{}.pt".format( | |
# args.dataset, epoch + 1 | |
# )) | |
torch.save(discriminator.get_classifier().state_dict(), | |
"{}_classifier_head_epoch_{}.pt".format(dataset, | |
epoch + 1)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Train a discriminator on top of GPT-2 representations") | |
parser.add_argument("--dataset", type=str, default="SST", | |
choices=("SST", "clickbait", "toxic", "generic"), | |
help="dataset to train the discriminator on." | |
"In case of generic, the dataset is expected" | |
"to be a TSBV file with structure: class \\t text") | |
parser.add_argument("--dataset_fp", type=str, default="", | |
help="File path of the dataset to use. " | |
"Needed only in case of generic datadset") | |
parser.add_argument("--pretrained_model", type=str, default="gpt2-medium", | |
help="Pretrained model to use as encoder") | |
parser.add_argument("--epochs", type=int, default=10, metavar="N", | |
help="Number of training epochs") | |
parser.add_argument("--batch_size", type=int, default=64, metavar="N", | |
help="input batch size for training (default: 64)") | |
parser.add_argument("--log_interval", type=int, default=10, metavar="N", | |
help="how many batches to wait before logging training status") | |
parser.add_argument("--save_model", action="store_true", | |
help="whether to save the model") | |
parser.add_argument("--cached", action="store_true", | |
help="whether to cache the input representations") | |
parser.add_argument("--no_cuda", action="store_true", | |
help="use to turn off cuda") | |
args = parser.parse_args() | |
train_discriminator(**(vars(args))) | |