File size: 1,728 Bytes
2fd6166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import cv2
import numpy as np
import torch
import torch.nn as nn
from pytorch3d.structures import Pointclouds


def set_requires_grad(module: nn.Module, requires_grad: bool):
    for p in module.parameters():
        p.requires_grad_(requires_grad)


def compute_distance_transform(mask: torch.Tensor):
    """

    Parameters
    ----------
    mask (B, 1, H, W) or (B, 2, H, W) true for foreground

    Returns
    -------
    the vector to the closest foreground pixel, zero if inside mask

    """
    C = mask.shape[1]
    assert C in [1, 2], f'invalid mask shape {mask.shape} found!'

    image_size = mask.shape[-1]

    dts = []
    for i in range(C):
        distance_transform = torch.stack([
            torch.from_numpy(cv2.distanceTransform(
                (1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3
            ) / (image_size / 2))
            for m in mask[:, i:i+1].squeeze(1).detach().cpu().numpy().astype(np.uint8)
        ]).unsqueeze(1).clip(0, 1).to(mask.device)
        dts.append(distance_transform)
    return torch.cat(dts, 1)


def default(x, d):
    return d if x is None else x


def get_num_points(x: Pointclouds, /):
    return x.points_padded().shape[1]


def get_custom_betas(beta_start: float, beta_end: float, warmup_frac: float = 0.3, num_train_timesteps: int = 1000):
    """Custom beta schedule"""
    betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
    warmup_frac = 0.3
    warmup_time = int(num_train_timesteps * warmup_frac)
    warmup_steps = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
    warmup_time = min(warmup_time, num_train_timesteps)
    betas[:warmup_time] = warmup_steps[:warmup_time]
    return betas