|
import torch |
|
import pandas as pd |
|
import time |
|
from torch.nn.utils import clip_grad_norm_ |
|
|
|
|
|
def train(model, train_loader, val_loader, criterion, optimizer, epochs, device, version, max_grad_norm=1.0, early_stopping_patience=5, early_stopping_delta=0.001): |
|
best_accuracy = 0.0 |
|
best_model_path = f'./output/version_{version}/best_model_{version}.pth' |
|
best_epoch = 0 |
|
early_stopping_counter = 0 |
|
total_batches = len(train_loader) |
|
metrics = { |
|
'epoch': [], 'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': [] |
|
} |
|
|
|
for epoch in range(epochs): |
|
model.train() |
|
total_loss, train_correct, train_total = 0, 0, 0 |
|
for batch_idx, (titles, texts, labels) in enumerate(train_loader): |
|
start_time = time.time() |
|
|
|
titles, texts, labels = titles.to(device), texts.to( |
|
device), labels.to(device).float() |
|
|
|
|
|
outputs = model(titles, texts).squeeze() |
|
loss = criterion(outputs, labels) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
if max_grad_norm: |
|
clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) |
|
optimizer.step() |
|
|
|
total_loss += loss.item() |
|
train_pred = (outputs > 0.5).float() |
|
train_correct += (train_pred == labels).sum().item() |
|
train_total += labels.size(0) |
|
|
|
|
|
batch_time = time.time() - start_time |
|
print( |
|
f'Epoch: {epoch+1}, Batch: {batch_idx+1}/{total_batches}, Batch Processing Time: {batch_time:.4f} seconds') |
|
|
|
train_accuracy = 100 * train_correct / train_total |
|
metrics['train_loss'].append(total_loss / len(train_loader)) |
|
metrics['train_accuracy'].append(train_accuracy) |
|
|
|
|
|
model.eval() |
|
val_loss, val_correct, val_total = 0, 0, 0 |
|
with torch.no_grad(): |
|
for titles, texts, labels in val_loader: |
|
titles, texts, labels = titles.to(device), texts.to( |
|
device), labels.to(device).float() |
|
outputs = model(titles, texts).squeeze() |
|
loss = criterion(outputs, labels) |
|
val_loss += loss.item() |
|
predicted = (outputs > 0.5).float() |
|
val_total += labels.size(0) |
|
val_correct += (predicted == labels).sum().item() |
|
|
|
val_accuracy = 100 * val_correct / val_total |
|
metrics['val_loss'].append(val_loss / len(val_loader)) |
|
metrics['val_accuracy'].append(val_accuracy) |
|
metrics['epoch'].append(epoch + 1) |
|
|
|
|
|
if val_accuracy > best_accuracy + early_stopping_delta: |
|
best_accuracy = val_accuracy |
|
early_stopping_counter = 0 |
|
best_epoch = epoch + 1 |
|
torch.save(model.state_dict(), best_model_path) |
|
else: |
|
early_stopping_counter += 1 |
|
|
|
if early_stopping_counter >= early_stopping_patience: |
|
print(f"Early stopping triggered at epoch {epoch + 1}") |
|
break |
|
|
|
print( |
|
f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%') |
|
|
|
pd.DataFrame(metrics).to_csv( |
|
f'./output/version_{version}/training_metrics_{version}.csv', index=False) |
|
|
|
return model, best_accuracy, best_epoch |
|
|