hotdog / app.py
asidfactory's picture
change device to cpu
cd1c32c verified
import gradio as gr
import torch
from torchvision import transforms, models
from PIL import Image
import torch.nn as nn
import os
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
# Use the model architecture
class ResNet18(nn.Module):
def __init__(self, num_classes):
super(ResNet18, self).__init__()
self.resnet18 = models.resnet18(weights='ResNet18_Weights.DEFAULT')
self.resnet18.fc = nn.Linear(self.resnet18.fc.in_features, num_classes)
def forward(self, x):
return self.resnet18(x)
# Load the pretrained classifier
num_classes = 2
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
model = ResNet18(num_classes=num_classes)
model.load_state_dict(torch.load('resnet_state_dict.pth', map_location=device)) # Load trained state path from resnet_state_dict.pth
model = model.to(device)
model.eval()
# Transform
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Define classes
class_names = ["Yes, it is a hotdog :)", "No, it isn't a hotdog! :("]
# Prediction function
def predict(image):
try:
if isinstance(image, Image.Image):
image = image.convert("RGB")
else:
raise ValueError("Input is not a PIL Image")
image = transform(image).unsqueeze(0)
image = image.to(device)
# Perform inference
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output, 1)
return class_names[predicted.item()]
except Exception as e:
return str(e)
# Use one of the preset images if not for an uploaded hotdog image
preset_images = [
'data/test/hot_dog/133012.jpg',
'data/test/hot_dog/133015.jpg',
'data/test/hot_dog/133245.jpg',
'data/test/hot_dog/135628.jpg',
'data/test/hot_dog/138933.jpg',
'data/test/not_hot_dog/6229.jpg',
'data/test/not_hot_dog/6261.jpg',
'data/test/not_hot_dog/6709.jpg',
'data/test/not_hot_dog/6926.jpg',
'data/test/not_hot_dog/7056.jpg']
# Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload your image"),
theme='gstaff/xkcd',
outputs=gr.Textbox(label="Is it a hotodog?"), # Show the predicted class name
live=True,
description="Your friendly hotdog/nothotdog classifier"
)
header = gr.Markdown("""
# Welcome to the Hotdog Classifier! πŸ”
This app classifies whether an image shows a hotdog or not.
Upload an image or choose from the preset images below.
""")
# Launch the app, currently share set to True
iface.launch()