File size: 5,496 Bytes
c0d8e31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
import torch.nn as nn
import torch.optim as optim
import logging
import argparse
import json
from datetime import datetime
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split, RandomSampler, SequentialSampler
from prepare_data import SpectrogramDataset, collate_fn
from train_model import (
AudioResNet,
train_one_epoch,
validate_one_epoch,
evaluate_model,
plot_confusion_matrix,
device
)
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import os
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()
fh = logging.FileHandler('finish_training.log')
fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
def parse_args():
parser = argparse.ArgumentParser(description='Train Sample Classifier Model')
parser.add_argument('--config', type=str, required=True, help='Path to the config file')
return parser.parse_args()
def load_config(config_path):
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(config_path, 'r') as f:
config = json.load(f)
return config
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, patience=10, max_epochs=50):
best_loss = float('inf')
patience_counter = 0
for epoch in range(max_epochs):
train_loss, train_accuracy = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_accuracy = validate_one_epoch(model, val_loader, criterion, device)
log_message = (f'Epoch {epoch + 1}:\n'
f'Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, '
f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}\n')
logging.info(log_message)
scheduler.step(val_loss)
current_lr = optimizer.param_groups[0]['lr']
logging.info(f'Current learning rate: {current_lr}')
if val_loss < best_loss:
best_loss = val_loss
patience_counter = 0
torch.save(model.state_dict(), 'best_model.pth')
else:
patience_counter += 1
if patience_counter >= patience:
logging.info('Early stopping triggered')
break
if (epoch + 1) % 10 == 0:
checkpoint_path = f'checkpoint_epoch_{epoch + 1}.pth'
torch.save(model.state_dict(), checkpoint_path)
logging.info(f'Model saved to {checkpoint_path}')
def main():
try:
args = parse_args()
config = load_config(args.config)
dataset = SpectrogramDataset(config, config['directory'], process_new=True)
if len(dataset) == 0:
raise ValueError("The dataset is empty. Please check the data loading process.")
num_classes = len(dataset.label_to_index)
class_names = list(dataset.label_to_index.keys())
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
train_labels = [dataset.labels[i] for i in train_dataset.indices]
class_counts = np.bincount(train_labels)
class_weights = 1. / class_counts
sample_weights = class_weights[train_labels]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=RandomSampler(val_dataset))
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=SequentialSampler(test_dataset))
# Load best hyperparameters
best_params = {'learning_rate': 0.00014687223021475341, 'weight_decay': 2.970399818935859e-05, 'dropout_rate': 0.36508234143710705}
model = AudioResNet(num_classes=num_classes, dropout_rate=best_params['dropout_rate']).to(device)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=best_params['learning_rate'], weight_decay=best_params['weight_decay'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
# Load the previously saved best model
if os.path.exists('checkpoint_epoch_50.pth'):
model.load_state_dict(torch.load('checkpoint_epoch_50.pth'))
logging.info("Loaded the best model from previous training.")
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, patience=config['patience'], max_epochs=50)
model.load_state_dict(torch.load('checkpoint_epoch_50.pth'))
evaluate_model(model, test_loader, device, class_names)
except Exception as e:
logging.error(f"An error occurred: {e}")
if __name__ == '__main__':
main() |