Update modeling_super_linear.py
Browse files- modeling_super_linear.py +9 -6
modeling_super_linear.py
CHANGED
|
@@ -525,9 +525,10 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 525 |
|
| 526 |
|
| 527 |
# the backbone keeps its own Config dataclass, so build one on‑the‑fly:
|
| 528 |
-
backbone_cfg
|
| 529 |
-
self.args
|
| 530 |
-
self.backbone
|
|
|
|
| 531 |
self.post_init()
|
| 532 |
|
| 533 |
# ------------------------------------------------------------------
|
|
@@ -589,15 +590,17 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 589 |
|
| 590 |
# backbone expects (B, C, L)
|
| 591 |
x_enc = inputs_embeds
|
|
|
|
|
|
|
| 592 |
|
| 593 |
if x_enc.shape[1] < 512:
|
| 594 |
x_enc = self.fourier_interp_dim1(x_enc)
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
x_enc = (x_enc - mean) / std
|
| 598 |
|
| 599 |
# backbone returns (B, pred_len, C)
|
| 600 |
preds = self.backbone(x_enc)
|
|
|
|
| 601 |
return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
|
| 602 |
|
| 603 |
|
|
|
|
| 525 |
|
| 526 |
|
| 527 |
# the backbone keeps its own Config dataclass, so build one on‑the‑fly:
|
| 528 |
+
backbone_cfg = type("Cfg", (), config.to_dict())()
|
| 529 |
+
self.args = backbone_cfg
|
| 530 |
+
self.backbone = superLinear(backbone_cfg)
|
| 531 |
+
self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
|
| 532 |
self.post_init()
|
| 533 |
|
| 534 |
# ------------------------------------------------------------------
|
|
|
|
| 590 |
|
| 591 |
# backbone expects (B, C, L)
|
| 592 |
x_enc = inputs_embeds
|
| 593 |
+
|
| 594 |
+
|
| 595 |
|
| 596 |
if x_enc.shape[1] < 512:
|
| 597 |
x_enc = self.fourier_interp_dim1(x_enc)
|
| 598 |
+
x_enc = self.revin_layer(x_enc, 'norm')
|
| 599 |
+
|
|
|
|
| 600 |
|
| 601 |
# backbone returns (B, pred_len, C)
|
| 602 |
preds = self.backbone(x_enc)
|
| 603 |
+
preds = self.revin_layer(preds, 'denorm')
|
| 604 |
return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
|
| 605 |
|
| 606 |
|