|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from PIL import Image |
|
import gradio as gr |
|
import numpy as np |
|
import cv2 |
|
|
|
|
|
from models.efficientnet_b0 import EfficientNetB0Classifier |
|
|
|
|
|
MODEL_PATH = "efficientnet_best9912.pth" |
|
CLASS_NAMES = ["Fresh", "Not Fresh"] |
|
INPUT_SIZE = 380 |
|
MODEL_ACCURACY = "99.12%" |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize((INPUT_SIZE, INPUT_SIZE)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
def load_model(): |
|
model = EfficientNetB0Classifier(train_base=False) |
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) |
|
model.eval() |
|
return model |
|
|
|
model = load_model() |
|
|
|
def process_prediction(confidence_score): |
|
"""Convert model output to detailed prediction information""" |
|
fresh_prob = float(confidence_score) |
|
not_fresh_prob = float(1.0 - confidence_score) |
|
|
|
prediction = "Fresh" if fresh_prob > 0.5 else "Not Fresh" |
|
confidence = fresh_prob if fresh_prob > 0.5 else not_fresh_prob |
|
|
|
return { |
|
"Fresh": fresh_prob, |
|
"Not Fresh": not_fresh_prob |
|
}, prediction, confidence |
|
|
|
def analyze_image(image): |
|
"""Analyze the image and return detailed results""" |
|
if image is None: |
|
return None, None, None, None |
|
|
|
|
|
if len(image.shape) == 2: |
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
|
elif image.shape[2] == 4: |
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
pil_image = Image.fromarray(image).convert('RGB') |
|
input_tensor = preprocess(pil_image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
confidence_score = output.item() |
|
|
|
|
|
probabilities, prediction, confidence = process_prediction(confidence_score) |
|
|
|
|
|
confidence_percentage = f"{confidence * 100:.2f}%" |
|
message = f"Prediction: {prediction} (Confidence: {confidence_percentage})" |
|
|
|
|
|
display_image = cv2.resize(image, (INPUT_SIZE, INPUT_SIZE)) |
|
|
|
return probabilities, message, display_image, confidence_percentage |
|
|
|
|
|
custom_css = """ |
|
.gradio-container { |
|
font-family: 'IBM Plex Sans', sans-serif; |
|
} |
|
.gr-button { |
|
color: white; |
|
border-radius: 8px; |
|
background: linear-gradient(45deg, #4CAF50, #45a049); |
|
border: none; |
|
font-size: 1.2em; |
|
padding: 10px 20px; |
|
} |
|
.gr-button:hover { |
|
background: linear-gradient(45deg, #45a049, #4CAF50); |
|
transform: translateY(-2px); |
|
box-shadow: 0 5px 15px rgba(0,0,0,0.1); |
|
} |
|
.footer { |
|
margin-top: 20px; |
|
text-align: center; |
|
font-size: 0.8em; |
|
} |
|
.confidence { |
|
font-size: 1.2em; |
|
font-weight: bold; |
|
margin-top: 10px; |
|
} |
|
.container { |
|
max-width: 1200px; |
|
margin: 0 auto; |
|
padding: 20px; |
|
} |
|
.result-box { |
|
background: #f8f9fa; |
|
border-radius: 10px; |
|
padding: 20px; |
|
margin-top: 20px; |
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1); |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
gr.Markdown( |
|
""" |
|
# π Fish Freshness Classifier |
|
|
|
Upload a fish image and get instant freshness analysis using our advanced AI model. |
|
|
|
### Model Performance |
|
- Architecture: EfficientNet-B0 |
|
- Validation Accuracy: """ + MODEL_ACCURACY + """ |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_image = gr.Image( |
|
label="Upload Fish Image", |
|
type="numpy", |
|
height=400, |
|
sources=["upload", "webcam", "clipboard"] |
|
) |
|
upload_button = gr.Button("πΈ Analyze Freshness", variant="primary", size="lg") |
|
|
|
with gr.Column(scale=1): |
|
with gr.Group(elem_classes="result-box"): |
|
output_label = gr.Label( |
|
num_top_classes=2, |
|
label="Freshness Analysis", |
|
show_label=True |
|
) |
|
result_message = gr.Textbox( |
|
label="Detailed Result", |
|
show_copy_button=True |
|
) |
|
confidence_indicator = gr.Textbox( |
|
label="Confidence Level", |
|
show_copy_button=True |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
### π Best Practices |
|
- Use clear, well-lit images |
|
- Ensure the fish is clearly visible |
|
- Include key features (eyes, gills, skin) |
|
- Avoid blurry or dark photos |
|
""" |
|
) |
|
|
|
|
|
upload_button.click( |
|
fn=analyze_image, |
|
inputs=input_image, |
|
outputs=[output_label, result_message, input_image, confidence_indicator] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|