|
import math |
|
from math import gcd |
|
import functools |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn, einsum |
|
|
|
from einops import rearrange, reduce, repeat |
|
from einops.layers.torch import Rearrange |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def lcm(*numbers): |
|
return int(functools.reduce(lambda x, y: int((x * y) / gcd(x, y)), numbers, 1)) |
|
|
|
|
|
def masked_mean(tensor, mask, dim = -1): |
|
diff_len = len(tensor.shape) - len(mask.shape) |
|
mask = mask[(..., *((None,) * diff_len))] |
|
tensor.masked_fill_(~mask, 0.) |
|
|
|
total_el = mask.sum(dim = dim) |
|
mean = tensor.sum(dim = dim) / total_el.clamp(min = 1.) |
|
mean.masked_fill_(total_el == 0, 0.) |
|
return mean |
|
|
|
|
|
def next_divisible_length(seqlen, multiple): |
|
return math.ceil(seqlen / multiple) * multiple |
|
|
|
|
|
def pad_to_multiple(tensor, multiple, *, seq_dim, dim = -1, value = 0.): |
|
seqlen = tensor.shape[seq_dim] |
|
length = next_divisible_length(seqlen, multiple) |
|
if length == seqlen: |
|
return tensor |
|
remainder = length - seqlen |
|
pad_offset = (0,) * (-1 - dim) * 2 |
|
return F.pad(tensor, (*pad_offset, 0, remainder), value = value) |
|
|
|
|
|
|
|
class Pad(nn.Module): |
|
def __init__(self, padding, value = 0.): |
|
super().__init__() |
|
self.padding = padding |
|
self.value = value |
|
|
|
def forward(self, x): |
|
return F.pad(x, self.padding, value = self.value) |
|
|
|
|
|
class DepthwiseConv1d(nn.Module): |
|
def __init__(self, dim_in, dim_out, kernel_size): |
|
super().__init__() |
|
self.conv = nn.Conv1d(dim_in, dim_out, kernel_size, groups = dim_in) |
|
self.proj_out = nn.Conv1d(dim_out, dim_out, 1) |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
return self.proj_out(x) |
|
|
|
|
|
|
|
class GBST(PreTrainedModel): |
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
def __init__( |
|
self, |
|
*, |
|
num_tokens, |
|
dim, |
|
max_block_size = None, |
|
blocks = None, |
|
downsample_factor = 4, |
|
score_consensus_attn = True, |
|
return_without_downsample = True, |
|
config = None |
|
): |
|
super(GBST, self).__init__(config=config) |
|
assert exists(max_block_size) ^ exists(blocks), 'either max_block_size or blocks are given on initialization' |
|
self.word_embeddings = nn.Embedding(num_tokens, dim) |
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, dim) |
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, dim) |
|
|
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) |
|
|
|
self.return_without_downsample = return_without_downsample |
|
|
|
if exists(blocks): |
|
assert isinstance(blocks, tuple), 'blocks must be a tuple of block sizes' |
|
self.blocks = tuple(map(lambda el: el if isinstance(el, tuple) else (el, 0), blocks)) |
|
assert all([(offset < block_size) for block_size, offset in self.blocks]), 'offset must be always smaller than the block size' |
|
|
|
max_block_size = max(list(map(lambda t: t[0], self.blocks))) |
|
else: |
|
self.blocks = tuple(map(lambda el: (el, 0), range(1, max_block_size + 1))) |
|
|
|
self.pos_conv = nn.Sequential( |
|
Pad((0, 0, 0, max_block_size - 1)), |
|
Rearrange('b n d -> b d n'), |
|
DepthwiseConv1d(dim, dim, kernel_size = max_block_size), |
|
Rearrange('b d n -> b n d') |
|
) |
|
|
|
self.score_fn = nn.Sequential( |
|
nn.Linear(dim, 1), |
|
Rearrange('... () -> ...') |
|
) |
|
|
|
self.score_consensus_attn = score_consensus_attn |
|
|
|
assert downsample_factor <= max_block_size, 'final downsample factor should be less than the maximum block size' |
|
|
|
self.block_pad_multiple = lcm(*[block_size for block_size, _ in self.blocks]) |
|
self.downsample_factor = downsample_factor |
|
|
|
def forward(self, input_ids, attention_mask=None, position_ids=None, token_type_ids=None, inputs_embeds=None): |
|
b, n, block_mult, ds_factor, device = *input_ids.shape, self.block_pad_multiple, self.downsample_factor, input_ids.device |
|
m = next_divisible_length(n, ds_factor) |
|
|
|
|
|
|
|
input_ids = self.word_embeddings(input_ids) |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
seq_len = input_ids.size()[1] |
|
position_ids = self.position_ids[:, :seq_len] |
|
position_embeddings = self.position_embeddings(position_ids) |
|
|
|
input_ids = input_ids + token_type_embeddings + position_embeddings |
|
|
|
|
|
input_ids = self.pos_conv(input_ids) |
|
|
|
|
|
|
|
input_ids = pad_to_multiple(input_ids, block_mult, seq_dim=1, dim=-2) |
|
|
|
if exists(attention_mask): |
|
attention_mask = pad_to_multiple(attention_mask, block_mult, seq_dim=1, dim=-1, value=False) |
|
|
|
|
|
|
|
block_masks = [] |
|
block_reprs = [] |
|
|
|
for block_size, offset in self.blocks: |
|
|
|
|
|
block_x = input_ids.clone() |
|
|
|
if exists(attention_mask): |
|
block_mask = attention_mask.clone() |
|
|
|
|
|
|
|
need_padding = offset > 0 |
|
|
|
if need_padding: |
|
left_offset, right_offset = (block_size - offset), offset |
|
block_x = F.pad(block_x, (0, 0, left_offset, right_offset), value = 0.) |
|
|
|
if exists(attention_mask): |
|
block_mask = F.pad(block_mask, (left_offset, right_offset), value = False) |
|
|
|
|
|
|
|
blocks = rearrange(block_x, 'b (n m) d -> b n m d', m = block_size) |
|
|
|
|
|
|
|
if exists(attention_mask): |
|
mask_blocks = rearrange(block_mask, 'b (n m) -> b n m', m = block_size) |
|
block_repr = masked_mean(blocks, mask_blocks, dim = -2) |
|
else: |
|
block_repr = blocks.mean(dim = -2) |
|
|
|
|
|
|
|
block_repr = repeat(block_repr, 'b n d -> b (n m) d', m = block_size) |
|
|
|
if need_padding: |
|
block_repr = block_repr[:, left_offset:-right_offset] |
|
|
|
block_reprs.append(block_repr) |
|
|
|
if exists(attention_mask): |
|
mask_blocks = torch.any(mask_blocks, dim = -1) |
|
mask_blocks = repeat(mask_blocks, 'b n -> b (n m)', m = block_size) |
|
|
|
if need_padding: |
|
mask_blocks = mask_blocks[:, left_offset:-right_offset] |
|
|
|
block_masks.append(mask_blocks) |
|
|
|
|
|
|
|
block_reprs = torch.stack(block_reprs, dim = 2) |
|
|
|
|
|
|
|
scores = self.score_fn(block_reprs) |
|
|
|
if exists(attention_mask): |
|
block_masks = torch.stack(block_masks, dim = 2) |
|
max_neg_value = -torch.finfo(scores.dtype).max |
|
scores = scores.masked_fill(~block_masks, max_neg_value) |
|
|
|
scores = scores.softmax(dim = 2) |
|
|
|
|
|
|
|
if self.score_consensus_attn: |
|
score_sim = einsum('b i d, b j d -> b i j', scores, scores) |
|
|
|
if exists(attention_mask): |
|
cross_mask = rearrange(attention_mask, 'b i -> b i ()') * rearrange(attention_mask, 'b j -> b () j') |
|
max_neg_value = -torch.finfo(score_sim.dtype).max |
|
score_sim = score_sim.masked_fill(~cross_mask, max_neg_value) |
|
|
|
score_attn = score_sim.softmax(dim=-1) |
|
scores = einsum('b i j, b j m -> b i m', score_attn, scores) |
|
|
|
|
|
|
|
scores = rearrange(scores, 'b n m -> b n m ()') |
|
input_ids = (block_reprs * scores).sum(dim=2) |
|
|
|
|
|
|
|
input_ids = input_ids[:, :m] |
|
|
|
original = None |
|
if self.return_without_downsample: |
|
original = torch.clone(input_ids) |
|
|
|
input_ids, attention_mask = self.down_sample(input_ids, attention_mask, ds_factor) |
|
|
|
return input_ids, attention_mask, original |
|
|
|
@staticmethod |
|
def down_sample(input_ids, attention_mask, ds_factor): |
|
n = input_ids.shape[1] |
|
m = next_divisible_length(n, ds_factor) |
|
if exists(attention_mask): |
|
attention_mask = attention_mask[:, :m] |
|
|
|
|
|
input_ids = rearrange(input_ids, 'b (n m) d -> b n m d', m=ds_factor) |
|
|
|
if exists(attention_mask): |
|
attention_mask = rearrange(attention_mask, 'b (n m) -> b n m', m=ds_factor) |
|
input_ids = masked_mean(input_ids, attention_mask, dim=2) |
|
attention_mask = torch.any(attention_mask, dim=-1) |
|
else: |
|
input_ids = input_ids.mean(dim=-2) |
|
return input_ids, attention_mask |
|
|
|
def block_score(self, input_ids, attention_mask=None, position_ids=None, token_type_ids=None, inputs_embeds=None): |
|
b, n, block_mult, ds_factor, device = *input_ids.shape, self.block_pad_multiple, self.downsample_factor, input_ids.device |
|
m = next_divisible_length(n, ds_factor) |
|
|
|
|
|
|
|
input_ids = self.word_embeddings(input_ids) |
|
|
|
|
|
|
|
input_ids = self.pos_conv(input_ids) |
|
|
|
|
|
|
|
input_ids = pad_to_multiple(input_ids, block_mult, seq_dim=1, dim=-2) |
|
|
|
if exists(attention_mask): |
|
attention_mask = pad_to_multiple(attention_mask, block_mult, seq_dim=1, dim=-1, value=False) |
|
|
|
|
|
|
|
block_masks = [] |
|
block_reprs = [] |
|
|
|
for block_size, offset in self.blocks: |
|
|
|
|
|
block_x = input_ids.clone() |
|
|
|
if exists(attention_mask): |
|
block_mask = attention_mask.clone() |
|
|
|
|
|
|
|
need_padding = offset > 0 |
|
|
|
if need_padding: |
|
left_offset, right_offset = (block_size - offset), offset |
|
block_x = F.pad(block_x, (0, 0, left_offset, right_offset), value = 0.) |
|
|
|
if exists(attention_mask): |
|
block_mask = F.pad(block_mask, (left_offset, right_offset), value = False) |
|
|
|
|
|
|
|
blocks = rearrange(block_x, 'b (n m) d -> b n m d', m = block_size) |
|
|
|
|
|
|
|
if exists(attention_mask): |
|
mask_blocks = rearrange(block_mask, 'b (n m) -> b n m', m = block_size) |
|
block_repr = masked_mean(blocks, mask_blocks, dim = -2) |
|
else: |
|
block_repr = blocks.mean(dim = -2) |
|
|
|
|
|
|
|
block_repr = repeat(block_repr, 'b n d -> b (n m) d', m = block_size) |
|
|
|
if need_padding: |
|
block_repr = block_repr[:, left_offset:-right_offset] |
|
|
|
block_reprs.append(block_repr) |
|
|
|
if exists(attention_mask): |
|
mask_blocks = torch.any(mask_blocks, dim = -1) |
|
mask_blocks = repeat(mask_blocks, 'b n -> b (n m)', m = block_size) |
|
|
|
if need_padding: |
|
mask_blocks = mask_blocks[:, left_offset:-right_offset] |
|
|
|
block_masks.append(mask_blocks) |
|
|
|
|
|
|
|
block_reprs = torch.stack(block_reprs, dim = 2) |
|
|
|
|
|
|
|
scores = self.score_fn(block_reprs) |
|
|
|
if exists(attention_mask): |
|
block_masks = torch.stack(block_masks, dim = 2) |
|
max_neg_value = -torch.finfo(scores.dtype).max |
|
scores = scores.masked_fill(~block_masks, max_neg_value) |
|
|
|
scores = scores.softmax(dim = 2) |
|
|
|
|
|
|
|
if self.score_consensus_attn: |
|
score_sim = einsum('b i d, b j d -> b i j', scores, scores) |
|
|
|
if exists(attention_mask): |
|
cross_mask = rearrange(attention_mask, 'b i -> b i ()') * rearrange(attention_mask, 'b j -> b () j') |
|
max_neg_value = -torch.finfo(score_sim.dtype).max |
|
score_sim = score_sim.masked_fill(~cross_mask, max_neg_value) |
|
|
|
score_attn = score_sim.softmax(dim=-1) |
|
scores = einsum('b i j, b j m -> b i m', score_attn, scores) |
|
|
|
|
|
|
|
scores = rearrange(scores, 'b n m -> b n m ()') |
|
input_ids = (block_reprs * scores).sum(dim=2) |
|
|
|
|
|
|
|
input_ids = input_ids[:, :m] |
|
|
|
if exists(attention_mask): |
|
attention_mask = attention_mask[:, :m] |
|
|
|
original = None |
|
if self.return_without_downsample: |
|
original = torch.clone(input_ids) |
|
|
|
|
|
input_ids = rearrange(input_ids, 'b (n m) d -> b n m d', m=ds_factor) |
|
|
|
if exists(attention_mask): |
|
attention_mask = rearrange(attention_mask, 'b (n m) -> b n m', m=ds_factor) |
|
input_ids = masked_mean(input_ids, attention_mask, dim=2) |
|
attention_mask = torch.any(attention_mask, dim=-1) |
|
else: |
|
input_ids = input_ids.mean(dim=-2) |
|
|
|
return scores |
|
|