peirong26's picture
Upload 187 files
2571f24 verified
import torch
def fake_decorator(*a, **k):
if len(a) == 1 and not k:
return a[0]
else:
return fake_decorator
def make_list(x, n=None, **kwargs):
"""Ensure that the input is a list (of a given size)
Parameters
----------
x : list or tuple or scalar
Input object
n : int, optional
Required length
default : scalar, optional
Value to right-pad with. Use last value of the input by default.
Returns
-------
x : list
"""
if not isinstance(x, (list, tuple)):
x = [x]
x = list(x)
if n and len(x) < n:
default = kwargs.get('default', x[-1])
x = x + [default] * max(0, n - len(x))
return x
def expanded_shape(*shapes, side='left'):
"""Expand input shapes according to broadcasting rules
Parameters
----------
*shapes : sequence[int]
Input shapes
side : {'left', 'right'}, default='left'
Side to add singleton dimensions.
Returns
-------
shape : tuple[int]
Output shape
Raises
------
ValueError
If shapes are not compatible for broadcast.
"""
def error(s0, s1):
raise ValueError('Incompatible shapes for broadcasting: {} and {}.'
.format(s0, s1))
# 1. nb dimensions
nb_dim = 0
for shape in shapes:
nb_dim = max(nb_dim, len(shape))
# 2. enumerate
shape = [1] * nb_dim
for i, shape1 in enumerate(shapes):
pad_size = nb_dim - len(shape1)
ones = [1] * pad_size
if side == 'left':
shape1 = [*ones, *shape1]
else:
shape1 = [*shape1, *ones]
shape = [max(s0, s1) if s0 == 1 or s1 == 1 or s0 == s1
else error(s0, s1) for s0, s1 in zip(shape, shape1)]
return tuple(shape)
def matvec(mat, vec, out=None):
"""Matrix-vector product (supports broadcasting)
Parameters
----------
mat : (..., M, N) tensor
Input matrix.
vec : (..., N) tensor
Input vector.
out : (..., M) tensor, optional
Placeholder for the output tensor.
Returns
-------
mv : (..., M) tensor
Matrix vector product of the inputs
"""
vec = vec[..., None]
if out is not None:
out = out[..., None]
mv = torch.matmul(mat, vec, out=out)
mv = mv[..., 0]
if out is not None:
out = out[..., 0]
return mv
def _compare_versions(version1, mode, version2):
for v1, v2 in zip(version1, version2):
if mode in ('gt', '>'):
if v1 > v2:
return True
elif v1 < v2:
return False
elif mode in ('ge', '>='):
if v1 > v2:
return True
elif v1 < v2:
return False
elif mode in ('lt', '<'):
if v1 < v2:
return True
elif v1 > v2:
return False
elif mode in ('le', '<='):
if v1 < v2:
return True
elif v1 > v2:
return False
if mode in ('gt', 'lt', '>', '<'):
return False
else:
return True
def torch_version(mode, version):
"""Check torch version
Parameters
----------
mode : {'<', '<=', '>', '>='}
version : tuple[int]
Returns
-------
True if "torch.version <mode> version"
"""
current_version, *cuda_variant = torch.__version__.split('+')
major, minor, patch, *_ = current_version.split('.')
# strip alpha tags
for x in 'abcdefghijklmnopqrstuvwxy':
if x in patch:
patch = patch[:patch.index(x)]
current_version = (int(major), int(minor), int(patch))
version = make_list(version)
return _compare_versions(current_version, mode, version)
if torch_version('>=', (1, 10)):
meshgrid_ij = lambda *x: torch.meshgrid(*x, indexing='ij')
meshgrid_xy = lambda *x: torch.meshgrid(*x, indexing='xy')
else:
meshgrid_ij = lambda *x: torch.meshgrid(*x)
def meshgrid_xy(*x):
grid = list(torch.meshgrid(*x))
if len(grid) > 1:
grid[0] = grid[0].transpose(0, 1)
grid[1] = grid[1].transpose(0, 1)
return grid
meshgrid = meshgrid_ij