File size: 837 Bytes
def3395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
import torch
import cv2


def dt(a):
    return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)


def trimap_transform(trimap, L=320):
    clicks = []
    for k in range(2):
        dt_mask = -dt(1 - trimap[:, :, k]) ** 2
        clicks.append(np.exp(dt_mask / (2 * ((0.02 * L) ** 2))))
        clicks.append(np.exp(dt_mask / (2 * ((0.08 * L) ** 2))))
        clicks.append(np.exp(dt_mask / (2 * ((0.16 * L) ** 2))))
    clicks = np.array(clicks)
    return clicks


# For RGB !
imagenet_norm_std = torch.from_numpy(np.array([0.229, 0.224, 0.225])).float().cpu()[None, :, None, None]
imagenet_norm_mean = torch.from_numpy(np.array([0.485, 0.456, 0.406])).float().cpu()[None, :, None, None]


def normalise_image(image, mean=imagenet_norm_mean, std=imagenet_norm_std):
    return (image - mean) / std