animals / app.py
sdafd's picture
Update app.py
29ebbeb verified
raw
history blame contribute delete
No virus
1.51 kB
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
# Load your trained model
with torch.no_grad():
model = torch.load('classifier.pt')
# Define the preprocessing function for the input image
def preprocess(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = transform(image)
return image.unsqueeze(0)
# Define the predict function
def predict(image):
# Preprocess the image
input_tensor = preprocess(image)
# Make a prediction
with torch.no_grad():
output = model(input_tensor)
# Perform post-processing if needed (e.g., softmax for probabilities)
# Replace this with your actual post-processing logic
probabilities = torch.softmax(output.logits, dim=1).squeeze().tolist()
# Map the class indices to class labels
class_labels = ["Cat", "Dog", "Horse", "Monkey"]
# Create a dictionary with class labels and probabilities
predictions = {label: prob for label, prob in zip(class_labels, probabilities)}
return predictions
# Create the Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(),
outputs=gr.Label(num_top_classes=4),
live=True
)
# Launch the Gradio app
iface.launch(quiet=True)