from pathlib import Path import torch from monai.bundle import ConfigParser import gradio as gr import pickle import torchvision.transforms as T import numpy as np import random parser = ConfigParser() parser.read_config(f="configs/inference.json") parser.read_meta(f="configs/metadata.json") inference = parser.get_parsed_content("inferer") # loader = parser.get_parsed_content("dataloader") network = parser.get_parsed_content("network_def") preprocess = parser.get_parsed_content("preprocessing") postprocess = parser.get_parsed_content("postprocessing") state_dict = torch.load("models/model.pt") network.load_state_dict(state_dict, strict=True) label2color = {0: (0, 0, 0), 1: (225, 24, 69), # RED 2: (135, 233, 17), # GREEN 3: (0, 87, 233), # BLUE 4: (242, 202, 25), # YELLOW 5: (137, 49, 239),} # PURPLE example_files = list(Path("sample_data").glob("*.png")) def visualize_instance_seg_mask(mask): image = np.zeros((mask.shape[0], mask.shape[1], 3)) labels = np.unique(mask) for i in range(image.shape[0]): for j in range(image.shape[1]): image[i, j, :] = label2color[mask[i, j]] image = image / 255 return image def query_image(img, progress=gr.Progress(track_tqdm=True)): data = {"image": img} batch = preprocess(data) # with open('filename.pickle', 'rb') as handle: # pred = pickle.load(handle) # batch["pred"] = pred network.eval() with torch.no_grad(): pred = inference(batch['image'].unsqueeze(dim=0), network) batch["pred"] = pred for k,v in batch["pred"].items(): batch["pred"][k] = v.squeeze(dim=0) batch = postprocess(batch) result = visualize_instance_seg_mask(batch["type_map"].squeeze()) # Combine image result = batch["image"].permute(1, 2, 0).cpu().numpy() * 0.5 + result * 0.5 # Solve rotating problem result = np.fliplr(result) result = np.rot90(result, k=1) return result demo = gr.Interface( query_image, inputs=[gr.Image(type="filepath")], outputs="image", title="Medical Image Classification with MONAI - Pathology Nuclei Segmentation Classification", description = "Please upload an image to see segmentation capabilities of this model", examples=example_files ) demo.queue(concurrency_count=20).launch()