import torch from monai.bundle import ConfigParser import gradio as gr parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow. parser.read_config(f="configs/inference.json") # read the config from specified JSON file parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file inference = parser.get_parsed_content("inferer") network = parser.get_parsed_content("network_def") preprocess = parser.get_parsed_content("preprocessing") state_dict = torch.load("models/model.pt") network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary class_names = { 0: "Other", 1: "Inflammatory", 2: "Epithelial", 3: "Spindle-Shaped", } def classify_image(image_file, label_file): data = {"image":image_file, "label":label_file} batch = preprocess(data) network.eval() with torch.no_grad(): pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask) prob = pred.softmax(-1).detach().cpu().numpy()[0] confidences = {class_names[i]: float(prob[i]) for i in range(len(class_names))} return confidences example_files1 = [ [r'sample_data\Images\test_11_2_0628.png', r'sample_data\Labels\test_11_2_0628.png'], [r'sample_data\Images\test_9_4_0149.png', r'sample_data\Labels\test_9_4_0149.png'], [r'sample_data\Images\test_12_3_0292.png', r'sample_data\Labels\test_12_3_0292.png'], [r'sample_data\Images\test_9_4_0019.png', r'sample_data\Labels\test_9_4_0019.png'] ] example_files2 = [ [r'sample_data\Images\test_14_3_0433.png', r'sample_data\Labels\test_14_3_0433.png'], [r'sample_data\Images\test_14_4_0544.png', r'sample_data\Labels\test_14_4_0544.png'], [r'sample_data\Images\train_1_1_0095.png', r'sample_data\Labels\train_1_1_0095.png'], [r'sample_data\Images\train_1_3_0020.png', r'sample_data\Labels\train_1_3_0020.png'], ] with open('Description.md','r') as file: markdown_content = file.read() with gr.Blocks() as app: gr.Markdown("# Pathology Nuclei Classification") gr.Markdown(markdown_content) with gr.Row(): with gr.Column(): with gr.Row(): inp_img = gr.Image(type="filepath", image_mode="RGB") label_img = gr.Image(type="filepath", image_mode="L") with gr.Row(): process_btn = gr.Button(value="Process") clear_btn = gr.Button(value="Clear") out_txt = gr.Label(label="Probabilities", num_top_classes=4) process_btn.click(fn=classify_image, inputs=[inp_img, label_img], outputs=out_txt) clear_btn.click(lambda:( gr.update(value=None), gr.update(value=None), gr.update(value=None) ), inputs=None, outputs=[inp_img, label_img,out_txt] ) gr.Markdown("## Image Examples") with gr.Row(): for file in example_files1: gr.Examples( [file], inputs=[inp_img, label_img] ) with gr.Row(): for file in example_files2: gr.Examples( [file], inputs=[inp_img, label_img] ) app.launch()