HachiML commited on
Commit
e41400f
1 Parent(s): fa69247

Upload modeling_moment.py

Browse files
Files changed (1) hide show
  1. modeling_moment.py +2 -1
modeling_moment.py CHANGED
@@ -448,13 +448,14 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
448
  attention_mask = patch_view_mask.repeat_interleave(n_channels, dim=0)
449
  outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
450
  enc_out = outputs.last_hidden_state
 
451
 
452
  enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
453
  # [batch_size x n_channels x n_patches x d_model]
454
 
455
  # For Mists model
456
  # [batch_size, n_channels x n_patches, d_model]
457
- hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
458
 
459
  if reduction == "mean":
460
  enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
 
448
  attention_mask = patch_view_mask.repeat_interleave(n_channels, dim=0)
449
  outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
450
  enc_out = outputs.last_hidden_state
451
+ hidden_states = outputs.hidden_states # hidden_statesを取得
452
 
453
  enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
454
  # [batch_size x n_channels x n_patches x d_model]
455
 
456
  # For Mists model
457
  # [batch_size, n_channels x n_patches, d_model]
458
+ # hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
459
 
460
  if reduction == "mean":
461
  enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels