import torch import pandas as pd import os from tqdm import tqdm from torch.utils.data import DataLoader from torch.cuda.amp import autocast from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score import numpy as np import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from dataset2 import MedicalImageDatasetBalancedIntensity3D from model import Backbone, SingleScanModel, Classifier from utils import BaseConfig def calculate_metrics(pred_probs, pred_labels, true_labels): """ Multi-class classification metrics. Args: pred_probs (numpy.ndarray): Predicted probabilities for each class pred_labels (numpy.ndarray): Predicted labels true_labels (numpy.ndarray): Ground truth labels Returns: dict: Dictionary containing accuracy, precision, recall, F1, and AUC metrics """ accuracy = accuracy_score(true_labels, pred_labels) precision = precision_score(true_labels, pred_labels, average='weighted') recall = recall_score(true_labels, pred_labels, average='weighted') f1 = f1_score(true_labels, pred_labels, average='weighted') # For multi-class, we use ROC AUC OVR (One-vs-Rest) auc = roc_auc_score(true_labels, pred_probs, multi_class='ovr') return { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc } #============================ # INFERENCE CLASS #============================ class SequenceInference(BaseConfig): """ Inference class for sequence classification model. """ def __init__(self): super().__init__() self.setup_model() self.setup_data() def setup_model(self): config = self.get_config() self.backbone = Backbone() self.classifier = Classifier(d_model=2048, num_classes=4) # 4-way classification self.model = SingleScanModel(self.backbone, self.classifier) # Load weights checkpoint = torch.load(config["infer"]["checkpoints"], map_location=self.device, weights_only=False) self.model.load_state_dict(checkpoint["model_state_dict"]) self.model = self.model.to(self.device) self.model.eval() print("Model and checkpoint loaded!") ## spin up data loaders def setup_data(self): config = self.get_config() self.test_dataset = MedicalImageDatasetBalancedIntensity3D( csv_path=config["data"]["test_csv"], root_dir=config["data"]["root_dir"] ) self.test_loader = DataLoader( self.test_dataset, batch_size=1, shuffle=False, collate_fn=self.custom_collate, num_workers=1 ) def infer(self): """ Run inference pass Returns: dict: Dictionary with evaluation metrics """ results_df = pd.DataFrame(columns=['PredictedProbs_Class0', 'PredictedProbs_Class1', 'PredictedProbs_Class2', 'PredictedProbs_Class3', 'PredictedLabel', 'TrueLabel']) all_labels = [] all_predictions = [] all_probs = [] with torch.no_grad(): for sample in tqdm(self.test_loader, desc="Inference", unit="batch"): inputs = sample['image'].to(self.device) labels = sample['label'].to(self.device) with autocast(): outputs = self.model(inputs) # Apply softmax to get probabilities probs = torch.softmax(outputs, dim=1).cpu().numpy() preds = np.argmax(probs, axis=1) all_labels.extend(labels.cpu().numpy()) all_predictions.extend(preds) all_probs.extend(probs) result = pd.DataFrame({ 'PredictedProbs_Class0': probs[:, 0], 'PredictedProbs_Class1': probs[:, 1], 'PredictedProbs_Class2': probs[:, 2], 'PredictedProbs_Class3': probs[:, 3], 'PredictedLabel': preds, 'TrueLabel': labels.cpu().numpy() }) results_df = pd.concat([results_df, result], ignore_index=True) # log metrics """metrics = calculate_metrics( np.array(all_probs), np.array(all_predictions), np.array(all_labels) ) print("\nTest Set Metrics:") print(f"Accuracy: {metrics['accuracy']:.4f}") print(f"Precision: {metrics['precision']:.4f}") print(f"Recall: {metrics['recall']:.4f}") print(f"F1 Score: {metrics['f1']:.4f}") print(f"AUC: {metrics['auc']:.4f}")""" # Save results print("PredictedLabel", preds) results_df.to_csv('./data/output/sequence_classification_predictions.csv', index=False) return None #metrics if __name__ == "__main__": inferencer = SequenceInference() _ = inferencer.infer()