mistral-denseformer-7b / denseformer.py
winglian's picture
Update denseformer.py
4d34a45 verified
import torch
class InPlaceSetSlice(torch.autograd.Function):
@staticmethod
def forward(ctx, full_tensor, last_slice, x_idx, x_val):
full_tensor[x_idx] = x_val
ctx.x_idx = x_idx
ret = torch.Tensor().to(device=full_tensor.device, dtype=torch.bfloat16)
ret.set_(full_tensor[:x_idx + 1])
return ret
@staticmethod
def backward(ctx, grad_out):
if ctx.x_idx == 0:
return None, None, None, grad_out[ctx.x_idx]
else:
return None, grad_out[:ctx.x_idx], None, grad_out[ctx.x_idx]
def apply_inplace_set(x_acc, x_idx, x_val):
full_tensor, last_slice = x_acc
new_slice = InPlaceSetSlice.apply(full_tensor, last_slice, x_idx, x_val)
return full_tensor, new_slice
class DWAModules(torch.nn.Module):
def __init__(self, n_blocks, dilation=1, period=1):
super().__init__()
self.n_blocks = n_blocks
self.dilation = dilation
self.period = period
self.alphas = torch.nn.ModuleList([torch.nn.Linear((i+1+dilation)//dilation, 1, bias=False) if (i+1)%period == 0 else None for i in range(n_blocks)])
self.accumulators = None
self._init_weights()
def _init_weights(self):
for module in self.alphas:
if module is not None:
module.weight.data.zero_()
module.weight.data[0, -1] = 1.
def init_accumulators(self, x):
x_accs = []
for i in range(self.dilation):
current_group_size = (self.n_blocks + 1) // self.dilation
if i < (self.n_blocks + 1) % self.dilation:
current_group_size += 1
x_accs.append((torch.zeros((current_group_size, *x.shape), device=x.device, dtype=x.dtype), None))
x_accs[0] = apply_inplace_set(x_accs[0], 0, x)
self.accumulators = x_accs
def forward(self, x, block_idx):
assert self.accumulators is not None, "`init_accumulators(x)` needs to be called first"
self.accumulators[(block_idx+1) % self.dilation] = apply_inplace_set(
self.accumulators[(block_idx+1) % self.dilation],
(block_idx+1)//self.dilation,
x
)
if (block_idx+1) % self.period == 0:
x = torch.tensordot(self.alphas[block_idx].weight.view(-1), self.accumulators[(block_idx+1)%self.dilation][1], dims=1)
return x