from PIL import Image import numpy as np from utils import AppUtils import cv2 from inference import start_inference class AppInference: def __init__(self): self.COLOR_MAP = {} def inference(self, input_id, img_path, mask, label): AppUtils.clear() self._input_id = input_id self._handle_preprocess(img_path, mask, label) return self._handle_model_inference() def _handle_preprocess(self, img_path, mask, label): items = self._read_files(img_path) self._items = items mask = self._save_mask(items, mask) if label != "None": self._edit_maps(items, mask, label, save=True) def _handle_model_inference(self): return start_inference() def preview(self, input_id, img_path, mask, label): AppUtils.clear() self._input_id = input_id items = self._read_files(img_path) mask = self._save_mask(items, mask) if label != "None": self._edit_maps(items, mask, label) return self.generate_colored_image(items["inst_map"]) def generate_colored_image(self, semantic_map): np.random.seed(256) if len(semantic_map.shape) == 3: semantic_map = semantic_map[:,:,1] color_image = np.zeros((semantic_map.shape[0], semantic_map.shape[1], 3), dtype=np.uint8) for row in range(semantic_map.shape[0]): for col in range(semantic_map.shape[1]): inst_id = semantic_map[row, col] if self.COLOR_MAP.get(inst_id, None) is None: self.COLOR_MAP[inst_id] = np.random.randint(256, size=(3,)) color_image[row, col, :] = self.COLOR_MAP[inst_id] return Image.fromarray(color_image) def _read_files(self, img_path): dataset = img_path.split("/")[2] items = { "img_path": img_path, "label_path": img_path.replace("images", "labels").replace("jpg", "png"), "inst_map_path": img_path.replace("images", "inst_map").replace("jpg", "png"), } for file_path in items.values(): AppUtils.copy_file(file_path, file_path.replace(dataset, "test_processed")) items["dataset"] = dataset base_img = cv2.imread(img_path) base_img = cv2.cvtColor(base_img, cv2.COLOR_BGR2RGB) base_lab = cv2.imread(items["label_path"], 0) base_inst_map = Image.open(items["inst_map_path"]) base_inst_map = np.array(base_inst_map, dtype=np.int32) items.update( { "img": base_img, "label": base_lab, "inst_map": base_inst_map, } ) return items def _save_mask(self, items, mask): mask = np.array(mask)[:,:,0] mask = mask.reshape((1,) + mask.shape).astype(np.float32) save_path = items["img_path"].replace(items["dataset"], "test_processed").replace("images", "predefined_masks/type_0").replace("jpg", "png") cv2.imwrite(save_path, mask[0]* 255) return mask[0].astype(np.uint8) def _edit_maps(self, items, mask, label, save=False): mask_path = items["img_path"].replace(items["dataset"], "test_processed").replace("images", "predefined_masks/type_0").replace("jpg", "png") mask = cv2.imread(mask_path, 0) / 255 target_pixels = mask == 1 target_inst_id = AppUtils.get_inst_id(self._input_id, label) items["inst_map"][target_pixels] = target_inst_id items["label"][target_pixels] = (target_inst_id % 120) im = Image.fromarray(items["inst_map"]).convert("I") if save: im.save(items["inst_map_path"].replace(items["dataset"], "test_processed")) cv2.imwrite(items["label_path"].replace(items["dataset"], "test_processed"), items["label"])