import torch def eye_like(n: int, input: torch.Tensor) -> torch.Tensor: r"""Return a 2-D tensor with ones on the diagonal and zeros elsewhere with the same batch size as the input. Args: n: the number of rows :math:`(N)`. input: image tensor that will determine the batch size of the output matrix. The expected shape is :math:`(B, *)`. Returns: The identity matrix with the same batch size as the input :math:`(B, N, N)`. """ if n <= 0: raise AssertionError(type(n), n) if len(input.shape) < 1: raise AssertionError(input.shape) identity = torch.eye(n, device=input.device, dtype=input.dtype) return identity[None].repeat(input.shape[0], 1, 1) def vec_like(n, tensor): r"""Return a 2-D tensor with a vector containing zeros with the same batch size as the input. Args: n: the number of rows :math:`(N)`. tensor: image tensor that will determine the batch size of the output matrix. The expected shape is :math:`(B, *)`. Returns: The vector with the same batch size as the input :math:`(B, N, 1)`. """ if n <= 0: raise AssertionError(type(n), n) if len(tensor.shape) < 1: raise AssertionError(tensor.shape) vec = torch.zeros(n, 1, device=tensor.device, dtype=tensor.dtype) return vec[None].repeat(tensor.shape[0], 1, 1)