charmen-electra / gbst.py
Tamás Ficsor
add model
11cdb73
import math
from math import gcd
import functools
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from transformers.modeling_utils import PreTrainedModel
def exists(val):
return val is not None
def lcm(*numbers):
return int(functools.reduce(lambda x, y: int((x * y) / gcd(x, y)), numbers, 1))
def masked_mean(tensor, mask, dim = -1):
diff_len = len(tensor.shape) - len(mask.shape)
mask = mask[(..., *((None,) * diff_len))]
tensor.masked_fill_(~mask, 0.)
total_el = mask.sum(dim = dim)
mean = tensor.sum(dim = dim) / total_el.clamp(min = 1.)
mean.masked_fill_(total_el == 0, 0.)
return mean
def next_divisible_length(seqlen, multiple):
return math.ceil(seqlen / multiple) * multiple
def pad_to_multiple(tensor, multiple, *, seq_dim, dim = -1, value = 0.):
seqlen = tensor.shape[seq_dim]
length = next_divisible_length(seqlen, multiple)
if length == seqlen:
return tensor
remainder = length - seqlen
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(tensor, (*pad_offset, 0, remainder), value = value)
# helper classes
class Pad(nn.Module):
def __init__(self, padding, value = 0.):
super().__init__()
self.padding = padding
self.value = value
def forward(self, x):
return F.pad(x, self.padding, value = self.value)
class DepthwiseConv1d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size):
super().__init__()
self.conv = nn.Conv1d(dim_in, dim_out, kernel_size, groups = dim_in)
self.proj_out = nn.Conv1d(dim_out, dim_out, 1)
def forward(self, x):
x = self.conv(x)
return self.proj_out(x)
# main class
class GBST(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def __init__(
self,
*,
num_tokens,
dim,
max_block_size = None,
blocks = None,
downsample_factor = 4,
score_consensus_attn = True,
return_without_downsample = True,
config = None
):
super(GBST, self).__init__(config=config)
assert exists(max_block_size) ^ exists(blocks), 'either max_block_size or blocks are given on initialization'
self.word_embeddings = nn.Embedding(num_tokens, dim)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, dim)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, dim)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.return_without_downsample = return_without_downsample
if exists(blocks):
assert isinstance(blocks, tuple), 'blocks must be a tuple of block sizes'
self.blocks = tuple(map(lambda el: el if isinstance(el, tuple) else (el, 0), blocks))
assert all([(offset < block_size) for block_size, offset in self.blocks]), 'offset must be always smaller than the block size'
max_block_size = max(list(map(lambda t: t[0], self.blocks)))
else:
self.blocks = tuple(map(lambda el: (el, 0), range(1, max_block_size + 1)))
self.pos_conv = nn.Sequential(
Pad((0, 0, 0, max_block_size - 1)),
Rearrange('b n d -> b d n'),
DepthwiseConv1d(dim, dim, kernel_size = max_block_size),
Rearrange('b d n -> b n d')
)
self.score_fn = nn.Sequential(
nn.Linear(dim, 1),
Rearrange('... () -> ...')
)
self.score_consensus_attn = score_consensus_attn
assert downsample_factor <= max_block_size, 'final downsample factor should be less than the maximum block size'
self.block_pad_multiple = lcm(*[block_size for block_size, _ in self.blocks])
self.downsample_factor = downsample_factor
def forward(self, input_ids, attention_mask=None, position_ids=None, token_type_ids=None, inputs_embeds=None):
b, n, block_mult, ds_factor, device = *input_ids.shape, self.block_pad_multiple, self.downsample_factor, input_ids.device
m = next_divisible_length(n, ds_factor)
# get character token embeddings
input_ids = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
seq_len = input_ids.size()[1]
position_ids = self.position_ids[:, :seq_len]
position_embeddings = self.position_embeddings(position_ids)
input_ids = input_ids + token_type_embeddings + position_embeddings
# do a conv to generate the positions for the tokens
input_ids = self.pos_conv(input_ids)
# pad both sequence and attention_mask to length visibile by all block sizes from 0 to max block size
input_ids = pad_to_multiple(input_ids, block_mult, seq_dim=1, dim=-2)
if exists(attention_mask):
attention_mask = pad_to_multiple(attention_mask, block_mult, seq_dim=1, dim=-1, value=False)
# compute representations for all blocks by mean pooling
block_masks = []
block_reprs = []
for block_size, offset in self.blocks:
# clone the input sequence as well as the attention_mask, in order to pad for offsets
block_x = input_ids.clone()
if exists(attention_mask):
block_mask = attention_mask.clone()
# pad for offsets, if needed
need_padding = offset > 0
if need_padding:
left_offset, right_offset = (block_size - offset), offset
block_x = F.pad(block_x, (0, 0, left_offset, right_offset), value = 0.)
if exists(attention_mask):
block_mask = F.pad(block_mask, (left_offset, right_offset), value = False)
# group input sequence into blocks
blocks = rearrange(block_x, 'b (n m) d -> b n m d', m = block_size)
# either mean pool the blocks, or do a masked mean
if exists(attention_mask):
mask_blocks = rearrange(block_mask, 'b (n m) -> b n m', m = block_size)
block_repr = masked_mean(blocks, mask_blocks, dim = -2)
else:
block_repr = blocks.mean(dim = -2)
# append the block representations, as well as the pooled block masks
block_repr = repeat(block_repr, 'b n d -> b (n m) d', m = block_size)
if need_padding:
block_repr = block_repr[:, left_offset:-right_offset]
block_reprs.append(block_repr)
if exists(attention_mask):
mask_blocks = torch.any(mask_blocks, dim = -1)
mask_blocks = repeat(mask_blocks, 'b n -> b (n m)', m = block_size)
if need_padding:
mask_blocks = mask_blocks[:, left_offset:-right_offset]
block_masks.append(mask_blocks)
# stack all the block representations
block_reprs = torch.stack(block_reprs, dim = 2)
# calculate scores and softmax across the block size dimension
scores = self.score_fn(block_reprs)
if exists(attention_mask):
block_masks = torch.stack(block_masks, dim = 2)
max_neg_value = -torch.finfo(scores.dtype).max
scores = scores.masked_fill(~block_masks, max_neg_value)
scores = scores.softmax(dim = 2)
# do the cheap consensus attention, eq (5) in paper
if self.score_consensus_attn:
score_sim = einsum('b i d, b j d -> b i j', scores, scores)
if exists(attention_mask):
cross_mask = rearrange(attention_mask, 'b i -> b i ()') * rearrange(attention_mask, 'b j -> b () j')
max_neg_value = -torch.finfo(score_sim.dtype).max
score_sim = score_sim.masked_fill(~cross_mask, max_neg_value)
score_attn = score_sim.softmax(dim=-1)
scores = einsum('b i j, b j m -> b i m', score_attn, scores)
# multiply the block representations by the position-wise scores
scores = rearrange(scores, 'b n m -> b n m ()')
input_ids = (block_reprs * scores).sum(dim=2)
# truncate to length divisible by downsample factor
input_ids = input_ids[:, :m]
original = None
if self.return_without_downsample:
original = torch.clone(input_ids)
input_ids, attention_mask = self.down_sample(input_ids, attention_mask, ds_factor)
return input_ids, attention_mask, original
@staticmethod
def down_sample(input_ids, attention_mask, ds_factor):
n = input_ids.shape[1]
m = next_divisible_length(n, ds_factor)
if exists(attention_mask):
attention_mask = attention_mask[:, :m]
# final mean pooling downsample
input_ids = rearrange(input_ids, 'b (n m) d -> b n m d', m=ds_factor)
if exists(attention_mask):
attention_mask = rearrange(attention_mask, 'b (n m) -> b n m', m=ds_factor)
input_ids = masked_mean(input_ids, attention_mask, dim=2)
attention_mask = torch.any(attention_mask, dim=-1)
else:
input_ids = input_ids.mean(dim=-2)
return input_ids, attention_mask
def block_score(self, input_ids, attention_mask=None, position_ids=None, token_type_ids=None, inputs_embeds=None):
b, n, block_mult, ds_factor, device = *input_ids.shape, self.block_pad_multiple, self.downsample_factor, input_ids.device
m = next_divisible_length(n, ds_factor)
# get character token embeddings
input_ids = self.word_embeddings(input_ids)
# do a conv to generate the positions for the tokens
input_ids = self.pos_conv(input_ids)
# pad both sequence and attention_mask to length visibile by all block sizes from 0 to max block size
input_ids = pad_to_multiple(input_ids, block_mult, seq_dim=1, dim=-2)
if exists(attention_mask):
attention_mask = pad_to_multiple(attention_mask, block_mult, seq_dim=1, dim=-1, value=False)
# compute representations for all blocks by mean pooling
block_masks = []
block_reprs = []
for block_size, offset in self.blocks:
# clone the input sequence as well as the attention_mask, in order to pad for offsets
block_x = input_ids.clone()
if exists(attention_mask):
block_mask = attention_mask.clone()
# pad for offsets, if needed
need_padding = offset > 0
if need_padding:
left_offset, right_offset = (block_size - offset), offset
block_x = F.pad(block_x, (0, 0, left_offset, right_offset), value = 0.)
if exists(attention_mask):
block_mask = F.pad(block_mask, (left_offset, right_offset), value = False)
# group input sequence into blocks
blocks = rearrange(block_x, 'b (n m) d -> b n m d', m = block_size)
# either mean pool the blocks, or do a masked mean
if exists(attention_mask):
mask_blocks = rearrange(block_mask, 'b (n m) -> b n m', m = block_size)
block_repr = masked_mean(blocks, mask_blocks, dim = -2)
else:
block_repr = blocks.mean(dim = -2)
# append the block representations, as well as the pooled block masks
block_repr = repeat(block_repr, 'b n d -> b (n m) d', m = block_size)
if need_padding:
block_repr = block_repr[:, left_offset:-right_offset]
block_reprs.append(block_repr)
if exists(attention_mask):
mask_blocks = torch.any(mask_blocks, dim = -1)
mask_blocks = repeat(mask_blocks, 'b n -> b (n m)', m = block_size)
if need_padding:
mask_blocks = mask_blocks[:, left_offset:-right_offset]
block_masks.append(mask_blocks)
# stack all the block representations
block_reprs = torch.stack(block_reprs, dim = 2)
# calculate scores and softmax across the block size dimension
scores = self.score_fn(block_reprs)
if exists(attention_mask):
block_masks = torch.stack(block_masks, dim = 2)
max_neg_value = -torch.finfo(scores.dtype).max
scores = scores.masked_fill(~block_masks, max_neg_value)
scores = scores.softmax(dim = 2)
# do the cheap consensus attention, eq (5) in paper
if self.score_consensus_attn:
score_sim = einsum('b i d, b j d -> b i j', scores, scores)
if exists(attention_mask):
cross_mask = rearrange(attention_mask, 'b i -> b i ()') * rearrange(attention_mask, 'b j -> b () j')
max_neg_value = -torch.finfo(score_sim.dtype).max
score_sim = score_sim.masked_fill(~cross_mask, max_neg_value)
score_attn = score_sim.softmax(dim=-1)
scores = einsum('b i j, b j m -> b i m', score_attn, scores)
# multiply the block representations by the position-wise scores
scores = rearrange(scores, 'b n m -> b n m ()')
input_ids = (block_reprs * scores).sum(dim=2)
# truncate to length divisible by downsample factor
input_ids = input_ids[:, :m]
if exists(attention_mask):
attention_mask = attention_mask[:, :m]
original = None
if self.return_without_downsample:
original = torch.clone(input_ids)
# final mean pooling downsample
input_ids = rearrange(input_ids, 'b (n m) d -> b n m d', m=ds_factor)
if exists(attention_mask):
attention_mask = rearrange(attention_mask, 'b (n m) -> b n m', m=ds_factor)
input_ids = masked_mean(input_ids, attention_mask, dim=2)
attention_mask = torch.any(attention_mask, dim=-1)
else:
input_ids = input_ids.mean(dim=-2)
return scores