import json import os from pathlib import Path import gradio as gr import numpy as np import torch from monai.bundle import ConfigParser from utils import page_utils with open("configs/inference.json") as f: inference_config = json.load(f) device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda:0') # * NOTE: device must be hardcoded, config file won't affect the device selection inference_config["device"] = device parser = ConfigParser() parser.read_config(f=inference_config) parser.read_meta(f="configs/metadata.json") inference = parser.get_parsed_content("inferer") # loader = parser.get_parsed_content("dataloader") network = parser.get_parsed_content("network_def") preprocess = parser.get_parsed_content("preprocessing") postprocess = parser.get_parsed_content("postprocessing") use_fp16 = os.environ.get('USE_FP16', False) state_dict = torch.load("models/model.pt") network.load_state_dict(state_dict, strict=True) network = network.to(device) network.eval() if use_fp16 and torch.cuda.is_available(): network = network.half() label2color = {0: (0, 0, 0), 1: (225, 24, 69), # RED 2: (135, 233, 17), # GREEN 3: (0, 87, 233), # BLUE 4: (242, 202, 25), # YELLOW 5: (137, 49, 239),} # PURPLE example_files = list(Path("sample_data").glob("*.png")) def visualize_instance_seg_mask(mask): image = np.zeros((mask.shape[0], mask.shape[1], 3)) labels = np.unique(mask) for i in range(image.shape[0]): for j in range(image.shape[1]): image[i, j, :] = label2color[mask[i, j]] image = image / 255 return image def query_image(img): data = {"image": img} batch = preprocess(data) batch['image'] = batch['image'].to(device) if use_fp16 and torch.cuda.is_available(): batch['image'] = batch['image'].half() with torch.no_grad(): pred = inference(batch['image'].unsqueeze(dim=0), network) batch["pred"] = pred for k,v in batch["pred"].items(): batch["pred"][k] = v.squeeze(dim=0) batch = postprocess(batch) result = visualize_instance_seg_mask(batch["type_map"].squeeze()) # Combine image result = batch["image"].permute(1, 2, 0).cpu().numpy() * 0.5 + result * 0.5 # Solve rotating problem result = np.fliplr(result) result = np.rot90(result, k=1) return result # load Markdown file with open('index.html', encoding='utf-8') as f: html_content = f.read() demo = gr.Interface( query_image, inputs=[gr.Image(type="filepath")], outputs="image", theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( button_primary_background_fill="*primary_600", button_primary_background_fill_hover="*primary_500", button_primary_text_color="white", ), description = html_content, examples=example_files, ) demo.queue(max_size=10).launch()