Spaces:
Runtime error
Runtime error
import gradio as gr | |
import cv2 | |
import requests | |
import os | |
from ultralytics import YOLO | |
# Define the directory containing the YOLO model files (*.pt) | |
model_weights_dir = 'best_weights' | |
# List all "*.pt" files in the specified directory | |
model_paths = [os.path.join(model_weights_dir, filename) for filename in os.listdir(model_weights_dir) if filename.endswith('.pt')] | |
# Initialize YOLO models based on the discovered model paths | |
models = [YOLO(model_path) for model_path in model_paths] | |
# Extract model names from paths (remove directory and ".pt" extension) | |
model_names = [os.path.splitext(os.path.basename(model_path))[0] for model_path in model_paths] | |
examples = [ | |
["plot.JPG"], | |
["plot2.JPG"], | |
] | |
# class_colors = [(0, 255, 0), (0, 0, 255), (255, 0, 0), (0, 255, 255), (255, 255, 0), (255, 0, 255)] | |
# def show_preds_image(image_path, selection): | |
# image = cv2.imread(image_path) | |
# outputs = models[selection].predict(source=image_path) | |
# results = outputs[0].cpu().numpy() | |
# for i, det in enumerate(results.boxes.xyxy): | |
# class_id = int(det[4]) # Assuming class ID is at index 4 | |
# class_name = model_names[selection] + "_" + str(class_id) # Combine model name and class ID | |
# color = class_colors[class_id % len(class_colors)] # Use modulo to handle more classes than colors | |
# # Draw rectangle and put text on the image | |
# cv2.rectangle(image, | |
# (int(det[0]), int(det[1])), | |
# (int(det[2]), int(det[3])), | |
# color=color, | |
# thickness=2, | |
# lineType=cv2.LINE_AA | |
# ) | |
# font = cv2.FONT_HERSHEY_SIMPLEX | |
# font_scale = 0.5 | |
# cv2.putText(image, class_name, (int(det[0]), int(det[1]) - 5), font, font_scale, color, thickness=1) | |
# return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
def show_preds_image(image_path, selection): | |
image = cv2.imread(image_path) | |
outputs = models[selection].predict(source=image_path) | |
results = outputs[0].cpu().numpy() | |
for i, det in enumerate(results.boxes.xyxy): | |
cv2.rectangle( | |
image, | |
(int(det[0]), int(det[1])), | |
(int(det[2]), int(det[3])), | |
color=(0, 0, 255), | |
thickness=2, | |
lineType=cv2.LINE_AA | |
) | |
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
inputs = [ | |
gr.components.Image(type="filepath", label="Input Image"), | |
gr.components.Dropdown(choices=list(zip(model_names, range(len(models)))), label="Select Model"), | |
] | |
outputs_image = [ | |
gr.components.Image(type="numpy", label="Output Image"), | |
] | |
# Create the Gradio interface | |
interface_image = gr.Interface( | |
fn=show_preds_image, | |
inputs=inputs, | |
outputs=outputs_image, | |
title="Paddy Growth Stage Recognition", | |
description="Select an image and a YOLO model to detect the growth stage", | |
examples=examples, | |
cache_examples=False, | |
) | |
gr.TabbedInterface( | |
[interface_image], | |
tab_names=['Image inference'] | |
).queue().launch() | |