Spaces:
Sleeping
Sleeping
File size: 1,728 Bytes
2fd6166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import cv2
import numpy as np
import torch
import torch.nn as nn
from pytorch3d.structures import Pointclouds
def set_requires_grad(module: nn.Module, requires_grad: bool):
for p in module.parameters():
p.requires_grad_(requires_grad)
def compute_distance_transform(mask: torch.Tensor):
"""
Parameters
----------
mask (B, 1, H, W) or (B, 2, H, W) true for foreground
Returns
-------
the vector to the closest foreground pixel, zero if inside mask
"""
C = mask.shape[1]
assert C in [1, 2], f'invalid mask shape {mask.shape} found!'
image_size = mask.shape[-1]
dts = []
for i in range(C):
distance_transform = torch.stack([
torch.from_numpy(cv2.distanceTransform(
(1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3
) / (image_size / 2))
for m in mask[:, i:i+1].squeeze(1).detach().cpu().numpy().astype(np.uint8)
]).unsqueeze(1).clip(0, 1).to(mask.device)
dts.append(distance_transform)
return torch.cat(dts, 1)
def default(x, d):
return d if x is None else x
def get_num_points(x: Pointclouds, /):
return x.points_padded().shape[1]
def get_custom_betas(beta_start: float, beta_end: float, warmup_frac: float = 0.3, num_train_timesteps: int = 1000):
"""Custom beta schedule"""
betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
warmup_frac = 0.3
warmup_time = int(num_train_timesteps * warmup_frac)
warmup_steps = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
warmup_time = min(warmup_time, num_train_timesteps)
betas[:warmup_time] = warmup_steps[:warmup_time]
return betas
|