File size: 3,875 Bytes
fb98024
 
 
 
 
 
b4dfea2
1db0189
fb98024
eafb4f8
fb98024
5e42a25
fb98024
eafb4f8
1db0189
 
5e42a25
1db0189
 
 
 
 
5e42a25
 
 
 
 
1db0189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e42a25
 
 
1db0189
5e42a25
 
 
 
 
eafb4f8
5e42a25
 
 
b4dfea2
5e42a25
 
 
 
 
b4dfea2
5e42a25
 
b4dfea2
 
 
 
 
fb98024
b15cbdc
5e42a25
eafb4f8
b15cbdc
eafb4f8
 
b4dfea2
b15cbdc
 
 
 
 
 
eafb4f8
b15cbdc
fb98024
5e42a25
fb98024
 
 
5e42a25
fb98024
 
 
 
 
 
 
 
 
 
 
eafb4f8
fb98024
 
 
 
5e42a25
 
 
eafb4f8
 
 
 
b15cbdc
 
 
 
eafb4f8
b15cbdc
5e42a25
fb98024
5e42a25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import torch
import numpy as np
from transformers import ViTForImageClassification, ViTModel, ViTImageProcessor
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import pandas as pd
from PIL import Image, ExifTags

# 中文問候函數
def greet(name):
    return f"你好,{name}!!"

# 圖像預處理
from PIL import Image, ExifTags

def preprocess_image(image):
    """
    將輸入圖像轉換為模型可接受的 tensor
    支援 iPhone 圖片自動旋轉
    """
    # 轉成 PIL Image
    if isinstance(image, np.ndarray):
        image = Image.fromarray(np.uint8(image))
    elif not isinstance(image, Image.Image):
        image = Image.open(image)
    
    # 修正 iPhone EXIF 方向
    try:
        for orientation in ExifTags.TAGS.keys():
            if ExifTags.TAGS[orientation]=='Orientation':
                break
        exif=dict(image._getexif().items())
        if exif[orientation] == 3:
            image=image.rotate(180, expand=True)
        elif exif[orientation] == 6:
            image=image.rotate(270, expand=True)
        elif exif[orientation] == 8:
            image=image.rotate(90, expand=True)
    except:
        # 沒有 EXIF 資訊就直接跳過
        pass

    # 轉成 RGB
    if image.mode != 'RGB':
        image = image.convert('RGB')

    # 使用 feature extractor
    inputs = feature_extractor(images=[image])
    image_tensor = torch.tensor(inputs['pixel_values'][0], dtype=torch.float32)
    
    return image_tensor

# 模型預測
def predict(image_tensor, top_k=5):
    model.eval()
    with torch.no_grad():
        outputs = model(pixel_values=image_tensor.unsqueeze(0))
        logits = outputs.logits.numpy()
    
    top_indices = np.argsort(logits, axis=1)[:, ::-1][:, :top_k]
    top_probs = np.sort(logits, axis=1)[:, ::-1][:, :top_k]

    data = []
    for i in range(top_k):
        class_name = encoder.inverse_transform([top_indices[0][i]])[0]
        probability = round(float(top_probs[0][i]), 4)
        data.append([i+1, class_name, probability])
    
    df = pd.DataFrame(data, columns=["排名", "藥丸名稱", "機率"])
    return df

# 主函數(回傳圖片 + 表格)
def classify_pill(file, top_k: int = 5):
    if file is None:
        return None, "⚠️ 請上傳一張藥丸圖片!"
    try:
        image_tensor = preprocess_image(file)
        df = predict(image_tensor, top_k)
        # 將輸入圖像縮圖回傳
        if isinstance(file, np.ndarray):
            img_display = Image.fromarray(np.uint8(file))
        else:
            img_display = file
        return img_display, df
    except Exception as e:
        return None, f"❌ 預測失敗,錯誤訊息:{e}"

# 載入 LabelEncoder
encoder = LabelEncoder()
encoder.classes_ = np.load('encoder.npy', allow_pickle=True)

# 載入模型
pretrained_model = ViTModel.from_pretrained('pillIdentifierAI/pillIdentifier')
feature_extractor = ViTImageProcessor(
    image_size=224,
    do_resize=True,
    do_normalize=True,
    do_rescale=False,
    image_mean=[0.5, 0.5, 0.5],
    image_std=[0.5, 0.5, 0.5],
)

config = pretrained_model.config
config.num_labels = len(encoder.classes_)
model = ViTForImageClassification(config)
model.vit = pretrained_model
model.eval()

# 啟動 Gradio
iface = gr.Interface(
    fn=classify_pill,
    inputs=[
        gr.Image(type="numpy", label="📸 上傳藥丸圖片"),
        gr.Slider(1, 10, value=5, step=1, label="🔢 顯示前幾個預測結果")
    ],
    outputs=[
        gr.Image(label="🔍 上傳圖片預覽"),
        gr.Dataframe(label="📝 預測結果(中文表格)", headers=["排名", "藥丸名稱", "機率"])
    ],
    title="藥丸辨識器 💊",
    description="上傳藥丸圖片,我們會顯示圖片預覽並以表格形式列出前幾個可能的藥名與機率。"
)

iface.launch(share=True)