whenxuan commited on
Commit
7c4d438
·
verified ·
1 Parent(s): ff4cdb7

whenxuan: add the config for model

Browse files
Files changed (1) hide show
  1. model.py +142 -140
model.py CHANGED
@@ -1,140 +1,142 @@
1
- from typing import Tuple
2
-
3
- import torch
4
- import torch.nn as nn
5
- from torch import Tensor
6
- from torch.nn import functional as F
7
- from einops import rearrange, repeat
8
- from transformers.modeling_utils import PreTrainedModel
9
-
10
- from configuration_symtime import SymTimeConfig
11
- from layers import MultiHeadAttention, TSTEncoder, TSTEncoderLayer
12
-
13
-
14
- class SymTimeModel(PreTrainedModel):
15
- """
16
- SymTime Model for Huggingface.
17
-
18
- Parameters
19
- ----------
20
- config: SymTimeConfig
21
- The configuration of the SymTime model.
22
-
23
- Attributes
24
- ----------
25
- config: SymTimeConfig
26
- The configuration of the SymTime model.
27
- encoder: TSTEncoder
28
- The encoder of the SymTime model.
29
-
30
- Methods
31
- -------
32
- forward(x: Tensor) -> Tuple[Tensor, Tensor]:
33
- Forward pass of the SymTime model.
34
-
35
- _init_weights(module: nn.Module) -> None:
36
- Initialize weights for the SymTime encoder stack.
37
- """
38
-
39
- def __init__(self, config: SymTimeConfig):
40
- super().__init__(config)
41
- self.config = config
42
- self.encoder = TSTEncoder(
43
- patch_size=config.patch_size,
44
- num_layers=config.num_layers,
45
- hidden_size=config.d_model,
46
- num_heads=config.num_heads,
47
- d_ff=config.d_ff,
48
- norm=config.norm,
49
- attn_dropout=config.dropout,
50
- dropout=config.dropout,
51
- act=config.act,
52
- pre_norm=config.pre_norm,
53
- )
54
-
55
- # Initialize weights and apply final processing
56
- self.post_init()
57
-
58
- def _init_weights(self, module) -> None:
59
- """Initialize weights for the SymTime encoder stack.
60
-
61
- The model is built on top of Hugging Face `PreTrainedModel`, so this method
62
- is called recursively via `post_init()`. We keep the initialization aligned
63
- with the current backbone structure in `layers.py`:
64
-
65
- - `TSTEncoder.W_P`: patch projection linear layer
66
- - `TSTEncoder.cls_token`: learnable CLS token
67
- - `TSTEncoderLayer.self_attn`: Q/K/V and output projections
68
- - `TSTEncoderLayer.ff`: feed-forward linear layers
69
- - `LayerNorm` / `BatchNorm1d`: normalization layers
70
- """
71
- super()._init_weights(module)
72
-
73
- factor = self.config.initializer_factor
74
- d_model = self.config.d_model
75
- num_heads = self.config.num_heads
76
- d_k = d_model // num_heads
77
- d_v = d_k
78
-
79
- if isinstance(module, nn.Linear):
80
- nn.init.normal_(
81
- module.weight, mean=0.0, std=factor * (module.in_features**-0.5)
82
- )
83
- if module.bias is not None:
84
- nn.init.zeros_(module.bias)
85
-
86
- elif isinstance(module, nn.LayerNorm):
87
- nn.init.ones_(module.weight)
88
- nn.init.zeros_(module.bias)
89
-
90
- elif isinstance(module, nn.BatchNorm1d):
91
- if module.weight is not None:
92
- nn.init.ones_(module.weight)
93
- if module.bias is not None:
94
- nn.init.zeros_(module.bias)
95
-
96
- elif isinstance(module, TSTEncoder):
97
- if hasattr(module, "cls_token") and module.cls_token is not None:
98
- nn.init.normal_(module.cls_token, mean=0.0, std=factor)
99
- if hasattr(module, "W_P") and isinstance(module.W_P, nn.Linear):
100
- nn.init.normal_(
101
- module.W_P.weight,
102
- mean=0.0,
103
- std=factor * (module.W_P.in_features**-0.5),
104
- )
105
- if module.W_P.bias is not None:
106
- nn.init.zeros_(module.W_P.bias)
107
-
108
- elif isinstance(module, MultiHeadAttention):
109
- nn.init.normal_(module.W_Q.weight, mean=0.0, std=factor * (d_model**-0.5))
110
- nn.init.normal_(module.W_K.weight, mean=0.0, std=factor * (d_model**-0.5))
111
- nn.init.normal_(module.W_V.weight, mean=0.0, std=factor * (d_model**-0.5))
112
- if module.W_Q.bias is not None:
113
- nn.init.zeros_(module.W_Q.bias)
114
- if module.W_K.bias is not None:
115
- nn.init.zeros_(module.W_K.bias)
116
- if module.W_V.bias is not None:
117
- nn.init.zeros_(module.W_V.bias)
118
-
119
- out_proj = module.to_out[0]
120
- nn.init.normal_(
121
- out_proj.weight, mean=0.0, std=factor * ((num_heads * d_v) ** -0.5)
122
- )
123
- if out_proj.bias is not None:
124
- nn.init.zeros_(out_proj.bias)
125
-
126
- elif isinstance(module, TSTEncoderLayer):
127
- for submodule in module.ff:
128
- if isinstance(submodule, nn.Linear):
129
- nn.init.normal_(
130
- submodule.weight,
131
- mean=0.0,
132
- std=factor * (submodule.in_features**-0.5),
133
- )
134
- if submodule.bias is not None:
135
- nn.init.zeros_(submodule.bias)
136
-
137
- def forward(
138
- self, x: Tensor, return_cls_token: bool = True
139
- ) -> Tuple[Tensor, Tensor]:
140
- return self.encoder(x, return_cls_token=return_cls_token)
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+ from torch.nn import functional as F
7
+ from einops import rearrange, repeat
8
+ from transformers.modeling_utils import PreTrainedModel
9
+
10
+ from configuration_symtime import SymTimeConfig
11
+ from layers import MultiHeadAttention, TSTEncoder, TSTEncoderLayer
12
+
13
+
14
+ class SymTimeModel(PreTrainedModel):
15
+ """
16
+ SymTime Model for Huggingface.
17
+
18
+ Parameters
19
+ ----------
20
+ config: SymTimeConfig
21
+ The configuration of the SymTime model.
22
+
23
+ Attributes
24
+ ----------
25
+ config: SymTimeConfig
26
+ The configuration of the SymTime model.
27
+ encoder: TSTEncoder
28
+ The encoder of the SymTime model.
29
+
30
+ Methods
31
+ -------
32
+ forward(x: Tensor) -> Tuple[Tensor, Tensor]:
33
+ Forward pass of the SymTime model.
34
+
35
+ _init_weights(module: nn.Module) -> None:
36
+ Initialize weights for the SymTime encoder stack.
37
+ """
38
+
39
+ config_class = SymTimeConfig
40
+
41
+ def __init__(self, config: SymTimeConfig):
42
+ super().__init__(config)
43
+ self.config = config
44
+ self.encoder = TSTEncoder(
45
+ patch_size=config.patch_size,
46
+ num_layers=config.num_layers,
47
+ hidden_size=config.d_model,
48
+ num_heads=config.num_heads,
49
+ d_ff=config.d_ff,
50
+ norm=config.norm,
51
+ attn_dropout=config.dropout,
52
+ dropout=config.dropout,
53
+ act=config.act,
54
+ pre_norm=config.pre_norm,
55
+ )
56
+
57
+ # Initialize weights and apply final processing
58
+ self.post_init()
59
+
60
+ def _init_weights(self, module) -> None:
61
+ """Initialize weights for the SymTime encoder stack.
62
+
63
+ The model is built on top of Hugging Face `PreTrainedModel`, so this method
64
+ is called recursively via `post_init()`. We keep the initialization aligned
65
+ with the current backbone structure in `layers.py`:
66
+
67
+ - `TSTEncoder.W_P`: patch projection linear layer
68
+ - `TSTEncoder.cls_token`: learnable CLS token
69
+ - `TSTEncoderLayer.self_attn`: Q/K/V and output projections
70
+ - `TSTEncoderLayer.ff`: feed-forward linear layers
71
+ - `LayerNorm` / `BatchNorm1d`: normalization layers
72
+ """
73
+ super()._init_weights(module)
74
+
75
+ factor = self.config.initializer_factor
76
+ d_model = self.config.d_model
77
+ num_heads = self.config.num_heads
78
+ d_k = d_model // num_heads
79
+ d_v = d_k
80
+
81
+ if isinstance(module, nn.Linear):
82
+ nn.init.normal_(
83
+ module.weight, mean=0.0, std=factor * (module.in_features**-0.5)
84
+ )
85
+ if module.bias is not None:
86
+ nn.init.zeros_(module.bias)
87
+
88
+ elif isinstance(module, nn.LayerNorm):
89
+ nn.init.ones_(module.weight)
90
+ nn.init.zeros_(module.bias)
91
+
92
+ elif isinstance(module, nn.BatchNorm1d):
93
+ if module.weight is not None:
94
+ nn.init.ones_(module.weight)
95
+ if module.bias is not None:
96
+ nn.init.zeros_(module.bias)
97
+
98
+ elif isinstance(module, TSTEncoder):
99
+ if hasattr(module, "cls_token") and module.cls_token is not None:
100
+ nn.init.normal_(module.cls_token, mean=0.0, std=factor)
101
+ if hasattr(module, "W_P") and isinstance(module.W_P, nn.Linear):
102
+ nn.init.normal_(
103
+ module.W_P.weight,
104
+ mean=0.0,
105
+ std=factor * (module.W_P.in_features**-0.5),
106
+ )
107
+ if module.W_P.bias is not None:
108
+ nn.init.zeros_(module.W_P.bias)
109
+
110
+ elif isinstance(module, MultiHeadAttention):
111
+ nn.init.normal_(module.W_Q.weight, mean=0.0, std=factor * (d_model**-0.5))
112
+ nn.init.normal_(module.W_K.weight, mean=0.0, std=factor * (d_model**-0.5))
113
+ nn.init.normal_(module.W_V.weight, mean=0.0, std=factor * (d_model**-0.5))
114
+ if module.W_Q.bias is not None:
115
+ nn.init.zeros_(module.W_Q.bias)
116
+ if module.W_K.bias is not None:
117
+ nn.init.zeros_(module.W_K.bias)
118
+ if module.W_V.bias is not None:
119
+ nn.init.zeros_(module.W_V.bias)
120
+
121
+ out_proj = module.to_out[0]
122
+ nn.init.normal_(
123
+ out_proj.weight, mean=0.0, std=factor * ((num_heads * d_v) ** -0.5)
124
+ )
125
+ if out_proj.bias is not None:
126
+ nn.init.zeros_(out_proj.bias)
127
+
128
+ elif isinstance(module, TSTEncoderLayer):
129
+ for submodule in module.ff:
130
+ if isinstance(submodule, nn.Linear):
131
+ nn.init.normal_(
132
+ submodule.weight,
133
+ mean=0.0,
134
+ std=factor * (submodule.in_features**-0.5),
135
+ )
136
+ if submodule.bias is not None:
137
+ nn.init.zeros_(submodule.bias)
138
+
139
+ def forward(
140
+ self, x: Tensor, return_cls_token: bool = True
141
+ ) -> Tuple[Tensor, Tensor]:
142
+ return self.encoder(x, return_cls_token=return_cls_token)