asidfactory commited on
Commit
f9e619e
·
verified ·
1 Parent(s): 82247e9

added gradioApp.py to main

Browse files
Files changed (1) hide show
  1. gradioApp.py +88 -0
gradioApp.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms, models
4
+ from PIL import Image
5
+ import torch.nn as nn
6
+ import os
7
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
8
+
9
+
10
+ # Use the model architecture
11
+ class ResNet18(nn.Module):
12
+ def __init__(self, num_classes):
13
+ super(ResNet18, self).__init__()
14
+ self.resnet18 = models.resnet18(weights='ResNet18_Weights.DEFAULT')
15
+ self.resnet18.fc = nn.Linear(self.resnet18.fc.in_features, num_classes)
16
+
17
+ def forward(self, x):
18
+ return self.resnet18(x)
19
+
20
+ # Load the pretrained classifier
21
+ num_classes = 2
22
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ model = ResNet18(num_classes=num_classes)
24
+ model.load_state_dict(torch.load('resnet_state_dict.pth', map_location=device)) # Load trained state path from resnet_state_dict.pth
25
+ model = model.to(device)
26
+ model.eval()
27
+
28
+ # Transform
29
+ transform = transforms.Compose([
30
+ transforms.Resize((256, 256)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize((0.5,), (0.5,))
33
+ ])
34
+
35
+ # Define classes
36
+ class_names = ["Yes, it is a hotdog :)", "No, it isn't a hotdog! :("]
37
+
38
+ # Prediction function
39
+ def predict(image):
40
+ try:
41
+ if isinstance(image, Image.Image):
42
+ image = image.convert("RGB")
43
+ else:
44
+ raise ValueError("Input is not a PIL Image")
45
+
46
+ image = transform(image).unsqueeze(0)
47
+ image = image.to(device)
48
+
49
+ # Perform inference
50
+ with torch.no_grad():
51
+ output = model(image)
52
+ _, predicted = torch.max(output, 1)
53
+
54
+ return class_names[predicted.item()]
55
+ except Exception as e:
56
+ return str(e)
57
+
58
+ # Use one of the preset images if not for an uploaded hotdog image
59
+ preset_images = [
60
+ 'data/test/hot_dog/133012.jpg',
61
+ 'data/test/hot_dog/133015.jpg',
62
+ 'data/test/hot_dog/133245.jpg',
63
+ 'data/test/hot_dog/135628.jpg',
64
+ 'data/test/hot_dog/138933.jpg',
65
+ 'data/test/not_hot_dog/6229.jpg',
66
+ 'data/test/not_hot_dog/6261.jpg',
67
+ 'data/test/not_hot_dog/6709.jpg',
68
+ 'data/test/not_hot_dog/6926.jpg',
69
+ 'data/test/not_hot_dog/7056.jpg']
70
+
71
+ # Gradio interface
72
+ iface = gr.Interface(
73
+ fn=predict,
74
+ inputs=gr.Image(type="pil", label="Upload your image"),
75
+ theme='gstaff/xkcd',
76
+ outputs=gr.Textbox(label="Is it a hotodog?"), # Show the predicted class name
77
+ live=True,
78
+ description="Your friendly hotdog/nothotdog classifier"
79
+ )
80
+
81
+ header = gr.Markdown("""
82
+ # Welcome to the Hotdog Classifier! 🍔
83
+ This app classifies whether an image shows a hotdog or not.
84
+ Upload an image or choose from the preset images below.
85
+ """)
86
+
87
+ # Launch the app, currently share set to True
88
+ iface.launch()