import math import torch import torch.nn.functional as F """ An implementation of the parallel scan operation in PyTorch (Blelloch version). Please see docs/pscan.ipynb for a detailed explanation of what happens here. """ def npo2(len): """ Returns the next power of 2 above len """ return 2 ** math.ceil(math.log2(len)) def pad_npo2(X): """ Pads input length dim to the next power of 2 Args: X : (B, L, D, N) Returns: Y : (B, npo2(L), D, N) """ len_npo2 = npo2(X.size(1)) pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1)) return F.pad(X, pad_tuple, "constant", 0) class PScan(torch.autograd.Function): @staticmethod def pscan(A, X): # A : (B, D, L, N) # X : (B, D, L, N) # modifies X in place by doing a parallel scan. # more formally, X will be populated by these values : # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) # only supports L that is a power of two (mainly for a clearer code) B, D, L, _ = A.size() num_steps = int(math.log2(L)) # up sweep (last 2 steps unfolded) Aa = A Xa = X for _ in range(num_steps-2): T = Xa.size(2) Aa = Aa.view(B, D, T//2, 2, -1) Xa = Xa.view(B, D, T//2, 2, -1) Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0])) Aa[:, :, :, 1].mul_(Aa[:, :, :, 0]) Aa = Aa[:, :, :, 1] Xa = Xa[:, :, :, 1] # we have only 4, 2 or 1 nodes left if Xa.size(2) == 4: Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) Aa[:, :, 1].mul_(Aa[:, :, 0]) Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1]))) elif Xa.size(2) == 2: Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) return else: return # down sweep (first 2 steps unfolded) Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)] Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)] Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1])) Aa[:, :, 2].mul_(Aa[:, :, 1]) for k in range(num_steps-3, -1, -1): Aa = A[:, :, 2**k-1:L:2**k] Xa = X[:, :, 2**k-1:L:2**k] T = Xa.size(2) Aa = Aa.view(B, D, T//2, 2, -1) Xa = Xa.view(B, D, T//2, 2, -1) Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1])) Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1]) @staticmethod def pscan_rev(A, X): # A : (B, D, L, N) # X : (B, D, L, N) # the same function as above, but in reverse # (if you flip the input, call pscan, then flip the output, you get what this function outputs) # it is used in the backward pass # only supports L that is a power of two (mainly for a clearer code) B, D, L, _ = A.size() num_steps = int(math.log2(L)) # up sweep (last 2 steps unfolded) Aa = A Xa = X for _ in range(num_steps-2): T = Xa.size(2) Aa = Aa.view(B, D, T//2, 2, -1) Xa = Xa.view(B, D, T//2, 2, -1) Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1])) Aa[:, :, :, 0].mul_(Aa[:, :, :, 1]) Aa = Aa[:, :, :, 0] Xa = Xa[:, :, :, 0] # we have only 4, 2 or 1 nodes left if Xa.size(2) == 4: Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3])) Aa[:, :, 2].mul_(Aa[:, :, 3]) Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2])))) elif Xa.size(2) == 2: Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1])) return else: return # down sweep (first 2 steps unfolded) Aa = A[:, :, 0:L:2**(num_steps-2)] Xa = X[:, :, 0:L:2**(num_steps-2)] Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2])) Aa[:, :, 1].mul_(Aa[:, :, 2]) for k in range(num_steps-3, -1, -1): Aa = A[:, :, 0:L:2**k] Xa = X[:, :, 0:L:2**k] T = Xa.size(2) Aa = Aa.view(B, D, T//2, 2, -1) Xa = Xa.view(B, D, T//2, 2, -1) Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0])) Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0]) @staticmethod def forward(ctx, A_in, X_in): """ Applies the parallel scan operation, as defined above. Returns a new tensor. If you can, privilege sequence lengths that are powers of two. Args: A_in : (B, L, D, N) X_in : (B, L, D, N) Returns: H : (B, L, D, N) """ L = X_in.size(1) # cloning is requiered because of the in-place ops if L == npo2(L): A = A_in.clone() X = X_in.clone() else: # pad tensors (and clone btw) A = pad_npo2(A_in) # (B, npo2(L), D, N) X = pad_npo2(X_in) # (B, npo2(L), D, N) # prepare tensors A = A.transpose(2, 1) # (B, D, npo2(L), N) X = X.transpose(2, 1) # (B, D, npo2(L), N) # parallel scan (modifies X in-place) PScan.pscan(A, X) ctx.save_for_backward(A_in, X) # slice [:, :L] (cut if there was padding) return X.transpose(2, 1)[:, :L] @staticmethod def backward(ctx, grad_output_in): """ Flows the gradient from the output to the input. Returns two new tensors. Args: ctx : A_in : (B, L, D, N), X : (B, D, L, N) grad_output_in : (B, L, D, N) Returns: gradA : (B, L, D, N), gradX : (B, L, D, N) """ A_in, X = ctx.saved_tensors L = grad_output_in.size(1) # cloning is requiered because of the in-place ops if L == npo2(L): grad_output = grad_output_in.clone() # the next padding will clone A_in else: grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N) A_in = pad_npo2(A_in) # (B, npo2(L), D, N) # prepare tensors grad_output = grad_output.transpose(2, 1) A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N) A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation) # reverse parallel scan (modifies grad_output in-place) PScan.pscan_rev(A, grad_output) Q = torch.zeros_like(X) Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:]) return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L] pscan = PScan.apply