Upload modeling_moment.py
Browse files- modeling_moment.py +0 -6
modeling_moment.py
CHANGED
@@ -430,8 +430,6 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
430 |
if input_mask is None:
|
431 |
input_mask = torch.ones((batch_size, seq_len)).to(x_enc.device)
|
432 |
|
433 |
-
print("*input_mask: ", input_mask.shape)
|
434 |
-
|
435 |
x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
|
436 |
x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)
|
437 |
|
@@ -474,17 +472,13 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
474 |
# Ensure hidden_states are consistent for both short and long inputs with input_mask specified
|
475 |
# hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
476 |
# [batch_size x n_patches]
|
477 |
-
print("*input_mask: ", input_mask.shape)
|
478 |
input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
|
479 |
# [batch_size x n_channels x n_patches x d_model]
|
480 |
-
print("*input_mask_patch_view_for_hidden_states: ", input_mask_patch_view_for_hidden_states.shape)
|
481 |
input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(1).unsqueeze(-1).repeat(
|
482 |
1, n_channels, 1, self.config.d_model
|
483 |
)
|
484 |
# [batch_size x n_channels x n_patches x d_model]
|
485 |
hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model)
|
486 |
-
print("*input_mask_patch_view_for_hidden_states: ", input_mask_patch_view_for_hidden_states.shape)
|
487 |
-
print("*hidden_states: ", hidden_states.shape)
|
488 |
hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
|
489 |
# [batch_size, n_channels x n_patches, d_model]
|
490 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
|
|
430 |
if input_mask is None:
|
431 |
input_mask = torch.ones((batch_size, seq_len)).to(x_enc.device)
|
432 |
|
|
|
|
|
433 |
x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
|
434 |
x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)
|
435 |
|
|
|
472 |
# Ensure hidden_states are consistent for both short and long inputs with input_mask specified
|
473 |
# hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
474 |
# [batch_size x n_patches]
|
|
|
475 |
input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
|
476 |
# [batch_size x n_channels x n_patches x d_model]
|
|
|
477 |
input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(1).unsqueeze(-1).repeat(
|
478 |
1, n_channels, 1, self.config.d_model
|
479 |
)
|
480 |
# [batch_size x n_channels x n_patches x d_model]
|
481 |
hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model)
|
|
|
|
|
482 |
hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
|
483 |
# [batch_size, n_channels x n_patches, d_model]
|
484 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|