DiverseSemanticImageEditing / app_inference.py
hakansivuk's picture
Final commit
087921f
raw
history blame contribute delete
No virus
3.86 kB
from PIL import Image
import numpy as np
from utils import AppUtils
import cv2
from inference import start_inference
class AppInference:
def __init__(self):
self.COLOR_MAP = {}
def inference(self, input_id, img_path, mask, label):
AppUtils.clear()
self._input_id = input_id
self._handle_preprocess(img_path, mask, label)
return self._handle_model_inference()
def _handle_preprocess(self, img_path, mask, label):
items = self._read_files(img_path)
self._items = items
mask = self._save_mask(items, mask)
if label != "None":
self._edit_maps(items, mask, label, save=True)
def _handle_model_inference(self):
return start_inference()
def preview(self, input_id, img_path, mask, label):
AppUtils.clear()
self._input_id = input_id
items = self._read_files(img_path)
mask = self._save_mask(items, mask)
if label != "None":
self._edit_maps(items, mask, label)
return self.generate_colored_image(items["inst_map"])
def generate_colored_image(self, semantic_map):
np.random.seed(256)
if len(semantic_map.shape) == 3:
semantic_map = semantic_map[:,:,1]
color_image = np.zeros((semantic_map.shape[0], semantic_map.shape[1], 3), dtype=np.uint8)
for row in range(semantic_map.shape[0]):
for col in range(semantic_map.shape[1]):
inst_id = semantic_map[row, col]
if self.COLOR_MAP.get(inst_id, None) is None:
self.COLOR_MAP[inst_id] = np.random.randint(256, size=(3,))
color_image[row, col, :] = self.COLOR_MAP[inst_id]
return Image.fromarray(color_image)
def _read_files(self, img_path):
dataset = img_path.split("/")[2]
items = {
"img_path": img_path,
"label_path": img_path.replace("images", "labels").replace("jpg", "png"),
"inst_map_path": img_path.replace("images", "inst_map").replace("jpg", "png"),
}
for file_path in items.values():
AppUtils.copy_file(file_path, file_path.replace(dataset, "test_processed"))
items["dataset"] = dataset
base_img = cv2.imread(img_path)
base_img = cv2.cvtColor(base_img, cv2.COLOR_BGR2RGB)
base_lab = cv2.imread(items["label_path"], 0)
base_inst_map = Image.open(items["inst_map_path"])
base_inst_map = np.array(base_inst_map, dtype=np.int32)
items.update(
{
"img": base_img,
"label": base_lab,
"inst_map": base_inst_map,
}
)
return items
def _save_mask(self, items, mask):
mask = np.array(mask)[:,:,0]
mask = mask.reshape((1,) + mask.shape).astype(np.float32)
save_path = items["img_path"].replace(items["dataset"], "test_processed").replace("images", "predefined_masks/type_0").replace("jpg", "png")
cv2.imwrite(save_path, mask[0]* 255)
return mask[0].astype(np.uint8)
def _edit_maps(self, items, mask, label, save=False):
mask_path = items["img_path"].replace(items["dataset"], "test_processed").replace("images", "predefined_masks/type_0").replace("jpg", "png")
mask = cv2.imread(mask_path, 0) / 255
target_pixels = mask == 1
target_inst_id = AppUtils.get_inst_id(self._input_id, label)
items["inst_map"][target_pixels] = target_inst_id
items["label"][target_pixels] = (target_inst_id % 120)
im = Image.fromarray(items["inst_map"]).convert("I")
if save:
im.save(items["inst_map_path"].replace(items["dataset"], "test_processed"))
cv2.imwrite(items["label_path"].replace(items["dataset"], "test_processed"), items["label"])