Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import ViTForImageClassification, ViTModel, ViTImageProcessor | |
from PIL import Image | |
from sklearn.preprocessing import LabelEncoder | |
import pandas as pd | |
from PIL import Image, ExifTags | |
# 中文問候函數 | |
def greet(name): | |
return f"你好,{name}!!" | |
# 圖像預處理 | |
from PIL import Image, ExifTags | |
def preprocess_image(image): | |
""" | |
將輸入圖像轉換為模型可接受的 tensor | |
支援 iPhone 圖片自動旋轉 | |
""" | |
# 轉成 PIL Image | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(np.uint8(image)) | |
elif not isinstance(image, Image.Image): | |
image = Image.open(image) | |
# 修正 iPhone EXIF 方向 | |
try: | |
for orientation in ExifTags.TAGS.keys(): | |
if ExifTags.TAGS[orientation]=='Orientation': | |
break | |
exif=dict(image._getexif().items()) | |
if exif[orientation] == 3: | |
image=image.rotate(180, expand=True) | |
elif exif[orientation] == 6: | |
image=image.rotate(270, expand=True) | |
elif exif[orientation] == 8: | |
image=image.rotate(90, expand=True) | |
except: | |
# 沒有 EXIF 資訊就直接跳過 | |
pass | |
# 轉成 RGB | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# 使用 feature extractor | |
inputs = feature_extractor(images=[image]) | |
image_tensor = torch.tensor(inputs['pixel_values'][0], dtype=torch.float32) | |
return image_tensor | |
# 模型預測 | |
def predict(image_tensor, top_k=5): | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(pixel_values=image_tensor.unsqueeze(0)) | |
logits = outputs.logits.numpy() | |
top_indices = np.argsort(logits, axis=1)[:, ::-1][:, :top_k] | |
top_probs = np.sort(logits, axis=1)[:, ::-1][:, :top_k] | |
data = [] | |
for i in range(top_k): | |
class_name = encoder.inverse_transform([top_indices[0][i]])[0] | |
probability = round(float(top_probs[0][i]), 4) | |
data.append([i+1, class_name, probability]) | |
df = pd.DataFrame(data, columns=["排名", "藥丸名稱", "機率"]) | |
return df | |
# 主函數(回傳圖片 + 表格) | |
def classify_pill(file, top_k: int = 5): | |
if file is None: | |
return None, "⚠️ 請上傳一張藥丸圖片!" | |
try: | |
image_tensor = preprocess_image(file) | |
df = predict(image_tensor, top_k) | |
# 將輸入圖像縮圖回傳 | |
if isinstance(file, np.ndarray): | |
img_display = Image.fromarray(np.uint8(file)) | |
else: | |
img_display = file | |
return img_display, df | |
except Exception as e: | |
return None, f"❌ 預測失敗,錯誤訊息:{e}" | |
# 載入 LabelEncoder | |
encoder = LabelEncoder() | |
encoder.classes_ = np.load('encoder.npy', allow_pickle=True) | |
# 載入模型 | |
pretrained_model = ViTModel.from_pretrained('pillIdentifierAI/pillIdentifier') | |
feature_extractor = ViTImageProcessor( | |
image_size=224, | |
do_resize=True, | |
do_normalize=True, | |
do_rescale=False, | |
image_mean=[0.5, 0.5, 0.5], | |
image_std=[0.5, 0.5, 0.5], | |
) | |
config = pretrained_model.config | |
config.num_labels = len(encoder.classes_) | |
model = ViTForImageClassification(config) | |
model.vit = pretrained_model | |
model.eval() | |
# 啟動 Gradio | |
iface = gr.Interface( | |
fn=classify_pill, | |
inputs=[ | |
gr.Image(type="numpy", label="📸 上傳藥丸圖片"), | |
gr.Slider(1, 10, value=5, step=1, label="🔢 顯示前幾個預測結果") | |
], | |
outputs=[ | |
gr.Image(label="🔍 上傳圖片預覽"), | |
gr.Dataframe(label="📝 預測結果(中文表格)", headers=["排名", "藥丸名稱", "機率"]) | |
], | |
title="藥丸辨識器 💊", | |
description="上傳藥丸圖片,我們會顯示圖片預覽並以表格形式列出前幾個可能的藥名與機率。" | |
) | |
iface.launch(share=True) | |