ZeroShot-AD / main.py
HoomKh's picture
files
e5461d8 verified
# main.py
import os
import torch
from torch.utils.data import DataLoader
from datasets.all_classes_dataset import AllClassesDataset, DatasetSplit
from models.anomaly_detector import AnomalyDetector
from utils.dump_scores import DumpScores
import logging
import json
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score
import numpy as np
import torch.nn.functional as F
import random
def set_seed(seed: int):
"""
Set the seed for reproducibility across various libraries.
Args:
seed (int): The seed value to be set.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # For multi-GPU setups
# Ensure deterministic behavior in PyTorch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# For DataLoader workers
os.environ['PYTHONHASHSEED'] = str(seed)
def worker_init_fn(worker_id):
"""
Initialize the seed for each DataLoader worker to ensure reproducibility.
Args:
worker_id (int): The worker ID.
"""
seed = torch.initial_seed()
np.random.seed(seed % 2**32)
random.seed(seed % 2**32)
def compute_aupro(y_true_pixel, y_scores_pixel, num_thresholds=50):
"""
Compute Area Under the Per-Region Overlap Curve (AUPRO).
Args:
y_true_pixel (np.ndarray): Ground truth binary masks, shape [N, H, W]
y_scores_pixel (np.ndarray): Predicted anomaly scores, shape [N, H, W]
num_thresholds (int): Number of thresholds to evaluate.
Returns:
float: AUPRO score.
"""
# Define thresholds
thresholds = np.linspace(0, 1, num_thresholds)
# Initialize list to store overlaps
overlaps = []
for thresh in thresholds:
# Binarize predictions
y_pred = (y_scores_pixel >= thresh).astype(int)
# Compute Intersection over Union (IoU) for each sample
ious = []
for gt, pred in zip(y_true_pixel, y_pred):
intersection = np.logical_and(gt, pred).sum()
union = np.logical_or(gt, pred).sum()
if union == 0:
iou = 1.0 # If both gt and pred are all zeros
else:
iou = intersection / union
ious.append(iou)
# Average IoU over all samples
avg_iou = np.mean(ious)
overlaps.append(avg_iou)
# Compute the area under the overlap curve
aupro = np.trapz(overlaps, thresholds) / np.trapz([1] * len(thresholds), thresholds) # Normalize
return aupro
def compute_metrics(y_true_image, y_scores_image, y_true_pixel, y_scores_pixel):
"""
Compute the required metrics based on true labels and predicted scores.
Args:
y_true_image (np.ndarray): Ground truth image labels, shape [N]
y_scores_image (np.ndarray): Predicted image scores, shape [N]
y_true_pixel (np.ndarray): Ground truth pixel masks, shape [N, H, W]
y_scores_pixel (np.ndarray): Predicted pixel anomaly scores, shape [N, H, W]
Returns:
dict: Dictionary containing computed metrics.
"""
# Check image-level consistency
if len(y_true_image) != len(y_scores_image):
raise ValueError(f"Image-level y_true and y_scores have different lengths: {len(y_true_image)} vs {len(y_scores_image)}")
# Check pixel-level consistency
if y_true_pixel.shape != y_scores_pixel.shape:
raise ValueError(f"Pixel-level y_true and y_scores have different shapes: {y_true_pixel.shape} vs {y_scores_pixel.shape}")
# Image-level Metrics
image_ap = average_precision_score(y_true_image, y_scores_image)
image_auroc = roc_auc_score(y_true_image, y_scores_image)
y_pred_image = (y_scores_image >= 0.5).astype(int)
image_f1 = f1_score(y_true_image, y_pred_image)
# Pixel-level Metrics
pixel_ap = average_precision_score(y_true_pixel.flatten(), y_scores_pixel.flatten())
pixel_auroc = roc_auc_score(y_true_pixel.flatten(), y_scores_pixel.flatten())
pixel_aupro = compute_aupro(y_true_pixel, y_scores_pixel)
y_pred_pixel = (y_scores_pixel >= 0.5).astype(int)
pixel_f1 = f1_score(y_true_pixel.flatten(), y_pred_pixel.flatten())
# Compute leaderboard_score as a weighted average (example weights)
# Adjust weights as per your specific requirements
leaderboard_score = (
0.25 * image_auroc +
0.25 * image_f1 +
0.25 * pixel_auroc +
0.25 * pixel_f1
)
metrics = {
"image_metrics": {
"image_ap": round(float(image_ap), 4),
"image_auroc": round(float(image_auroc), 4),
"image_f1": round(float(image_f1), 4)
},
"pixel_metrics": {
"pixel_ap": round(float(pixel_ap), 4),
"pixel_aupro": round(float(pixel_aupro), 4),
"pixel_auroc": round(float(pixel_auroc), 4),
"pixel_f1": round(float(pixel_f1), 4)
},
"overall_metric": {
"leaderboard_score": round(float(leaderboard_score), 4)
}
}
return metrics
def get_class_name(image_path, source_dir):
"""
Extract the class name from the image path.
Args:
image_path (str): Path to the image file.
source_dir (str): Root source directory.
Returns:
str: Class name.
"""
# Example image_path: "./data/pill/test/broken/image1.png"
rel_path = os.path.relpath(image_path, source_dir) # "pill/test/broken/image1.png"
parts = rel_path.split(os.sep)
if len(parts) < 2:
raise ValueError(f"Unexpected image path format: {image_path}")
class_name = parts[0] # "pill"
return class_name
def main():
SEED = 41 # You can choose any integer value
set_seed(SEED)
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Configuration
source_dir = "./data"
output_scores_dir = "./output_scores"
split = DatasetSplit.TEST # Use the Enum instead of string
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logging.info("Initializing the dataset and dataloader...")
# Initialize dataset and dataloader using AllClassesDataset with output_size=17
dataset = AllClassesDataset(
source=source_dir,
split=split,
# output_size=16 # Set to match anomaly_map resolution
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0)
logging.info("Initializing the anomaly detector...")
# Initialize anomaly detector
detector = AnomalyDetector(device=device)
# Initialize DumpScores
dump_scores = DumpScores(output_dir=output_scores_dir)
logging.info("Starting anomaly detection inference...")
# Initialize containers for metrics
classes = dataset.get_all_class_names()
metrics_data = {cls: {
"y_true_image": [],
"y_scores_image": [],
"y_true_pixel": [],
"y_scores_pixel": []
} for cls in classes}
# Iterate through the dataset
for batch_idx, batch in enumerate(dataloader):
image = batch['image'].squeeze(0) # Shape: [3, H, W]
mask = batch['mask'].squeeze(1).numpy() # Remove all singleton dimensions to get [17, 17]
image_label = batch['is_anomaly'].item() # 1 or 0
image_path = batch['image_path'][0] # Assuming batch_size=1
# Extract class name from image_path
try:
class_name = get_class_name(image_path, source_dir)
except ValueError as e:
logging.error(f"Error extracting class name: {e}")
continue # Skip this sample
# Extract features and compute scores using GLASS
image_score, anomaly_map = detector.extract_features(image, "all")
# Compute pixel-level anomaly score (already normalized)
pixel_score = detector.compute_pixel_score(anomaly_map).squeeze()
pixel_score_tensor = torch.from_numpy(pixel_score).float().unsqueeze(0).unsqueeze(0).to(
device) # Shape: [1, 1, 17, 17]
# **Upsample pixel_score to (224, 224)**
# Option 1: Using PyTorch Interpolation
pixel_score = F.interpolate(
pixel_score_tensor, # Add batch and channel dimensions
size=(224, 224),
mode='bilinear',
align_corners=False
).squeeze(0).cpu().numpy() # Removes all singleton dimensions, resulting in [224, 224]
# Option 2: Using OpenCV (Uncomment if preferred)
# pixel_score_np = pixel_score.numpy()
# pixel_score = cv2.resize(
# pixel_score,
# dsize=(224, 224),
# interpolation=cv2.INTER_LINEAR
# )
# **Optional: Verify the upsampled pixel_score shape**
# if pixel_score.shape != (1, 224, 224):
# logging.warning(
# f"Upsampled pixel score shape mismatch for image {image_path}: expected (224, 224), got {pixel_score.shape}")
# continue # Skip this sample
# Append to metrics_data
metrics_data[class_name]["y_true_image"].append(image_label)
metrics_data[class_name]["y_scores_image"].append(image_score)
metrics_data[class_name]["y_true_pixel"].append(mask)
metrics_data[class_name]["y_scores_pixel"].append(pixel_score)
# Save individual image scores
dump_scores.save_scores([image_path], [image_score], [pixel_score])
logging.info(f"[{batch_idx + 1}/{len(dataloader)}] Processed image: {image_path}")
logging.info(f"Image-level score: {image_score:.4f}")
logging.info(f"Pixel-level mean score: {pixel_score.mean():.4f}")
logging.info("Anomaly detection inference completed. Computing metrics...")
# Initialize dictionary to hold metrics per class
classes_metrics = {}
for cls in classes:
y_true_image = np.array(metrics_data[cls]["y_true_image"])
y_scores_image = np.array(metrics_data[cls]["y_scores_image"])
y_true_pixel = np.array(metrics_data[cls]["y_true_pixel"])
y_scores_pixel = np.array(metrics_data[cls]["y_scores_pixel"])
# Check if there are any samples for the class
if len(y_true_image) == 0:
logging.warning(f"No samples found for class {cls}. Skipping metric computation.")
continue
try:
metrics = compute_metrics(y_true_image, y_scores_image, y_true_pixel, y_scores_pixel)
classes_metrics[cls] = metrics
logging.info(f"Metrics computed for class: {cls}")
except Exception as e:
logging.error(f"Failed to compute metrics for class {cls}: {e}")
# Save metrics to JSON
os.makedirs(output_scores_dir, exist_ok=True)
metrics_json_path = os.path.join(output_scores_dir, "metrics.json")
try:
with open(metrics_json_path, "w") as f:
json.dump(classes_metrics, f, indent=4)
logging.info(f"Metrics successfully saved to {metrics_json_path}")
except Exception as e:
logging.error(f"Failed to save metrics to {metrics_json_path}: {e}")
if __name__ == "__main__":
main()