vil-encoder / train_vil_encoder_v2.py
Nine1Eight
Initial Linux build for VIL encoder
e566f33
#!/usr/bin/env python3
import json
from pathlib import Path
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
TRAIN_PATH = Path("data/train.jsonl")
MODEL_OUT = Path("vil-encoder-v2.pt")
SEQ_LEN = 64
EMBED_DIM = 32
BATCH_SIZE = 128
EPOCHS = 12
LR = 1e-3
WEIGHT_DECAY = 1e-5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 918
torch.manual_seed(SEED)
np.random.seed(SEED)
def encode_triplet(visible: str, braille: str, hanzi: str) -> np.ndarray:
text = f"{visible}|{braille}|{hanzi}"
arr = np.array([ord(c) % 256 for c in text], dtype=np.float32)
if arr.shape[0] < SEQ_LEN:
arr = np.pad(arr, (0, SEQ_LEN - arr.shape[0]))
else:
arr = arr[:SEQ_LEN]
arr /= 255.0
return arr
def load_rows(path: Path) -> List[dict]:
rows: List[dict] = []
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
rows.append(json.loads(line))
if not rows:
raise RuntimeError(f"No rows loaded from {path}")
return rows
class PairDataset(Dataset):
def __init__(self, rows: List[dict]) -> None:
self.rows = rows
self.inputs = np.stack([
encode_triplet(r["visible"], r["braille"], r["hanzi"]) for r in rows
]).astype(np.float32)
def __len__(self) -> int:
return len(self.rows)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
anchor = self.inputs[idx]
pos_idx = (idx + 1) % len(self.inputs)
positive = self.inputs[pos_idx]
return torch.from_numpy(anchor), torch.from_numpy(positive)
class Encoder(nn.Module):
def __init__(self, input_dim: int = SEQ_LEN, embed_dim: int = EMBED_DIM) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, embed_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = self.net(x)
return nn.functional.normalize(z, dim=-1)
def cosine_pull_loss(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return 1.0 - nn.functional.cosine_similarity(a, b).mean()
def main() -> None:
rows = load_rows(TRAIN_PATH)
dataset = PairDataset(rows)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
model = Encoder().to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
best_loss = float("inf")
history = []
for epoch in range(EPOCHS):
model.train()
running = 0.0
batches = 0
for x1, x2 in loader:
x1 = x1.to(DEVICE)
x2 = x2.to(DEVICE)
z1 = model(x1)
z2 = model(x2)
loss = cosine_pull_loss(z1, z2)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
running += float(loss.item())
batches += 1
epoch_loss = running / max(1, batches)
history.append(epoch_loss)
print(f"epoch={epoch:02d} loss={epoch_loss:.6f}")
if epoch_loss < best_loss:
best_loss = epoch_loss
checkpoint = {
"model_state_dict": model.state_dict(),
"config": {
"input_dim": SEQ_LEN,
"embed_dim": EMBED_DIM,
},
"history": history,
}
torch.save(checkpoint, MODEL_OUT)
print(f"saved={MODEL_OUT} best_loss={best_loss:.6f}")
if __name__ == "__main__":
main()