Cludoy commited on
Commit
5de5c2b
·
verified ·
1 Parent(s): 6c56621

Add TinyBert.py

Browse files
Files changed (1) hide show
  1. TinyBert.py +620 -0
TinyBert.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModel, AutoTokenizer
4
+ from torch.utils.data import Dataset
5
+ import re
6
+
7
+
8
+ class IntentDataset(Dataset):
9
+ """
10
+ Dataset for handling student input and session context for 5-class intent categorization.
11
+ """
12
+ def __init__(self, data, tokenizer, max_length=128):
13
+ # data: list of dicts with 'student_input', 'session_context', 'label'
14
+ self.data = data
15
+ self.tokenizer = tokenizer
16
+ self.max_length = max_length
17
+ self.label_map = {
18
+ 'On-Topic Question': 0,
19
+ 'Off-Topic Question': 1,
20
+ 'Emotional-State': 2,
21
+ 'Pace-Related': 3,
22
+ 'Repeat/clarification': 4
23
+ }
24
+
25
+ def __len__(self):
26
+ return len(self.data)
27
+
28
+ def __getitem__(self, idx):
29
+ item = self.data[idx]
30
+ student_input = str(item.get('student_input', ''))
31
+ session_context = str(item.get('session_context', ''))
32
+
33
+ # Tokenize pair — longest_first truncation preserves student input priority
34
+ encoded = self.tokenizer(
35
+ student_input,
36
+ session_context,
37
+ padding='max_length',
38
+ truncation='longest_first',
39
+ max_length=self.max_length,
40
+ return_tensors='pt'
41
+ )
42
+
43
+ label_val = item.get('label', 0)
44
+ if isinstance(label_val, str):
45
+ label_val = self.label_map.get(label_val, 0)
46
+
47
+ output = {
48
+ 'input_ids': encoded['input_ids'].squeeze(0),
49
+ 'attention_mask': encoded['attention_mask'].squeeze(0),
50
+ 'labels': torch.tensor(label_val, dtype=torch.long)
51
+ }
52
+ if 'token_type_ids' in encoded:
53
+ output['token_type_ids'] = encoded['token_type_ids'].squeeze(0)
54
+
55
+ return output
56
+
57
+
58
+ class CompoundSentenceSplitter:
59
+ """
60
+ Algorithm to split compound sentences containing 2 separate questions.
61
+ Handles various patterns and conjunctions commonly used to combine questions.
62
+ English only.
63
+ """
64
+
65
+ def __init__(self):
66
+ # English question words
67
+ self.question_words = [
68
+ 'what', 'when', 'where', 'which', 'who', 'whom', 'whose', 'why', 'how',
69
+ 'is', 'are', 'was', 'were', 'do', 'does', 'did', 'can', 'could',
70
+ 'will', 'would', 'should', 'may', 'might', 'must'
71
+ ]
72
+
73
+ # English conjunctions
74
+ self.conjunctions = [
75
+ 'and', 'or', 'also', 'plus', 'additionally', 'moreover'
76
+ ]
77
+
78
+ # English transition phrases
79
+ self.transition_phrases = [
80
+ 'and also', 'and what about', 'and how about', 'or what about',
81
+ 'or how about', 'also what', 'also how', 'also when', 'also where',
82
+ 'also who', 'also why', 'plus what', 'plus how'
83
+ ]
84
+
85
+ def split_compound_question(self, text):
86
+ """
87
+ Split a compound sentence into 2 separate questions if applicable.
88
+ Works with English text.
89
+
90
+ Args:
91
+ text (str): Input text that may contain compound questions
92
+
93
+ Returns:
94
+ list: List of separated questions. Returns [text] if no split is needed.
95
+ """
96
+ text = text.strip()
97
+
98
+ # Check if text is likely a question
99
+ if not self._is_question(text):
100
+ return [text]
101
+
102
+ # Try different splitting strategies
103
+ questions = []
104
+
105
+ # Strategy 1: Split by transition phrases
106
+ questions = self._split_by_transition_phrases(text)
107
+ if len(questions) > 1:
108
+ return self._clean_questions(questions)
109
+
110
+ # Strategy 2: Split by conjunction followed by question word
111
+ questions = self._split_by_conjunction_pattern(text)
112
+ if len(questions) > 1:
113
+ return self._clean_questions(questions)
114
+
115
+ # Strategy 3: Split by semicolon or comma-conjunction pattern
116
+ questions = self._split_by_punctuation_pattern(text)
117
+ if len(questions) > 1:
118
+ return self._clean_questions(questions)
119
+
120
+ # Strategy 4: Split by multiple question marks
121
+ questions = self._split_by_question_marks(text)
122
+ if len(questions) > 1:
123
+ return self._clean_questions(questions)
124
+
125
+ # No split found, return original
126
+ return [text]
127
+
128
+ def _is_question(self, text):
129
+ """Check if text is likely a question (English)"""
130
+ text_stripped = text.strip()
131
+
132
+ # Has question mark
133
+ if '?' in text:
134
+ return True
135
+
136
+ # Check for question words at the start
137
+ words = text_stripped.split()
138
+ if words:
139
+ first_word = words[0].lower()
140
+ # Check English question words
141
+ if first_word in self.question_words:
142
+ return True
143
+
144
+ return False
145
+
146
+ def _split_by_transition_phrases(self, text):
147
+ """Split by transition phrases (English)"""
148
+ for phrase in self.transition_phrases:
149
+ # English phrase with word boundaries
150
+ pattern = r'\s+' + re.escape(phrase) + r'\s+'
151
+
152
+ match = re.search(pattern, text, re.IGNORECASE)
153
+ if match:
154
+ parts = re.split(pattern, text, maxsplit=1, flags=re.IGNORECASE)
155
+ if len(parts) == 2 and parts[0] and parts[1]:
156
+ return parts
157
+
158
+ return [text]
159
+
160
+ def _split_by_conjunction_pattern(self, text):
161
+ """Split by conjunction followed by question word (English)"""
162
+ # Pattern: conjunction + question word
163
+ for conj in self.conjunctions:
164
+ for qword in self.question_words:
165
+ # English pattern with word boundaries
166
+ pattern = r'\s+' + re.escape(conj) + r'\s+' + re.escape(qword) + r'\b'
167
+
168
+ match = re.search(pattern, text, re.IGNORECASE)
169
+
170
+ if match:
171
+ # Find the actual position in original text
172
+ split_pos = match.start()
173
+ part1 = text[:split_pos].strip()
174
+ part2 = text[split_pos:].strip()
175
+
176
+ # Remove leading conjunction from part2
177
+ for c in self.conjunctions:
178
+ is_arabic_c = any(ch in 'أبتثجحخدذرزسشصضطظعغفقكلمنهويىةؤإآ' for ch in c)
179
+ part2 = re.sub(r'^\s*' + re.escape(c) + r'\s+', '', part2, flags=re.IGNORECASE if not is_arabic_c else 0)
180
+
181
+ # Ensure both parts are questions
182
+ if part1 and part2 and self._is_question(part1):
183
+ return [part1, part2]
184
+
185
+ return [text]
186
+
187
+ def _split_by_punctuation_pattern(self, text):
188
+ """Split by semicolon or specific comma patterns"""
189
+ # Split by semicolon (works for both languages)
190
+ if ';' in text or '؛' in text: # Added Arabic semicolon
191
+ parts = re.split(r'[;؛]', text, maxsplit=1)
192
+ if len(parts) == 2:
193
+ parts = [p.strip() for p in parts]
194
+ if all(self._is_question(p) for p in parts):
195
+ return parts
196
+
197
+ # Split by comma followed by question word
198
+ pattern = r',\s+(?=' + '|'.join([re.escape(qw) for qw in self.question_words]) + r')'
199
+ parts = re.split(pattern, text, maxsplit=1, flags=re.IGNORECASE)
200
+
201
+ if len(parts) == 2:
202
+ parts = [p.strip() for p in parts]
203
+ # Only split if second part is clearly a question
204
+ if self._is_question(parts[1]):
205
+ return parts
206
+
207
+ return [text]
208
+
209
+ def _split_by_question_marks(self, text):
210
+ """Split by question marks if multiple exist (both ? and ؟)"""
211
+ # Count both English and Arabic question marks
212
+ q_marks = text.count('?') + text.count('؟')
213
+
214
+ if q_marks >= 2:
215
+ # Split at first question mark
216
+ match = re.search(r'[?؟]', text)
217
+ if match:
218
+ split_pos = match.end()
219
+ part1 = text[:split_pos].strip()
220
+ part2 = text[split_pos:].strip()
221
+
222
+ if part2: # Ensure second part is not empty
223
+ return [part1, part2]
224
+
225
+ return [text]
226
+
227
+ def _clean_questions(self, questions):
228
+ """Clean and validate split questions"""
229
+ cleaned = []
230
+
231
+ for q in questions:
232
+ q = q.strip()
233
+
234
+ # Skip empty questions
235
+ if not q:
236
+ continue
237
+
238
+ # Ensure question ends with '?' or '؟' if it's clearly a question
239
+ if self._is_question(q):
240
+ # Check if already has question mark
241
+ if not (q.endswith('?') or q.endswith('؟')):
242
+ # Add appropriate question mark based on language
243
+ if any(c in 'أبتثجحخدذرزسشصضطظعغفقكلمنهويىةؤإآ' for c in q):
244
+ q += '؟' # Arabic question mark
245
+ else:
246
+ q += '?' # English question mark
247
+
248
+ cleaned.append(q)
249
+
250
+ return cleaned if len(cleaned) > 1 else [' '.join(questions)]
251
+
252
+
253
+ class TinyBertCNN(nn.Module):
254
+ """
255
+ TinyBERT-CNN model for intent classification.
256
+ Combines TinyBERT embeddings with CNN layers + BatchNorm + hidden FC layer.
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ num_classes,
262
+ bert_model_name='huawei-noah/TinyBERT_General_4L_312D',
263
+ num_filters=256,
264
+ filter_sizes=[2, 3, 4],
265
+ dropout=0.5,
266
+ hidden_dim=128,
267
+ freeze_bert=False
268
+ ):
269
+ """
270
+ Args:
271
+ num_classes (int): Number of intent classes
272
+ bert_model_name (str): Pre-trained TinyBERT model name
273
+ num_filters (int): Number of filters for each filter size
274
+ filter_sizes (list): List of filter sizes for CNN
275
+ dropout (float): Dropout rate
276
+ hidden_dim (int): Hidden FC layer dimension
277
+ freeze_bert (bool): Whether to freeze BERT parameters
278
+ """
279
+ super(TinyBertCNN, self).__init__()
280
+
281
+ # Load TinyBERT model
282
+ self.bert = AutoModel.from_pretrained(bert_model_name)
283
+ self.bert_hidden_size = self.bert.config.hidden_size
284
+
285
+ # Freeze BERT parameters if specified
286
+ if freeze_bert:
287
+ for param in self.bert.parameters():
288
+ param.requires_grad = False
289
+
290
+ # CNN layers with BatchNorm
291
+ self.convs = nn.ModuleList([
292
+ nn.Conv1d(
293
+ in_channels=self.bert_hidden_size,
294
+ out_channels=num_filters,
295
+ kernel_size=fs
296
+ )
297
+ for fs in filter_sizes
298
+ ])
299
+ self.batchnorms = nn.ModuleList([
300
+ nn.BatchNorm1d(num_filters)
301
+ for _ in filter_sizes
302
+ ])
303
+
304
+ # Dropout
305
+ self.dropout = nn.Dropout(dropout)
306
+
307
+ # Hidden FC layer
308
+ cnn_out_dim = len(filter_sizes) * num_filters
309
+ self.fc_hidden = nn.Linear(cnn_out_dim, hidden_dim)
310
+ self.bn_hidden = nn.BatchNorm1d(hidden_dim)
311
+
312
+ # Output layer
313
+ self.fc = nn.Linear(hidden_dim, num_classes)
314
+
315
+ def forward(self, input_ids, attention_mask, token_type_ids=None):
316
+ """
317
+ Forward pass
318
+
319
+ Args:
320
+ input_ids: Token IDs (batch_size, seq_len)
321
+ attention_mask: Attention mask (batch_size, seq_len)
322
+ token_type_ids: Token type IDs (batch_size, seq_len), optional
323
+
324
+ Returns:
325
+ logits: Classification logits (batch_size, num_classes)
326
+ """
327
+ # Get TinyBERT embeddings
328
+ # outputs: (batch_size, seq_len, hidden_size)
329
+ bert_kwargs = {
330
+ 'input_ids': input_ids,
331
+ 'attention_mask': attention_mask
332
+ }
333
+ if token_type_ids is not None:
334
+ bert_kwargs['token_type_ids'] = token_type_ids
335
+
336
+ bert_output = self.bert(**bert_kwargs)
337
+
338
+ # Use last hidden state
339
+ # sequence_output: (batch_size, seq_len, hidden_size)
340
+ sequence_output = bert_output.last_hidden_state
341
+
342
+ # Transpose for CNN: (batch_size, hidden_size, seq_len)
343
+ sequence_output = sequence_output.transpose(1, 2)
344
+
345
+ # Pad if sequence is shorter than the largest kernel
346
+ max_kernel = max(conv.kernel_size[0] for conv in self.convs)
347
+ if sequence_output.size(2) < max_kernel:
348
+ pad_size = max_kernel - sequence_output.size(2)
349
+ sequence_output = torch.nn.functional.pad(sequence_output, (0, pad_size))
350
+
351
+ # Apply convolution + batchnorm + max pooling for each filter size
352
+ conv_outputs = []
353
+ for conv, bn in zip(self.convs, self.batchnorms):
354
+ # conv_out: (batch_size, num_filters, seq_len - filter_size + 1)
355
+ conv_out = torch.relu(bn(conv(sequence_output)))
356
+ # pooled: (batch_size, num_filters)
357
+ pooled = torch.max_pool1d(conv_out, conv_out.size(2)).squeeze(2)
358
+ conv_outputs.append(pooled)
359
+
360
+ # Concatenate all features
361
+ # concatenated: (batch_size, len(filter_sizes) * num_filters)
362
+ concatenated = torch.cat(conv_outputs, dim=1)
363
+ concatenated = self.dropout(concatenated)
364
+
365
+ # Hidden FC layer
366
+ hidden = torch.relu(self.bn_hidden(self.fc_hidden(concatenated)))
367
+ hidden = self.dropout(hidden)
368
+
369
+ # Final classification
370
+ logits = self.fc(hidden)
371
+
372
+ return logits
373
+
374
+
375
+ class IntentClassifier:
376
+ """
377
+ Wrapper class for training and inference
378
+ """
379
+
380
+ def __init__(
381
+ self,
382
+ num_classes,
383
+ bert_model_name='huawei-noah/TinyBERT_General_4L_312D',
384
+ num_filters=256,
385
+ filter_sizes=[2, 3, 4],
386
+ dropout=0.5,
387
+ freeze_bert=False,
388
+ device=None
389
+ ):
390
+ self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
391
+
392
+ # Initialize model
393
+ self.model = TinyBertCNN(
394
+ num_classes=num_classes,
395
+ bert_model_name=bert_model_name,
396
+ num_filters=num_filters,
397
+ filter_sizes=filter_sizes,
398
+ dropout=dropout,
399
+ freeze_bert=freeze_bert
400
+ ).to(self.device)
401
+
402
+ # Initialize tokenizer
403
+ self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
404
+
405
+ # Initialize compound sentence splitter
406
+ self.sentence_splitter = CompoundSentenceSplitter()
407
+
408
+ self.num_classes = num_classes
409
+
410
+ def preprocess_text(self, text):
411
+ """
412
+ Preprocess text by splitting compound questions if detected
413
+
414
+ Args:
415
+ text (str): Input text (English or Arabic)
416
+
417
+ Returns:
418
+ list: List of individual questions
419
+ """
420
+ return self.sentence_splitter.split_compound_question(text)
421
+
422
+ def predict(self, student_inputs, session_contexts=None, max_length=128, split_compound=False):
423
+ """
424
+ Predict intents for input texts
425
+
426
+ Args:
427
+ student_inputs (list): List of student input texts (English or Arabic)
428
+ session_contexts (list): List of session context texts
429
+ max_length (int): Maximum sequence length
430
+ split_compound (bool): Whether to split compound questions before prediction
431
+
432
+ Returns:
433
+ If split_compound=False:
434
+ predictions: Predicted class indices
435
+ probabilities: Prediction probabilities
436
+ If split_compound=True:
437
+ predictions: List of predictions (may contain multiple per text if split)
438
+ probabilities: List of probabilities
439
+ split_info: Dictionary with information about splits
440
+ """
441
+ # Handle compound questions if requested
442
+ if split_compound:
443
+ return self._predict_with_splitting(student_inputs, session_contexts, max_length)
444
+
445
+ self.model.eval()
446
+
447
+ # Determine if we are passing single string or pair
448
+ if session_contexts is not None:
449
+ text_args = (student_inputs, session_contexts)
450
+ else:
451
+ text_args = (student_inputs,)
452
+
453
+ # Tokenize
454
+ encoded = self.tokenizer(
455
+ *text_args,
456
+ padding=True,
457
+ truncation=True,
458
+ max_length=max_length,
459
+ return_tensors='pt'
460
+ )
461
+
462
+ input_ids = encoded['input_ids'].to(self.device)
463
+ attention_mask = encoded['attention_mask'].to(self.device)
464
+ token_type_ids = encoded.get('token_type_ids')
465
+ if token_type_ids is not None:
466
+ token_type_ids = token_type_ids.to(self.device)
467
+
468
+ with torch.no_grad():
469
+ logits = self.model(input_ids, attention_mask, token_type_ids=token_type_ids)
470
+ probabilities = torch.softmax(logits, dim=1)
471
+ predictions = torch.argmax(probabilities, dim=1)
472
+
473
+ return predictions.cpu().numpy(), probabilities.cpu().numpy()
474
+
475
+ def _predict_with_splitting(self, student_inputs, session_contexts=None, max_length=128):
476
+ """
477
+ Predict intents after splitting compound questions (English and Arabic)
478
+
479
+ Args:
480
+ student_inputs (list): List of input texts
481
+ session_contexts (list): List of session context texts
482
+ max_length (int): Maximum sequence length
483
+
484
+ Returns:
485
+ predictions: List of predictions (one per original text, may contain multiple if split)
486
+ probabilities: List of probabilities
487
+ split_info: Dictionary with information about splits
488
+ """
489
+ all_predictions = []
490
+ all_probabilities = []
491
+ split_info = {
492
+ 'original_texts': student_inputs,
493
+ 'split_texts': [],
494
+ 'was_split': [],
495
+ 'split_indices': [] # Maps split question index to original text index
496
+ }
497
+
498
+ # Collect all questions after splitting
499
+ all_questions = []
500
+ all_contexts = []
501
+ for i, text in enumerate(student_inputs):
502
+ questions = self.preprocess_text(text)
503
+ split_info['split_texts'].append(questions)
504
+ split_info['was_split'].append(len(questions) > 1)
505
+
506
+ # Track which original text each split question belongs to
507
+ for _ in questions:
508
+ split_info['split_indices'].append(i)
509
+ if session_contexts is not None:
510
+ all_contexts.append(session_contexts[i])
511
+
512
+ all_questions.extend(questions)
513
+
514
+ # Predict for all questions at once
515
+ if all_questions:
516
+ contexts_to_pass = all_contexts if session_contexts is not None else None
517
+ predictions, probabilities = self.predict(all_questions, contexts_to_pass, max_length, split_compound=False)
518
+
519
+ # Reorganize results by original text
520
+ idx = 0
521
+ for i, text in enumerate(student_inputs):
522
+ num_questions = len(split_info['split_texts'][i])
523
+ text_predictions = predictions[idx:idx + num_questions]
524
+ text_probabilities = probabilities[idx:idx + num_questions]
525
+
526
+ all_predictions.append(text_predictions)
527
+ all_probabilities.append(text_probabilities)
528
+
529
+ idx += num_questions
530
+
531
+ return all_predictions, all_probabilities, split_info
532
+
533
+ def train_step(self, batch, optimizer, criterion):
534
+ """
535
+ Single training step
536
+
537
+ Args:
538
+ batch: Dictionary with 'input_ids', 'attention_mask', 'labels'
539
+ optimizer: Optimizer
540
+ criterion: Loss function
541
+
542
+ Returns:
543
+ loss: Training loss
544
+ """
545
+ self.model.train()
546
+
547
+ input_ids = batch['input_ids'].to(self.device)
548
+ attention_mask = batch['attention_mask'].to(self.device)
549
+ labels = batch['labels'].to(self.device)
550
+ token_type_ids = batch.get('token_type_ids')
551
+ if token_type_ids is not None:
552
+ token_type_ids = token_type_ids.to(self.device)
553
+
554
+ # Forward pass
555
+ logits = self.model(input_ids, attention_mask, token_type_ids=token_type_ids)
556
+ loss = criterion(logits, labels)
557
+
558
+ # Backward pass
559
+ optimizer.zero_grad()
560
+ loss.backward()
561
+ optimizer.step()
562
+
563
+ return loss.item()
564
+
565
+ def evaluate(self, dataloader, criterion):
566
+ """
567
+ Evaluate model on validation/test set
568
+
569
+ Args:
570
+ dataloader: DataLoader for evaluation
571
+ criterion: Loss function
572
+
573
+ Returns:
574
+ avg_loss: Average loss
575
+ accuracy: Classification accuracy
576
+ """
577
+ self.model.eval()
578
+
579
+ total_loss = 0
580
+ total_correct = 0
581
+ total_samples = 0
582
+
583
+ with torch.no_grad():
584
+ for batch in dataloader:
585
+ input_ids = batch['input_ids'].to(self.device)
586
+ attention_mask = batch['attention_mask'].to(self.device)
587
+ labels = batch['labels'].to(self.device)
588
+ token_type_ids = batch.get('token_type_ids')
589
+ if token_type_ids is not None:
590
+ token_type_ids = token_type_ids.to(self.device)
591
+
592
+ # Forward pass
593
+ logits = self.model(input_ids, attention_mask, token_type_ids=token_type_ids)
594
+ loss = criterion(logits, labels)
595
+
596
+ # Calculate metrics
597
+ predictions = torch.argmax(logits, dim=1)
598
+ total_loss += loss.item() * labels.size(0)
599
+ total_correct += (predictions == labels).sum().item()
600
+ total_samples += labels.size(0)
601
+
602
+ avg_loss = total_loss / total_samples
603
+ accuracy = total_correct / total_samples
604
+
605
+ return avg_loss, accuracy
606
+
607
+ def save_model(self, path):
608
+ """Save model checkpoint"""
609
+ torch.save({
610
+ 'model_state_dict': self.model.state_dict(),
611
+ 'num_classes': self.num_classes
612
+ }, path)
613
+ print(f"Model saved to {path}")
614
+
615
+ def load_model(self, path):
616
+ """Load model checkpoint"""
617
+ checkpoint = torch.load(path, map_location=self.device)
618
+ self.model.load_state_dict(checkpoint['model_state_dict'])
619
+ print(f"Model loaded from {path}")
620
+