SyedNaseem's picture
Create app.py
372827b verified
from ultralytics import YOLO
import gradio as gr
import numpy as np
import cv2
# Load the trained model
model = YOLO('/content/drive/MyDrive/MS-Thesis/Multi-Class Classification/runs/classify/train4/weights/best.pt') # Replace with the path to your trained model
# Prediction function
def predict_image(image):
try:
# Convert the input image to the format expected by the model
image = np.array(image)
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Make prediction
results = model.predict(image_bgr)
# Get the predicted class and confidence using the correct attributes
predicted_class = results[0].names[results[0].probs.top1]
confidence = results[0].probs.top1conf
# Annotate image with predicted class and confidence
annotated_image = image.copy()
cv2.putText(annotated_image, f"{predicted_class}: {confidence:.2f}",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
# Convert the annotated image back to RGB for display
annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
return annotated_image_rgb, f"Predicted: {predicted_class} with {confidence:.2f} confidence"
except Exception as e:
# Return an error message if something goes wrong
return None, f"Error: {str(e)}"
# Define the Gradio interface
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(label="Upload an Image"),
outputs=[gr.Image(label="Annotated Image"), gr.Text(label="Prediction")],
title="Fruit Freshness Classifier",
description="Upload an image of a fruit, and the model will predict whether it is Fresh, Mild, or Rotten, and display the result on the image."
)
# Launch the interface
interface.launch()