|
import os |
|
import torch |
|
from torch import nn |
|
from torchvision import models |
|
from torchvision.transforms import v2 |
|
from huggingface_hub import hf_hub_download |
|
import gradio as gr |
|
|
|
labels = [ |
|
'Pastel', |
|
'Yellow Belly', |
|
'Enchi', |
|
'Clown', |
|
'Leopard', |
|
'Piebald', |
|
'Orange Dream', |
|
'Fire', |
|
'Mojave', |
|
'Pinstripe', |
|
'Banana', |
|
'Normal', |
|
'Black Pastel', |
|
'Lesser', |
|
'Spotnose', |
|
'Cinnamon', |
|
'GHI', |
|
'Hypo', |
|
'Spider', |
|
'Super Pastel', |
|
'Desert Ghost', |
|
'Black Head', |
|
'Vanilla', |
|
'Red Stripe', |
|
'Asphalt', |
|
'Gravel', |
|
'Butter', |
|
'Calico', |
|
'Albino', |
|
'Chocolate' |
|
] |
|
|
|
num_labels = len(labels) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
hf_token = os.getenv('HF_token') |
|
model_path = hf_hub_download(repo_id="samfhy/morphmarket_model", filename="model_v13_1_epoch9.pt", token=hf_token) |
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
|
new_layers = nn.Sequential( |
|
nn.LazyLinear(2048), |
|
nn.BatchNorm1d(2048), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.LazyLinear(num_labels) |
|
) |
|
|
|
IMAGE_SIZE = checkpoint['image_size'] |
|
transform = v2.Compose([ |
|
v2.ToImage(), |
|
v2.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
|
v2.ToDtype(torch.float32), |
|
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
efficientnet = models.efficientnet_v2_l(weights='EfficientNet_V2_L_Weights.DEFAULT') |
|
efficientnet.classifier = new_layers |
|
efficientnet.load_state_dict(checkpoint['model_state_dict']) |
|
efficientnet.eval() |
|
|
|
def predict(img, confidence): |
|
input_img = transform(img) |
|
input_img = input_img.unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = efficientnet(input_img) |
|
|
|
predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist() |
|
prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > confidence} |
|
|
|
return prediction_dict |
|
|
|
|
|
with gr.Blocks(title='Ball Python Morph Identifier') as demo: |
|
gr.Markdown("# Ball Python Morph Identifier") |
|
gr.Markdown("Upload or paste an image of your ball python to identify its morphs!") |
|
gr.Markdown(""" |
|
If you're unfamiliar with snakes, ball pythons come in various patterns and colors, |
|
called *morphs*, which can be difficult to distinguish without expert knowledge. |
|
This tool automatically identifies these unique variations, making identification accessible to everyone. |
|
Try selecting one of the examples and click "Identify Morphs" to see how it works! |
|
""") |
|
|
|
with gr.Accordion("Click here to show all the morphs that can be predicted", open=False): |
|
gr.Markdown(""" |
|
Albino, Asphalt, Banana, Black Head, Black Pastel, Butter, Calico, Chocolate, Cinnamon, Clown, |
|
Desert Ghost, Enchi, Fire, GHI, Gravel, Hypo, Leopard, Lesser, Mojave, Normal, |
|
Orange Dream, Pastel, Piebald, Pinstripe, Red Stripe, Spider, Spotnose, Super Pastel, Vanilla, Yellow Belly |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
img_input = gr.Image(type="pil", label="Upload/Paste Image") |
|
gr.Examples( |
|
examples=[ |
|
["enchi_albino_clown.png", "Enchi, Albino, Clown"], |
|
["mojave_ghi.png", "Mojave, GHI"], |
|
["hypo_banana_pastel_enchi.png", "Hypo, Banana, Pastel, Enchi"], |
|
["yb_pastel_gravel.png", "Yellow Belly, Pastel, Gravel"], |
|
["ivory.png", "Super Yellow Belly"] |
|
], |
|
inputs=[img_input] |
|
) |
|
confidence = gr.Slider(0, 1, value=0.5, label="Confidence", |
|
info="Show predictions that are above this confidence level") |
|
predict_btn = gr.Button("Identify Morphs", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
label_output = gr.Label(label="Predicted Morphs") |
|
|
|
predict_btn.click(fn=predict, inputs=[img_input, confidence], outputs=label_output) |
|
|
|
demo.launch() |
|
|
|
|