import torch import torch.nn as nn import torch.optim as optim import logging import argparse import json from datetime import datetime from torch.utils.data import DataLoader, WeightedRandomSampler, random_split, RandomSampler, SequentialSampler from prepare_data import SpectrogramDataset, collate_fn from train_model import ( AudioResNet, train_one_epoch, validate_one_epoch, evaluate_model, plot_confusion_matrix, device ) from sklearn.metrics import classification_report, confusion_matrix import numpy as np import os # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger() fh = logging.FileHandler('finish_training.log') fh.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') fh.setFormatter(formatter) ch.setFormatter(formatter) logger.addHandler(fh) logger.addHandler(ch) def parse_args(): parser = argparse.ArgumentParser(description='Train Sample Classifier Model') parser.add_argument('--config', type=str, required=True, help='Path to the config file') return parser.parse_args() def load_config(config_path): if not os.path.exists(config_path): raise FileNotFoundError(f"Config file not found: {config_path}") with open(config_path, 'r') as f: config = json.load(f) return config def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, patience=10, max_epochs=50): best_loss = float('inf') patience_counter = 0 for epoch in range(max_epochs): train_loss, train_accuracy = train_one_epoch(model, train_loader, criterion, optimizer, device) val_loss, val_accuracy = validate_one_epoch(model, val_loader, criterion, device) log_message = (f'Epoch {epoch + 1}:\n' f'Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, ' f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}\n') logging.info(log_message) scheduler.step(val_loss) current_lr = optimizer.param_groups[0]['lr'] logging.info(f'Current learning rate: {current_lr}') if val_loss < best_loss: best_loss = val_loss patience_counter = 0 torch.save(model.state_dict(), 'best_model.pth') else: patience_counter += 1 if patience_counter >= patience: logging.info('Early stopping triggered') break if (epoch + 1) % 10 == 0: checkpoint_path = f'checkpoint_epoch_{epoch + 1}.pth' torch.save(model.state_dict(), checkpoint_path) logging.info(f'Model saved to {checkpoint_path}') def main(): try: args = parse_args() config = load_config(args.config) dataset = SpectrogramDataset(config, config['directory'], process_new=True) if len(dataset) == 0: raise ValueError("The dataset is empty. Please check the data loading process.") num_classes = len(dataset.label_to_index) class_names = list(dataset.label_to_index.keys()) train_size = int(0.7 * len(dataset)) val_size = int(0.15 * len(dataset)) test_size = len(dataset) - train_size - val_size train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size]) train_labels = [dataset.labels[i] for i in train_dataset.indices] class_counts = np.bincount(train_labels) class_weights = 1. / class_counts sample_weights = class_weights[train_labels] sampler = WeightedRandomSampler(sample_weights, len(sample_weights)) train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=sampler) val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=RandomSampler(val_dataset)) test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=SequentialSampler(test_dataset)) # Load best hyperparameters best_params = {'learning_rate': 0.00014687223021475341, 'weight_decay': 2.970399818935859e-05, 'dropout_rate': 0.36508234143710705} model = AudioResNet(num_classes=num_classes, dropout_rate=best_params['dropout_rate']).to(device) criterion = nn.NLLLoss() optimizer = optim.Adam(model.parameters(), lr=best_params['learning_rate'], weight_decay=best_params['weight_decay']) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3) # Load the previously saved best model if os.path.exists('checkpoint_epoch_50.pth'): model.load_state_dict(torch.load('checkpoint_epoch_50.pth')) logging.info("Loaded the best model from previous training.") train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, patience=config['patience'], max_epochs=50) model.load_state_dict(torch.load('checkpoint_epoch_50.pth')) evaluate_model(model, test_loader, device, class_names) except Exception as e: logging.error(f"An error occurred: {e}") if __name__ == '__main__': main()