HachiML commited on
Commit
15d09b0
1 Parent(s): 1ad822b

Upload modeling_moment.py

Browse files
Files changed (1) hide show
  1. modeling_moment.py +1 -4
modeling_moment.py CHANGED
@@ -449,15 +449,12 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
449
  outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
450
  enc_out = outputs.last_hidden_state
451
 
452
- # For Mists model
453
- hidden_states = outputs.last_hidden_state
454
-
455
  enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
456
  # [batch_size x n_channels x n_patches x d_model]
457
 
458
  # For Mists model
459
  # [batch_size, n_channels x n_patches, d_model]
460
- # hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
461
 
462
  if reduction == "mean":
463
  enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
 
449
  outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
450
  enc_out = outputs.last_hidden_state
451
 
 
 
 
452
  enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
453
  # [batch_size x n_channels x n_patches x d_model]
454
 
455
  # For Mists model
456
  # [batch_size, n_channels x n_patches, d_model]
457
+ hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
458
 
459
  if reduction == "mean":
460
  enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels