File size: 2,191 Bytes
f25282b
 
 
 
 
 
 
 
 
4d34a45
f25282b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch


class InPlaceSetSlice(torch.autograd.Function):

  @staticmethod
  def forward(ctx, full_tensor, last_slice, x_idx, x_val):
    full_tensor[x_idx] = x_val
    ctx.x_idx = x_idx
    ret = torch.Tensor().to(device=full_tensor.device, dtype=torch.bfloat16)
    ret.set_(full_tensor[:x_idx + 1])
    return ret

  @staticmethod
  def backward(ctx, grad_out):
    if ctx.x_idx == 0:
      return None, None, None, grad_out[ctx.x_idx]
    else:
      return None, grad_out[:ctx.x_idx], None, grad_out[ctx.x_idx]


def apply_inplace_set(x_acc, x_idx, x_val):
  full_tensor, last_slice = x_acc
  new_slice = InPlaceSetSlice.apply(full_tensor, last_slice, x_idx, x_val)
  return full_tensor, new_slice


class DWAModules(torch.nn.Module):

  def __init__(self, n_blocks, dilation=1, period=1):
    super().__init__()
    self.n_blocks = n_blocks
    self.dilation = dilation
    self.period = period
    self.alphas = torch.nn.ModuleList([torch.nn.Linear((i+1+dilation)//dilation, 1, bias=False) if (i+1)%period == 0 else None for i in range(n_blocks)])
    self.accumulators = None
    self._init_weights()

  def _init_weights(self):
    for module in self.alphas:
      if module is not None:
        module.weight.data.zero_()
        module.weight.data[0, -1] = 1.

  def init_accumulators(self, x):
    x_accs = []
    for i in range(self.dilation):
      current_group_size = (self.n_blocks + 1) // self.dilation
      if i < (self.n_blocks + 1) % self.dilation:
        current_group_size += 1
      x_accs.append((torch.zeros((current_group_size, *x.shape), device=x.device, dtype=x.dtype), None))
    x_accs[0] = apply_inplace_set(x_accs[0], 0, x)
    self.accumulators = x_accs

  def forward(self, x, block_idx):
    assert self.accumulators is not None, "`init_accumulators(x)` needs to be called first"
    self.accumulators[(block_idx+1) % self.dilation] = apply_inplace_set(
        self.accumulators[(block_idx+1) % self.dilation], 
        (block_idx+1)//self.dilation,
        x  
    )
    if (block_idx+1) % self.period == 0:
      x = torch.tensordot(self.alphas[block_idx].weight.view(-1), self.accumulators[(block_idx+1)%self.dilation][1], dims=1)
    return x