|
|
|
|
|
import os
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from datasets.all_classes_dataset import AllClassesDataset, DatasetSplit
|
|
from models.anomaly_detector import AnomalyDetector
|
|
from utils.dump_scores import DumpScores
|
|
import logging
|
|
import json
|
|
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
import random
|
|
|
|
def set_seed(seed: int):
|
|
"""
|
|
Set the seed for reproducibility across various libraries.
|
|
|
|
Args:
|
|
seed (int): The seed value to be set.
|
|
"""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
|
|
def worker_init_fn(worker_id):
|
|
"""
|
|
Initialize the seed for each DataLoader worker to ensure reproducibility.
|
|
|
|
Args:
|
|
worker_id (int): The worker ID.
|
|
"""
|
|
seed = torch.initial_seed()
|
|
np.random.seed(seed % 2**32)
|
|
random.seed(seed % 2**32)
|
|
|
|
def compute_aupro(y_true_pixel, y_scores_pixel, num_thresholds=50):
|
|
"""
|
|
Compute Area Under the Per-Region Overlap Curve (AUPRO).
|
|
|
|
Args:
|
|
y_true_pixel (np.ndarray): Ground truth binary masks, shape [N, H, W]
|
|
y_scores_pixel (np.ndarray): Predicted anomaly scores, shape [N, H, W]
|
|
num_thresholds (int): Number of thresholds to evaluate.
|
|
|
|
Returns:
|
|
float: AUPRO score.
|
|
"""
|
|
|
|
thresholds = np.linspace(0, 1, num_thresholds)
|
|
|
|
|
|
overlaps = []
|
|
|
|
for thresh in thresholds:
|
|
|
|
y_pred = (y_scores_pixel >= thresh).astype(int)
|
|
|
|
|
|
ious = []
|
|
for gt, pred in zip(y_true_pixel, y_pred):
|
|
intersection = np.logical_and(gt, pred).sum()
|
|
union = np.logical_or(gt, pred).sum()
|
|
if union == 0:
|
|
iou = 1.0
|
|
else:
|
|
iou = intersection / union
|
|
ious.append(iou)
|
|
|
|
|
|
avg_iou = np.mean(ious)
|
|
overlaps.append(avg_iou)
|
|
|
|
|
|
aupro = np.trapz(overlaps, thresholds) / np.trapz([1] * len(thresholds), thresholds)
|
|
return aupro
|
|
|
|
|
|
def compute_metrics(y_true_image, y_scores_image, y_true_pixel, y_scores_pixel):
|
|
"""
|
|
Compute the required metrics based on true labels and predicted scores.
|
|
|
|
Args:
|
|
y_true_image (np.ndarray): Ground truth image labels, shape [N]
|
|
y_scores_image (np.ndarray): Predicted image scores, shape [N]
|
|
y_true_pixel (np.ndarray): Ground truth pixel masks, shape [N, H, W]
|
|
y_scores_pixel (np.ndarray): Predicted pixel anomaly scores, shape [N, H, W]
|
|
|
|
Returns:
|
|
dict: Dictionary containing computed metrics.
|
|
"""
|
|
|
|
if len(y_true_image) != len(y_scores_image):
|
|
raise ValueError(f"Image-level y_true and y_scores have different lengths: {len(y_true_image)} vs {len(y_scores_image)}")
|
|
|
|
|
|
if y_true_pixel.shape != y_scores_pixel.shape:
|
|
raise ValueError(f"Pixel-level y_true and y_scores have different shapes: {y_true_pixel.shape} vs {y_scores_pixel.shape}")
|
|
|
|
|
|
image_ap = average_precision_score(y_true_image, y_scores_image)
|
|
image_auroc = roc_auc_score(y_true_image, y_scores_image)
|
|
y_pred_image = (y_scores_image >= 0.5).astype(int)
|
|
image_f1 = f1_score(y_true_image, y_pred_image)
|
|
|
|
|
|
pixel_ap = average_precision_score(y_true_pixel.flatten(), y_scores_pixel.flatten())
|
|
pixel_auroc = roc_auc_score(y_true_pixel.flatten(), y_scores_pixel.flatten())
|
|
pixel_aupro = compute_aupro(y_true_pixel, y_scores_pixel)
|
|
y_pred_pixel = (y_scores_pixel >= 0.5).astype(int)
|
|
pixel_f1 = f1_score(y_true_pixel.flatten(), y_pred_pixel.flatten())
|
|
|
|
|
|
|
|
leaderboard_score = (
|
|
0.25 * image_auroc +
|
|
0.25 * image_f1 +
|
|
0.25 * pixel_auroc +
|
|
0.25 * pixel_f1
|
|
)
|
|
|
|
metrics = {
|
|
"image_metrics": {
|
|
"image_ap": round(float(image_ap), 4),
|
|
"image_auroc": round(float(image_auroc), 4),
|
|
"image_f1": round(float(image_f1), 4)
|
|
},
|
|
"pixel_metrics": {
|
|
"pixel_ap": round(float(pixel_ap), 4),
|
|
"pixel_aupro": round(float(pixel_aupro), 4),
|
|
"pixel_auroc": round(float(pixel_auroc), 4),
|
|
"pixel_f1": round(float(pixel_f1), 4)
|
|
},
|
|
"overall_metric": {
|
|
"leaderboard_score": round(float(leaderboard_score), 4)
|
|
}
|
|
}
|
|
|
|
return metrics
|
|
|
|
|
|
def get_class_name(image_path, source_dir):
|
|
"""
|
|
Extract the class name from the image path.
|
|
|
|
Args:
|
|
image_path (str): Path to the image file.
|
|
source_dir (str): Root source directory.
|
|
|
|
Returns:
|
|
str: Class name.
|
|
"""
|
|
|
|
rel_path = os.path.relpath(image_path, source_dir)
|
|
parts = rel_path.split(os.sep)
|
|
if len(parts) < 2:
|
|
raise ValueError(f"Unexpected image path format: {image_path}")
|
|
class_name = parts[0]
|
|
return class_name
|
|
|
|
|
|
def main():
|
|
SEED = 41
|
|
set_seed(SEED)
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
source_dir = "./data"
|
|
output_scores_dir = "./output_scores"
|
|
split = DatasetSplit.TEST
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
logging.info("Initializing the dataset and dataloader...")
|
|
|
|
|
|
dataset = AllClassesDataset(
|
|
source=source_dir,
|
|
split=split,
|
|
|
|
)
|
|
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0)
|
|
|
|
logging.info("Initializing the anomaly detector...")
|
|
|
|
detector = AnomalyDetector(device=device)
|
|
|
|
|
|
dump_scores = DumpScores(output_dir=output_scores_dir)
|
|
|
|
logging.info("Starting anomaly detection inference...")
|
|
|
|
classes = dataset.get_all_class_names()
|
|
metrics_data = {cls: {
|
|
"y_true_image": [],
|
|
"y_scores_image": [],
|
|
"y_true_pixel": [],
|
|
"y_scores_pixel": []
|
|
} for cls in classes}
|
|
|
|
|
|
for batch_idx, batch in enumerate(dataloader):
|
|
image = batch['image'].squeeze(0)
|
|
mask = batch['mask'].squeeze(1).numpy()
|
|
image_label = batch['is_anomaly'].item()
|
|
image_path = batch['image_path'][0]
|
|
|
|
|
|
try:
|
|
class_name = get_class_name(image_path, source_dir)
|
|
except ValueError as e:
|
|
logging.error(f"Error extracting class name: {e}")
|
|
continue
|
|
|
|
|
|
image_score, anomaly_map = detector.extract_features(image, "all")
|
|
|
|
|
|
pixel_score = detector.compute_pixel_score(anomaly_map).squeeze()
|
|
|
|
pixel_score_tensor = torch.from_numpy(pixel_score).float().unsqueeze(0).unsqueeze(0).to(
|
|
device)
|
|
|
|
|
|
|
|
pixel_score = F.interpolate(
|
|
pixel_score_tensor,
|
|
size=(224, 224),
|
|
mode='bilinear',
|
|
align_corners=False
|
|
).squeeze(0).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics_data[class_name]["y_true_image"].append(image_label)
|
|
metrics_data[class_name]["y_scores_image"].append(image_score)
|
|
metrics_data[class_name]["y_true_pixel"].append(mask)
|
|
metrics_data[class_name]["y_scores_pixel"].append(pixel_score)
|
|
|
|
|
|
dump_scores.save_scores([image_path], [image_score], [pixel_score])
|
|
|
|
logging.info(f"[{batch_idx + 1}/{len(dataloader)}] Processed image: {image_path}")
|
|
logging.info(f"Image-level score: {image_score:.4f}")
|
|
logging.info(f"Pixel-level mean score: {pixel_score.mean():.4f}")
|
|
|
|
logging.info("Anomaly detection inference completed. Computing metrics...")
|
|
|
|
|
|
classes_metrics = {}
|
|
|
|
for cls in classes:
|
|
y_true_image = np.array(metrics_data[cls]["y_true_image"])
|
|
y_scores_image = np.array(metrics_data[cls]["y_scores_image"])
|
|
y_true_pixel = np.array(metrics_data[cls]["y_true_pixel"])
|
|
y_scores_pixel = np.array(metrics_data[cls]["y_scores_pixel"])
|
|
|
|
|
|
if len(y_true_image) == 0:
|
|
logging.warning(f"No samples found for class {cls}. Skipping metric computation.")
|
|
continue
|
|
|
|
try:
|
|
metrics = compute_metrics(y_true_image, y_scores_image, y_true_pixel, y_scores_pixel)
|
|
classes_metrics[cls] = metrics
|
|
logging.info(f"Metrics computed for class: {cls}")
|
|
except Exception as e:
|
|
logging.error(f"Failed to compute metrics for class {cls}: {e}")
|
|
|
|
|
|
os.makedirs(output_scores_dir, exist_ok=True)
|
|
metrics_json_path = os.path.join(output_scores_dir, "metrics.json")
|
|
try:
|
|
with open(metrics_json_path, "w") as f:
|
|
json.dump(classes_metrics, f, indent=4)
|
|
logging.info(f"Metrics successfully saved to {metrics_json_path}")
|
|
except Exception as e:
|
|
logging.error(f"Failed to save metrics to {metrics_json_path}: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|