import gradio as gr import torch import numpy as np import modeling import cv2 def predict(image1, image2, image3,age, BCVA, CDR, IOP): # Check for empty values if None in [age, BCVA, CDR, IOP]: empty_fields = [label for label, value in [("Age", age), ("BCVA", BCVA), ("CDR", CDR), ("IOP", IOP)] if value is None] error_message = ", ".join([f"{field} value is empty" for field in empty_fields]) return error_message # load model = torch.load('./model_save.pth', map_location=torch.device('cpu')) model.eval() x = [] for img in [image1, image2, image3]: pic = cv2.resize(img, (512, 512)) pic = pic[np.newaxis, :] # (1, 512, 512, 3) print("image shape(1, 512, 512, 3):", pic.shape) # 打印图像张量形状 x.append(pic) _img = np.concatenate(x, axis=0) # (3, 512, 512, 3) print("image shape(3, 512, 512, 3):", _img.shape) # 打印图像张量形状 def preprocess_image(img): img = np.array(img).astype(np.float32) img /= 255.0 img -= (0.485, 0.456, 0.406) img /= (0.229, 0.224, 0.225) img = np.transpose(img, (0, 3, 1, 2)) # (3, 512, 512, 3) -> (3, 3, 512, 512) img = np.expand_dims(img, axis=0) # 添加一个维度表示 batch size,设为1 img = torch.from_numpy(img).float() return img # Preprocessed numerical data if age is not None: age /= 100 if IOP is not None: IOP /= 100 tabular = np.array([age, BCVA, CDR, IOP]).astype(np.float32) tabular = np.expand_dims(tabular, axis=0) # 添加一个维度表示 batch size,设为1 tabular = torch.from_numpy(tabular).float() img = preprocess_image(_img) # 推断 with torch.no_grad(): output = model(img, tabular) # 创建索引到标签的映射关系 index_to_label = {0: "Early",1: "Serious"} # 获取模型的输出概率 output_probs = torch.softmax(output, dim=1).numpy()[0] # 查找相应的标签及其概率 class_proba = {index_to_label[i]: float(prob) for i, prob in enumerate(output_probs)} # 返回每个类别的预测概率 return class_proba examples = [["1-3.jpg","1-6.jpg","1-9.jpg", 50, 1.5, 0.6, 23], ["2-3.jpg","2-6.jpg","2-9.jpg", 36, 1.2, 0.4, 12],["3-3.jpg","3-6.jpg","3-9.jpg", 36, 1.2, 0.4, 12] ] with gr.Blocks() as demo: gr.Markdown("# Multi-Glau") with gr.Row(): image1 = gr.Image(label="image1") image2 = gr.Image(label="image2") image3 = gr.Image(label="image3") with gr.Row(): Age = gr.Number(label="Age",value=np.nan) BCVA = gr.Number(label="BCVA",value=np.nan) CDR = gr.Number(label="CDR",value=np.nan) IOP = gr.Number(label="IOP",value=np.nan) btn = gr.Button("Submit") output = gr.Label(label="Output") btn.click(fn=predict, inputs=[image1, image2, image3, Age, BCVA, CDR, IOP], outputs=output) gr.Examples(examples, inputs=[image1, image2, image3, Age, BCVA, CDR, IOP]) demo.launch()