|
import sys |
|
from pathlib import Path |
|
sys.path.append(str(Path(__file__).resolve().parent.parent)) |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import json |
|
from models.model import Microformer |
|
from config import * |
|
|
|
|
|
|
|
|
|
with open("data/vocab.json", "r") as f: |
|
vocab = json.load(f) |
|
stoi = vocab["stoi"] |
|
itos = {int(k): v for k, v in vocab["itos"].items()} |
|
VOCAB_SIZE = len(stoi) |
|
|
|
data = torch.load("data/train.pt") |
|
SEQ_LEN = MAX_SEQ_LEN |
|
BATCH_SIZE = 32 |
|
|
|
|
|
num_batches = len(data) // (SEQ_LEN * BATCH_SIZE) |
|
trimmed_len = num_batches * SEQ_LEN * BATCH_SIZE |
|
data = data[:trimmed_len] |
|
data = data.view(BATCH_SIZE, -1) |
|
|
|
def get_batch(start_idx): |
|
x = data[:, start_idx:start_idx+SEQ_LEN] |
|
y = data[:, start_idx+1:start_idx+SEQ_LEN+1] |
|
return x, y |
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
model = Microformer( |
|
VOCAB_SIZE, |
|
EMBED_DIM, |
|
NUM_HEADS, |
|
FF_DIM, |
|
NUM_LAYERS, |
|
MAX_SEQ_LEN, |
|
long_term_adapter_dim=ADAPTER_DIM, |
|
session_adapter_dim=ADAPTER_DIM |
|
) |
|
model.to(device) |
|
|
|
|
|
|
|
|
|
model.freeze_except_adapters(session_only=False, include_output=True) |
|
|
|
for layer in model.layers: |
|
if getattr(layer, 'session_adapter', None) is not None: |
|
for param in layer.session_adapter.parameters(): |
|
param.requires_grad = False |
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3) |
|
|
|
|
|
|
|
|
|
for epoch in range(6): |
|
for i in range(0, data.shape[1] - SEQ_LEN, SEQ_LEN): |
|
inputs, targets = get_batch(i) |
|
inputs, targets = inputs.to(device), targets.to(device) |
|
optimizer.zero_grad() |
|
out = model(inputs) |
|
loss = criterion(out.reshape(-1, VOCAB_SIZE), targets.reshape(-1)) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
print(f"Epoch {epoch}, Loss: {loss.item():.4f}") |
|
|
|
torch.save(model.state_dict(), "microformer.pt") |
|
|
|
|
|
|
|
|
|
def online_unsupervised_update(model, tokenizer, text, optimizer, loss_fn, device, max_len=64): |
|
|
|
ids = tokenizer.encode(text).ids + [tokenizer.token_to_id("<EOS>")] |
|
if len(ids) < 2: |
|
return None |
|
|
|
ids = ids[:max_len + 1] |
|
input_ids = ids[:-1] |
|
target_ids = ids[1:] |
|
input_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(input_ids)) |
|
target_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(target_ids)) |
|
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device) |
|
target_tensor = torch.tensor([target_ids], dtype=torch.long, device=device) |
|
|
|
model.train() |
|
logits = model(input_tensor) |
|
logits = logits.view(-1, logits.size(-1)) |
|
targets = target_tensor.view(-1) |
|
loss = loss_fn(logits, targets) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
model.eval() |
|
return loss.item() |
|
|
|
|
|
|
|
|
|
def reset_session_adapters(model): |
|
for layer in model.layers: |
|
if getattr(layer, 'session_adapter', None) is not None: |
|
for param in layer.session_adapter.parameters(): |
|
if param.data is not None: |
|
nn.init.zeros_(param.data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|