File size: 4,114 Bytes
2ded624
49a7401
 
 
 
d79e14c
688a048
bddc3d6
2ded624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49a7401
 
aadf03a
 
 
 
2ded624
aadf03a
 
688a048
164d735
 
 
 
 
 
688a048
3802471
688a048
164d735
 
 
 
 
688a048
 
 
 
 
49a7401
688a048
49a7401
 
 
 
826a99e
49a7401
bda76e2
30a5c0f
49a7401
bda76e2
49a7401
688a048
f7724ab
 
 
e6445f0
 
 
 
 
 
f7724ab
 
 
e6445f0
 
 
 
f7724ab
 
 
 
 
 
 
 
e47bd96
 
 
 
f7724ab
 
 
 
 
 
 
 
 
 
49a7401
f7724ab
49a7401
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import torch
from torch import nn
from torchvision import models
from torchvision.transforms import v2
from huggingface_hub import hf_hub_download
import gradio as gr

labels = [
    'Pastel',
    'Yellow Belly',
    'Enchi',
    'Clown',
    'Leopard',
    'Piebald',
    'Orange Dream',
    'Fire',
    'Mojave',
    'Pinstripe',
    'Banana',
    'Normal',
    'Black Pastel',
    'Lesser',
    'Spotnose',
    'Cinnamon',
    'GHI',
    'Hypo',
    'Spider',
    'Super Pastel',
    'Desert Ghost',
    'Black Head',
    'Vanilla',
    'Red Stripe',
    'Asphalt',
    'Gravel',
    'Butter',
    'Calico',
    'Albino',
    'Chocolate'
    ]

num_labels = len(labels)

# If using GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hf_token = os.getenv('HF_token')
model_path = hf_hub_download(repo_id="samfhy/morphmarket_model", filename="model_v13_1_epoch9.pt", token=hf_token)
checkpoint = torch.load(model_path, map_location=device)

new_layers = nn.Sequential(
    nn.LazyLinear(2048),  
    nn.BatchNorm1d(2048),  
    nn.ReLU(),            
    nn.Dropout(0.5),      
    nn.LazyLinear(num_labels)
    )

IMAGE_SIZE = checkpoint['image_size']
transform = v2.Compose([
    v2.ToImage(),
    v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    v2.ToDtype(torch.float32),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

efficientnet = models.efficientnet_v2_l(weights='EfficientNet_V2_L_Weights.DEFAULT')
efficientnet.classifier = new_layers
efficientnet.load_state_dict(checkpoint['model_state_dict'])
efficientnet.eval()

def predict(img, confidence):
    input_img = transform(img)
    input_img = input_img.unsqueeze(0)

    with torch.no_grad():
        output = efficientnet(input_img)

    predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist()
    prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > confidence}

    return prediction_dict


with gr.Blocks(title='Ball Python Morph Identifier') as demo:
    gr.Markdown("# Ball Python Morph Identifier")
    gr.Markdown("Upload or paste an image of your ball python to identify its morphs!")
    gr.Markdown("""
        If you're unfamiliar with snakes, ball pythons come in various patterns and colors,
        called *morphs*, which can be difficult to distinguish without expert knowledge.  
        This tool automatically identifies these unique variations, making identification accessible to everyone.  
        Try selecting one of the examples and click "Identify Morphs" to see how it works!
        """)
    
    with gr.Accordion("Click here to show all the morphs that can be predicted", open=False):
        gr.Markdown("""
        Albino, Asphalt, Banana, Black Head, Black Pastel, Butter, Calico, Chocolate, Cinnamon, Clown,  
        Desert Ghost, Enchi, Fire, GHI, Gravel, Hypo, Leopard, Lesser, Mojave, Normal,  
        Orange Dream, Pastel, Piebald, Pinstripe, Red Stripe, Spider, Spotnose, Super Pastel, Vanilla, Yellow Belly
        """)
    
    with gr.Row():
        with gr.Column(scale=1):
            img_input = gr.Image(type="pil", label="Upload/Paste Image")           
            gr.Examples(
                examples=[
                    ["enchi_albino_clown.png", "Enchi, Albino, Clown"],
                    ["mojave_ghi.png", "Mojave, GHI"],
                    ["hypo_banana_pastel_enchi.png", "Hypo, Banana, Pastel, Enchi"],
                    ["yb_pastel_gravel.png", "Yellow Belly, Pastel, Gravel"],
                    ["ivory.png", "Super Yellow Belly"]
                    ],
                inputs=[img_input]
            )
            confidence = gr.Slider(0, 1, value=0.5, label="Confidence", 
                                info="Show predictions that are above this confidence level")
            predict_btn = gr.Button("Identify Morphs", variant="primary")
            
        with gr.Column(scale=1):
            label_output = gr.Label(label="Predicted Morphs")
    
    predict_btn.click(fn=predict, inputs=[img_input, confidence], outputs=label_output)

demo.launch()