fohy24
using HF hub to download model
d79e14c
raw
history blame
2.68 kB
import torch
from torch import nn
from torchvision import models
from torchvision.transforms import v2
import os
from huggingface_hub import hf_hub_download
labels = ['Pastel',
'Yellow Belly',
'Enchi',
'Clown',
'Leopard',
'Piebald',
'Orange Dream',
'Fire',
'Mojave',
'Pinstripe',
'Banana',
'Normal',
'Black Pastel',
'Lesser',
'Spotnose',
'Cinnamon',
'GHI',
'Hypo',
'Spider',
'Super Pastel']
num_labels = len(labels)
def predict(img, confidence):
new_layers = nn.Sequential(
nn.Linear(1920, 1000), # Reduce dimension from 1024 to 500
nn.BatchNorm1d(1000), # Normalize the activations from the previous layer
nn.ReLU(), # Non-linear activation function
nn.Dropout(0.5), # Dropout for regularization (50% probability)
nn.Linear(1000, num_labels) # Final layer for class predictions
)
IMAGE_SIZE = 512
transform = v2.Compose([
v2.ToImage(),
v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
densenet = models.densenet201(weights='DenseNet201_Weights.DEFAULT')
densenet.classifier = new_layers
# If using GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hf_token = os.getenv('HF_token')
model_path = hf_hub_download(repo_id="fohy24/morphmarket_model", filename="model_v8_epoch9.pt", token=hf_token)
checkpoint = torch.load(model_path, map_location=device)
densenet.load_state_dict(checkpoint['model_state_dict'])
densenet.eval()
input_img = transform(img)
input_img = input_img.unsqueeze(0)
with torch.no_grad():
output = densenet(input_img)
predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist()
prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > confidence}
return prediction_dict
import gradio as gr
gr.Interface(fn=predict,
inputs=[gr.Image(type="pil"),
gr.Slider(0, 1, value=0.5, label="Confidence", info="Show predictions that are above this confidence level")],
outputs=gr.Label(),
examples=[["pastel_yb.png", 0.5], ["piebald.png", 0.5], ["leopard_fire.png", 0.5]],
title='Ball Python Morph Identifier',
description="Upload or paste an image of your ball python to identify its morphs!"
).launch()