Cludoy commited on
Commit
0a7ac4c
Β·
verified Β·
1 Parent(s): 330841b

Add test_suite.py

Browse files
Files changed (1) hide show
  1. test_suite.py +319 -0
test_suite.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit Test Suite for TinyBert-CNN Intent Classifier Pipeline.
3
+ Tests: model init, dataset tokenization, forward pass, predict, compound splitter,
4
+ dataset generator output, and auto_trainer state I/O.
5
+ """
6
+
7
+ import unittest
8
+ import os
9
+ import sys
10
+ import json
11
+ import tempfile
12
+ import torch
13
+ import pandas as pd
14
+
15
+ # Ensure the project directory is on sys.path
16
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
17
+
18
+ from TinyBert import IntentClassifier, IntentDataset, CompoundSentenceSplitter, TinyBertCNN
19
+
20
+
21
+ # ─────────────────────────────────────────────────────────────────────
22
+ # 1. MODEL INITIALIZATION
23
+ # ─────────────────────────────────────────────────────────────────────
24
+
25
+ class TestModelInit(unittest.TestCase):
26
+ """Test that the model initializes correctly."""
27
+
28
+ @classmethod
29
+ def setUpClass(cls):
30
+ cls.classifier = IntentClassifier(num_classes=5)
31
+
32
+ def test_model_instance(self):
33
+ self.assertIsInstance(self.classifier.model, TinyBertCNN)
34
+
35
+ def test_num_classes(self):
36
+ self.assertEqual(self.classifier.num_classes, 5)
37
+
38
+ def test_device_assigned(self):
39
+ self.assertIsNotNone(self.classifier.device)
40
+
41
+ def test_tokenizer_loaded(self):
42
+ self.assertIsNotNone(self.classifier.tokenizer)
43
+
44
+ def test_model_has_batchnorm(self):
45
+ """Verify BatchNorm layers were added."""
46
+ self.assertTrue(hasattr(self.classifier.model, 'batchnorms'))
47
+ self.assertEqual(len(self.classifier.model.batchnorms), 3) # 3 filter sizes
48
+
49
+ def test_model_has_hidden_fc(self):
50
+ """Verify hidden FC layer exists."""
51
+ self.assertTrue(hasattr(self.classifier.model, 'fc_hidden'))
52
+ self.assertTrue(hasattr(self.classifier.model, 'bn_hidden'))
53
+
54
+
55
+ # ─────────────────────────────────────────────────────────────────────
56
+ # 2. INTENT DATASET
57
+ # ─────────────────────────────────────────────────────────────────────
58
+
59
+ class TestIntentDataset(unittest.TestCase):
60
+ """Test tokenization and tensor shapes from IntentDataset."""
61
+
62
+ @classmethod
63
+ def setUpClass(cls):
64
+ cls.classifier = IntentClassifier(num_classes=5)
65
+ cls.sample_data = [
66
+ {'student_input': 'How do I use for loops?',
67
+ 'session_context': 'topic:For Loops | prev:If/Else | ability:If/Else:85% | emotion:engaged | pace:normal | slides:14,15,16',
68
+ 'label': 0},
69
+ {'student_input': "What's the weather?",
70
+ 'session_context': 'topic:Variables | prev:None | ability:N/A | emotion:bored | pace:slow | slides:5,6,7',
71
+ 'label': 1},
72
+ ]
73
+ cls.dataset = IntentDataset(cls.sample_data, cls.classifier.tokenizer, max_length=128)
74
+
75
+ def test_dataset_length(self):
76
+ self.assertEqual(len(self.dataset), 2)
77
+
78
+ def test_output_keys(self):
79
+ item = self.dataset[0]
80
+ self.assertIn('input_ids', item)
81
+ self.assertIn('attention_mask', item)
82
+ self.assertIn('labels', item)
83
+
84
+ def test_tensor_shapes(self):
85
+ item = self.dataset[0]
86
+ self.assertEqual(item['input_ids'].shape, torch.Size([128]))
87
+ self.assertEqual(item['attention_mask'].shape, torch.Size([128]))
88
+
89
+ def test_label_type(self):
90
+ item = self.dataset[0]
91
+ self.assertEqual(item['labels'].dtype, torch.long)
92
+
93
+ def test_token_type_ids_present(self):
94
+ """TinyBERT should produce token_type_ids for sentence pairs."""
95
+ item = self.dataset[0]
96
+ if 'token_type_ids' in item:
97
+ self.assertEqual(item['token_type_ids'].shape, torch.Size([128]))
98
+
99
+ def test_handles_string_labels(self):
100
+ data = [{'student_input': 'test', 'session_context': 'ctx', 'label': 'Pace-Related'}]
101
+ ds = IntentDataset(data, self.classifier.tokenizer)
102
+ item = ds[0]
103
+ self.assertEqual(item['labels'].item(), 3)
104
+
105
+
106
+ # ─────────────────────────────────────────────────────────────────────
107
+ # 3. FORWARD PASS
108
+ # ─────────────────────────────────────────────────────────────────────
109
+
110
+ class TestForwardPass(unittest.TestCase):
111
+ """Test the TinyBertCNN forward pass with dummy data."""
112
+
113
+ @classmethod
114
+ def setUpClass(cls):
115
+ cls.classifier = IntentClassifier(num_classes=5)
116
+
117
+ def test_output_shape(self):
118
+ batch_size = 4
119
+ seq_len = 128
120
+ input_ids = torch.randint(0, 1000, (batch_size, seq_len)).to(self.classifier.device)
121
+ attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long).to(self.classifier.device)
122
+
123
+ self.classifier.model.eval()
124
+ with torch.no_grad():
125
+ logits = self.classifier.model(input_ids, attention_mask)
126
+ self.assertEqual(logits.shape, torch.Size([batch_size, 5]))
127
+
128
+ def test_output_with_token_type_ids(self):
129
+ batch_size = 2
130
+ seq_len = 128
131
+ input_ids = torch.randint(0, 1000, (batch_size, seq_len)).to(self.classifier.device)
132
+ attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long).to(self.classifier.device)
133
+ token_type_ids = torch.zeros(batch_size, seq_len, dtype=torch.long).to(self.classifier.device)
134
+
135
+ self.classifier.model.eval()
136
+ with torch.no_grad():
137
+ logits = self.classifier.model(input_ids, attention_mask, token_type_ids=token_type_ids)
138
+ self.assertEqual(logits.shape, torch.Size([batch_size, 5]))
139
+
140
+ def test_single_sample(self):
141
+ """Ensure single-sample batches don't crash (important for BatchNorm)."""
142
+ input_ids = torch.randint(0, 1000, (1, 128)).to(self.classifier.device)
143
+ attention_mask = torch.ones(1, 128, dtype=torch.long).to(self.classifier.device)
144
+
145
+ self.classifier.model.eval()
146
+ with torch.no_grad():
147
+ logits = self.classifier.model(input_ids, attention_mask)
148
+ self.assertEqual(logits.shape, torch.Size([1, 5]))
149
+
150
+
151
+ # ─────────────────────────────────────────────────────────────────────
152
+ # 4. PREDICT
153
+ # ─────────────────────────────────────────────────────────────────────
154
+
155
+ class TestPredict(unittest.TestCase):
156
+ """Test the predict() method with real text."""
157
+
158
+ @classmethod
159
+ def setUpClass(cls):
160
+ cls.classifier = IntentClassifier(num_classes=5)
161
+
162
+ def test_predict_with_context(self):
163
+ preds, probs = self.classifier.predict(
164
+ ["How do loops work?"],
165
+ ["topic:For Loops | prev:None | ability:N/A | emotion:neutral | pace:normal | slides:10,11,12"]
166
+ )
167
+ self.assertEqual(len(preds), 1)
168
+ self.assertEqual(probs.shape[1], 5)
169
+
170
+ def test_predict_without_context(self):
171
+ preds, probs = self.classifier.predict(["I'm feeling frustrated"])
172
+ self.assertEqual(len(preds), 1)
173
+
174
+ def test_predict_empty_string(self):
175
+ """Empty input should not crash."""
176
+ preds, probs = self.classifier.predict([""])
177
+ self.assertEqual(len(preds), 1)
178
+
179
+ def test_predict_multiple(self):
180
+ preds, probs = self.classifier.predict(
181
+ ["Hello", "Can you repeat?", "Speed up please"],
182
+ ["ctx1", "ctx2", "ctx3"]
183
+ )
184
+ self.assertEqual(len(preds), 3)
185
+
186
+
187
+ # ─────────────────────────────────────────────────────────────────────
188
+ # 5. COMPOUND SENTENCE SPLITTER
189
+ # ─────────────────────────────────────────────────────────────────────
190
+
191
+ class TestCompoundSplitter(unittest.TestCase):
192
+ """Test the CompoundSentenceSplitter edge cases."""
193
+
194
+ @classmethod
195
+ def setUpClass(cls):
196
+ cls.splitter = CompoundSentenceSplitter()
197
+
198
+ def test_compound_question_splits(self):
199
+ result = self.splitter.split_compound_question(
200
+ "What is a variable and how do I use it?"
201
+ )
202
+ self.assertGreaterEqual(len(result), 2)
203
+
204
+ def test_single_question_no_split(self):
205
+ result = self.splitter.split_compound_question("How do loops work?")
206
+ self.assertEqual(len(result), 1)
207
+
208
+ def test_non_question_no_split(self):
209
+ result = self.splitter.split_compound_question("I like programming.")
210
+ self.assertEqual(len(result), 1)
211
+
212
+ def test_multiple_question_marks(self):
213
+ result = self.splitter.split_compound_question("What is a loop? How does it work?")
214
+ self.assertEqual(len(result), 2)
215
+
216
+ def test_empty_string(self):
217
+ result = self.splitter.split_compound_question("")
218
+ self.assertEqual(len(result), 1)
219
+
220
+
221
+ # ─────────────────────────────────────────────────────────────────────
222
+ # 6. DATASET GENERATOR
223
+ # ─────────────────────────────────────────────────────────────────────
224
+
225
+ class TestDatasetGenerator(unittest.TestCase):
226
+ """Test that the dataset generator produces correct output."""
227
+
228
+ @classmethod
229
+ def setUpClass(cls):
230
+ # Generate a small dataset
231
+ from dataset_generator import build_dataset
232
+ cls.original_dir = os.getcwd()
233
+ cls.tmp_dir = tempfile.mkdtemp()
234
+ os.chdir(cls.tmp_dir)
235
+ build_dataset(num_samples_per_class=20)
236
+ cls.train_df = pd.read_csv('data/train.csv')
237
+ cls.val_df = pd.read_csv('data/val.csv')
238
+ cls.test_df = pd.read_csv('data/test.csv')
239
+
240
+ @classmethod
241
+ def tearDownClass(cls):
242
+ os.chdir(cls.original_dir)
243
+
244
+ def test_columns_exist(self):
245
+ for col in ['student_input', 'session_context', 'label', 'intent_name']:
246
+ self.assertIn(col, self.train_df.columns)
247
+
248
+ def test_three_splits_exist(self):
249
+ self.assertGreater(len(self.train_df), 0)
250
+ self.assertGreater(len(self.val_df), 0)
251
+ self.assertGreater(len(self.test_df), 0)
252
+
253
+ def test_all_classes_present(self):
254
+ all_labels = set(self.train_df['label'].unique())
255
+ self.assertEqual(all_labels, {0, 1, 2, 3, 4})
256
+
257
+ def test_compact_context_format(self):
258
+ ctx = self.train_df.iloc[0]['session_context']
259
+ self.assertIn('topic:', ctx)
260
+ self.assertIn('prev:', ctx)
261
+ self.assertIn('emotion:', ctx)
262
+
263
+ def test_no_empty_inputs(self):
264
+ self.assertFalse(self.train_df['student_input'].isna().any())
265
+ self.assertFalse(self.train_df['session_context'].isna().any())
266
+
267
+
268
+ # ─────────────────────────────────────────────────────────────────────
269
+ # 7. AUTO TRAINER STATE
270
+ # ─────────────────────────────────────────────────────────────────────
271
+
272
+ class TestAutoTrainerState(unittest.TestCase):
273
+ """Test load_state / save_state round-trip."""
274
+
275
+ def test_state_round_trip(self):
276
+ from auto_trainer import load_state, save_state, STATE_FILE
277
+
278
+ # Save original if exists
279
+ original_exists = os.path.exists(STATE_FILE)
280
+ original_content = None
281
+ if original_exists:
282
+ with open(STATE_FILE, 'r') as f:
283
+ original_content = f.read()
284
+
285
+ try:
286
+ test_state = {"sessions_since_last_train": 42, "total_sessions": 100}
287
+ save_state(test_state)
288
+ loaded = load_state()
289
+ self.assertEqual(loaded["sessions_since_last_train"], 42)
290
+ self.assertEqual(loaded["total_sessions"], 100)
291
+ finally:
292
+ # Restore original
293
+ if original_exists:
294
+ with open(STATE_FILE, 'w') as f:
295
+ f.write(original_content)
296
+ elif os.path.exists(STATE_FILE):
297
+ os.remove(STATE_FILE)
298
+
299
+ def test_default_state(self):
300
+ from auto_trainer import load_state, STATE_FILE
301
+
302
+ backup = None
303
+ if os.path.exists(STATE_FILE):
304
+ with open(STATE_FILE, 'r') as f:
305
+ backup = f.read()
306
+ os.remove(STATE_FILE)
307
+
308
+ try:
309
+ state = load_state()
310
+ self.assertEqual(state["sessions_since_last_train"], 0)
311
+ self.assertEqual(state["total_sessions"], 0)
312
+ finally:
313
+ if backup:
314
+ with open(STATE_FILE, 'w') as f:
315
+ f.write(backup)
316
+
317
+
318
+ if __name__ == '__main__':
319
+ unittest.main(verbosity=2)