TabPFN commited on
Commit
7396040
1 Parent(s): fb508eb

Upload transformer.py

Browse files
Files changed (1) hide show
  1. TabPFN/transformer.py +24 -18
TabPFN/transformer.py CHANGED
@@ -15,7 +15,7 @@ class TransformerModel(nn.Module):
15
  def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, style_encoder=None, y_encoder=None,
16
  pos_encoder=None, decoder=None, input_normalization=False, init_method=None, pre_norm=False,
17
  activation='gelu', recompute_attn=False, num_global_att_tokens=0, full_attention=False,
18
- all_layers_same_init=True):
19
  super().__init__()
20
  self.model_type = 'Transformer'
21
  encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
@@ -34,12 +34,17 @@ class TransformerModel(nn.Module):
34
  assert not full_attention
35
  self.global_att_embeddings = nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
36
  self.full_attention = full_attention
 
37
 
38
  self.n_out = n_out
39
  self.nhid = nhid
40
 
41
  self.init_weights()
42
 
 
 
 
 
43
  @staticmethod
44
  def generate_square_subsequent_mask(sz):
45
  mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
@@ -93,35 +98,37 @@ class TransformerModel(nn.Module):
93
  nn.init.zeros_(attn.out_proj.bias)
94
 
95
  def forward(self, src, src_mask=None, single_eval_pos=None):
96
- assert isinstance(src, tuple), 'fuse_x_y is forbidden, that is inputs have to be given as (x,y) or (style,x,y)'
97
 
98
- if len(src) == 2:
99
  src = (None,) + src
100
 
101
- style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1))
 
 
 
 
 
 
 
102
  if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
103
  if src_mask is None:
104
- x_src = src[1]
105
  if self.global_att_embeddings is None:
106
- full_len = len(x_src) + style_src_size
107
  if self.full_attention:
108
  src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
 
 
109
  else:
110
- src_mask = self.generate_D_q_matrix(len(x_src) + style_src_size, len(x_src) + style_src_size -single_eval_pos).to(x_src.device)
111
  else:
112
  src_mask_args = (self.global_att_embeddings.num_embeddings,
113
- len(x_src) + style_src_size,
114
- len(x_src) + style_src_size - single_eval_pos)
115
  src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
116
  self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
117
  self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))
118
 
119
- style_src, x_src, y_src = src
120
- x_src = self.encoder(x_src)
121
- y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
122
- style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else torch.tensor([], device=x_src.device)
123
- global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
124
- self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)
125
  train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
126
  src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
127
 
@@ -131,10 +138,9 @@ class TransformerModel(nn.Module):
131
  if self.pos_encoder is not None:
132
  src = self.pos_encoder(src)
133
 
134
- # If we have style input, drop its output
135
- output = self.transformer_encoder(src, src_mask)[style_src_size:]
136
  output = self.decoder(output)
137
- return output[single_eval_pos+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]
138
 
139
  @torch.no_grad()
140
  def init_from_small_model(self, small_model):
 
15
  def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, style_encoder=None, y_encoder=None,
16
  pos_encoder=None, decoder=None, input_normalization=False, init_method=None, pre_norm=False,
17
  activation='gelu', recompute_attn=False, num_global_att_tokens=0, full_attention=False,
18
+ all_layers_same_init=False, efficient_eval_masking=True):
19
  super().__init__()
20
  self.model_type = 'Transformer'
21
  encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
 
34
  assert not full_attention
35
  self.global_att_embeddings = nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
36
  self.full_attention = full_attention
37
+ self.efficient_eval_masking = efficient_eval_masking
38
 
39
  self.n_out = n_out
40
  self.nhid = nhid
41
 
42
  self.init_weights()
43
 
44
+ def __setstate__(self, state):
45
+ super().__setstate__(state)
46
+ self.__dict__.setdefault('efficient_eval_masking', False)
47
+
48
  @staticmethod
49
  def generate_square_subsequent_mask(sz):
50
  mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
 
98
  nn.init.zeros_(attn.out_proj.bias)
99
 
100
  def forward(self, src, src_mask=None, single_eval_pos=None):
101
+ assert isinstance(src, tuple), 'inputs (src) have to be given as (x,y) or (style,x,y) tuple'
102
 
103
+ if len(src) == 2: # (x,y) and no style
104
  src = (None,) + src
105
 
106
+ style_src, x_src, y_src = src
107
+ x_src = self.encoder(x_src)
108
+ y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
109
+ style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else \
110
+ torch.tensor([], device=x_src.device)
111
+ global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
112
+ self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)
113
+
114
  if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
115
  if src_mask is None:
 
116
  if self.global_att_embeddings is None:
117
+ full_len = len(x_src) + len(style_src)
118
  if self.full_attention:
119
  src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
120
+ elif self.efficient_eval_masking:
121
+ src_mask = single_eval_pos + len(style_src)
122
  else:
123
+ src_mask = self.generate_D_q_matrix(full_len, len(x_src) - single_eval_pos).to(x_src.device)
124
  else:
125
  src_mask_args = (self.global_att_embeddings.num_embeddings,
126
+ len(x_src) + len(style_src),
127
+ len(x_src) + len(style_src) - single_eval_pos)
128
  src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
129
  self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
130
  self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))
131
 
 
 
 
 
 
 
132
  train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
133
  src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
134
 
 
138
  if self.pos_encoder is not None:
139
  src = self.pos_encoder(src)
140
 
141
+ output = self.transformer_encoder(src, src_mask)
 
142
  output = self.decoder(output)
143
+ return output[single_eval_pos+len(style_src)+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]
144
 
145
  @torch.no_grad()
146
  def init_from_small_model(self, small_model):