Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
from typing import Optional
import torch
def create_meshgrid(
height: int,
width: int,
normalized_coordinates: bool = True,
device: Optional[torch.device] = torch.device('cpu'),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Generate a coordinate grid for an image.
When the flag ``normalized_coordinates`` is set to True, the grid is
normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch
function :py:func:`torch.nn.functional.grid_sample`.
Args:
height: the image height (rows).
width: the image width (cols).
normalized_coordinates: whether to normalize
coordinates in the range :math:`[-1,1]` in order to be consistent with the
PyTorch function :py:func:`torch.nn.functional.grid_sample`.
device: the device on which the grid will be generated.
dtype: the data type of the generated grid.
Return:
grid tensor with shape :math:`(1, H, W, 2)`.
Example:
>>> create_meshgrid(2, 2)
tensor([[[[-1., -1.],
[ 1., -1.]],
<BLANKLINE>
[[-1., 1.],
[ 1., 1.]]]])
>>> create_meshgrid(2, 2, normalized_coordinates=False)
tensor([[[[0., 0.],
[1., 0.]],
<BLANKLINE>
[[0., 1.],
[1., 1.]]]])
"""
xs: torch.Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
ys: torch.Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
# Fix TracerWarning
# Note: normalize_pixel_coordinates still gots TracerWarning since new width and height
# tensors will be generated.
# Below is the code using normalize_pixel_coordinates:
# base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys]), dim=2)
# if normalized_coordinates:
# base_grid = K.geometry.normalize_pixel_coordinates(base_grid, height, width)
# return torch.unsqueeze(base_grid.transpose(0, 1), dim=0)
if normalized_coordinates:
xs = (xs / (width - 1) - 0.5) * 2
ys = (ys / (height - 1) - 0.5) * 2
# generate grid by stacking coordinates
base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys]), dim=-1) # WxHx2
return base_grid.permute(1, 0, 2).unsqueeze(0) # 1xHxWx2
def create_meshgrid3d(
depth: int,
height: int,
width: int,
normalized_coordinates: bool = True,
device: Optional[torch.device] = torch.device('cpu'),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Generate a coordinate grid for an image.
When the flag ``normalized_coordinates`` is set to True, the grid is
normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch
function :py:func:`torch.nn.functional.grid_sample`.
Args:
depth: the image depth (channels).
height: the image height (rows).
width: the image width (cols).
normalized_coordinates: whether to normalize
coordinates in the range :math:`[-1,1]` in order to be consistent with the
PyTorch function :py:func:`torch.nn.functional.grid_sample`.
device: the device on which the grid will be generated.
dtype: the data type of the generated grid.
Return:
grid tensor with shape :math:`(1, D, H, W, 3)`.
"""
xs: torch.Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
ys: torch.Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
zs: torch.Tensor = torch.linspace(0, depth - 1, depth, device=device, dtype=dtype)
# Fix TracerWarning
if normalized_coordinates:
xs = (xs / (width - 1) - 0.5) * 2
ys = (ys / (height - 1) - 0.5) * 2
zs = (zs / (depth - 1) - 0.5) * 2
# generate grid by stacking coordinates
base_grid: torch.Tensor = torch.stack(torch.meshgrid([zs, xs, ys]), dim=-1) # DxWxHx3
return base_grid.permute(0, 2, 1, 3).unsqueeze(0) # 1xDxHxWx3