Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import math | |
import torch | |
import numpy as np | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from glob import glob | |
from utils.plots import Annotator, colors | |
from utils.augmentations import letterbox | |
from models.common import DetectMultiBackend | |
from utils.general import non_max_suppression, scale_boxes | |
from utils.torch_utils import select_device, smart_inference_mode | |
from pytorch_grad_cam import EigenCAM | |
import torchvision.transforms as transforms | |
from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image | |
weights = "runs/train/best_striped.pt" | |
data = "data.yaml" | |
# Load model | |
device = select_device('cpu') | |
model = DetectMultiBackend(weights=weights, device=device, fp16=False, data=data) | |
target_layers = [model.model.model[-2]] | |
false_detection_data = glob(os.path.join("false_detection", '*.jpg')) | |
false_detection_data = [x.replace('\\', '/') for x in false_detection_data] | |
def resize_image_pil(image, new_width, new_height): | |
# Convert to PIL image | |
img = Image.fromarray(np.array(image)) | |
# Get original size | |
width, height = img.size | |
# Calculate scale | |
width_scale = new_width / width | |
height_scale = new_height / height | |
scale = min(width_scale, height_scale) | |
# Resize | |
resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST) | |
# Crop to exact size | |
resized = resized.crop((0, 0, new_width, new_height)) | |
return resized | |
def display_false_detection_data(false_detection_data, number_of_samples): | |
fig = plt.figure(figsize=(10, 10)) | |
x_count = 5 | |
y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count) | |
for i in range(number_of_samples): | |
plt.subplot(y_count, x_count, i + 1) | |
img = cv2.imread(false_detection_data[i]) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
plt.imshow(img) | |
plt.xticks([]) | |
plt.yticks([]) | |
return fig | |
def inference(input_img, conf_thres, iou_thres, is_eigen_cam=True, is_false_detection_images=True, num_false_detection_images=10): | |
stride, names, pt = model.stride, model.names, model.pt | |
# Load image | |
img0 = input_img.copy() | |
img = letterbox(img0, 640, stride=stride, auto=True)[0] | |
img = img[:, :, ::-1].transpose(2, 0, 1) | |
img = np.ascontiguousarray(img) | |
img = torch.from_numpy(img).to(device).float() | |
img /= 255.0 | |
if img.ndimension() == 3: | |
img = img.unsqueeze(0) | |
# Inference | |
pred = model(img, augment=False, visualize=False) | |
# Apply NMS | |
pred = non_max_suppression(pred, conf_thres, iou_thres, classes=None, max_det=1000) | |
# Process predictions | |
seen = 0 | |
for i, det in enumerate(pred): # per image | |
seen += 1 | |
annotator = Annotator(img0, line_width=2, example=str(model.names)) | |
if len(det): | |
# Rescale boxes from img_size to im0 size | |
det[:, :4] = scale_boxes(img.shape[2:], det[:, :4], img0.shape).round() | |
# Write results | |
for *xyxy, conf, cls in reversed(det): | |
c = int(cls) # integer class | |
label = f'{names[c]} {conf:.2f}' | |
annotator.box_label(xyxy, label, color=colors(c, True)) | |
if is_false_detection_images: | |
# Plot the misclassified data | |
misclassified_images = display_false_detection_data(false_detection_data, number_of_samples=num_false_detection_images) | |
else: | |
misclassified_images = None | |
if is_eigen_cam: | |
img_GC = cv2.resize(input_img, (640, 640)) | |
rgb_img = img_GC.copy() | |
img_GC = np.float32(img_GC) / 255 | |
transform = transforms.ToTensor() | |
tensor = transform(img_GC).unsqueeze(0) | |
cam = EigenCAM(model, target_layers) | |
grayscale_cam = cam(tensor)[0, :, :] | |
cam_image = show_cam_on_image(img_GC, grayscale_cam, use_rgb=True) | |
else: | |
cam_image = None | |
return img0, cam_image, misclassified_images | |
title = "YOLOv9 model to detect shirt/tshirt" | |
description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image" | |
examples = [["image_1.jpg", 0.25, 0.45, True, True, 10], | |
["image_2.jpg", 0.25, 0.45, True, True, 10], | |
["image_3.jpg", 0.25, 0.45, True, True, 10], | |
["image_4.jpg", 0.25, 0.45, True, True, 10], | |
["image_5.jpg", 0.25, 0.45, True, True, 10], | |
["image_6.jpg", 0.25, 0.45, True, True, 10], | |
["image_7.jpg", 0.25, 0.45, True, True, 10], | |
["image_8.jpg", 0.25, 0.45, True, True, 10], | |
["image_9.jpg", 0.25, 0.45, True, True, 10], | |
["image_10.jpg", 0.25, 0.45, True, True, 10]] | |
demo = gr.Interface(inference, | |
inputs = [gr.Image(width=320, height=320, label="Input Image"), | |
gr.Slider(0, 1, 0.25, label="Confidence Threshold"), | |
gr.Slider(0, 1, 0.45, label="IoU Thresold"), | |
gr.Checkbox(label="Show Eigen CAM"), | |
gr.Checkbox(label="Show False Detection"), | |
gr.Slider(5, 35, value=10, step=5, label="Number of False Detection")], | |
outputs= [gr.Image(width=640, height=640, label="Output"), | |
gr.Image(label="EigenCAM"), | |
gr.Plot(label="False Detection")], | |
title=title, | |
description=description, | |
examples=examples) | |
demo.launch() |