Aastha
Add application file
94cd336
raw
history blame
No virus
3.91 kB
import torch
from efficientnet_pytorch import EfficientNet
from torchvision import transforms
from PIL import Image
import gradio as gr
from super_gradients.training import models
import cv2
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the YOLO-NAS model
yolo_nas_l = models.get("yolo_nas_l", pretrained_weights="coco")
def bounding_boxes_overlap(box1, box2):
"""Check if two bounding boxes overlap or touch."""
x1, y1, x2, y2 = box1
x3, y3, x4, y4 = box2
return not (x3 > x2 or x4 < x1 or y3 > y2 or y4 < y1)
def merge_boxes(box1, box2):
"""Return the encompassing bounding box of two boxes."""
x1, y1, x2, y2 = box1
x3, y3, x4, y4 = box2
x = min(x1, x3)
y = min(y1, y3)
w = max(x2, x4)
h = max(y2, y4)
return (x, y, w, h)
def save_merged_boxes(predictions, image_np):
"""Save merged bounding boxes as separate images."""
processed_boxes = set()
roi = None # Initialize roi to None
for image_prediction in predictions:
bboxes = image_prediction.prediction.bboxes_xyxy
for box1 in bboxes:
for box2 in bboxes:
if np.array_equal(box1, box2):
continue
if bounding_boxes_overlap(box1, box2) and tuple(box1) not in processed_boxes and tuple(box2) not in processed_boxes:
merged_box = merge_boxes(box1, box2)
roi = image_np[int(merged_box[1]):int(merged_box[3]), int(merged_box[0]):int(merged_box[2])]
processed_boxes.add(tuple(box1))
processed_boxes.add(tuple(box2))
break # Exit the inner loop once a match is found
if roi is not None:
break # Exit the outer loop once a match is found
return roi
# Load the EfficientNet model
def load_model(model_path):
model = torch.load(model_path)
model = model.to(device)
model.eval() # Set the model to evaluation mode
return model
# Perform inference on an image
def predict_image(image, model):
# First, get the ROI using YOLO-NAS
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
predictions = yolo_nas_l.predict(image_np, iou=0.3, conf=0.35)
roi_new = save_merged_boxes(predictions, image_np)
if roi_new is None:
roi_new = image_np # Use the original image if no ROI is found
# Convert ROI back to PIL Image for EfficientNet
roi_image = Image.fromarray(cv2.cvtColor(roi_new, cv2.COLOR_BGR2RGB))
# Define the image transformations
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Convert PIL Image to Tensor
roi_image_tensor = transform(roi_image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(roi_image_tensor)
_, predicted = outputs.max(1)
prediction_text = 'Accident' if predicted.item() == 0 else 'No accident'
return roi_image, prediction_text # Return both the roi_image and the prediction text
# Load the EfficientNet model outside the function to avoid loading it multiple times
model_path = 'vehicle.pt'
model = load_model(model_path)
# Gradio UI
title = "Vehicle Collision Classification"
description = "Upload an image to determine if it depicts a vehicle accident. Powered by EfficientNet."
examples = [["roi_none.png"], ["test2.jpeg"]] # Replace with your example image path
gr.Interface(fn=lambda img: predict_image(img, model),
inputs=gr.inputs.Image(type="pil"),
outputs=[gr.outputs.Image(type="pil"), "text"], # Updated outputs to handle both image and text
title=title,
description=description,
examples=examples).launch()