razmars commited on
Commit
aa6df5b
·
verified ·
1 Parent(s): 934cf7d

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +5 -1
modeling_super_linear.py CHANGED
@@ -452,11 +452,13 @@ class superLinear(nn.Module):
452
  x = x_enc.permute(0, 2, 1)
453
  B, V, L = x.shape
454
  else:
455
- x = x_enc
456
  B, L = x.shape
457
  V = 1
458
 
459
  x = x.reshape(B * V, L)
 
 
460
 
461
  expert_probs = None
462
 
@@ -473,7 +475,9 @@ class superLinear(nn.Module):
473
  outputs.append(ar_out)
474
  ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
475
  out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
 
476
  out = out.reshape(B, V, out.shape[-1])
 
477
  result = out.permute(0, 2, 1)
478
 
479
  if get_prob:
 
452
  x = x_enc.permute(0, 2, 1)
453
  B, V, L = x.shape
454
  else:
455
+ x = x_enc
456
  B, L = x.shape
457
  V = 1
458
 
459
  x = x.reshape(B * V, L)
460
+ print("RAZ")
461
+ print(x.shape)
462
 
463
  expert_probs = None
464
 
 
475
  outputs.append(ar_out)
476
  ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
477
  out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
478
+ print(F"out1 :{out.shape}")
479
  out = out.reshape(B, V, out.shape[-1])
480
+ print(F"out2 :{out.shape}")
481
  result = out.permute(0, 2, 1)
482
 
483
  if get_prob: