Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from PIL import Image | |
import torchvision.transforms as T | |
# Load the trained model (YOLOv8n) with your weights | |
model = torch.hub.load('ultralytics/yolov8', 'yolov8n') | |
model.load_state_dict(torch.load("best_p6.pt")) | |
model.eval() | |
# Define the image transformation (if required, based on your dataset preprocessing) | |
transform = T.Compose([T.ToTensor()]) | |
# Define the inference function | |
def process_image(image): | |
# Convert the image to tensor and make inference | |
image_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
with torch.no_grad(): | |
outputs = model(image_tensor) | |
# Get the output image with bounding boxes (you can adjust this part based on your model's output) | |
result_image = outputs.render()[0] # This will render bounding boxes on the image | |
# Convert to PIL image for easy download | |
result_pil_image = Image.fromarray(result_image) | |
# Save the output image for download | |
output_path = "/tmp/output_image.jpg" | |
result_pil_image.save(output_path) | |
return output_path | |
# Define Gradio interface | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type="pil"), # Image input from user | |
outputs=gr.File(label="Download Processed Image"), # Provide the file output for download | |
title="Waste Detection", # Interface title | |
description="Upload an image of floating waste, and the model will detect and label the objects in it." | |
) | |
# Launch the interface | |
iface.launch() | |