fish-freshness-classifier / fish_freshness_app.py
roqueselopeta's picture
Update fish_freshness_app.py
873ea67 verified
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import gradio as gr
import numpy as np
import cv2
# Load the model class definition
from models.efficientnet_b0 import EfficientNetB0Classifier
# Constants
MODEL_PATH = "efficientnet_best9912.pth"
CLASS_NAMES = ["Fresh", "Not Fresh"]
INPUT_SIZE = 380
MODEL_ACCURACY = "99.12%" # Your model's validation accuracy
# Define preprocessing pipeline
preprocess = transforms.Compose([
transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Load model
def load_model():
model = EfficientNetB0Classifier(train_base=False)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
return model
model = load_model()
def process_prediction(confidence_score):
"""Convert model output to detailed prediction information"""
fresh_prob = float(confidence_score)
not_fresh_prob = float(1.0 - confidence_score)
prediction = "Fresh" if fresh_prob > 0.5 else "Not Fresh"
confidence = fresh_prob if fresh_prob > 0.5 else not_fresh_prob
return {
"Fresh": fresh_prob,
"Not Fresh": not_fresh_prob
}, prediction, confidence
def analyze_image(image):
"""Analyze the image and return detailed results"""
if image is None:
return None, None, None, None
# Convert to RGB if needed
if len(image.shape) == 2: # Grayscale
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4: # RGBA
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Prepare image for model
pil_image = Image.fromarray(image).convert('RGB')
input_tensor = preprocess(pil_image).unsqueeze(0)
# Get prediction
with torch.no_grad():
output = model(input_tensor)
confidence_score = output.item()
# Process results
probabilities, prediction, confidence = process_prediction(confidence_score)
# Create result message
confidence_percentage = f"{confidence * 100:.2f}%"
message = f"Prediction: {prediction} (Confidence: {confidence_percentage})"
# Prepare visualization
display_image = cv2.resize(image, (INPUT_SIZE, INPUT_SIZE))
return probabilities, message, display_image, confidence_percentage
# Custom CSS for better styling
custom_css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
color: white;
border-radius: 8px;
background: linear-gradient(45deg, #4CAF50, #45a049);
border: none;
font-size: 1.2em;
padding: 10px 20px;
}
.gr-button:hover {
background: linear-gradient(45deg, #45a049, #4CAF50);
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(0,0,0,0.1);
}
.footer {
margin-top: 20px;
text-align: center;
font-size: 0.8em;
}
.confidence {
font-size: 1.2em;
font-weight: bold;
margin-top: 10px;
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
.result-box {
background: #f8f9fa;
border-radius: 10px;
padding: 20px;
margin-top: 20px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
"""
# Create Gradio interface
with gr.Blocks(css=custom_css) as demo:
gr.Markdown(
"""
# 🐟 Fish Freshness Classifier
Upload a fish image and get instant freshness analysis using our advanced AI model.
### Model Performance
- Architecture: EfficientNet-B0
- Validation Accuracy: """ + MODEL_ACCURACY + """
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="Upload Fish Image",
type="numpy",
height=400,
sources=["upload", "webcam", "clipboard"]
)
upload_button = gr.Button("πŸ“Έ Analyze Freshness", variant="primary", size="lg")
with gr.Column(scale=1):
with gr.Group(elem_classes="result-box"):
output_label = gr.Label(
num_top_classes=2,
label="Freshness Analysis",
show_label=True
)
result_message = gr.Textbox(
label="Detailed Result",
show_copy_button=True
)
confidence_indicator = gr.Textbox(
label="Confidence Level",
show_copy_button=True
)
gr.Markdown(
"""
### πŸ“ Best Practices
- Use clear, well-lit images
- Ensure the fish is clearly visible
- Include key features (eyes, gills, skin)
- Avoid blurry or dark photos
"""
)
# Set up the prediction flow
upload_button.click(
fn=analyze_image,
inputs=input_image,
outputs=[output_label, result_message, input_image, confidence_indicator]
)
if __name__ == "__main__":
demo.launch(share=True)