GerryRaz's picture
Fix the text in the desc
c9bf638 verified
import gradio as gr
import os
import torch
#from model import create_densenet_model
import model
from timeit import default_timer as timer
import os
from pathlib import Path
from pathlib import Path
# 1. Get the path of the current folder where app.py is located
current_dir = Path(__file__).parent
# 2. Load class names from the text file
# This fixes the "File Not Found" error by using the full path
with open(current_dir / "class_names.txt", "r") as f:
class_names = [line.strip() for line in f.readlines()]
# Verify the number of classes loaded
print(f"Loaded {len(class_names)} classes.")
model_1, transforms = model.create_model(num_classes=120)
state_dict = torch.load(
f="30_epoch_model_efficientv2_2_93%_acc_dog_bread_classifier.pth",
weights_only=False,
map_location="cpu"
)
model_1.load_state_dict(state_dict)
def predict_img(img):
start_time = timer()
img = transforms(img).unsqueeze(0)
model_1.eval()
with torch.inference_mode():
# 1. Get the probabilities
pred_probs = torch.softmax(model_1(img), dim=1)
# 2. Create a dictionary of ALL classes and their probabilities
all_pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
# 3. Sort them by value (probability) in descending order and take the first 5
pred_labels_and_probs = dict(sorted(all_pred_labels_and_probs.items(),
key=lambda item: item[1],
reverse=True)[:5])
#pred_time = round(timer() - start_time(),5)
pred_time = round(timer() - start_time, 5)
return pred_labels_and_probs, pred_time
title = "Dog Breed Classifier"
description = "Upload a photo of your dog here to identify its breed! Our AI analyzes 120 different types to give you the top 5 most likely matches in seconds. Simply drag and drop your image, click submit, and see the results. Fast, fun, and accurate dog breed classification at your fingertips."
article = "Created at Mauaque Ressettlement Center Gozales Compound"
example_list = [["examples/" + example] for example in os.listdir("examples")]
demo = gr.Interface(
fn=predict_img,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=5,label="Predictions"),
gr.Number(label="Prediction Time")
],
examples = example_list,
title = title,
description = description,
article = article
)
demo.launch(debug=True)