|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import argparse |
|
from tqdm import trange |
|
from torchtext import data as torchtext_data |
|
from torchtext import datasets |
|
|
|
import torch |
|
import torch.utils.data as data |
|
|
|
from torchtext.vocab import Vectors, GloVe, CharNGram, FastText |
|
from nltk.tokenize.treebank import TreebankWordDetokenizer |
|
import torch |
|
import torch.optim |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from IPython import embed |
|
from operator import add |
|
from run_gpt2 import top_k_logits |
|
from style_utils import to_var |
|
import copy |
|
import pickle |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.dataset import random_split |
|
import torch.optim as optim |
|
|
|
torch.manual_seed(0) |
|
np.random.seed(0) |
|
|
|
lab_root = os.path.join(os.path.abspath(os.path.dirname(__file__)), '..', '..') |
|
sys.path.insert(1, lab_root) |
|
|
|
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer |
|
from torch.autograd import Variable |
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt-2_pt_models/345M/') |
|
|
|
model = GPT2LMHeadModel.from_pretrained('gpt-2_pt_models/345M/') |
|
|
|
|
|
class ClassificationHead(torch.nn.Module): |
|
""" Language Model Head for the transformer """ |
|
|
|
def __init__(self, class_size=5, embed_size=2048): |
|
super(ClassificationHead, self).__init__() |
|
self.class_size = class_size |
|
self.embed_size = embed_size |
|
|
|
|
|
self.mlp = (torch.nn.Linear(embed_size, class_size)) |
|
|
|
def forward(self, hidden_state): |
|
|
|
|
|
|
|
|
|
lm_logits = self.mlp(hidden_state) |
|
return lm_logits |
|
|
|
|
|
class Discriminator(torch.nn.Module): |
|
def __init__(self): |
|
super(Discriminator, self).__init__() |
|
self.classifierhead = ClassificationHead() |
|
self.model = model |
|
self.spltoken = Variable(torch.randn(1, 1, 1024).type(torch.FloatTensor), requires_grad=True) |
|
self.spltoken = self.spltoken.repeat(10, 1, 1) |
|
self.spltoken = self.spltoken.cuda() |
|
|
|
def train(self): |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
pass |
|
|
|
def forward(self, x): |
|
x = model.forward_embed(x) |
|
x = torch.cat((x, self.spltoken), dim=1) |
|
_, x = model.forward_transformer_embed(x, add_one=True) |
|
x = self.classifierhead(x[-1][:, -1, :]) |
|
x = F.log_softmax(x, dim=-1) |
|
return x |
|
|
|
|
|
class Discriminator2(torch.nn.Module): |
|
def __init__(self, class_size=5, embed_size=1024): |
|
super(Discriminator2, self).__init__() |
|
self.classifierhead = ClassificationHead(class_size=class_size, embed_size=embed_size) |
|
self.model = model |
|
self.embed_size = embed_size |
|
|
|
def get_classifier(self): |
|
return self.classifierhead |
|
|
|
def train_custom(self): |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
pass |
|
self.classifierhead.train() |
|
|
|
def forward(self, x): |
|
x = model.forward_embed(x) |
|
hidden, x = model.forward_transformer_embed(x) |
|
x = torch.sum(hidden, dim=1) |
|
x = self.classifierhead(x) |
|
x = F.log_softmax(x, dim=-1) |
|
return x |
|
|
|
class Discriminator2mean(torch.nn.Module): |
|
def __init__(self, class_size=5, embed_size=1024): |
|
super(Discriminator2mean, self).__init__() |
|
self.classifierhead = ClassificationHead(class_size=class_size, embed_size=embed_size) |
|
self.model = model |
|
self.embed_size = embed_size |
|
|
|
def get_classifier(self): |
|
return self.classifierhead |
|
|
|
def train_custom(self): |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
pass |
|
self.classifierhead.train() |
|
|
|
def forward(self, x): |
|
mask_src = 1 - x.eq(0).unsqueeze(1).type(torch.FloatTensor).cuda().detach() |
|
mask_src = mask_src.repeat(1, self.embed_size, 1) |
|
x = model.forward_embed(x) |
|
hidden, x = model.forward_transformer_embed(x) |
|
|
|
|
|
hidden = hidden.permute(0, 2, 1) |
|
_, _, batch_length = hidden.shape |
|
hidden = hidden * mask_src |
|
|
|
hidden = hidden.permute(0, 2, 1) |
|
x = torch.sum(hidden, dim=1)/(torch.sum(mask_src, dim=-1).detach() + 1e-10) |
|
x = self.classifierhead(x) |
|
x = F.log_softmax(x, dim=-1) |
|
return x |
|
|
|
class Dataset(data.Dataset): |
|
def __init__(self, X, y): |
|
"""Reads source and target sequences from txt files.""" |
|
self.X = X |
|
self.y = y |
|
|
|
def __len__(self): |
|
return len(self.X) |
|
|
|
def __getitem__(self, index): |
|
"""Returns one data pair (source and target).""" |
|
d = {} |
|
d['X'] = self.X[index] |
|
d['y'] = self.y[index] |
|
return d |
|
|
|
|
|
def collate_fn(data): |
|
def merge(sequences): |
|
lengths = [len(seq) for seq in sequences] |
|
|
|
padded_seqs = torch.zeros(len(sequences), max(lengths)).long().cuda() |
|
for i, seq in enumerate(sequences): |
|
end = lengths[i] |
|
padded_seqs[i, :end] = seq[:end] |
|
return padded_seqs, lengths |
|
|
|
data.sort(key=lambda x: len(x["X"]), reverse=True) |
|
|
|
item_info = {} |
|
for key in data[0].keys(): |
|
item_info[key] = [d[key] for d in data] |
|
|
|
|
|
x_batch, _ = merge(item_info['X']) |
|
y_batch = item_info['y'] |
|
|
|
return x_batch, torch.tensor(y_batch, device='cuda', dtype=torch.long) |
|
|
|
|
|
def train_epoch(data_loader, discriminator, device='cuda', args=None, epoch=1): |
|
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001) |
|
discriminator.train_custom() |
|
|
|
for batch_idx, (data, target) in enumerate(data_loader): |
|
data, target = data.to(device), target.to(device) |
|
|
|
optimizer.zero_grad() |
|
|
|
output = discriminator(data) |
|
loss = F.nll_loss(output, target) |
|
loss.backward(retain_graph=True) |
|
optimizer.step() |
|
|
|
if batch_idx % args.log_interval == 0: |
|
print('Relu Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( |
|
epoch, batch_idx * len(data), len(data_loader.dataset), |
|
100. * batch_idx / len(data_loader), loss.item())) |
|
|
|
|
|
def test_epoch(data_loader, discriminator, device='cuda', args=None): |
|
discriminator.eval() |
|
test_loss = 0 |
|
correct = 0 |
|
with torch.no_grad(): |
|
for data, target in data_loader: |
|
data, target = data.to(device), target.to(device) |
|
output = discriminator(data) |
|
test_loss += F.nll_loss(output, target, reduction='sum').item() |
|
pred = output.argmax(dim=1, keepdim=True) |
|
correct += pred.eq(target.view_as(pred)).sum().item() |
|
|
|
test_loss /= len(data_loader.dataset) |
|
|
|
print('\nRelu Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( |
|
test_loss, correct, len(data_loader.dataset), |
|
100. * correct / len(data_loader.dataset))) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Train a discriminator on top of GPT-2 representations') |
|
parser.add_argument('--batch-size', type=int, default=64, metavar='N', |
|
help='input batch size for training (default: 64)') |
|
parser.add_argument('--log-interval', type=int, default=10, metavar='N', |
|
help='how many batches to wait before logging training status') |
|
parser.add_argument('--epochs', type=int, default=10, metavar='N', |
|
help='Number of training epochs') |
|
parser.add_argument('--save-model', action='store_true', help='whether to save the model') |
|
parser.add_argument('--dataset-label', type=str, default='SST',choices=('SST', 'clickbait', 'toxic')) |
|
args = parser.parse_args() |
|
|
|
batch_size = args.batch_size |
|
device = 'cuda' |
|
|
|
if args.dataset_label == 'SST': |
|
text = torchtext_data.Field() |
|
label = torchtext_data.Field(sequential=False) |
|
train_data, val_data, test_data = datasets.SST.splits(text, label, fine_grained=True, train_subtrees=True, |
|
|
|
) |
|
x = [] |
|
y = [] |
|
d = {"positive": 0, "negative": 1, "very positive": 2, "very negative": 3, "neutral": 4} |
|
|
|
for i in range(len(train_data)): |
|
seq = TreebankWordDetokenizer().detokenize(vars(train_data[i])["text"]) |
|
seq = tokenizer.encode(seq) |
|
seq = torch.tensor(seq, device=device, dtype=torch.long) |
|
x.append(seq) |
|
y.append(d[vars(train_data[i])["label"]]) |
|
|
|
dataset = Dataset(x, y) |
|
|
|
test_x = [] |
|
test_y = [] |
|
for i in range(len(test_data)): |
|
seq = TreebankWordDetokenizer().detokenize(vars(test_data[i])["text"]) |
|
seq = tokenizer.encode(seq) |
|
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) |
|
test_x.append(seq) |
|
test_y.append(d[vars(test_data[i])["label"]]) |
|
test_dataset = Dataset(test_x, test_y) |
|
discriminator = Discriminator2mean(class_size=5).to(device) |
|
|
|
elif args.dataset_label == 'clickbait': |
|
|
|
with open("datasets/clickbait/clickbait_train_prefix.txt") as f: |
|
data = [] |
|
for d in f: |
|
try: |
|
data.append(eval(d)) |
|
except: |
|
continue |
|
x = [] |
|
y = [] |
|
for d in data: |
|
try: |
|
|
|
try: |
|
seq = tokenizer.encode(d["text"]) |
|
except: |
|
continue |
|
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) |
|
x.append(seq) |
|
y.append(d['label']) |
|
except: |
|
pass |
|
|
|
dataset = Dataset(x, y) |
|
train_size = int(0.9 * len(dataset)) |
|
test_size = len(dataset) - train_size |
|
dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) |
|
discriminator = Discriminator2mean(class_size=2).to(device) |
|
|
|
elif args.dataset_label == 'toxic': |
|
|
|
with open("datasets/toxic/toxic_train.txt") as f: |
|
data = [] |
|
for d in f: |
|
data.append(eval(d)) |
|
|
|
x = [] |
|
y = [] |
|
for d in data: |
|
try: |
|
|
|
seq = tokenizer.encode(d["text"]) |
|
|
|
device = 'cuda' |
|
if(len(seq)<100): |
|
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) |
|
else: |
|
continue |
|
x.append(seq) |
|
y.append(int(np.sum(d['label'])>0)) |
|
except: |
|
pass |
|
|
|
dataset = Dataset(x, y) |
|
print(dataset) |
|
print(len(dataset)) |
|
train_size = int(0.9 * len(dataset)) |
|
test_size = len(dataset) - train_size |
|
dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) |
|
discriminator = Discriminator2mean(class_size=2).to(device) |
|
|
|
data_loader = torch.utils.data.DataLoader(dataset=dataset, |
|
batch_size=batch_size, |
|
shuffle=True, collate_fn=collate_fn) |
|
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, |
|
batch_size=batch_size, collate_fn=collate_fn) |
|
|
|
for epoch in range(args.epochs): |
|
train_epoch(discriminator=discriminator, data_loader=data_loader, args=args, device=device, epoch=epoch) |
|
test_epoch(data_loader=test_loader, discriminator=discriminator, args=args) |
|
seq = tokenizer.encode("This is incredible! I love it, this is the best chicken I have ever had.") |
|
seq = torch.tensor([seq], device=device, dtype=torch.long) |
|
print(discriminator(seq)) |
|
|
|
if (args.save_model): |
|
torch.save(discriminator.state_dict(), |
|
"discrim_models/{}_mean_lin_discriminator_{}.pt".format(args.dataset_label, epoch)) |
|
torch.save(discriminator.get_classifier().state_dict(), |
|
"discrim_models/{}_classifierhead.pt".format(args.dataset_label)) |
|
|
|
seq = tokenizer.encode("This is incredible! I love it, this is the best chicken I have ever had.") |
|
seq = torch.tensor([seq], device=device, dtype=torch.long) |
|
print(discriminator(seq)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|