HachiML commited on
Commit
7f82313
1 Parent(s): a4b9ee6

Upload model

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +109 -0
  3. configuration_moment.py +103 -0
  4. model.safetensors +3 -0
  5. modeling_moment.py +497 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_positional_embedding": true,
3
+ "architectures": [
4
+ "MomentEmbeddingModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_moment.MomentConfig",
8
+ "AutoModel": "modeling_moment.MomentEmbeddingModel"
9
+ },
10
+ "d_model": 1024,
11
+ "dropout": 0.1,
12
+ "enable_gradient_checkpointing": true,
13
+ "freeze_embedder": true,
14
+ "freeze_encoder": true,
15
+ "freeze_head": false,
16
+ "mask_ratio": 0.0,
17
+ "model_type": "moment",
18
+ "orth_gain": 1.41,
19
+ "patch_len": 8,
20
+ "patch_stride_len": 8,
21
+ "randomly_initialize_backbone": false,
22
+ "revin_affine": false,
23
+ "revin_eps": 1e-05,
24
+ "revin_num_features": 1,
25
+ "seq_len": 512,
26
+ "t5_config": {
27
+ "add_cross_attention": false,
28
+ "attn_implementation": null,
29
+ "bad_words_ids": null,
30
+ "begin_suppress_tokens": null,
31
+ "bos_token_id": null,
32
+ "chunk_size_feed_forward": 0,
33
+ "classifier_dropout": 0.0,
34
+ "cross_attention_hidden_size": null,
35
+ "d_ff": 2816,
36
+ "d_kv": 64,
37
+ "d_model": 1024,
38
+ "decoder_start_token_id": 0,
39
+ "dense_act_fn": "gelu_new",
40
+ "diversity_penalty": 0.0,
41
+ "do_sample": false,
42
+ "dropout_rate": 0.1,
43
+ "early_stopping": false,
44
+ "encoder_no_repeat_ngram_size": 0,
45
+ "eos_token_id": 1,
46
+ "exponential_decay_length_penalty": null,
47
+ "feed_forward_proj": "gated-gelu",
48
+ "finetuning_task": null,
49
+ "forced_bos_token_id": null,
50
+ "forced_eos_token_id": null,
51
+ "id2label": {
52
+ "0": "LABEL_0",
53
+ "1": "LABEL_1"
54
+ },
55
+ "initializer_factor": 1.0,
56
+ "is_decoder": false,
57
+ "is_encoder_decoder": true,
58
+ "is_gated_act": true,
59
+ "label2id": {
60
+ "LABEL_0": 0,
61
+ "LABEL_1": 1
62
+ },
63
+ "layer_norm_epsilon": 1e-06,
64
+ "length_penalty": 1.0,
65
+ "max_length": 20,
66
+ "min_length": 0,
67
+ "n_positions": 512,
68
+ "no_repeat_ngram_size": 0,
69
+ "num_beam_groups": 1,
70
+ "num_beams": 1,
71
+ "num_decoder_layers": 24,
72
+ "num_heads": 16,
73
+ "num_layers": 24,
74
+ "num_return_sequences": 1,
75
+ "output_attentions": false,
76
+ "output_hidden_states": false,
77
+ "output_past": true,
78
+ "output_scores": false,
79
+ "pad_token_id": 0,
80
+ "prefix": null,
81
+ "problem_type": null,
82
+ "pruned_heads": {},
83
+ "relative_attention_max_distance": 128,
84
+ "relative_attention_num_buckets": 32,
85
+ "remove_invalid_values": false,
86
+ "repetition_penalty": 1.0,
87
+ "return_dict": true,
88
+ "return_dict_in_generate": false,
89
+ "sep_token_id": null,
90
+ "suppress_tokens": null,
91
+ "task_specific_params": null,
92
+ "temperature": 1.0,
93
+ "tf_legacy_loss": false,
94
+ "tie_encoder_decoder": false,
95
+ "tie_word_embeddings": false,
96
+ "tokenizer_class": null,
97
+ "top_k": 50,
98
+ "top_p": 1.0,
99
+ "torch_dtype": null,
100
+ "torchscript": false,
101
+ "typical_p": 1.0,
102
+ "use_bfloat16": false,
103
+ "use_cache": true,
104
+ "vocab_size": 32128
105
+ },
106
+ "torch_dtype": "float32",
107
+ "transformers_version": "4.41.2",
108
+ "value_embedding_bias": false
109
+ }
configuration_moment.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Moment model configuration"""
2
+
3
+ from transformers import PretrainedConfig
4
+ from transformers import logging
5
+
6
+
7
+ DEFAULT_T5_CONFIG = {
8
+ # "_name_or_path": "google/flan-t5-large",
9
+ # "architectures": [
10
+ # "T5ForConditionalGeneration"
11
+ # ],
12
+ "classifier_dropout": 0.0,
13
+ "d_ff": 2816,
14
+ "d_kv": 64,
15
+ "d_model": 1024,
16
+ "decoder_start_token_id": 0,
17
+ "dense_act_fn": "gelu_new",
18
+ "dropout_rate": 0.1,
19
+ "eos_token_id": 1,
20
+ "feed_forward_proj": "gated-gelu",
21
+ "initializer_factor": 1.0,
22
+ "is_encoder_decoder": False,
23
+ "is_gated_act": True,
24
+ "layer_norm_epsilon": 1e-06,
25
+ # "model_type": "t5",
26
+ "n_positions": 512,
27
+ "num_decoder_layers": 24,
28
+ "num_heads": 16,
29
+ "num_layers": 24,
30
+ "output_past": True,
31
+ "pad_token_id": 0,
32
+ "relative_attention_max_distance": 128,
33
+ "relative_attention_num_buckets": 32,
34
+ "tie_word_embeddings": False,
35
+ # "transformers_version": "4.33.3",
36
+ "use_cache": False,
37
+ "vocab_size": 32128
38
+ }
39
+
40
+
41
+ class MomentConfig(PretrainedConfig):
42
+ model_type = "moment"
43
+
44
+ def __init__(
45
+ self,
46
+ t5_config: dict = DEFAULT_T5_CONFIG,
47
+ d_model: int = None,
48
+ seq_len: int = 512,
49
+ patch_len: int = 16,
50
+ patch_stride_len: int = 16,
51
+ dropout: float = 0.1,
52
+ revin_num_features: int = 1,
53
+ revin_eps: float = 1e-5,
54
+ revin_affine: bool = True,
55
+ add_positional_embedding: bool = True,
56
+ value_embedding_bias: bool = False,
57
+ orth_gain: float = 1.41,
58
+ mask_ratio: float = 0.15,
59
+ freeze_embedder: bool = True,
60
+ freeze_encoder: bool = True,
61
+ freeze_head: bool = False,
62
+ enable_gradient_checkpointing: bool = True,
63
+ randomly_initialize_backbone: bool = False,
64
+ **kwargs
65
+ ):
66
+ self.t5_config = self._init_t5_config(t5_config)
67
+ self.d_model = d_model
68
+ self.seq_len = seq_len
69
+ self.patch_len = patch_len
70
+ self.patch_stride_len = patch_stride_len
71
+ self.dropout = dropout
72
+ self.revin_num_features = revin_num_features
73
+ self.revin_eps = revin_eps
74
+ self.revin_affine = revin_affine
75
+ self.add_positional_embedding = add_positional_embedding
76
+ self.value_embedding_bias = value_embedding_bias
77
+ self.orth_gain = orth_gain
78
+ self.mask_ratio = mask_ratio
79
+ self.freeze_embedder = freeze_embedder
80
+ self.freeze_encoder = freeze_encoder
81
+ self.freeze_head = freeze_head
82
+ self.enable_gradient_checkpointing = enable_gradient_checkpointing
83
+ self.randomly_initialize_backbone = randomly_initialize_backbone
84
+
85
+ self._validation_config()
86
+
87
+ super().__init__(**kwargs)
88
+
89
+ def _init_t5_config(self, config: dict):
90
+ if config is None:
91
+ return DEFAULT_T5_CONFIG
92
+ else:
93
+ # 与えられたconfigでDEFAULT_T5_CONFIGを更新
94
+ updated_config = DEFAULT_T5_CONFIG.copy()
95
+ updated_config.update(config)
96
+ return updated_config
97
+
98
+ def _validation_config(self):
99
+ """
100
+ Validate configuration.
101
+ """
102
+ if self.d_model is None:
103
+ self.d_model = self.t5_config["d_model"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad070de05a7097e3291fcbeac7ca5185bcf4d4f433b5e16810e56ac2c6a8b429
3
+ size 1385468280
modeling_moment.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Auton LabによるMomentライブラリをTransformers向けに書き換えたものです。
2
+ # Embeddingに特化したアーキテクチャとなっています。
3
+ # refers: https://github.com/moment-timeseries-foundation-model/moment
4
+
5
+ from dataclasses import dataclass
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import math
9
+ import numpy.typing as npt
10
+ import torch
11
+ from torch import nn
12
+
13
+ from transformers import PreTrainedModel
14
+ from transformers import T5Config, T5Model
15
+ from transformers.utils import logging
16
+
17
+ from .configuration_moment import MomentConfig
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+ @dataclass
22
+ class TimeseriesOutputs:
23
+ # forecast: npt.NDArray = None
24
+ # anomaly_scores: npt.NDArray = None
25
+ logits: npt.NDArray = None
26
+ labels: int = None
27
+ input_mask: npt.NDArray = None
28
+ pretrain_mask: npt.NDArray = None
29
+ # reconstruction: npt.NDArray = None
30
+ embeddings: npt.NDArray = None
31
+ metadata: dict = None
32
+ illegal_output: bool = False
33
+ hidden_states: npt.NDArray = None # For Mists model
34
+
35
+
36
+ # refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/utils/masking.py#L6C1-L6C2
37
+ class Masking:
38
+ def __init__(
39
+ self, mask_ratio: float = 0.3, patch_len: int = 8, stride: Optional[int] = None
40
+ ):
41
+ """
42
+ Indices with 0 mask are hidden, and with 1 are observed.
43
+ """
44
+ self.mask_ratio = mask_ratio
45
+ self.patch_len = patch_len
46
+ self.stride = patch_len if stride is None else stride
47
+
48
+ @staticmethod
49
+ def convert_seq_to_patch_view(
50
+ mask: torch.Tensor, patch_len: int = 8, stride: Optional[int] = None
51
+ ):
52
+ """
53
+ Input:
54
+ mask : torch.Tensor of shape [batch_size x seq_len]
55
+ Output
56
+ mask : torch.Tensor of shape [batch_size x n_patches]
57
+ """
58
+ stride = patch_len if stride is None else stride
59
+ mask = mask.unfold(dimension=-1, size=patch_len, step=stride)
60
+ # mask : [batch_size x n_patches x patch_len]
61
+ return (mask.sum(dim=-1) == patch_len).long()
62
+
63
+ @staticmethod
64
+ def convert_patch_to_seq_view(
65
+ mask: torch.Tensor,
66
+ patch_len: int = 8,
67
+ ):
68
+ """
69
+ Input:
70
+ mask : torch.Tensor of shape [batch_size x n_patches]
71
+ Output:
72
+ mask : torch.Tensor of shape [batch_size x seq_len]
73
+ """
74
+ return mask.repeat_interleave(patch_len, dim=-1)
75
+
76
+ def generate_mask(self, x: torch.Tensor, input_mask: Optional[torch.Tensor] = None):
77
+ """
78
+ Input:
79
+ x : torch.Tensor of shape
80
+ [batch_size x n_channels x n_patches x patch_len] or
81
+ [batch_size x n_channels x seq_len]
82
+ input_mask: torch.Tensor of shape [batch_size x seq_len] or
83
+ [batch_size x n_patches]
84
+ Output:
85
+ mask : torch.Tensor of shape [batch_size x seq_len]
86
+ """
87
+ if x.ndim == 4:
88
+ return self._mask_patch_view(x, input_mask=input_mask)
89
+ elif x.ndim == 3:
90
+ return self._mask_seq_view(x, input_mask=input_mask)
91
+
92
+ def _mask_patch_view(self, x, input_mask=None):
93
+ """
94
+ Input:
95
+ x : torch.Tensor of shape
96
+ [batch_size x n_channels x n_patches x patch_len]
97
+ input_mask: torch.Tensor of shape [batch_size x seq_len]
98
+ Output:
99
+ mask : torch.Tensor of shape [batch_size x n_patches]
100
+ """
101
+ input_mask = self.convert_seq_to_patch_view(
102
+ input_mask, self.patch_len, self.stride
103
+ )
104
+ n_observed_patches = input_mask.sum(dim=-1, keepdim=True) # batch_size x 1
105
+
106
+ batch_size, _, n_patches, _ = x.shape
107
+ len_keep = torch.ceil(n_observed_patches * (1 - self.mask_ratio)).long()
108
+ noise = torch.rand(
109
+ batch_size, n_patches, device=x.device
110
+ ) # noise in [0, 1], batch_size x n_channels x n_patches
111
+ noise = torch.where(
112
+ input_mask == 1, noise, torch.ones_like(noise)
113
+ ) # only keep the noise of observed patches
114
+
115
+ # Sort noise for each sample
116
+ ids_shuffle = torch.argsort(
117
+ noise, dim=1
118
+ ) # Ascend: small is keep, large is remove
119
+ ids_restore = torch.argsort(
120
+ ids_shuffle, dim=1
121
+ ) # ids_restore: [batch_size x n_patches]
122
+
123
+ # Generate the binary mask: 0 is keep, 1 is remove
124
+ mask = torch.zeros(
125
+ [batch_size, n_patches], device=x.device
126
+ ) # mask: [batch_size x n_patches]
127
+ for i in range(batch_size):
128
+ mask[i, : len_keep[i]] = 1
129
+
130
+ # Unshuffle to get the binary mask
131
+ mask = torch.gather(mask, dim=1, index=ids_restore)
132
+
133
+ return mask.long()
134
+
135
+ def _mask_seq_view(self, x, input_mask=None):
136
+ """
137
+ Input:
138
+ x : torch.Tensor of shape
139
+ [batch_size x n_channels x seq_len]
140
+ input_mask: torch.Tensor of shape [batch_size x seq_len]
141
+ Output:
142
+ mask : torch.Tensor of shape [batch_size x seq_len]
143
+ """
144
+ x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
145
+ mask = self._mask_patch_view(x, input_mask=input_mask)
146
+ return self.convert_patch_to_seq_view(mask, self.patch_len).long()
147
+
148
+
149
+ # refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/layers/revin.py#L5
150
+ def nanvar(tensor, dim=None, keepdim=False):
151
+ tensor_mean = tensor.nanmean(dim=dim, keepdim=True)
152
+ output = (tensor - tensor_mean).square().nanmean(dim=dim, keepdim=keepdim)
153
+ return output
154
+
155
+ # refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/layers/revin.py#L11
156
+ def nanstd(tensor, dim=None, keepdim=False):
157
+ output = nanvar(tensor, dim=dim, keepdim=keepdim)
158
+ output = output.sqrt()
159
+ return output
160
+
161
+ # refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/layers/revin.py#L17
162
+ class RevIN(nn.Module):
163
+ def __init__(self, num_features: int, eps: float = 1e-5, affine: bool = False):
164
+ """
165
+ :param num_features: the number of features or channels
166
+ :param eps: a value added for numerical stability
167
+ :param affine: if True, RevIN has learnable affine parameters
168
+ """
169
+ super(RevIN, self).__init__()
170
+ self.num_features = num_features
171
+ self.eps = eps
172
+ self.affine = affine
173
+
174
+ if self.affine:
175
+ self._init_params()
176
+
177
+ def forward(self, x: torch.Tensor, mode: str = "norm", mask: torch.Tensor = None):
178
+ """
179
+ :param x: input tensor of shape (batch_size, n_channels, seq_len)
180
+ :param mode: 'norm' or 'denorm'
181
+ :param mask: input mask of shape (batch_size, seq_len)
182
+ :return: RevIN transformed tensor
183
+ """
184
+ if mode == "norm":
185
+ self._get_statistics(x, mask=mask)
186
+ x = self._normalize(x)
187
+ elif mode == "denorm":
188
+ x = self._denormalize(x)
189
+ else:
190
+ raise NotImplementedError
191
+ return x
192
+
193
+ def _init_params(self):
194
+ # initialize RevIN params: (C,)
195
+ self.affine_weight = nn.Parameter(torch.ones(1, self.num_features, 1))
196
+ self.affine_bias = nn.Parameter(torch.zeros(1, self.num_features, 1))
197
+
198
+ def _get_statistics(self, x, mask=None):
199
+ """
200
+ x : batch_size x n_channels x seq_len
201
+ mask : batch_size x seq_len
202
+ """
203
+ if mask is None:
204
+ mask = torch.ones((x.shape[0], x.shape[-1]))
205
+ n_channels = x.shape[1]
206
+ mask = mask.unsqueeze(1).repeat(1, n_channels, 1).bool()
207
+ # Set masked positions to NaN, and unmasked positions are taken from x
208
+ masked_x = torch.where(mask, x, torch.nan)
209
+ self.mean = torch.nanmean(masked_x, dim=-1, keepdim=True).detach()
210
+ self.stdev = nanstd(masked_x, dim=-1, keepdim=True).detach() + self.eps
211
+ # self.stdev = torch.sqrt(
212
+ # torch.var(masked_x, dim=-1, keepdim=True) + self.eps).get_data().detach()
213
+ # NOTE: By default not bessel correction
214
+
215
+ def _normalize(self, x):
216
+ x = x - self.mean
217
+ x = x / self.stdev
218
+
219
+ if self.affine:
220
+ x = x * self.affine_weight
221
+ x = x + self.affine_bias
222
+ return x
223
+
224
+ def _denormalize(self, x):
225
+ if self.affine:
226
+ x = x - self.affine_bias
227
+ x = x / (self.affine_weight + self.eps * self.eps)
228
+ x = x * self.stdev
229
+ x = x + self.mean
230
+ return x
231
+
232
+
233
+ # refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/layers/embed.py#L10
234
+ class PositionalEmbedding(nn.Module):
235
+ def __init__(self, d_model, max_len=5000, model_name="MOMENT"):
236
+ super(PositionalEmbedding, self).__init__()
237
+ self.model_name = model_name
238
+
239
+ # Compute the positional encodings once in log space.
240
+ pe = torch.zeros(max_len, d_model).float()
241
+ pe.require_grad = False
242
+
243
+ position = torch.arange(0, max_len).float().unsqueeze(1)
244
+ div_term = (
245
+ torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
246
+ ).exp()
247
+
248
+ pe[:, 0::2] = torch.sin(position * div_term)
249
+ pe[:, 1::2] = torch.cos(position * div_term)
250
+
251
+ pe = pe.unsqueeze(0)
252
+ self.register_buffer("pe", pe)
253
+
254
+ def forward(self, x):
255
+ if (
256
+ self.model_name == "MOMENT"
257
+ or self.model_name == "TimesNet"
258
+ or self.model_name == "GPT4TS"
259
+ ):
260
+ return self.pe[:, : x.size(2)]
261
+ else:
262
+ return self.pe[:, : x.size(1)]
263
+
264
+
265
+ # refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/layers/embed.py#L181
266
+ class PatchEmbedding(nn.Module):
267
+ def __init__(
268
+ self,
269
+ d_model: int = 768,
270
+ seq_len: int = 512,
271
+ patch_len: int = 8,
272
+ stride: int = 8,
273
+ dropout: int = 0.1,
274
+ add_positional_embedding: bool = False,
275
+ value_embedding_bias: bool = False,
276
+ orth_gain: float = 1.41,
277
+ ):
278
+ super(PatchEmbedding, self).__init__()
279
+ self.patch_len = patch_len
280
+ self.seq_len = seq_len
281
+ self.stride = stride
282
+ self.d_model = d_model
283
+ self.add_positional_embedding = add_positional_embedding
284
+
285
+ self.value_embedding = nn.Linear(patch_len, d_model, bias=value_embedding_bias)
286
+ self.mask_embedding = nn.Parameter(torch.zeros(d_model))
287
+
288
+ if orth_gain is not None:
289
+ torch.nn.init.orthogonal_(self.value_embedding.weight, gain=orth_gain)
290
+ if value_embedding_bias:
291
+ self.value_embedding.bias.data.zero_()
292
+ # torch.nn.init.orthogonal_(self.mask_embedding, gain=orth_gain) # Fails
293
+
294
+ # Positional embedding
295
+ if self.add_positional_embedding:
296
+ self.position_embedding = PositionalEmbedding(d_model)
297
+
298
+ # Residual dropout
299
+ self.dropout = nn.Dropout(dropout)
300
+
301
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
302
+ mask = Masking.convert_seq_to_patch_view(
303
+ mask, patch_len=self.patch_len
304
+ ).unsqueeze(-1)
305
+ # mask : [batch_size x n_patches x 1]
306
+ n_channels = x.shape[1]
307
+ mask = (
308
+ mask.repeat_interleave(self.d_model, dim=-1)
309
+ .unsqueeze(1)
310
+ .repeat(1, n_channels, 1, 1)
311
+ )
312
+ # mask : [batch_size x n_channels x n_patches x d_model]
313
+
314
+ # Input encoding
315
+ x = mask * self.value_embedding(x) + (1 - mask) * self.mask_embedding
316
+ if self.add_positional_embedding:
317
+ x = x + self.position_embedding(x)
318
+
319
+ return self.dropout(x)
320
+
321
+
322
+ # refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/layers/embed.py#L237C1-L251C17
323
+ class Patching(nn.Module):
324
+ def __init__(self, patch_len: int, stride: int):
325
+ super().__init__()
326
+ self.patch_len = patch_len
327
+ self.stride = stride
328
+ if self.stride != self.patch_len:
329
+ logger.warning(
330
+ "Stride and patch length are not equal. "
331
+ "This may lead to unexpected behavior."
332
+ )
333
+
334
+ def forward(self, x):
335
+ x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
336
+ # x : [batch_size x n_channels x num_patch x patch_len]
337
+ return x
338
+
339
+
340
+ class MomentPreTrainedModel(PreTrainedModel):
341
+ config_class = MomentConfig
342
+
343
+ base_model_prefix = "model"
344
+ supports_gradient_checkpointing = True
345
+ _no_split_modules = ["T5Block"]
346
+ _skip_keys_device_placement = ""
347
+
348
+ # 本来のT5の_init_weightsはもっと詳細だが、事前学習の予定はないためここでは簡単にしている。
349
+ # refers: https://github.com/huggingface/transformers/blob/517df566f572d90e6301df87870f651f0d1b1110/src/transformers/models/t5/modeling_t5.py#L810
350
+ def _init_weights(self, module):
351
+ std = self.config.t5_config["initializer_factor"]
352
+ if isinstance(module, nn.Linear):
353
+ module.weight.data.normal_(mean=0.0, std=std)
354
+ if module.bias is not None:
355
+ module.bias.data.zero_()
356
+ elif isinstance(module, nn.Embedding):
357
+ module.weight.data.normal_(mean=0.0, std=std)
358
+ if module.padding_idx is not None:
359
+ module.weight.data[module.padding_idx].zero_()
360
+
361
+
362
+ class MomentEmbeddingModel(MomentPreTrainedModel):
363
+ def __init__(self, config):
364
+ super().__init__(config)
365
+ self.config = config
366
+ self.seq_len = config.seq_len
367
+ self.patch_len = config.patch_len
368
+
369
+ # TODO: normalizer, tokenizerはProcessor側に配置するべきか?
370
+ # 現状の考え: 特にMomentから切り離す用途もない。
371
+ #       Processor側では入力の512timestepsへの切り取り等、
372
+ #       input validationとTensorへの切り替えを行うで良さそう。
373
+ self.normalizer = RevIN(
374
+ num_features=getattr(config, "revin_num_features", 1), eps=getattr(config, "revin_eps", 1e-5), affine=getattr(config, "revin_affine", False)
375
+ )
376
+ self.tokenizer = Patching(
377
+ patch_len=config.patch_len, stride=config.patch_stride_len
378
+ )
379
+ # モデル構成
380
+ self.patch_embedding = PatchEmbedding(
381
+ d_model=config.d_model,
382
+ seq_len=config.seq_len,
383
+ patch_len=config.patch_len,
384
+ stride=config.patch_stride_len,
385
+ dropout=getattr(config, "dropout", 0.1),
386
+ add_positional_embedding=getattr(config, "add_positional_embedding", True),
387
+ value_embedding_bias=getattr(config, "value_embedding_bias", False),
388
+ orth_gain=getattr(config, "orth_gain", 1.41),
389
+ )
390
+ self.mask_generator = Masking(mask_ratio=getattr(config, "mask_ratio", 0.0))
391
+ self.encoder = self._get_t5_encoder(config.t5_config, config.enable_gradient_checkpointing)
392
+ self.head = nn.Identity()
393
+
394
+ # Frozen parameters
395
+ self.freeze_embedder = getattr(config, "freeze_embedder", True)
396
+ self.freeze_encoder = getattr(config, "freeze_encoder", True)
397
+ self.freeze_head = getattr(config, "freeze_head", False)
398
+
399
+ if self.freeze_embedder:
400
+ self.patch_embedding = freeze_parameters(self.patch_embedding)
401
+ if self.freeze_encoder:
402
+ self.encoder = freeze_parameters(self.encoder)
403
+ if self.freeze_head:
404
+ self.head = freeze_parameters(self.head)
405
+
406
+ def _get_t5_encoder(self, config: dict, enable_gradient_checkpointing: bool) -> nn.Module:
407
+ # random initialize
408
+ # Momentでは(言語で)事前学習済みのモデルを取得することもできるようになっている
409
+ # refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/moment.py#L205
410
+ t5_config = T5Config.from_dict(config)
411
+ t5_model = T5Model(t5_config)
412
+ t5_model_encoder = t5_model.get_encoder()
413
+
414
+ if enable_gradient_checkpointing:
415
+ t5_model_encoder.gradient_checkpointing_enable()
416
+ logger.info("Enabling gradient checkpointing.")
417
+
418
+ return t5_model_encoder
419
+
420
+ def embed(
421
+ self,
422
+ x_enc: torch.Tensor,
423
+ input_mask: torch.Tensor = None,
424
+ reduction: str = "mean",
425
+ **kwargs,
426
+ ) -> TimeseriesOutputs:
427
+ batch_size, n_channels, seq_len = x_enc.shape
428
+
429
+ if input_mask is None:
430
+ input_mask = torch.ones((batch_size, seq_len)).to(x_enc.device)
431
+
432
+ x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
433
+ x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)
434
+
435
+ input_mask_patch_view = Masking.convert_seq_to_patch_view(
436
+ input_mask, self.patch_len
437
+ )
438
+
439
+ x_enc = self.tokenizer(x=x_enc)
440
+ enc_in = self.patch_embedding(x_enc, mask=input_mask)
441
+
442
+ n_patches = enc_in.shape[2]
443
+ enc_in = enc_in.reshape(
444
+ (batch_size * n_channels, n_patches, self.config.d_model)
445
+ )
446
+
447
+ patch_view_mask = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
448
+ attention_mask = patch_view_mask.repeat_interleave(n_channels, dim=0)
449
+ outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
450
+ enc_out = outputs.last_hidden_state
451
+
452
+ enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
453
+ # [batch_size x n_channels x n_patches x d_model]
454
+
455
+ # For Mists model
456
+ # [batch_size, n_channels x n_patches, d_model]
457
+ hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
458
+
459
+ if reduction == "mean":
460
+ enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
461
+ # [batch_size x n_patches x d_model]
462
+ input_mask_patch_view = input_mask_patch_view.unsqueeze(-1).repeat(
463
+ 1, 1, self.config.d_model
464
+ )
465
+ enc_out = (input_mask_patch_view * enc_out).sum(
466
+ dim=1
467
+ ) / input_mask_patch_view.sum(dim=1)
468
+ else:
469
+ raise NotImplementedError(f"Reduction method {reduction} not implemented.")
470
+
471
+ return TimeseriesOutputs(
472
+ embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states
473
+ )
474
+
475
+ def forward(
476
+ self,
477
+ time_series_values: torch.Tensor,
478
+ # mask: torch.Tensor = None,
479
+ input_mask: torch.Tensor = None,
480
+ **kwargs,
481
+ ) -> TimeseriesOutputs:
482
+ if input_mask is None:
483
+ input_mask = torch.ones_like(time_series_values[:, 0, :])
484
+
485
+ return self.embed(x_enc=time_series_values, input_mask=input_mask, **kwargs)
486
+
487
+
488
+ # refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/moment.py#L601
489
+ def freeze_parameters(model):
490
+ """
491
+ Freeze parameters of the model
492
+ """
493
+ # Freeze the parameters
494
+ for name, param in model.named_parameters():
495
+ param.requires_grad = False
496
+
497
+ return model