File size: 14,041 Bytes
da89f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cf4c1f
f292cd1
da89f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f292cd1
6cf4c1f
f292cd1
 
da89f1c
 
 
 
 
 
 
cb428cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da89f1c
cb428cb
 
 
 
 
 
da89f1c
cb428cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f292cd1
cb428cb
 
 
 
 
 
 
 
 
 
 
 
f292cd1
cb428cb
 
 
 
 
 
 
 
 
 
f292cd1
cb428cb
 
 
f292cd1
cb428cb
 
 
 
f292cd1
cb428cb
 
da89f1c
cb428cb
 
 
 
 
80b78df
 
 
 
 
 
 
 
 
 
cb428cb
 
 
f292cd1
cb428cb
 
 
 
 
 
 
 
 
 
 
 
 
da89f1c
cb428cb
 
 
 
 
 
d26ac21
 
 
 
cb428cb
da89f1c
 
 
 
 
 
 
 
 
 
 
d26ac21
 
 
da89f1c
99575b1
da89f1c
 
 
 
 
bad72d7
 
da89f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb428cb
 
 
 
 
 
 
6cf4c1f
cb428cb
 
 
 
 
 
 
 
 
 
da89f1c
 
 
f292cd1
 
 
 
 
 
 
 
 
 
bad72d7
 
99575b1
f292cd1
bad72d7
f292cd1
 
 
6e2887d
 
