sample_2 / paper_code /gpt2tunediscrim.py
zachlopez's picture
Add application file
87d688d
#! /usr/bin/env python3
# coding=utf-8
# This code is licensed under a non-commercial license.
import os
import sys
import argparse
from tqdm import trange
from torchtext import data as torchtext_data
from torchtext import datasets
import torch
import torch.utils.data as data
from torchtext.vocab import Vectors, GloVe, CharNGram, FastText
from nltk.tokenize.treebank import TreebankWordDetokenizer
import torch
import torch.optim
import torch.nn.functional as F
import numpy as np
from IPython import embed
from operator import add
from run_gpt2 import top_k_logits
from style_utils import to_var
import copy
import pickle
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch.optim as optim
torch.manual_seed(0)
np.random.seed(0)
lab_root = os.path.join(os.path.abspath(os.path.dirname(__file__)), '..', '..')
sys.path.insert(1, lab_root)
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
from torch.autograd import Variable
tokenizer = GPT2Tokenizer.from_pretrained('gpt-2_pt_models/345M/')
model = GPT2LMHeadModel.from_pretrained('gpt-2_pt_models/345M/')
class ClassificationHead(torch.nn.Module):
""" Language Model Head for the transformer """
def __init__(self, class_size=5, embed_size=2048):
super(ClassificationHead, self).__init__()
self.class_size = class_size
self.embed_size = embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self.mlp = (torch.nn.Linear(embed_size, class_size))
def forward(self, hidden_state):
# Truncated Language modeling logits (we remove the last token)
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
# lm_logits = F.relu(self.mlp1(hidden_state))
# lm_logits = self.mlp2(lm_logits)
lm_logits = self.mlp(hidden_state)
return lm_logits
class Discriminator(torch.nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.classifierhead = ClassificationHead()
self.model = model
self.spltoken = Variable(torch.randn(1, 1, 1024).type(torch.FloatTensor), requires_grad=True)
self.spltoken = self.spltoken.repeat(10, 1, 1)
self.spltoken = self.spltoken.cuda()
def train(self):
for param in self.model.parameters():
param.requires_grad = False
pass
def forward(self, x):
x = model.forward_embed(x)
x = torch.cat((x, self.spltoken), dim=1)
_, x = model.forward_transformer_embed(x, add_one=True)
x = self.classifierhead(x[-1][:, -1, :])
x = F.log_softmax(x, dim=-1)
return x
class Discriminator2(torch.nn.Module):
def __init__(self, class_size=5, embed_size=1024):
super(Discriminator2, self).__init__()
self.classifierhead = ClassificationHead(class_size=class_size, embed_size=embed_size)
self.model = model
self.embed_size = embed_size
def get_classifier(self):
return self.classifierhead
def train_custom(self):
for param in self.model.parameters():
param.requires_grad = False
pass
self.classifierhead.train()
def forward(self, x):
x = model.forward_embed(x)
hidden, x = model.forward_transformer_embed(x)
x = torch.sum(hidden, dim=1)
x = self.classifierhead(x)
x = F.log_softmax(x, dim=-1)
return x
class Discriminator2mean(torch.nn.Module):
def __init__(self, class_size=5, embed_size=1024):
super(Discriminator2mean, self).__init__()
self.classifierhead = ClassificationHead(class_size=class_size, embed_size=embed_size)
self.model = model
self.embed_size = embed_size
def get_classifier(self):
return self.classifierhead
def train_custom(self):
for param in self.model.parameters():
param.requires_grad = False
pass
self.classifierhead.train()
def forward(self, x):
mask_src = 1 - x.eq(0).unsqueeze(1).type(torch.FloatTensor).cuda().detach()
mask_src = mask_src.repeat(1, self.embed_size, 1)
x = model.forward_embed(x)
hidden, x = model.forward_transformer_embed(x)
# Hidden has shape batch_size x length x embed-dim
hidden = hidden.permute(0, 2, 1)
_, _, batch_length = hidden.shape
hidden = hidden * mask_src # / torch.sum(mask_src, dim=-1).unsqueeze(2).repeat(1, 1, batch_length)
#
hidden = hidden.permute(0, 2, 1)
x = torch.sum(hidden, dim=1)/(torch.sum(mask_src, dim=-1).detach() + 1e-10)
x = self.classifierhead(x)
x = F.log_softmax(x, dim=-1)
return x
class Dataset(data.Dataset):
def __init__(self, X, y):
"""Reads source and target sequences from txt files."""
self.X = X
self.y = y
def __len__(self):
return len(self.X)
def __getitem__(self, index):
"""Returns one data pair (source and target)."""
d = {}
d['X'] = self.X[index]
d['y'] = self.y[index]
return d
def collate_fn(data):
def merge(sequences):
lengths = [len(seq) for seq in sequences]
padded_seqs = torch.zeros(len(sequences), max(lengths)).long().cuda() # padding index 0
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
data.sort(key=lambda x: len(x["X"]), reverse=True) # sort by source seq
item_info = {}
for key in data[0].keys():
item_info[key] = [d[key] for d in data]
# input
x_batch, _ = merge(item_info['X'])
y_batch = item_info['y']
return x_batch, torch.tensor(y_batch, device='cuda', dtype=torch.long)
def train_epoch(data_loader, discriminator, device='cuda', args=None, epoch=1):
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
discriminator.train_custom()
for batch_idx, (data, target) in enumerate(data_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = discriminator(data)
loss = F.nll_loss(output, target)
loss.backward(retain_graph=True)
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Relu Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(data_loader.dataset),
100. * batch_idx / len(data_loader), loss.item()))
def test_epoch(data_loader, discriminator, device='cuda', args=None):
discriminator.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in data_loader:
data, target = data.to(device), target.to(device)
output = discriminator(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(data_loader.dataset)
print('\nRelu Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(data_loader.dataset),
100. * correct / len(data_loader.dataset)))
def main():
parser = argparse.ArgumentParser(description='Train a discriminator on top of GPT-2 representations')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='Number of training epochs')
parser.add_argument('--save-model', action='store_true', help='whether to save the model')
parser.add_argument('--dataset-label', type=str, default='SST',choices=('SST', 'clickbait', 'toxic'))
args = parser.parse_args()
batch_size = args.batch_size
device = 'cuda'
# load sst
if args.dataset_label == 'SST':
text = torchtext_data.Field()
label = torchtext_data.Field(sequential=False)
train_data, val_data, test_data = datasets.SST.splits(text, label, fine_grained=True, train_subtrees=True,
# filter_pred=lambda ex: ex.label != 'neutral'
)
x = []
y = []
d = {"positive": 0, "negative": 1, "very positive": 2, "very negative": 3, "neutral": 4}
for i in range(len(train_data)):
seq = TreebankWordDetokenizer().detokenize(vars(train_data[i])["text"])
seq = tokenizer.encode(seq)
seq = torch.tensor(seq, device=device, dtype=torch.long)
x.append(seq)
y.append(d[vars(train_data[i])["label"]])
dataset = Dataset(x, y)
test_x = []
test_y = []
for i in range(len(test_data)):
seq = TreebankWordDetokenizer().detokenize(vars(test_data[i])["text"])
seq = tokenizer.encode(seq)
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
test_x.append(seq)
test_y.append(d[vars(test_data[i])["label"]])
test_dataset = Dataset(test_x, test_y)
discriminator = Discriminator2mean(class_size=5).to(device)
elif args.dataset_label == 'clickbait':
# data = pickle.load(open("/home/gilocal/lab/exp/language/datasets/clickbait/clickbait.p", "r"))
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
data = []
for d in f:
try:
data.append(eval(d))
except:
continue
x = []
y = []
for d in data:
try:
# seq = tokenizer.encode("Apple's iOS 9 'App thinning' feature will give your phone's storage a boost")
try:
seq = tokenizer.encode(d["text"])
except:
continue
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
x.append(seq)
y.append(d['label'])
except:
pass
dataset = Dataset(x, y)
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
discriminator = Discriminator2mean(class_size=2).to(device)
elif args.dataset_label == 'toxic':
# data = pickle.load(open("/home/gilocal/lab/exp/language/datasets/clickbait/clickbait.p", "r"))
with open("datasets/toxic/toxic_train.txt") as f:
data = []
for d in f:
data.append(eval(d))
x = []
y = []
for d in data:
try:
# seq = tokenizer.encode("Apple's iOS 9 'App thinning' feature will give your phone's storage a boost")
seq = tokenizer.encode(d["text"])
device = 'cuda'
if(len(seq)<100):
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
else:
continue
x.append(seq)
y.append(int(np.sum(d['label'])>0))
except:
pass
dataset = Dataset(x, y)
print(dataset)
print(len(dataset))
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
discriminator = Discriminator2mean(class_size=2).to(device)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size, collate_fn=collate_fn)
for epoch in range(args.epochs):
train_epoch(discriminator=discriminator, data_loader=data_loader, args=args, device=device, epoch=epoch)
test_epoch(data_loader=test_loader, discriminator=discriminator, args=args)
seq = tokenizer.encode("This is incredible! I love it, this is the best chicken I have ever had.")
seq = torch.tensor([seq], device=device, dtype=torch.long)
print(discriminator(seq))
if (args.save_model):
torch.save(discriminator.state_dict(),
"discrim_models/{}_mean_lin_discriminator_{}.pt".format(args.dataset_label, epoch))
torch.save(discriminator.get_classifier().state_dict(),
"discrim_models/{}_classifierhead.pt".format(args.dataset_label))
seq = tokenizer.encode("This is incredible! I love it, this is the best chicken I have ever had.")
seq = torch.tensor([seq], device=device, dtype=torch.long)
print(discriminator(seq))
if __name__ == '__main__':
main()