Spaces:
Runtime error
Runtime error
Upload transformer.py
Browse files- TabPFN/transformer.py +24 -18
TabPFN/transformer.py
CHANGED
@@ -15,7 +15,7 @@ class TransformerModel(nn.Module):
|
|
15 |
def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, style_encoder=None, y_encoder=None,
|
16 |
pos_encoder=None, decoder=None, input_normalization=False, init_method=None, pre_norm=False,
|
17 |
activation='gelu', recompute_attn=False, num_global_att_tokens=0, full_attention=False,
|
18 |
-
all_layers_same_init=True):
|
19 |
super().__init__()
|
20 |
self.model_type = 'Transformer'
|
21 |
encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
|
@@ -34,12 +34,17 @@ class TransformerModel(nn.Module):
|
|
34 |
assert not full_attention
|
35 |
self.global_att_embeddings = nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
|
36 |
self.full_attention = full_attention
|
|
|
37 |
|
38 |
self.n_out = n_out
|
39 |
self.nhid = nhid
|
40 |
|
41 |
self.init_weights()
|
42 |
|
|
|
|
|
|
|
|
|
43 |
@staticmethod
|
44 |
def generate_square_subsequent_mask(sz):
|
45 |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
@@ -93,35 +98,37 @@ class TransformerModel(nn.Module):
|
|
93 |
nn.init.zeros_(attn.out_proj.bias)
|
94 |
|
95 |
def forward(self, src, src_mask=None, single_eval_pos=None):
|
96 |
-
assert isinstance(src, tuple), '
|
97 |
|
98 |
-
if len(src) == 2:
|
99 |
src = (None,) + src
|
100 |
|
101 |
-
style_src,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
|
103 |
if src_mask is None:
|
104 |
-
x_src = src[1]
|
105 |
if self.global_att_embeddings is None:
|
106 |
-
full_len = len(x_src) +
|
107 |
if self.full_attention:
|
108 |
src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
|
|
|
|
|
109 |
else:
|
110 |
-
src_mask = self.generate_D_q_matrix(
|
111 |
else:
|
112 |
src_mask_args = (self.global_att_embeddings.num_embeddings,
|
113 |
-
len(x_src) +
|
114 |
-
len(x_src) +
|
115 |
src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
|
116 |
self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
|
117 |
self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))
|
118 |
|
119 |
-
style_src, x_src, y_src = src
|
120 |
-
x_src = self.encoder(x_src)
|
121 |
-
y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
|
122 |
-
style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else torch.tensor([], device=x_src.device)
|
123 |
-
global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
|
124 |
-
self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)
|
125 |
train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
|
126 |
src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
|
127 |
|
@@ -131,10 +138,9 @@ class TransformerModel(nn.Module):
|
|
131 |
if self.pos_encoder is not None:
|
132 |
src = self.pos_encoder(src)
|
133 |
|
134 |
-
|
135 |
-
output = self.transformer_encoder(src, src_mask)[style_src_size:]
|
136 |
output = self.decoder(output)
|
137 |
-
return output[single_eval_pos+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]
|
138 |
|
139 |
@torch.no_grad()
|
140 |
def init_from_small_model(self, small_model):
|
|
|
15 |
def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, style_encoder=None, y_encoder=None,
|
16 |
pos_encoder=None, decoder=None, input_normalization=False, init_method=None, pre_norm=False,
|
17 |
activation='gelu', recompute_attn=False, num_global_att_tokens=0, full_attention=False,
|
18 |
+
all_layers_same_init=False, efficient_eval_masking=True):
|
19 |
super().__init__()
|
20 |
self.model_type = 'Transformer'
|
21 |
encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
|
|
|
34 |
assert not full_attention
|
35 |
self.global_att_embeddings = nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
|
36 |
self.full_attention = full_attention
|
37 |
+
self.efficient_eval_masking = efficient_eval_masking
|
38 |
|
39 |
self.n_out = n_out
|
40 |
self.nhid = nhid
|
41 |
|
42 |
self.init_weights()
|
43 |
|
44 |
+
def __setstate__(self, state):
|
45 |
+
super().__setstate__(state)
|
46 |
+
self.__dict__.setdefault('efficient_eval_masking', False)
|
47 |
+
|
48 |
@staticmethod
|
49 |
def generate_square_subsequent_mask(sz):
|
50 |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
|
|
98 |
nn.init.zeros_(attn.out_proj.bias)
|
99 |
|
100 |
def forward(self, src, src_mask=None, single_eval_pos=None):
|
101 |
+
assert isinstance(src, tuple), 'inputs (src) have to be given as (x,y) or (style,x,y) tuple'
|
102 |
|
103 |
+
if len(src) == 2: # (x,y) and no style
|
104 |
src = (None,) + src
|
105 |
|
106 |
+
style_src, x_src, y_src = src
|
107 |
+
x_src = self.encoder(x_src)
|
108 |
+
y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
|
109 |
+
style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else \
|
110 |
+
torch.tensor([], device=x_src.device)
|
111 |
+
global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
|
112 |
+
self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)
|
113 |
+
|
114 |
if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
|
115 |
if src_mask is None:
|
|
|
116 |
if self.global_att_embeddings is None:
|
117 |
+
full_len = len(x_src) + len(style_src)
|
118 |
if self.full_attention:
|
119 |
src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
|
120 |
+
elif self.efficient_eval_masking:
|
121 |
+
src_mask = single_eval_pos + len(style_src)
|
122 |
else:
|
123 |
+
src_mask = self.generate_D_q_matrix(full_len, len(x_src) - single_eval_pos).to(x_src.device)
|
124 |
else:
|
125 |
src_mask_args = (self.global_att_embeddings.num_embeddings,
|
126 |
+
len(x_src) + len(style_src),
|
127 |
+
len(x_src) + len(style_src) - single_eval_pos)
|
128 |
src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
|
129 |
self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
|
130 |
self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
|
133 |
src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
|
134 |
|
|
|
138 |
if self.pos_encoder is not None:
|
139 |
src = self.pos_encoder(src)
|
140 |
|
141 |
+
output = self.transformer_encoder(src, src_mask)
|
|
|
142 |
output = self.decoder(output)
|
143 |
+
return output[single_eval_pos+len(style_src)+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]
|
144 |
|
145 |
@torch.no_grad()
|
146 |
def init_from_small_model(self, small_model):
|