razmars commited on
Commit
a63396c
·
verified ·
1 Parent(s): 302f34b

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +4 -3
modeling_super_linear.py CHANGED
@@ -476,13 +476,14 @@ class superLinear(nn.Module):
476
  ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
477
  out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
478
 
479
- print(F"out1 :{out.shape}")
480
  if len(x_enc.shape) > 2:
481
- result = out
482
- else:
483
  out = out.reshape(B, V, out.shape[-1])
484
  print(F"out2 :{out.shape}")
485
  result = out.permute(0, 2, 1)
 
 
 
 
486
 
487
  if get_prob:
488
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
 
476
  ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
477
  out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
478
 
 
479
  if len(x_enc.shape) > 2:
 
 
480
  out = out.reshape(B, V, out.shape[-1])
481
  print(F"out2 :{out.shape}")
482
  result = out.permute(0, 2, 1)
483
+ else:
484
+ print(F"out1 :{out.shape}")
485
+ result = out
486
+
487
 
488
  if get_prob:
489
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])