|
from typing import Optional |
|
|
|
import torch |
|
|
|
|
|
def create_meshgrid( |
|
height: int, |
|
width: int, |
|
normalized_coordinates: bool = True, |
|
device: Optional[torch.device] = torch.device('cpu'), |
|
dtype: torch.dtype = torch.float32, |
|
) -> torch.Tensor: |
|
"""Generate a coordinate grid for an image. |
|
|
|
When the flag ``normalized_coordinates`` is set to True, the grid is |
|
normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch |
|
function :py:func:`torch.nn.functional.grid_sample`. |
|
|
|
Args: |
|
height: the image height (rows). |
|
width: the image width (cols). |
|
normalized_coordinates: whether to normalize |
|
coordinates in the range :math:`[-1,1]` in order to be consistent with the |
|
PyTorch function :py:func:`torch.nn.functional.grid_sample`. |
|
device: the device on which the grid will be generated. |
|
dtype: the data type of the generated grid. |
|
|
|
Return: |
|
grid tensor with shape :math:`(1, H, W, 2)`. |
|
|
|
Example: |
|
>>> create_meshgrid(2, 2) |
|
tensor([[[[-1., -1.], |
|
[ 1., -1.]], |
|
<BLANKLINE> |
|
[[-1., 1.], |
|
[ 1., 1.]]]]) |
|
|
|
>>> create_meshgrid(2, 2, normalized_coordinates=False) |
|
tensor([[[[0., 0.], |
|
[1., 0.]], |
|
<BLANKLINE> |
|
[[0., 1.], |
|
[1., 1.]]]]) |
|
""" |
|
xs: torch.Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype) |
|
ys: torch.Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if normalized_coordinates: |
|
xs = (xs / (width - 1) - 0.5) * 2 |
|
ys = (ys / (height - 1) - 0.5) * 2 |
|
|
|
base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys]), dim=-1) |
|
return base_grid.permute(1, 0, 2).unsqueeze(0) |
|
|
|
|
|
def create_meshgrid3d( |
|
depth: int, |
|
height: int, |
|
width: int, |
|
normalized_coordinates: bool = True, |
|
device: Optional[torch.device] = torch.device('cpu'), |
|
dtype: torch.dtype = torch.float32, |
|
) -> torch.Tensor: |
|
"""Generate a coordinate grid for an image. |
|
|
|
When the flag ``normalized_coordinates`` is set to True, the grid is |
|
normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch |
|
function :py:func:`torch.nn.functional.grid_sample`. |
|
|
|
Args: |
|
depth: the image depth (channels). |
|
height: the image height (rows). |
|
width: the image width (cols). |
|
normalized_coordinates: whether to normalize |
|
coordinates in the range :math:`[-1,1]` in order to be consistent with the |
|
PyTorch function :py:func:`torch.nn.functional.grid_sample`. |
|
device: the device on which the grid will be generated. |
|
dtype: the data type of the generated grid. |
|
|
|
Return: |
|
grid tensor with shape :math:`(1, D, H, W, 3)`. |
|
""" |
|
xs: torch.Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype) |
|
ys: torch.Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype) |
|
zs: torch.Tensor = torch.linspace(0, depth - 1, depth, device=device, dtype=dtype) |
|
|
|
if normalized_coordinates: |
|
xs = (xs / (width - 1) - 0.5) * 2 |
|
ys = (ys / (height - 1) - 0.5) * 2 |
|
zs = (zs / (depth - 1) - 0.5) * 2 |
|
|
|
base_grid: torch.Tensor = torch.stack(torch.meshgrid([zs, xs, ys]), dim=-1) |
|
return base_grid.permute(0, 2, 1, 3).unsqueeze(0) |
|
|