jonathanjordan21 commited on
Commit
0dd1cff
·
verified ·
1 Parent(s): c2627ca

Update modeling_mos_mamba.py

Browse files
Files changed (1) hide show
  1. modeling_mos_mamba.py +6 -6
modeling_mos_mamba.py CHANGED
@@ -604,12 +604,12 @@ class MoSMambaMixer(nn.Module):
604
  # expert_layer.grad = torch.zeros_like(expert_layer.weight)
605
  # current_hidden_states = expert_layer(current_state)
606
 
607
- current_hidden_states = current_hidden_states.reshape(-1, hidden_dim)
608
- # print(current_hidden_states.shape, final_hidden_states.shape)
609
-
610
- # However `index_add_` only support torch tensors for indexing so we'll use
611
- # the `top_x` tensor here.
612
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
613
  final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
614
 
615
  return final_hidden_states, router_logits
 
604
  # expert_layer.grad = torch.zeros_like(expert_layer.weight)
605
  # current_hidden_states = expert_layer(current_state)
606
 
607
+ current_hidden_states = current_hidden_states.reshape(-1, hidden_dim)
608
+ # print(current_hidden_states.shape, final_hidden_states.shape)
609
+
610
+ # However `index_add_` only support torch tensors for indexing so we'll use
611
+ # the `top_x` tensor here.
612
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
613
  final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
614
 
615
  return final_hidden_states, router_logits