#./experiments/experiment1/train.py import logging import pickle import sqlite3 import torch import torchvision import torch.optim as optim import torch.nn as nn import torch.nn.functional as F import transformers from model import RotationallyInvariantGPT, RotationallyInvariantGPTConfig from prereqs.nanoGPT.model import GPTConfig, GPT, MLP from datasets import load_from_disk from torch.utils.data import DataLoader from transformers import GPT2TokenizerFast from torch.nn.utils.rnn import pad_sequence def pad_collate(batch): # Separating inputs and labels inputs = [d['input_ids'] for d in batch] labels = [d['labels'] for d in batch] # Padding the input sequences input_tensor = pad_sequence(inputs, batch_first=True) # Padding the labels sequences label_tensor = pad_sequence(labels, batch_first=True) return {'input_ids': input_tensor, 'labels': label_tensor} class DatabaseInterface(object): def __init__(self, db_file): self.db_file = db_file def read(self, split): conn = sqlite3.connect(self.db_file) c = conn.cursor() c.execute(f"SELECT * FROM plain_text WHERE split='{split}'") col_names = [desc[0] for desc in c.description] # get column names results = [dict(zip(col_names, row)) for row in c.fetchall()] # convert tuples to dictionaries conn.close() return results class PlainTextDataset(torch.utils.data.Dataset): def __init__(self, plain_text_dataset, tokenizer, device): self.plain_text_dataset = plain_text_dataset self.tokenizer = tokenizer self.device = device def __len__(self): return len(self.plain_text_dataset) def __getitem__(self, idx): item = self.plain_text_dataset[idx] tokens = self.tokenizer.encode_plus(item["text"], truncation=True, max_length=512, padding="max_length") input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] return { 'input_ids': torch.as_tensor(input_ids[:-1], dtype=torch.long).to(self.device), 'attention_mask': torch.as_tensor(attention_mask[:-1], dtype=torch.long).to(self.device), 'labels': torch.as_tensor(input_ids[1:], dtype=torch.long).to(self.device) } def train(model: nn.Module, optimizer: optim.Optimizer, train_loader: DataLoader) -> float: model.train() running_loss = 0 for i, batch in enumerate(train_loader): inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device) optimizer.zero_grad() outputs, loss = model(inputs, targets) loss.backward() optimizer.step() running_loss += loss.item() if i % 100 == 0: logging.info(f"Batch {i}: Loss={loss.item()}") return running_loss / len(train_loader) def evaluate(model, valid_loader) -> float: model.eval() running_loss = 0 with torch.no_grad(): for i, batch in enumerate(valid_loader): inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device) outputs = model(inputs, targets) loss = outputs.loss running_loss += loss.item() if i % 100 == 0: logging.info(f"Batch {i}: Validation Loss={loss.item()}") return running_loss / len(valid_loader) if __name__ == '__main__': logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO ) logging.info(f"PyTorch version: {torch.__version__}") logging.info(f"Torchvision version: {torchvision.__version__}") logging.info(f"Transformers version: {transformers.__version__}") logging.info(f"CUDA version: {torch.version.cuda}") logging.info(f"cuDNN version: {torch.backends.cudnn.version()}") logging.info("Clearing cuda cache...") torch.cuda.empty_cache() logging.info("Setting num_threads to 1...") torch.set_num_threads(1) # Configs d_model = 512 num_heads = 4 num_layers = 1 block_size = 512 dropout = 0.2 bias = True rotational = True batch_size = 32 eval_batch_size = 64 epochs = 10 lr = 0.001 vocab_size = 50304 # GPT-2 tokenizer vocab size logging.info(f"Vocab size: {vocab_size}") logging.info(f''' Config: d_model={d_model}, num_heads={num_heads}, num_layers={num_layers}, block_size={block_size}, dropout={dropout}, bias={bias} ''' ) logging.info( f"Training for {epochs} epochs with a learning rate of {lr}..." ) logging.info(f"Batch size: {batch_size}") logging.info(f"Eval batch size: {eval_batch_size}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu") logging.info(f"Device: {device}") logging.info("Loading tokenizer") tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # Query the database for the tokenized data logging.info("Querying plain text data...") db_file_path = "data/experiment1.db" plain_text_train = DatabaseInterface(db_file_path).read("train") #logging.debug(f"Plain text train: {plain_text_train[:10]}") plain_text_val = DatabaseInterface(db_file_path).read("val") #logging.debug(f"Plain text val: {plain_text_val[:10]}") # Create train/val dataset objects train_dataset = PlainTextDataset(plain_text_train, tokenizer, device) valid_dataset = PlainTextDataset(plain_text_val, tokenizer, device) # DEBUG #for idx, item in enumerate(train_dataset): # input_ids = item["input_ids"] # attention_mask = item["attention_mask"] # if input_ids.size(0) == 0: # print(f"Sample index with 0 length: {idx}") # print(f"Input_ids: {input_ids}") # print(f"Attention_mask: {attention_mask}") # Calculate the number of batches num_train_batches = len(train_dataset) // batch_size num_eval_batches = len(valid_dataset) // eval_batch_size logging.info(f"Number of train batches: {num_train_batches}") logging.info(f"Number of eval batches: {num_eval_batches}") train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False, collate_fn=pad_collate ) valid_loader = DataLoader( valid_dataset, batch_size=eval_batch_size, shuffle=False, collate_fn=pad_collate ) # gpt_config = GPTConfig( # vocab_size=vocab_size, # n_embd=d_model, # n_head=num_heads, # n_layer=num_layers, # block_size=block_size, # dropout=dropout, # bias=bias #) rigpt_config = RotationallyInvariantGPTConfig( vocab_size=vocab_size, n_embd=d_model, n_head=num_heads, n_layer=num_layers, block_size=block_size, dropout=dropout, bias=bias, rotational_invariance=rotational ) logging.info("Creating models...") # gpt = GPT(gpt_config).to(device) rigpt = RotationallyInvariantGPT(rigpt_config).to(device) logging.info("Creating optimizers...") # optimizer_gpt = optim.Adam(gpt.parameters(), lr=lr) optimizer_rigpt = optim.Adam(rigpt.parameters(), lr=lr) logging.info("Training...") for model, optimizer, model_name in [ # ( # gpt, # optimizer_gpt, # 'GPT' # ), ( rigpt, optimizer_rigpt, 'RotationallyInvariantGPT' ) ]: print(f"Training {model_name}") for epoch in range(1, epochs + 1): print(f"Training epoch {epoch}") train_loss = train(model, optimizer, train_loader) print(f"Validating epoch {epoch}") valid_loss = evaluate(model, num_eval_batches) print( f''' {model_name} - Epoch: {epoch}, Train loss: {train_loss:.3f}, Validation loss: {valid_loss:.3f}' ''' ) # torch.save(gpt.state_dict(), "gpt.pt") torch.save(rigpt.state_dict(), "rigpt.pt")