Maple728 commited on
Commit
105fd0a
·
verified ·
1 Parent(s): ee8156c

Update modeling_time_moe.py

Browse files
Files changed (1) hide show
  1. modeling_time_moe.py +3 -6
modeling_time_moe.py CHANGED
@@ -25,6 +25,7 @@ try:
25
  except:
26
  pass
27
 
 
28
  def _get_unpad_data(attention_mask):
29
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
30
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
@@ -66,7 +67,7 @@ def load_balancing_loss_func(
66
  The auxiliary loss.
67
  """
68
  if gate_logits is None or not isinstance(gate_logits, (tuple, list)) or gate_logits[0] is None:
69
- return None
70
 
71
  compute_device = gate_logits[0].device
72
  concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
@@ -293,7 +294,7 @@ class TimeMoeSparseExpertsLayer(nn.Module):
293
  """ """
294
  batch_size, sequence_length, hidden_dim = hidden_states.shape
295
  hidden_states = hidden_states.view(-1, hidden_dim)
296
- # router_logits: (batch * sequence_length, n_experts)
297
  router_logits = self.gate(hidden_states)
298
 
299
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
@@ -764,8 +765,6 @@ class TimeMoeModel(TimeMoePreTrainedModel):
764
 
765
  def __init__(self, config: TimeMoeConfig):
766
  super().__init__(config)
767
- # self.padding_idx = config.pad_token_id
768
-
769
  self.embed_layer = TimeMoeInputEmbedding(config)
770
  self.layers = nn.ModuleList(
771
  [TimeMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
@@ -1096,12 +1095,10 @@ class TimeMoeForPrediction(TimeMoePreTrainedModel, TSGenerationMixin):
1096
  shift_labels = labels
1097
 
1098
  # Calculate loss with mask
1099
- # losses = self.loss_function(shift_predictions.to(torch.float32), shift_labels.to(torch.float32))
1100
  losses = self.loss_function(shift_predictions, shift_labels)
1101
 
1102
  if loss_masks is not None:
1103
  losses = losses * loss_masks
1104
-
1105
  loss = losses.sum() / loss_masks.sum()
1106
  else:
1107
  loss = torch.mean(losses)
 
25
  except:
26
  pass
27
 
28
+
29
  def _get_unpad_data(attention_mask):
30
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
31
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
 
67
  The auxiliary loss.
68
  """
69
  if gate_logits is None or not isinstance(gate_logits, (tuple, list)) or gate_logits[0] is None:
70
+ return 0.0
71
 
72
  compute_device = gate_logits[0].device
73
  concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
 
294
  """ """
295
  batch_size, sequence_length, hidden_dim = hidden_states.shape
296
  hidden_states = hidden_states.view(-1, hidden_dim)
297
+ # router_logits -> (batch * sequence_length, n_experts)
298
  router_logits = self.gate(hidden_states)
299
 
300
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
 
765
 
766
  def __init__(self, config: TimeMoeConfig):
767
  super().__init__(config)
 
 
768
  self.embed_layer = TimeMoeInputEmbedding(config)
769
  self.layers = nn.ModuleList(
770
  [TimeMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 
1095
  shift_labels = labels
1096
 
1097
  # Calculate loss with mask
 
1098
  losses = self.loss_function(shift_predictions, shift_labels)
1099
 
1100
  if loss_masks is not None:
1101
  losses = losses * loss_masks
 
1102
  loss = losses.sum() / loss_masks.sum()
1103
  else:
1104
  loss = torch.mean(losses)