|
import torch |
|
|
|
|
|
def fake_decorator(*a, **k): |
|
if len(a) == 1 and not k: |
|
return a[0] |
|
else: |
|
return fake_decorator |
|
|
|
|
|
def make_list(x, n=None, **kwargs): |
|
"""Ensure that the input is a list (of a given size) |
|
|
|
Parameters |
|
---------- |
|
x : list or tuple or scalar |
|
Input object |
|
n : int, optional |
|
Required length |
|
default : scalar, optional |
|
Value to right-pad with. Use last value of the input by default. |
|
|
|
Returns |
|
------- |
|
x : list |
|
""" |
|
if not isinstance(x, (list, tuple)): |
|
x = [x] |
|
x = list(x) |
|
if n and len(x) < n: |
|
default = kwargs.get('default', x[-1]) |
|
x = x + [default] * max(0, n - len(x)) |
|
return x |
|
|
|
|
|
def expanded_shape(*shapes, side='left'): |
|
"""Expand input shapes according to broadcasting rules |
|
|
|
Parameters |
|
---------- |
|
*shapes : sequence[int] |
|
Input shapes |
|
side : {'left', 'right'}, default='left' |
|
Side to add singleton dimensions. |
|
|
|
Returns |
|
------- |
|
shape : tuple[int] |
|
Output shape |
|
|
|
Raises |
|
------ |
|
ValueError |
|
If shapes are not compatible for broadcast. |
|
|
|
""" |
|
def error(s0, s1): |
|
raise ValueError('Incompatible shapes for broadcasting: {} and {}.' |
|
.format(s0, s1)) |
|
|
|
|
|
nb_dim = 0 |
|
for shape in shapes: |
|
nb_dim = max(nb_dim, len(shape)) |
|
|
|
|
|
shape = [1] * nb_dim |
|
for i, shape1 in enumerate(shapes): |
|
pad_size = nb_dim - len(shape1) |
|
ones = [1] * pad_size |
|
if side == 'left': |
|
shape1 = [*ones, *shape1] |
|
else: |
|
shape1 = [*shape1, *ones] |
|
shape = [max(s0, s1) if s0 == 1 or s1 == 1 or s0 == s1 |
|
else error(s0, s1) for s0, s1 in zip(shape, shape1)] |
|
|
|
return tuple(shape) |
|
|
|
|
|
def matvec(mat, vec, out=None): |
|
"""Matrix-vector product (supports broadcasting) |
|
|
|
Parameters |
|
---------- |
|
mat : (..., M, N) tensor |
|
Input matrix. |
|
vec : (..., N) tensor |
|
Input vector. |
|
out : (..., M) tensor, optional |
|
Placeholder for the output tensor. |
|
|
|
Returns |
|
------- |
|
mv : (..., M) tensor |
|
Matrix vector product of the inputs |
|
|
|
""" |
|
vec = vec[..., None] |
|
if out is not None: |
|
out = out[..., None] |
|
|
|
mv = torch.matmul(mat, vec, out=out) |
|
mv = mv[..., 0] |
|
if out is not None: |
|
out = out[..., 0] |
|
|
|
return mv |
|
|
|
|
|
def _compare_versions(version1, mode, version2): |
|
for v1, v2 in zip(version1, version2): |
|
if mode in ('gt', '>'): |
|
if v1 > v2: |
|
return True |
|
elif v1 < v2: |
|
return False |
|
elif mode in ('ge', '>='): |
|
if v1 > v2: |
|
return True |
|
elif v1 < v2: |
|
return False |
|
elif mode in ('lt', '<'): |
|
if v1 < v2: |
|
return True |
|
elif v1 > v2: |
|
return False |
|
elif mode in ('le', '<='): |
|
if v1 < v2: |
|
return True |
|
elif v1 > v2: |
|
return False |
|
if mode in ('gt', 'lt', '>', '<'): |
|
return False |
|
else: |
|
return True |
|
|
|
|
|
def torch_version(mode, version): |
|
"""Check torch version |
|
|
|
Parameters |
|
---------- |
|
mode : {'<', '<=', '>', '>='} |
|
version : tuple[int] |
|
|
|
Returns |
|
------- |
|
True if "torch.version <mode> version" |
|
|
|
""" |
|
current_version, *cuda_variant = torch.__version__.split('+') |
|
major, minor, patch, *_ = current_version.split('.') |
|
|
|
for x in 'abcdefghijklmnopqrstuvwxy': |
|
if x in patch: |
|
patch = patch[:patch.index(x)] |
|
current_version = (int(major), int(minor), int(patch)) |
|
version = make_list(version) |
|
return _compare_versions(current_version, mode, version) |
|
|
|
|
|
if torch_version('>=', (1, 10)): |
|
meshgrid_ij = lambda *x: torch.meshgrid(*x, indexing='ij') |
|
meshgrid_xy = lambda *x: torch.meshgrid(*x, indexing='xy') |
|
else: |
|
meshgrid_ij = lambda *x: torch.meshgrid(*x) |
|
def meshgrid_xy(*x): |
|
grid = list(torch.meshgrid(*x)) |
|
if len(grid) > 1: |
|
grid[0] = grid[0].transpose(0, 1) |
|
grid[1] = grid[1].transpose(0, 1) |
|
return grid |
|
|
|
|
|
meshgrid = meshgrid_ij |