fohy24
specifying pytorch version to fix model loading failure
bddc3d6
raw
history blame
2.69 kB
import torch
from torch import nn
from torchvision import models
from torchvision.transforms import v2
import os
import requests
labels = ['Pastel',
'Yellow Belly',
'Enchi',
'Clown',
'Leopard',
'Piebald',
'Orange Dream',
'Fire',
'Mojave',
'Pinstripe',
'Banana',
'Normal',
'Black Pastel',
'Lesser',
'Spotnose',
'Cinnamon',
'GHI',
'Hypo',
'Spider',
'Super Pastel']
num_labels = len(labels)
def predict(img, confidence):
new_layers = nn.Sequential(
nn.Linear(1920, 1000), # Reduce dimension from 1024 to 500
nn.BatchNorm1d(1000), # Normalize the activations from the previous layer
nn.ReLU(), # Non-linear activation function
nn.Dropout(0.5), # Dropout for regularization (50% probability)
nn.Linear(1000, num_labels) # Final layer for class predictions
)
IMAGE_SIZE = 512
transform = v2.Compose([
v2.ToImage(),
v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
densenet = models.densenet201(weights='DenseNet201_Weights.DEFAULT')
densenet.classifier = new_layers
# If using GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download model from GCS
model_path = os.getenv('model_path')
response = requests.get(model_path)
with open('model.pt', 'wb') as f:
f.write(response.content)
checkpoint = torch.load('model.pt', map_location=device)
densenet.load_state_dict(checkpoint['model_state_dict'])
densenet.eval()
input_img = transform(img)
input_img = input_img.unsqueeze(0)
with torch.no_grad():
output = densenet(input_img)
predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist()
prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > confidence}
return prediction_dict
import gradio as gr
gr.Interface(fn=predict,
inputs=[gr.Image(type="pil"),
gr.Slider(0, 1, value=0.5, label="Confidence", info="Show predictions that are above this confidence level")],
outputs=gr.Label(),
examples=[["pastel_yb.png", 0.5], ["piebald.png", 0.5], ["leopard_fire.png", 0.5]],
title='Ball Python Morph Identifier',
description="Upload or paste an image of your ball python to identify its morphs!"
).launch()