Spaces:
Running
Running
File size: 6,118 Bytes
1a92924 6b63a3c 1a92924 518b8f3 1a92924 334ba02 1a92924 334ba02 1a92924 6b63a3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import argparse
from functools import partial
import gradio as gr
from torch.nn import functional as F
from torch import nn
from dataset import get_data_transforms
from PIL import Image
import os
from utils import get_gaussian_kernel
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
import os
import torch
import cv2
import numpy as np
# # Model-Related Modules
from models import vit_encoder
from models.uad import INP_Former
from models.vision_transformer import Mlp, Aggregation_Block, Prototype_Block
# Configurations
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
parser = argparse.ArgumentParser(description='')
# model info
parser.add_argument('--encoder', type=str, default='dinov2reg_vit_base_14')
parser.add_argument('--input_size', type=int, default=448)
parser.add_argument('--crop_size', type=int, default=392)
parser.add_argument('--INP_num', type=int, default=6)
args = parser.parse_args()
############ Init Model
ckt_path1 = 'weights/Real-IAD/model.pth'
ckt_path2 = "weights/Real-IAD/model.pth"
#
data_transform, _ = get_data_transforms(args.input_size, args.crop_size)
# device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Adopting a grouping-based reconstruction strategy similar to Dinomaly
target_layers = [2, 3, 4, 5, 6, 7, 8, 9]
fuse_layer_encoder = [[0, 1, 2, 3], [4, 5, 6, 7]]
fuse_layer_decoder = [[0, 1, 2, 3], [4, 5, 6, 7]]
# Encoder info
encoder = vit_encoder.load(args.encoder)
if 'small' in args.encoder:
embed_dim, num_heads = 384, 6
elif 'base' in args.encoder:
embed_dim, num_heads = 768, 12
elif 'large' in args.encoder:
embed_dim, num_heads = 1024, 16
target_layers = [4, 6, 8, 10, 12, 14, 16, 18]
else:
raise "Architecture not in small, base, large."
# Model Preparation
Bottleneck = []
INP_Guided_Decoder = []
INP_Extractor = []
# bottleneck
Bottleneck.append(Mlp(embed_dim, embed_dim * 4, embed_dim, drop=0.))
Bottleneck = nn.ModuleList(Bottleneck)
# INP
INP = nn.ParameterList(
[nn.Parameter(torch.randn(args.INP_num, embed_dim))
for _ in range(1)])
# INP Extractor
for i in range(1):
blk = Aggregation_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4.,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-8))
INP_Extractor.append(blk)
INP_Extractor = nn.ModuleList(INP_Extractor)
# INP_Guided_Decoder
for i in range(8):
blk = Prototype_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4.,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-8))
INP_Guided_Decoder.append(blk)
INP_Guided_Decoder = nn.ModuleList(INP_Guided_Decoder)
model = INP_Former(encoder=encoder, bottleneck=Bottleneck, aggregation=INP_Extractor, decoder=INP_Guided_Decoder,
target_layers=target_layers, remove_class_token=True, fuse_layer_encoder=fuse_layer_encoder,
fuse_layer_decoder=fuse_layer_decoder, prototype_token=INP)
model = model.to(device)
gaussian_kernel = get_gaussian_kernel(kernel_size=5, sigma=4).to(device)
def resize_and_center_crop(image, resize_size=448, crop_size=392):
# Resize to 448x448
image_resized = cv2.resize(image, (resize_size, resize_size), interpolation=cv2.INTER_LINEAR)
# Compute crop coordinates
start = (resize_size - crop_size) // 2
end = start + crop_size
# Center crop to 392x392
image_cropped = image_resized[start:end, start:end, :]
return image_cropped
def process_image(image, options):
# Load the model based on selected options
if 'Real-IAD' in options:
model.load_state_dict(torch.load(ckt_path1, map_location=torch.device('cpu')), strict=True)
elif 'VisA' in options:
model.load_state_dict(torch.load(ckt_path2, map_location=torch.device('cpu')), strict=True)
else:
# Default to 'All' if no valid option is provided
model.load_state_dict(torch.load(ckt_path1), strict=True)
print('Invalid option. Defaulting to All.')
# Ensure image is in RGB mode
image = image.convert('RGB')
# Convert PIL image to NumPy array
np_image = np.array(image)
image_shape = np_image.shape[0]
# Convert RGB to BGR for OpenCV
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
np_image = resize_and_center_crop(np_image, resize_size=args.input_size, crop_size=args.crop_size)
# Preprocess the image and run the model
input_image = data_transform(image)
input_image = input_image.to(device)
with torch.no_grad():
_ = model(input_image.unsqueeze(0))
anomaly_map = model.distance
side = int(model.distance.shape[1] ** 0.5)
anomaly_map = anomaly_map.reshape([anomaly_map.shape[0], side, side]).contiguous()
anomaly_map = torch.unsqueeze(anomaly_map, dim=1)
anomaly_map = F.interpolate(anomaly_map, size=input_image.shape[-1], mode='bilinear', align_corners=True)
anomaly_map = gaussian_kernel(anomaly_map)
# Process anomaly map
anomaly_map = anomaly_map.squeeze().cpu().numpy()
anomaly_map = (anomaly_map * 255).astype(np.uint8)
# Apply color map and blend with original image
heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0)
# Convert OpenCV image back to PIL image for Gradio
vis_map_pil = Image.fromarray(cv2.resize(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB), (image_shape, image_shape)))
return vis_map_pil
# Define examples
examples = [
["assets/img2.png", "Real-IAD"],
["assets/img.png", "VisA"]
]
# Gradio interface layout
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Radio(["Real-IAD",
"VisA"],
label="Pre-trained Datasets")
],
outputs=[
gr.Image(type="pil", label="Output Image")
],
examples=examples,
title="INP-Former -- Zero-shot Anomaly Detection",
description="Upload an image and select pre-trained datasets to do zero-shot anomaly detection"
)
# Launch the demo
demo.launch()
# demo.launch(server_name="0.0.0.0", server_port=10002)
|