Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torchvision import transforms, models | |
from PIL import Image | |
import torch.nn as nn | |
import os | |
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
# Use the model architecture | |
class ResNet18(nn.Module): | |
def __init__(self, num_classes): | |
super(ResNet18, self).__init__() | |
self.resnet18 = models.resnet18(weights='ResNet18_Weights.DEFAULT') | |
self.resnet18.fc = nn.Linear(self.resnet18.fc.in_features, num_classes) | |
def forward(self, x): | |
return self.resnet18(x) | |
# Load the pretrained classifier | |
num_classes = 2 | |
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
device = 'cpu' | |
model = ResNet18(num_classes=num_classes) | |
model.load_state_dict(torch.load('resnet_state_dict.pth', map_location=device)) # Load trained state path from resnet_state_dict.pth | |
model = model.to(device) | |
model.eval() | |
# Transform | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
# Define classes | |
class_names = ["Yes, it is a hotdog :)", "No, it isn't a hotdog! :("] | |
# Prediction function | |
def predict(image): | |
try: | |
if isinstance(image, Image.Image): | |
image = image.convert("RGB") | |
else: | |
raise ValueError("Input is not a PIL Image") | |
image = transform(image).unsqueeze(0) | |
image = image.to(device) | |
# Perform inference | |
with torch.no_grad(): | |
output = model(image) | |
_, predicted = torch.max(output, 1) | |
return class_names[predicted.item()] | |
except Exception as e: | |
return str(e) | |
# Use one of the preset images if not for an uploaded hotdog image | |
preset_images = [ | |
'data/test/hot_dog/133012.jpg', | |
'data/test/hot_dog/133015.jpg', | |
'data/test/hot_dog/133245.jpg', | |
'data/test/hot_dog/135628.jpg', | |
'data/test/hot_dog/138933.jpg', | |
'data/test/not_hot_dog/6229.jpg', | |
'data/test/not_hot_dog/6261.jpg', | |
'data/test/not_hot_dog/6709.jpg', | |
'data/test/not_hot_dog/6926.jpg', | |
'data/test/not_hot_dog/7056.jpg'] | |
# Gradio interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil", label="Upload your image"), | |
theme='gstaff/xkcd', | |
outputs=gr.Textbox(label="Is it a hotodog?"), # Show the predicted class name | |
live=True, | |
description="Your friendly hotdog/nothotdog classifier" | |
) | |
header = gr.Markdown(""" | |
# Welcome to the Hotdog Classifier! π | |
This app classifies whether an image shows a hotdog or not. | |
Upload an image or choose from the preset images below. | |
""") | |
# Launch the app, currently share set to True | |
iface.launch() |