Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
import torch
def eye_like(n: int, input: torch.Tensor) -> torch.Tensor:
r"""Return a 2-D tensor with ones on the diagonal and zeros elsewhere with the same batch size as the input.
Args:
n: the number of rows :math:`(N)`.
input: image tensor that will determine the batch size of the output matrix.
The expected shape is :math:`(B, *)`.
Returns:
The identity matrix with the same batch size as the input :math:`(B, N, N)`.
"""
if n <= 0:
raise AssertionError(type(n), n)
if len(input.shape) < 1:
raise AssertionError(input.shape)
identity = torch.eye(n, device=input.device, dtype=input.dtype)
return identity[None].repeat(input.shape[0], 1, 1)
def vec_like(n, tensor):
r"""Return a 2-D tensor with a vector containing zeros with the same batch size as the input.
Args:
n: the number of rows :math:`(N)`.
tensor: image tensor that will determine the batch size of the output matrix.
The expected shape is :math:`(B, *)`.
Returns:
The vector with the same batch size as the input :math:`(B, N, 1)`.
"""
if n <= 0:
raise AssertionError(type(n), n)
if len(tensor.shape) < 1:
raise AssertionError(tensor.shape)
vec = torch.zeros(n, 1, device=tensor.device, dtype=tensor.dtype)
return vec[None].repeat(tensor.shape[0], 1, 1)