razmars commited on
Commit
279789a
·
verified ·
1 Parent(s): 692fb2b

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +9 -6
modeling_super_linear.py CHANGED
@@ -525,9 +525,10 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
525
 
526
 
527
  # the backbone keeps its own Config dataclass, so build one on‑the‑fly:
528
- backbone_cfg = type("Cfg", (), config.to_dict())()
529
- self.args = backbone_cfg
530
- self.backbone = superLinear(backbone_cfg)
 
531
  self.post_init()
532
 
533
  # ------------------------------------------------------------------
@@ -589,15 +590,17 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
589
 
590
  # backbone expects (B, C, L)
591
  x_enc = inputs_embeds
 
 
592
 
593
  if x_enc.shape[1] < 512:
594
  x_enc = self.fourier_interp_dim1(x_enc)
595
- mean = x_enc.mean()
596
- std = x_enc.std().clamp_min(1e-6)
597
- x_enc = (x_enc - mean) / std
598
 
599
  # backbone returns (B, pred_len, C)
600
  preds = self.backbone(x_enc)
 
601
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
602
 
603
 
 
525
 
526
 
527
  # the backbone keeps its own Config dataclass, so build one on‑the‑fly:
528
+ backbone_cfg = type("Cfg", (), config.to_dict())()
529
+ self.args = backbone_cfg
530
+ self.backbone = superLinear(backbone_cfg)
531
+ self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
532
  self.post_init()
533
 
534
  # ------------------------------------------------------------------
 
590
 
591
  # backbone expects (B, C, L)
592
  x_enc = inputs_embeds
593
+
594
+
595
 
596
  if x_enc.shape[1] < 512:
597
  x_enc = self.fourier_interp_dim1(x_enc)
598
+ x_enc = self.revin_layer(x_enc, 'norm')
599
+
 
600
 
601
  # backbone returns (B, pred_len, C)
602
  preds = self.backbone(x_enc)
603
+ preds = self.revin_layer(preds, 'denorm')
604
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
605
 
606