Upload modeling_moment.py
Browse files- modeling_moment.py +15 -0
modeling_moment.py
CHANGED
@@ -503,6 +503,21 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
503 |
input_mask = torch.ones_like(time_series_values[:, 0, :])
|
504 |
|
505 |
return self.embed(x_enc=time_series_values, input_mask=input_mask, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
506 |
|
507 |
|
508 |
# refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/moment.py#L601
|
|
|
503 |
input_mask = torch.ones_like(time_series_values[:, 0, :])
|
504 |
|
505 |
return self.embed(x_enc=time_series_values, input_mask=input_mask, **kwargs)
|
506 |
+
|
507 |
+
def calculate_n_patches(self, seq_len: int) -> int:
|
508 |
+
"""
|
509 |
+
時系列の長さ(seq_len)を与えて、モデルのself.patch_lenとself.strideを使ってn_patchesを計算して返します。
|
510 |
+
strideがNoneの場合はpatch_lenを使用します。
|
511 |
+
|
512 |
+
Args:
|
513 |
+
seq_len (int): 時系列の長さ
|
514 |
+
|
515 |
+
Returns:
|
516 |
+
int: 計算されたn_patchesの数
|
517 |
+
"""
|
518 |
+
stride = self.stride if self.stride is not None else self.patch_len
|
519 |
+
n_patches = (seq_len - self.patch_len) // stride + 1
|
520 |
+
return n_patches
|
521 |
|
522 |
|
523 |
# refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/moment.py#L601
|