Space2 / app.py
Aohanah's picture
Update app.py
bcce238 verified
import gradio as gr
import torch
import numpy as np
import modeling
import cv2
def predict(image1, image2, image3,age, BCVA, CDR, IOP):
# Check for empty values
if None in [age, BCVA, CDR, IOP]:
empty_fields = [label for label, value in [("Age", age), ("BCVA", BCVA), ("CDR", CDR), ("IOP", IOP)] if value is None]
error_message = ", ".join([f"{field} value is empty" for field in empty_fields])
return error_message
# load
model = torch.load('./model_save.pth', map_location=torch.device('cpu'))
model.eval()
x = []
for img in [image1, image2, image3]:
pic = cv2.resize(img, (512, 512))
pic = pic[np.newaxis, :] # (1, 512, 512, 3)
print("image shape(1, 512, 512, 3):", pic.shape) # 打印图像张量形状
x.append(pic)
_img = np.concatenate(x, axis=0) # (3, 512, 512, 3)
print("image shape(3, 512, 512, 3):", _img.shape) # 打印图像张量形状
def preprocess_image(img):
img = np.array(img).astype(np.float32)
img /= 255.0
img -= (0.485, 0.456, 0.406)
img /= (0.229, 0.224, 0.225)
img = np.transpose(img, (0, 3, 1, 2)) # (3, 512, 512, 3) -> (3, 3, 512, 512)
img = np.expand_dims(img, axis=0) # 添加一个维度表示 batch size,设为1
img = torch.from_numpy(img).float()
return img
# Preprocessed numerical data
if age is not None:
age /= 100
if IOP is not None:
IOP /= 100
tabular = np.array([age, BCVA, CDR, IOP]).astype(np.float32)
tabular = np.expand_dims(tabular, axis=0) # 添加一个维度表示 batch size,设为1
tabular = torch.from_numpy(tabular).float()
img = preprocess_image(_img)
# 推断
with torch.no_grad():
output = model(img, tabular)
# 创建索引到标签的映射关系
index_to_label = {0: "Early",1: "Serious"}
# 获取模型的输出概率
output_probs = torch.softmax(output, dim=1).numpy()[0]
# 查找相应的标签及其概率
class_proba = {index_to_label[i]: float(prob) for i, prob in enumerate(output_probs)}
# 返回每个类别的预测概率
return class_proba
examples = [["1-3.jpg","1-6.jpg","1-9.jpg", 50, 1.5, 0.6, 23], ["2-3.jpg","2-6.jpg","2-9.jpg", 36, 1.2, 0.4, 12],["3-3.jpg","3-6.jpg","3-9.jpg", 36, 1.2, 0.4, 12] ]
with gr.Blocks() as demo:
gr.Markdown("# Multi-Glau")
with gr.Row():
image1 = gr.Image(label="image1")
image2 = gr.Image(label="image2")
image3 = gr.Image(label="image3")
with gr.Row():
Age = gr.Number(label="Age",value=np.nan)
BCVA = gr.Number(label="BCVA",value=np.nan)
CDR = gr.Number(label="CDR",value=np.nan)
IOP = gr.Number(label="IOP",value=np.nan)
btn = gr.Button("Submit")
output = gr.Label(label="Output")
btn.click(fn=predict, inputs=[image1, image2, image3, Age, BCVA, CDR, IOP], outputs=output)
gr.Examples(examples, inputs=[image1, image2, image3, Age, BCVA, CDR, IOP])
demo.launch()