|
import os |
|
import logging |
|
import gradio as gr |
|
from PIL import Image as PILImg |
|
from iteach_toolkit.DHYOLO import DHYOLODetector |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
def detect_objects(selected_model, input_image, conf, imu_threshold, detections): |
|
try: |
|
model_path = model_options[selected_model] |
|
dhyolo = DHYOLODetector(model_path) |
|
|
|
|
|
input_image_path = "dhyolo_temp_input_image.jpg" |
|
input_image.save(input_image_path) |
|
|
|
|
|
orig_image, detections = dhyolo.predict(input_image_path, conf, imu_threshold, detections) |
|
|
|
|
|
logger.info("Detections: %s", detections) |
|
|
|
|
|
orig_image, image_with_bboxes = dhyolo.plot_bboxes(attach_watermark=True) |
|
|
|
|
|
pil_img_with_bboxes = PILImg.fromarray(image_with_bboxes) |
|
|
|
return input_image, pil_img_with_bboxes |
|
|
|
except FileNotFoundError as e: |
|
logger.error("File not found: %s", e) |
|
return None, None |
|
except Exception as e: |
|
logger.error("An error occurred: %s", e) |
|
return None, None |
|
|
|
def load_test_images(): |
|
"""Load images from the test_imgs directory.""" |
|
test_imgs_dir = os.path.join(os.getcwd(), "test_imgs") |
|
logger.info("Loading images from: %s", test_imgs_dir) |
|
return [f for f in os.listdir(test_imgs_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown("<h1 style='text-align: center;'>πͺπ DHYOLO DoorHandle Object Detection</h1>") |
|
|
|
|
|
gr.Markdown("<h2 style='text-align: center;'>π iTeach: Interactive Teaching for Robot Perception using Mixed Reality</h2>") |
|
|
|
|
|
gr.Markdown("<h2 style='text-align: center;'>π Project Link: <a href='https://irvlutd.github.io/iTeach/' target='_blank'>iTeach Project</a></h2>") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
default_image_path = os.path.join(os.getcwd(), "test_imgs", "jpad-irvl-test.jpg") |
|
|
|
|
|
try: |
|
default_image = PILImg.open(default_image_path) |
|
except FileNotFoundError: |
|
logger.error("Default image not found at: %s", default_image_path) |
|
default_image = None |
|
|
|
input_image = gr.Image(type="pil", label="Input Image", value=default_image) |
|
|
|
|
|
cwd = os.getcwd() |
|
global model_options |
|
model_options = { |
|
"dh-yolo-v1-pb-ddf-524": f'{cwd}/pretrained_ckpts/dh-yolo-v1-pb-ddf-524.pt', |
|
"dh-yolo-exp27-pb-1008": f'{cwd}/pretrained_ckpts/dh-yolo-exp27-pb-1008.pt', |
|
"dh-yolo-exp31-pb-1532": f'{cwd}/pretrained_ckpts/dh-yolo-exp-31-pb-1532.pt', |
|
"dh-yolo-exp31-pl-1532": f'{cwd}/pretrained_ckpts/dh-yolo-exp-31-pl-1532.pt' |
|
} |
|
|
|
model_path = gr.Dropdown(choices=list(model_options.keys()), label="Select Pretrained Model", value="dh-yolo-v1-pb-ddf-524") |
|
|
|
conf = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, step=0.01, value=0.5) |
|
imu_threshold = gr.Slider(label="IoU Threshold", minimum=0.0, maximum=1.0, step=0.01, value=0.5) |
|
detections = gr.Slider(label="Max number of Detections", minimum=1, maximum=100, step=1, value=10) |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label="Output Image with DH-YOLO Detections", type="pil") |
|
|
|
detect_button = gr.Button("Run") |
|
|
|
|
|
detect_button.click(detect_objects, inputs=[model_path, input_image, conf, imu_threshold, detections], |
|
outputs=[input_image, output_image]) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch() |
|
|