| import numpy as np # linear algebra | |
| import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) | |
| import torch | |
| import string | |
| import pandas as pd | |
| import numpy as np | |
| from torch import nn | |
| from sklearn.model_selection import train_test_split | |
| # from gensim.models import Word2Vec | |
| from torch.nn.utils.rnn import pack_padded_sequence | |
| from pathlib import Path | |
| import argparse | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Trainer, TrainingArguments, AdamW, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel | |
| from transformers import GPTNeoForCausalLM, GPT2Tokenizer ,GPTNeoConfig | |
| from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel,BertTokenizer | |
| from transformers import GPT2TokenizerFast | |
| # from peft import LoraModel, LoraConfig | |
| from pathlib import Path | |
| import datetime | |
| from tqdm import tqdm | |
| import random | |
| from tqdm import tqdm | |
| from torch.cuda.amp import autocast, GradScaler | |
| import gc | |
| import matplotlib.pyplot as plt | |
| class Encoder(torch.nn.Module): #8,18,24 -> 8,40,24 (8x720 and 432x960) | |
| def __init__(self,h=128,n=8, e=64, a=4, o=1280): | |
| super(Encoder, self).__init__() | |
| self.embed = nn.Embedding(50257,e) | |
| # self.ip = nn.Sequential( | |
| # nn.Linear(e,e//2), | |
| # nn.ReLU(), | |
| # nn.Linear(e//2,e) | |
| # ) | |
| self.lstm = nn.LSTM(input_size=e,hidden_size=h,num_layers=n, batch_first=True, bidirectional=True) | |
| self.sa = nn.MultiheadAttention(h*2, a, dropout=0.1, batch_first=True) | |
| self.op = nn.Sequential( | |
| nn.Linear(2*h, h//2), | |
| nn.ReLU(), | |
| nn.Linear(h//2 , o), | |
| ) | |
| # self.__init_weights() | |
| def forward(self, X): | |
| emb = self.embed(X) #bs,seq ,e | |
| # emb = self.ip(emb) | |
| enc, (hidden, cell) = self.lstm(emb) #bs, seq, h #1,bs,h | |
| query = enc #nn.MA expects ; seq, bs, h | |
| atOp , atW = self.sa(query, query, query) | |
| #convert back to bs,seq, h | |
| # print(f'AtOp: {atOp.shape} | enc: {enc.shape}') | |
| logits = self.op(atOp + enc) | |
| # logits = self.op(enc) | |
| return logits , hidden , cell | |
| # def __init_weights(self): | |
| # for module in [self.ip, self.op]: | |
| # if isinstance(module, torch.nn.Linear): | |
| # torch.nn.init.normal_(module.weight,mean = 0.0 , std=0.02) | |
| # if module.bias is not None: | |
| # torch.nn.init.zeros_(module.bias) | |
| class Decoder(torch.nn.Module): | |
| def __init__(self,h=128,n=8, e=64, a=4, o=50257): | |
| super(Decoder, self).__init__() | |
| self.embed = nn.Embedding(50257,e) | |
| # self.ip = nn.Sequential( | |
| # nn.Linear(e,e), | |
| # nn.ReLU(), | |
| # nn.Linear(e,e) | |
| # ) | |
| self.lstm = nn.LSTM(input_size=e,hidden_size=h,num_layers=n, batch_first=True, bidirectional=True) | |
| self.sa = nn.MultiheadAttention(h, a, dropout=0.1, batch_first=True) | |
| self.op = nn.Sequential( | |
| nn.Linear(2*h + e, h//2), | |
| nn.ReLU(), | |
| nn.Linear(h//2 , o), | |
| ) | |
| # self.__init_weights() | |
| def forward(self, ip, ho, co, enc, mask): | |
| emb = self.embed(ip) #bs, seq_i, e | |
| # emb = self.ip(emb) | |
| dec, (ho, co) = self.lstm(emb, (ho, co)) #bs, seq_i, h #1,bs,h | |
| query = emb #bs, seq_i, e | |
| key = enc #bs, seq_e, o | |
| value = enc #bs, seq_e, o | |
| # print(f'Q:{query.shape} | K:{key.shape} | V:{value.shape}') | |
| atOp , atW = self.sa(query, key, value, key_padding_mask=mask) #bs, seq_i, e | |
| # print(f'Dec: {dec.shape} | atOp : {atOp.shape}') | |
| op = torch.cat([dec.squeeze(dim=1), atOp.squeeze(dim=1)], dim=1) #bs, seq_i, 2*h + bs, seq_i, e -> bs, 2*h + r | |
| # op = torch.cat([ho[-1], co[-1], atOp.reshape(atOp.size(0), -1)], dim=-1) | |
| logits = self.op(op) #bs, o | |
| return logits, ho ,co | |
| # def __init_weights(self): | |
| # for module in [self.ip, self.op]: | |
| # if isinstance(module, torch.nn.Linear): | |
| # torch.nn.init.normal_(module.weight,mean = 0.0 , std=0.02) | |
| # if module.bias is not None: | |
| # torch.nn.init.zeros_(module.bias) | |
| def init_state(self, batch_size): | |
| return (torch.zeros(2*self.n,batch_size, self.h).to(device),torch.zeros(2*self.n,batch_size, self.h).to(device)) | |
| class Seq2Seq(nn.Module): | |
| def __init__(self, encoder, decoder): | |
| super(Seq2Seq, self).__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| def forward(self, seq_ip, ip_mask, seq_tg): | |
| enc, hidden, cell = self.encoder(seq_ip) | |
| outputs = [] | |
| len_tg = seq_tg.shape[1] | |
| dec_ip = seq_tg[:,0].unsqueeze(dim=-1) | |
| # print('Target length: ') | |
| for t in range(1, len_tg): # Teacher Forcing | |
| op , hidden, cell = self.decoder(dec_ip, hidden, cell, enc, ip_mask) | |
| outputs.append(op) | |
| dec_ip = seq_tg[:,t].unsqueeze(dim=-1) | |
| torch.stack(outputs, dim=1) | |
| return outputs | |
| def diverse_beam_search(decoder, encoder_output, ip_mask, hidden, cell, device, beam_width=5, diversity_penalty=0.7, max_len=100): | |
| dec_ip = torch.tensor([50256]).type(torch.int64).to(device) # Start token | |
| beams = [(0.0, [dec_ip.item()], hidden.clone(), cell.clone())] # (score, sequence, hidden, cell) | |
| count = 0 | |
| for _ in range(max_len): | |
| all_candidates = [] | |
| for score, seq, h, c in beams: | |
| if seq[-1] == 50256 and count > 0: # EOS reached | |
| all_candidates.append((score, seq, h, c)) | |
| continue | |
| dec_out, h_new, c_new = decoder( | |
| torch.tensor([seq[-1]]).unsqueeze(0).to(device), h, c, encoder_output, ip_mask | |
| ) | |
| log_probs = torch.nn.functional.log_softmax(dec_out, dim=-1) # Shape: [1, vocab_size] | |
| top_k_log_probs, top_k_tokens = torch.topk(log_probs, beam_width, dim=-1) | |
| for i in range(beam_width): | |
| new_score = score + top_k_log_probs[0, i].item() - (diversity_penalty * i) # Diversity penalty | |
| new_seq = seq + [top_k_tokens[0, i].item()] | |
| all_candidates.append((new_score, new_seq, h_new.clone(), c_new.clone())) | |
| count = 1 | |
| # Select top beam_width candidates | |
| beams = sorted(all_candidates, key=lambda x: x[0], reverse=True)[:beam_width] | |
| if all(seq[-1] == 50256 for _, seq, _, _ in beams): # All beams ended | |
| break | |
| return beams[0][1] # Return highest-scoring sequence | |
| def mbr_decoding(decoder, encoder_output, ip_mask, hidden, cell, device, num_candidates=10, max_len=100): | |
| # Generate candidate sequences using top-k sampling | |
| candidates = [] | |
| for _ in range(num_candidates): | |
| dec_ip = torch.tensor([50256]).type(torch.int64).to(device) | |
| seq = [dec_ip.item()] | |
| h, c = hidden.clone(), cell.clone() | |
| for _ in range(max_len): | |
| dec_out, h, c = decoder(dec_ip.unsqueeze(0), h, c, encoder_output, ip_mask) | |
| dec_ip = top_k_sampling(dec_out, k=5).unsqueeze(dim=0) # Use top-k for diversity | |
| seq.append(dec_ip.item()) | |
| if dec_ip.item() == 50256: | |
| break | |
| candidates.append(seq) | |
| # Score candidates by similarity (e.g., average overlap with others) | |
| best_seq, best_score = None, float('-inf') | |
| for i, cand in enumerate(candidates): | |
| score = sum(sum(1 for t1, t2 in zip(cand, other) if t1 == t2) | |
| for other in candidates if other != cand) / (len(candidates) - 1) | |
| if score > best_score: | |
| best_score, best_seq = score, cand | |
| return best_seq | |
| def top_k_sampling(logits, k=10, temperature=1.0): | |
| logits = logits / temperature # Temperature scaling for diversity | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| top_k_probs, top_k_indices = torch.topk(probs, k, dim=-1) | |
| sampled_idx = torch.multinomial(top_k_probs, num_samples=1) | |
| return top_k_indices[0, sampled_idx.item()] | |
| def genOp(encoder, decoder, device, ip, ip_mask, mode='greedy', temperature=1.0, k=13, beam_width=5, diversity_penalty=0.7, num_candidates=10, max_len=100): | |
| encoder.eval() | |
| decoder.eval() | |
| # model.eval() | |
| print(f'\n\n\n GENOP FX CALL \n\n\n') | |
| with torch.no_grad(): | |
| enc, hidden, cell = encoder(ip) | |
| print(f'Hidden : {hidden.shape} | Cell : {cell.shape}') | |
| if mode == 'greedy': | |
| outputs = [] | |
| dec_ip = torch.tensor([50256]).type(torch.int64).to(device) | |
| count = 0 | |
| while True: | |
| dec, hidden, cell = decoder(dec_ip.unsqueeze(dim=0), hidden, cell, enc, ip_mask) | |
| dec_ip = torch.argmax(dec, dim=-1) | |
| outputs.append(dec_ip.item()) | |
| count += 1 | |
| if count > max_len: | |
| break | |
| if dec_ip.item() == 50256: | |
| print('Self terminated !!!') | |
| break | |
| return outputs | |
| elif mode=='sample': | |
| outputs = [] | |
| dec_ip = torch.tensor([50256]).type(torch.int64).to(device) | |
| count = 0 | |
| while True: | |
| dec, hidden, cell = decoder(dec_ip.unsqueeze(dim=0), hidden, cell, enc, ip_mask) | |
| # print(dec) | |
| dec = dec/temperature | |
| dec = torch.nn.functional.softmax(dec, dim=-1) | |
| dec_ip = torch.multinomial(input=dec, num_samples=1, replacement=True).squeeze(0) | |
| outputs.append(dec_ip.item()) | |
| count += 1 | |
| if count > max_len: | |
| break | |
| if dec_ip.item() == 50256: | |
| print('Self terminated !!!') | |
| break | |
| return outputs | |
| elif mode=='top_k': | |
| outputs = [] | |
| dec_ip = torch.tensor([50256]).type(torch.int64).to(device) | |
| count = 0 | |
| while True: | |
| dec, hidden, cell = decoder(dec_ip.unsqueeze(dim=0), hidden, cell, enc, ip_mask) | |
| dec = torch.nn.functional.softmax(dec, dim=-1) | |
| top_k_probs , top_k_indices = torch.topk(dec, k, dim=-1) | |
| dec_ip = torch.multinomial(input=top_k_probs, num_samples=1, replacement=True).squeeze(0) | |
| dec_ip = top_k_indices[0, dec_ip.item()].unsqueeze(dim=0) | |
| outputs.append(dec_ip.item()) | |
| count += 1 | |
| if count > max_len: | |
| break | |
| if dec_ip.item() == 50256: | |
| print('Self terminated !!!') | |
| break | |
| return outputs | |
| elif mode=='diverse-beam-search': | |
| outputs = diverse_beam_search(decoder, enc, ip_mask, hidden, cell, device, beam_width=beam_width, diversity_penalty=diversity_penalty) | |
| # print(f'GenOP stack trace: {outputs}') | |
| return outputs | |
| elif mode=='min-bayes-risk': | |
| outputs = mbr_decoding(decoder, enc, ip_mask, hidden, cell, device, num_candidates=num_candidates, max_len=max_len) | |
| return outputs | |
| # ip = torch.tensor([[50256, 11195, 318, 13837, 11, 8272, 318, 2688, 4345, 1578, | |
| # 11, 4475, 318, 3909, 11, 3035, 767, 11, 1941, 318, | |
| # 4793, 11, 2435, 357, 315, 66, 8, 318, 1478, 25, | |
| # 405, 11, 1078, 437, 590, 318, 3126, 11, 2931, 23, | |
| # 11, 4080, 318, 24880, 10499, 11, 3576, 11, 4492, 11, | |
| # 19316, 318, 4793, 12, 12726, 37985, 9952, 4041, 11, 6057, | |
| # 62, 13376, 318, 19446, 11, 30408, 448, 318, 10352, 11, | |
| # 11195, 62, 26675, 318, 657, 11, 8272, 62, 26675, 318, | |
| # 352, 11, 11195, 62, 79, 49809, 47, 310, 318, 5598, | |
| # 7441, 8272, 62, 79, 49809, 47, 310, 318, 4570, 7441, | |
| # 11195, 62, 20910, 22093, 318, 1542, 357, 1314, 828, 8272, | |
| # 62, 20910, 22093, 318, 718, 357, 20, 828, 11195, 62, | |
| # 69, 42033, 6935, 2175, 318, 838, 13, 15, 11, 8272, | |
| # 62, 69, 42033, 6935, 2175, 318, 1315, 13, 15, 11, | |
| # 11195, 62, 36022, 34, 1371, 318, 657, 13, 15, 11, | |
| # 8272, 62, 36022, 34, 1371, 318, 352, 13, 15, 11, | |
| # 11195, 62, 445, 34, 1371, 318, 657, 13, 15, 11, | |
| # 8272, 62, 445, 34, 1371, 318, 657, 13, 15, 11, | |
| # 11195, 62, 8210, 1460, 318, 657, 13, 15, 11, 8272, | |
| # 62, 8210, 1460, 318, 604, 13, 15, 11, 11195, 62, | |
| # 26502, 41389, 364, 318, 1478, 13, 15, 11, 8272, 62, | |
| # 26502, 41389, 364, 318, 352, 13, 15, 11, 11195, 62, | |
| # 82, 3080, 318, 642, 13, 15, 11, 8272, 62, 82, | |
| # 3080, 318, 1596, 13, 15, 11, 11195, 62, 1161, 318, | |
| # 16185, 11, 8272, 62, 1161, 318, 16185, 11, 24623, 318, | |
| # 3594, 9952, 4041, 11, 16060, 62, 15592, 318, 449, 641, | |
| # 29921, 9038, 11, 17121, 7096, 292, 11, 42, 14057, 9852, | |
| # 2634, 11, 10161, 18713, 12119, 280, 2634, 11, 35389, 26689, | |
| # 75, 1012, 488, 88, 11, 30847, 11979, 406, 73, 2150, | |
| # 3900, 11, 13787, 292, 10018, 17479, 11, 40747, 32371, 23720, | |
| # 11, 15309, 38142, 81, 367, 293, 65, 11, 34, 3798, | |
| # 376, 24247, 65, 2301, 292, 11, 10161, 18713, 1215, 1765, | |
| # 323, 273, 11, 5124, 2731, 978, 6199, 544, 11, 49680, | |
| # 68, 311, 2194, 418, 11, 41, 21356, 48590, 18226, 12523, | |
| # 11, 4826, 280, 6031, 3930, 11, 31579, 44871, 12104, 324, | |
| # 13235, 11, 32, 1014, 62, 15592, 318, 5199, 3469, 11, | |
| # 22946, 292, 3169, 359, 11, 20191, 44677, 11, 13217, 261, | |
| # 44312, 11, 14731, 14006, 11, 24338, 9740, 9860, 11, 25372, | |
| # 20017, 9557, 11, 45, 47709, 797, 78, 12, 34, 11020, | |
| # 11, 9704, 20833, 11, 33, 11369, 38343, 5799, 11, 26886, | |
| # 418, 1665, 33425, 11, 32027, 21298, 11, 31306, 6559, 19574, | |
| # 1040, 11, 30365, 13058, 273, 11, 25596, 271, 3248, 64, | |
| # 10788, 68, 11, 42, 538, 64, 11, 7575, 318, 4153, | |
| # 6]]) | |
| # ip_mask = torch.tensor([[True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True, True, | |
| # True, True, True, True, True, True, True, True, True, True, True]]) | |
| # device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # encoder = Encoder(h=64,n=2, e=64, a=4, o=64).to(device) | |
| # decoder = Decoder(h=64,n=2, e=64, a=4, o=50257).to(device) | |
| # model = Seq2Seq(encoder, decoder).to(device) | |
| # # checkpoint = torch.load('./seq2seq_checkpoint.pt', weights_only=True, map_location=device) | |
| # # model.load_state_dict(checkpoint['model_state_dict']) | |
| # print(genOp(model.encoder, model.decoder, device, ip, ip_mask, mode='greedy', temperature=1.0, k=13, beam_width=5, diversity_penalty=0.7, num_candidates=10, max_len=100)) |