from typing import List, Any, Dict, Tuple, Union |
import time |
import numpy as np |
from tabulate import tabulate |
from scipy.ndimage import gaussian_filter |
import torch |
from torch import Tensor |
import torch.nn.functional as F |
from lightning import LightningModule |
from torchmetrics import MaxMetric, MetricCollection |
from torchmetrics.classification import Accuracy, MulticlassAUROC |
from src.models.components.functional import cosine_similarity_torch |
from src.utils.metrics import image_level_metrics, pixel_level_metrics |
from src.utils import pylogger |
log = pylogger.RankedLogger(__name__, rank_zero_only=True) |
class AnomalyCLIPModule(LightningModule): |
""" |
LightningModule for training an anomaly detection model using features extracted by a CLIP model. |
Attributes: |
net (nn.Module): The core model which contains the layers for feature extraction and processing. |
loss_focal (nn.Module): Focal loss function, helpful for handling class imbalance in datasets. |
loss_dice (nn.Module): Dice loss function, typically used for segmentation tasks. |
loss_ce (nn.Module): Cross-entropy loss for classification tasks. |
""" |
def __init__( |
self, |
net: torch.nn.Module, |
optimizer: torch.optim.Optimizer, |
scheduler: torch.optim.lr_scheduler, |
loss: Union[torch.nn.Module, Any], |
enable_validation: bool, |
compile: bool, |
**kwargs, |
) -> None: |
super().__init__() |
""" |
Initialize a `AnomalyCLIPModule`. |
:param net: The model to train. |
:param optimizer: The optimizer to use for training. |
:param scheduler: The learning rate scheduler to use for training. |
:param loss: The loss function to use for training. |
:param enable_validation: Boolean to enable validation. |
:param compile: Boolean to enable compilation. |
:param kwargs: Additional keyword arguments. |
""" |
self.save_hyperparameters(logger=False, ignore=['net']) |
self.net = net |
self.optimizer = optimizer |
self.scheduler = scheduler |
self.loss_ce = loss["cross_entropy"] |
self.loss_focal = loss["focal"] |
self.loss_dice = loss["dice"] |
metrics = MetricCollection( |
{ |
"acc": Accuracy(task="multiclass", num_classes=2), |
"auroc": MulticlassAUROC(num_classes=2), |
} |
) |
self.train_metrics = metrics.clone(prefix="train/") |
self.val_metrics = metrics.clone(prefix="val/") |
self.test_metrics = metrics.clone(prefix="test/") |
self.val_acc_best = MaxMetric() |
self.val_auroc_best = MaxMetric() |
def forward(self, images: Tensor, cls_name: str) -> Tuple: |
""" |
Perform a forward pass through the model `self.net`. |
:param x: A tensor of images. |
:return: A tuple containing the model's outputs. |
""" |
return self.net(images, cls_name) |
def on_train_start(self) -> None: |
"""Lightning hook that is called when training begins.""" |
self.val_metrics.reset() |
self.val_acc_best.reset() |
self.val_auroc_best.reset() |
def on_train_batch_start(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: |
"""Hook called on the train batch start event.""" |
current_lr = self.trainer.optimizers[0].param_groups[0]['lr'] |
self.log('learning_rate', current_lr, on_step=True, on_epoch=False, prog_bar=True) |
def on_train_epoch_end(self) -> None: |
"""Lightning hook that is called when training epoch end.""" |
self.train_metrics.reset() |
def model_step(self, batch: Dict[str, Any]) -> Tuple: |
""" |
Shared logic for training, validation, and testing using vectorized operations. |
:param batch: A dictionary containing the input data. |
:return: A tuple containing the model's predictions and targets. |
""" |
images, masks, cls_name, labels = batch["image"], batch["image_mask"], batch["cls_name"], batch["anomaly"] |
_, _, _, _, anomaly_maps, text_probs = self.forward(images, cls_name) |
return text_probs, labels, anomaly_maps, masks |
def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]: |
""" |
Perform a single training step on a batch of data from the training set. |
:param batch: A batch of data (a dictionary) containing the input tensor of images and target labels. |
:param batch_idx: The index of the current batch. |
:return: A dictionary containing the computed losses. |
""" |
logits, labels, anomaly_maps, masks = self.model_step(batch) |
loss_ce = self.loss_ce(logits, labels) |
masks = torch.where(masks > 0.0, 1, 0).squeeze(1) |
loss_focal = 0 |
loss_dice = 0 |
for i in range(len(anomaly_maps)): |
loss_focal += self.loss_focal(anomaly_maps[i], masks) |
loss_dice += self.loss_dice(anomaly_maps[i][:, 1, :, :], masks) |
loss_dice += self.loss_dice(anomaly_maps[i][:, 0, :, :], 1 - masks) |
loss = loss_ce + loss_focal + loss_dice |
self.log("train/loss_ce", loss_ce, on_step=True, on_epoch=True, prog_bar=True) |
self.log("train/loss_focal", loss_focal, on_step=True, on_epoch=True, prog_bar=True) |
self.log("train/loss_dice", loss_dice, on_step=True, on_epoch=True, prog_bar=True) |
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) |
self.log_dict(self.train_metrics(logits, labels), on_step=False, on_epoch=True, prog_bar=True) |
return { |
"loss_ce": loss_ce, |
"loss_focal": loss_focal, |
"loss_dice": loss_dice, |
"loss": loss, |
} |
def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> None: |
""" |
Perform a single validation step on a batch of data from the validation set. |
:param batch: A batch of data (a dictionary) containing the input tensor of images and target labels. |
:param batch_idx: The index of the current batch. |
""" |
if not self.hparams.enable_validation: |
return None |
logits, labels, anomaly_maps, masks = self.model_step(batch) |
loss_ce = self.loss_ce(logits, labels) |
masks = torch.where(masks > 0.0, 1, 0).squeeze(1) |
loss_focal = 0 |
loss_dice = 0 |
for i in range(len(anomaly_maps)): |
loss_focal += self.loss_focal(anomaly_maps[i], masks) |
loss_dice += self.loss_dice(anomaly_maps[i][:, 1, :, :], masks) |
loss_dice += self.loss_dice(anomaly_maps[i][:, 0, :, :], 1 - masks) |
loss = loss_ce + loss_focal + loss_dice |
self.log("val/loss_ce", loss_ce, on_step=True, on_epoch=True, prog_bar=True) |
self.log("val/loss_focal", loss_focal, on_step=True, on_epoch=True, prog_bar=True) |
self.log("val/loss_dice", loss_dice, on_step=True, on_epoch=True, prog_bar=True) |
self.log("val/loss", loss, on_step=True, on_epoch=True, prog_bar=True) |
self.log_dict(self.val_metrics(logits, labels), on_step=False, on_epoch=True, prog_bar=True) |
def on_epoch_end(self) -> None: |
"""Hook called at the end of an epoch.""" |
pass |
def on_validation_epoch_end(self) -> None: |
"""Lightning hook that is called when a validation epoch ends.""" |
if not self.hparams.enable_validation: |
return None |
current_metrics = self.val_metrics.compute() |
current_acc = current_metrics["val/acc"] |
current_auroc = current_metrics["val/auroc"] |
self.val_acc_best.update(current_acc) |
self.val_auroc_best.update(current_auroc) |
self.log("val/val_acc_best", self.val_acc_best.compute(), on_epoch=True, prog_bar=True) |
self.log("val/val_roc_best", self.val_auroc_best.compute(), on_epoch=True, prog_bar=True) |
self.val_metrics.reset() |
def kshot_step(self, kshot_batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]: |
"""Compute features for the k-shot dataset. |
:param kshot_batch: A batch of data from the k-shot dataset. |
:param batch_idx: The index of the current batch. |
:return: A dictionary containing the computed features. |
""" |
kshot_images = kshot_batch['image'].squeeze(0) |
cls_name = kshot_batch['cls_name'][0] |
k_shot = kshot_images.shape[0] |
classnames = kshot_batch['cls_name'] * k_shot |
_, _, kshot_patches, _, _, _ = self.forward(kshot_images, classnames) |
return {cls_name: kshot_patches} |
def on_test_start(self) -> None: |
"""Lightning hook that is called when test begins.""" |
self.results = {} |
self.mem_features = {} |
device = self.device if hasattr(self, 'device') else torch.device('cpu') |
if self.hparams.k_shot: |
for batch_idx, kshot_batch in enumerate(self.trainer.datamodule.kshot_dataloader()): |
kshot_batch = {k: (v.to(device) if isinstance(v, Tensor) else v) for k, v in kshot_batch.items()} |
self.mem_features.update(self.kshot_step(kshot_batch, batch_idx)) |
def kshot_anomaly(self, patch_tokens: List[Tensor], cls_name: str) -> Tensor: |
""" |
Compute the k-shot anomaly maps for a given set of patch tokens. |
:param patch_tokens: List of tensors, where each tensor represents the patch tokens for a single image. |
Each tensor has shape (L, C) where L is the number of patches and C is the feature dimension. |
:param cls_name: String representing the class name used to retrieve the corresponding k-shot patches. |
:return: A tensor representing the aggregated k-shot anomaly map for the given set of patch tokens. |
The returned tensor has shape (1, H, W) where H and W are the dimensions of the resized anomaly map. |
""" |
kshot_patch = self.mem_features[cls_name] |
B, L, C = kshot_patch[0].shape |
H = int(L ** 0.5) |
anomaly_maps_kshot = [] |
for idx, patch in enumerate(patch_tokens): |
kshot_patch_expand = kshot_patch[idx].reshape(B * L, C) |
cosine_similarity = cosine_similarity_torch(kshot_patch_expand, patch) |
cosine_distance = (1. - cosine_similarity).min(dim=0)[0] |
anomaly_map_kshot = cosine_distance.reshape(1, 1, H, H) |
anomaly_map_kshot = F.interpolate(anomaly_map_kshot, |
size=self.net.image_size, mode='bilinear', align_corners=True) |
anomaly_maps_kshot.append(anomaly_map_kshot[0]) |
anomaly_maps_kshot = torch.stack(anomaly_maps_kshot).sum(dim=0) |
return anomaly_maps_kshot |
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Any]: |
""" |
Perform a single test step on a batch of data from the test set. |
:param batch: A batch of data (a dictionary) containing the input tensor of images and target labels. |
:param batch_idx: The index of the current batch. |
:return: A dictionary containing the anomaly maps. |
""" |
images, masks, classnames, labels = batch["image"], batch["image_mask"], batch["cls_name"], batch["anomaly"] |
_, _, patches, _, anomaly_maps, text_probs = self.forward(images, classnames) |
self.log_dict(self.test_metrics(text_probs, labels), on_step=False, on_epoch=True, prog_bar=True) |
masks = torch.where(masks > 0.5, 1, 0).squeeze(1) |
batch_size = images.size(0) |
anomaly_maps = torch.stack(anomaly_maps)[:, :, 1, :, :].sum(dim=0) |
if self.hparams.filter: |
anomaly_maps = torch.stack([torch.from_numpy(gaussian_filter(anomaly_map, sigma=4)) for anomaly_map in anomaly_maps.detach().cpu()], dim=0) |
updated_anomaly_maps = [] |
for i in range(batch_size): |
cls_key = classnames[i] |
if cls_key not in self.results: |
self.results[cls_key] = { |
'imgs_masks': [], |
'anomaly_maps': [], |
'gt_sp': [], |
'pr_sp': [] |
} |
mask = masks[i] |
anomaly_map = anomaly_maps[i] |
if self.hparams.k_shot: |
patch_tokens = [patch[i] for patch in patches] |
anomaly_maps_kshot = self.kshot_anomaly(patch_tokens, cls_key) |
anomaly_map = anomaly_map.cpu() + anomaly_maps_kshot.cpu() |
self.results[cls_key]['imgs_masks'].append(mask) |
self.results[cls_key]['anomaly_maps'].append(anomaly_map) |
self.results[cls_key]['gt_sp'].append(labels[i].cpu().numpy().item()) |
self.results[cls_key]['pr_sp'].append(text_probs[i][1].detach().cpu().numpy().item()) |
updated_anomaly_maps.append(anomaly_map) |
updated_anomaly_maps = torch.stack(updated_anomaly_maps, dim=0).squeeze(1) |
return { |
"anomaly_maps": updated_anomaly_maps, |
"abnormal": text_probs, |
} |
def on_test_epoch_end(self) -> None: |
""" |
Lightning hook that is called when a test epoch ends. |
This method processes the results collected during the test epoch, computes various metrics |
(image-level AUROC, image-level AP, pixel-level AUROC, pixel-level AUPRO), and logs these metrics. |
It also formats the metrics into a table and logs the table. Finally, it resets the results |
to ensure a clean state for the next epoch. |
Operations performed in this method: |
1. Start a timer to measure the duration of result processing. |
2. Initialize lists and a table for storing metrics for each class. |
3. Loop over the results for each class: |
a. Convert the stored masks and anomaly maps to numpy arrays. |
b. Compute the image-level AUROC, image-level AP, pixel-level AUROC, and pixel-level AUPRO for each class. |
c. Append the computed metrics to the corresponding lists and the table. |
4. Compute the mean values for each metric across all classes. |
5. Log the mean metrics. |
6. Format the metrics into a table using the `tabulate` library and log the table. |
7. Log the duration of result processing. |
8. Reset the results to ensure a clean state for the next epoch. |
""" |
start_time = time.time() |
log.info(f"Processing test results...") |
tables = [] |
image_auroc_list = [] |
image_ap_list = [] |
image_f1_list = [] |
pixel_auroc_list = [] |
pixel_ap_list = [] |
pixel_f1_list = [] |
pixel_aupro_list = [] |
for cls_key, data in self.results.items(): |
table = [cls_key] |
data['imgs_masks'] = torch.stack(data['imgs_masks']).detach().cpu().numpy() |
data['anomaly_maps'] = torch.stack(data['anomaly_maps']).detach().cpu().numpy() |
if self.hparams.k_shot: |
pr_sp_tmp = np.array([np.max(anomaly_map) for anomaly_map in data["anomaly_maps"]]) |
pr_sp_tmp = (pr_sp_tmp - pr_sp_tmp.min()) / (pr_sp_tmp.max() - pr_sp_tmp.min()) |
pr_sp = 0.5 * (np.array(data["pr_sp"]) + pr_sp_tmp) |
data["pr_sp"] = pr_sp |
image_auroc = image_level_metrics(self.results, cls_key, "image-auroc") |
image_ap = image_level_metrics(self.results, cls_key, "image-ap") |
image_f1 = image_level_metrics(self.results, cls_key, "image-f1-max") |
pixel_auroc = pixel_level_metrics(self.results, cls_key, "pixel-auroc") |
pixel_ap = pixel_level_metrics(self.results, cls_key, "pixel-ap") |
pixel_f1 = pixel_level_metrics(self.results, cls_key, "pixel-f1-max") |
pixel_aupro = pixel_level_metrics(self.results, cls_key, "pixel-aupro") |
table.append(str(np.round(image_auroc * 100, decimals=1))) |
table.append(str(np.round(image_ap * 100, decimals=1))) |
table.append(str(np.round(image_f1 * 100, decimals=1))) |
table.append(str(np.round(pixel_auroc * 100, decimals=1))) |
table.append(str(np.round(pixel_ap * 100, decimals=1))) |
table.append(str(np.round(pixel_f1 * 100, decimals=1))) |
table.append(str(np.round(pixel_aupro * 100, decimals=1))) |
tables.append(table) |
image_auroc_list.append(image_auroc) |
image_ap_list.append(image_ap) |
image_f1_list.append(image_f1) |
pixel_auroc_list.append(pixel_auroc) |
pixel_ap_list.append(pixel_ap) |
pixel_f1_list.append(pixel_f1) |
pixel_aupro_list.append(pixel_aupro) |
mean_image_auroc = np.mean(image_auroc_list) |
mean_image_ap = np.mean(image_ap_list) |
mean_image_f1 = np.mean(image_f1_list) |
mean_pixel_auroc = np.mean(pixel_auroc_list) |
mean_pixel_ap = np.mean(pixel_ap_list) |
mean_pixel_f1 = np.mean(pixel_f1_list) |
mean_pixel_aupro = np.mean(pixel_aupro_list) |
objective = (mean_image_auroc + mean_image_f1 + mean_image_ap + mean_pixel_auroc + mean_pixel_ap + mean_pixel_f1 + mean_pixel_aupro) / 7 |
tables.append( |
[ |
'mean', |
str(np.round(mean_image_auroc * 100, decimals=1)), |
str(np.round(mean_image_ap * 100, decimals=1)), |
str(np.round(mean_image_f1 * 100, decimals=1)), |
str(np.round(mean_pixel_auroc * 100, decimals=1)), |
str(np.round(mean_pixel_ap * 100, decimals=1)), |
str(np.round(mean_pixel_f1 * 100, decimals=1)), |
str(np.round(mean_pixel_aupro * 100, decimals=1)), |
] |
) |
self.log("test/image_auroc", mean_image_auroc, on_epoch=True, prog_bar=True) |
self.log("test/image_ap", mean_image_ap, on_epoch=True, prog_bar=True) |
self.log("test/image_f1", mean_image_f1, on_epoch=True, prog_bar=True) |
self.log("test/pixel_auroc", mean_pixel_auroc, on_epoch=True, prog_bar=True) |
self.log("test/pixel_ap", mean_pixel_ap, on_epoch=True, prog_bar=True) |
self.log("test/pixel_f1", mean_pixel_f1, on_epoch=True, prog_bar=True) |
self.log("test/pixel_aupro", mean_pixel_aupro, on_epoch=True, prog_bar=True) |
self.log("test/objective", objective, on_epoch=True, prog_bar=True) |
metrics = tabulate(tables, headers=["objects", "image_auroc", "image_ap", "image_f1", "pixel_auroc", "pixel_ap", "pixel_f1", 'pixel_aupro'], tablefmt="pipe") |
end_time = time.time() |
duration = end_time - start_time |
log.info(f"Processed test results in {duration:.2f} seconds.") |
log.info(f"\n{metrics}") |
self.results = {} |
def setup(self, stage: str) -> None: |
"""Lightning hook that is called at the beginning of fit (train + validate), validate, |
test, or predict. |
This is a good hook when you need to build models dynamically or adjust something about |
them. This hook is called on every process when using DDP. |
:param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. |
""" |
if self.hparams.compile and stage == "fit": |
self.net = torch.compile(self.net) |
def configure_optimizers(self) -> Dict[str, Any]: |
"""Choose what optimizers and learning-rate schedulers to use in your optimization. |
Normally you'd need one. But in the case of GANs or similar you might have multiple. |
Examples: |
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers |
:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. |
""" |
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) |
if self.hparams.scheduler is not None: |
scheduler = self.hparams.scheduler(optimizer=optimizer) |
return { |
"optimizer": optimizer, |
"lr_scheduler": { |
"scheduler": scheduler, |
"monitor": "train/loss", |
"interval": "epoch", |
"frequency": 1, |
}, |
} |
return {"optimizer": optimizer} |