HachiML commited on
Commit
a05500d
1 Parent(s): 5f2d447

Upload modeling_moment.py

Browse files
Files changed (1) hide show
  1. modeling_moment.py +1 -0
modeling_moment.py CHANGED
@@ -456,6 +456,7 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
456
  # For Mists model
457
  # [batch_size, n_channels x n_patches, d_model]
458
  # hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
 
459
 
460
  if reduction == "mean":
461
  enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
 
456
  # For Mists model
457
  # [batch_size, n_channels x n_patches, d_model]
458
  # hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
459
+ hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
460
 
461
  if reduction == "mean":
462
  enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels