File size: 2,667 Bytes
cd1c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9e619e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
import torch
from torchvision import transforms, models
from PIL import Image
import torch.nn as nn
import os
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"


# Use the model architecture
class ResNet18(nn.Module):
    def __init__(self, num_classes):
        super(ResNet18, self).__init__()
        self.resnet18 = models.resnet18(weights='ResNet18_Weights.DEFAULT')
        self.resnet18.fc = nn.Linear(self.resnet18.fc.in_features, num_classes)

    def forward(self, x):
        return self.resnet18(x)

# Load the pretrained classifier
num_classes = 2
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
model = ResNet18(num_classes=num_classes)
model.load_state_dict(torch.load('resnet_state_dict.pth', map_location=device)) # Load trained state path from resnet_state_dict.pth
model = model.to(device)
model.eval()

# Transform
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Define classes
class_names = ["Yes, it is a hotdog :)", "No, it isn't a hotdog! :("]

# Prediction function
def predict(image):
    try:
        if isinstance(image, Image.Image):
            image = image.convert("RGB")
        else:
            raise ValueError("Input is not a PIL Image")

        image = transform(image).unsqueeze(0)
        image = image.to(device)

        # Perform inference
        with torch.no_grad():
            output = model(image)
            _, predicted = torch.max(output, 1)

        return class_names[predicted.item()]
    except Exception as e:
        return str(e)

# Use one of the preset images if not for an uploaded hotdog image
preset_images = [
    'data/test/hot_dog/133012.jpg',
    'data/test/hot_dog/133015.jpg',
    'data/test/hot_dog/133245.jpg',
    'data/test/hot_dog/135628.jpg',
    'data/test/hot_dog/138933.jpg',
    'data/test/not_hot_dog/6229.jpg',
    'data/test/not_hot_dog/6261.jpg',
    'data/test/not_hot_dog/6709.jpg',
    'data/test/not_hot_dog/6926.jpg',
    'data/test/not_hot_dog/7056.jpg']

# Gradio interface
iface = gr.Interface(
    fn=predict, 
    inputs=gr.Image(type="pil", label="Upload your image"),
    theme='gstaff/xkcd',
    outputs=gr.Textbox(label="Is it a hotodog?"),  # Show the predicted class name
    live=True,
    description="Your friendly hotdog/nothotdog classifier"
)

header = gr.Markdown("""
    # Welcome to the Hotdog Classifier! πŸ”
    This app classifies whether an image shows a hotdog or not.
    Upload an image or choose from the preset images below.
""")

# Launch the app, currently share set to True
iface.launch()