ERA_Session12 / app.py
bala1802's picture
Upload app.py
1f80462
raw
history blame
No virus
5.46 kB
import os
from io import BytesIO
from pathlib import Path
from random import shuffle
import cv2
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import torch
from mini_resnet import CustomResNet
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision import transforms as T
mean = (0.49139968, 0.48215841, 0.44653091)
std = (0.24703223, 0.24348513, 0.26158784)
transforms = T.Compose([T.ToTensor(), T.Normalize(mean=mean, std=std)])
classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")
softmax = torch.nn.Softmax(dim=0)
model = CustomResNet()
model.load_state_dict(torch.load("model_weights/weights.pt", map_location=torch.device("cpu")))
model.eval()
misclf_path = "images/miss_classified"
mis_classified_imgs = list(Path(misclf_path).glob("*"))
def get_traget_layer(block: str, layer: int):
layer_num = 0 if layer == 0 else -1
if block == "block1":
return model.layer1[layer_num]
if block == "block2":
return model.layer2[layer_num]
if block == "block3":
return model.layer3[layer_num]
default_cam = GradCAM(model=model, target_layers=[get_traget_layer("block3", -1)])
def make_image(p: Path | str, pred: str, label: str):
im = cv2.imread(str(p))
im = cv2.resize(im, (64, 64))
plt.imshow(im)
plt.title(f"{pred} / {label}")
plt.axis("off")
buffer = BytesIO()
plt.savefig(buffer, format="png")
buffer.seek(0)
img_array = np.frombuffer(buffer.getvalue(), dtype=np.uint8)
buffer.close()
# Decode the image array using OpenCV
im = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
return im
@torch.inference_mode()
def predict_img(img: np.ndarray, top_k: int = 10):
preds = model(img)
preds = softmax(preds.flatten())
preds = {classes[i]: float(preds[i]) for i in range(10)}
preds = {
k: v for k, v in sorted(preds.items(), key=lambda item: item[1], reverse=True)[:top_k]
}
return preds
def display_cam(cam: GradCAM, org_img: np.ndarray, img: torch.Tensor, transparency: float):
grayscale_cam = cam(input_tensor=img, targets=None)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(
org_img / 255, grayscale_cam, use_rgb=True, image_weight=transparency
)
return visualization
def inference(
org_img: np.ndarray,
top_k: int,
show_cam: str,
num_cam_imgs: int,
cam_block: str,
target_layer_num: int,
transparency: float,
show_misclf: str,
num_misclf: int,
):
input_img = transforms(org_img)
input_img = input_img.unsqueeze(0)
preds = predict_img(input_img, top_k)
org_img = display_cam(default_cam, org_img, input_img, transparency)
shuffle(mis_classified_imgs)
cam_outputs = []
if show_cam:
img_list = []
target_layers = [get_traget_layer(cam_block, target_layer_num)]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
for p in mis_classified_imgs[:num_cam_imgs]:
im = cv2.imread(str(p))
inp_im = transforms(im)
inp_im = inp_im.unsqueeze(0)
grayscale_cam = cam(input_tensor=inp_im, targets=None)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(
im / 255, grayscale_cam, use_rgb=True, image_weight=transparency
)
cam_outputs.append(visualization)
del cam, img_list
misclf_images_output = []
if show_misclf:
img_list = []
gt = []
for p in mis_classified_imgs[:num_misclf]:
img_list.append(transforms(Image.open(p).convert("RGB")))
gt.append(p.name.split("_")[0])
misclf_out = softmax(model(torch.stack(img_list))).argmax(dim=1).tolist()
del img_list
for imp, pred, label in zip(mis_classified_imgs[:num_misclf], misclf_out, gt):
pred = classes[pred]
misclf_images_output.append(make_image(imp, pred, label))
return org_img, preds, cam_outputs, misclf_images_output
title = "CIFAR10 trained on Custom Model inspired by ResNet with GradCAM"
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
# examples = [["cat.jpg", 0.5, -1], ["dog.jpg", 0.5, -1]]
demo = gr.Interface(
inference,
inputs=[
gr.Image(shape=(32, 32), label="Input Image"),
gr.Slider(1, 10, value=3, step=1, label="Top K predictions"),
gr.Checkbox(label="Show Grad Cam"),
gr.Slider(1, 20, value=5, step=1, label="Number of images"),
gr.Radio(label="Which Block?", choices=["block1", "block2", "block3"]),
gr.Slider(0, 1, value=1, step=1, label="Which Layer?"),
gr.Slider(0, 1, value=0.5, label="Opacity of GradCAM"),
gr.Checkbox(label="Show Misclassified Images"),
gr.Slider(1, 20, value=5, step=5, label="Number of Misclassification Images"),
],
outputs=[
gr.Image(shape=(32, 32), label="Output", width=128, height=128),
"label",
gr.Gallery(label="GradCAM Output"),
gr.Gallery(
label="Misclassified Images Pred/G.T.",
columns=[2],
rows=[2],
object_fit="contain",
height="auto",
),
],
title=title,
description=description,
# examples=examples,
)
demo.launch()