import os import torch import torchvision.transforms as transforms from functools import lru_cache @lru_cache(maxsize=None) def meshgrid(B, H, W, dtype, device, normalized=False): """ Create mesh-grid given batch size, height and width dimensions. From https://github.com/TRI-ML/KP2D. Parameters ---------- B: int Batch size H: int Grid Height W: int Batch size dtype: torch.dtype Tensor dtype device: str Tensor device normalized: bool Normalized image coordinates or integer-grid. Returns ------- xs: torch.Tensor Batched mesh-grid x-coordinates (BHW). ys: torch.Tensor Batched mesh-grid y-coordinates (BHW). """ if normalized: xs = torch.linspace(-1, 1, W, device=device, dtype=dtype) ys = torch.linspace(-1, 1, H, device=device, dtype=dtype) else: xs = torch.linspace(0, W-1, W, device=device, dtype=dtype) ys = torch.linspace(0, H-1, H, device=device, dtype=dtype) ys, xs = torch.meshgrid([ys, xs]) return xs.repeat([B, 1, 1]), ys.repeat([B, 1, 1]) @lru_cache(maxsize=None) def image_grid(B, H, W, dtype, device, ones=True, normalized=False): """ Create an image mesh grid with shape B3HW given image shape BHW. From https://github.com/TRI-ML/KP2D. Parameters ---------- B: int Batch size H: int Grid Height W: int Batch size dtype: str Tensor dtype device: str Tensor device ones : bool Use (x, y, 1) coordinates normalized: bool Normalized image coordinates or integer-grid. Returns ------- grid: torch.Tensor Mesh-grid for the corresponding image shape (B3HW) """ xs, ys = meshgrid(B, H, W, dtype, device, normalized=normalized) coords = [xs, ys] if ones: coords.append(torch.ones_like(xs)) # BHW grid = torch.stack(coords, dim=1) # B3HW return grid def to_tensor_sample(sample, tensor_type='torch.FloatTensor'): """ Casts the keys of sample to tensors. From https://github.com/TRI-ML/KP2D. Parameters ---------- sample : dict Input sample tensor_type : str Type of tensor we are casting to Returns ------- sample : dict Sample with keys cast as tensors """ transform = transforms.ToTensor() sample['image'] = transform(sample['image']).type(tensor_type) return sample def prepare_dirs(config): for path in [config.ckpt_dir]: if not os.path.exists(path): os.makedirs(path)