File size: 2,678 Bytes
49a7401 1a423d1 d79e14c 49a7401 bddc3d6 49a7401 30a5c0f 49a7401 d79e14c 1a423d1 d79e14c 49a7401 bda76e2 30a5c0f 49a7401 bda76e2 49a7401 3f3c078 49a7401 3f3c078 30a5c0f 3f3c078 94d6d22 7d51cfd 49a7401 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
from torch import nn
from torchvision import models
from torchvision.transforms import v2
import os
from huggingface_hub import hf_hub_download
labels = ['Pastel',
'Yellow Belly',
'Enchi',
'Clown',
'Leopard',
'Piebald',
'Orange Dream',
'Fire',
'Mojave',
'Pinstripe',
'Banana',
'Normal',
'Black Pastel',
'Lesser',
'Spotnose',
'Cinnamon',
'GHI',
'Hypo',
'Spider',
'Super Pastel']
num_labels = len(labels)
def predict(img, confidence):
new_layers = nn.Sequential(
nn.Linear(1920, 1000), # Reduce dimension from 1024 to 500
nn.BatchNorm1d(1000), # Normalize the activations from the previous layer
nn.ReLU(), # Non-linear activation function
nn.Dropout(0.5), # Dropout for regularization (50% probability)
nn.Linear(1000, num_labels) # Final layer for class predictions
)
IMAGE_SIZE = 512
transform = v2.Compose([
v2.ToImage(),
v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
densenet = models.densenet201(weights='DenseNet201_Weights.DEFAULT')
densenet.classifier = new_layers
# If using GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hf_token = os.getenv('HF_token')
model_path = hf_hub_download(repo_id="fohy24/morphmarket_model", filename="model_v8_epoch9.pt", token=hf_token)
checkpoint = torch.load(model_path, map_location=device)
densenet.load_state_dict(checkpoint['model_state_dict'])
densenet.eval()
input_img = transform(img)
input_img = input_img.unsqueeze(0)
with torch.no_grad():
output = densenet(input_img)
predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist()
prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > confidence}
return prediction_dict
import gradio as gr
gr.Interface(fn=predict,
inputs=[gr.Image(type="pil"),
gr.Slider(0, 1, value=0.5, label="Confidence", info="Show predictions that are above this confidence level")],
outputs=gr.Label(),
examples=[["pastel_yb.png", 0.5], ["piebald.png", 0.5], ["leopard_fire.png", 0.5]],
title='Ball Python Morph Identifier',
description="Upload or paste an image of your ball python to identify its morphs!"
).launch()
|