|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
""" |
|
|
|
An implementation of the parallel scan operation in PyTorch (Blelloch version). |
|
Please see docs/pscan.ipynb for a detailed explanation of what happens here. |
|
|
|
""" |
|
|
|
def npo2(len): |
|
""" |
|
Returns the next power of 2 above len |
|
""" |
|
|
|
return 2 ** math.ceil(math.log2(len)) |
|
|
|
def pad_npo2(X): |
|
""" |
|
Pads input length dim to the next power of 2 |
|
|
|
Args: |
|
X : (B, L, D, N) |
|
|
|
Returns: |
|
Y : (B, npo2(L), D, N) |
|
""" |
|
|
|
len_npo2 = npo2(X.size(1)) |
|
pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1)) |
|
return F.pad(X, pad_tuple, "constant", 0) |
|
|
|
class PScan(torch.autograd.Function): |
|
@staticmethod |
|
def pscan(A, X): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B, D, L, _ = A.size() |
|
num_steps = int(math.log2(L)) |
|
|
|
|
|
Aa = A |
|
Xa = X |
|
for _ in range(num_steps-2): |
|
T = Xa.size(2) |
|
Aa = Aa.view(B, D, T//2, 2, -1) |
|
Xa = Xa.view(B, D, T//2, 2, -1) |
|
|
|
Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0])) |
|
Aa[:, :, :, 1].mul_(Aa[:, :, :, 0]) |
|
|
|
Aa = Aa[:, :, :, 1] |
|
Xa = Xa[:, :, :, 1] |
|
|
|
|
|
if Xa.size(2) == 4: |
|
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) |
|
Aa[:, :, 1].mul_(Aa[:, :, 0]) |
|
|
|
Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1]))) |
|
elif Xa.size(2) == 2: |
|
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) |
|
return |
|
else: |
|
return |
|
|
|
|
|
Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)] |
|
Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)] |
|
Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1])) |
|
Aa[:, :, 2].mul_(Aa[:, :, 1]) |
|
|
|
for k in range(num_steps-3, -1, -1): |
|
Aa = A[:, :, 2**k-1:L:2**k] |
|
Xa = X[:, :, 2**k-1:L:2**k] |
|
|
|
T = Xa.size(2) |
|
Aa = Aa.view(B, D, T//2, 2, -1) |
|
Xa = Xa.view(B, D, T//2, 2, -1) |
|
|
|
Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1])) |
|
Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1]) |
|
|
|
@staticmethod |
|
def pscan_rev(A, X): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B, D, L, _ = A.size() |
|
num_steps = int(math.log2(L)) |
|
|
|
|
|
Aa = A |
|
Xa = X |
|
for _ in range(num_steps-2): |
|
T = Xa.size(2) |
|
Aa = Aa.view(B, D, T//2, 2, -1) |
|
Xa = Xa.view(B, D, T//2, 2, -1) |
|
|
|
Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1])) |
|
Aa[:, :, :, 0].mul_(Aa[:, :, :, 1]) |
|
|
|
Aa = Aa[:, :, :, 0] |
|
Xa = Xa[:, :, :, 0] |
|
|
|
|
|
if Xa.size(2) == 4: |
|
Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3])) |
|
Aa[:, :, 2].mul_(Aa[:, :, 3]) |
|
|
|
Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2])))) |
|
elif Xa.size(2) == 2: |
|
Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1])) |
|
return |
|
else: |
|
return |
|
|
|
|
|
Aa = A[:, :, 0:L:2**(num_steps-2)] |
|
Xa = X[:, :, 0:L:2**(num_steps-2)] |
|
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2])) |
|
Aa[:, :, 1].mul_(Aa[:, :, 2]) |
|
|
|
for k in range(num_steps-3, -1, -1): |
|
Aa = A[:, :, 0:L:2**k] |
|
Xa = X[:, :, 0:L:2**k] |
|
|
|
T = Xa.size(2) |
|
Aa = Aa.view(B, D, T//2, 2, -1) |
|
Xa = Xa.view(B, D, T//2, 2, -1) |
|
|
|
Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0])) |
|
Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0]) |
|
|
|
@staticmethod |
|
def forward(ctx, A_in, X_in): |
|
""" |
|
Applies the parallel scan operation, as defined above. Returns a new tensor. |
|
If you can, privilege sequence lengths that are powers of two. |
|
|
|
Args: |
|
A_in : (B, L, D, N) |
|
X_in : (B, L, D, N) |
|
|
|
Returns: |
|
H : (B, L, D, N) |
|
""" |
|
|
|
L = X_in.size(1) |
|
|
|
|
|
if L == npo2(L): |
|
A = A_in.clone() |
|
X = X_in.clone() |
|
else: |
|
|
|
A = pad_npo2(A_in) |
|
X = pad_npo2(X_in) |
|
|
|
|
|
A = A.transpose(2, 1) |
|
X = X.transpose(2, 1) |
|
|
|
|
|
PScan.pscan(A, X) |
|
|
|
ctx.save_for_backward(A_in, X) |
|
|
|
|
|
return X.transpose(2, 1)[:, :L] |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output_in): |
|
""" |
|
Flows the gradient from the output to the input. Returns two new tensors. |
|
|
|
Args: |
|
ctx : A_in : (B, L, D, N), X : (B, D, L, N) |
|
grad_output_in : (B, L, D, N) |
|
|
|
Returns: |
|
gradA : (B, L, D, N), gradX : (B, L, D, N) |
|
""" |
|
|
|
A_in, X = ctx.saved_tensors |
|
|
|
L = grad_output_in.size(1) |
|
|
|
|
|
if L == npo2(L): |
|
grad_output = grad_output_in.clone() |
|
|
|
else: |
|
grad_output = pad_npo2(grad_output_in) |
|
A_in = pad_npo2(A_in) |
|
|
|
|
|
grad_output = grad_output.transpose(2, 1) |
|
A_in = A_in.transpose(2, 1) |
|
A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) |
|
|
|
|
|
PScan.pscan_rev(A, grad_output) |
|
|
|
Q = torch.zeros_like(X) |
|
Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:]) |
|
|
|
return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L] |
|
|
|
pscan = PScan.apply |
|
|