Spaces:
Running
Running
File size: 3,460 Bytes
37f6bf3 75c78ca 37f6bf3 9c3253b 37f6bf3 9c3253b 85265af 37f6bf3 9c3253b 85265af 9c3253b 9e437f8 75c78ca 9c3253b 37f6bf3 9c3253b 37f6bf3 9c3253b 75c78ca 9c3253b 37f6bf3 9ac86e7 37f6bf3 9ac86e7 75c78ca 9ac86e7 85265af 9ac86e7 9e437f8 85265af 9e437f8 85265af 37f6bf3 9ac86e7 9e437f8 9ac86e7 9c3253b 9ac86e7 9c3253b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
from torchvision import transforms, models
from PIL import Image
import gradio as gr
import os
# 使用 CPU
device = torch.device('cpu')
# 定義 ResNet-50 模型架構(不使用預訓練權重)
model = models.resnet50(weights=None)
# 修改模型的全連接層,輸出 37 個類別
model.fc = torch.nn.Linear(2048, 37)
# 加載模型權重
model.load_state_dict(torch.load('./resnet50_model_weights.pth', map_location=device))
# 設置模型為評估模式
model.eval()
# 定義影像預處理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 定義類別名稱
class_names = ['Abyssinian (阿比西尼亞貓)', 'American Bulldog (美國鬥牛犬)', 'American Pit Bull Terrier (美國比特鬥牛梗)',
'Basset Hound (巴吉度獵犬)', 'Beagle (米格魯)', 'Bengal (孟加拉貓)', 'Birman (緬甸貓)', 'Bombay (孟買貓)',
'Boxer (拳師犬)', 'British Shorthair (英國短毛貓)', 'Chihuahua (吉娃娃)', 'Egyptian Mau (埃及貓)',
'English Cocker Spaniel (英國可卡犬)', 'English Setter (英國設得蘭犬)', 'German Shorthaired (德國短毛犬)',
'Great Pyrenees (大白熊犬)', 'Havanese (哈瓦那犬)', 'Japanese Chin (日本狆)', 'Keeshond (荷蘭毛獅犬)',
'Leonberger (萊昂貝格犬)', 'Maine Coon (緬因貓)', 'Miniature Pinscher (迷你品犬)', 'Newfoundland (紐芬蘭犬)',
'Persian (波斯貓)', 'Pomeranian (博美犬)', 'Pug (哈巴狗)', 'Ragdoll (布偶貓)', 'Russian Blue (俄羅斯藍貓)',
'Saint Bernard (聖伯納犬)', 'Samoyed (薩摩耶)', 'Scottish Terrier (蘇格蘭梗)', 'Shiba Inu (柴犬)',
'Siamese (暹羅貓)', 'Sphynx (無毛貓)', 'Staffordshire Bull Terrier (史塔福郡鬥牛犬)',
'Wheaten Terrier (小麥色梗)', 'Yorkshire Terrier (約克夏犬)']
# 定義預測函數
def classify_image(image):
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image)
probabilities, indices = torch.topk(outputs, k=3)
probabilities = torch.nn.functional.softmax(probabilities, dim=1)
predictions = [(class_names[idx], prob.item()) for idx, prob in zip(indices[0], probabilities[0])]
return {class_name: f"{prob:.2f}" for class_name, prob in predictions}
# 設定 examples 路徑
examples_path = './examples'
if os.path.exists(examples_path):
print(f"[INFO] Found examples folder at {examples_path}")
else:
print(f"[ERROR] Examples folder not found at {examples_path}")
# 設定範例圖片
examples = [[examples_path + "/" + img] for img in os.listdir(examples_path)]
# 新增下拉選單,顯示所有品種
dropdown = gr.Dropdown(choices=class_names, label="Select a breed", type="value")
# Gradio 介面
demo = gr.Interface(
fn=classify_image,
inputs=[gr.Image(type="pil")], # 移除掉 dropdown 作為輸入
outputs=[gr.Label(num_top_classes=3, label="Top 3 Predictions")],
examples=examples,
title='Oxford Pet 🐈🐕',
description='A ResNet50-based model for classifying 37 different pet breeds.',
article='[Oxford Project](https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/The%20Oxford-IIIT%20Pet%20Project)'
)
demo.launch() |