0Shot1Shot-v0.1 / train_model.py
CyborgPaloma's picture
Upload 2 files
45f95fb verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split, RandomSampler, SequentialSampler
import logging
import argparse
import json
from datetime import datetime
import optuna
from prepare_data import SpectrogramDataset, collate_fn
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import os
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
else:
self.shortcut = nn.Identity()
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class AudioResNet(nn.Module):
def __init__(self, num_classes=6, dropout_rate=0.5):
super(AudioResNet, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1)
self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2)
self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2)
self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2)
self.dropout = nn.Dropout(dropout_rate)
self.gap = nn.AdaptiveAvgPool2d((1, 1)) # Global Average Pooling
self.fc1 = nn.Linear(512, 1024)
self.fc2 = nn.Linear(1024, num_classes)
def _make_layer(self, in_channels, out_channels, num_blocks, stride):
layers = []
for i in range(num_blocks):
layers.append(ResidualBlock(in_channels if i == 0 else out_channels, out_channels, stride if i == 0 else 1))
return nn.Sequential(*layers)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.gap(x) # Apply Global Average Pooling
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# Example device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()
fh = logging.FileHandler('training.log')
fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
def parse_args():
parser = argparse.ArgumentParser(description='Train Sample Classifier Model')
parser.add_argument('--config', type=str, required=True, help='Path to the config file')
return parser.parse_args()
def load_config(config_path):
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(config_path, 'r') as f:
config = json.load(f)
return config
def train_one_epoch(model, train_loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
total_correct = 0
for batch_idx, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs.unsqueeze(1))
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
total_correct += (predicted == labels).sum().item()
train_loss = running_loss / len(train_loader.dataset)
train_accuracy = total_correct / len(train_loader.dataset)
return train_loss, train_accuracy
def validate_one_epoch(model, val_loader, criterion, device):
model.eval()
val_loss = 0.0
val_correct = 0
with torch.no_grad():
for batch_idx, (inputs, labels) in enumerate(val_loader):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs.unsqueeze(1))
loss = criterion(outputs, labels)
val_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
val_correct += (predicted == labels).sum().item()
val_loss /= len(val_loader.dataset)
val_accuracy = val_correct / len(val_loader.dataset)
return val_loss, val_accuracy
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, patience=10, max_epochs=50):
best_loss = float('inf')
patience_counter = 0
for epoch in range(max_epochs):
train_loss, train_accuracy = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_accuracy = validate_one_epoch(model, val_loader, criterion, device)
log_message = (f'Epoch {epoch + 1}:\n'
f'Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, '
f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}\n')
logging.info(log_message)
scheduler.step(val_loss)
current_lr = optimizer.param_groups[0]['lr']
logging.info(f'Current learning rate: {current_lr}')
if val_loss < best_loss:
best_loss = val_loss
patience_counter = 0
torch.save(model.state_dict(), 'best_model.pth')
else:
patience_counter += 1
if patience_counter >= patience:
logging.info('Early stopping triggered')
break
if (epoch + 1) % 10 == 0:
checkpoint_path = f'checkpoint_epoch_{epoch + 1}.pth'
torch.save(model.state_dict(), checkpoint_path)
logging.info(f'Model saved to {checkpoint_path}')
def evaluate_model(model, test_loader, device, class_names):
model.eval()
all_labels = []
all_preds = []
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs.unsqueeze(1))
_, preds = torch.max(outputs, 1)
all_labels.extend(labels.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
logging.info(classification_report(all_labels, all_preds, target_names=class_names))
plot_confusion_matrix(all_labels, all_preds, class_names)
def plot_confusion_matrix(labels, preds, class_names, save_path=None):
cm = confusion_matrix(labels, preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Confusion Matrix')
if save_path:
plt.savefig(save_path)
plt.show()
def objective(trial, train_loader, val_loader, num_classes):
learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True)
weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-3, log=True)
dropout_rate = trial.suggest_float('dropout_rate', 0.2, 0.5)
model = AudioResNet(num_classes=num_classes, dropout_rate=dropout_rate).to(device)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
best_loss = float('inf')
patience_counter = 0
for epoch in range(10):
train_loss, _ = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_loss, _ = validate_one_epoch(model, val_loader, criterion, device)
scheduler.step(val_loss)
if val_loss < best_loss:
best_loss = val_loss
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= 3:
break
return val_loss
def verify_dataset_and_loader(dataset, train_loader, val_loader, test_loader):
try:
logger.info(f"Dataset length: {len(dataset)}")
logger.info(f"Train dataset length: {len(train_loader.dataset)}")
logger.info(f"Validation dataset length: {len(val_loader.dataset)}")
logger.info(f"Test dataset length: {len(test_loader.dataset)}")
for idx in range(len(train_loader.dataset)):
_ = train_loader.dataset[idx]
logger.info("Train dataset verification passed")
for idx in range(len(val_loader.dataset)):
_ = val_loader.dataset[idx]
logger.info("Validation dataset verification passed")
for idx in range(len(test_loader.dataset)):
_ = test_loader.dataset[idx]
logger.info("Test dataset verification passed")
except IndexError as e:
logger.error(f"Dataset index error: {e}")
def verify_sampler_indices(loader, name):
indices = list(loader.sampler)
logger.info(f"{name} sampler indices: {indices[:10]}... (total: {len(indices)})")
max_index = max(indices)
if max_index >= len(loader.dataset):
logger.error(f"{name} sampler index out of range: {max_index} >= {len(loader.dataset)}")
else:
logger.info(f"{name} sampler indices within range.")
def main():
try:
args = parse_args()
config = load_config(args.config)
dataset = SpectrogramDataset(config, config['directory'], process_new=True)
if len(dataset) == 0:
raise ValueError("The dataset is empty. Please check the data loading process.")
num_classes = len(dataset.label_to_index)
class_names = list(dataset.label_to_index.keys())
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
train_labels = [dataset.labels[i] for i in train_dataset.indices]
class_counts = np.bincount(train_labels)
class_weights = 1. / class_counts
sample_weights = class_weights[train_labels]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=RandomSampler(val_dataset))
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=SequentialSampler(test_dataset))
verify_dataset_and_loader(dataset, train_loader, val_loader, test_loader)
verify_sampler_indices(train_loader, "Train")
verify_sampler_indices(val_loader, "Validation")
verify_sampler_indices(test_loader, "Test")
study = optuna.create_study(direction='minimize')
study.optimize(lambda trial: objective(trial, train_loader, val_loader, num_classes), n_trials=50)
print('Best hyperparameters: ', study.best_params)
best_params = study.best_params
model = AudioResNet(num_classes=num_classes, dropout_rate=best_params['dropout_rate']).to(device)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=best_params['learning_rate'], weight_decay=best_params['weight_decay'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, patience=config['patience'])
model.load_state_dict(torch.load('best_model.pth'))
evaluate_model(model, test_loader, device, class_names)
except Exception as e:
logging.error(f"An error occurred: {e}")
if __name__ == '__main__':
main()