| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class AdaptiveMask(nn.Module): |
| | """Soft masking function for adaptive size. |
| | It masks out the last K values of an input. The masking value |
| | goes from 1 to 0 gradually, so K can be learned with |
| | back-propagation. |
| | |
| | Args: |
| | max_size: maximum size (i.e. input dimension) |
| | ramp_size: size of the ramp going from 0 to 1 |
| | init_val: initial size proportion not to be masked out |
| | shape: learn multiple sizes independent of each other |
| | """ |
| |
|
| | def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)): |
| | nn.Module.__init__(self) |
| | self._max_size = max_size |
| | self._ramp_size = ramp_size |
| | self.current_val = nn.Parameter(torch.zeros(*shape) + init_val) |
| | mask_template = torch.linspace(1 - max_size, 0, steps=max_size) |
| | self.register_buffer('mask_template', mask_template) |
| |
|
| | def forward(self, x): |
| | mask = self.mask_template + self.current_val * self._max_size |
| | mask = mask / self._ramp_size + 1 |
| | mask=torch.where(mask>0.5,mask,0) |
| | mask = mask.clamp(0, 1) |
| | |
| |
|
| | if x.size(-1) < self._max_size: |
| | |
| | mask = mask[:, :, -x.size(-1):] |
| | x = x * mask |
| | return x |
| |
|
| | def get_current_max_size(self, include_ramp=True): |
| | current_size = math.ceil(self.current_val.max().item() * self._max_size) |
| | if include_ramp: |
| | current_size += self._ramp_size |
| | current_size = max(0, min(self._max_size, current_size)) |
| | return current_size |
| |
|
| | def get_current_avg_size(self, include_ramp=True): |
| | current_size = math.ceil(self.current_val.mean().item() * self._max_size) |
| | if include_ramp: |
| | current_size += self._ramp_size |
| | current_size = max(0, min(self._max_size, current_size)) |
| | return current_size |
| |
|
| | def clamp_param(self): |
| | """this need to be called after each update""" |
| | self.current_val.data.clamp_(0, 1) |
| |
|
| |
|
| | class AdaptiveSpan(nn.Module): |
| | """Adaptive attention span for Transformerself. |
| | This module learns an attention span length from data for each |
| | self-attention head. |
| | |
| | Args: |
| | attn_span: maximum attention span |
| | adapt_span_loss: loss coefficient for the span length |
| | adapt_span_ramp: length of the masking ramp |
| | adapt_span_init: initial size ratio |
| | adapt_span_cache: adapt cache size to reduce memory usage |
| | """ |
| | def __init__(self, attn_span, adapt_span_loss, adapt_span_ramp, |
| | adapt_span_init, adapt_span_cache, nb_heads, **kargs): |
| | nn.Module.__init__(self) |
| | self._adapt_cache = adapt_span_cache |
| | self._max_span = attn_span |
| | self._loss_coeff = adapt_span_loss |
| | self._nb_heads = nb_heads |
| | self._mask = AdaptiveMask(max_size=self._max_span, |
| | ramp_size=adapt_span_ramp, |
| | init_val=adapt_span_init, |
| | shape=(nb_heads, 1, 1)) |
| |
|
| | def forward(self, attn, normalize=True): |
| | """mask attention with the right span""" |
| | |
| | B = attn.size(0) |
| | M = attn.size(1) |
| | attn = attn.reshape(B // self._nb_heads, self._nb_heads, M, -1) |
| |
|
| | attn = self._mask(attn) |
| | if normalize: |
| | attn = attn / (attn.sum(-1, keepdim=True) + 1e-8) |
| |
|
| | attn = attn.view(B, M, -1) |
| | return attn |
| |
|
| | def get_trim_len(self): |
| | """how much of memory can be trimmed to reduce computation""" |
| | L = self._max_span |
| | trim_len = min(L - 1, L - self._mask.get_current_max_size()) |
| | |
| | trim_len = math.floor(trim_len / 64) * 64 |
| | return trim_len |
| |
|
| | def trim_memory(self, query, key, value, key_pe): |
| | """trim out unnecessary memory beforehand to reduce computation""" |
| | trim_len = self.get_trim_len() |
| | cache_size = key.size(1) - query.size(1) |
| | trim_len_cache = trim_len - (self._max_span - cache_size) |
| | if trim_len_cache > 0: |
| | key = key[:, trim_len_cache:, :] |
| | value = value[:, trim_len_cache:, :] |
| | elif trim_len_cache < 0: |
| | |
| | |
| | key = F.pad(key, [0, 0, -trim_len_cache, 0]) |
| | value = F.pad(value, [0, 0, -trim_len_cache, 0]) |
| | if trim_len > 0: |
| | if key_pe is not None: |
| | key_pe = key_pe[:, :, trim_len:] |
| | return key, value, key_pe |
| |
|
| | def get_cache_size(self): |
| | """determine how long the cache should be""" |
| | if self._adapt_cache: |
| | trim_len = self.get_trim_len() |
| | |
| | |
| | return min(self._max_span, self._max_span - trim_len + 64) |
| | else: |
| | return self._max_span |
| |
|
| | def get_loss(self): |
| | """a loss term for regularizing the span length""" |
| | return self._loss_coeff * self._max_span * self._mask.current_val.mean() |
| |
|
| | def get_current_max_span(self): |
| | return self._mask.get_current_max_size() |
| |
|
| | def get_current_avg_span(self): |
| | return self._mask.get_current_avg_size() |
| |
|
| | def clamp_param(self): |
| | self._mask.clamp_param() |
| |
|