| import os |
| import sys |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Dict, Optional, Tuple |
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from PIL import Image |
| from scipy.ndimage import gaussian_filter |
| from transformers import AutoProcessor, AutoTokenizer, SiglipVisionModel |
|
|
| |
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| TIPSOMALY_DIR = os.path.join(ROOT_DIR, "Tipsomaly") |
| MODEL_DIR = os.path.join(TIPSOMALY_DIR, "model") |
| if TIPSOMALY_DIR not in sys.path: |
| sys.path.insert(0, TIPSOMALY_DIR) |
| if MODEL_DIR not in sys.path: |
| sys.path.insert(0, MODEL_DIR) |
|
|
| from Tipsomaly.model.omaly.text_encoder import text_encoder as TipsomalyTextEncoder |
| from Tipsomaly.model.omaly.vision_encoder import vision_encoder as TipsomalyVisionEncoder |
| from Tipsomaly.model.siglip2.siglip2_prompt_learnable import SiglipTextModelWithPromptLearning |
|
|
|
|
| @dataclass |
| class DemoConfig: |
| model_id: str = os.getenv("SIGLIP2_MODEL_ID", "google/siglip2-base-patch16-256") |
| image_size: int = int(os.getenv("IMAGE_SIZE", "256")) |
| max_len: int = int(os.getenv("MAX_LEN", "64")) |
| use_local_to_global: bool = os.getenv("USE_LOCAL_TO_GLOBAL", "true").lower() == "true" |
| sigma: float = float(os.getenv("ANOMALY_SMOOTH_SIGMA", "4")) |
| object_name: str = os.getenv("OBJECT_NAME", "object") |
| prompt_learn_method: str = os.getenv("PROMPT_LEARN_METHOD", "none") |
| n_prompt: int = int(os.getenv("N_PROMPT", "8")) |
| n_deep_tokens: int = int(os.getenv("N_DEEP_TOKENS", "0")) |
| d_deep_tokens: int = int(os.getenv("D_DEEP_TOKENS", "0")) |
| checkpoint_epoch: int = int(os.getenv("LEARNABLE_PROMPT_EPOCH", "2")) |
|
|
|
|
| CHECKPOINTS: Dict[str, str] = { |
| "mvtec": "Tipsomaly/workspaces/trained_on_mvtec_default/vegan-arkansas/checkpoints", |
| "visa": "Tipsomaly/workspaces/trained_on_visa_default/vegan-arkansas/checkpoints", |
| } |
|
|
|
|
| def calc_sigm_score_hf( |
| vis_feat: torch.Tensor, |
| txt_feat: torch.Tensor, |
| temperature: torch.Tensor, |
| bias: torch.Tensor, |
| ) -> torch.Tensor: |
| if vis_feat.dim() < 3: |
| vis_feat = vis_feat.unsqueeze(dim=1) |
| logits = vis_feat @ txt_feat.permute(0, 2, 1) * temperature + bias |
| probs = torch.sigmoid(logits) |
| return probs |
|
|
|
|
| def regrid_upsample_smooth(flat_scores: torch.Tensor, size: int, sigma: float) -> torch.Tensor: |
| h_w = int(flat_scores.shape[1] ** 0.5) |
| regrided = flat_scores.reshape(flat_scores.shape[0], h_w, h_w, -1).permute(0, 3, 1, 2) |
| upsampled = torch.nn.functional.interpolate( |
| regrided, (size, size), mode="bilinear", align_corners=False |
| ).permute(0, 2, 3, 1) |
| rough_maps = (1 - upsampled[..., 0] + upsampled[..., 1]) / 2 |
| anomaly_map = torch.stack( |
| [torch.from_numpy(gaussian_filter(one_map, sigma=sigma)) for one_map in rough_maps.detach().cpu()], |
| dim=0, |
| ) |
| return anomaly_map |
|
|
|
|
| def make_heatmap_rgb(anomaly_map: np.ndarray) -> Image.Image: |
| normalized = anomaly_map - anomaly_map.min() |
| denom = normalized.max() + 1e-8 |
| normalized = normalized / denom |
|
|
| |
| red = (normalized * 255).astype(np.uint8) |
| green = (np.clip(1.0 - np.abs(normalized - 0.5) * 2.0, 0, 1) * 255).astype(np.uint8) |
| blue = ((1.0 - normalized) * 255).astype(np.uint8) |
| rgb = np.stack([red, green, blue], axis=-1) |
| return Image.fromarray(rgb, mode="RGB") |
|
|
|
|
| class TipsomalyDemo: |
| def __init__(self, config: DemoConfig) -> None: |
| self.config = config |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(config.model_id) |
| self.processor = AutoProcessor.from_pretrained(config.model_id) |
| self.vision_backbone = SiglipVisionModel.from_pretrained(config.model_id).to(self.device).eval() |
| self.text_backbone = SiglipTextModelWithPromptLearning.from_pretrained(config.model_id).to(self.device).eval() |
|
|
| self.temperature, self.bias = self._load_logit_params() |
| text_embd_dim = self.text_backbone.text_model.head.out_features |
| self.vision_encoder = TipsomalyVisionEncoder(self.vision_backbone, "siglip2-hf").to(self.device).eval() |
| self.text_embd_dim = text_embd_dim |
|
|
| def _load_logit_params(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| from transformers import AutoModel |
|
|
| model = AutoModel.from_pretrained(self.config.model_id).to(self.device).eval() |
| temperature = model.logit_scale.exp() |
| bias = model.logit_bias |
| return temperature, bias |
|
|
| def _build_text_encoder(self, domain: str, prompt_learn_method: str) -> TipsomalyTextEncoder: |
| encoder = TipsomalyTextEncoder( |
| tokenizer=self.tokenizer, |
| bb_text_encoder=self.text_backbone, |
| bb_type="siglip2-hf", |
| text_embd_dim=self.text_embd_dim, |
| MAX_LEN=self.config.max_len, |
| prompt_learn_method=prompt_learn_method, |
| prompt_type=domain, |
| n_prompt=self.config.n_prompt, |
| n_deep=self.config.n_deep_tokens, |
| d_deep=self.config.d_deep_tokens, |
| ).to(self.device).eval() |
| return encoder |
|
|
| def _resolve_checkpoint_path(self, token_source: str, custom_checkpoint: str) -> Optional[Path]: |
| if token_source == "none": |
| return None |
| if token_source == "custom": |
| if not custom_checkpoint.strip(): |
| raise gr.Error("Custom checkpoint selected, but path is empty.") |
| path = Path(custom_checkpoint.strip()) |
| else: |
| if token_source not in CHECKPOINTS: |
| raise gr.Error(f"Unknown token source: {token_source}") |
| base = Path(ROOT_DIR) / CHECKPOINTS[token_source] |
| path = base / f"learnable_params_{self.config.checkpoint_epoch}.pth" |
| if not path.exists(): |
| raise gr.Error(f"Checkpoint not found: {path}") |
| return path |
|
|
| def _load_learnable_prompts(self, encoder: TipsomalyTextEncoder, checkpoint_path: Optional[Path]) -> bool: |
| if checkpoint_path is None: |
| return False |
| checkpoint = torch.load(str(checkpoint_path), map_location=self.device, weights_only=False) |
| prompts = checkpoint["learnable_prompts"] if isinstance(checkpoint, dict) else checkpoint |
| encoder.learnable_prompts = prompts |
| return True |
|
|
| def _preprocess_image(self, image: Image.Image) -> torch.Tensor: |
| image = image.convert("RGB").resize((self.config.image_size, self.config.image_size)) |
| batch = self.processor(images=image, return_tensors="pt") |
| return batch["pixel_values"].to(self.device) |
|
|
| @torch.inference_mode() |
| def infer( |
| self, |
| image: Image.Image, |
| domain: str, |
| token_source: str, |
| custom_checkpoint: str, |
| ) -> Tuple[Image.Image, float]: |
| if image is None: |
| raise gr.Error("Please upload an image.") |
|
|
| checkpoint_path = self._resolve_checkpoint_path(token_source, custom_checkpoint) |
| prompt_learn_method = "concat" if checkpoint_path else self.config.prompt_learn_method |
| text_encoder = self._build_text_encoder(domain, prompt_learn_method=prompt_learn_method) |
| has_learned = self._load_learnable_prompts(text_encoder, checkpoint_path) |
|
|
| fixed_text_features = text_encoder([self.config.object_name], self.device, learned=False) |
| fixed_text_features = fixed_text_features / fixed_text_features.norm(dim=-1, keepdim=True) |
| seg_text_features = fixed_text_features |
| if has_learned: |
| learned_text_features = text_encoder([self.config.object_name], self.device, learned=True) |
| learned_text_features = learned_text_features / learned_text_features.norm(dim=-1, keepdim=True) |
| seg_text_features = learned_text_features |
|
|
| pixel_values = self._preprocess_image(image) |
| vision_features = self.vision_encoder(pixel_values) |
| vision_features = [feat / feat.norm(dim=-1, keepdim=True) for feat in vision_features] |
|
|
| |
| img_scr0 = calc_sigm_score_hf(vision_features[0], fixed_text_features, self.temperature, self.bias).squeeze(dim=1).detach() |
| img_scr1 = calc_sigm_score_hf(vision_features[1], fixed_text_features, self.temperature, self.bias).squeeze(dim=1).detach() |
|
|
| img_map = calc_sigm_score_hf(vision_features[2], seg_text_features, self.temperature, self.bias).detach() |
| if self.config.use_local_to_global: |
| max_local = torch.max(img_map, dim=1)[0] |
| img_scr0 = img_scr0 + max_local |
| img_scr1 = img_scr1 + max_local |
|
|
| pxl_scr = regrid_upsample_smooth(img_map, self.config.image_size, self.config.sigma) |
| anomaly_map = pxl_scr[0].cpu().numpy() |
| anomaly_score = float(img_scr1[0][1].item()) |
| return make_heatmap_rgb(anomaly_map), anomaly_score |
|
|
|
|
| CONFIG = DemoConfig() |
| MODEL = TipsomalyDemo(CONFIG) |
|
|
|
|
| def predict( |
| image: Image.Image, |
| domain: str, |
| token_source: str, |
| custom_checkpoint: str, |
| ) -> Tuple[Image.Image, float]: |
| return MODEL.infer(image, domain, token_source, custom_checkpoint) |
|
|
|
|
| with gr.Blocks(title="Tipsomaly Demo") as demo: |
| gr.Markdown( |
| "# Tipsomaly Anomaly Detection Demo\n" |
| "Upload one image and choose the domain prompt set. " |
| "The app returns an anomaly heatmap and image-level anomaly score." |
| ) |
| with gr.Row(): |
| image_input = gr.Image(type="pil", label="Input Image") |
| with gr.Column(): |
| domain_input = gr.Radio( |
| choices=["industrial", "medical"], |
| value="industrial", |
| label="Domain", |
| ) |
| token_source_input = gr.Radio( |
| choices=["none", "mvtec", "visa", "custom"], |
| value="none", |
| label="Learnable Tokens", |
| info="Use pretrained prompt tokens from workspace checkpoints.", |
| ) |
| custom_checkpoint_input = gr.Textbox( |
| label="Custom Checkpoint Path", |
| value="", |
| placeholder="Optional, used only when Learnable Tokens = custom", |
| ) |
| run_btn = gr.Button("Run Detection", variant="primary") |
| with gr.Row(): |
| anomaly_map_output = gr.Image(type="pil", label="Anomaly Map") |
| anomaly_score_output = gr.Number(label="Anomaly Score") |
|
|
| run_btn.click( |
| fn=predict, |
| inputs=[image_input, domain_input, token_source_input, custom_checkpoint_input], |
| outputs=[anomaly_map_output, anomaly_score_output], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|