Add fine-tuning test suite — 15 tests covering dataset, batching, forward/backward, Trainer smoke, multiclass
Browse files- tests/test_finetune.py +166 -0
tests/test_finetune.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for domainTokenizer Phase 2D: Fine-tuning Pipeline.
|
| 3 |
+
15 tests covering dataset, batching, forward/backward, Trainer smoke, multiclass.
|
| 4 |
+
|
| 5 |
+
Run: pytest tests/test_finetune.py -v
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import random
|
| 10 |
+
from datetime import datetime, timedelta
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import pytest
|
| 15 |
+
|
| 16 |
+
from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA
|
| 17 |
+
from domain_tokenizer.tokenizers.domain_tokenizer import DomainTokenizerBuilder
|
| 18 |
+
from domain_tokenizer.models.configuration import DomainTransformerConfig
|
| 19 |
+
from domain_tokenizer.models.modeling import DomainTransformerForCausalLM
|
| 20 |
+
from domain_tokenizer.models.joint_fusion import JointFusionModel
|
| 21 |
+
from domain_tokenizer.training.finetune_data import DomainFinetuneDataset, prepare_finetune_dataset
|
| 22 |
+
from domain_tokenizer.training.finetune import finetune_domain_model
|
| 23 |
+
|
| 24 |
+
logging.basicConfig(level=logging.INFO)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def make_events(n=10, seed=42):
|
| 28 |
+
rng = random.Random(seed)
|
| 29 |
+
merchants = ["AMAZON", "UBER", "SALARY", "GROCERY", "NETFLIX", "GAS"]
|
| 30 |
+
base = datetime(2025, 1, 1)
|
| 31 |
+
return [
|
| 32 |
+
{"amount_sign": (a := rng.uniform(5, 5000) * rng.choice([1, -1])),
|
| 33 |
+
"amount": a, "timestamp": base + timedelta(days=rng.randint(0, 365), hours=rng.randint(0, 23)),
|
| 34 |
+
"description": rng.choice(merchants)}
|
| 35 |
+
for _ in range(n)
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def make_labeled_data(n_users=20, n_tabular=10, seed=42):
|
| 40 |
+
rng = random.Random(seed)
|
| 41 |
+
seqs = [make_events(rng.randint(3, 15), seed + i) for i in range(n_users)]
|
| 42 |
+
tab = np.random.RandomState(seed).randn(n_users, n_tabular).astype(np.float32)
|
| 43 |
+
labels = np.array([rng.choice([0.0, 1.0]) for _ in range(n_users)])
|
| 44 |
+
return seqs, tab, labels
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def build_tok(seqs):
|
| 48 |
+
flat = [e for s in seqs for e in s]
|
| 49 |
+
b = DomainTokenizerBuilder(FINANCE_SCHEMA)
|
| 50 |
+
b.fit(flat)
|
| 51 |
+
return b, b.build(text_corpus=list(set(e["description"] for e in flat)) * 20, bpe_vocab_size=300)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def tiny_cfg(v=128):
|
| 55 |
+
return DomainTransformerConfig(vocab_size=v, hidden_size=64, num_hidden_layers=2,
|
| 56 |
+
num_attention_heads=4, intermediate_size=128)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def make_fusion(v, nt=10, nc=1):
|
| 60 |
+
return JointFusionModel(
|
| 61 |
+
DomainTransformerForCausalLM(tiny_cfg(v)), nt, nc,
|
| 62 |
+
plr_frequencies=4, plr_embedding_dim=8, dcn_cross_layers=2,
|
| 63 |
+
dcn_deep_layers=1, dcn_deep_dim=32, head_hidden_dim=32,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TestDataset:
|
| 68 |
+
@pytest.fixture
|
| 69 |
+
def setup(self):
|
| 70 |
+
seqs, tab, labels = make_labeled_data(10, 5)
|
| 71 |
+
b, t = build_tok(seqs)
|
| 72 |
+
return DomainFinetuneDataset(seqs, tab, labels, b, t, 64), t
|
| 73 |
+
|
| 74 |
+
def test_len(self, setup):
|
| 75 |
+
assert len(setup[0]) == 10
|
| 76 |
+
|
| 77 |
+
def test_keys(self, setup):
|
| 78 |
+
assert set(setup[0][0].keys()) == {"input_ids", "attention_mask", "tabular_features", "labels"}
|
| 79 |
+
|
| 80 |
+
def test_shapes(self, setup):
|
| 81 |
+
it = setup[0][0]
|
| 82 |
+
assert it["input_ids"].shape == (64,) and it["tabular_features"].shape == (5,)
|
| 83 |
+
|
| 84 |
+
def test_padding(self, setup):
|
| 85 |
+
it = setup[0][0]
|
| 86 |
+
assert (it["input_ids"] != setup[1].pad_token_id).any()
|
| 87 |
+
|
| 88 |
+
def test_mask_matches_pad(self, setup):
|
| 89 |
+
it = setup[0][0]
|
| 90 |
+
assert torch.equal(it["input_ids"] == setup[1].pad_token_id, it["attention_mask"] == 0)
|
| 91 |
+
|
| 92 |
+
def test_dtypes(self, setup):
|
| 93 |
+
it = setup[0][0]
|
| 94 |
+
assert it["labels"].dtype == torch.float32 and it["tabular_features"].dtype == torch.float32
|
| 95 |
+
|
| 96 |
+
def test_mismatch(self):
|
| 97 |
+
seqs, tab, labels = make_labeled_data(10)
|
| 98 |
+
b, t = build_tok(seqs)
|
| 99 |
+
with pytest.raises(AssertionError):
|
| 100 |
+
DomainFinetuneDataset(seqs[:5], tab, labels, b, t)
|
| 101 |
+
|
| 102 |
+
def test_stats(self, setup):
|
| 103 |
+
assert setup[0].get_stats()["n_samples"] == 10
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class TestBatching:
|
| 107 |
+
def test_loader(self):
|
| 108 |
+
seqs, tab, labels = make_labeled_data(8, 5)
|
| 109 |
+
b, t = build_tok(seqs)
|
| 110 |
+
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
|
| 111 |
+
batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
|
| 112 |
+
assert batch["input_ids"].shape == (4, 32) and batch["tabular_features"].shape == (4, 5)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class TestForwardBackward:
|
| 116 |
+
def test_forward(self):
|
| 117 |
+
seqs, tab, labels = make_labeled_data(8, 5)
|
| 118 |
+
b, t = build_tok(seqs)
|
| 119 |
+
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
|
| 120 |
+
batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
|
| 121 |
+
out = make_fusion(t.vocab_size, 5)(** batch)
|
| 122 |
+
assert out["loss"].item() > 0
|
| 123 |
+
|
| 124 |
+
def test_backward(self):
|
| 125 |
+
seqs, tab, labels = make_labeled_data(4, 5)
|
| 126 |
+
b, t = build_tok(seqs)
|
| 127 |
+
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
|
| 128 |
+
batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
|
| 129 |
+
model = make_fusion(t.vocab_size, 5)
|
| 130 |
+
model(**batch)["loss"].backward()
|
| 131 |
+
assert model.transformer.model.embed_tokens.weight.grad is not None
|
| 132 |
+
assert model.plr.frequencies.grad is not None
|
| 133 |
+
|
| 134 |
+
def test_multiclass(self):
|
| 135 |
+
seqs, tab, _ = make_labeled_data(8, 5)
|
| 136 |
+
labels = np.array([random.randint(0, 2) for _ in range(8)])
|
| 137 |
+
b, t = build_tok(seqs)
|
| 138 |
+
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
|
| 139 |
+
batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
|
| 140 |
+
batch["labels"] = batch["labels"].long()
|
| 141 |
+
out = make_fusion(t.vocab_size, 5, 3)(**batch)
|
| 142 |
+
assert out["logits"].shape == (4, 3) and out["loss"] is not None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class TestTrainer:
|
| 146 |
+
def test_smoke(self, tmp_path):
|
| 147 |
+
seqs, tab, labels = make_labeled_data(20, 5)
|
| 148 |
+
b, t = build_tok(seqs)
|
| 149 |
+
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
|
| 150 |
+
trainer = finetune_domain_model(
|
| 151 |
+
make_fusion(t.vocab_size, 5), ds,
|
| 152 |
+
output_dir=str(tmp_path), num_epochs=1, per_device_batch_size=4,
|
| 153 |
+
learning_rate=1e-3, warmup_steps=0, logging_steps=1,
|
| 154 |
+
save_strategy="no", report_to="none", seed=42,
|
| 155 |
+
)
|
| 156 |
+
assert trainer.state.global_step > 0
|
| 157 |
+
losses = [h["loss"] for h in trainer.state.log_history if "loss" in h]
|
| 158 |
+
assert len(losses) > 0
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class TestPrepare:
|
| 162 |
+
def test_prepare(self):
|
| 163 |
+
seqs, tab, labels = make_labeled_data(10, 5)
|
| 164 |
+
b, t = build_tok(seqs)
|
| 165 |
+
ds = prepare_finetune_dataset(seqs, tab, labels, b, t, 32)
|
| 166 |
+
assert len(ds) == 10
|