curiouscurrent's picture
Update app.py
bc99e01 verified
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as T
# Load the trained model (YOLOv8n) with your weights
model = torch.hub.load('ultralytics/yolov8', 'yolov8n')
model.load_state_dict(torch.load("best_p6.pt"))
model.eval()
# Define the image transformation (if required, based on your dataset preprocessing)
transform = T.Compose([T.ToTensor()])
# Define the inference function
def process_image(image):
# Convert the image to tensor and make inference
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs = model(image_tensor)
# Get the output image with bounding boxes (you can adjust this part based on your model's output)
result_image = outputs.render()[0] # This will render bounding boxes on the image
# Convert to PIL image for easy download
result_pil_image = Image.fromarray(result_image)
# Save the output image for download
output_path = "/tmp/output_image.jpg"
result_pil_image.save(output_path)
return output_path
# Define Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"), # Image input from user
outputs=gr.File(label="Download Processed Image"), # Provide the file output for download
title="Waste Detection", # Interface title
description="Upload an image of floating waste, and the model will detect and label the objects in it."
)
# Launch the interface
iface.launch()