from math import inf # from utils import * import torch import torch.nn as nn import numpy as np import torch.utils import torch.utils.data # from utils import MyDataset, custom_collate from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence import wandb import torch.nn.functional as F import einops from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast np.random.seed(123) torch.manual_seed(123) torch.cuda.random.manual_seed(123) import lightning as L import utils from torchmetrics.text.rouge import ROUGEScore def top_p_sampling(logits, p=0.9, temperature=0.5): # Apply temperature scaling logits = logits / temperature # Sort logits and get cumulative probabilities sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Create a mask for probabilities above the threshold sorted_indices_to_remove = cumulative_probs > p # Shift the indices to the right to keep also the smallest p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # Scatter sorted indices to original indices with mask indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') # Set unwanted logits to -inf # Sample from the remaining logits probs = F.softmax(logits, dim=-1) sampled_indices = torch.multinomial(probs, num_samples=1) sampled_indices = sampled_indices.squeeze(1) return sampled_indices class PromptTuningModel(nn.Module): def __init__(self, num_prompts=6): super().__init__() self.num_prompts = num_prompts self.model = AutoModelForCausalLM.from_pretrained("gpt2", ) self.model.requires_grad_(False) self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.tokenizer.add_special_tokens({'pad_token': '[START]'}) self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0] self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0] self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer),) tmp = self.tokenizer('summarise', return_tensors="pt").input_ids token_embedding = self.model.transformer.wte(tmp[0]) self.token_embedding = token_embedding for _ in range(num_prompts//3-1): self.token_embedding = torch.cat([self.token_embedding, token_embedding]) # print(self.token_embedding.shape) data = torch.zeros(num_prompts, 768) + self.token_embedding[:] self.learnable_prompt = nn.Parameter(data, requires_grad=True) # @torch.compile def forward(self, X, y): self.learnable_prompt = self.learnable_prompt.to(X.device) embeddings = self.model.transformer.wte(X, ) # b s d embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1) # mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1) # print(mask.shape) # labels = torch.where(y == 50257, -100, y) # ignore = torch.ones([X.shape[0], self.num_prompts], dtype=torch.long, device=X.device)*-100 # labels = torch.cat([ignore, labels], dim=1) out = self.model(inputs_embeds = embeddings) # print("Out.loss:", out.loss) logits = out.logits[:,self.num_prompts:] return logits def generate_new(self, X): batch_size = X.shape[0] self.learnable_prompt = self.learnable_prompt.to(X.device) embeddings = self.model.transformer.wte(X) embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(batch_size, 1, 1), embeddings], dim=1) cnt = 0 past_key_values = None generated_ids = torch.tensor([], dtype=torch.long, device=X.device).view(batch_size, 0) # Store all generated tokens while cnt < 196: out = self.model(inputs_embeds=embeddings, use_cache=True, past_key_values=past_key_values) past_key_values = out.past_key_values # print(cnt) if cnt == 0: logits = out.logits[:, self.num_prompts:] else: logits = out.logits logits[:, :, 50257:] = -1e4 # Apply after slicing for correct dimensions next_token_ids = top_p_sampling(logits[:, -1, :]) # next_token_ids will have shape (batch_size,) print(next_token_ids.shape) exit() generated_ids = torch.cat([generated_ids, next_token_ids.unsqueeze(-1)], dim=-1) embeddings = self.model.transformer.wte(next_token_ids) # Correctly obtains embeddings for current batch cnt += 1 #Check if all sequences have reached the token if torch.all((generated_ids == self.eot.item()).any(dim=-1)): # Check each sequence independently break return generated_ids def generate(self, X): # Only bs = 1 self.learnable_prompt = self.learnable_prompt.to(X.device) embeddings = self.model.transformer.wte(X, ) # b s d embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1) cnt = 0 past_key_values = None final_prediction = torch.tensor([], dtype=torch.long).to(X.device) while cnt < 196: out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values) # print(cnt, out.logits.shape) past_key_values = out.past_key_values if cnt == 0: logits = out.logits[:,self.num_prompts:] logits[:,:, 50257:] = -1e4 output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:, None] # print(output.shape) final_prediction = torch.cat([final_prediction, output], dim=1) # print(output.shape) embeddings = self.model.transformer.wte(output) # print(embeddings.shape) else: # print(logits.shape) logits = out.logits logits[:, :, 50257:] = -1e4 output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] # print(output) final_prediction = torch.cat([final_prediction, output], dim=1) # print(final_prediction.shape, 'final') embeddings = self.model.transformer.wte(output) cnt += 1 # print(output.shape, self.eot.shape) if torch.all((final_prediction == self.eot.item()).any(dim=-1)): break return final_prediction class LMModel(nn.Module): def __init__(self, num_prompts=0): super().__init__() self.num_prompts = num_prompts self.model = AutoModelForCausalLM.from_pretrained("gpt2", ) self.model.requires_grad_(False) self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.tokenizer.add_special_tokens({'pad_token': '[START]'}) self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0] self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0] self.model.lm_head.requires_grad_(True) # @torch.compile def forward(self, X, y): embeddings = self.model.transformer.wte(X, ) # b s d logits = self.model(inputs_embeds = embeddings).logits return logits def generate(self, X): # Only bs = 1 # self.learnable_prompt = self.learnable_prompt.to(X.device) embeddings = self.model.transformer.wte(X, ) # b s d # embeddings = torch.cat([self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1), embeddings], dim=1) cnt = 0 past_key_values = None final_prediction = torch.tensor([], dtype=torch.long).to(X.device) while cnt < 196: out = self.model(inputs_embeds = embeddings, use_cache=True, past_key_values=past_key_values) # print(cnt, out.logits.shape) past_key_values = out.past_key_values if cnt == 0: logits = out.logits[:,self.num_prompts:] logits[:,:, 50257:] = -1e4 output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] # print(output.shape) final_prediction = torch.cat([final_prediction, output], dim=1) # print(output.shape) embeddings = self.model.transformer.wte(output) # print(embeddings.shape) else: # print(logits.shape) logits = out.logits logits[:, :, 50257:] = -1e4 output = top_p_sampling(logits[:,-1,:], temperature=self.temperature)[:,None] # print(output) final_prediction = torch.cat([final_prediction, output], dim=1) # print(final_prediction.shape, 'final') embeddings = self.model.transformer.wte(output) cnt += 1 # print(output.shape, self.eot.shape) if torch.all((final_prediction == self.eot.item()).any(dim=-1)): break return final_prediction def zero_after_x(tensor, x): """ Zeros out all elements in each row of a 2D tensor after the first occurrence of x. Args: tensor: The input 2D tensor. x: The value after which to zero out elements. Returns: A new tensor with elements zeroed out after x. """ mask = (tensor == x).cumsum(dim=1) > 0 # Create a cumulative mask result = tensor.where(~mask, torch.ones_like(tensor, dtype=torch.long)*x) #zero out where mask is True return result class LitModelPromptTuning(L.LightningModule): def __init__(self, model, lr=1e-4, temperature): super().__init__() self.model = model self.lr = lr self.model.temperature = temperature tokenize_to_strings = lambda text: self.model.tokenizer.convert_ids_to_tokens(self.model.tokenizer(text)["input_ids"]) self.rouge = ROUGEScore(tokenizer=tokenize_to_strings) self.save_hyperparameters(ignore=['model']) def training_step(self, batch, batch_idx): X, y = batch # for i,j in zip(X[1], y[1]): # print(i.item(),j.item()) logits = self.model(X, y) logits[:,:, 50257:] = -1e4 # print(X.shape, y.shape, logits.shape) loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257) # print(loss) # prob = logits.softmax(dim=-1)[0,-300:] # target = y[0, -300:] # print('logits',logits[0,-300:][torch.arange(target.numel()), target]) # print(prob[torch.arange(target.numel()), target]) # print(prob.argmax(dim=-1), target, X[0, -300:]) # print(self.model.tokenizer.decode(prob.argmax(dim=-1)), 'gap', self.model.tokenizer.decode(target)) # print(self.model.tokenizer.decode(X[0,-300:])) # x = F.cross_entropy(logits[0,:-1,:].reshape(-1, logits.shape[-1]), target=y[0,:-1].reshape(-1), ignore_index=50257, reduction='none') # print(x[-20:]) # print(self.model.pad) # print(X[0, -300:].shape, target.shape, prob.argmax(dim=-1).shape) # for i,j,k in zip(X[0, -300:], target, prob.argmax(dim=-1)): # print(self.model.tokenizer.decode(i),'\tx ',self.model.tokenizer.decode(j),'\tx ',self.model.tokenizer.decode(k)) # exit() self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True) return loss def validation_step(self, batch, batch_idx): X, y = batch logits = self.model(X, y) logits[:,:, 50257:] = -1e4 # print(X.shape, y.shape, logits.shape) loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257) self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True) return loss def on_test_epoch_start(self, ): self.all_text = [] self.predicted_text = [] def test_step(self, batch, batch_idx): if batch_idx == 0: return X, y = batch # print(self.model.tokenizer.batch_decode(X)) # print(X.shape) # with torch.no_grad() out = self.model.generate(X) # out = zero_after_x(out, self.model.eot.item()) # print(out.shape, y.shape) # print(out, y) pred = self.model.tokenizer.batch_decode(out, skip_special_tokens=True) gt = self.model.tokenizer.batch_decode(y, skip_special_tokens=True) print(pred) print('GAP') print(gt) final_score = 0 for p,g in zip(pred, gt): score = self.rouge(p, g, ) print(score) # exit() self.log_dict(score, on_step=True, on_epoch=True, logger=True, sync_dist=True) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) return optimizer from lightning.pytorch.loggers import WandbLogger if __name__ == '__main__': torch.set_float32_matmul_precision('medium') dl_train, dl_val, dl_test = utils.import_data(bs=25, fraction=0.1) # gpt_model = PromptTuningModel(num_prompts=12) gpt_model = LMModel(num_prompts=0) # gpt_model = torch.compile(gpt_model) model = LitModelPromptTuning( model=gpt_model, lr=1e-4, temperature=0.9, epoch = 10 ) print('Training') logger = WandbLogger(project='Anlp-3') trainer = L.Trainer( accelerator='gpu', # limit_train_batches=1, # strategy='auto', # strategy=pl.strategies.DDPStrategy(find_unused_parameters=True), devices=[2], default_root_dir=f'./logs/', # Tensorflow can be used to viz num_nodes=1, num_sanity_val_steps=1, # runs a validation step before stating training precision='bf16-mixed', # we use half precision to reduce memory usage max_epochs=5, check_val_every_n_epoch=1, # run validation every epoch log_every_n_steps=20, logger=logger, # detect_anomaly=True, ) # trainer.test(model, dataloaders=dl_test) trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val) trainer.test(model, dataloaders=dl_test)