HachiML commited on
Commit
39ad17e
1 Parent(s): d2a0d69

Upload modeling_moment.py

Browse files
Files changed (1) hide show
  1. modeling_moment.py +3 -0
modeling_moment.py CHANGED
@@ -478,7 +478,10 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
478
  input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(-1).repeat(
479
  1, n_channels, 1, self.config.d_model
480
  )
 
481
  hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
 
 
482
 
483
  return TimeseriesOutputs(
484
  embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states
 
478
  input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(-1).repeat(
479
  1, n_channels, 1, self.config.d_model
480
  )
481
+ # [batch_size x n_channels x n_patches x d_model]
482
  hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
483
+ # [batch_size, n_channels x n_patches, d_model]
484
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
485
 
486
  return TimeseriesOutputs(
487
  embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states