File size: 4,114 Bytes
2ded624 49a7401 d79e14c 688a048 bddc3d6 2ded624 49a7401 aadf03a 2ded624 aadf03a 688a048 164d735 688a048 3802471 688a048 164d735 688a048 49a7401 688a048 49a7401 826a99e 49a7401 bda76e2 30a5c0f 49a7401 bda76e2 49a7401 688a048 f7724ab e6445f0 f7724ab e6445f0 f7724ab e47bd96 f7724ab 49a7401 f7724ab 49a7401 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import torch
from torch import nn
from torchvision import models
from torchvision.transforms import v2
from huggingface_hub import hf_hub_download
import gradio as gr
labels = [
'Pastel',
'Yellow Belly',
'Enchi',
'Clown',
'Leopard',
'Piebald',
'Orange Dream',
'Fire',
'Mojave',
'Pinstripe',
'Banana',
'Normal',
'Black Pastel',
'Lesser',
'Spotnose',
'Cinnamon',
'GHI',
'Hypo',
'Spider',
'Super Pastel',
'Desert Ghost',
'Black Head',
'Vanilla',
'Red Stripe',
'Asphalt',
'Gravel',
'Butter',
'Calico',
'Albino',
'Chocolate'
]
num_labels = len(labels)
# If using GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hf_token = os.getenv('HF_token')
model_path = hf_hub_download(repo_id="samfhy/morphmarket_model", filename="model_v13_1_epoch9.pt", token=hf_token)
checkpoint = torch.load(model_path, map_location=device)
new_layers = nn.Sequential(
nn.LazyLinear(2048),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Dropout(0.5),
nn.LazyLinear(num_labels)
)
IMAGE_SIZE = checkpoint['image_size']
transform = v2.Compose([
v2.ToImage(),
v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
v2.ToDtype(torch.float32),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
efficientnet = models.efficientnet_v2_l(weights='EfficientNet_V2_L_Weights.DEFAULT')
efficientnet.classifier = new_layers
efficientnet.load_state_dict(checkpoint['model_state_dict'])
efficientnet.eval()
def predict(img, confidence):
input_img = transform(img)
input_img = input_img.unsqueeze(0)
with torch.no_grad():
output = efficientnet(input_img)
predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist()
prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > confidence}
return prediction_dict
with gr.Blocks(title='Ball Python Morph Identifier') as demo:
gr.Markdown("# Ball Python Morph Identifier")
gr.Markdown("Upload or paste an image of your ball python to identify its morphs!")
gr.Markdown("""
If you're unfamiliar with snakes, ball pythons come in various patterns and colors,
called *morphs*, which can be difficult to distinguish without expert knowledge.
This tool automatically identifies these unique variations, making identification accessible to everyone.
Try selecting one of the examples and click "Identify Morphs" to see how it works!
""")
with gr.Accordion("Click here to show all the morphs that can be predicted", open=False):
gr.Markdown("""
Albino, Asphalt, Banana, Black Head, Black Pastel, Butter, Calico, Chocolate, Cinnamon, Clown,
Desert Ghost, Enchi, Fire, GHI, Gravel, Hypo, Leopard, Lesser, Mojave, Normal,
Orange Dream, Pastel, Piebald, Pinstripe, Red Stripe, Spider, Spotnose, Super Pastel, Vanilla, Yellow Belly
""")
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image(type="pil", label="Upload/Paste Image")
gr.Examples(
examples=[
["enchi_albino_clown.png", "Enchi, Albino, Clown"],
["mojave_ghi.png", "Mojave, GHI"],
["hypo_banana_pastel_enchi.png", "Hypo, Banana, Pastel, Enchi"],
["yb_pastel_gravel.png", "Yellow Belly, Pastel, Gravel"],
["ivory.png", "Super Yellow Belly"]
],
inputs=[img_input]
)
confidence = gr.Slider(0, 1, value=0.5, label="Confidence",
info="Show predictions that are above this confidence level")
predict_btn = gr.Button("Identify Morphs", variant="primary")
with gr.Column(scale=1):
label_output = gr.Label(label="Predicted Morphs")
predict_btn.click(fn=predict, inputs=[img_input, confidence], outputs=label_output)
demo.launch()
|