htr_demo / src /htr_pipeline /utils /process_segmask.py
Gabriel's picture
bad merge quick fix..
5ebeb73
raw
history blame
3.23 kB
import cv2
import numpy as np
import torch
from mmdet.registry import VISUALIZERS
class SegMaskHelper:
def __init__(self):
pass
# Pad the masks to image size (bug in RTMDet config?)
# @timer_func
def align_masks_with_image(self, result, img):
masks = list()
img = img[..., ::-1].copy()
for j, mask in enumerate(result.pred_instances.masks):
numpy_mask = mask.cpu().numpy()
mask = cv2.resize(
numpy_mask.astype(np.uint8),
(img.shape[1], img.shape[0]),
interpolation=cv2.INTER_NEAREST,
)
# Pad the mask to match the size of the image
padded_mask = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
padded_mask[: mask.shape[0], : mask.shape[1]] = mask
mask = padded_mask
mask = torch.from_numpy(mask)
masks.append(mask)
stacked_masks = torch.stack(masks)
result.pred_instances.masks = stacked_masks
return result
# Crops the images using masks and put the cropped images on a white background
# @timer_func
def crop_masks(self, result, img):
cropped_imgs = list()
polygons = list()
for j, mask in enumerate(result.pred_instances.masks):
np_array = mask.cpu().numpy()
contours, _ = cv2.findContours(
np_array.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
) # fix so only one contour (the largest one) is extracted
largest_contour = max(contours, key=cv2.contourArea)
epsilon = 0.003 * cv2.arcLength(largest_contour, True)
approx_poly = cv2.approxPolyDP(largest_contour, epsilon, True)
approx_poly = np.squeeze(approx_poly)
approx_poly = approx_poly.tolist()
polygons.append(approx_poly)
x, y, w, h = cv2.boundingRect(largest_contour)
# Crop masked region and put on white background
masked_region = img[y : y + h, x : x + w]
white_background = np.ones_like(masked_region)
white_background.fill(255)
masked_region_on_white = cv2.bitwise_and(
white_background, masked_region, mask=np_array.astype(np.uint8)[y : y + h, x : x + w]
)
cv2.bitwise_not(white_background, white_background, mask=np_array.astype(np.uint8)[y : y + h, x : x + w])
res = white_background + masked_region_on_white
cropped_imgs.append(res)
return cropped_imgs, polygons
def visualize_result(self, result, img, model_visualizer):
visualizer = VISUALIZERS.build(model_visualizer)
visualizer.add_datasample("result", img, data_sample=result, draw_gt=False)
return visualizer.get_image()
def _translate_line_coords(self, region_mask, line_polygons):
region_mask = region_mask.cpu().numpy()
region_masks_binary = (region_mask > 0).astype(np.uint8)
box = cv2.boundingRect(region_masks_binary)
translated_line_polygons = [[[a + box[0], b + box[1]] for [a, b] in poly] for poly in line_polygons]
return translated_line_polygons