File size: 4,260 Bytes
2571f24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import torch
def fake_decorator(*a, **k):
if len(a) == 1 and not k:
return a[0]
else:
return fake_decorator
def make_list(x, n=None, **kwargs):
"""Ensure that the input is a list (of a given size)
Parameters
----------
x : list or tuple or scalar
Input object
n : int, optional
Required length
default : scalar, optional
Value to right-pad with. Use last value of the input by default.
Returns
-------
x : list
"""
if not isinstance(x, (list, tuple)):
x = [x]
x = list(x)
if n and len(x) < n:
default = kwargs.get('default', x[-1])
x = x + [default] * max(0, n - len(x))
return x
def expanded_shape(*shapes, side='left'):
"""Expand input shapes according to broadcasting rules
Parameters
----------
*shapes : sequence[int]
Input shapes
side : {'left', 'right'}, default='left'
Side to add singleton dimensions.
Returns
-------
shape : tuple[int]
Output shape
Raises
------
ValueError
If shapes are not compatible for broadcast.
"""
def error(s0, s1):
raise ValueError('Incompatible shapes for broadcasting: {} and {}.'
.format(s0, s1))
# 1. nb dimensions
nb_dim = 0
for shape in shapes:
nb_dim = max(nb_dim, len(shape))
# 2. enumerate
shape = [1] * nb_dim
for i, shape1 in enumerate(shapes):
pad_size = nb_dim - len(shape1)
ones = [1] * pad_size
if side == 'left':
shape1 = [*ones, *shape1]
else:
shape1 = [*shape1, *ones]
shape = [max(s0, s1) if s0 == 1 or s1 == 1 or s0 == s1
else error(s0, s1) for s0, s1 in zip(shape, shape1)]
return tuple(shape)
def matvec(mat, vec, out=None):
"""Matrix-vector product (supports broadcasting)
Parameters
----------
mat : (..., M, N) tensor
Input matrix.
vec : (..., N) tensor
Input vector.
out : (..., M) tensor, optional
Placeholder for the output tensor.
Returns
-------
mv : (..., M) tensor
Matrix vector product of the inputs
"""
vec = vec[..., None]
if out is not None:
out = out[..., None]
mv = torch.matmul(mat, vec, out=out)
mv = mv[..., 0]
if out is not None:
out = out[..., 0]
return mv
def _compare_versions(version1, mode, version2):
for v1, v2 in zip(version1, version2):
if mode in ('gt', '>'):
if v1 > v2:
return True
elif v1 < v2:
return False
elif mode in ('ge', '>='):
if v1 > v2:
return True
elif v1 < v2:
return False
elif mode in ('lt', '<'):
if v1 < v2:
return True
elif v1 > v2:
return False
elif mode in ('le', '<='):
if v1 < v2:
return True
elif v1 > v2:
return False
if mode in ('gt', 'lt', '>', '<'):
return False
else:
return True
def torch_version(mode, version):
"""Check torch version
Parameters
----------
mode : {'<', '<=', '>', '>='}
version : tuple[int]
Returns
-------
True if "torch.version <mode> version"
"""
current_version, *cuda_variant = torch.__version__.split('+')
major, minor, patch, *_ = current_version.split('.')
# strip alpha tags
for x in 'abcdefghijklmnopqrstuvwxy':
if x in patch:
patch = patch[:patch.index(x)]
current_version = (int(major), int(minor), int(patch))
version = make_list(version)
return _compare_versions(current_version, mode, version)
if torch_version('>=', (1, 10)):
meshgrid_ij = lambda *x: torch.meshgrid(*x, indexing='ij')
meshgrid_xy = lambda *x: torch.meshgrid(*x, indexing='xy')
else:
meshgrid_ij = lambda *x: torch.meshgrid(*x)
def meshgrid_xy(*x):
grid = list(torch.meshgrid(*x))
if len(grid) > 1:
grid[0] = grid[0].transpose(0, 1)
grid[1] = grid[1].transpose(0, 1)
return grid
meshgrid = meshgrid_ij |