File size: 6,911 Bytes
f8519d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import math
import torch
import torch.nn.functional as F
"""
An implementation of the parallel scan operation in PyTorch (Blelloch version).
Please see docs/pscan.ipynb for a detailed explanation of what happens here.
"""
def npo2(len):
"""
Returns the next power of 2 above len
"""
return 2 ** math.ceil(math.log2(len))
def pad_npo2(X):
"""
Pads input length dim to the next power of 2
Args:
X : (B, L, D, N)
Returns:
Y : (B, npo2(L), D, N)
"""
len_npo2 = npo2(X.size(1))
pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
return F.pad(X, pad_tuple, "constant", 0)
class PScan(torch.autograd.Function):
@staticmethod
def pscan(A, X):
# A : (B, D, L, N)
# X : (B, D, L, N)
# modifies X in place by doing a parallel scan.
# more formally, X will be populated by these values :
# H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
# which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
# only supports L that is a power of two (mainly for a clearer code)
B, D, L, _ = A.size()
num_steps = int(math.log2(L))
# up sweep (last 2 steps unfolded)
Aa = A
Xa = X
for _ in range(num_steps-2):
T = Xa.size(2)
Aa = Aa.view(B, D, T//2, 2, -1)
Xa = Xa.view(B, D, T//2, 2, -1)
Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
Aa = Aa[:, :, :, 1]
Xa = Xa[:, :, :, 1]
# we have only 4, 2 or 1 nodes left
if Xa.size(2) == 4:
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
Aa[:, :, 1].mul_(Aa[:, :, 0])
Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
elif Xa.size(2) == 2:
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
return
else:
return
# down sweep (first 2 steps unfolded)
Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
Aa[:, :, 2].mul_(Aa[:, :, 1])
for k in range(num_steps-3, -1, -1):
Aa = A[:, :, 2**k-1:L:2**k]
Xa = X[:, :, 2**k-1:L:2**k]
T = Xa.size(2)
Aa = Aa.view(B, D, T//2, 2, -1)
Xa = Xa.view(B, D, T//2, 2, -1)
Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
@staticmethod
def pscan_rev(A, X):
# A : (B, D, L, N)
# X : (B, D, L, N)
# the same function as above, but in reverse
# (if you flip the input, call pscan, then flip the output, you get what this function outputs)
# it is used in the backward pass
# only supports L that is a power of two (mainly for a clearer code)
B, D, L, _ = A.size()
num_steps = int(math.log2(L))
# up sweep (last 2 steps unfolded)
Aa = A
Xa = X
for _ in range(num_steps-2):
T = Xa.size(2)
Aa = Aa.view(B, D, T//2, 2, -1)
Xa = Xa.view(B, D, T//2, 2, -1)
Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])
Aa = Aa[:, :, :, 0]
Xa = Xa[:, :, :, 0]
# we have only 4, 2 or 1 nodes left
if Xa.size(2) == 4:
Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
Aa[:, :, 2].mul_(Aa[:, :, 3])
Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
elif Xa.size(2) == 2:
Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
return
else:
return
# down sweep (first 2 steps unfolded)
Aa = A[:, :, 0:L:2**(num_steps-2)]
Xa = X[:, :, 0:L:2**(num_steps-2)]
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
Aa[:, :, 1].mul_(Aa[:, :, 2])
for k in range(num_steps-3, -1, -1):
Aa = A[:, :, 0:L:2**k]
Xa = X[:, :, 0:L:2**k]
T = Xa.size(2)
Aa = Aa.view(B, D, T//2, 2, -1)
Xa = Xa.view(B, D, T//2, 2, -1)
Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])
@staticmethod
def forward(ctx, A_in, X_in):
"""
Applies the parallel scan operation, as defined above. Returns a new tensor.
If you can, privilege sequence lengths that are powers of two.
Args:
A_in : (B, L, D, N)
X_in : (B, L, D, N)
Returns:
H : (B, L, D, N)
"""
L = X_in.size(1)
# cloning is requiered because of the in-place ops
if L == npo2(L):
A = A_in.clone()
X = X_in.clone()
else:
# pad tensors (and clone btw)
A = pad_npo2(A_in) # (B, npo2(L), D, N)
X = pad_npo2(X_in) # (B, npo2(L), D, N)
# prepare tensors
A = A.transpose(2, 1) # (B, D, npo2(L), N)
X = X.transpose(2, 1) # (B, D, npo2(L), N)
# parallel scan (modifies X in-place)
PScan.pscan(A, X)
ctx.save_for_backward(A_in, X)
# slice [:, :L] (cut if there was padding)
return X.transpose(2, 1)[:, :L]
@staticmethod
def backward(ctx, grad_output_in):
"""
Flows the gradient from the output to the input. Returns two new tensors.
Args:
ctx : A_in : (B, L, D, N), X : (B, D, L, N)
grad_output_in : (B, L, D, N)
Returns:
gradA : (B, L, D, N), gradX : (B, L, D, N)
"""
A_in, X = ctx.saved_tensors
L = grad_output_in.size(1)
# cloning is requiered because of the in-place ops
if L == npo2(L):
grad_output = grad_output_in.clone()
# the next padding will clone A_in
else:
grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
A_in = pad_npo2(A_in) # (B, npo2(L), D, N)
# prepare tensors
grad_output = grad_output.transpose(2, 1)
A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)
# reverse parallel scan (modifies grad_output in-place)
PScan.pscan_rev(A, grad_output)
Q = torch.zeros_like(X)
Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])
return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
pscan = PScan.apply
|