fadindashafira's picture
update error handler (#2)
d1ab033
import torch
from monai.bundle import ConfigParser
import gradio as gr
from utils import page_utils
parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow.
parser.read_config(f="configs/inference.json") # read the config from specified JSON file
parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file
inference = parser.get_parsed_content("inferer")
network = parser.get_parsed_content("network_def")
preprocess = parser.get_parsed_content("preprocessing")
state_dict = torch.load("models/model.pt", map_location=torch.device('cpu'))
network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary
class_names = {
0: "Other",
1: "Inflammatory",
2: "Epithelial",
3: "Spindle-Shaped",
}
def classify_image(image_file, label_file):
if image_file is None:
raise gr.Error("Need a histology image")
if label_file is None:
raise gr.Error("Need a label image")
data = {"image":image_file, "label":label_file}
batch = preprocess(data)
batch['image'] = batch['image']
network.eval()
with torch.no_grad():
pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask)
prob = pred.softmax(-1).detach().cpu().numpy()[0]
confidences = {class_names[i]: float(prob[i]) for i in range(len(class_names))}
return confidences
example_files1 = [
['sample_data/Images/test_11_2_0628.png',
'sample_data/Labels/test_11_2_0628.png'],
['sample_data/Images/test_9_4_0149.png',
'sample_data/Labels/test_9_4_0149.png'],
['sample_data/Images/test_12_3_0292.png',
'sample_data/Labels/test_12_3_0292.png'],
['sample_data/Images/test_9_4_0019.png',
'sample_data/Labels/test_9_4_0019.png']
]
example_files2 = [
['sample_data/Images/test_14_3_0433.png',
'sample_data/Labels/test_14_3_0433.png'],
['sample_data/Images/test_14_4_0544.png',
'sample_data/Labels/test_14_4_0544.png'],
['sample_data/Images/train_1_1_0095.png',
'sample_data/Labels/train_1_1_0095.png'],
['sample_data/Images/train_1_3_0020.png',
'sample_data/Labels/train_1_3_0020.png'],
]
with open('index.html', encoding='utf-8') as file:
html_content = file.read()
with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
button_primary_background_fill='*primary_600',
button_primary_background_fill_hover='*primary_500',
button_primary_text_color='white',
)) as app:
gr.HTML(html_content)
with gr.Row():
with gr.Column():
with gr.Row():
inp_img = gr.Image(type="filepath", image_mode="RGB", label="Histology Image", show_label=True)
label_img = gr.Image(type="filepath", image_mode="L", label="Label Image", show_label=True)
with gr.Row():
clear_btn = gr.Button(value="Clear")
process_btn = gr.Button(value="Process", variant="primary")
out_txt = gr.Label(label="Probabilities", num_top_classes=4)
process_btn.click(fn=classify_image, inputs=[inp_img, label_img], outputs=out_txt)
clear_btn.click(lambda:(
gr.update(value=None),
gr.update(value=None),
gr.update(value=None)
),
inputs=None,
outputs=[inp_img, label_img,out_txt]
)
gr.Markdown("## Image Examples")
with gr.Row():
for file in example_files1:
gr.Examples(
[file], inputs=[inp_img, label_img]
)
with gr.Row():
for file in example_files2:
gr.Examples(
[file], inputs=[inp_img, label_img]
)
app.launch()