from math import inf # from utils import * import torch import torch.nn as nn import numpy as np import torch.utils import torch.utils.data # from utils import MyDataset, custom_collate from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence,pack_padded_sequence import wandb import torch.nn.functional as F import einops from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2TokenizerFast np.random.seed(123) torch.manual_seed(123) torch.cuda.random.manual_seed(123) import lightning as L import utils class PromptTuningModel(nn.Module): def __init__(self, num_prompts=6): super().__init__() self.num_prompts = num_prompts self.model = AutoModelForCausalLM.from_pretrained("gpt2", ) self.model.requires_grad_(False) self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.tokenizer.add_special_tokens({'pad_token': '[START]'}) self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0] self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0] self.model.resize_token_embeddings(new_num_tokens = len(self.tokenizer), pad_to_multiple_of=128) tmp = self.tokenizer('summarise', return_tensors="pt").input_ids token_embedding = self.model.transformer.wte(tmp[0]) self.token_embedding = token_embedding for _ in range(num_prompts//3-1): self.token_embedding = torch.cat([self.token_embedding, token_embedding]) # print(self.token_embedding.shape) data = torch.zeros(num_prompts, 768) + self.token_embedding[:] self.learnable_prompt = nn.Parameter(data, requires_grad=True) # @torch.compile def forward(self, X, y): self.learnable_prompt = self.learnable_prompt.to(X.device) embeddings = self.model.transformer.wte(X, ) # b s d embeddings = torch.cat([embeddings, self.learnable_prompt[None, :, :].repeat(X.shape[0], 1, 1)], dim=1) # mask = torch.cat([torch.ones([X.shape[0],self.num_prompts], dtype=torch.long).to(X), torch.where(X != self.pad.to(X.device), 1, 0)], dim=1) # print(mask.shape) labels = torch.where(y == 50257, -100, y) ignore = torch.ones([X.shape[0], self.num_prompts], dtype=torch.long, device=X.device)*-100 labels = torch.cat([ignore, labels], dim=1) out = self.model(inputs_embeds = embeddings) # print("Out.loss:", out.loss) logits = out.logits[:,self.num_prompts:] return logits class LMModel(nn.Module): def __init__(self, num_prompts=6): super().__init__() self.num_prompts = num_prompts self.model = AutoModelForCausalLM.from_pretrained("gpt2", ) self.model.requires_grad_(False) self.tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.tokenizer.add_special_tokens({'pad_token': '[START]'}) self.eot = self.tokenizer("<|endoftext|>", return_tensors="pt").input_ids[0] self.pad = self.tokenizer("[PAD]", return_tensors="pt").input_ids[0] self.model.lm_head.requires_grad_(True) # @torch.compile def forward(self, X): embeddings = self.model.transformer.wte(X, ) # b s d logits = self.model(inputs_embeds = embeddings).logits[:,self.num_prompts:] return logits class LitModelPromptTuning(L.LightningModule): def __init__(self, model, lr=1e-4): super().__init__() self.model = model self.lr = lr self.save_hyperparameters(ignore=['model']) def training_step(self, batch, batch_idx): X, y = batch # for i,j in zip(X[1], y[1]): # print(i.item(),j.item()) logits = self.model(X, y) # print(X.shape, y.shape, logits.shape) loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257) print(loss) prob = logits.softmax(dim=-1)[0,-20:] target = y[0, -20:] print(prob[torch.arange(target.numel()), target]) print(prob.argmax(dim=-1), target) print(self.model.tokenizer.decode(prob.argmax(dim=-1)), self.model.tokenizer.decode(target)) x = F.cross_entropy(logits[0,:-1,:].reshape(-1, logits.shape[-1]), target=y[0,:-1].reshape(-1), ignore_index=50257, reduction='none') print(x) exit() self.log('Training loss', loss, on_step=True, on_epoch=True,logger=True, sync_dist=True) return loss def validation_step(self, batch, batch_idx): X, y = batch logits = self.model(X, y) print(X.shape, y.shape, logits.shape) loss = F.cross_entropy(logits[:,:-1,:].reshape(-1, logits.shape[-1]), target=y[:,:-1].reshape(-1), ignore_index=50257) self.log('Validation loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True) return loss def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) return optimizer from lightning.pytorch.loggers import WandbLogger if __name__ == '__main__': torch.set_float32_matmul_precision('medium') dl_train, dl_val, dl_test = utils.import_data(bs=5, fraction=0.1) gpt_model = PromptTuningModel(num_prompts=12) # gpt_model = LMModel(num_prompts=12) # gpt_model = torch.compile(gpt_model) model = LitModelPromptTuning(model=gpt_model, lr=1e-4) print('Training') logger = WandbLogger(project='Anlp-3') trainer = L.Trainer( accelerator='gpu', limit_train_batches=1, # strategy='auto', # strategy=pl.strategies.DDPStrategy(find_unused_parameters=True), devices=[3], default_root_dir=f'./logs/', # Tensorflow can be used to viz num_nodes=1, num_sanity_val_steps=1, # runs a validation step before stating training precision='16-mixed', # we use half precision to reduce memory usage max_epochs=100, check_val_every_n_epoch=50, # run validation every epoch log_every_n_steps=20, logger=logger, # detect_anomaly=True, ) trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_val)