import os import cv2 import math import torch import numpy as np import gradio as gr import matplotlib.pyplot as plt from PIL import Image from glob import glob from utils.plots import Annotator, colors from utils.augmentations import letterbox from models.common import DetectMultiBackend from utils.general import non_max_suppression, scale_boxes from utils.torch_utils import select_device, smart_inference_mode from pytorch_grad_cam import EigenCAM import torchvision.transforms as transforms from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image weights = "runs/train/best_striped.pt" data = "data.yaml" # Load model device = select_device('cpu') model = DetectMultiBackend(weights=weights, device=device, fp16=False, data=data) target_layers = [model.model.model[-2]] false_detection_data = glob(os.path.join("false_detection", '*.jpg')) false_detection_data = [x.replace('\\', '/') for x in false_detection_data] def resize_image_pil(image, new_width, new_height): # Convert to PIL image img = Image.fromarray(np.array(image)) # Get original size width, height = img.size # Calculate scale width_scale = new_width / width height_scale = new_height / height scale = min(width_scale, height_scale) # Resize resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST) # Crop to exact size resized = resized.crop((0, 0, new_width, new_height)) return resized def display_false_detection_data(false_detection_data, number_of_samples): fig = plt.figure(figsize=(10, 10)) x_count = 5 y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count) for i in range(number_of_samples): plt.subplot(y_count, x_count, i + 1) img = cv2.imread(false_detection_data[i]) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) plt.imshow(img) plt.xticks([]) plt.yticks([]) return fig def inference(input_img, conf_thres, iou_thres, is_eigen_cam=True, is_false_detection_images=True, num_false_detection_images=10): stride, names, pt = model.stride, model.names, model.pt # Load image img0 = input_img.copy() img = letterbox(img0, 640, stride=stride, auto=True)[0] img = img[:, :, ::-1].transpose(2, 0, 1) img = np.ascontiguousarray(img) img = torch.from_numpy(img).to(device).float() img /= 255.0 if img.ndimension() == 3: img = img.unsqueeze(0) # Inference pred = model(img, augment=False, visualize=False) # Apply NMS pred = non_max_suppression(pred, conf_thres, iou_thres, classes=None, max_det=1000) # Process predictions seen = 0 for i, det in enumerate(pred): # per image seen += 1 annotator = Annotator(img0, line_width=2, example=str(model.names)) if len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_boxes(img.shape[2:], det[:, :4], img0.shape).round() # Write results for *xyxy, conf, cls in reversed(det): c = int(cls) # integer class label = f'{names[c]} {conf:.2f}' annotator.box_label(xyxy, label, color=colors(c, True)) if is_false_detection_images: # Plot the misclassified data misclassified_images = display_false_detection_data(false_detection_data, number_of_samples=num_false_detection_images) else: misclassified_images = None if is_eigen_cam: img_GC = cv2.resize(input_img, (640, 640)) rgb_img = img_GC.copy() img_GC = np.float32(img_GC) / 255 transform = transforms.ToTensor() tensor = transform(img_GC).unsqueeze(0) cam = EigenCAM(model, target_layers) grayscale_cam = cam(tensor)[0, :, :] cam_image = show_cam_on_image(img_GC, grayscale_cam, use_rgb=True) else: cam_image = None return img0, cam_image, misclassified_images title = "YOLOv9 model to detect shirt/tshirt" description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image" examples = [["image_1.jpg", 0.25, 0.45, True, True, 10], ["image_2.jpg", 0.25, 0.45, True, True, 10], ["image_3.jpg", 0.25, 0.45, True, True, 10], ["image_4.jpg", 0.25, 0.45, True, True, 10], ["image_5.jpg", 0.25, 0.45, True, True, 10], ["image_6.jpg", 0.25, 0.45, True, True, 10], ["image_7.jpg", 0.25, 0.45, True, True, 10], ["image_8.jpg", 0.25, 0.45, True, True, 10], ["image_9.jpg", 0.25, 0.45, True, True, 10], ["image_10.jpg", 0.25, 0.45, True, True, 10]] demo = gr.Interface(inference, inputs = [gr.Image(width=320, height=320, label="Input Image"), gr.Slider(0, 1, 0.25, label="Confidence Threshold"), gr.Slider(0, 1, 0.45, label="IoU Thresold"), gr.Checkbox(label="Show Eigen CAM"), gr.Checkbox(label="Show False Detection"), gr.Slider(5, 35, value=10, step=5, label="Number of False Detection")], outputs= [gr.Image(width=640, height=640, label="Output"), gr.Image(label="EigenCAM"), gr.Plot(label="False Detection")], title=title, description=description, examples=examples) demo.launch()