|
import re |
|
import os |
|
import cv2 |
|
import yaml |
|
import math |
|
import random |
|
import scipy.ndimage |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from typing import List |
|
from torchvision import transforms as T |
|
|
|
from bilateral_solver import bilateral_solver_output |
|
|
|
|
|
loader = yaml.SafeLoader |
|
loader.add_implicit_resolver( |
|
u'tag:yaml.org,2002:float', |
|
re.compile(u'''^(?: |
|
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? |
|
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) |
|
|\\.[0-9_]+(?:[eE][-+][0-9]+)? |
|
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* |
|
|[-+]?\\.(?:inf|Inf|INF) |
|
|\\.(?:nan|NaN|NAN))$''', re.X), |
|
list(u'-+0123456789.')) |
|
|
|
class Struct: |
|
def __init__(self, **entries): |
|
self.__dict__.update(entries) |
|
|
|
def load_config(config_file): |
|
with open(config_file, errors='ignore') as f: |
|
|
|
conf = yaml.load(f, Loader=loader) |
|
print('hyperparameters: ' + ', '.join(f'{k}={v}' for k, v in conf.items())) |
|
|
|
|
|
return Struct(**conf) |
|
|
|
def set_seed(seed: int) -> None: |
|
""" |
|
Set all seeds to make results reproducible |
|
""" |
|
|
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
|
|
|
random.seed(seed) |
|
|
|
|
|
np.random.seed(seed) |
|
|
|
|
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
|
|
def IoU(mask1, mask2): |
|
""" |
|
Code adapted from TokenCut: https://github.com/YangtaoWANG95/TokenCut |
|
""" |
|
mask1, mask2 = (mask1 > 0.5).to(torch.bool), (mask2 > 0.5).to(torch.bool) |
|
intersection = torch.sum(mask1 * (mask1 == mask2), dim=[-1, -2]).squeeze() |
|
union = torch.sum(mask1 + mask2, dim=[-1, -2]).squeeze() |
|
return (intersection.to(torch.float) / union).mean().item() |
|
|
|
def batch_apply_bilateral_solver(data, |
|
masks, |
|
get_all_cc=True, |
|
shape=None): |
|
|
|
cnt_bs = 0 |
|
masks_bs = [] |
|
inputs, init_imgs, gt_labels, img_path = data |
|
|
|
for id in range(inputs.shape[0]): |
|
_, bs_mask, use_bs = apply_bilateral_solver( |
|
mask=masks[id].squeeze().cpu().numpy(), |
|
img=init_imgs[id], |
|
img_path=img_path[id], |
|
im_fullsize=False, |
|
|
|
shape=(gt_labels.shape[-1], gt_labels.shape[-2]), |
|
get_all_cc=get_all_cc, |
|
) |
|
cnt_bs += use_bs |
|
|
|
|
|
if use_bs: |
|
if shape is None: |
|
shape = masks.shape[-2:] |
|
|
|
bs_ds = F.interpolate( |
|
torch.Tensor(bs_mask).unsqueeze(0).unsqueeze(0), |
|
shape, |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
masks_bs.append(bs_ds.bool().cuda().squeeze()[None, :, :]) |
|
else: |
|
|
|
masks_bs.append(masks[id].cuda().squeeze()[None, :, :]) |
|
|
|
return torch.cat(masks_bs).squeeze(), cnt_bs |
|
|
|
|
|
def apply_bilateral_solver( |
|
mask, |
|
img, |
|
img_path, |
|
shape, |
|
im_fullsize=False, |
|
get_all_cc=False, |
|
bs_iou_threshold: float = 0.5, |
|
reshape: bool = True, |
|
): |
|
|
|
img_init = None |
|
if not im_fullsize: |
|
|
|
shape = (img.shape[-1], img.shape[-2]) |
|
t = T.ToPILImage() |
|
img_init = t(img) |
|
|
|
if reshape: |
|
|
|
resized_mask = cv2.resize(mask, shape) |
|
sel_obj_mask = resized_mask |
|
else: |
|
resized_mask = mask |
|
sel_obj_mask = mask |
|
|
|
|
|
_, binary_solver = bilateral_solver_output( |
|
img_path, |
|
resized_mask, |
|
img=img_init, |
|
sigma_spatial=16, |
|
sigma_luma=16, |
|
sigma_chroma=8, |
|
get_all_cc=get_all_cc, |
|
) |
|
|
|
mask1 = torch.from_numpy(resized_mask).cuda() |
|
mask2 = torch.from_numpy(binary_solver).cuda().float() |
|
|
|
use_bs = 0 |
|
|
|
if IoU(mask1, mask2) > bs_iou_threshold: |
|
sel_obj_mask = binary_solver.astype(float) |
|
use_bs = 1 |
|
|
|
return resized_mask, sel_obj_mask, use_bs |
|
|
|
def get_bbox_from_segmentation_labels( |
|
segmenter_predictions: torch.Tensor, |
|
initial_image_size: torch.Size, |
|
scales: List[int], |
|
) -> np.array: |
|
""" |
|
Find the largest connected component in foreground, extract its bounding box |
|
""" |
|
objects, num_objects = scipy.ndimage.label(segmenter_predictions) |
|
|
|
|
|
all_foreground_labels = objects.flatten()[objects.flatten() != 0] |
|
most_frequent_label = np.bincount(all_foreground_labels).argmax() |
|
mask = np.where(objects == most_frequent_label) |
|
|
|
ymin, ymax = min(mask[0]), max(mask[0]) + 1 |
|
xmin, xmax = min(mask[1]), max(mask[1]) + 1 |
|
|
|
if initial_image_size == segmenter_predictions.shape: |
|
|
|
pred = [xmin, ymin, xmax, ymax] |
|
else: |
|
|
|
r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax |
|
r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax |
|
pred = [r_xmin, r_ymin, r_xmax, r_ymax] |
|
|
|
|
|
if initial_image_size: |
|
pred[2] = min(pred[2], initial_image_size[1]) |
|
pred[3] = min(pred[3], initial_image_size[0]) |
|
|
|
return np.asarray(pred) |
|
|
|
|
|
def bbox_iou( |
|
box1: np.array, |
|
box2: np.array, |
|
x1y1x2y2: bool = True, |
|
GIoU: bool = False, |
|
DIoU: bool = False, |
|
CIoU: bool = False, |
|
eps: float = 1e-7, |
|
): |
|
|
|
|
|
box2 = box2.T |
|
|
|
|
|
if x1y1x2y2: |
|
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] |
|
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] |
|
else: |
|
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 |
|
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 |
|
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 |
|
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 |
|
|
|
|
|
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * ( |
|
torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1) |
|
).clamp(0) |
|
|
|
|
|
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps |
|
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps |
|
union = w1 * h1 + w2 * h2 - inter + eps |
|
|
|
iou = inter / union |
|
if GIoU or DIoU or CIoU: |
|
cw = torch.max(b1_x2, b2_x2) - torch.min( |
|
b1_x1, b2_x1 |
|
) |
|
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) |
|
if CIoU or DIoU: |
|
c2 = cw**2 + ch**2 + eps |
|
rho2 = ( |
|
(b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 |
|
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2 |
|
) / 4 |
|
if DIoU: |
|
return iou - rho2 / c2 |
|
elif ( |
|
CIoU |
|
): |
|
v = (4 / math.pi**2) * torch.pow( |
|
torch.atan(w2 / h2) - torch.atan(w1 / h1), 2 |
|
) |
|
with torch.no_grad(): |
|
alpha = v / (v - iou + (1 + eps)) |
|
return iou - (rho2 / c2 + v * alpha) |
|
else: |
|
c_area = cw * ch + eps |
|
return iou - (c_area - union) / c_area |
|
else: |
|
return iou |
|
|