File size: 3,856 Bytes
087921f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from PIL import Image
import numpy as np
from utils import AppUtils
import cv2
from inference import start_inference

class AppInference:
    def __init__(self):
        self.COLOR_MAP = {}    

    def inference(self, input_id, img_path, mask, label):
        AppUtils.clear()
        self._input_id = input_id
        self._handle_preprocess(img_path, mask, label)
        return self._handle_model_inference()
    
    def _handle_preprocess(self, img_path, mask, label):
        items = self._read_files(img_path)
        self._items = items
        mask = self._save_mask(items, mask)
        if label != "None":
            self._edit_maps(items, mask, label, save=True)

    def _handle_model_inference(self):
        return start_inference()
    
    def preview(self, input_id, img_path, mask, label):
        AppUtils.clear()
        self._input_id = input_id
        items = self._read_files(img_path)
        mask = self._save_mask(items, mask)
        if label != "None":
            self._edit_maps(items, mask, label)
        return self.generate_colored_image(items["inst_map"])

    def generate_colored_image(self, semantic_map):
        np.random.seed(256)
        if len(semantic_map.shape) == 3:
            semantic_map = semantic_map[:,:,1]
        color_image = np.zeros((semantic_map.shape[0], semantic_map.shape[1], 3), dtype=np.uint8)
        for row in range(semantic_map.shape[0]):
            for col in range(semantic_map.shape[1]):
                inst_id = semantic_map[row, col]
                if self.COLOR_MAP.get(inst_id, None) is None:
                    self.COLOR_MAP[inst_id] = np.random.randint(256, size=(3,))
                color_image[row, col, :] = self.COLOR_MAP[inst_id]
        return Image.fromarray(color_image)

    def _read_files(self, img_path):
        dataset = img_path.split("/")[2]
        items = {
            "img_path": img_path,
            "label_path": img_path.replace("images", "labels").replace("jpg", "png"),
            "inst_map_path": img_path.replace("images", "inst_map").replace("jpg", "png"),
        }
        for file_path in items.values():
            AppUtils.copy_file(file_path, file_path.replace(dataset, "test_processed"))
        items["dataset"] = dataset
        base_img = cv2.imread(img_path)
        base_img = cv2.cvtColor(base_img, cv2.COLOR_BGR2RGB)
        base_lab = cv2.imread(items["label_path"], 0)
        base_inst_map = Image.open(items["inst_map_path"])            
        base_inst_map = np.array(base_inst_map, dtype=np.int32)
        items.update(
            {
                "img": base_img,
                "label": base_lab,
                "inst_map": base_inst_map,
            }
        )
        return items

    def _save_mask(self, items, mask):
        mask = np.array(mask)[:,:,0]
        mask = mask.reshape((1,) + mask.shape).astype(np.float32)
        save_path = items["img_path"].replace(items["dataset"], "test_processed").replace("images", "predefined_masks/type_0").replace("jpg", "png")
        cv2.imwrite(save_path, mask[0]* 255)
        return mask[0].astype(np.uint8)

    def _edit_maps(self, items, mask, label, save=False):
        mask_path = items["img_path"].replace(items["dataset"], "test_processed").replace("images", "predefined_masks/type_0").replace("jpg", "png")
        mask = cv2.imread(mask_path, 0) / 255
        target_pixels = mask == 1
        target_inst_id = AppUtils.get_inst_id(self._input_id, label)
        items["inst_map"][target_pixels] = target_inst_id
        items["label"][target_pixels] = (target_inst_id % 120)
        im = Image.fromarray(items["inst_map"]).convert("I")
        if save:
            im.save(items["inst_map_path"].replace(items["dataset"], "test_processed"))
            cv2.imwrite(items["label_path"].replace(items["dataset"], "test_processed"), items["label"])