HachiML commited on
Commit
d2a0d69
1 Parent(s): d06b53f

Upload modeling_moment.py

Browse files
Files changed (1) hide show
  1. modeling_moment.py +15 -5
modeling_moment.py CHANGED
@@ -432,6 +432,7 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
432
  x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
433
  x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)
434
 
 
435
  input_mask_patch_view = Masking.convert_seq_to_patch_view(
436
  input_mask, self.patch_len
437
  )
@@ -453,11 +454,6 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
453
  enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
454
  # [batch_size x n_channels x n_patches x d_model]
455
 
456
- # For Mists model
457
- # [batch_size, n_channels x n_patches, d_model]
458
- # Ensure hidden_states are consistent for both short and long inputs with input_mask specified
459
- hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
460
-
461
  if reduction == "mean":
462
  enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
463
  # [batch_size x n_patches x d_model]
@@ -469,6 +465,20 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
469
  ) / input_mask_patch_view.sum(dim=1)
470
  else:
471
  raise NotImplementedError(f"Reduction method {reduction} not implemented.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
  return TimeseriesOutputs(
474
  embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states
 
432
  x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
433
  x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)
434
 
435
+ # [batch_size x n_patches]
436
  input_mask_patch_view = Masking.convert_seq_to_patch_view(
437
  input_mask, self.patch_len
438
  )
 
454
  enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
455
  # [batch_size x n_channels x n_patches x d_model]
456
 
 
 
 
 
 
457
  if reduction == "mean":
458
  enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
459
  # [batch_size x n_patches x d_model]
 
465
  ) / input_mask_patch_view.sum(dim=1)
466
  else:
467
  raise NotImplementedError(f"Reduction method {reduction} not implemented.")
468
+
469
+ # For Mists model
470
+ # [batch_size, n_channels x n_patches, d_model]
471
+ # Ensure hidden_states are consistent for both short and long inputs with input_mask specified
472
+ # hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
473
+ # [batch_size x n_channels x n_patches x d_model]
474
+ hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model)
475
+ # [batch_size x n_patches]
476
+ input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
477
+ # [batch_size x n_channels x n_patches x d_model]
478
+ input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(-1).repeat(
479
+ 1, n_channels, 1, self.config.d_model
480
+ )
481
+ hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
482
 
483
  return TimeseriesOutputs(
484
  embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states