gpt2-finetuned / main1.py
kyrylokumar's picture
Upload folder using huggingface_hub
4d898ee verified
raw
history blame
16.1 kB
from math import inf
# from utils import *
import torch
import torch.nn as nn
import numpy as np
import torch.utils
import torch.utils.data
# from utils import MyDataset, custom_collate
from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence
import wandb
import torch.nn.functional as F
import einops
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast
np.random.seed(123)
torch.manual_seed(123)
torch.cuda.random.manual_seed(123)
import lightning as L
import utils
from torchmetrics.text.rouge import ROUGEScore
def top_p_sampling(logits, p=0.9, temperature=0.5):
# Apply temperature scaling
logits = logits / temperature
# Sort logits and get cumulative probabilities
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Create a mask for probabilities above the threshold
sorted_indices_to_remove = cumulative_probs > p
# Shift the indices to the right to keep also the smallest p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Scatter sorted indices to original indices with mask
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf') # Set unwanted logits to -inf
# Sample from the remaining logits
probs = F.softmax(logits, dim=-1)
sampled_indices = torch.multinomial(probs, num_samples=1)
sampled_indices = sampled_indices.squeeze(1)
return sampled_indices
class PromptTuningModel(nn.Module):
def __init__(self, num_prompts=6):
super().__init__()
self.num_prompts = num_prompts
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
# self.model.generation_config.cache_implementation = "static"
# self.model.generation_config.max_new_tokens = 256
self.model.requires_grad_(False)
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.tokenizer.add_special_tokens({'cls_token': '[START]'})
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
tmp = self.tokenizer('summarise', return_tensors="pt").input_ids
token_embedding = self.model.transformer.wte(tmp[0])
self.token_embedding = token_embedding
for _ in range(num_prompts//3-1):
self.token_embedding = torch.cat([self.token_embedding, token_embedding])
# print(self.token_embedding.shape)
data = torch.zeros(num_prompts, 768) + self.token_embedding[:]
self.learnable_prompt = nn.Parameter(data, requires_grad=True)
# self.model.transformer.wte.weight[self.start].requires_grad = True
# @torch.compile
def forward(self, X, y):
self.learnable_prompt = self.learnable_prompt.to(X.device)
embeddings = self.model.transformer.wte(X, ) # b s d
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
# mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1)
# print(mask.shape)
# labels = torch.where(y == 50257, -100, y)
# ignore = torch.ones([X.shape[0], self.num_prompts], dtype=torch.long, device=X.device)*-100
# labels = torch.cat([ignore, labels], dim=1)
out = self.model(inputs_embeds = embeddings)
# print("Out.loss:", out.loss)
logits = out.logits[:,self.num_prompts:]
return logits
def generate_new(self, X):
batch_size = X.shape[0]
self.learnable_prompt = self.learnable_prompt.to(X.device)
embeddings = self.model.transformer.wte(X)
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1)
cnt = 0
past_key_values = None
generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) # Store all generated tokens
while cnt < 196:
out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values)
past_key_values = out.past_key_values
# print(cnt)
if cnt == 0:
logits = out.logits[:, self.num_prompts:]
else:
logits = out.logits
logits[:, :, 50257:] = -1e4 # Apply after slicing for correct dimensions
next_token_ids = top_p_sampling(logits[:, -1, :])
# next_token_ids will have shape (batch_size,)
print(next_token_ids.shape)
exit()
generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1)
embeddings = self.model.transformer.wte(next_token_ids) # Correctly obtains embeddings for current batch
cnt += 1
#Check if all sequences have reached the <end> token
if torch.all((generated_ids == self.eot.item()).any(dim=-1)): # Check each sequence independently
break
return generated_ids
def generate(self, X):
# Only bs = 1
self.learnable_prompt = self.learnable_prompt.to(X.device)
embeddings = self.model.transformer.wte(X, ) # b s d
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
cnt = 0
past_key_values = None
final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
while cnt < 196:
out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
# print(cnt, out.logits.shape)
past_key_values = out.past_key_values
if cnt == 0:
logits = out.logits[:,self.num_prompts:]
logits[:,:, 50257:] = -1e4
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
# print(output.shape)
final_prediction = torch.cat([final_prediction, output], dim=1)
# print(output.shape)
embeddings = self.model.transformer.wte(output)
# print(embeddings.shape)
else:
# print(logits.shape)
logits = out.logits
logits[:, :, 50257:] = -1e4
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
# print(output)
final_prediction = torch.cat([final_prediction, output], dim=1)
# print(final_prediction.shape, 'final')
embeddings = self.model.transformer.wte(output)
cnt += 1
# print(output.shape, self.eot.shape)
if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
break
return final_prediction
class LMModel(nn.Module):
def __init__(self, num_prompts=6):
super().__init__()
self.num_prompts = num_prompts
self.model = AutoModelForCausalLM.from_pretrained("gpt2", )
# self.model.generation_config.cache_implementation = "static"
# self.model.generation_config.max_new_tokens = 256
self.model.requires_grad_(False)
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.tokenizer.add_special_tokens({'cls_token': '[START]'})
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0]
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0]
self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0]
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),)
self.model.lm_head.requires_grad_(True)
# @torch.compile
def forward(self, X, y):
embeddings = self.model.transformer.wte(X, ) # b s d
logits = self.model(inputs_embeds = embeddings).logits
return logits
def generate(self, X):
# Only bs = 1
# self.learnable_prompt = self.learnable_prompt.to(X.device)
embeddings = self.model.transformer.wte(X, ) # b s d
# embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1)
cnt = 0
past_key_values = None
final_prediction = torch.tensor([], dtype=torch.long).to(X.device)
while cnt < 196:
out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values)
# print(cnt, out.logits.shape)
past_key_values = out.past_key_values
if cnt == 0:
logits = out.logits[:,self.num_prompts:]
logits[:,:, 50257:] = -1e4
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
# print(output.shape)
final_prediction = torch.cat([final_prediction, output], dim=1)
# print(output.shape)
embeddings = self.model.transformer.wte(output)
# print(embeddings.shape)
else:
# print(logits.shape)
logits = out.logits
logits[:, :, 50257:] = -1e4
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None]
# print(output)
final_prediction = torch.cat([final_prediction, output], dim=1)
# print(final_prediction.shape, 'final')
embeddings = self.model.transformer.wte(output)
cnt += 1
# print(output.shape, self.eot.shape)
if torch.all((final_prediction == self.eot.item()).any(dim=-1)):
break
return final_prediction
def zero_after_x(arr, x):
"""
Zeros out all elements in each row of a 2D tensor after the first occurrence of x.
Args:
tensor: The input 2D tensor.
x: The value after which to zero out elements.
Returns:
A new tensor with elements zeroed out after x.
"""
mask = (arr == x).cumsum(dim=1) > 0 # Create a cumulative mask
result = torch.where(mask, x, arr) #zero out where mask is True
return result
class LitModelPromptTuning(L.LightningModule):
def __init__(self, model, temperature, epoch, lr=1e-4, **kwargs):
super().__init__()
self.model = model
self.lr = lr
self.model.temperature = temperature
self.epoch = epoch
self.temperature = temperature
for key, value in kwargs.items():
setattr(self, key, value)
tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"])
self.rouge = ROUGEScore(tokenizer=tokenize_to_strings)
self.save_hyperparameters(ignore=['model'])
def training_step(self, batch, batch_idx):
X, y = batch
# for i,j in zip(X[1], y[1]):
# print(i.item(),j.item())
logits = self.model(X, y)
logits[:,:, 50257:] = -1e4
# print(X.shape, y.shape, logits.shape)
loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
# print(loss)
# prob = logits.softmax(dim=-1)[0,-300:]
# target = y[0, -300:]
# print('logits',logits[0,-300:][torch.arange(target.numel()), target])
# print(prob[torch.arange(target.numel()), target])
# print(prob.argmax(dim=-1), target, X[0, -300:])
# print(self.model.tokenizer.decode(prob.argmax(dim=-1)), 'gap', self.model.tokenizer.decode(target))
# print(self.model.tokenizer.decode(X[0,-300:]))
# x = F.cross_entropy(logits[0,:-1,:].reshape(-1, logits.shape[-1]), target=y[0,:-1].reshape(-1), ignore_index=50257, reduction='none')
# print(x[-20:])
# print(self.model.pad)
# print(X[0, -300:].shape, target.shape, prob.argmax(dim=-1).shape)
# for i,j,k in zip(X[0, -300:], target, prob.argmax(dim=-1)):
# print(self.model.tokenizer.decode(i),'\tx ',self.model.tokenizer.decode(j),'\tx ',self.model.tokenizer.decode(k))
# exit()
self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True)
return loss
def validation_step(self, batch, batch_idx):
X, y = batch
logits = self.model(X, y)
logits[:,:, 50257:] = -1e4
# print(X.shape, y.shape, logits.shape)
loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257)
self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
return loss
def on_test_epoch_start(self, ):
self.all_text = []
self.predicted_text = []
def test_step(self, batch, batch_idx):
if batch_idx == 0:
return
X, y = batch
# print(self.model.tokenizer.batch_decode(X))
# print(X.shape)
# with torch.no_grad()
out = self.model.generate(X)
# out = zero_after_x(out, self.model.eot.to(X.device))
# print(out.shape, y.shape)
# print(out, y)
pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=False)
gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=False)
print(pred)
print('GAP')
print(gt)
final_score = 0
for p,g in zip(pred, gt):
score = self.rouge(p, g, )
print(score)
# exit()
self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
return optimizer
from lightning.pytorch.loggers import WandbLogger
if __name__ == '__main__':
train = False
torch.set_float32_matmul_precision('medium')
dl_train, dl_val, dl_test = utils.import_data(bs=24, fraction=1)
# gpt_model = PromptTuningModel(num_prompts=24)
if train:
gpt_model = LMModel(num_prompts=12)
gpt_model = torch.compile(gpt_model)
else:
gpt_model = torch.load('./model1.bin')
model = LitModelPromptTuning(
model=gpt_model,
lr=1e-4,
temperature=0.9,
epoch = 5,
type_model = 'lm_head'
)
print('Training')
logger = WandbLogger(project='Anlp-3')
trainer = L.Trainer(
accelerator='gpu',
# limit_train_batches=1,
# strategy='auto',
# strategy=pl.strategies.DDPStrategy(find_unused_parameters=True),
devices=1,
default_root_dir=f'./logs/', # Tensorflow can be used to viz
num_nodes=1,
num_sanity_val_steps=1, # runs a validation step before stating training
precision='bf16-mixed', # we use half precision to reduce memory usage
max_epochs=5,
check_val_every_n_epoch=1, # run validation every epoch
log_every_n_steps=20,
logger=logger,
# detect_anomaly=True,
)
if train:
trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)
trainer.test(model, dataloaders=dl_test)
torch.save(model.model, './model1.bin')
else:
trainer.test(model, dataloaders=dl_test)