File size: 3,526 Bytes
6f9bfc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import pandas as pd
import time
from torch.nn.utils import clip_grad_norm_


def train(model, train_loader, val_loader, criterion, optimizer, epochs, device, version, max_grad_norm=1.0, early_stopping_patience=5, early_stopping_delta=0.001):
    best_accuracy = 0.0
    best_model_path = f'./output/version_{version}/best_model_{version}.pth'
    best_epoch = 0
    early_stopping_counter = 0
    total_batches = len(train_loader)
    metrics = {
        'epoch': [], 'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []
    }

    for epoch in range(epochs):
        model.train()
        total_loss, train_correct, train_total = 0, 0, 0
        for batch_idx, (titles, texts, labels) in enumerate(train_loader):
            start_time = time.time()  # Start time for the batch

            titles, texts, labels = titles.to(device), texts.to(
                device), labels.to(device).float()

            # Forward pass
            outputs = model(titles, texts).squeeze()
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            if max_grad_norm:
                clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
            optimizer.step()

            total_loss += loss.item()
            train_pred = (outputs > 0.5).float()
            train_correct += (train_pred == labels).sum().item()
            train_total += labels.size(0)

            # Calculate and print batch processing time
            batch_time = time.time() - start_time
            print(
                f'Epoch: {epoch+1}, Batch: {batch_idx+1}/{total_batches}, Batch Processing Time: {batch_time:.4f} seconds')

        train_accuracy = 100 * train_correct / train_total
        metrics['train_loss'].append(total_loss / len(train_loader))
        metrics['train_accuracy'].append(train_accuracy)

        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for titles, texts, labels in val_loader:
                titles, texts, labels = titles.to(device), texts.to(
                    device), labels.to(device).float()
                outputs = model(titles, texts).squeeze()
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                predicted = (outputs > 0.5).float()
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_accuracy = 100 * val_correct / val_total
        metrics['val_loss'].append(val_loss / len(val_loader))
        metrics['val_accuracy'].append(val_accuracy)
        metrics['epoch'].append(epoch + 1)

        # Early stopping logic
        if val_accuracy > best_accuracy + early_stopping_delta:
            best_accuracy = val_accuracy
            early_stopping_counter = 0
            best_epoch = epoch + 1
            torch.save(model.state_dict(), best_model_path)
        else:
            early_stopping_counter += 1

        if early_stopping_counter >= early_stopping_patience:
            print(f"Early stopping triggered at epoch {epoch + 1}")
            break

        print(
            f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%')

    pd.DataFrame(metrics).to_csv(
        f'./output/version_{version}/training_metrics_{version}.csv', index=False)

    return model, best_accuracy, best_epoch