Update modeling_super_linear.py
Browse files- modeling_super_linear.py +5 -1
modeling_super_linear.py
CHANGED
|
@@ -452,11 +452,13 @@ class superLinear(nn.Module):
|
|
| 452 |
x = x_enc.permute(0, 2, 1)
|
| 453 |
B, V, L = x.shape
|
| 454 |
else:
|
| 455 |
-
x
|
| 456 |
B, L = x.shape
|
| 457 |
V = 1
|
| 458 |
|
| 459 |
x = x.reshape(B * V, L)
|
|
|
|
|
|
|
| 460 |
|
| 461 |
expert_probs = None
|
| 462 |
|
|
@@ -473,7 +475,9 @@ class superLinear(nn.Module):
|
|
| 473 |
outputs.append(ar_out)
|
| 474 |
ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
|
| 475 |
out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
|
|
|
|
| 476 |
out = out.reshape(B, V, out.shape[-1])
|
|
|
|
| 477 |
result = out.permute(0, 2, 1)
|
| 478 |
|
| 479 |
if get_prob:
|
|
|
|
| 452 |
x = x_enc.permute(0, 2, 1)
|
| 453 |
B, V, L = x.shape
|
| 454 |
else:
|
| 455 |
+
x = x_enc
|
| 456 |
B, L = x.shape
|
| 457 |
V = 1
|
| 458 |
|
| 459 |
x = x.reshape(B * V, L)
|
| 460 |
+
print("RAZ")
|
| 461 |
+
print(x.shape)
|
| 462 |
|
| 463 |
expert_probs = None
|
| 464 |
|
|
|
|
| 475 |
outputs.append(ar_out)
|
| 476 |
ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.seq_len:]
|
| 477 |
out = torch.cat(outputs, dim=1)[:, :self.inf_pred_len]
|
| 478 |
+
print(F"out1 :{out.shape}")
|
| 479 |
out = out.reshape(B, V, out.shape[-1])
|
| 480 |
+
print(F"out2 :{out.shape}")
|
| 481 |
result = out.permute(0, 2, 1)
|
| 482 |
|
| 483 |
if get_prob:
|