sowa / SOWA /src /models /anomaly_clip_module.py
zongxiang's picture
Upload 116 files
7fe0374 verified
from typing import List, Any, Dict, Tuple, Union
import time
import numpy as np
from tabulate import tabulate
from scipy.ndimage import gaussian_filter
import torch
from torch import Tensor
import torch.nn.functional as F
from lightning import LightningModule
from torchmetrics import MaxMetric, MetricCollection
from torchmetrics.classification import Accuracy, MulticlassAUROC
from src.models.components.functional import cosine_similarity_torch
from src.utils.metrics import image_level_metrics, pixel_level_metrics
from src.utils import pylogger
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
class AnomalyCLIPModule(LightningModule):
"""
LightningModule for training an anomaly detection model using features extracted by a CLIP model.
Attributes:
net (nn.Module): The core model which contains the layers for feature extraction and processing.
loss_focal (nn.Module): Focal loss function, helpful for handling class imbalance in datasets.
loss_dice (nn.Module): Dice loss function, typically used for segmentation tasks.
loss_ce (nn.Module): Cross-entropy loss for classification tasks.
"""
def __init__(
self,
net: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
loss: Union[torch.nn.Module, Any],
enable_validation: bool,
compile: bool,
**kwargs,
) -> None:
super().__init__()
"""
Initialize a `AnomalyCLIPModule`.
:param net: The model to train.
:param optimizer: The optimizer to use for training.
:param scheduler: The learning rate scheduler to use for training.
:param loss: The loss function to use for training.
:param enable_validation: Boolean to enable validation.
:param compile: Boolean to enable compilation.
:param kwargs: Additional keyword arguments.
"""
self.save_hyperparameters(logger=False, ignore=['net'])
self.net = net
self.optimizer = optimizer
self.scheduler = scheduler
# loss function
self.loss_ce = loss["cross_entropy"]
self.loss_focal = loss["focal"]
self.loss_dice = loss["dice"]
metrics = MetricCollection(
{
"acc": Accuracy(task="multiclass", num_classes=2),
"auroc": MulticlassAUROC(num_classes=2),
}
)
self.train_metrics = metrics.clone(prefix="train/")
self.val_metrics = metrics.clone(prefix="val/")
self.test_metrics = metrics.clone(prefix="test/")
self.val_acc_best = MaxMetric()
self.val_auroc_best = MaxMetric()
def forward(self, images: Tensor, cls_name: str) -> Tuple:
"""
Perform a forward pass through the model `self.net`.
:param x: A tensor of images.
:return: A tuple containing the model's outputs.
"""
return self.net(images, cls_name)
def on_train_start(self) -> None:
"""Lightning hook that is called when training begins."""
# By default lightning executes validation step sanity checks before training starts,
# so it's worth to make sure validation metrics don't store results from these checks.
self.val_metrics.reset()
self.val_acc_best.reset()
self.val_auroc_best.reset()
def on_train_batch_start(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Hook called on the train batch start event."""
# This function logs the learning rate at the start of each training bat
current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
self.log('learning_rate', current_lr, on_step=True, on_epoch=False, prog_bar=True)
def on_train_epoch_end(self) -> None:
"""Lightning hook that is called when training epoch end."""
# by default lightning executes validation step sanity checks before training end.
self.train_metrics.reset()
def model_step(self, batch: Dict[str, Any]) -> Tuple:
"""
Shared logic for training, validation, and testing using vectorized operations.
:param batch: A dictionary containing the input data.
:return: A tuple containing the model's predictions and targets.
"""
images, masks, cls_name, labels = batch["image"], batch["image_mask"], batch["cls_name"], batch["anomaly"]
_, _, _, _, anomaly_maps, text_probs = self.forward(images, cls_name)
return text_probs, labels, anomaly_maps, masks
def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]:
"""
Perform a single training step on a batch of data from the training set.
:param batch: A batch of data (a dictionary) containing the input tensor of images and target labels.
:param batch_idx: The index of the current batch.
:return: A dictionary containing the computed losses.
"""
logits, labels, anomaly_maps, masks = self.model_step(batch)
loss_ce = self.loss_ce(logits, labels)
masks = torch.where(masks > 0.0, 1, 0).squeeze(1)
loss_focal = 0
loss_dice = 0
for i in range(len(anomaly_maps)):
loss_focal += self.loss_focal(anomaly_maps[i], masks)
loss_dice += self.loss_dice(anomaly_maps[i][:, 1, :, :], masks)
loss_dice += self.loss_dice(anomaly_maps[i][:, 0, :, :], 1 - masks)
loss = loss_ce + loss_focal + loss_dice
self.log("train/loss_ce", loss_ce, on_step=True, on_epoch=True, prog_bar=True)
self.log("train/loss_focal", loss_focal, on_step=True, on_epoch=True, prog_bar=True)
self.log("train/loss_dice", loss_dice, on_step=True, on_epoch=True, prog_bar=True)
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
self.log_dict(self.train_metrics(logits, labels), on_step=False, on_epoch=True, prog_bar=True)
return {
"loss_ce": loss_ce,
"loss_focal": loss_focal,
"loss_dice": loss_dice,
"loss": loss,
}
def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> None:
"""
Perform a single validation step on a batch of data from the validation set.
:param batch: A batch of data (a dictionary) containing the input tensor of images and target labels.
:param batch_idx: The index of the current batch.
"""
if not self.hparams.enable_validation:
return None
logits, labels, anomaly_maps, masks = self.model_step(batch)
loss_ce = self.loss_ce(logits, labels)
masks = torch.where(masks > 0.0, 1, 0).squeeze(1)
loss_focal = 0
loss_dice = 0
for i in range(len(anomaly_maps)):
loss_focal += self.loss_focal(anomaly_maps[i], masks)
loss_dice += self.loss_dice(anomaly_maps[i][:, 1, :, :], masks)
loss_dice += self.loss_dice(anomaly_maps[i][:, 0, :, :], 1 - masks)
loss = loss_ce + loss_focal + loss_dice
self.log("val/loss_ce", loss_ce, on_step=True, on_epoch=True, prog_bar=True)
self.log("val/loss_focal", loss_focal, on_step=True, on_epoch=True, prog_bar=True)
self.log("val/loss_dice", loss_dice, on_step=True, on_epoch=True, prog_bar=True)
self.log("val/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
self.log_dict(self.val_metrics(logits, labels), on_step=False, on_epoch=True, prog_bar=True)
def on_epoch_end(self) -> None:
"""Hook called at the end of an epoch."""
# Custom actions at the end of each epoch
pass
def on_validation_epoch_end(self) -> None:
"""Lightning hook that is called when a validation epoch ends."""
# Get current validation metrics
if not self.hparams.enable_validation:
return None
current_metrics = self.val_metrics.compute()
current_acc = current_metrics["val/acc"]
current_auroc = current_metrics["val/auroc"]
# update best so far val metrics
self.val_acc_best.update(current_acc)
self.val_auroc_best.update(current_auroc)
# log `val_acc/auroc_best` as a value through `.compute()` method, instead of as a metric object
# otherwise metric would be reset by lightning after each epoch
self.log("val/val_acc_best", self.val_acc_best.compute(), on_epoch=True, prog_bar=True)
self.log("val/val_roc_best", self.val_auroc_best.compute(), on_epoch=True, prog_bar=True)
# reset
self.val_metrics.reset()
def kshot_step(self, kshot_batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]:
"""Compute features for the k-shot dataset.
:param kshot_batch: A batch of data from the k-shot dataset.
:param batch_idx: The index of the current batch.
:return: A dictionary containing the computed features.
"""
kshot_images = kshot_batch['image'].squeeze(0) # Shape: [1, k_shot, C, H, W] -> [k_shot, C, H, W]
cls_name = kshot_batch['cls_name'][0] # [str] -> str
k_shot = kshot_images.shape[0]
classnames = kshot_batch['cls_name'] * k_shot # [str] -> [k_shot * str]
_, _, kshot_patches, _, _, _ = self.forward(kshot_images, classnames) # List[(k_shot, L, C)]
return {cls_name: kshot_patches}
def on_test_start(self) -> None:
"""Lightning hook that is called when test begins."""
self.results = {}
self.mem_features = {}
device = self.device if hasattr(self, 'device') else torch.device('cpu')
if self.hparams.k_shot:
for batch_idx, kshot_batch in enumerate(self.trainer.datamodule.kshot_dataloader()):
kshot_batch = {k: (v.to(device) if isinstance(v, Tensor) else v) for k, v in kshot_batch.items()}
self.mem_features.update(self.kshot_step(kshot_batch, batch_idx))
def kshot_anomaly(self, patch_tokens: List[Tensor], cls_name: str) -> Tensor:
"""
Compute the k-shot anomaly maps for a given set of patch tokens.
:param patch_tokens: List of tensors, where each tensor represents the patch tokens for a single image.
Each tensor has shape (L, C) where L is the number of patches and C is the feature dimension.
:param cls_name: String representing the class name used to retrieve the corresponding k-shot patches.
:return: A tensor representing the aggregated k-shot anomaly map for the given set of patch tokens.
The returned tensor has shape (1, H, W) where H and W are the dimensions of the resized anomaly map.
"""
kshot_patch = self.mem_features[cls_name]
B, L, C = kshot_patch[0].shape
H = int(L ** 0.5)
anomaly_maps_kshot = []
for idx, patch in enumerate(patch_tokens):
kshot_patch_expand = kshot_patch[idx].reshape(B * L, C)
cosine_similarity = cosine_similarity_torch(kshot_patch_expand, patch)
cosine_distance = (1. - cosine_similarity).min(dim=0)[0]
anomaly_map_kshot = cosine_distance.reshape(1, 1, H, H)
anomaly_map_kshot = F.interpolate(anomaly_map_kshot,
size=self.net.image_size, mode='bilinear', align_corners=True)
anomaly_maps_kshot.append(anomaly_map_kshot[0])
anomaly_maps_kshot = torch.stack(anomaly_maps_kshot).sum(dim=0)
return anomaly_maps_kshot
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Any]:
"""
Perform a single test step on a batch of data from the test set.
:param batch: A batch of data (a dictionary) containing the input tensor of images and target labels.
:param batch_idx: The index of the current batch.
:return: A dictionary containing the anomaly maps.
"""
images, masks, classnames, labels = batch["image"], batch["image_mask"], batch["cls_name"], batch["anomaly"]
_, _, patches, _, anomaly_maps, text_probs = self.forward(images, classnames)
self.log_dict(self.test_metrics(text_probs, labels), on_step=False, on_epoch=True, prog_bar=True)
masks = torch.where(masks > 0.5, 1, 0).squeeze(1)
batch_size = images.size(0)
anomaly_maps = torch.stack(anomaly_maps)[:, :, 1, :, :].sum(dim=0)
if self.hparams.filter:
anomaly_maps = torch.stack([torch.from_numpy(gaussian_filter(anomaly_map, sigma=4)) for anomaly_map in anomaly_maps.detach().cpu()], dim=0)
updated_anomaly_maps = []
for i in range(batch_size):
cls_key = classnames[i]
if cls_key not in self.results:
self.results[cls_key] = {
'imgs_masks': [],
'anomaly_maps': [],
'gt_sp': [],
'pr_sp': []
}
mask = masks[i]
anomaly_map = anomaly_maps[i]
if self.hparams.k_shot:
patch_tokens = [patch[i] for patch in patches]
anomaly_maps_kshot = self.kshot_anomaly(patch_tokens, cls_key)
anomaly_map = anomaly_map.cpu() + anomaly_maps_kshot.cpu()
self.results[cls_key]['imgs_masks'].append(mask) # Store masks for pixel-level metrics
self.results[cls_key]['anomaly_maps'].append(anomaly_map)
self.results[cls_key]['gt_sp'].append(labels[i].cpu().numpy().item()) # Store labels for image-level metrics
self.results[cls_key]['pr_sp'].append(text_probs[i][1].detach().cpu().numpy().item()) # Store predicted probabilities
updated_anomaly_maps.append(anomaly_map)
updated_anomaly_maps = torch.stack(updated_anomaly_maps, dim=0).squeeze(1)
return {
"anomaly_maps": updated_anomaly_maps,
"abnormal": text_probs,
}
def on_test_epoch_end(self) -> None:
"""
Lightning hook that is called when a test epoch ends.
This method processes the results collected during the test epoch, computes various metrics
(image-level AUROC, image-level AP, pixel-level AUROC, pixel-level AUPRO), and logs these metrics.
It also formats the metrics into a table and logs the table. Finally, it resets the results
to ensure a clean state for the next epoch.
Operations performed in this method:
1. Start a timer to measure the duration of result processing.
2. Initialize lists and a table for storing metrics for each class.
3. Loop over the results for each class:
a. Convert the stored masks and anomaly maps to numpy arrays.
b. Compute the image-level AUROC, image-level AP, pixel-level AUROC, and pixel-level AUPRO for each class.
c. Append the computed metrics to the corresponding lists and the table.
4. Compute the mean values for each metric across all classes.
5. Log the mean metrics.
6. Format the metrics into a table using the `tabulate` library and log the table.
7. Log the duration of result processing.
8. Reset the results to ensure a clean state for the next epoch.
"""
start_time = time.time()
log.info(f"Processing test results...")
tables = []
image_auroc_list = []
image_ap_list = []
image_f1_list = []
pixel_auroc_list = []
pixel_ap_list = []
pixel_f1_list = []
pixel_aupro_list = []
for cls_key, data in self.results.items():
table = [cls_key]
data['imgs_masks'] = torch.stack(data['imgs_masks']).detach().cpu().numpy()
data['anomaly_maps'] = torch.stack(data['anomaly_maps']).detach().cpu().numpy()
# Handling few-shot mode
if self.hparams.k_shot:
pr_sp_tmp = np.array([np.max(anomaly_map) for anomaly_map in data["anomaly_maps"]])
pr_sp_tmp = (pr_sp_tmp - pr_sp_tmp.min()) / (pr_sp_tmp.max() - pr_sp_tmp.min())
pr_sp = 0.5 * (np.array(data["pr_sp"]) + pr_sp_tmp)
data["pr_sp"] = pr_sp
image_auroc = image_level_metrics(self.results, cls_key, "image-auroc")
image_ap = image_level_metrics(self.results, cls_key, "image-ap")
image_f1 = image_level_metrics(self.results, cls_key, "image-f1-max")
pixel_auroc = pixel_level_metrics(self.results, cls_key, "pixel-auroc")
pixel_ap = pixel_level_metrics(self.results, cls_key, "pixel-ap")
pixel_f1 = pixel_level_metrics(self.results, cls_key, "pixel-f1-max")
pixel_aupro = pixel_level_metrics(self.results, cls_key, "pixel-aupro")
table.append(str(np.round(image_auroc * 100, decimals=1)))
table.append(str(np.round(image_ap * 100, decimals=1)))
table.append(str(np.round(image_f1 * 100, decimals=1)))
table.append(str(np.round(pixel_auroc * 100, decimals=1)))
table.append(str(np.round(pixel_ap * 100, decimals=1)))
table.append(str(np.round(pixel_f1 * 100, decimals=1)))
table.append(str(np.round(pixel_aupro * 100, decimals=1)))
tables.append(table)
image_auroc_list.append(image_auroc)
image_ap_list.append(image_ap)
image_f1_list.append(image_f1)
pixel_auroc_list.append(pixel_auroc)
pixel_ap_list.append(pixel_ap)
pixel_f1_list.append(pixel_f1)
pixel_aupro_list.append(pixel_aupro)
mean_image_auroc = np.mean(image_auroc_list)
mean_image_ap = np.mean(image_ap_list)
mean_image_f1 = np.mean(image_f1_list)
mean_pixel_auroc = np.mean(pixel_auroc_list)
mean_pixel_ap = np.mean(pixel_ap_list)
mean_pixel_f1 = np.mean(pixel_f1_list)
mean_pixel_aupro = np.mean(pixel_aupro_list)
objective = (mean_image_auroc + mean_image_f1 + mean_image_ap + mean_pixel_auroc + mean_pixel_ap + mean_pixel_f1 + mean_pixel_aupro) / 7
tables.append(
[
'mean',
str(np.round(mean_image_auroc * 100, decimals=1)),
str(np.round(mean_image_ap * 100, decimals=1)),
str(np.round(mean_image_f1 * 100, decimals=1)),
str(np.round(mean_pixel_auroc * 100, decimals=1)),
str(np.round(mean_pixel_ap * 100, decimals=1)),
str(np.round(mean_pixel_f1 * 100, decimals=1)),
str(np.round(mean_pixel_aupro * 100, decimals=1)),
]
)
# Log the mean metrics
self.log("test/image_auroc", mean_image_auroc, on_epoch=True, prog_bar=True)
self.log("test/image_ap", mean_image_ap, on_epoch=True, prog_bar=True)
self.log("test/image_f1", mean_image_f1, on_epoch=True, prog_bar=True)
self.log("test/pixel_auroc", mean_pixel_auroc, on_epoch=True, prog_bar=True)
self.log("test/pixel_ap", mean_pixel_ap, on_epoch=True, prog_bar=True)
self.log("test/pixel_f1", mean_pixel_f1, on_epoch=True, prog_bar=True)
self.log("test/pixel_aupro", mean_pixel_aupro, on_epoch=True, prog_bar=True)
self.log("test/objective", objective, on_epoch=True, prog_bar=True)
metrics = tabulate(tables, headers=["objects", "image_auroc", "image_ap", "image_f1", "pixel_auroc", "pixel_ap", "pixel_f1", 'pixel_aupro'], tablefmt="pipe")
end_time = time.time()
duration = end_time - start_time
log.info(f"Processed test results in {duration:.2f} seconds.")
log.info(f"\n{metrics}")
# Reset results to clean state after test
self.results = {}
def setup(self, stage: str) -> None:
"""Lightning hook that is called at the beginning of fit (train + validate), validate,
test, or predict.
This is a good hook when you need to build models dynamically or adjust something about
them. This hook is called on every process when using DDP.
:param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
"""
if self.hparams.compile and stage == "fit":
self.net = torch.compile(self.net)
def configure_optimizers(self) -> Dict[str, Any]:
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
Examples:
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
"""
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
if self.hparams.scheduler is not None:
scheduler = self.hparams.scheduler(optimizer=optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "train/loss",
"interval": "epoch",
"frequency": 1,
},
}
return {"optimizer": optimizer}