File size: 4,416 Bytes
4ea6cf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))

import torch
import torch.nn as nn
import torch.optim as optim
import json
from models.model import Microformer
from config import *

# ------------------------
# LOAD DATA AND VOCAB
# ------------------------
with open("data/vocab.json", "r") as f:
    vocab = json.load(f)
    stoi = vocab["stoi"]
    itos = {int(k): v for k, v in vocab["itos"].items()}
VOCAB_SIZE = len(stoi)

data = torch.load("data/train.pt")
SEQ_LEN = MAX_SEQ_LEN
BATCH_SIZE = 32

# Drop remainder for clean batch shape
num_batches = len(data) // (SEQ_LEN * BATCH_SIZE)
trimmed_len = num_batches * SEQ_LEN * BATCH_SIZE
data = data[:trimmed_len]
data = data.view(BATCH_SIZE, -1)  # shape: (BATCH_SIZE, n_chunks)

def get_batch(start_idx):
    x = data[:, start_idx:start_idx+SEQ_LEN]
    y = data[:, start_idx+1:start_idx+SEQ_LEN+1]
    return x, y

# ------------------------
# DEVICE SETUP
# ------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

# ------------------------
# MODEL INSTANTIATION (with stacked adapters)
# ------------------------
model = Microformer(
    VOCAB_SIZE,
    EMBED_DIM,
    NUM_HEADS,
    FF_DIM,
    NUM_LAYERS,
    MAX_SEQ_LEN,
    long_term_adapter_dim=ADAPTER_DIM,     # <-- set in config
    session_adapter_dim=ADAPTER_DIM        # <-- set in config
)
model.to(device)

# ------------------------
# TRAIN LONG-TERM ADAPTERS ONLY
# ------------------------
model.freeze_except_adapters(session_only=False, include_output=True)
# (Optionally, explicitly freeze session adapters:)
for layer in model.layers:
    if getattr(layer, 'session_adapter', None) is not None:
        for param in layer.session_adapter.parameters():
            param.requires_grad = False

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

# ------------------------
# MAIN BATCH TRAINING LOOP (CORPUS)
# ------------------------
for epoch in range(6):
    for i in range(0, data.shape[1] - SEQ_LEN, SEQ_LEN):
        inputs, targets = get_batch(i)
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        out = model(inputs)
        loss = criterion(out.reshape(-1, VOCAB_SIZE), targets.reshape(-1))
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

torch.save(model.state_dict(), "microformer.pt")

# ------------------------
# ONLINE (SESSION) LEARNING UTILITY
# ------------------------
def online_unsupervised_update(model, tokenizer, text, optimizer, loss_fn, device, max_len=64):
    # Only update session adapters/output layer; call freeze_except_adapters before this as needed.
    ids = tokenizer.encode(text).ids + [tokenizer.token_to_id("<EOS>")]
    if len(ids) < 2:
        return None  # not enough tokens

    ids = ids[:max_len + 1]
    input_ids = ids[:-1]
    target_ids = ids[1:]
    input_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(input_ids))
    target_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(target_ids))
    input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
    target_tensor = torch.tensor([target_ids], dtype=torch.long, device=device)

    model.train()
    logits = model(input_tensor)
    logits = logits.view(-1, logits.size(-1))
    targets = target_tensor.view(-1)
    loss = loss_fn(logits, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    model.eval()
    return loss.item()

# ------------------------
# SESSION ADAPTER RESET FUNCTION (OPTIONAL)
# ------------------------
def reset_session_adapters(model):
    for layer in model.layers:
        if getattr(layer, 'session_adapter', None) is not None:
            for param in layer.session_adapter.parameters():
                if param.data is not None:
                    nn.init.zeros_(param.data)

# ------------------------
# USAGE FOR ONLINE LEARNING (after chat, NOT in main batch loop):
# ------------------------
# from tokenizers import Tokenizer
# tokenizer = Tokenizer.from_file("data/tokenizer.json")
# model.freeze_except_adapters(session_only=True, include_output=True)
# message = "Who is Buck?"
# loss = online_unsupervised_update(model, tokenizer, message, optimizer, criterion, device, max_len=SEQ_LEN)
# print(f"Online update loss: {loss}")