HDM-interaction-recon / model /model_utils.py
xiexh20's picture
add hdm demo v1
2fd6166
raw history blame
No virus
1.73 kB
import cv2
import numpy as np
import torch
import torch.nn as nn
from pytorch3d.structures import Pointclouds
def set_requires_grad(module: nn.Module, requires_grad: bool):
for p in module.parameters():
p.requires_grad_(requires_grad)
def compute_distance_transform(mask: torch.Tensor):
"""
Parameters
----------
mask (B, 1, H, W) or (B, 2, H, W) true for foreground
Returns
-------
the vector to the closest foreground pixel, zero if inside mask
"""
C = mask.shape[1]
assert C in [1, 2], f'invalid mask shape {mask.shape} found!'
image_size = mask.shape[-1]
dts = []
for i in range(C):
distance_transform = torch.stack([
torch.from_numpy(cv2.distanceTransform(
(1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3
) / (image_size / 2))
for m in mask[:, i:i+1].squeeze(1).detach().cpu().numpy().astype(np.uint8)
]).unsqueeze(1).clip(0, 1).to(mask.device)
dts.append(distance_transform)
return torch.cat(dts, 1)
def default(x, d):
return d if x is None else x
def get_num_points(x: Pointclouds, /):
return x.points_padded().shape[1]
def get_custom_betas(beta_start: float, beta_end: float, warmup_frac: float = 0.3, num_train_timesteps: int = 1000):
"""Custom beta schedule"""
betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
warmup_frac = 0.3
warmup_time = int(num_train_timesteps * warmup_frac)
warmup_steps = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
warmup_time = min(warmup_time, num_train_timesteps)
betas[:warmup_time] = warmup_steps[:warmup_time]
return betas