domainTokenizer / tests /test_finetune.py
rtferraz's picture
Add fine-tuning test suite — 15 tests covering dataset, batching, forward/backward, Trainer smoke, multiclass
abab711 verified
"""
Tests for domainTokenizer Phase 2D: Fine-tuning Pipeline.
15 tests covering dataset, batching, forward/backward, Trainer smoke, multiclass.
Run: pytest tests/test_finetune.py -v
"""
import logging
import random
from datetime import datetime, timedelta
import numpy as np
import torch
import pytest
from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA
from domain_tokenizer.tokenizers.domain_tokenizer import DomainTokenizerBuilder
from domain_tokenizer.models.configuration import DomainTransformerConfig
from domain_tokenizer.models.modeling import DomainTransformerForCausalLM
from domain_tokenizer.models.joint_fusion import JointFusionModel
from domain_tokenizer.training.finetune_data import DomainFinetuneDataset, prepare_finetune_dataset
from domain_tokenizer.training.finetune import finetune_domain_model
logging.basicConfig(level=logging.INFO)
def make_events(n=10, seed=42):
rng = random.Random(seed)
merchants = ["AMAZON", "UBER", "SALARY", "GROCERY", "NETFLIX", "GAS"]
base = datetime(2025, 1, 1)
return [
{"amount_sign": (a := rng.uniform(5, 5000) * rng.choice([1, -1])),
"amount": a, "timestamp": base + timedelta(days=rng.randint(0, 365), hours=rng.randint(0, 23)),
"description": rng.choice(merchants)}
for _ in range(n)
]
def make_labeled_data(n_users=20, n_tabular=10, seed=42):
rng = random.Random(seed)
seqs = [make_events(rng.randint(3, 15), seed + i) for i in range(n_users)]
tab = np.random.RandomState(seed).randn(n_users, n_tabular).astype(np.float32)
labels = np.array([rng.choice([0.0, 1.0]) for _ in range(n_users)])
return seqs, tab, labels
def build_tok(seqs):
flat = [e for s in seqs for e in s]
b = DomainTokenizerBuilder(FINANCE_SCHEMA)
b.fit(flat)
return b, b.build(text_corpus=list(set(e["description"] for e in flat)) * 20, bpe_vocab_size=300)
def tiny_cfg(v=128):
return DomainTransformerConfig(vocab_size=v, hidden_size=64, num_hidden_layers=2,
num_attention_heads=4, intermediate_size=128)
def make_fusion(v, nt=10, nc=1):
return JointFusionModel(
DomainTransformerForCausalLM(tiny_cfg(v)), nt, nc,
plr_frequencies=4, plr_embedding_dim=8, dcn_cross_layers=2,
dcn_deep_layers=1, dcn_deep_dim=32, head_hidden_dim=32,
)
class TestDataset:
@pytest.fixture
def setup(self):
seqs, tab, labels = make_labeled_data(10, 5)
b, t = build_tok(seqs)
return DomainFinetuneDataset(seqs, tab, labels, b, t, 64), t
def test_len(self, setup):
assert len(setup[0]) == 10
def test_keys(self, setup):
assert set(setup[0][0].keys()) == {"input_ids", "attention_mask", "tabular_features", "labels"}
def test_shapes(self, setup):
it = setup[0][0]
assert it["input_ids"].shape == (64,) and it["tabular_features"].shape == (5,)
def test_padding(self, setup):
it = setup[0][0]
assert (it["input_ids"] != setup[1].pad_token_id).any()
def test_mask_matches_pad(self, setup):
it = setup[0][0]
assert torch.equal(it["input_ids"] == setup[1].pad_token_id, it["attention_mask"] == 0)
def test_dtypes(self, setup):
it = setup[0][0]
assert it["labels"].dtype == torch.float32 and it["tabular_features"].dtype == torch.float32
def test_mismatch(self):
seqs, tab, labels = make_labeled_data(10)
b, t = build_tok(seqs)
with pytest.raises(AssertionError):
DomainFinetuneDataset(seqs[:5], tab, labels, b, t)
def test_stats(self, setup):
assert setup[0].get_stats()["n_samples"] == 10
class TestBatching:
def test_loader(self):
seqs, tab, labels = make_labeled_data(8, 5)
b, t = build_tok(seqs)
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
assert batch["input_ids"].shape == (4, 32) and batch["tabular_features"].shape == (4, 5)
class TestForwardBackward:
def test_forward(self):
seqs, tab, labels = make_labeled_data(8, 5)
b, t = build_tok(seqs)
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
out = make_fusion(t.vocab_size, 5)(** batch)
assert out["loss"].item() > 0
def test_backward(self):
seqs, tab, labels = make_labeled_data(4, 5)
b, t = build_tok(seqs)
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
model = make_fusion(t.vocab_size, 5)
model(**batch)["loss"].backward()
assert model.transformer.model.embed_tokens.weight.grad is not None
assert model.plr.frequencies.grad is not None
def test_multiclass(self):
seqs, tab, _ = make_labeled_data(8, 5)
labels = np.array([random.randint(0, 2) for _ in range(8)])
b, t = build_tok(seqs)
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
batch["labels"] = batch["labels"].long()
out = make_fusion(t.vocab_size, 5, 3)(**batch)
assert out["logits"].shape == (4, 3) and out["loss"] is not None
class TestTrainer:
def test_smoke(self, tmp_path):
seqs, tab, labels = make_labeled_data(20, 5)
b, t = build_tok(seqs)
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
trainer = finetune_domain_model(
make_fusion(t.vocab_size, 5), ds,
output_dir=str(tmp_path), num_epochs=1, per_device_batch_size=4,
learning_rate=1e-3, warmup_steps=0, logging_steps=1,
save_strategy="no", report_to="none", seed=42,
)
assert trainer.state.global_step > 0
losses = [h["loss"] for h in trainer.state.log_history if "loss" in h]
assert len(losses) > 0
class TestPrepare:
def test_prepare(self):
seqs, tab, labels = make_labeled_data(10, 5)
b, t = build_tok(seqs)
ds = prepare_finetune_dataset(seqs, tab, labels, b, t, 32)
assert len(ds) == 10