Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Fri Aug 11 18:08:06 2023 | |
| @author: prarthana.ts | |
| """ | |
| import torch | |
| import torch.optim as optim | |
| import lightning.pytorch as pl | |
| from lightning.pytorch.tuner import Tuner | |
| # import pytorch_lightning as pl | |
| from tqdm import tqdm | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| import albumentations as A | |
| import cv2 | |
| import torch | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import numpy as np | |
| from albumentations.pytorch import ToTensorV2 | |
| from utils_for_app import cells_to_bboxes,non_max_suppression,plot_image,YoloCAM | |
| from yolov3 import YOLOv3 | |
| from loss import YoloLoss | |
| from utils import LearningRateFinder | |
| # Create your config module or import it from the existing config.py file. | |
| import config | |
| from main_yolov3_lightening import YOLOv3Lightning | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| import gradio as gr | |
| import os | |
| model = YOLOv3Lightning() | |
| model.load_state_dict(torch.load("yolov3_model_without_75_mosaic.pth", map_location=torch.device('cpu')), strict=False) | |
| model.setup(stage="test") | |
| IMAGE_SIZE = 416 | |
| transforms = A.Compose( | |
| [ | |
| A.LongestMaxSize(max_size=IMAGE_SIZE), | |
| A.PadIfNeeded( | |
| min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT | |
| ), | |
| A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,), | |
| ToTensorV2(), | |
| ], | |
| ) | |
| ANCHORS = [ | |
| [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)], | |
| [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)], | |
| [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)], | |
| ] # Note these have been rescaled to be between [0, 1] | |
| S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8] | |
| scaled_anchors = ( | |
| torch.tensor(config.ANCHORS) | |
| * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
| ) | |
| def process_image_and_plot(image,iou_threshold=0.5, threshold=0.4): | |
| transformed_image = transforms(image=image)["image"].unsqueeze(0) | |
| output = model(transformed_image) | |
| bboxes = [[] for _ in range(1)] | |
| for i in range(3): | |
| batch_size, A, S, _, _ = output[i].shape | |
| anchor = scaled_anchors[i] | |
| boxes_scale_i = cells_to_bboxes(output[i], anchor, S=S, is_preds=True) | |
| for idx, box in enumerate(boxes_scale_i): | |
| bboxes[idx] += box | |
| nms_boxes = non_max_suppression( | |
| bboxes[0], iou_threshold=iou_threshold, threshold=threshold, box_format="midpoint", | |
| ) | |
| fig = plot_image(transformed_image[0].permute(1, 2, 0), nms_boxes) | |
| cam = YoloCAM(model=model, target_layers=[model.model.layers[-2]], use_cuda=False) | |
| grayscale_cam = cam(transformed_image, scaled_anchors)[0, :, :] | |
| img = cv2.resize(image, (416, 416)) | |
| img = np.float32(img) / 255 | |
| cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True) | |
| return fig,cam_image | |
| examples = [ | |
| ["images/cycle.jpg"], | |
| ["images/human.jpg"], | |
| ["images/automobile.jpg"], | |
| ["images/barn.jpg"], | |
| ["images/car.jpg"], | |
| ["images/cars.jpg"], | |
| ["images/farm.jpg"], | |
| ["images/farms.jpg"], | |
| ["images/living.jpg"], | |
| ["images/livings.jpg"], | |
| ] | |
| icon_html = '<i class="fas fa-chart-bar"></i>' | |
| title_with_icon = f""" | |
| <div style="background-color: #f5f1f2; padding: 10px; display: flex; align-items: center;"> | |
| {icon_html} <span style="margin-left: 10px;">Object Detection on Pascal VOC Dataset with YoloV3</span> | |
| </div> | |
| """ | |
| description_with_icon = f""" | |
| <div style="background-color: #f1f1f5; padding: 10px; display: flex; align-items: center;"> | |
| {icon_html} | |
| <span style="margin-left: 10px;"> | |
| <p><strong>PyTorch Lightning Implementation of YOLOv3 Trained from Scratch</strong></p> | |
| <p><strong>Trained Classes:</strong></p> | |
| <ul> | |
| <li>๐ถโโ๏ธ Person: person</li> | |
| <li>๐ฆฎ Animal: bird, cat, cow, dog, horse, sheep</li> | |
| <li>๐ Vehicle: aeroplane, bicycle, boat, bus, car, motorbike, train</li> | |
| <li>๐ Indoor: bottle, chair, dining table, potted plant, sofa, TV, monitor</li> | |
| </ul> | |
| <p>Note: Setting a Lower IOU and a higher threshold shows better object detection</p> | |
| </span> | |
| </div> | |
| """ | |
| demo = gr.Interface(process_image_and_plot, | |
| inputs=[gr.Image(label="Input Image"), | |
| gr.Slider(0, 1, value=0.5, label="Intersection over Union (IOU) Threshold",info="Determines how much overlap between two boxes is allowed before they are considered redundant"), | |
| gr.Slider(0, 1, value=0.4, label="Threshold", info="It is used to filter out boxes with confidence scores below it. Higher value reduces the weaker classes"),], | |
| outputs=[ | |
| gr.Plot(label="Output with Classes",), | |
| gr.Image(shape=(32, 32), label="GradCAM Output"), | |
| ], | |
| title=title_with_icon, | |
| description=description_with_icon, | |
| examples=examples, | |
| ) | |
| demo.launch() |