HachiML commited on
Commit
1ad822b
1 Parent(s): e5119d1

Upload modeling_moment.py

Browse files
Files changed (1) hide show
  1. modeling_moment.py +4 -1
modeling_moment.py CHANGED
@@ -449,12 +449,15 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
449
  outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
450
  enc_out = outputs.last_hidden_state
451
 
 
 
 
452
  enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
453
  # [batch_size x n_channels x n_patches x d_model]
454
 
455
  # For Mists model
456
  # [batch_size, n_channels x n_patches, d_model]
457
- hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
458
 
459
  if reduction == "mean":
460
  enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
 
449
  outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
450
  enc_out = outputs.last_hidden_state
451
 
452
+ # For Mists model
453
+ hidden_states = outputs.last_hidden_state
454
+
455
  enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
456
  # [batch_size x n_channels x n_patches x d_model]
457
 
458
  # For Mists model
459
  # [batch_size, n_channels x n_patches, d_model]
460
+ # hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
461
 
462
  if reduction == "mean":
463
  enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels