Abner0803 commited on
Commit
e5b68ca
·
verified ·
1 Parent(s): 8160511

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +129 -0
README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ | Transformer_ALiBi shares most of the modules with [Transformer-RPB](https://huggingface.co/Abner0803/Transformer-RPB) except of the below modules
2
+
3
+ ## TransformerComp
4
+
5
+ Add `TransformerComp` into your current script
6
+
7
+ ```python
8
+ class TransformerComp(BaseTransformerComp):
9
+ def __init__(
10
+ self,
11
+ input_dim: int,
12
+ hidden_dim: int,
13
+ num_layers: int,
14
+ num_heads: int,
15
+ dropout: float = 0.1,
16
+ mask_type: str = "none",
17
+ ) -> None:
18
+ """
19
+ mask_type: "none", "alibi", "calibi", "causal"
20
+ """
21
+ super().__init__(input_dim, hidden_dim, num_layers, num_heads, dropout)
22
+ self.feature_layer = nn.Linear(input_dim, hidden_dim)
23
+ self.pe = PositionalEncoding(hidden_dim, dropout)
24
+ self.mask_type = mask_type
25
+
26
+ if self.mask_type in ["alibi", "calibi"]:
27
+ closest_power_of_2 = 2 ** int(math.log2(num_heads))
28
+ base_slopes = torch.pow(
29
+ 2,
30
+ -torch.arange(1, closest_power_of_2 + 1, dtype=torch.float32)
31
+ * 8
32
+ / closest_power_of_2,
33
+ )
34
+
35
+ if closest_power_of_2 != num_heads:
36
+ extra_slopes = torch.pow(
37
+ 2,
38
+ -torch.arange(
39
+ 1,
40
+ 2 * (num_heads - closest_power_of_2) + 1,
41
+ 2,
42
+ dtype=torch.float32,
43
+ )
44
+ * 8
45
+ / closest_power_of_2,
46
+ )
47
+ base_slopes = torch.cat([base_slopes, extra_slopes])
48
+
49
+ self.register_buffer(
50
+ "slopes", base_slopes.view(-1, 1, 1)
51
+ ) # [n_heads, 1, 1]
52
+
53
+ encoder_layer = nn.TransformerEncoderLayer(
54
+ d_model=hidden_dim,
55
+ nhead=num_heads,
56
+ dim_feedforward=hidden_dim * 4,
57
+ dropout=dropout,
58
+ activation="relu",
59
+ batch_first=False,
60
+ )
61
+ self.encoder_norm = nn.LayerNorm(hidden_dim)
62
+ self.transformer_encoder = nn.TransformerEncoder(
63
+ encoder_layer, num_layers=num_layers
64
+ )
65
+
66
+ def _generate_alibi_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
67
+ """
68
+ Creates a mask that is Relative (ALiBi).
69
+ Returns: [Num_Heads, Seq_Len, Seq_Len]
70
+ """
71
+ context_pos = torch.arange(seq_len, device=device).unsqueeze(1)
72
+ memory_pos = torch.arange(seq_len, device=device).unsqueeze(0)
73
+ distance = torch.abs(context_pos - memory_pos)
74
+ alibi_bias = distance * -1.0 * self.slopes
75
+ return alibi_bias
76
+
77
+ def _generate_causal_alibi_mask(
78
+ self, seq_len: int, device: torch.device
79
+ ) -> torch.Tensor:
80
+ """
81
+ Creates a mask that is Relative (ALiBi) and Causal (Mask Wall)
82
+ """
83
+ context_pos = torch.arange(seq_len, device=device).unsqueeze(1)
84
+ memory_pos = torch.arange(seq_len, device=device).unsqueeze(0)
85
+ distance = torch.abs(context_pos - memory_pos)
86
+ alibi_bias = distance * -1.0 * self.slopes
87
+ causal_mask = torch.triu(
88
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1
89
+ )
90
+ alibi_bias.masked_fill_(causal_mask, float("-inf"))
91
+
92
+ return alibi_bias
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ """x.shape [batch, seq_len, n_stocks, n_feats]"""
96
+ x, batch, n_stocks = self._reshape_input(x)
97
+ seq_len = x.shape[0]
98
+ x = self.encoder_norm(self.pe(self.feature_layer(x))) # [t, b * s, d_model]
99
+
100
+ if self.mask_type == "causal":
101
+ mask = self._generate_causal_mask(seq_len, x.device).permute(1, 0)
102
+ elif self.mask_type == "alibi":
103
+ mask = self._generate_alibi_mask(seq_len, x.device).repeat(
104
+ x.shape[1], 1, 1
105
+ ) # [b * s, t, t]
106
+ elif self.mask_type == "calibi":
107
+ mask = self._generate_causal_alibi_mask(seq_len, x.device).repeat(
108
+ x.shape[1], 1, 1
109
+ )
110
+ else:
111
+ mask = None
112
+
113
+ x = self.transformer_encoder(x, mask=mask)
114
+
115
+ return self._reshape_output(x, batch, n_stocks)
116
+ ```
117
+
118
+ ## Model Config
119
+
120
+ ```yaml
121
+ input_dim: 8
122
+ output_dim: 1
123
+ hidden_dim: 64
124
+ num_layers: 2
125
+ num_heads: 4
126
+ dropout: 0.0
127
+ tfm_type: "base"
128
+ mask_type: "alibi"
129
+ ```