TabPFN commited on
Commit
fb508eb
1 Parent(s): f1c7310

Upload layer.py

Browse files
Files changed (1) hide show
  1. layer.py +131 -0
layer.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ from torch import nn
4
+ from torch.nn.modules.transformer import *
5
+ from torch.nn.modules.transformer import _get_activation_fn
6
+
7
+ from torch.utils.checkpoint import checkpoint
8
+
9
+
10
+ class TransformerEncoderLayer(Module):
11
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
12
+ This standard encoder layer is based on the paper "Attention Is All You Need".
13
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
14
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
15
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
16
+ in a different way during application.
17
+
18
+ Args:
19
+ d_model: the number of expected features in the input (required).
20
+ nhead: the number of heads in the multiheadattention models (required).
21
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
22
+ dropout: the dropout value (default=0.1).
23
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
24
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
25
+ batch_first: If ``True``, then the input and output tensors are provided
26
+ as (batch, seq, feature). Default: ``False``.
27
+
28
+ Examples::
29
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
30
+ >>> src = torch.rand(10, 32, 512)
31
+ >>> out = encoder_layer(src)
32
+
33
+ Alternatively, when ``batch_first`` is ``True``:
34
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
35
+ >>> src = torch.rand(32, 10, 512)
36
+ >>> out = encoder_layer(src)
37
+ """
38
+ __constants__ = ['batch_first']
39
+
40
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
41
+ layer_norm_eps=1e-5, batch_first=False, pre_norm=False,
42
+ device=None, dtype=None, recompute_attn=False) -> None:
43
+ factory_kwargs = {'device': device, 'dtype': dtype}
44
+ super().__init__()
45
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
46
+ **factory_kwargs)
47
+ # Implementation of Feedforward model
48
+ self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
49
+ self.dropout = Dropout(dropout)
50
+ self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
51
+
52
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
53
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
54
+ self.dropout1 = Dropout(dropout)
55
+ self.dropout2 = Dropout(dropout)
56
+ self.pre_norm = pre_norm
57
+ self.recompute_attn = recompute_attn
58
+
59
+ self.activation = _get_activation_fn(activation)
60
+
61
+ def __setstate__(self, state):
62
+ if 'activation' not in state:
63
+ state['activation'] = F.relu
64
+ super().__setstate__(state)
65
+
66
+ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
67
+ r"""Pass the input through the encoder layer.
68
+
69
+ Args:
70
+ src: the sequence to the encoder layer (required).
71
+ src_mask: the mask for the src sequence (optional).
72
+ src_key_padding_mask: the mask for the src keys per batch (optional).
73
+
74
+ Shape:
75
+ see the docs in Transformer class.
76
+ """
77
+ if self.pre_norm:
78
+ src_ = self.norm1(src)
79
+ else:
80
+ src_ = src
81
+ if isinstance(src_mask, tuple):
82
+ # global attention setup
83
+ assert not self.self_attn.batch_first
84
+ assert src_key_padding_mask is None
85
+
86
+ global_src_mask, trainset_src_mask, valset_src_mask = src_mask
87
+
88
+ num_global_tokens = global_src_mask.shape[0]
89
+ num_train_tokens = trainset_src_mask.shape[0]
90
+
91
+ global_tokens_src = src_[:num_global_tokens]
92
+ train_tokens_src = src_[num_global_tokens:num_global_tokens+num_train_tokens]
93
+ global_and_train_tokens_src = src_[:num_global_tokens+num_train_tokens]
94
+ eval_tokens_src = src_[num_global_tokens+num_train_tokens:]
95
+
96
+
97
+ attn = partial(checkpoint, self.self_attn) if self.recompute_attn else self.self_attn
98
+
99
+ global_tokens_src2 = attn(global_tokens_src, global_and_train_tokens_src, global_and_train_tokens_src, None, True, global_src_mask)[0]
100
+ train_tokens_src2 = attn(train_tokens_src, global_tokens_src, global_tokens_src, None, True, trainset_src_mask)[0]
101
+ eval_tokens_src2 = attn(eval_tokens_src, src_, src_,
102
+ None, True, valset_src_mask)[0]
103
+
104
+ src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0)
105
+
106
+ elif isinstance(src_mask, int):
107
+ assert src_key_padding_mask is None
108
+ single_eval_position = src_mask
109
+ src_left = self.self_attn(src_[:single_eval_position], src_[:single_eval_position], src_[:single_eval_position])[0]
110
+ src_right = self.self_attn(src_[single_eval_position:], src_[:single_eval_position], src_[:single_eval_position])[0]
111
+ src2 = torch.cat([src_left, src_right], dim=0)
112
+ else:
113
+ if self.recompute_attn:
114
+ src2 = checkpoint(self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask)[0]
115
+ else:
116
+ src2 = self.self_attn(src_, src_, src_, attn_mask=src_mask,
117
+ key_padding_mask=src_key_padding_mask)[0]
118
+ src = src + self.dropout1(src2)
119
+ if not self.pre_norm:
120
+ src = self.norm1(src)
121
+
122
+ if self.pre_norm:
123
+ src_ = self.norm2(src)
124
+ else:
125
+ src_ = src
126
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src_))))
127
+ src = src + self.dropout2(src2)
128
+
129
+ if not self.pre_norm:
130
+ src = self.norm2(src)
131
+ return src