|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader |
|
|
import wandb |
|
|
from tqdm import tqdm |
|
|
from torch.optim.lr_scheduler import OneCycleLR |
|
|
from torch.cuda.amp import GradScaler, autocast |
|
|
import os |
|
|
import sys |
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
from dataset2 import MedicalImageDatasetBalancedIntensity3D, TransformationMedicalImageDatasetBalancedIntensity3D |
|
|
from model import Backbone, SingleScanModel, Classifier |
|
|
from utils import BaseConfig |
|
|
import numpy as np |
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score |
|
|
|
|
|
|
|
|
def calculate_metrics(pred_probs, pred_labels, true_labels): |
|
|
""" |
|
|
classification metrics. |
|
|
|
|
|
Args: |
|
|
pred_probs (numpy.ndarray): Predicted probabilities |
|
|
pred_labels (numpy.ndarray): Predicted labels |
|
|
true_labels (numpy.ndarray): Ground truth labels |
|
|
|
|
|
Returns: |
|
|
dict: Dictionary containing accuracy, precision, recall, F1, and AUC |
|
|
""" |
|
|
accuracy = accuracy_score(true_labels, pred_labels) |
|
|
precision = precision_score(true_labels, pred_labels) |
|
|
recall = recall_score(true_labels, pred_labels) |
|
|
f1 = f1_score(true_labels, pred_labels) |
|
|
auc = roc_auc_score(true_labels, pred_probs) |
|
|
|
|
|
return { |
|
|
'accuracy': accuracy, |
|
|
'precision': precision, |
|
|
'recall': recall, |
|
|
'f1': f1, |
|
|
'auc': auc |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MCITrainer(BaseConfig): |
|
|
""" |
|
|
trainer class for MCI classification |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.setup_wandb() |
|
|
self.setup_model() |
|
|
self.setup_data() |
|
|
self.setup_training() |
|
|
|
|
|
def setup_wandb(self): |
|
|
config = self.get_config() |
|
|
wandb.init( |
|
|
project=config['logger']['project_name'], |
|
|
name=config['logger']['run_name'], |
|
|
config=config |
|
|
) |
|
|
|
|
|
def setup_model(self): |
|
|
self.backbone = Backbone() |
|
|
|
|
|
self.classifier = Classifier(d_model=2048, num_classes=1) |
|
|
self.model = SingleScanModel(self.backbone, self.classifier) |
|
|
|
|
|
|
|
|
config = self.get_config() |
|
|
if config["train"]["finetune"] == "yes": |
|
|
checkpoint = torch.load(config["train"]["weights"], map_location=self.device) |
|
|
state_dict = checkpoint["state_dict"] |
|
|
filtered_state_dict = {} |
|
|
for key, value in state_dict.items(): |
|
|
new_key = key.replace("module.", "backbone.") if key.startswith("module.") else key |
|
|
filtered_state_dict[new_key] = value |
|
|
self.model.backbone.load_state_dict(filtered_state_dict, strict=False) |
|
|
print("Pretrained weights loaded!") |
|
|
|
|
|
if config["train"]["freeze"] == "yes": |
|
|
for param in self.model.backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
print("Backbone weights frozen!") |
|
|
|
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
|
|
|
def setup_data(self): |
|
|
config = self.get_config() |
|
|
self.train_dataset = TransformationMedicalImageDatasetBalancedIntensity3D( |
|
|
csv_path=config['data']['train_csv'], |
|
|
root_dir=config["data"]["root_dir"] |
|
|
) |
|
|
self.val_dataset = MedicalImageDatasetBalancedIntensity3D( |
|
|
csv_path=config['data']['val_csv'], |
|
|
root_dir=config["data"]["root_dir"] |
|
|
) |
|
|
|
|
|
self.train_loader = DataLoader( |
|
|
self.train_dataset, |
|
|
batch_size=config["data"]["batch_size"], |
|
|
shuffle=True, |
|
|
collate_fn=self.custom_collate, |
|
|
num_workers=config["data"]["num_workers"] |
|
|
) |
|
|
self.val_loader = DataLoader( |
|
|
self.val_dataset, |
|
|
batch_size=1, |
|
|
shuffle=False, |
|
|
collate_fn=self.custom_collate, |
|
|
num_workers=1 |
|
|
) |
|
|
|
|
|
def setup_training(self): |
|
|
""" |
|
|
training setup |
|
|
""" |
|
|
config = self.get_config() |
|
|
|
|
|
self.criterion = nn.BCEWithLogitsLoss().to(self.device) |
|
|
self.optimizer = optim.AdamW( |
|
|
self.model.parameters(), |
|
|
lr=config['optim']['lr'], |
|
|
weight_decay=config["optim"]["weight_decay"] |
|
|
) |
|
|
self.scheduler = OneCycleLR( |
|
|
self.optimizer, |
|
|
max_lr=config['optim']['lr'], |
|
|
epochs=config['optim']['max_epochs'], |
|
|
steps_per_epoch=len(self.train_loader) |
|
|
) |
|
|
self.scaler = GradScaler() |
|
|
|
|
|
|
|
|
def train(self): |
|
|
config = self.get_config() |
|
|
max_epochs = config['optim']['max_epochs'] |
|
|
best_metrics = { |
|
|
'val_loss': float('inf'), |
|
|
'accuracy': 0, |
|
|
'precision': 0, |
|
|
'recall': 0, |
|
|
'f1': 0, |
|
|
'auc': 0 |
|
|
} |
|
|
|
|
|
for epoch in range(max_epochs): |
|
|
train_loss = self.train_epoch(epoch, max_epochs) |
|
|
val_loss, metrics = self.validate_epoch(epoch, max_epochs) |
|
|
|
|
|
|
|
|
if metrics['auc'] > best_metrics['auc']: |
|
|
print(f"New best model found!") |
|
|
print(f"Improved Val Loss from {best_metrics['val_loss']:.4f} to {val_loss:.4f}") |
|
|
print(f"Improved F1 from {best_metrics['f1']:.4f} to {metrics['f1']:.4f}") |
|
|
best_metrics.update(metrics) |
|
|
best_metrics['val_loss'] = val_loss |
|
|
self.save_checkpoint(epoch, val_loss, metrics) |
|
|
|
|
|
wandb.finish() |
|
|
|
|
|
|
|
|
def train_epoch(self, epoch, max_epochs): |
|
|
self.model.train() |
|
|
train_loss = 0.0 |
|
|
|
|
|
for sample in tqdm(self.train_loader, desc=f"Training Epoch {epoch}/{max_epochs-1}"): |
|
|
inputs = sample['image'].to(self.device) |
|
|
labels = sample['label'].float().to(self.device) |
|
|
|
|
|
self.optimizer.zero_grad(set_to_none=True) |
|
|
with autocast(): |
|
|
outputs = self.model(inputs) |
|
|
loss = self.criterion(outputs, labels.unsqueeze(1)) |
|
|
|
|
|
self.scaler.scale(loss).backward() |
|
|
|
|
|
self.scaler.unscale_(self.optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
|
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
self.scheduler.step() |
|
|
|
|
|
train_loss += loss.item() * inputs.size(0) |
|
|
|
|
|
train_loss = train_loss / len(self.train_loader.dataset) |
|
|
wandb.log({"Train Loss": train_loss}) |
|
|
return train_loss |
|
|
|
|
|
|
|
|
def validate_epoch(self, epoch, max_epochs): |
|
|
self.model.eval() |
|
|
val_loss = 0.0 |
|
|
all_labels = [] |
|
|
all_preds = [] |
|
|
all_probs = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for sample in tqdm(self.val_loader, desc=f"Validation Epoch {epoch}/{max_epochs-1}"): |
|
|
inputs = sample['image'].to(self.device) |
|
|
labels = sample['label'].float().to(self.device) |
|
|
|
|
|
outputs = self.model(inputs) |
|
|
loss = self.criterion(outputs, labels.unsqueeze(1)) |
|
|
|
|
|
|
|
|
probs = torch.sigmoid(outputs).cpu().numpy() |
|
|
preds = (probs > 0.5).astype(int) |
|
|
|
|
|
val_loss += loss.item() * inputs.size(0) |
|
|
all_labels.extend(labels.cpu().numpy().flatten()) |
|
|
all_preds.extend(preds.flatten()) |
|
|
all_probs.extend(probs.flatten()) |
|
|
|
|
|
val_loss = val_loss / len(self.val_loader.dataset) |
|
|
metrics = calculate_metrics( |
|
|
np.array(all_probs), |
|
|
np.array(all_preds), |
|
|
np.array(all_labels) |
|
|
) |
|
|
|
|
|
wandb.log({ |
|
|
"Val Loss": val_loss, |
|
|
"Accuracy": metrics['accuracy'], |
|
|
"Precision": metrics['precision'], |
|
|
"Recall": metrics['recall'], |
|
|
"F1 Score": metrics['f1'], |
|
|
"AUC": metrics['auc'] |
|
|
}) |
|
|
|
|
|
print(f"Epoch {epoch}/{max_epochs-1}") |
|
|
print(f"Val Loss: {val_loss:.4f}") |
|
|
print(f"Accuracy: {metrics['accuracy']:.4f}") |
|
|
print(f"Precision: {metrics['precision']:.4f}") |
|
|
print(f"Recall: {metrics['recall']:.4f}") |
|
|
print(f"F1 Score: {metrics['f1']:.4f}") |
|
|
print(f"AUC: {metrics['auc']:.4f}") |
|
|
|
|
|
return val_loss, metrics |
|
|
|
|
|
|
|
|
def save_checkpoint(self, epoch, loss, metrics): |
|
|
config = self.get_config() |
|
|
checkpoint = { |
|
|
'epoch': epoch, |
|
|
'model_state_dict': self.model.state_dict(), |
|
|
'metrics': metrics |
|
|
} |
|
|
save_path = os.path.join( |
|
|
config['logger']['save_dir'], |
|
|
config['logger']['save_name'].format(epoch=epoch, loss=loss, metric=metrics['f1']) |
|
|
) |
|
|
torch.save(checkpoint, save_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
trainer = MCITrainer() |
|
|
trainer.train() |