da89f1c
80b78df
da89f1c
 
 
bad72d7
 
 
da89f1c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import numpy as np
import time
from tqdm import tqdm
import logging
import os

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class Trainer:
    """
    Improved trainer class with techniques from Hedwig implementation
    to get better performance on document classification tasks
    """
    def __init__(
        self, 
        model, 
        train_loader, 
        val_loader, 
        test_loader=None, 
        lr=2e-5,
        weight_decay=0.01,
        warmup_proportion=0.1,
        gradient_accumulation_steps=1,
        max_grad_norm=1.0,
        num_classes=2,
        num_categories=1,
        device=None
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info(f"Using device: {self.device}")
        
        self.model.to(self.device)
        
        # Total number of training steps
        self.num_training_steps = len(train_loader) * gradient_accumulation_steps
        
        # Optimizer with weight decay (L2 regularization)
        # Using different learning rates for BERT and classifier
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 
             'weight_decay': 0.0}
        ]
        
        self.optimizer = optim.AdamW(optimizer_grouped_parameters, lr=lr)
        
        # Learning rate scheduler
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=0.5, patience=2, verbose=True)
        
        # Loss function with label smoothing for better generalization
        self.criterion = nn.CrossEntropyLoss()
        
        # Training parameters
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.max_grad_norm = max_grad_norm
        
        # For tracking metrics
        self.best_val_f1 = 0.0
        self.best_model_state = None

        self.num_classes = num_classes  # Number of classes for classification
        # For training if using multiple categories (e.g., multiple sentiment classes, there can be multiple sentiment in one document)
        self.num_categories = num_categories
    
    def train(self, epochs, save_path='best_model.pth'):
        """
        Training loop with improved techniques
        """
        logger.info(f"Starting training for {epochs} epochs")
        
        for epoch in range(epochs):
            start_time = time.time()
            
            # Training phase
            self.model.train()
            train_loss = 0
            all_predictions = []
            all_labels = []
            
            # Progress bar for training
            train_iterator = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
            for i, batch in enumerate(train_iterator):
                # Move batch to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                token_type_ids = batch['token_type_ids'].to(self.device)
                labels = batch['label'].to(self.device)
                
                # Forward pass
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids
                )
                
                # Calculate loss
                
                if self.num_categories > 1:
                    total_loss = 0
                    for i in range(self.num_categories):
                        start_idx = i * self.num_classes
                        end_idx = (i + 1) * self.num_classes
                        category_outputs = outputs[:, start_idx:end_idx] # Shape (batch, num_classes)
                        category_labels = labels[:, i] # Shape (batch)
                        # Ensure category_labels are in [0, self.num_classes - 1]
                        if category_labels.max() >= self.num_classes or category_labels.min() < 0:
                            print(f"ERROR: Category {i} labels out of range [0, {self.num_classes - 1}]: min={category_labels.min()}, max={category_labels.max()}")
                            
                        total_loss += self.criterion(category_outputs, category_labels)

                    loss = total_loss / self.num_categories # Average loss
                else:
                    loss = self.criterion(outputs, labels)
                
                # Scale loss if using gradient accumulation
                if self.gradient_accumulation_steps > 1:
                    loss = loss / self.gradient_accumulation_steps
                
                # Backward pass
                loss.backward()
                
                # Update weights if we've accumulated enough gradients
                if (i + 1) % self.gradient_accumulation_steps == 0:
                    # Gradient clipping to prevent exploding gradients
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                    
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                
                train_loss += loss.item() * self.gradient_accumulation_steps
                
                # Get predictions for metrics
                if self.num_categories > 1:
                    batch_size, total_classes = outputs.shape
                    if total_classes % self.num_categories != 0:
                        raise ValueError(f"Error: Number of total classes in the batch must of divisible by {self.num_categories}")

                    classes_per_group = total_classes // self.num_categories
                    # Group every classes_per_group values along dim=1
                    reshaped = outputs.view(outputs.size(0), -1, classes_per_group)  # shape: (batch, self., classes_per_group)

                    # Argmax over each group of classes_per_group
                    preds = reshaped.argmax(dim=-1)
                else:
                    _, preds = torch.max(outputs, dim=1)

                all_predictions.extend(preds.cpu().tolist())
                all_labels.extend(labels.cpu().tolist())
                
                # Update progress bar with current loss
                train_iterator.set_postfix({'loss': f"{loss.item():.4f}"})
            
            # Calculate training metrics
            train_loss /= len(self.train_loader)
            if self.num_categories > 1:
                # Flatten the list of predictions and labels
                all_predictions = np.concatenate(all_predictions)
                all_labels = np.concatenate(all_labels)
                
                train_acc = accuracy_score(all_labels, all_predictions)
                train_f1 = f1_score(all_labels, all_predictions, average='macro')
            else:
                train_acc = accuracy_score(all_labels, all_predictions)
                train_f1 = f1_score(all_labels, all_predictions, average='macro')
            
            # Validation phase
            val_loss, val_acc, val_f1, val_precision, val_recall = self.evaluate(self.val_loader, "Validation")

            # Log validation metrics
            logger.info(f"Validation - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}, "
                        f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}")
            
            # Adjust learning rate based on validation performance
            self.scheduler.step(val_f1)
            
            # Save best model
            if val_f1 > self.best_val_f1:
                self.best_val_f1 = val_f1
                self.best_model_state = self.model.state_dict().copy()
                torch.save(self.model.state_dict(), save_path)
                logger.info(f"New best model saved with validation F1: {val_f1:.4f}")
            
            # Print epoch summary
            epoch_time = time.time() - start_time
            logger.info(f"Epoch {epoch+1}/{epochs} - "
                    f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}, "
                    f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}, "
                    f"Time: {epoch_time:.2f}s") 
            print(f"Epoch {epoch+1}/{epochs} - ",
                    f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}, ",
                    f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}, ",
                    f"Time: {epoch_time:.2f}s")

        # Load best model for final evaluation
        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state)
            logger.info(f"Loaded best model with validation F1: {self.best_val_f1:.4f}")
        
        # Test evaluation if test loader provided
        if self.test_loader:
            test_loss, test_acc, test_f1, test_precision, test_recall = self.evaluate(self.test_loader, "Test")
            logger.info(f"Final test results - "
                       f"Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, F1: {test_f1:.4f}, "
                       f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}")
            print(f"Final test results - ",
                       f"Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, F1: {test_f1:.4f}, ",
                       f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}")
    
    def evaluate(self, data_loader, phase="Validation", threshold=0.55):
        """
        Evaluation function for both validation and test sets
        """
        self.model.eval()
        eval_loss = 0
        all_predictions = np.array([], dtype=int)
        all_labels = np.array([], dtype=int)
        
        # No gradient computation during evaluation
        with torch.no_grad():
            # Progress bar for evaluation
            iterator = tqdm(data_loader, desc=f"[{phase}]")
            for batch in iterator:
                # Move batch to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                token_type_ids = batch['token_type_ids'].to(self.device)
                labels = batch['label'].to(self.device)
                
                # Forward pass
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids
                )
                
                # Calculate loss
                if self.num_categories > 1:
                    total_loss = 0
                    for i in range(self.num_categories):
                        start_idx = i * self.num_classes
                        end_idx = (i + 1) * self.num_classes
                        category_outputs = outputs[:, start_idx:end_idx] # Shape (batch, num_classes)
                        category_labels = labels[:, i] # Shape (batch)

                        # Ensure category_labels are in [0, self.num_classes - 1]
                        if category_labels.max() >= self.num_classes or category_labels.min() < 0:
                            print(f"ERROR: Category {i} labels out of range [0, {self.num_classes - 1}]: min={category_labels.min()}, max={category_labels.max()}")
                            
                        total_loss += self.criterion(category_outputs, category_labels)

                    loss = total_loss / self.num_categories # Average loss
                else:
                    loss = self.criterion(outputs, labels)

                eval_loss += loss.item()
                
                # Get predictions
                # Get predictions for metrics
                if self.num_categories > 1:
                    batch_size, total_classes = outputs.shape
                    if total_classes % self.num_categories != 0:
                        raise ValueError(f"Error: Number of total classes in the batch must of divisible by {self.num_categories}")

                    classes_per_group = total_classes // self.num_categories
                    # Group every classes_per_group values along dim=1
                    reshaped = outputs.view(outputs.size(0), -1, classes_per_group)  # shape: (batch, self., classes_per_group)

                    # Softmax and apply threshold
                    probs = torch.softmax(reshaped, dim=1)
                    probs = torch.where(probs > threshold, probs, 0.0)
                    # Argmax over each group of classes_per_group
                    preds = probs.argmax(dim=-1)
                else:
                    _, preds = torch.max(outputs, dim=1)

                all_predictions = np.append(all_predictions, preds.cpu().tolist())
                all_labels = np.append(all_labels, labels.cpu().tolist())
        

        # Calculate metrics
        eval_loss /= len(data_loader)
        accuracy = accuracy_score(all_labels, all_predictions)
        f1 = f1_score(all_labels, all_predictions, average='weighted')
        precision = precision_score(all_labels, all_predictions, average='weighted', zero_division=0)
        recall = recall_score(all_labels, all_predictions, average='weighted', zero_division=0)
        
        return eval_loss, accuracy, f1, precision, recall