TabPFN / TabPFN /transformer.py
Samuel Mueller
last tryyyy
67673f2
raw history blame
No virus
11.2 kB
import math
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Module, TransformerEncoder
from layer import TransformerEncoderLayer, _get_activation_fn
from utils import SeqBN, bool_mask_to_att_mask
class TransformerModel(nn.Module):
def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, style_encoder=None, y_encoder=None,
pos_encoder=None, decoder=None, input_normalization=False, init_method=None, pre_norm=False,
activation='gelu', recompute_attn=False, num_global_att_tokens=0, full_attention=False,
all_layers_same_init=True):
super().__init__()
self.model_type = 'Transformer'
encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
pre_norm=pre_norm, recompute_attn=recompute_attn)
self.transformer_encoder = TransformerEncoder(encoder_layer_creator(), nlayers)\
if all_layers_same_init else TransformerEncoderDiffInit(encoder_layer_creator, nlayers)
self.ninp = ninp
self.encoder = encoder
self.y_encoder = y_encoder
self.pos_encoder = pos_encoder
self.decoder = decoder(ninp, nhid, n_out) if decoder is not None else nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out))
self.input_ln = SeqBN(ninp) if input_normalization else None
self.style_encoder = style_encoder
self.init_method = init_method
if num_global_att_tokens is not None:
assert not full_attention
self.global_att_embeddings = nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
self.full_attention = full_attention
self.n_out = n_out
self.nhid = nhid
self.init_weights()
@staticmethod
def generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
return bool_mask_to_att_mask(mask)
@staticmethod
def generate_D_q_matrix(sz, query_size):
train_size = sz-query_size
mask = torch.zeros(sz,sz) == 0
mask[:,train_size:].zero_()
mask |= torch.eye(sz) == 1
return bool_mask_to_att_mask(mask)
@staticmethod
def generate_global_att_query_matrix(num_global_att_tokens, seq_len, num_query_tokens):
train_size = seq_len + num_global_att_tokens - num_query_tokens
sz = seq_len + num_global_att_tokens
mask = torch.zeros(num_query_tokens, sz) == 0
mask[:,train_size:].zero_()
mask[:,train_size:] |= torch.eye(num_query_tokens) == 1
return bool_mask_to_att_mask(mask)
@staticmethod
def generate_global_att_trainset_matrix(num_global_att_tokens, seq_len, num_query_tokens):
train_size = seq_len + num_global_att_tokens - num_query_tokens
trainset_size = seq_len - num_query_tokens
mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
#mask[:,num_global_att_tokens:].zero_()
#mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
return bool_mask_to_att_mask(mask)
@staticmethod
def generate_global_att_globaltokens_matrix(num_global_att_tokens, seq_len, num_query_tokens):
mask = torch.zeros(num_global_att_tokens, num_global_att_tokens+seq_len-num_query_tokens) == 0
return bool_mask_to_att_mask(mask)
def init_weights(self):
initrange = 1.
# if isinstance(self.encoder,EmbeddingEncoder):
# self.encoder.weight.data.uniform_(-initrange, initrange)
# self.decoder.bias.data.zero_()
# self.decoder.weight.data.uniform_(-initrange, initrange)
if self.init_method is not None:
self.apply(self.init_method)
for layer in self.transformer_encoder.layers:
nn.init.zeros_(layer.linear2.weight)
nn.init.zeros_(layer.linear2.bias)
attns = layer.self_attn if isinstance(layer.self_attn, nn.ModuleList) else [layer.self_attn]
for attn in attns:
nn.init.zeros_(attn.out_proj.weight)
nn.init.zeros_(attn.out_proj.bias)
def forward(self, src, src_mask=None, single_eval_pos=None):
assert isinstance(src, tuple), 'fuse_x_y is forbidden, that is inputs have to be given as (x,y) or (style,x,y)'
if len(src) == 2:
src = (None,) + src
style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1))
if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
if src_mask is None:
x_src = src[1]
if self.global_att_embeddings is None:
full_len = len(x_src) + style_src_size
if self.full_attention:
src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
else:
src_mask = self.generate_D_q_matrix(len(x_src) + style_src_size, len(x_src) + style_src_size -single_eval_pos).to(x_src.device)
else:
src_mask_args = (self.global_att_embeddings.num_embeddings,
len(x_src) + style_src_size,
len(x_src) + style_src_size - single_eval_pos)
src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))
style_src, x_src, y_src = src
x_src = self.encoder(x_src)
y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else torch.tensor([], device=x_src.device)
global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)
train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
if self.input_ln is not None:
src = self.input_ln(src)
if self.pos_encoder is not None:
src = self.pos_encoder(src)
# If we have style input, drop its output
output = self.transformer_encoder(src, src_mask)[style_src_size:]
output = self.decoder(output)
return output[single_eval_pos+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]
@torch.no_grad()
def init_from_small_model(self, small_model):
assert isinstance(self.decoder, nn.Linear) and isinstance(self.encoder, (nn.Linear, nn.Sequential)) \
and isinstance(self.y_encoder, (nn.Linear, nn.Sequential))
def set_encoder_weights(my_encoder, small_model_encoder):
my_encoder_linear, small_encoder_linear = (my_encoder, small_model_encoder) \
if isinstance(my_encoder, nn.Linear) else (my_encoder[-1], small_model_encoder[-1])
small_in_dim = small_encoder_linear.out_features
my_encoder_linear.weight.zero_()
my_encoder_linear.bias.zero_()
my_encoder_linear.weight[:small_in_dim] = small_encoder_linear.weight
my_encoder_linear.bias[:small_in_dim] = small_encoder_linear.bias
set_encoder_weights(self.encoder, small_model.encoder)
set_encoder_weights(self.y_encoder, small_model.y_encoder)
small_in_dim = small_model.decoder.in_features
self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
self.decoder.bias = small_model.decoder.bias
for my_layer, small_layer in zip(self.transformer_encoder.layers, small_model.transformer_encoder.layers):
small_hid_dim = small_layer.linear1.out_features
my_in_dim = my_layer.linear1.in_features
# packed along q,k,v order in first dim
my_in_proj_w = my_layer.self_attn.in_proj_weight
small_in_proj_w = small_layer.self_attn.in_proj_weight
my_in_proj_w.view(3, my_in_dim, my_in_dim)[:, :small_in_dim, :small_in_dim] = small_in_proj_w.view(3,
small_in_dim,
small_in_dim)
my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:,
:small_in_dim] = small_layer.self_attn.in_proj_bias.view(3, small_in_dim)
my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = small_layer.self_attn.out_proj.weight
my_layer.self_attn.out_proj.bias[:small_in_dim] = small_layer.self_attn.out_proj.bias
my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = small_layer.linear1.weight
my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = small_layer.linear2.weight
my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
my_layer.norm1.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight
my_layer.norm2.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight
my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
class TransformerEncoderDiffInit(Module):
r"""TransformerEncoder is a stack of N encoder layers
Args:
encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
"""
__constants__ = ['norm']
def __init__(self, encoder_layer_creator, num_layers, norm=None):
super().__init__()
self.layers = nn.ModuleList([encoder_layer_creator() for _ in range(num_layers)])
self.num_layers = num_layers
self.norm = norm
def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
output = src
for mod in self.layers:
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
if self.norm is not None:
output = self.norm(output)
return output