|
import torch |
|
from torch import nn |
|
from torchvision import models |
|
from torchvision.transforms import v2 |
|
import os |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
labels = ['Pastel', |
|
'Yellow Belly', |
|
'Enchi', |
|
'Clown', |
|
'Leopard', |
|
'Piebald', |
|
'Orange Dream', |
|
'Fire', |
|
'Mojave', |
|
'Pinstripe', |
|
'Banana', |
|
'Normal', |
|
'Black Pastel', |
|
'Lesser', |
|
'Spotnose', |
|
'Cinnamon', |
|
'GHI', |
|
'Hypo', |
|
'Spider', |
|
'Super Pastel'] |
|
num_labels = len(labels) |
|
|
|
def predict(img, confidence): |
|
|
|
new_layers = nn.Sequential( |
|
nn.Linear(1920, 1000), |
|
nn.BatchNorm1d(1000), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Linear(1000, num_labels) |
|
) |
|
|
|
IMAGE_SIZE = 512 |
|
transform = v2.Compose([ |
|
v2.ToImage(), |
|
v2.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
|
v2.ToDtype(torch.float32, scale=True), |
|
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
densenet = models.densenet201(weights='DenseNet201_Weights.DEFAULT') |
|
densenet.classifier = new_layers |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
hf_token = os.getenv('HF_token') |
|
model_path = hf_hub_download(repo_id="fohy24/morphmarket_model", filename="model_v8_epoch9.pt", token=hf_token) |
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
densenet.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
densenet.eval() |
|
|
|
input_img = transform(img) |
|
input_img = input_img.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
output = densenet(input_img) |
|
|
|
predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist() |
|
prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > confidence} |
|
|
|
return prediction_dict |
|
|
|
import gradio as gr |
|
|
|
gr.Interface(fn=predict, |
|
inputs=[gr.Image(type="pil"), |
|
gr.Slider(0, 1, value=0.5, label="Confidence", info="Show predictions that are above this confidence level")], |
|
outputs=gr.Label(), |
|
examples=[["pastel_yb.png", 0.5], ["piebald.png", 0.5], ["leopard_fire.png", 0.5]], |
|
title='Ball Python Morph Identifier', |
|
description="Upload or paste an image of your ball python to identify its morphs!" |
|
).launch() |
|
|
|
|