Search commited on
Commit
20ccbfa
Β·
1 Parent(s): 4cbc2ee

auto: sync run_qwen_phase_probe.py

Browse files
scripts/run_qwen_phase_probe.py CHANGED
@@ -60,6 +60,8 @@ from src.utils.anchor_geometry import (
60
  match_anchor_span,
61
  token_has_leading_whitespace,
62
  )
 
 
63
 
64
  # ─────────────────────────────────────────────────────────────────────────────
65
  # Constants
@@ -347,8 +349,9 @@ def generate_base_text(
347
  if tokenizer is None:
348
  raise ValueError("tokenizer is required")
349
  device = next(overlay.parameters()).device
 
350
  encoded = tokenizer(
351
- [prompt],
352
  truncation=True,
353
  max_length=MAX_LENGTH,
354
  return_tensors="pt",
@@ -627,6 +630,7 @@ def run(
627
 
628
 
629
  def main() -> None:
 
630
  parser = argparse.ArgumentParser(description="ABPT Phase Probe β€” Π€Π°Π·Π° 1 Π²Π΅Ρ€ΠΈΡ„ΠΈΠΊΠ°Ρ†ΠΈΠΈ Π³Π΅ΠΎΠΌΠ΅Ρ‚Ρ€ΠΈΠΈ")
631
  parser.add_argument("--model", default="Qwen/Qwen3.5-4B",
632
  help="HuggingFace model name")
 
60
  match_anchor_span,
61
  token_has_leading_whitespace,
62
  )
63
+ from src.utils.qwen_prompting import format_generation_prompt
64
+ from src.utils.stdio import configure_utf8_stdio
65
 
66
  # ─────────────────────────────────────────────────────────────────────────────
67
  # Constants
 
349
  if tokenizer is None:
350
  raise ValueError("tokenizer is required")
351
  device = next(overlay.parameters()).device
352
+ generation_prompt = format_generation_prompt(tokenizer, prompt)
353
  encoded = tokenizer(
354
+ [generation_prompt],
355
  truncation=True,
356
  max_length=MAX_LENGTH,
357
  return_tensors="pt",
 
630
 
631
 
632
  def main() -> None:
633
+ configure_utf8_stdio()
634
  parser = argparse.ArgumentParser(description="ABPT Phase Probe β€” Π€Π°Π·Π° 1 Π²Π΅Ρ€ΠΈΡ„ΠΈΠΊΠ°Ρ†ΠΈΠΈ Π³Π΅ΠΎΠΌΠ΅Ρ‚Ρ€ΠΈΠΈ")
635
  parser.add_argument("--model", default="Qwen/Qwen3.5-4B",
636
  help="HuggingFace model name")
src/model/qwen_anchor_overlay.py CHANGED
@@ -53,6 +53,7 @@ from src.utils.anchor_geometry import (
53
  match_anchor_span,
54
  select_tail_probe_layers,
55
  )
 
56
  from src.data.qwen_anchor_geometry_cases import QwenAnchorGeometryCase, make_qwen_anchor_geometry_cases
57
 
58
 
@@ -853,8 +854,9 @@ class QwenAnchorOverlay(nn.Module):
853
  if self.tokenizer is None:
854
  raise ValueError("tokenizer is required for generate_with_anchor_bias")
855
 
 
856
  encoded = self.tokenizer(
857
- [prompt],
858
  padding=True,
859
  truncation=True,
860
  max_length=max_length,
@@ -1235,8 +1237,9 @@ class QwenAnchorOverlay(nn.Module):
1235
  ) -> dict[str, Any]:
1236
  if self.tokenizer is None:
1237
  raise ValueError("tokenizer is required for _generate_trust_completion")
 
1238
  encoded = self.tokenizer(
1239
- [prompt],
1240
  padding=True,
1241
  truncation=True,
1242
  max_length=max_length,
 
53
  match_anchor_span,
54
  select_tail_probe_layers,
55
  )
56
+ from src.utils.qwen_prompting import format_generation_prompt
57
  from src.data.qwen_anchor_geometry_cases import QwenAnchorGeometryCase, make_qwen_anchor_geometry_cases
58
 
59
 
 
854
  if self.tokenizer is None:
855
  raise ValueError("tokenizer is required for generate_with_anchor_bias")
856
 
857
+ generation_prompt = format_generation_prompt(self.tokenizer, prompt)
858
  encoded = self.tokenizer(
859
+ [generation_prompt],
860
  padding=True,
861
  truncation=True,
862
  max_length=max_length,
 
1237
  ) -> dict[str, Any]:
1238
  if self.tokenizer is None:
1239
  raise ValueError("tokenizer is required for _generate_trust_completion")
1240
+ generation_prompt = format_generation_prompt(self.tokenizer, prompt)
1241
  encoded = self.tokenizer(
1242
+ [generation_prompt],
1243
  padding=True,
1244
  truncation=True,
1245
  max_length=max_length,
src/utils/qwen_prompting.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ def format_generation_prompt(
7
+ tokenizer: Any | None,
8
+ prompt: str,
9
+ *,
10
+ disable_thinking: bool = True,
11
+ ) -> str:
12
+ if tokenizer is None:
13
+ return prompt
14
+ apply_chat_template = getattr(tokenizer, "apply_chat_template", None)
15
+ if not callable(apply_chat_template):
16
+ return prompt
17
+ messages = [{"role": "user", "content": prompt}]
18
+ kwargs: dict[str, Any] = {
19
+ "tokenize": False,
20
+ "add_generation_prompt": True,
21
+ }
22
+ if disable_thinking:
23
+ kwargs["enable_thinking"] = False
24
+ try:
25
+ rendered = apply_chat_template(messages, **kwargs)
26
+ except TypeError:
27
+ kwargs.pop("enable_thinking", None)
28
+ try:
29
+ rendered = apply_chat_template(messages, **kwargs)
30
+ except Exception:
31
+ return prompt
32
+ except Exception:
33
+ return prompt
34
+ return str(rendered)
src/utils/stdio.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+
5
+
6
+ def configure_utf8_stdio() -> None:
7
+ for stream_name in ("stdout", "stderr"):
8
+ stream = getattr(sys, stream_name, None)
9
+ reconfigure = getattr(stream, "reconfigure", None)
10
+ if callable(reconfigure):
11
+ reconfigure(encoding="utf-8", errors="replace")