File size: 4,436 Bytes
37f6bf3
75c78ca
37f6bf3
 
 
 
85265af
37f6bf3
 
9e437f8
85265af
37f6bf3
4c3f6bd
85265af
 
f4dbfb8
9e437f8
75c78ca
37f6bf3
 
f4dbfb8
37f6bf3
 
 
 
 
 
4c3f6bd
75c78ca
 
 
 
 
 
 
 
 
 
 
4c3f6bd
37f6bf3
4c3f6bd
 
37f6bf3
4c3f6bd
37f6bf3
4c3f6bd
 
 
 
 
75c78ca
f4dbfb8
85265af
4c3f6bd
9e437f8
 
4c3f6bd
85265af
9e437f8
85265af
 
37f6bf3
4c3f6bd
 
9e437f8
 
4c3f6bd
 
12efc89
f4dbfb8
1038938
f4dbfb8
 
 
 
 
 
 
70e3c21
f4dbfb8
 
 
 
 
 
 
 
 
 
 
1038938
f4dbfb8
1038938
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from torchvision import transforms, models
from PIL import Image
import gradio as gr
import os

# Use CPU
device = torch.device('cpu')

# Define ResNet-50 Architecture
model = models.resnet50(weights=None)

# Revise fully connected layer to output 37 classes (num_classes = 37)
model.fc = torch.nn.Linear(2048, 37)

# Load Model weights
model.load_state_dict(torch.load('./resnet50_model_weights.pth', map_location=device))

model.eval()

# Transformation for the input image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# List of class names (37 dog and cat breeds)
class_names = ['Abyssinian (阿比西尼亞貓)', 'American Bulldog (美國鬥牛犬)', 'American Pit Bull Terrier (美國比特鬥牛梗)',
               'Basset Hound (巴吉度獵犬)', 'Beagle (米格魯)', 'Bengal (孟加拉貓)', 'Birman (緬甸貓)', 'Bombay (孟買貓)', 
               'Boxer (拳師犬)', 'British Shorthair (英國短毛貓)', 'Chihuahua (吉娃娃)', 'Egyptian Mau (埃及貓)', 
               'English Cocker Spaniel (英國可卡犬)', 'English Setter (英國設得蘭犬)', 'German Shorthaired (德國短毛犬)', 
               'Great Pyrenees (大白熊犬)', 'Havanese (哈瓦那犬)', 'Japanese Chin (日本狆)', 'Keeshond (荷蘭毛獅犬)', 
               'Leonberger (萊昂貝格犬)', 'Maine Coon (緬因貓)', 'Miniature Pinscher (迷你品犬)', 'Newfoundland (紐芬蘭犬)', 
               'Persian (波斯貓)', 'Pomeranian (博美犬)', 'Pug (哈巴狗)', 'Ragdoll (布偶貓)', 'Russian Blue (俄羅斯藍貓)', 
               'Saint Bernard (聖伯納犬)', 'Samoyed (薩摩耶)', 'Scottish Terrier (蘇格蘭梗)', 'Shiba Inu (柴犬)', 
               'Siamese (暹羅貓)', 'Sphynx (無毛貓)', 'Staffordshire Bull Terrier (史塔福郡鬥牛犬)', 
               'Wheaten Terrier (小麥色梗)', 'Yorkshire Terrier (約克夏犬)']

# Prediction function
def classify_image(image):
    # Apply transformation and add batch dimension
    image = transform(image).unsqueeze(0).to(device)  
    with torch.no_grad():
        # Make predictions using the model
        outputs = model(image)
        # Apply softmax to get probabilities
        probabilities = torch.nn.functional.softmax(outputs, dim=1)  
        # Get the top 3 predictions
        probabilities, indices = torch.topk(probabilities, k=3)  
        # Return the class names with their corresponding probabilities
        predictions = [(class_names[idx], prob.item()) for idx, prob in zip(indices[0], probabilities[0])]
        return {class_name: prob for class_name, prob in predictions}  # Return raw float numbers  # Return formatted percentages

# Path to the folder containing example images
examples_path = './examples'

# Check if the example images folder exists
if os.path.exists(examples_path):
    print(f"[INFO] Found examples folder at {examples_path}")
else:
    print(f"[ERROR] Examples folder not found at {examples_path}")

# Gradio interface
# Load example images from the folder
examples = [[examples_path + "/" + img] for img in os.listdir(examples_path)]

# Create dropdown menu for users to see available classes (as reference, no direct connection to prediction)
dropdown = gr.Dropdown(choices=class_names, label="Recognizable Breeds", type="value")

# Use `gr.Blocks()` to define the full interface
with gr.Blocks() as demo_with_dropdown:
    # Display markdown heading
    gr.Markdown("# Oxford Pet 🐕🐈 Recognizable Breeds")
    
    # Dropdown as a reference for users
    dropdown  
    
    # Image classification demo
    gr.Image(label="Upload an image to classify", scale=0.5)
    
    # Gradio interface for the image input and label output
    gr.Interface(
        fn=classify_image,  
        inputs=gr.Image(type="pil"),  # Only image input is used for prediction
        outputs=gr.Label(num_top_classes=3, label="Top 3 Predictions"),  # Outputs top 3 predictions with probabilities
        examples=examples,  
        title='Oxford Pet 🐈🐕',  
        description='A ResNet50-based model for classifying 37 different pet breeds.', 
        article='[Oxford Project](https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/The%20Oxford-IIIT%20Pet%20Project)' 
    )

# Launch Gradio app
demo_with_dropdown.launch()