pill_identifier / app.py
shiue2000's picture
Update app.py
1db0189 verified
import gradio as gr
import torch
import numpy as np
from transformers import ViTForImageClassification, ViTModel, ViTImageProcessor
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import pandas as pd
from PIL import Image, ExifTags
# 中文問候函數
def greet(name):
return f"你好,{name}!!"
# 圖像預處理
from PIL import Image, ExifTags
def preprocess_image(image):
"""
將輸入圖像轉換為模型可接受的 tensor
支援 iPhone 圖片自動旋轉
"""
# 轉成 PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(np.uint8(image))
elif not isinstance(image, Image.Image):
image = Image.open(image)
# 修正 iPhone EXIF 方向
try:
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation]=='Orientation':
break
exif=dict(image._getexif().items())
if exif[orientation] == 3:
image=image.rotate(180, expand=True)
elif exif[orientation] == 6:
image=image.rotate(270, expand=True)
elif exif[orientation] == 8:
image=image.rotate(90, expand=True)
except:
# 沒有 EXIF 資訊就直接跳過
pass
# 轉成 RGB
if image.mode != 'RGB':
image = image.convert('RGB')
# 使用 feature extractor
inputs = feature_extractor(images=[image])
image_tensor = torch.tensor(inputs['pixel_values'][0], dtype=torch.float32)
return image_tensor
# 模型預測
def predict(image_tensor, top_k=5):
model.eval()
with torch.no_grad():
outputs = model(pixel_values=image_tensor.unsqueeze(0))
logits = outputs.logits.numpy()
top_indices = np.argsort(logits, axis=1)[:, ::-1][:, :top_k]
top_probs = np.sort(logits, axis=1)[:, ::-1][:, :top_k]
data = []
for i in range(top_k):
class_name = encoder.inverse_transform([top_indices[0][i]])[0]
probability = round(float(top_probs[0][i]), 4)
data.append([i+1, class_name, probability])
df = pd.DataFrame(data, columns=["排名", "藥丸名稱", "機率"])
return df
# 主函數(回傳圖片 + 表格)
def classify_pill(file, top_k: int = 5):
if file is None:
return None, "⚠️ 請上傳一張藥丸圖片!"
try:
image_tensor = preprocess_image(file)
df = predict(image_tensor, top_k)
# 將輸入圖像縮圖回傳
if isinstance(file, np.ndarray):
img_display = Image.fromarray(np.uint8(file))
else:
img_display = file
return img_display, df
except Exception as e:
return None, f"❌ 預測失敗,錯誤訊息:{e}"
# 載入 LabelEncoder
encoder = LabelEncoder()
encoder.classes_ = np.load('encoder.npy', allow_pickle=True)
# 載入模型
pretrained_model = ViTModel.from_pretrained('pillIdentifierAI/pillIdentifier')
feature_extractor = ViTImageProcessor(
image_size=224,
do_resize=True,
do_normalize=True,
do_rescale=False,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
)
config = pretrained_model.config
config.num_labels = len(encoder.classes_)
model = ViTForImageClassification(config)
model.vit = pretrained_model
model.eval()
# 啟動 Gradio
iface = gr.Interface(
fn=classify_pill,
inputs=[
gr.Image(type="numpy", label="📸 上傳藥丸圖片"),
gr.Slider(1, 10, value=5, step=1, label="🔢 顯示前幾個預測結果")
],
outputs=[
gr.Image(label="🔍 上傳圖片預覽"),
gr.Dataframe(label="📝 預測結果(中文表格)", headers=["排名", "藥丸名稱", "機率"])
],
title="藥丸辨識器 💊",
description="上傳藥丸圖片,我們會顯示圖片預覽並以表格形式列出前幾個可能的藥名與機率。"
)
iface.launch(share=True)