|
from typing import List, Any, Dict, Tuple, Union |
|
import time |
|
import numpy as np |
|
from tabulate import tabulate |
|
from scipy.ndimage import gaussian_filter |
|
|
|
import torch |
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
from lightning import LightningModule |
|
from torchmetrics import MaxMetric, MetricCollection |
|
from torchmetrics.classification import Accuracy, MulticlassAUROC |
|
|
|
from src.models.components.functional import cosine_similarity_torch |
|
from src.utils.metrics import image_level_metrics, pixel_level_metrics |
|
from src.utils import pylogger |
|
|
|
log = pylogger.RankedLogger(__name__, rank_zero_only=True) |
|
|
|
|
|
class AnomalyCLIPModule(LightningModule): |
|
""" |
|
LightningModule for training an anomaly detection model using features extracted by a CLIP model. |
|
|
|
Attributes: |
|
net (nn.Module): The core model which contains the layers for feature extraction and processing. |
|
loss_focal (nn.Module): Focal loss function, helpful for handling class imbalance in datasets. |
|
loss_dice (nn.Module): Dice loss function, typically used for segmentation tasks. |
|
loss_ce (nn.Module): Cross-entropy loss for classification tasks. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
net: torch.nn.Module, |
|
optimizer: torch.optim.Optimizer, |
|
scheduler: torch.optim.lr_scheduler, |
|
loss: Union[torch.nn.Module, Any], |
|
enable_validation: bool, |
|
compile: bool, |
|
**kwargs, |
|
) -> None: |
|
super().__init__() |
|
""" |
|
Initialize a `AnomalyCLIPModule`. |
|
|
|
:param net: The model to train. |
|
:param optimizer: The optimizer to use for training. |
|
:param scheduler: The learning rate scheduler to use for training. |
|
:param loss: The loss function to use for training. |
|
:param enable_validation: Boolean to enable validation. |
|
:param compile: Boolean to enable compilation. |
|
:param kwargs: Additional keyword arguments. |
|
""" |
|
self.save_hyperparameters(logger=False, ignore=['net']) |
|
|
|
self.net = net |
|
self.optimizer = optimizer |
|
self.scheduler = scheduler |
|
|
|
|
|
self.loss_ce = loss["cross_entropy"] |
|
self.loss_focal = loss["focal"] |
|
self.loss_dice = loss["dice"] |
|
|
|
metrics = MetricCollection( |
|
{ |
|
"acc": Accuracy(task="multiclass", num_classes=2), |
|
"auroc": MulticlassAUROC(num_classes=2), |
|
} |
|
) |
|
self.train_metrics = metrics.clone(prefix="train/") |
|
self.val_metrics = metrics.clone(prefix="val/") |
|
self.test_metrics = metrics.clone(prefix="test/") |
|
|
|
self.val_acc_best = MaxMetric() |
|
self.val_auroc_best = MaxMetric() |
|
|
|
def forward(self, images: Tensor, cls_name: str) -> Tuple: |
|
""" |
|
Perform a forward pass through the model `self.net`. |
|
|
|
:param x: A tensor of images. |
|
:return: A tuple containing the model's outputs. |
|
""" |
|
return self.net(images, cls_name) |
|
|
|
def on_train_start(self) -> None: |
|
"""Lightning hook that is called when training begins.""" |
|
|
|
|
|
self.val_metrics.reset() |
|
self.val_acc_best.reset() |
|
self.val_auroc_best.reset() |
|
|
|
def on_train_batch_start(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: |
|
"""Hook called on the train batch start event.""" |
|
|
|
current_lr = self.trainer.optimizers[0].param_groups[0]['lr'] |
|
self.log('learning_rate', current_lr, on_step=True, on_epoch=False, prog_bar=True) |
|
|
|
def on_train_epoch_end(self) -> None: |
|
"""Lightning hook that is called when training epoch end.""" |
|
|
|
self.train_metrics.reset() |
|
|
|
def model_step(self, batch: Dict[str, Any]) -> Tuple: |
|
""" |
|
Shared logic for training, validation, and testing using vectorized operations. |
|
|
|
:param batch: A dictionary containing the input data. |
|
:return: A tuple containing the model's predictions and targets. |
|
""" |
|
images, masks, cls_name, labels = batch["image"], batch["image_mask"], batch["cls_name"], batch["anomaly"] |
|
_, _, _, _, anomaly_maps, text_probs = self.forward(images, cls_name) |
|
|
|
return text_probs, labels, anomaly_maps, masks |
|
|
|
def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]: |
|
""" |
|
Perform a single training step on a batch of data from the training set. |
|
|
|
:param batch: A batch of data (a dictionary) containing the input tensor of images and target labels. |
|
:param batch_idx: The index of the current batch. |
|
:return: A dictionary containing the computed losses. |
|
""" |
|
logits, labels, anomaly_maps, masks = self.model_step(batch) |
|
loss_ce = self.loss_ce(logits, labels) |
|
|
|
masks = torch.where(masks > 0.0, 1, 0).squeeze(1) |
|
|
|
loss_focal = 0 |
|
loss_dice = 0 |
|
for i in range(len(anomaly_maps)): |
|
loss_focal += self.loss_focal(anomaly_maps[i], masks) |
|
loss_dice += self.loss_dice(anomaly_maps[i][:, 1, :, :], masks) |
|
loss_dice += self.loss_dice(anomaly_maps[i][:, 0, :, :], 1 - masks) |
|
|
|
loss = loss_ce + loss_focal + loss_dice |
|
|
|
self.log("train/loss_ce", loss_ce, on_step=True, on_epoch=True, prog_bar=True) |
|
self.log("train/loss_focal", loss_focal, on_step=True, on_epoch=True, prog_bar=True) |
|
self.log("train/loss_dice", loss_dice, on_step=True, on_epoch=True, prog_bar=True) |
|
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) |
|
|
|
self.log_dict(self.train_metrics(logits, labels), on_step=False, on_epoch=True, prog_bar=True) |
|
|
|
return { |
|
"loss_ce": loss_ce, |
|
"loss_focal": loss_focal, |
|
"loss_dice": loss_dice, |
|
"loss": loss, |
|
} |
|
|
|
def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> None: |
|
""" |
|
Perform a single validation step on a batch of data from the validation set. |
|
|
|
:param batch: A batch of data (a dictionary) containing the input tensor of images and target labels. |
|
:param batch_idx: The index of the current batch. |
|
""" |
|
if not self.hparams.enable_validation: |
|
return None |
|
|
|
logits, labels, anomaly_maps, masks = self.model_step(batch) |
|
loss_ce = self.loss_ce(logits, labels) |
|
|
|
masks = torch.where(masks > 0.0, 1, 0).squeeze(1) |
|
|
|
loss_focal = 0 |
|
loss_dice = 0 |
|
for i in range(len(anomaly_maps)): |
|
loss_focal += self.loss_focal(anomaly_maps[i], masks) |
|
loss_dice += self.loss_dice(anomaly_maps[i][:, 1, :, :], masks) |
|
loss_dice += self.loss_dice(anomaly_maps[i][:, 0, :, :], 1 - masks) |
|
|
|
loss = loss_ce + loss_focal + loss_dice |
|
|
|
self.log("val/loss_ce", loss_ce, on_step=True, on_epoch=True, prog_bar=True) |
|
self.log("val/loss_focal", loss_focal, on_step=True, on_epoch=True, prog_bar=True) |
|
self.log("val/loss_dice", loss_dice, on_step=True, on_epoch=True, prog_bar=True) |
|
self.log("val/loss", loss, on_step=True, on_epoch=True, prog_bar=True) |
|
|
|
self.log_dict(self.val_metrics(logits, labels), on_step=False, on_epoch=True, prog_bar=True) |
|
|
|
def on_epoch_end(self) -> None: |
|
"""Hook called at the end of an epoch.""" |
|
|
|
pass |
|
|
|
def on_validation_epoch_end(self) -> None: |
|
"""Lightning hook that is called when a validation epoch ends.""" |
|
|
|
if not self.hparams.enable_validation: |
|
return None |
|
|
|
current_metrics = self.val_metrics.compute() |
|
current_acc = current_metrics["val/acc"] |
|
current_auroc = current_metrics["val/auroc"] |
|
|
|
self.val_acc_best.update(current_acc) |
|
self.val_auroc_best.update(current_auroc) |
|
|
|
|
|
self.log("val/val_acc_best", self.val_acc_best.compute(), on_epoch=True, prog_bar=True) |
|
self.log("val/val_roc_best", self.val_auroc_best.compute(), on_epoch=True, prog_bar=True) |
|
|
|
|
|
self.val_metrics.reset() |
|
|
|
def kshot_step(self, kshot_batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]: |
|
"""Compute features for the k-shot dataset. |
|
|
|
:param kshot_batch: A batch of data from the k-shot dataset. |
|
:param batch_idx: The index of the current batch. |
|
:return: A dictionary containing the computed features. |
|
""" |
|
kshot_images = kshot_batch['image'].squeeze(0) |
|
cls_name = kshot_batch['cls_name'][0] |
|
|
|
k_shot = kshot_images.shape[0] |
|
classnames = kshot_batch['cls_name'] * k_shot |
|
|
|
_, _, kshot_patches, _, _, _ = self.forward(kshot_images, classnames) |
|
|
|
return {cls_name: kshot_patches} |
|
|
|
def on_test_start(self) -> None: |
|
"""Lightning hook that is called when test begins.""" |
|
self.results = {} |
|
self.mem_features = {} |
|
device = self.device if hasattr(self, 'device') else torch.device('cpu') |
|
if self.hparams.k_shot: |
|
for batch_idx, kshot_batch in enumerate(self.trainer.datamodule.kshot_dataloader()): |
|
kshot_batch = {k: (v.to(device) if isinstance(v, Tensor) else v) for k, v in kshot_batch.items()} |
|
self.mem_features.update(self.kshot_step(kshot_batch, batch_idx)) |
|
|
|
def kshot_anomaly(self, patch_tokens: List[Tensor], cls_name: str) -> Tensor: |
|
""" |
|
Compute the k-shot anomaly maps for a given set of patch tokens. |
|
|
|
:param patch_tokens: List of tensors, where each tensor represents the patch tokens for a single image. |
|
Each tensor has shape (L, C) where L is the number of patches and C is the feature dimension. |
|
:param cls_name: String representing the class name used to retrieve the corresponding k-shot patches. |
|
:return: A tensor representing the aggregated k-shot anomaly map for the given set of patch tokens. |
|
The returned tensor has shape (1, H, W) where H and W are the dimensions of the resized anomaly map. |
|
""" |
|
kshot_patch = self.mem_features[cls_name] |
|
B, L, C = kshot_patch[0].shape |
|
H = int(L ** 0.5) |
|
anomaly_maps_kshot = [] |
|
|
|
for idx, patch in enumerate(patch_tokens): |
|
kshot_patch_expand = kshot_patch[idx].reshape(B * L, C) |
|
cosine_similarity = cosine_similarity_torch(kshot_patch_expand, patch) |
|
cosine_distance = (1. - cosine_similarity).min(dim=0)[0] |
|
anomaly_map_kshot = cosine_distance.reshape(1, 1, H, H) |
|
anomaly_map_kshot = F.interpolate(anomaly_map_kshot, |
|
size=self.net.image_size, mode='bilinear', align_corners=True) |
|
anomaly_maps_kshot.append(anomaly_map_kshot[0]) |
|
|
|
anomaly_maps_kshot = torch.stack(anomaly_maps_kshot).sum(dim=0) |
|
return anomaly_maps_kshot |
|
|
|
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Any]: |
|
""" |
|
Perform a single test step on a batch of data from the test set. |
|
|
|
:param batch: A batch of data (a dictionary) containing the input tensor of images and target labels. |
|
:param batch_idx: The index of the current batch. |
|
:return: A dictionary containing the anomaly maps. |
|
""" |
|
images, masks, classnames, labels = batch["image"], batch["image_mask"], batch["cls_name"], batch["anomaly"] |
|
_, _, patches, _, anomaly_maps, text_probs = self.forward(images, classnames) |
|
|
|
self.log_dict(self.test_metrics(text_probs, labels), on_step=False, on_epoch=True, prog_bar=True) |
|
|
|
masks = torch.where(masks > 0.5, 1, 0).squeeze(1) |
|
batch_size = images.size(0) |
|
|
|
anomaly_maps = torch.stack(anomaly_maps)[:, :, 1, :, :].sum(dim=0) |
|
if self.hparams.filter: |
|
anomaly_maps = torch.stack([torch.from_numpy(gaussian_filter(anomaly_map, sigma=4)) for anomaly_map in anomaly_maps.detach().cpu()], dim=0) |
|
|
|
updated_anomaly_maps = [] |
|
for i in range(batch_size): |
|
cls_key = classnames[i] |
|
if cls_key not in self.results: |
|
self.results[cls_key] = { |
|
'imgs_masks': [], |
|
'anomaly_maps': [], |
|
'gt_sp': [], |
|
'pr_sp': [] |
|
} |
|
|
|
mask = masks[i] |
|
anomaly_map = anomaly_maps[i] |
|
|
|
if self.hparams.k_shot: |
|
patch_tokens = [patch[i] for patch in patches] |
|
anomaly_maps_kshot = self.kshot_anomaly(patch_tokens, cls_key) |
|
anomaly_map = anomaly_map.cpu() + anomaly_maps_kshot.cpu() |
|
|
|
self.results[cls_key]['imgs_masks'].append(mask) |
|
self.results[cls_key]['anomaly_maps'].append(anomaly_map) |
|
|
|
self.results[cls_key]['gt_sp'].append(labels[i].cpu().numpy().item()) |
|
self.results[cls_key]['pr_sp'].append(text_probs[i][1].detach().cpu().numpy().item()) |
|
|
|
updated_anomaly_maps.append(anomaly_map) |
|
|
|
updated_anomaly_maps = torch.stack(updated_anomaly_maps, dim=0).squeeze(1) |
|
|
|
return { |
|
"anomaly_maps": updated_anomaly_maps, |
|
"abnormal": text_probs, |
|
} |
|
|
|
def on_test_epoch_end(self) -> None: |
|
""" |
|
Lightning hook that is called when a test epoch ends. |
|
|
|
This method processes the results collected during the test epoch, computes various metrics |
|
(image-level AUROC, image-level AP, pixel-level AUROC, pixel-level AUPRO), and logs these metrics. |
|
It also formats the metrics into a table and logs the table. Finally, it resets the results |
|
to ensure a clean state for the next epoch. |
|
|
|
Operations performed in this method: |
|
1. Start a timer to measure the duration of result processing. |
|
2. Initialize lists and a table for storing metrics for each class. |
|
3. Loop over the results for each class: |
|
a. Convert the stored masks and anomaly maps to numpy arrays. |
|
b. Compute the image-level AUROC, image-level AP, pixel-level AUROC, and pixel-level AUPRO for each class. |
|
c. Append the computed metrics to the corresponding lists and the table. |
|
4. Compute the mean values for each metric across all classes. |
|
5. Log the mean metrics. |
|
6. Format the metrics into a table using the `tabulate` library and log the table. |
|
7. Log the duration of result processing. |
|
8. Reset the results to ensure a clean state for the next epoch. |
|
""" |
|
|
|
start_time = time.time() |
|
log.info(f"Processing test results...") |
|
|
|
tables = [] |
|
image_auroc_list = [] |
|
image_ap_list = [] |
|
image_f1_list = [] |
|
pixel_auroc_list = [] |
|
pixel_ap_list = [] |
|
pixel_f1_list = [] |
|
pixel_aupro_list = [] |
|
|
|
for cls_key, data in self.results.items(): |
|
table = [cls_key] |
|
data['imgs_masks'] = torch.stack(data['imgs_masks']).detach().cpu().numpy() |
|
data['anomaly_maps'] = torch.stack(data['anomaly_maps']).detach().cpu().numpy() |
|
|
|
|
|
if self.hparams.k_shot: |
|
pr_sp_tmp = np.array([np.max(anomaly_map) for anomaly_map in data["anomaly_maps"]]) |
|
pr_sp_tmp = (pr_sp_tmp - pr_sp_tmp.min()) / (pr_sp_tmp.max() - pr_sp_tmp.min()) |
|
pr_sp = 0.5 * (np.array(data["pr_sp"]) + pr_sp_tmp) |
|
|
|
data["pr_sp"] = pr_sp |
|
|
|
image_auroc = image_level_metrics(self.results, cls_key, "image-auroc") |
|
image_ap = image_level_metrics(self.results, cls_key, "image-ap") |
|
image_f1 = image_level_metrics(self.results, cls_key, "image-f1-max") |
|
pixel_auroc = pixel_level_metrics(self.results, cls_key, "pixel-auroc") |
|
pixel_ap = pixel_level_metrics(self.results, cls_key, "pixel-ap") |
|
pixel_f1 = pixel_level_metrics(self.results, cls_key, "pixel-f1-max") |
|
pixel_aupro = pixel_level_metrics(self.results, cls_key, "pixel-aupro") |
|
|
|
table.append(str(np.round(image_auroc * 100, decimals=1))) |
|
table.append(str(np.round(image_ap * 100, decimals=1))) |
|
table.append(str(np.round(image_f1 * 100, decimals=1))) |
|
table.append(str(np.round(pixel_auroc * 100, decimals=1))) |
|
table.append(str(np.round(pixel_ap * 100, decimals=1))) |
|
table.append(str(np.round(pixel_f1 * 100, decimals=1))) |
|
table.append(str(np.round(pixel_aupro * 100, decimals=1))) |
|
tables.append(table) |
|
|
|
image_auroc_list.append(image_auroc) |
|
image_ap_list.append(image_ap) |
|
image_f1_list.append(image_f1) |
|
pixel_auroc_list.append(pixel_auroc) |
|
pixel_ap_list.append(pixel_ap) |
|
pixel_f1_list.append(pixel_f1) |
|
pixel_aupro_list.append(pixel_aupro) |
|
|
|
mean_image_auroc = np.mean(image_auroc_list) |
|
mean_image_ap = np.mean(image_ap_list) |
|
mean_image_f1 = np.mean(image_f1_list) |
|
mean_pixel_auroc = np.mean(pixel_auroc_list) |
|
mean_pixel_ap = np.mean(pixel_ap_list) |
|
mean_pixel_f1 = np.mean(pixel_f1_list) |
|
mean_pixel_aupro = np.mean(pixel_aupro_list) |
|
objective = (mean_image_auroc + mean_image_f1 + mean_image_ap + mean_pixel_auroc + mean_pixel_ap + mean_pixel_f1 + mean_pixel_aupro) / 7 |
|
|
|
tables.append( |
|
[ |
|
'mean', |
|
str(np.round(mean_image_auroc * 100, decimals=1)), |
|
str(np.round(mean_image_ap * 100, decimals=1)), |
|
str(np.round(mean_image_f1 * 100, decimals=1)), |
|
str(np.round(mean_pixel_auroc * 100, decimals=1)), |
|
str(np.round(mean_pixel_ap * 100, decimals=1)), |
|
str(np.round(mean_pixel_f1 * 100, decimals=1)), |
|
str(np.round(mean_pixel_aupro * 100, decimals=1)), |
|
] |
|
) |
|
|
|
|
|
self.log("test/image_auroc", mean_image_auroc, on_epoch=True, prog_bar=True) |
|
self.log("test/image_ap", mean_image_ap, on_epoch=True, prog_bar=True) |
|
self.log("test/image_f1", mean_image_f1, on_epoch=True, prog_bar=True) |
|
self.log("test/pixel_auroc", mean_pixel_auroc, on_epoch=True, prog_bar=True) |
|
self.log("test/pixel_ap", mean_pixel_ap, on_epoch=True, prog_bar=True) |
|
self.log("test/pixel_f1", mean_pixel_f1, on_epoch=True, prog_bar=True) |
|
self.log("test/pixel_aupro", mean_pixel_aupro, on_epoch=True, prog_bar=True) |
|
self.log("test/objective", objective, on_epoch=True, prog_bar=True) |
|
|
|
metrics = tabulate(tables, headers=["objects", "image_auroc", "image_ap", "image_f1", "pixel_auroc", "pixel_ap", "pixel_f1", 'pixel_aupro'], tablefmt="pipe") |
|
|
|
end_time = time.time() |
|
duration = end_time - start_time |
|
log.info(f"Processed test results in {duration:.2f} seconds.") |
|
log.info(f"\n{metrics}") |
|
|
|
|
|
self.results = {} |
|
|
|
def setup(self, stage: str) -> None: |
|
"""Lightning hook that is called at the beginning of fit (train + validate), validate, |
|
test, or predict. |
|
|
|
This is a good hook when you need to build models dynamically or adjust something about |
|
them. This hook is called on every process when using DDP. |
|
|
|
:param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. |
|
""" |
|
if self.hparams.compile and stage == "fit": |
|
self.net = torch.compile(self.net) |
|
|
|
def configure_optimizers(self) -> Dict[str, Any]: |
|
"""Choose what optimizers and learning-rate schedulers to use in your optimization. |
|
Normally you'd need one. But in the case of GANs or similar you might have multiple. |
|
|
|
Examples: |
|
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers |
|
|
|
:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. |
|
""" |
|
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) |
|
if self.hparams.scheduler is not None: |
|
scheduler = self.hparams.scheduler(optimizer=optimizer) |
|
return { |
|
"optimizer": optimizer, |
|
"lr_scheduler": { |
|
"scheduler": scheduler, |
|
"monitor": "train/loss", |
|
"interval": "epoch", |
|
"frequency": 1, |
|
}, |
|
} |
|
return {"optimizer": optimizer} |
|
|