rtferraz commited on
Commit
abab711
·
verified ·
1 Parent(s): 7edb04f

Add fine-tuning test suite — 15 tests covering dataset, batching, forward/backward, Trainer smoke, multiclass

Browse files
Files changed (1) hide show
  1. tests/test_finetune.py +166 -0
tests/test_finetune.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for domainTokenizer Phase 2D: Fine-tuning Pipeline.
3
+ 15 tests covering dataset, batching, forward/backward, Trainer smoke, multiclass.
4
+
5
+ Run: pytest tests/test_finetune.py -v
6
+ """
7
+
8
+ import logging
9
+ import random
10
+ from datetime import datetime, timedelta
11
+
12
+ import numpy as np
13
+ import torch
14
+ import pytest
15
+
16
+ from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA
17
+ from domain_tokenizer.tokenizers.domain_tokenizer import DomainTokenizerBuilder
18
+ from domain_tokenizer.models.configuration import DomainTransformerConfig
19
+ from domain_tokenizer.models.modeling import DomainTransformerForCausalLM
20
+ from domain_tokenizer.models.joint_fusion import JointFusionModel
21
+ from domain_tokenizer.training.finetune_data import DomainFinetuneDataset, prepare_finetune_dataset
22
+ from domain_tokenizer.training.finetune import finetune_domain_model
23
+
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ def make_events(n=10, seed=42):
28
+ rng = random.Random(seed)
29
+ merchants = ["AMAZON", "UBER", "SALARY", "GROCERY", "NETFLIX", "GAS"]
30
+ base = datetime(2025, 1, 1)
31
+ return [
32
+ {"amount_sign": (a := rng.uniform(5, 5000) * rng.choice([1, -1])),
33
+ "amount": a, "timestamp": base + timedelta(days=rng.randint(0, 365), hours=rng.randint(0, 23)),
34
+ "description": rng.choice(merchants)}
35
+ for _ in range(n)
36
+ ]
37
+
38
+
39
+ def make_labeled_data(n_users=20, n_tabular=10, seed=42):
40
+ rng = random.Random(seed)
41
+ seqs = [make_events(rng.randint(3, 15), seed + i) for i in range(n_users)]
42
+ tab = np.random.RandomState(seed).randn(n_users, n_tabular).astype(np.float32)
43
+ labels = np.array([rng.choice([0.0, 1.0]) for _ in range(n_users)])
44
+ return seqs, tab, labels
45
+
46
+
47
+ def build_tok(seqs):
48
+ flat = [e for s in seqs for e in s]
49
+ b = DomainTokenizerBuilder(FINANCE_SCHEMA)
50
+ b.fit(flat)
51
+ return b, b.build(text_corpus=list(set(e["description"] for e in flat)) * 20, bpe_vocab_size=300)
52
+
53
+
54
+ def tiny_cfg(v=128):
55
+ return DomainTransformerConfig(vocab_size=v, hidden_size=64, num_hidden_layers=2,
56
+ num_attention_heads=4, intermediate_size=128)
57
+
58
+
59
+ def make_fusion(v, nt=10, nc=1):
60
+ return JointFusionModel(
61
+ DomainTransformerForCausalLM(tiny_cfg(v)), nt, nc,
62
+ plr_frequencies=4, plr_embedding_dim=8, dcn_cross_layers=2,
63
+ dcn_deep_layers=1, dcn_deep_dim=32, head_hidden_dim=32,
64
+ )
65
+
66
+
67
+ class TestDataset:
68
+ @pytest.fixture
69
+ def setup(self):
70
+ seqs, tab, labels = make_labeled_data(10, 5)
71
+ b, t = build_tok(seqs)
72
+ return DomainFinetuneDataset(seqs, tab, labels, b, t, 64), t
73
+
74
+ def test_len(self, setup):
75
+ assert len(setup[0]) == 10
76
+
77
+ def test_keys(self, setup):
78
+ assert set(setup[0][0].keys()) == {"input_ids", "attention_mask", "tabular_features", "labels"}
79
+
80
+ def test_shapes(self, setup):
81
+ it = setup[0][0]
82
+ assert it["input_ids"].shape == (64,) and it["tabular_features"].shape == (5,)
83
+
84
+ def test_padding(self, setup):
85
+ it = setup[0][0]
86
+ assert (it["input_ids"] != setup[1].pad_token_id).any()
87
+
88
+ def test_mask_matches_pad(self, setup):
89
+ it = setup[0][0]
90
+ assert torch.equal(it["input_ids"] == setup[1].pad_token_id, it["attention_mask"] == 0)
91
+
92
+ def test_dtypes(self, setup):
93
+ it = setup[0][0]
94
+ assert it["labels"].dtype == torch.float32 and it["tabular_features"].dtype == torch.float32
95
+
96
+ def test_mismatch(self):
97
+ seqs, tab, labels = make_labeled_data(10)
98
+ b, t = build_tok(seqs)
99
+ with pytest.raises(AssertionError):
100
+ DomainFinetuneDataset(seqs[:5], tab, labels, b, t)
101
+
102
+ def test_stats(self, setup):
103
+ assert setup[0].get_stats()["n_samples"] == 10
104
+
105
+
106
+ class TestBatching:
107
+ def test_loader(self):
108
+ seqs, tab, labels = make_labeled_data(8, 5)
109
+ b, t = build_tok(seqs)
110
+ ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
111
+ batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
112
+ assert batch["input_ids"].shape == (4, 32) and batch["tabular_features"].shape == (4, 5)
113
+
114
+
115
+ class TestForwardBackward:
116
+ def test_forward(self):
117
+ seqs, tab, labels = make_labeled_data(8, 5)
118
+ b, t = build_tok(seqs)
119
+ ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
120
+ batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
121
+ out = make_fusion(t.vocab_size, 5)(** batch)
122
+ assert out["loss"].item() > 0
123
+
124
+ def test_backward(self):
125
+ seqs, tab, labels = make_labeled_data(4, 5)
126
+ b, t = build_tok(seqs)
127
+ ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
128
+ batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
129
+ model = make_fusion(t.vocab_size, 5)
130
+ model(**batch)["loss"].backward()
131
+ assert model.transformer.model.embed_tokens.weight.grad is not None
132
+ assert model.plr.frequencies.grad is not None
133
+
134
+ def test_multiclass(self):
135
+ seqs, tab, _ = make_labeled_data(8, 5)
136
+ labels = np.array([random.randint(0, 2) for _ in range(8)])
137
+ b, t = build_tok(seqs)
138
+ ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
139
+ batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
140
+ batch["labels"] = batch["labels"].long()
141
+ out = make_fusion(t.vocab_size, 5, 3)(**batch)
142
+ assert out["logits"].shape == (4, 3) and out["loss"] is not None
143
+
144
+
145
+ class TestTrainer:
146
+ def test_smoke(self, tmp_path):
147
+ seqs, tab, labels = make_labeled_data(20, 5)
148
+ b, t = build_tok(seqs)
149
+ ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
150
+ trainer = finetune_domain_model(
151
+ make_fusion(t.vocab_size, 5), ds,
152
+ output_dir=str(tmp_path), num_epochs=1, per_device_batch_size=4,
153
+ learning_rate=1e-3, warmup_steps=0, logging_steps=1,
154
+ save_strategy="no", report_to="none", seed=42,
155
+ )
156
+ assert trainer.state.global_step > 0
157
+ losses = [h["loss"] for h in trainer.state.log_history if "loss" in h]
158
+ assert len(losses) > 0
159
+
160
+
161
+ class TestPrepare:
162
+ def test_prepare(self):
163
+ seqs, tab, labels = make_labeled_data(10, 5)
164
+ b, t = build_tok(seqs)
165
+ ds = prepare_finetune_dataset(seqs, tab, labels, b, t, 32)
166
+ assert len(ds) == 10