import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, WeightedRandomSampler, random_split, RandomSampler, SequentialSampler import logging import argparse import json from datetime import datetime import optuna from prepare_data import SpectrogramDataset, collate_fn from sklearn.metrics import classification_report, confusion_matrix import matplotlib.pyplot as plt import seaborn as sns import os import numpy as np device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f'Using device: {device}') class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) else: self.shortcut = nn.Identity() def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class AudioResNet(nn.Module): def __init__(self, num_classes=6, dropout_rate=0.5): super(AudioResNet, self).__init__() self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1) self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2) self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2) self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2) self.dropout = nn.Dropout(dropout_rate) self.gap = nn.AdaptiveAvgPool2d((1, 1)) # Global Average Pooling self.fc1 = nn.Linear(512, 1024) self.fc2 = nn.Linear(1024, num_classes) def _make_layer(self, in_channels, out_channels, num_blocks, stride): layers = [] for i in range(num_blocks): layers.append(ResidualBlock(in_channels if i == 0 else out_channels, out_channels, stride if i == 0 else 1)) return nn.Sequential(*layers) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.gap(x) # Apply Global Average Pooling x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.dropout(x) x = self.fc2(x) return F.log_softmax(x, dim=1) # Example device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f'Using device: {device}') # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger() fh = logging.FileHandler('training.log') fh.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') fh.setFormatter(formatter) ch.setFormatter(formatter) logger.addHandler(fh) logger.addHandler(ch) def parse_args(): parser = argparse.ArgumentParser(description='Train Sample Classifier Model') parser.add_argument('--config', type=str, required=True, help='Path to the config file') return parser.parse_args() def load_config(config_path): if not os.path.exists(config_path): raise FileNotFoundError(f"Config file not found: {config_path}") with open(config_path, 'r') as f: config = json.load(f) return config def train_one_epoch(model, train_loader, criterion, optimizer, device): model.train() running_loss = 0.0 total_correct = 0 for batch_idx, (inputs, labels) in enumerate(train_loader): inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs.unsqueeze(1)) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) _, predicted = torch.max(outputs, 1) total_correct += (predicted == labels).sum().item() train_loss = running_loss / len(train_loader.dataset) train_accuracy = total_correct / len(train_loader.dataset) return train_loss, train_accuracy def validate_one_epoch(model, val_loader, criterion, device): model.eval() val_loss = 0.0 val_correct = 0 with torch.no_grad(): for batch_idx, (inputs, labels) in enumerate(val_loader): inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs.unsqueeze(1)) loss = criterion(outputs, labels) val_loss += loss.item() * inputs.size(0) _, predicted = torch.max(outputs, 1) val_correct += (predicted == labels).sum().item() val_loss /= len(val_loader.dataset) val_accuracy = val_correct / len(val_loader.dataset) return val_loss, val_accuracy def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, patience=10, max_epochs=50): best_loss = float('inf') patience_counter = 0 for epoch in range(max_epochs): train_loss, train_accuracy = train_one_epoch(model, train_loader, criterion, optimizer, device) val_loss, val_accuracy = validate_one_epoch(model, val_loader, criterion, device) log_message = (f'Epoch {epoch + 1}:\n' f'Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, ' f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}\n') logging.info(log_message) scheduler.step(val_loss) current_lr = optimizer.param_groups[0]['lr'] logging.info(f'Current learning rate: {current_lr}') if val_loss < best_loss: best_loss = val_loss patience_counter = 0 torch.save(model.state_dict(), 'best_model.pth') else: patience_counter += 1 if patience_counter >= patience: logging.info('Early stopping triggered') break if (epoch + 1) % 10 == 0: checkpoint_path = f'checkpoint_epoch_{epoch + 1}.pth' torch.save(model.state_dict(), checkpoint_path) logging.info(f'Model saved to {checkpoint_path}') def evaluate_model(model, test_loader, device, class_names): model.eval() all_labels = [] all_preds = [] with torch.no_grad(): for inputs, labels in test_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs.unsqueeze(1)) _, preds = torch.max(outputs, 1) all_labels.extend(labels.cpu().numpy()) all_preds.extend(preds.cpu().numpy()) logging.info(classification_report(all_labels, all_preds, target_names=class_names)) plot_confusion_matrix(all_labels, all_preds, class_names) def plot_confusion_matrix(labels, preds, class_names, save_path=None): cm = confusion_matrix(labels, preds) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names) plt.ylabel('Actual') plt.xlabel('Predicted') plt.title('Confusion Matrix') if save_path: plt.savefig(save_path) plt.show() def objective(trial, train_loader, val_loader, num_classes): learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True) weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-3, log=True) dropout_rate = trial.suggest_float('dropout_rate', 0.2, 0.5) model = AudioResNet(num_classes=num_classes, dropout_rate=dropout_rate).to(device) criterion = nn.NLLLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3) best_loss = float('inf') patience_counter = 0 for epoch in range(10): train_loss, _ = train_one_epoch(model, train_loader, criterion, optimizer, device) val_loss, _ = validate_one_epoch(model, val_loader, criterion, device) scheduler.step(val_loss) if val_loss < best_loss: best_loss = val_loss patience_counter = 0 else: patience_counter += 1 if patience_counter >= 3: break return val_loss def verify_dataset_and_loader(dataset, train_loader, val_loader, test_loader): try: logger.info(f"Dataset length: {len(dataset)}") logger.info(f"Train dataset length: {len(train_loader.dataset)}") logger.info(f"Validation dataset length: {len(val_loader.dataset)}") logger.info(f"Test dataset length: {len(test_loader.dataset)}") for idx in range(len(train_loader.dataset)): _ = train_loader.dataset[idx] logger.info("Train dataset verification passed") for idx in range(len(val_loader.dataset)): _ = val_loader.dataset[idx] logger.info("Validation dataset verification passed") for idx in range(len(test_loader.dataset)): _ = test_loader.dataset[idx] logger.info("Test dataset verification passed") except IndexError as e: logger.error(f"Dataset index error: {e}") def verify_sampler_indices(loader, name): indices = list(loader.sampler) logger.info(f"{name} sampler indices: {indices[:10]}... (total: {len(indices)})") max_index = max(indices) if max_index >= len(loader.dataset): logger.error(f"{name} sampler index out of range: {max_index} >= {len(loader.dataset)}") else: logger.info(f"{name} sampler indices within range.") def main(): try: args = parse_args() config = load_config(args.config) dataset = SpectrogramDataset(config, config['directory'], process_new=True) if len(dataset) == 0: raise ValueError("The dataset is empty. Please check the data loading process.") num_classes = len(dataset.label_to_index) class_names = list(dataset.label_to_index.keys()) train_size = int(0.7 * len(dataset)) val_size = int(0.15 * len(dataset)) test_size = len(dataset) - train_size - val_size train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size]) train_labels = [dataset.labels[i] for i in train_dataset.indices] class_counts = np.bincount(train_labels) class_weights = 1. / class_counts sample_weights = class_weights[train_labels] sampler = WeightedRandomSampler(sample_weights, len(sample_weights)) train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=sampler) val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=RandomSampler(val_dataset)) test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=SequentialSampler(test_dataset)) verify_dataset_and_loader(dataset, train_loader, val_loader, test_loader) verify_sampler_indices(train_loader, "Train") verify_sampler_indices(val_loader, "Validation") verify_sampler_indices(test_loader, "Test") study = optuna.create_study(direction='minimize') study.optimize(lambda trial: objective(trial, train_loader, val_loader, num_classes), n_trials=50) print('Best hyperparameters: ', study.best_params) best_params = study.best_params model = AudioResNet(num_classes=num_classes, dropout_rate=best_params['dropout_rate']).to(device) criterion = nn.NLLLoss() optimizer = optim.Adam(model.parameters(), lr=best_params['learning_rate'], weight_decay=best_params['weight_decay']) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3) train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, patience=config['patience']) model.load_state_dict(torch.load('best_model.pth')) evaluate_model(model, test_loader, device, class_names) except Exception as e: logging.error(f"An error occurred: {e}") if __name__ == '__main__': main()