yixinsong commited on
Commit
8f55951
1 Parent(s): d7c6bda
Files changed (2) hide show
  1. config.json +1 -1
  2. modeling_supersparsemixtral.py +1 -1
config.json CHANGED
@@ -15,7 +15,7 @@
15
  "initializer_range": 0.02,
16
  "intermediate_size": 14336,
17
  "max_position_embeddings": 32768,
18
- "model_type": "mixtral",
19
  "num_attention_heads": 32,
20
  "num_experts_per_tok": 2,
21
  "num_hidden_layers": 32,
 
15
  "initializer_range": 0.02,
16
  "intermediate_size": 14336,
17
  "max_position_embeddings": 32768,
18
+ "model_type": "supersparsemixtral",
19
  "num_attention_heads": 32,
20
  "num_experts_per_tok": 2,
21
  "num_hidden_layers": 32,
modeling_supersparsemixtral.py CHANGED
@@ -1280,7 +1280,7 @@ class SuperSparseMixtralBlockSparseTop2MLP(nn.Module):
1280
 
1281
  def forward(self, hidden_states):
1282
  mask = self.predictor(hidden_states)
1283
- current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
1284
  hard_mask = torch.round(mask)
1285
  mask = mask + (hard_mask - mask).detach()
1286
  current_hidden_states = torch.mul(current_hidden_states, mask)
 
1280
 
1281
  def forward(self, hidden_states):
1282
  mask = self.predictor(hidden_states)
1283
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.act_fn(self.w3(hidden_states))
1284
  hard_mask = torch.round(mask)
1285
  mask = mask + (hard_mask - mask).detach()
1286
  current_hidden_states = torch.mul(current_hidden_states, mask)