File size: 2,986 Bytes
a1f1417
 
54f43fd
a1f1417
54f43fd
 
a1f1417
 
 
15c4bc8
 
a1f1417
 
 
 
 
 
 
 
 
54f43fd
 
a1f1417
54f43fd
 
 
 
 
 
 
 
a1f1417
 
54f43fd
 
 
a1f1417
 
 
 
 
 
54f43fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f845d2
54f43fd
 
a1f1417
 
 
 
4adfcec
54f43fd
 
4adfcec
54f43fd
 
 
4adfcec
54f43fd
4adfcec
54f43fd
4adfcec
54f43fd
 
4adfcec
54f43fd
 
 
4adfcec
54f43fd
ed80dbe
 
15c4bc8
 
ed80dbe
54f43fd
a1f1417
4adfcec
54f43fd
15c4bc8
 
 
 
 
 
4adfcec
54f43fd
 
b88ad11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import json
import os
from pathlib import Path

import gradio as gr
import numpy as np
import torch
from monai.bundle import ConfigParser

from utils import page_utils

with open("configs/inference.json") as f:
    inference_config = json.load(f)

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda:0')

# * NOTE: device must be hardcoded, config file won't affect the device selection
inference_config["device"] = device

parser = ConfigParser()
parser.read_config(f=inference_config)
parser.read_meta(f="configs/metadata.json")

inference = parser.get_parsed_content("inferer")
# loader = parser.get_parsed_content("dataloader")
network = parser.get_parsed_content("network_def")
preprocess = parser.get_parsed_content("preprocessing")
postprocess = parser.get_parsed_content("postprocessing")

use_fp16 = os.environ.get('USE_FP16', False)

state_dict = torch.load("models/model.pt")
network.load_state_dict(state_dict, strict=True)

network = network.to(device)
network.eval()

if use_fp16 and torch.cuda.is_available():
    network = network.half()

label2color = {0: (0, 0, 0),
             1: (225, 24, 69), # RED
             2: (135, 233, 17), # GREEN
             3: (0, 87, 233), # BLUE
             4: (242, 202, 25), # YELLOW
             5: (137, 49, 239),} # PURPLE

example_files = list(Path("sample_data").glob("*.png"))

def visualize_instance_seg_mask(mask):
    image = np.zeros((mask.shape[0], mask.shape[1], 3))
    labels = np.unique(mask)
    for i in range(image.shape[0]):
      for j in range(image.shape[1]):
        image[i, j, :] = label2color[mask[i, j]]
    image = image / 255
    return image

def query_image(img):
    data = {"image": img}
    batch = preprocess(data)
    batch['image'] = batch['image'].to(device)

    if use_fp16 and torch.cuda.is_available():
        batch['image'] = batch['image'].half()

    with torch.no_grad():
        pred = inference(batch['image'].unsqueeze(dim=0), network)

    batch["pred"] = pred
    for k,v in batch["pred"].items():
        batch["pred"][k] = v.squeeze(dim=0)

    batch = postprocess(batch)

    result = visualize_instance_seg_mask(batch["type_map"].squeeze())

    # Combine image
    result = batch["image"].permute(1, 2, 0).cpu().numpy() * 0.5 + result * 0.5

    # Solve rotating problem
    result = np.fliplr(result)
    result = np.rot90(result, k=1)

    return result

# load Markdown file
with open('index.html', encoding='utf-8') as f:
    html_content = f.read()

demo = gr.Interface(
    query_image,
    inputs=[gr.Image(type="filepath")],
    outputs="image",
    theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
        button_primary_background_fill="*primary_600",
        button_primary_background_fill_hover="*primary_500",
        button_primary_text_color="white",
    ),
    description = html_content,
    examples=example_files,
)

demo.queue(max_size=10).launch()