Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) Facebook, Inc. and its affiliates. | |
This source code is licensed under the MIT license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
from typing import List, Optional | |
import torch | |
from packaging import version | |
if version.parse(torch.__version__) >= version.parse("1.7.0"): | |
import torch.fft # type: ignore | |
def fft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: | |
""" | |
Apply centered 2 dimensional Fast Fourier Transform. | |
Args: | |
data: Complex valued input data containing at least 3 dimensions: | |
dimensions -3 & -2 are spatial dimensions and dimension -1 has size | |
2. All other dimensions are assumed to be batch dimensions. | |
norm: Whether to include normalization. Must be one of ``"backward"`` | |
or ``"ortho"``. See ``torch.fft.fft`` on PyTorch 1.9.0 for details. | |
Returns: | |
The FFT of the input. | |
""" | |
if not data.shape[-1] == 2: | |
raise ValueError("Tensor does not have separate complex dim.") | |
if norm not in ("ortho", "backward"): | |
raise ValueError("norm must be 'ortho' or 'backward'.") | |
normalized = True if norm == "ortho" else False | |
data = ifftshift(data, dim=[-3, -2]) | |
data = torch.fft(data, 2, normalized=normalized) | |
data = fftshift(data, dim=[-3, -2]) | |
return data | |
def ifft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: | |
""" | |
Apply centered 2-dimensional Inverse Fast Fourier Transform. | |
Args: | |
data: Complex valued input data containing at least 3 dimensions: | |
dimensions -3 & -2 are spatial dimensions and dimension -1 has size | |
2. All other dimensions are assumed to be batch dimensions. | |
norm: Whether to include normalization. Must be one of ``"backward"`` | |
or ``"ortho"``. See ``torch.fft.ifft`` on PyTorch 1.9.0 for | |
details. | |
Returns: | |
The IFFT of the input. | |
""" | |
if not data.shape[-1] == 2: | |
raise ValueError("Tensor does not have separate complex dim.") | |
if norm not in ("ortho", "backward"): | |
raise ValueError("norm must be 'ortho' or 'backward'.") | |
normalized = True if norm == "ortho" else False | |
data = ifftshift(data, dim=[-3, -2]) | |
data = torch.ifft(data, 2, normalized=normalized) | |
data = fftshift(data, dim=[-3, -2]) | |
return data | |
def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: | |
""" | |
Apply centered 2 dimensional Fast Fourier Transform. | |
Args: | |
data: Complex valued input data containing at least 3 dimensions: | |
dimensions -3 & -2 are spatial dimensions and dimension -1 has size | |
2. All other dimensions are assumed to be batch dimensions. | |
norm: Normalization mode. See ``torch.fft.fft``. | |
Returns: | |
The FFT of the input. | |
""" | |
if not data.shape[-1] == 2: | |
raise ValueError("Tensor does not have separate complex dim.") | |
data = ifftshift(data, dim=[-3, -2]) | |
data = torch.view_as_real( | |
torch.fft.fftn( # type: ignore | |
torch.view_as_complex(data), dim=(-2, -1), norm=norm | |
) | |
) | |
data = fftshift(data, dim=[-3, -2]) | |
return data | |
def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: | |
""" | |
Apply centered 2-dimensional Inverse Fast Fourier Transform. | |
Args: | |
data: Complex valued input data containing at least 3 dimensions: | |
dimensions -3 & -2 are spatial dimensions and dimension -1 has size | |
2. All other dimensions are assumed to be batch dimensions. | |
norm: Normalization mode. See ``torch.fft.ifft``. | |
Returns: | |
The IFFT of the input. | |
""" | |
if not data.shape[-1] == 2: | |
raise ValueError("Tensor does not have separate complex dim.") | |
data = ifftshift(data, dim=[-3, -2]) | |
data = torch.view_as_real( | |
torch.fft.ifftn( # type: ignore | |
torch.view_as_complex(data), dim=(-2, -1), norm=norm | |
) | |
) | |
data = fftshift(data, dim=[-3, -2]) | |
return data | |
# Helper functions | |
def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor: | |
""" | |
Similar to roll but for only one dim. | |
Args: | |
x: A PyTorch tensor. | |
shift: Amount to roll. | |
dim: Which dimension to roll. | |
Returns: | |
Rolled version of x. | |
""" | |
shift = shift % x.size(dim) | |
if shift == 0: | |
return x | |
left = x.narrow(dim, 0, x.size(dim) - shift) | |
right = x.narrow(dim, x.size(dim) - shift, shift) | |
return torch.cat((right, left), dim=dim) | |
def roll( | |
x: torch.Tensor, | |
shift: List[int], | |
dim: List[int], | |
) -> torch.Tensor: | |
""" | |
Similar to np.roll but applies to PyTorch Tensors. | |
Args: | |
x: A PyTorch tensor. | |
shift: Amount to roll. | |
dim: Which dimension to roll. | |
Returns: | |
Rolled version of x. | |
""" | |
if len(shift) != len(dim): | |
raise ValueError("len(shift) must match len(dim)") | |
for (s, d) in zip(shift, dim): | |
x = roll_one_dim(x, s, d) | |
return x | |
def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: | |
""" | |
Similar to np.fft.fftshift but applies to PyTorch Tensors | |
Args: | |
x: A PyTorch tensor. | |
dim: Which dimension to fftshift. | |
Returns: | |
fftshifted version of x. | |
""" | |
if dim is None: | |
# this weird code is necessary for toch.jit.script typing | |
dim = [0] * (x.dim()) | |
for i in range(1, x.dim()): | |
dim[i] = i | |
# also necessary for torch.jit.script | |
shift = [0] * len(dim) | |
for i, dim_num in enumerate(dim): | |
shift[i] = x.shape[dim_num] // 2 | |
return roll(x, shift, dim) | |
def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: | |
""" | |
Similar to np.fft.ifftshift but applies to PyTorch Tensors | |
Args: | |
x: A PyTorch tensor. | |
dim: Which dimension to ifftshift. | |
Returns: | |
ifftshifted version of x. | |
""" | |
if dim is None: | |
# this weird code is necessary for toch.jit.script typing | |
dim = [0] * (x.dim()) | |
for i in range(1, x.dim()): | |
dim[i] = i | |
# also necessary for torch.jit.script | |
shift = [0] * len(dim) | |
for i, dim_num in enumerate(dim): | |
shift[i] = (x.shape[dim_num] + 1) // 2 | |
return roll(x, shift, dim) |