import os import torch from torch import nn from torchvision import models from torchvision.transforms import v2 from huggingface_hub import hf_hub_download import gradio as gr labels = [ 'Pastel', 'Yellow Belly', 'Enchi', 'Clown', 'Leopard', 'Piebald', 'Orange Dream', 'Fire', 'Mojave', 'Pinstripe', 'Banana', 'Normal', 'Black Pastel', 'Lesser', 'Spotnose', 'Cinnamon', 'GHI', 'Hypo', 'Spider', 'Super Pastel', 'Desert Ghost', 'Black Head', 'Vanilla', 'Red Stripe', 'Asphalt', 'Gravel', 'Butter', 'Calico', 'Albino', 'Chocolate' ] num_labels = len(labels) # If using GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") hf_token = os.getenv('HF_token') model_path = hf_hub_download(repo_id="samfhy/morphmarket_model", filename="model_v13_1_epoch9.pt", token=hf_token) checkpoint = torch.load(model_path, map_location=device) new_layers = nn.Sequential( nn.LazyLinear(2048), nn.BatchNorm1d(2048), nn.ReLU(), nn.Dropout(0.5), nn.LazyLinear(num_labels) ) IMAGE_SIZE = checkpoint['image_size'] transform = v2.Compose([ v2.ToImage(), v2.Resize((IMAGE_SIZE, IMAGE_SIZE)), v2.ToDtype(torch.float32), v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) efficientnet = models.efficientnet_v2_l(weights='EfficientNet_V2_L_Weights.DEFAULT') efficientnet.classifier = new_layers efficientnet.load_state_dict(checkpoint['model_state_dict']) efficientnet.eval() def predict(img, confidence): input_img = transform(img) input_img = input_img.unsqueeze(0) with torch.no_grad(): output = efficientnet(input_img) predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist() prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > confidence} return prediction_dict with gr.Blocks(title='Ball Python Morph Identifier') as demo: gr.Markdown("# Ball Python Morph Identifier") gr.Markdown("Upload or paste an image of your ball python to identify its morphs!") gr.Markdown(""" If you're unfamiliar with snakes, ball pythons come in various patterns and colors, called *morphs*, which can be difficult to distinguish without expert knowledge. This tool automatically identifies these unique variations, making identification accessible to everyone. Try selecting one of the examples and click "Identify Morphs" to see how it works! """) with gr.Accordion("Click here to show all the morphs that can be predicted", open=False): gr.Markdown(""" Albino, Asphalt, Banana, Black Head, Black Pastel, Butter, Calico, Chocolate, Cinnamon, Clown, Desert Ghost, Enchi, Fire, GHI, Gravel, Hypo, Leopard, Lesser, Mojave, Normal, Orange Dream, Pastel, Piebald, Pinstripe, Red Stripe, Spider, Spotnose, Super Pastel, Vanilla, Yellow Belly """) with gr.Row(): with gr.Column(scale=1): img_input = gr.Image(type="pil", label="Upload/Paste Image") gr.Examples( examples=[ ["enchi_albino_clown.png", "Enchi, Albino, Clown"], ["mojave_ghi.png", "Mojave, GHI"], ["hypo_banana_pastel_enchi.png", "Hypo, Banana, Pastel, Enchi"], ["yb_pastel_gravel.png", "Yellow Belly, Pastel, Gravel"], ["ivory.png", "Super Yellow Belly"] ], inputs=[img_input] ) confidence = gr.Slider(0, 1, value=0.5, label="Confidence", info="Show predictions that are above this confidence level") predict_btn = gr.Button("Identify Morphs", variant="primary") with gr.Column(scale=1): label_output = gr.Label(label="Predicted Morphs") predict_btn.click(fn=predict, inputs=[img_input, confidence], outputs=label_output) demo.launch()