|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import modeling |
|
import cv2 |
|
|
|
|
|
def predict(image1, image2, image3,age, BCVA, CDR, IOP): |
|
|
|
if None in [age, BCVA, CDR, IOP]: |
|
empty_fields = [label for label, value in [("Age", age), ("BCVA", BCVA), ("CDR", CDR), ("IOP", IOP)] if value is None] |
|
error_message = ", ".join([f"{field} value is empty" for field in empty_fields]) |
|
return error_message |
|
|
|
|
|
model = torch.load('./model_save.pth', map_location=torch.device('cpu')) |
|
model.eval() |
|
|
|
x = [] |
|
|
|
for img in [image1, image2, image3]: |
|
pic = cv2.resize(img, (512, 512)) |
|
pic = pic[np.newaxis, :] |
|
print("image shape(1, 512, 512, 3):", pic.shape) |
|
x.append(pic) |
|
|
|
_img = np.concatenate(x, axis=0) |
|
print("image shape(3, 512, 512, 3):", _img.shape) |
|
|
|
def preprocess_image(img): |
|
img = np.array(img).astype(np.float32) |
|
img /= 255.0 |
|
img -= (0.485, 0.456, 0.406) |
|
img /= (0.229, 0.224, 0.225) |
|
img = np.transpose(img, (0, 3, 1, 2)) |
|
img = np.expand_dims(img, axis=0) |
|
img = torch.from_numpy(img).float() |
|
return img |
|
|
|
|
|
if age is not None: |
|
age /= 100 |
|
if IOP is not None: |
|
IOP /= 100 |
|
tabular = np.array([age, BCVA, CDR, IOP]).astype(np.float32) |
|
tabular = np.expand_dims(tabular, axis=0) |
|
tabular = torch.from_numpy(tabular).float() |
|
|
|
img = preprocess_image(_img) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(img, tabular) |
|
|
|
|
|
index_to_label = {0: "Early",1: "Serious"} |
|
|
|
output_probs = torch.softmax(output, dim=1).numpy()[0] |
|
|
|
class_proba = {index_to_label[i]: float(prob) for i, prob in enumerate(output_probs)} |
|
|
|
return class_proba |
|
|
|
examples = [["1-3.jpg","1-6.jpg","1-9.jpg", 50, 1.5, 0.6, 23], ["2-3.jpg","2-6.jpg","2-9.jpg", 36, 1.2, 0.4, 12],["3-3.jpg","3-6.jpg","3-9.jpg", 36, 1.2, 0.4, 12] ] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Multi-Glau") |
|
with gr.Row(): |
|
image1 = gr.Image(label="image1") |
|
image2 = gr.Image(label="image2") |
|
image3 = gr.Image(label="image3") |
|
with gr.Row(): |
|
Age = gr.Number(label="Age",value=np.nan) |
|
BCVA = gr.Number(label="BCVA",value=np.nan) |
|
CDR = gr.Number(label="CDR",value=np.nan) |
|
IOP = gr.Number(label="IOP",value=np.nan) |
|
btn = gr.Button("Submit") |
|
output = gr.Label(label="Output") |
|
|
|
btn.click(fn=predict, inputs=[image1, image2, image3, Age, BCVA, CDR, IOP], outputs=output) |
|
gr.Examples(examples, inputs=[image1, image2, image3, Age, BCVA, CDR, IOP]) |
|
|
|
demo.launch() |