fohy24
remove print()
57a0cc8
import os
import torch
from torch import nn
from torchvision import models
from torchvision.transforms import v2
from huggingface_hub import hf_hub_download
import gradio as gr
labels = [
'Pastel',
'Yellow Belly',
'Enchi',
'Clown',
'Leopard',
'Piebald',
'Orange Dream',
'Fire',
'Mojave',
'Pinstripe',
'Banana',
'Normal',
'Black Pastel',
'Lesser',
'Spotnose',
'Cinnamon',
'GHI',
'Hypo',
'Spider',
'Super Pastel',
'Desert Ghost',
'Black Head',
'Vanilla',
'Red Stripe',
'Asphalt',
'Gravel',
'Butter',
'Calico',
'Albino',
'Chocolate'
]
num_labels = len(labels)
# If using GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hf_token = os.getenv('HF_token')
model_path = hf_hub_download(repo_id="samfhy/morphmarket_model", filename="model_v13_1_epoch9.pt", token=hf_token)
checkpoint = torch.load(model_path, map_location=device)
new_layers = nn.Sequential(
nn.LazyLinear(2048),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Dropout(0.5),
nn.LazyLinear(num_labels)
)
IMAGE_SIZE = checkpoint['image_size']
transform = v2.Compose([
v2.ToImage(),
v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
v2.ToDtype(torch.float32),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
efficientnet = models.efficientnet_v2_l(weights='EfficientNet_V2_L_Weights.DEFAULT')
efficientnet.classifier = new_layers
efficientnet.load_state_dict(checkpoint['model_state_dict'])
efficientnet.eval()
def predict(img, confidence):
input_img = transform(img)
input_img = input_img.unsqueeze(0)
with torch.no_grad():
output = efficientnet(input_img)
predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist()
prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > confidence}
return prediction_dict
with gr.Blocks(title='Ball Python Morph Identifier') as demo:
gr.Markdown("# Ball Python Morph Identifier")
gr.Markdown("Upload or paste an image of your ball python to identify its morphs!")
gr.Markdown("""
If you're unfamiliar with snakes, ball pythons come in various patterns and colors,
called *morphs*, which can be difficult to distinguish without expert knowledge.
This tool automatically identifies these unique variations, making identification accessible to everyone.
Try selecting one of the examples and click "Identify Morphs" to see how it works!
""")
with gr.Accordion("Click here to show all the morphs that can be predicted", open=False):
gr.Markdown("""
Albino, Asphalt, Banana, Black Head, Black Pastel, Butter, Calico, Chocolate, Cinnamon, Clown,
Desert Ghost, Enchi, Fire, GHI, Gravel, Hypo, Leopard, Lesser, Mojave, Normal,
Orange Dream, Pastel, Piebald, Pinstripe, Red Stripe, Spider, Spotnose, Super Pastel, Vanilla, Yellow Belly
""")
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image(type="pil", label="Upload/Paste Image")
gr.Examples(
examples=[
["enchi_albino_clown.png", "Enchi, Albino, Clown"],
["mojave_ghi.png", "Mojave, GHI"],
["hypo_banana_pastel_enchi.png", "Hypo, Banana, Pastel, Enchi"],
["yb_pastel_gravel.png", "Yellow Belly, Pastel, Gravel"],
["ivory.png", "Super Yellow Belly"]
],
inputs=[img_input]
)
confidence = gr.Slider(0, 1, value=0.5, label="Confidence",
info="Show predictions that are above this confidence level")
predict_btn = gr.Button("Identify Morphs", variant="primary")
with gr.Column(scale=1):
label_output = gr.Label(label="Predicted Morphs")
predict_btn.click(fn=predict, inputs=[img_input, confidence], outputs=label_output)
demo.launch()