HachiML commited on
Commit
a3a58e4
1 Parent(s): 39ad17e

Upload modeling_moment.py

Browse files
Files changed (1) hide show
  1. modeling_moment.py +2 -3
modeling_moment.py CHANGED
@@ -470,15 +470,14 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
470
  # [batch_size, n_channels x n_patches, d_model]
471
  # Ensure hidden_states are consistent for both short and long inputs with input_mask specified
472
  # hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
473
- # [batch_size x n_channels x n_patches x d_model]
474
- hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model)
475
  # [batch_size x n_patches]
476
  input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
477
  # [batch_size x n_channels x n_patches x d_model]
478
  input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(-1).repeat(
479
  1, n_channels, 1, self.config.d_model
480
  )
481
- # [batch_size x n_channels x n_patches x d_model]
 
482
  hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
483
  # [batch_size, n_channels x n_patches, d_model]
484
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
 
470
  # [batch_size, n_channels x n_patches, d_model]
471
  # Ensure hidden_states are consistent for both short and long inputs with input_mask specified
472
  # hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
 
 
473
  # [batch_size x n_patches]
474
  input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
475
  # [batch_size x n_channels x n_patches x d_model]
476
  input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(-1).repeat(
477
  1, n_channels, 1, self.config.d_model
478
  )
479
+ # [batch_size x n_channels x n_patches x d_model]
480
+ hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model)
481
  hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
482
  # [batch_size, n_channels x n_patches, d_model]
483
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)