import torch from monai.bundle import ConfigParser import gradio as gr from utils import page_utils parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow. parser.read_config(f="configs/inference.json") # read the config from specified JSON file parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file inference = parser.get_parsed_content("inferer") network = parser.get_parsed_content("network_def") preprocess = parser.get_parsed_content("preprocessing") state_dict = torch.load("models/model.pt", map_location=torch.device('cpu')) network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary class_names = { 0: "Other", 1: "Inflammatory", 2: "Epithelial", 3: "Spindle-Shaped", } def classify_image(image_file, label_file): if image_file is None: raise gr.Error("Need a histology image") if label_file is None: raise gr.Error("Need a label image") data = {"image":image_file, "label":label_file} batch = preprocess(data) batch['image'] = batch['image'] network.eval() with torch.no_grad(): pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask) prob = pred.softmax(-1).detach().cpu().numpy()[0] confidences = {class_names[i]: float(prob[i]) for i in range(len(class_names))} return confidences example_files1 = [ ['sample_data/Images/test_11_2_0628.png', 'sample_data/Labels/test_11_2_0628.png'], ['sample_data/Images/test_9_4_0149.png', 'sample_data/Labels/test_9_4_0149.png'], ['sample_data/Images/test_12_3_0292.png', 'sample_data/Labels/test_12_3_0292.png'], ['sample_data/Images/test_9_4_0019.png', 'sample_data/Labels/test_9_4_0019.png'] ] example_files2 = [ ['sample_data/Images/test_14_3_0433.png', 'sample_data/Labels/test_14_3_0433.png'], ['sample_data/Images/test_14_4_0544.png', 'sample_data/Labels/test_14_4_0544.png'], ['sample_data/Images/train_1_1_0095.png', 'sample_data/Labels/train_1_1_0095.png'], ['sample_data/Images/train_1_3_0020.png', 'sample_data/Labels/train_1_3_0020.png'], ] with open('index.html', encoding='utf-8') as file: html_content = file.read() with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( button_primary_background_fill='*primary_600', button_primary_background_fill_hover='*primary_500', button_primary_text_color='white', )) as app: gr.HTML(html_content) with gr.Row(): with gr.Column(): with gr.Row(): inp_img = gr.Image(type="filepath", image_mode="RGB", label="Histology Image", show_label=True) label_img = gr.Image(type="filepath", image_mode="L", label="Label Image", show_label=True) with gr.Row(): clear_btn = gr.Button(value="Clear") process_btn = gr.Button(value="Process", variant="primary") out_txt = gr.Label(label="Probabilities", num_top_classes=4) process_btn.click(fn=classify_image, inputs=[inp_img, label_img], outputs=out_txt) clear_btn.click(lambda:( gr.update(value=None), gr.update(value=None), gr.update(value=None) ), inputs=None, outputs=[inp_img, label_img,out_txt] ) gr.Markdown("## Image Examples") with gr.Row(): for file in example_files1: gr.Examples( [file], inputs=[inp_img, label_img] ) with gr.Row(): for file in example_files2: gr.Examples( [file], inputs=[inp_img, label_img] ) app.launch()