Impulse2000 commited on
Commit
810a5ae
·
verified ·
1 Parent(s): a8f90b8

Upload sentiment-transformer model

Browse files
__pycache__/modeling_sentiment_transformer.cpython-314.pyc CHANGED
Binary files a/__pycache__/modeling_sentiment_transformer.cpython-314.pyc and b/__pycache__/modeling_sentiment_transformer.cpython-314.pyc differ
 
modeling_sentiment_transformer.py CHANGED
@@ -352,11 +352,17 @@ class SentimentTransformerForSequenceClassification(PreTrainedModel):
352
  freqs = torch.outer(t, inv_freq)
353
  module.cos_cached = freqs.cos()
354
  module.sin_cached = freqs.sin()
355
-
356
- def post_init(self) -> None:
357
- """Override HF's post_init to also recompute RoPE buffers."""
358
- super().post_init()
359
- self._recompute_rope_buffers()
 
 
 
 
 
 
360
 
361
  def forward(
362
  self,
@@ -367,6 +373,8 @@ class SentimentTransformerForSequenceClassification(PreTrainedModel):
367
  **_kwargs,
368
  ) -> SequenceClassifierOutput | tuple[torch.Tensor, ...]:
369
  """Run sequence classification and return HF-style outputs."""
 
 
370
  if input_ids is None:
371
  raise ValueError("`input_ids` is required.")
372
  if attention_mask is None:
 
352
  freqs = torch.outer(t, inv_freq)
353
  module.cos_cached = freqs.cos()
354
  module.sin_cached = freqs.sin()
355
+ self._rope_valid = True
356
+
357
+ def _ensure_rope_valid(self) -> None:
358
+ """Lazily recompute RoPE buffers if they were corrupted by HF loading."""
359
+ if not getattr(self, "_rope_valid", False):
360
+ # Check if the backbone's RoPE buffers contain valid data
361
+ rope = self.backbone.rope
362
+ if not rope.cos_cached.isfinite().all():
363
+ self._recompute_rope_buffers()
364
+ else:
365
+ self._rope_valid = True
366
 
367
  def forward(
368
  self,
 
373
  **_kwargs,
374
  ) -> SequenceClassifierOutput | tuple[torch.Tensor, ...]:
375
  """Run sequence classification and return HF-style outputs."""
376
+ self._ensure_rope_valid()
377
+
378
  if input_ids is None:
379
  raise ValueError("`input_ids` is required.")
380
  if attention_mask is None: