import gradio as gr import torch from torchvision import transforms, models from PIL import Image import torch.nn as nn import os os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" # Use the model architecture class ResNet18(nn.Module): def __init__(self, num_classes): super(ResNet18, self).__init__() self.resnet18 = models.resnet18(weights='ResNet18_Weights.DEFAULT') self.resnet18.fc = nn.Linear(self.resnet18.fc.in_features, num_classes) def forward(self, x): return self.resnet18(x) # Load the pretrained classifier num_classes = 2 # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = 'cpu' model = ResNet18(num_classes=num_classes) model.load_state_dict(torch.load('resnet_state_dict.pth', map_location=device)) # Load trained state path from resnet_state_dict.pth model = model.to(device) model.eval() # Transform transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # Define classes class_names = ["Yes, it is a hotdog :)", "No, it isn't a hotdog! :("] # Prediction function def predict(image): try: if isinstance(image, Image.Image): image = image.convert("RGB") else: raise ValueError("Input is not a PIL Image") image = transform(image).unsqueeze(0) image = image.to(device) # Perform inference with torch.no_grad(): output = model(image) _, predicted = torch.max(output, 1) return class_names[predicted.item()] except Exception as e: return str(e) # Use one of the preset images if not for an uploaded hotdog image preset_images = [ 'data/test/hot_dog/133012.jpg', 'data/test/hot_dog/133015.jpg', 'data/test/hot_dog/133245.jpg', 'data/test/hot_dog/135628.jpg', 'data/test/hot_dog/138933.jpg', 'data/test/not_hot_dog/6229.jpg', 'data/test/not_hot_dog/6261.jpg', 'data/test/not_hot_dog/6709.jpg', 'data/test/not_hot_dog/6926.jpg', 'data/test/not_hot_dog/7056.jpg'] # Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload your image"), theme='gstaff/xkcd', outputs=gr.Textbox(label="Is it a hotodog?"), # Show the predicted class name live=True, description="Your friendly hotdog/nothotdog classifier" ) header = gr.Markdown(""" # Welcome to the Hotdog Classifier! 🍔 This app classifies whether an image shows a hotdog or not. Upload an image or choose from the preset images below. """) # Launch the app, currently share set to True iface.launch()