File size: 4,416 Bytes
4ea6cf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
import torch
import torch.nn as nn
import torch.optim as optim
import json
from models.model import Microformer
from config import *
# ------------------------
# LOAD DATA AND VOCAB
# ------------------------
with open("data/vocab.json", "r") as f:
vocab = json.load(f)
stoi = vocab["stoi"]
itos = {int(k): v for k, v in vocab["itos"].items()}
VOCAB_SIZE = len(stoi)
data = torch.load("data/train.pt")
SEQ_LEN = MAX_SEQ_LEN
BATCH_SIZE = 32
# Drop remainder for clean batch shape
num_batches = len(data) // (SEQ_LEN * BATCH_SIZE)
trimmed_len = num_batches * SEQ_LEN * BATCH_SIZE
data = data[:trimmed_len]
data = data.view(BATCH_SIZE, -1) # shape: (BATCH_SIZE, n_chunks)
def get_batch(start_idx):
x = data[:, start_idx:start_idx+SEQ_LEN]
y = data[:, start_idx+1:start_idx+SEQ_LEN+1]
return x, y
# ------------------------
# DEVICE SETUP
# ------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# ------------------------
# MODEL INSTANTIATION (with stacked adapters)
# ------------------------
model = Microformer(
VOCAB_SIZE,
EMBED_DIM,
NUM_HEADS,
FF_DIM,
NUM_LAYERS,
MAX_SEQ_LEN,
long_term_adapter_dim=ADAPTER_DIM, # <-- set in config
session_adapter_dim=ADAPTER_DIM # <-- set in config
)
model.to(device)
# ------------------------
# TRAIN LONG-TERM ADAPTERS ONLY
# ------------------------
model.freeze_except_adapters(session_only=False, include_output=True)
# (Optionally, explicitly freeze session adapters:)
for layer in model.layers:
if getattr(layer, 'session_adapter', None) is not None:
for param in layer.session_adapter.parameters():
param.requires_grad = False
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
# ------------------------
# MAIN BATCH TRAINING LOOP (CORPUS)
# ------------------------
for epoch in range(6):
for i in range(0, data.shape[1] - SEQ_LEN, SEQ_LEN):
inputs, targets = get_batch(i)
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
out = model(inputs)
loss = criterion(out.reshape(-1, VOCAB_SIZE), targets.reshape(-1))
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
torch.save(model.state_dict(), "microformer.pt")
# ------------------------
# ONLINE (SESSION) LEARNING UTILITY
# ------------------------
def online_unsupervised_update(model, tokenizer, text, optimizer, loss_fn, device, max_len=64):
# Only update session adapters/output layer; call freeze_except_adapters before this as needed.
ids = tokenizer.encode(text).ids + [tokenizer.token_to_id("<EOS>")]
if len(ids) < 2:
return None # not enough tokens
ids = ids[:max_len + 1]
input_ids = ids[:-1]
target_ids = ids[1:]
input_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(input_ids))
target_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(target_ids))
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
target_tensor = torch.tensor([target_ids], dtype=torch.long, device=device)
model.train()
logits = model(input_tensor)
logits = logits.view(-1, logits.size(-1))
targets = target_tensor.view(-1)
loss = loss_fn(logits, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
return loss.item()
# ------------------------
# SESSION ADAPTER RESET FUNCTION (OPTIONAL)
# ------------------------
def reset_session_adapters(model):
for layer in model.layers:
if getattr(layer, 'session_adapter', None) is not None:
for param in layer.session_adapter.parameters():
if param.data is not None:
nn.init.zeros_(param.data)
# ------------------------
# USAGE FOR ONLINE LEARNING (after chat, NOT in main batch loop):
# ------------------------
# from tokenizers import Tokenizer
# tokenizer = Tokenizer.from_file("data/tokenizer.json")
# model.freeze_except_adapters(session_only=True, include_output=True)
# message = "Who is Buck?"
# loss = online_unsupervised_update(model, tokenizer, message, optimizer, criterion, device, max_len=SEQ_LEN)
# print(f"Online update loss: {loss}")
|