|
import torch |
|
|
|
|
|
def eye_like(n: int, input: torch.Tensor) -> torch.Tensor: |
|
r"""Return a 2-D tensor with ones on the diagonal and zeros elsewhere with the same batch size as the input. |
|
|
|
Args: |
|
n: the number of rows :math:`(N)`. |
|
input: image tensor that will determine the batch size of the output matrix. |
|
The expected shape is :math:`(B, *)`. |
|
|
|
Returns: |
|
The identity matrix with the same batch size as the input :math:`(B, N, N)`. |
|
|
|
""" |
|
if n <= 0: |
|
raise AssertionError(type(n), n) |
|
if len(input.shape) < 1: |
|
raise AssertionError(input.shape) |
|
|
|
identity = torch.eye(n, device=input.device, dtype=input.dtype) |
|
return identity[None].repeat(input.shape[0], 1, 1) |
|
|
|
|
|
def vec_like(n, tensor): |
|
r"""Return a 2-D tensor with a vector containing zeros with the same batch size as the input. |
|
|
|
Args: |
|
n: the number of rows :math:`(N)`. |
|
tensor: image tensor that will determine the batch size of the output matrix. |
|
The expected shape is :math:`(B, *)`. |
|
|
|
Returns: |
|
The vector with the same batch size as the input :math:`(B, N, 1)`. |
|
|
|
""" |
|
if n <= 0: |
|
raise AssertionError(type(n), n) |
|
if len(tensor.shape) < 1: |
|
raise AssertionError(tensor.shape) |
|
|
|
vec = torch.zeros(n, 1, device=tensor.device, dtype=tensor.dtype) |
|
return vec[None].repeat(tensor.shape[0], 1, 1) |
|
|