|
from math import inf |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import torch.utils |
|
import torch.utils.data |
|
|
|
from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence |
|
import wandb |
|
import torch.nn.functional as F |
|
|
|
import einops |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast |
|
|
|
np.random.seed(123) |
|
torch.manual_seed(123) |
|
torch.cuda.random.manual_seed(123) |
|
|
|
import lightning as L |
|
import utils |
|
from torchmetrics.text.rouge import ROUGEScore |
|
def top_p_sampling(logits, p=0.9, temperature=0.5): |
|
|
|
|
|
logits = logits / temperature |
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > p |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
sampled_indices = torch.multinomial(probs, num_samples=1) |
|
sampled_indices = sampled_indices.squeeze(1) |
|
|
|
return sampled_indices |
|
|
|
class PromptTuningModel(nn.Module): |
|
def __init__(self, num_prompts=6): |
|
super().__init__() |
|
self.num_prompts = num_prompts |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained("gpt2", ) |
|
|
|
|
|
self.model.requires_grad_(False) |
|
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") |
|
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
self.tokenizer.add_special_tokens({'cls_token': '[START]'}) |
|
|
|
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0] |
|
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0] |
|
self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0] |
|
|
|
|
|
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),) |
|
|
|
tmp = self.tokenizer('summarise', return_tensors="pt").input_ids |
|
token_embedding = self.model.transformer.wte(tmp[0]) |
|
self.token_embedding = token_embedding |
|
for _ in range(num_prompts//3-1): |
|
self.token_embedding = torch.cat([self.token_embedding, token_embedding]) |
|
|
|
|
|
data = torch.zeros(num_prompts, 768) + self.token_embedding[:] |
|
self.learnable_prompt = nn.Parameter(data, requires_grad=True) |
|
|
|
|
|
|
|
|
|
def forward(self, X, y): |
|
self.learnable_prompt = self.learnable_prompt.to(X.device) |
|
embeddings = self.model.transformer.wte(X, ) |
|
|
|
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
out = self.model(inputs_embeds = embeddings) |
|
|
|
logits = out.logits[:,self.num_prompts:] |
|
return logits |
|
|
|
def generate_new(self, X): |
|
batch_size = X.shape[0] |
|
self.learnable_prompt = self.learnable_prompt.to(X.device) |
|
embeddings = self.model.transformer.wte(X) |
|
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1) |
|
|
|
cnt = 0 |
|
past_key_values = None |
|
generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) |
|
|
|
while cnt < 196: |
|
|
|
out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values) |
|
past_key_values = out.past_key_values |
|
|
|
if cnt == 0: |
|
logits = out.logits[:, self.num_prompts:] |
|
else: |
|
logits = out.logits |
|
|
|
logits[:, :, 50257:] = -1e4 |
|
|
|
next_token_ids = top_p_sampling(logits[:, -1, :]) |
|
|
|
print(next_token_ids.shape) |
|
exit() |
|
generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1) |
|
|
|
embeddings = self.model.transformer.wte(next_token_ids) |
|
|
|
|
|
cnt += 1 |
|
|
|
|
|
if torch.all((generated_ids == self.eot.item()).any(dim=-1)): |
|
break |
|
|
|
return generated_ids |
|
def generate(self, X): |
|
|
|
self.learnable_prompt = self.learnable_prompt.to(X.device) |
|
embeddings = self.model.transformer.wte(X, ) |
|
embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1) |
|
|
|
cnt = 0 |
|
past_key_values = None |
|
final_prediction = torch.tensor([], dtype=torch.long).to(X.device) |
|
while cnt < 196: |
|
out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values) |
|
|
|
past_key_values = out.past_key_values |
|
if cnt == 0: |
|
logits = out.logits[:,self.num_prompts:] |
|
logits[:,:, 50257:] = -1e4 |
|
|
|
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
|
|
|
|
|
|
|
final_prediction = torch.cat([final_prediction, output], dim=1) |
|
|
|
embeddings = self.model.transformer.wte(output) |
|
|
|
|
|
|
|
else: |
|
|
|
logits = out.logits |
|
logits[:, :, 50257:] = -1e4 |
|
|
|
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
|
|
|
final_prediction = torch.cat([final_prediction, output], dim=1) |
|
|
|
embeddings = self.model.transformer.wte(output) |
|
|
|
|
|
|
|
cnt += 1 |
|
|
|
if torch.all((final_prediction == self.eot.item()).any(dim=-1)): |
|
break |
|
|
|
return final_prediction |
|
|
|
|
|
|
|
|
|
class LMModel(nn.Module): |
|
def __init__(self, num_prompts=6): |
|
super().__init__() |
|
self.num_prompts = num_prompts |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained("gpt2", ) |
|
|
|
|
|
self.model.requires_grad_(False) |
|
self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") |
|
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
self.tokenizer.add_special_tokens({'cls_token': '[START]'}) |
|
|
|
self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0] |
|
self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0] |
|
self.start = self.tokenizer("[START]", return_tensors="pt").input_ids[0] |
|
|
|
|
|
self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),) |
|
|
|
|
|
self.model.lm_head.requires_grad_(True) |
|
|
|
|
|
def forward(self, X, y): |
|
embeddings = self.model.transformer.wte(X, ) |
|
logits = self.model(inputs_embeds = embeddings).logits |
|
return logits |
|
|
|
def generate(self, X): |
|
|
|
|
|
embeddings = self.model.transformer.wte(X, ) |
|
|
|
|
|
cnt = 0 |
|
past_key_values = None |
|
final_prediction = torch.tensor([], dtype=torch.long).to(X.device) |
|
while cnt < 196: |
|
out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values) |
|
|
|
past_key_values = out.past_key_values |
|
if cnt == 0: |
|
logits = out.logits[:,self.num_prompts:] |
|
logits[:,:, 50257:] = -1e4 |
|
|
|
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
|
|
|
|
|
|
|
final_prediction = torch.cat([final_prediction, output], dim=1) |
|
|
|
embeddings = self.model.transformer.wte(output) |
|
|
|
|
|
|
|
else: |
|
|
|
logits = out.logits |
|
logits[:, :, 50257:] = -1e4 |
|
|
|
output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] |
|
|
|
final_prediction = torch.cat([final_prediction, output], dim=1) |
|
|
|
embeddings = self.model.transformer.wte(output) |
|
|
|
|
|
|
|
cnt += 1 |
|
|
|
if torch.all((final_prediction == self.eot.item()).any(dim=-1)): |
|
break |
|
|
|
return final_prediction |
|
|
|
def zero_after_x(arr, x): |
|
""" |
|
Zeros out all elements in each row of a 2D tensor after the first occurrence of x. |
|
|
|
Args: |
|
tensor: The input 2D tensor. |
|
x: The value after which to zero out elements. |
|
|
|
Returns: |
|
A new tensor with elements zeroed out after x. |
|
""" |
|
|
|
mask = (arr == x).cumsum(dim=1) > 0 |
|
result = torch.where(mask, x, arr) |
|
|
|
return result |
|
|
|
class LitModelPromptTuning(L.LightningModule): |
|
def __init__(self, model, temperature, epoch, lr=1e-4, **kwargs): |
|
super().__init__() |
|
self.model = model |
|
self.lr = lr |
|
self.model.temperature = temperature |
|
self.epoch = epoch |
|
self.temperature = temperature |
|
|
|
for key, value in kwargs.items(): |
|
setattr(self, key, value) |
|
|
|
tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"]) |
|
self.rouge = ROUGEScore(tokenizer=tokenize_to_strings) |
|
|
|
self.save_hyperparameters(ignore=['model']) |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
X, y = batch |
|
|
|
|
|
|
|
logits = self.model(X, y) |
|
|
|
logits[:,:, 50257:] = -1e4 |
|
|
|
loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True) |
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
X, y = batch |
|
|
|
logits = self.model(X, y) |
|
logits[:,:, 50257:] = -1e4 |
|
|
|
loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257) |
|
|
|
self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True) |
|
return loss |
|
|
|
def on_test_epoch_start(self, ): |
|
self.all_text = [] |
|
self.predicted_text = [] |
|
|
|
def test_step(self, batch, batch_idx): |
|
if batch_idx == 0: |
|
return |
|
X, y = batch |
|
|
|
|
|
|
|
out = self.model.generate(X) |
|
|
|
|
|
|
|
pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=False) |
|
gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=False) |
|
|
|
|
|
print(pred) |
|
print('GAP') |
|
print(gt) |
|
final_score = 0 |
|
|
|
for p,g in zip(pred, gt): |
|
score = self.rouge(p, g, ) |
|
print(score) |
|
|
|
|
|
|
|
self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True) |
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) |
|
return optimizer |
|
|
|
from lightning.pytorch.loggers import WandbLogger |
|
if __name__ == '__main__': |
|
train = False |
|
|
|
torch.set_float32_matmul_precision('medium') |
|
dl_train, dl_val, dl_test = utils.import_data(bs=24, fraction=1) |
|
|
|
if train: |
|
gpt_model = LMModel(num_prompts=12) |
|
gpt_model = torch.compile(gpt_model) |
|
else: |
|
gpt_model = torch.load('./model1.bin') |
|
|
|
|
|
model = LitModelPromptTuning( |
|
model=gpt_model, |
|
lr=1e-4, |
|
temperature=0.9, |
|
epoch = 5, |
|
|
|
type_model = 'lm_head' |
|
) |
|
print('Training') |
|
|
|
logger = WandbLogger(project='Anlp-3') |
|
trainer = L.Trainer( |
|
accelerator='gpu', |
|
|
|
|
|
|
|
devices=1, |
|
default_root_dir=f'./logs/', |
|
num_nodes=1, |
|
num_sanity_val_steps=1, |
|
precision='bf16-mixed', |
|
max_epochs=5, |
|
check_val_every_n_epoch=1, |
|
log_every_n_steps=20, |
|
logger=logger, |
|
|
|
) |
|
|
|
if train: |
|
trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val) |
|
trainer.test(model, dataloaders=dl_test) |
|
torch.save(model.model, './model1.bin') |
|
else: |
|
trainer.test(model, dataloaders=dl_test) |