HachiML commited on
Commit
33922ac
1 Parent(s): e3c91e6

Upload modeling_moment.py

Browse files
Files changed (1) hide show
  1. modeling_moment.py +4 -1
modeling_moment.py CHANGED
@@ -430,6 +430,8 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
430
  if input_mask is None:
431
  input_mask = torch.ones((batch_size, seq_len)).to(x_enc.device)
432
 
 
 
433
  x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
434
  x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)
435
 
@@ -472,6 +474,7 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
472
  # Ensure hidden_states are consistent for both short and long inputs with input_mask specified
473
  # hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
474
  # [batch_size x n_patches]
 
475
  input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
476
  # [batch_size x n_channels x n_patches x d_model]
477
  input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(-1).repeat(
@@ -480,7 +483,7 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
480
  # [batch_size x n_channels x n_patches x d_model]
481
  hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model)
482
  print("*input_mask_patch_view_for_hidden_states: ", input_mask_patch_view_for_hidden_states.shape)
483
- print("hidden_states: ", hidden_states.shape)
484
  hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
485
  # [batch_size, n_channels x n_patches, d_model]
486
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
 
430
  if input_mask is None:
431
  input_mask = torch.ones((batch_size, seq_len)).to(x_enc.device)
432
 
433
+ print("*input_mask: ", input_mask.shape)
434
+
435
  x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
436
  x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)
437
 
 
474
  # Ensure hidden_states are consistent for both short and long inputs with input_mask specified
475
  # hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
476
  # [batch_size x n_patches]
477
+ print("*input_mask: ", input_mask.shape)
478
  input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
479
  # [batch_size x n_channels x n_patches x d_model]
480
  input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(-1).repeat(
 
483
  # [batch_size x n_channels x n_patches x d_model]
484
  hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model)
485
  print("*input_mask_patch_view_for_hidden_states: ", input_mask_patch_view_for_hidden_states.shape)
486
+ print("hidden_states: ", hidden_states.shape)
487
  hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
488
  # [batch_size, n_channels x n_patches, d_model]
489
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)