import torch class InPlaceSetSlice(torch.autograd.Function): @staticmethod def forward(ctx, full_tensor, last_slice, x_idx, x_val): full_tensor[x_idx] = x_val ctx.x_idx = x_idx ret = torch.Tensor().to(device=full_tensor.device, dtype=torch.bfloat16) ret.set_(full_tensor[:x_idx + 1]) return ret @staticmethod def backward(ctx, grad_out): if ctx.x_idx == 0: return None, None, None, grad_out[ctx.x_idx] else: return None, grad_out[:ctx.x_idx], None, grad_out[ctx.x_idx] def apply_inplace_set(x_acc, x_idx, x_val): full_tensor, last_slice = x_acc new_slice = InPlaceSetSlice.apply(full_tensor, last_slice, x_idx, x_val) return full_tensor, new_slice class DWAModules(torch.nn.Module): def __init__(self, n_blocks, dilation=1, period=1): super().__init__() self.n_blocks = n_blocks self.dilation = dilation self.period = period self.alphas = torch.nn.ModuleList([torch.nn.Linear((i+1+dilation)//dilation, 1, bias=False) if (i+1)%period == 0 else None for i in range(n_blocks)]) self.accumulators = None self._init_weights() def _init_weights(self): for module in self.alphas: if module is not None: module.weight.data.zero_() module.weight.data[0, -1] = 1. def init_accumulators(self, x): x_accs = [] for i in range(self.dilation): current_group_size = (self.n_blocks + 1) // self.dilation if i < (self.n_blocks + 1) % self.dilation: current_group_size += 1 x_accs.append((torch.zeros((current_group_size, *x.shape), device=x.device, dtype=x.dtype), None)) x_accs[0] = apply_inplace_set(x_accs[0], 0, x) self.accumulators = x_accs def forward(self, x, block_idx): assert self.accumulators is not None, "`init_accumulators(x)` needs to be called first" self.accumulators[(block_idx+1) % self.dilation] = apply_inplace_set( self.accumulators[(block_idx+1) % self.dilation], (block_idx+1)//self.dilation, x ) if (block_idx+1) % self.period == 0: x = torch.tensordot(self.alphas[block_idx].weight.view(-1), self.accumulators[(block_idx+1)%self.dilation][1], dims=1) return x