Spaces:
Runtime error
Runtime error
from PIL import Image | |
from yolo import YOLO | |
import gradio as gr | |
import os | |
from tqdm import tqdm | |
# Initialize YOLO model | |
yolo = YOLO() | |
def predict_image(image, crop=False, count=True): | |
""" | |
Predict single image using YOLO model | |
""" | |
try: | |
r_image = yolo.detect_image(image, crop=crop, count=count) | |
return r_image | |
except Exception as e: | |
print(f"Error: {e}") | |
return None | |
def predict_directory(input_dir, output_dir, crop=False, count=True): | |
""" | |
Predict images in a directory using YOLO model and save results to another directory | |
""" | |
img_names = os.listdir(input_dir) | |
results = [] | |
for img_name in tqdm(img_names): | |
if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): | |
image_path = os.path.join(input_dir, img_name) | |
image = Image.open(image_path) | |
r_image = yolo.detect_image(image, crop=crop, count=count) | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
output_path = os.path.join(output_dir, img_name.replace(".jpg", ".png")) | |
r_image.save(output_path, quality=95, subsampling=0) | |
results.append((img_name, output_path)) | |
return results | |
def inference(image, mode='predict', crop=False, count=True, input_dir=None, output_dir=None): | |
if mode == 'predict': | |
return predict_image(image, crop=crop, count=count) | |
elif mode == 'dir_predict' and input_dir and output_dir: | |
return predict_directory(input_dir, output_dir, crop=crop, count=count) | |
else: | |
raise ValueError("Invalid mode or missing directories for 'dir_predict' mode.") | |
title = "YOLO Image Prediction" | |
description = "This demo allows you to perform image prediction using a YOLO model. You can either predict a single image or all images in a directory." | |
css = """ | |
.image-frame img, .image-container img { | |
width: auto; | |
height: auto; | |
max-width: none; | |
} | |
""" | |
demo = gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.Image(type="pil", label="Input Image"), | |
gr.Radio(choices=["predict", "dir_predict"], label="Mode", value="predict"), | |
gr.Checkbox(value=False, label="Crop"), | |
gr.Checkbox(value=True, label="Count"), | |
gr.Textbox(placeholder="Input directory (for 'dir_predict' mode)", label="Input Directory", visible=False), | |
gr.Textbox(placeholder="Output directory (for 'dir_predict' mode)", label="Output Directory", visible=False), | |
], | |
outputs=gr.Image(type="pil", label="Output Image"), | |
title=title, | |
description=description, | |
css=css, | |
) | |
if __name__ == "__main__": | |
demo.launch() |