ASQ-Interview-Workflow / workflow_core.py
LucasUn's picture
Upload workflow_core.py
f7859b8 verified
import os
import re
import random
import time
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterator, List, Tuple
import openpyxl
import torch
from docx import Document
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
try:
from peft import PeftModel # type: ignore
except Exception:
PeftModel = None # type: ignore
try:
from groq import Groq as GroqSDK # type: ignore
except Exception:
GroqSDK = None # type: ignore
_BASE_DIR = Path(__file__).resolve().parent
_DEFAULT_CHILD_AVATARS_PY = _BASE_DIR / "child_avatars_v3_05022026_groq_gpt_oss_120b.py"
_DEFAULT_ASQ_DOCX_DIR = _BASE_DIR / "System prompts"
_DEFAULT_MODEL_HYPO = os.environ.get("HF_MODEL_HYPO", "LucasUn/ASQ-Hypotheses")
_DEFAULT_MODEL_SUMM = os.environ.get("HF_MODEL_SUMM", "LucasUn/ASQ-Summary")
_DEFAULT_MODEL_TOPIC = os.environ.get("HF_MODEL_TOPIC", "LucasUn/ASQ-Topic_Navi")
_DEFAULT_MODEL_QUEST = os.environ.get("HF_MODEL_QUEST", "LucasUn/ASQ-Question")
_DEFAULT_TOKENIZER_FALLBACK_ID = os.environ.get(
"HF_TOKENIZER_FALLBACK_ID",
"Qwen/Qwen3.5-27B",
)
_DEFAULT_ADAPTER_BASE_MODEL_ID = os.environ.get(
"HF_ADAPTER_BASE_MODEL_ID",
"Qwen/Qwen3.5-27B",
)
_DEFAULT_ADAPTER_BASE_MODEL_HYPO = os.environ.get(
"HF_ADAPTER_BASE_MODEL_HYPO",
_DEFAULT_ADAPTER_BASE_MODEL_ID,
)
_DEFAULT_ADAPTER_BASE_MODEL_SUMM = os.environ.get(
"HF_ADAPTER_BASE_MODEL_SUMM",
_DEFAULT_ADAPTER_BASE_MODEL_ID,
)
_DEFAULT_ADAPTER_BASE_MODEL_TOPIC = os.environ.get(
"HF_ADAPTER_BASE_MODEL_TOPIC",
_DEFAULT_ADAPTER_BASE_MODEL_ID,
)
_DEFAULT_ADAPTER_BASE_MODEL_QUEST = os.environ.get(
"HF_ADAPTER_BASE_MODEL_QUEST",
_DEFAULT_ADAPTER_BASE_MODEL_ID,
)
def _normalize_model_id(value: Any, fallback: str) -> str:
raw = "" if value is None else str(value).strip()
if not raw or raw.lower() in {"none", "null", "nil"}:
return fallback
return raw
def _env_flag(name: str, default: bool) -> bool:
raw = os.environ.get(name)
if raw is None:
return default
return raw.strip().lower() in {"1", "true", "yes", "y", "on"}
_HF_ENABLE_4BIT = _env_flag("HF_ENABLE_4BIT", True)
_HF_ENABLE_8BIT = _env_flag("HF_ENABLE_8BIT", False)
_HF_BNB_4BIT_QUANT_TYPE = os.environ.get("HF_BNB_4BIT_QUANT_TYPE", "nf4").strip() or "nf4"
_HF_BNB_4BIT_USE_DOUBLE_QUANT = _env_flag("HF_BNB_4BIT_USE_DOUBLE_QUANT", True)
_HF_BNB_4BIT_COMPUTE_DTYPE = os.environ.get("HF_BNB_4BIT_COMPUTE_DTYPE", "float16").strip().lower() or "float16"
_SHARED_BASE_MODELS: Dict[str, Any] = {}
@dataclass
class ChildSessionState:
suspicion_true: bool
rng_seed: int
FIXED_INTRO = (
"This tool is for research and training in forensic interviewing; it is not a real case.\n"
"You are role-playing the child described below for interviewer training.\n"
"Speak exactly as that child would.\n\n"
"Answering style:\n"
"• Speak naturally (like a child), not in lists.\n"
"• Keep each reply very short (usually 1 sentence).\n"
"• Typical length ≤ 22 words (toddlers 1–6 words).\n"
"• Give at most 1–2 concrete details per reply. Do not volunteer extra details.\n"
"• On topics about harm/violence/anything “bad”: say even less; "
"“I don’t remember / I don’t know / I don’t want to talk about it” is allowed.\n"
"• ≤ 8 yrs: may briefly drift to toys/pets if nervous.\n"
"• ≥ 9 yrs: be initially reluctant; open up only with gentle encouragement.\n"
"• Leading questions can influence you, but do not invent detailed new information.\n"
"• Do not ask the interviewer questions except simple clarifications.\n\n"
)
class HFLocalGenerator:
@staticmethod
def _get_adapter_base_model_id(repo_id: str) -> str | None:
try:
adapter_cfg_path = hf_hub_download(repo_id=repo_id, filename="adapter_config.json")
cfg = json.loads(Path(adapter_cfg_path).read_text(encoding="utf-8", errors="ignore"))
base_raw = cfg.get("base_model_name_or_path")
if base_raw is None:
return None
base = str(base_raw).strip()
if not base or base.lower() in {"none", "null", "nil"}:
return None
return base
except Exception:
return None
def __init__(self, model_id: str, adapter_base_model_id: str | None = None):
self.model_id = model_id
self.adapter_base_model_id = (
_normalize_model_id(adapter_base_model_id, _DEFAULT_ADAPTER_BASE_MODEL_ID)
if adapter_base_model_id is not None
else None
)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
try:
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
use_fast=True,
)
except Exception as fast_err:
# Some fine-tuned repos only provide slow tokenizer assets.
try:
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
use_fast=False,
)
except Exception as slow_err:
# Final fallback: use a known compatible base tokenizer.
try:
self.tokenizer = AutoTokenizer.from_pretrained(
_DEFAULT_TOKENIZER_FALLBACK_ID,
trust_remote_code=True,
use_fast=False,
)
except Exception as fallback_err:
raise RuntimeError(
"Tokenizer 加载失败:fast/slow 两种模式都无法初始化,"
f"且回退 tokenizer `{_DEFAULT_TOKENIZER_FALLBACK_ID}` 也加载失败。"
) from fallback_err
model_kwargs: Dict[str, Any] = {"trust_remote_code": True, "low_cpu_mem_usage": True}
if self.device == "cuda":
model_kwargs["device_map"] = {"": 0} # Force everything to first GPU to avoid accelerate hanging on multi-GPU/CPU split
quant_enabled = _HF_ENABLE_4BIT or _HF_ENABLE_8BIT
if quant_enabled:
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
compute_dtype = dtype_map.get(_HF_BNB_4BIT_COMPUTE_DTYPE, torch.float16)
if _HF_ENABLE_4BIT:
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=_HF_BNB_4BIT_QUANT_TYPE,
bnb_4bit_use_double_quant=_HF_BNB_4BIT_USE_DOUBLE_QUANT,
bnb_4bit_compute_dtype=compute_dtype,
)
else:
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
else:
model_kwargs["torch_dtype"] = torch.float16
else:
model_kwargs.update({
"device_map": "cpu",
"torch_dtype": torch.float32,
})
import gc
try:
# For testing/monitoring when loading huge files without hanging UI
print(f"[{model_id}] Starting model load from HF...", flush=True)
# Forcing accelerate to completely avoid CPU offloading which causes OOM Killer
if "device_map" in model_kwargs:
# Give it generous RAM limit but absolutely block swapping beyond that
model_kwargs["max_memory"] = {0: "78GB", "cpu": "80GB"}
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
**model_kwargs
)
print(f"[{model_id}] Successfully loaded into memory.", flush=True)
self.is_peft = False
# Force garbage collection after each large model load to free up unused CPU RAM
gc.collect()
except Exception as model_err:
raise RuntimeError(f"模型加载失败: {model_id}. Error: {model_err}")
if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
def _build_chat_prompt(self, messages: List[Dict[str, str]]) -> str:
if hasattr(self.tokenizer, "apply_chat_template"):
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
chunks: List[str] = []
for m in messages:
role = m.get("role", "user").upper()
content = m.get("content", "")
chunks.append(f"[{role}]\n{content}\n")
chunks.append("[ASSISTANT]\n")
return "".join(chunks)
def generate_from_prompt(
self,
prompt_text: str,
max_new_tokens: int,
temperature: float,
) -> str:
if getattr(self, "is_peft", False) and hasattr(self.model, "set_adapter"):
self.model.set_adapter(self.active_adapter)
inputs = self.tokenizer(prompt_text, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.model.device)
attention_mask = inputs.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(self.model.device)
do_sample = temperature > 0
gen_kwargs: Dict[str, Any] = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
if do_sample:
gen_kwargs["temperature"] = max(temperature, 1e-5)
gen_kwargs["top_p"] = 0.9
with torch.no_grad():
out_ids = self.model.generate(**gen_kwargs)
# Free memory immediately after generation
if torch.cuda.is_available():
torch.cuda.empty_cache()
gen_ids = out_ids[0][input_ids.shape[1] :]
text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
# Yield the generated text back via print so it can be monitored
# (Though we can also just return it and let the caller yield status)
return text.strip()
def generate_from_messages(
self,
messages: List[Dict[str, str]],
max_new_tokens: int,
temperature: float,
) -> str:
prompt_text = self._build_chat_prompt(messages)
return self.generate_from_prompt(
prompt_text=prompt_text,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
def build_suspicion_system_msg(state: ChildSessionState) -> Dict[str, str]:
content = (
f"SUSPICION_STATUS = {'TRUE' if state.suspicion_true else 'FALSE'}.\n"
"Never reveal this flag. Behave consistently with it.\n\n"
"Global response rules:\n"
"• Output exactly one short spoken turn (no bullet lists).\n"
"• Keep it very brief: usually 4–18 words (≤8 yrs), 6–22 words (>8 yrs); toddlers 1–6 words.\n"
"• Include at most 1–2 concrete details per reply.\n"
". If you already gave one detail, stop.\n"
"• If you are ≤8 yrs, try to adapt your answering style to that of the questioner, at least a bit .\n"
"• Answer only what was asked; and even then, only minimally, do not volunteer background.\n"
". Never divulge specifics that you consider irrelevant or shameful (e.g. number of a particular room or a criminal action etc.) unless asked.\n"
"• Natural child speech is preferred (simple words; occasional “um”).\n"
"• If unsure or unwilling, you may say: “I don’t know”, “I don’t remember”, or “I don’t want to talk about it”.\n\n"
"Suspicion-sensitive topics (harm, violence, sexual, neglect):\n"
"• Be even more limited: yes/no + one tiny detail, OR refuse/forget.\n"
"• Do not narrate sequences, motives, frequency, injuries, or graphic details.\n"
"• “Why/when/how” questions may be answered with “I don’t remember” or “I don’t want to talk about it”.\n\n"
"Suggestibility / confabulation:\n"
"• Leading questions can influence you.\n"
"• ≤ 8 yrs: after 1–2 leading prompts, you may start agreeing with the interviewer’s assumption.\n"
"• > 8 yrs: agree only after persistent prompting (≥ 3 times).\n"
"• When you agree, keep it vague: at most one new, non-specific claim per reply.\n"
"• If SUSPICION_STATUS = FALSE and violence is suggested, you may echo it briefly (e.g., “um… yeah”), "
"but you do not know details—avoid specifics and prefer “I don’t remember”.\n"
"• If SUSPICION_STATUS = TRUE, disclose truth slowly: one small, non-graphic detail at a time.\n\n"
"Consistency:\n"
"• If SUSPICION_STATUS = TRUE and the interviewer insists it is false, you may doubt and minimise.\n"
"• If SUSPICION_STATUS = FALSE and the interviewer insists it is true, you may become suggestible, "
"but stay vague and avoid concrete new details unless asked directly.\n"
)
return {"role": "system", "content": content}
def _get_groq_api_key() -> str:
key = os.environ.get("GROQ_API_KEY", "").strip()
if key:
return key
raise RuntimeError("缺少 GROQ_API_KEY,请在 Hugging Face Space Secrets 中配置。")
def interview_child(
llm: HFLocalGenerator,
history: List[Dict[str, str]],
question: str,
) -> str:
"""
Mutates `history` in-place:
- append {"role":"user","content": question}
- append {"role":"assistant","content": reply}
"""
history.append({"role": "user", "content": question})
last_reply = None
for msg in reversed(history):
if msg.get("role") == "assistant":
last_reply = msg.get("content")
break
exc: Exception | None = None
for attempt in range(3):
try:
msgs = history
if attempt > 0 and last_reply is not None:
msgs = history + [
{
"role": "system",
"content": (
"Produce exactly one concise conversational response. "
"Do not list alternatives. Do not repeat your previous wording."
),
}
]
reply = llm.generate_from_messages(
messages=msgs,
temperature=random.choice([0.1, 0.2, 0.4]),
max_new_tokens=96,
)
if reply in {
"I’m sorry, but I can’t continue this conversation.",
"I’m sorry, but I can’t answer that.",
"I don’t want to talk about that.",
}:
reply = "[silence]"
if not reply or not reply.strip():
continue
if reply == last_reply and attempt < 2:
continue
history.append({"role": "assistant", "content": reply})
return reply
except Exception as e:
exc = e
time.sleep(0.2 * (2**attempt))
if exc:
msg = str(exc).strip().replace("\n", " ")
if len(msg) > 160:
msg = msg[:160] + "..."
return f"[child error: {type(exc).__name__}: {msg}]"
return "[child error: empty/invalid reply after retries]"
def interview_child_groq(
groq_client: Any,
history: List[Dict[str, str]],
question: str,
model: str,
) -> str:
history.append({"role": "user", "content": question})
last_reply = None
for msg in reversed(history):
if msg.get("role") == "assistant":
last_reply = msg.get("content")
break
exc: Exception | None = None
for attempt in range(3):
try:
msgs = history
if attempt > 0 and last_reply is not None:
msgs = history + [
{
"role": "system",
"content": (
"Produce exactly one concise conversational response. "
"Do not list alternatives. Do not repeat your previous wording."
),
}
]
resp = groq_client.chat.completions.create(
model=model,
messages=msgs,
temperature=random.choice([0.1, 0.2, 0.4]),
max_tokens=300,
)
choice0 = resp.choices[0]
reply = getattr(choice0.message, "content", "") or ""
reply = str(reply).strip()
if reply in {
"I’m sorry, but I can’t continue this conversation.",
"I’m sorry, but I can’t answer that.",
"I don’t want to talk about that.",
}:
reply = "[silence]"
if not reply:
continue
if reply == last_reply and attempt < 2:
continue
history.append({"role": "assistant", "content": reply})
return reply
except Exception as e:
exc = e
time.sleep(0.2 * (2**attempt))
if exc:
msg = str(exc).strip().replace("\n", " ")
if len(msg) > 160:
msg = msg[:160] + "..."
return f"[child error: {type(exc).__name__}: {msg}]"
return "[child error: empty/invalid reply after retries]"
def strip_interviewer_prefix(question_text: str) -> str:
if not question_text:
return ""
s = question_text.strip()
if s.startswith("Interviewer:"):
return s[len("Interviewer:") :].lstrip()
return s
def extract_avatra_prompts_from_file(child_avatars_py: str) -> Dict[str, str]:
p = Path(child_avatars_py)
text = p.read_text(encoding="utf-8", errors="ignore")
start = text.find("AVATAR_PROMPTS = {")
if start == -1:
raise RuntimeError("Could not locate AVATAR_PROMPTS in child file.")
marker = "\n}\n\n\"\"\"\n# Virtual Interview System"
end = text.find(marker, start)
if end == -1:
marker = "\n}\n\n\"\"\"\n# Virtual Interview System (do not edit)"
end = text.find(marker, start)
if end == -1:
raise RuntimeError("Could not locate end of AVATAR_PROMPTS dict.")
snippet = text[start : end + 2]
ns: Dict[str, Any] = {}
exec(snippet, ns)
prompts = ns.get("AVATAR_PROMPTS")
if not isinstance(prompts, dict) or not prompts:
raise RuntimeError("Failed to parse AVATAR_PROMPTS.")
return prompts
def extract_suspicion_block(avatar_prompt: str) -> str:
m = re.search(r"(?is)\bsuspicion\s*:\s*", avatar_prompt)
if not m:
m = re.search(r"(?is)\bThe suspicion\s*:\s*", avatar_prompt)
if not m:
raise RuntimeError("Could not find suspicion section in avatar prompt.")
block = avatar_prompt[m.end() :].strip()
if not block:
raise RuntimeError("Suspicion block is empty.")
return block
class AsqEngine:
def __init__(self, docx_dir: Path, llms: Dict[str, HFLocalGenerator]):
self._SYSTEM_PROMPTS_DIR = docx_dir
self._llms = llms
self._SYSTEM_PROMPT_HYPO = self._read_docx_text(self._SYSTEM_PROMPTS_DIR / "Hypotheses prompt.docx")
self._SYSTEM_PROMPT_SUMM = self._read_docx_text(self._SYSTEM_PROMPTS_DIR / "Summary prompt.docx")
self._SYSTEM_PROMPT_TOPIC = self._read_docx_text(self._SYSTEM_PROMPTS_DIR / "Topic prompt.docx")
self._SYSTEM_PROMPT_QUEST = self._read_docx_text(self._SYSTEM_PROMPTS_DIR / "Question prompt.docx")
def _read_docx_text(self, path: Path) -> str:
if not path.exists():
raise FileNotFoundError(f"Missing prompt docx: {path}")
doc = Document(str(path))
parts: List[str] = []
for para in doc.paragraphs:
txt = (para.text or "").strip()
if txt:
parts.append(txt)
text = "\n".join(parts).strip()
if not text:
raise ValueError(f"Prompt docx is empty: {path}")
return text
def render_prompt(self, system_prompt: str, user_content: str) -> str:
parts = []
if system_prompt.strip():
parts.append("[SYSTEM]\n" + system_prompt.strip() + "\n")
parts.append("[USER]\n" + user_content.strip() + "\n")
parts.append("[ASSISTANT]\n")
return "".join(parts)
def generate_response(self, system_prompt: str, user_prompt: str, model_key: str) -> str:
llm = self._llms[model_key]
prompt_text = self.render_prompt(system_prompt, user_prompt)
# Add a tiny delay to ensure device sync if running on multiple GPUs or busy instances
if torch.cuda.is_available():
torch.cuda.synchronize()
return llm.generate_from_prompt(
prompt_text=prompt_text,
max_new_tokens=512,
temperature=0.0,
)
def _extract_section_content(self, text: str, section_heading: str) -> str:
pattern = re.compile(rf"(?is){re.escape(section_heading)}\s*(.*?)(?=\n###\s+|$)")
m = pattern.search(text)
if not m:
return ""
return m.group(1).strip()
def sanitize_summary(self, summary_text: str) -> str:
if not summary_text:
return ""
heading_summary = "### Summary of the ongoing interview:"
heading_missing = "### Missing Information:"
summary_part = self._extract_section_content(summary_text, heading_summary)
missing_part = self._extract_section_content(summary_text, heading_missing)
if not summary_part and not missing_part:
return summary_text.strip()
if not summary_part:
summary_part = "N/A"
if not missing_part:
missing_part = "N/A"
return (
f"{heading_summary}\n"
f"{summary_part}\n\n"
f"{heading_missing}\n"
f"{missing_part}"
).strip()
def sanitize_topic(self, topic_text: str) -> str:
if not topic_text:
return ""
heading = "### Topic navigation:"
required_keys = [
"Active_Topic",
"Exploration_Depth",
"Recommended_Question_Type",
"Target_Missing_Detail",
"Reasoning",
]
lines = topic_text.splitlines()
start_idx = None
for i, line in enumerate(lines):
if line.strip().lower() == heading.lower():
start_idx = i
break
if start_idx is None:
return topic_text.strip()
parsed: Dict[str, str] = {}
current_key = ""
for raw in lines[start_idx + 1 :]:
line = raw.strip()
if not line:
continue
if line.startswith("### "):
break
matched_key = None
for key in required_keys:
prefix = f"{key}:"
if line.startswith(prefix):
matched_key = key
parsed[key] = line[len(prefix) :].strip()
current_key = key
break
if matched_key is not None:
continue
if current_key:
parsed[current_key] = (parsed[current_key] + " " + line).strip()
kept_lines: List[str] = [heading]
for key in required_keys:
value = parsed.get(key, "").strip() or "N/A"
kept_lines.append(f"{key}: {value}")
return "\n".join(kept_lines).strip()
def step1_generate_hypo(self, background: str) -> str:
if not background.strip():
return "請先输入案件背景信息。"
return self.generate_response(self._SYSTEM_PROMPT_HYPO, background, "hypo")
def format_context_for_prompt(self, history: List[Dict[str, str]], max_rounds: int = 15) -> str:
max_msgs = max_rounds * 2
recent_history = history[-max_msgs:] if len(history) > max_msgs else history
context_str = ""
for i in range(0, len(recent_history) - 1, 2):
msg1 = recent_history[i]
msg2 = recent_history[i + 1]
child_msg = msg1.get("content", "") if isinstance(msg1, dict) else ""
interviewer_msg = msg2.get("content", "") if isinstance(msg2, dict) else ""
context_str += f"Child: {child_msg}\nInterviewer: {interviewer_msg}\n"
return context_str.strip()
def step3_once(
self,
user_input: str,
history: List[Dict[str, str]],
bg: str,
hyp: str,
current_summ: str,
current_topic: str,
round_count: int,
refresh_interval: int,
) -> Tuple[List[Dict[str, str]], str, str, int, str, float]:
if not user_input.strip():
return history, current_summ, current_topic, round_count, "", 0.0
round_count += 1
formatted_user_input = f"Child: {user_input}"
history.append({"role": "user", "content": user_input})
history.append({"role": "assistant", "content": "[thinking]"})
if round_count == 1:
summ_user_prompt = "No interview at this stage"
current_summ = self.sanitize_summary(self.generate_response(self._SYSTEM_PROMPT_SUMM, summ_user_prompt, "summ"))
topic_user_prompt = (
f"Background:\n{bg}\n\nHypotheses:\n{hyp}\n\nSummary:\n{current_summ}\n\nRecent Context:\nNo context yet."
)
current_topic = self.sanitize_topic(self.generate_response(self._SYSTEM_PROMPT_TOPIC, topic_user_prompt, "topic"))
elif round_count > 1 and (round_count - 1) % refresh_interval == 0:
all_context_str = self.format_context_for_prompt(history[:-1]) + f"\n{formatted_user_input}"
current_summ = self.sanitize_summary(self.generate_response(self._SYSTEM_PROMPT_SUMM, all_context_str, "summ"))
topic_user_prompt = (
f"Background:\n{bg}\n\nHypotheses:\n{hyp}\n\nSummary:\n{current_summ}\n\n"
f"Recent Context:\n{self.format_context_for_prompt(history[:-1], 15)}\n{formatted_user_input}"
)
current_topic = self.sanitize_topic(self.generate_response(self._SYSTEM_PROMPT_TOPIC, topic_user_prompt, "topic"))
quest_user_prompt = (
f"Background:\n{bg}\n\n"
f"Hypotheses:\n{hyp}\n\n"
f"Summary:\n{current_summ}\n\n"
f"Recent Context:\n{self.format_context_for_prompt(history[:-1], 15)}\n\n"
f"Topic Navigation:\n{current_topic}\n\n"
f"Current User Input:\n{formatted_user_input}"
)
quest_start_ts = time.perf_counter()
raw_response = self.generate_response(self._SYSTEM_PROMPT_QUEST, quest_user_prompt, "quest")
question_latency_seconds = time.perf_counter() - quest_start_ts
paragraphs = [p.strip() for p in raw_response.split("\n") if p.strip()]
first_paragraph = paragraphs[0] if paragraphs else "Can you tell me more about that?"
if not first_paragraph.startswith("Interviewer:"):
model_reply = f"Interviewer: {first_paragraph}"
else:
model_reply = first_paragraph
history[-1]["content"] = model_reply
return history, current_summ, current_topic, round_count, model_reply, question_latency_seconds
def decide_suspicion_true(force_value: str | None) -> bool:
if not force_value:
return bool(random.getrandbits(1))
s = force_value.strip().lower()
if s in {"true", "1", "yes", "y"}:
return True
if s in {"false", "0", "no", "n"}:
return False
raise ValueError("Invalid FORCE_SUSPICION, expected true/false.")
def safe_excel_sheet_title(title: str) -> str:
title = (title or "").strip()
if not title:
title = "child"
if len(title) > 31:
title = title[:28] + "..."
title = re.sub(r"[:\\/?*\[\]]", "_", title)
return title
def init_child_worksheet(ws, child_name: str, hypotheses: str) -> None:
del child_name
ws["A1"] = "Summary"
ws["B1"] = "Target Navigation"
ws["C1"] = hypotheses
ws["D1"] = "Child Answer"
ws["E1"] = "ASQ Question Response Time (s)"
def run_batch_export_to_excel(
out_excel_path: Path,
rounds: int = 20,
summary_topic_refresh_interval: int = 3,
hf_model_hypo: str = _DEFAULT_MODEL_HYPO,
hf_model_summ: str = _DEFAULT_MODEL_SUMM,
hf_model_topic: str = _DEFAULT_MODEL_TOPIC,
hf_model_quest: str = _DEFAULT_MODEL_QUEST,
child_groq_model: str = os.environ.get("GROQ_CHILD_MODEL", "openai/gpt-oss-120b"),
child_avatars_py: str | Path = _DEFAULT_CHILD_AVATARS_PY,
asq_docx_dir: Path | None = None,
) -> Iterator[Tuple[str, str | None]]:
if summary_topic_refresh_interval < 1:
raise ValueError("summary_topic_refresh_interval 必须 >= 1。")
yield ("正在解析四个 child 的头像提示词与 suspicion 字段...", None)
child_path = str(Path(child_avatars_py))
avatar_prompts = extract_avatra_prompts_from_file(child_path)
child_items = list(avatar_prompts.items())
if len(child_items) < 4:
raise RuntimeError(f"Expected 4 children in avatar prompts, got {len(child_items)}.")
child_items = child_items[:4]
if asq_docx_dir is None:
asq_docx_dir = _DEFAULT_ASQ_DOCX_DIR
selected_models = {
"hypo": _normalize_model_id(hf_model_hypo, _DEFAULT_MODEL_HYPO),
"summ": _normalize_model_id(hf_model_summ, _DEFAULT_MODEL_SUMM),
"topic": _normalize_model_id(hf_model_topic, _DEFAULT_MODEL_TOPIC),
"quest": _normalize_model_id(hf_model_quest, _DEFAULT_MODEL_QUEST),
}
adapter_base_models = {
"hypo": _normalize_model_id(_DEFAULT_ADAPTER_BASE_MODEL_HYPO, _DEFAULT_ADAPTER_BASE_MODEL_ID),
"summ": _normalize_model_id(_DEFAULT_ADAPTER_BASE_MODEL_SUMM, _DEFAULT_ADAPTER_BASE_MODEL_ID),
"topic": _normalize_model_id(_DEFAULT_ADAPTER_BASE_MODEL_TOPIC, _DEFAULT_ADAPTER_BASE_MODEL_ID),
"quest": _normalize_model_id(_DEFAULT_ADAPTER_BASE_MODEL_QUEST, _DEFAULT_ADAPTER_BASE_MODEL_ID),
}
loaded_generators: Dict[str, HFLocalGenerator] = {}
model_cache_by_id: Dict[Tuple[str, str], HFLocalGenerator] = {}
# NEW FIX: Instead of keeping all models in memory, load one, generate, and unload it.
# We will initialize them lazily in AsqEngine later if needed, but for now we must
# clear memory correctly or delay their load. Wait, ASQEngine expects them all loaded.
# Pre-download all models first to avoid hitting connection/cache locks during loading
yield ("正在预下载所有模型文件到本地缓存,这可能需要一段时间...", None)
from huggingface_hub import snapshot_download
for key in ["hypo", "summ", "topic", "quest"]:
model_id = selected_models[key]
yield (f"正在检查/预下载模型: {model_id}", None)
try:
snapshot_download(repo_id=model_id, allow_patterns=["*.json", "*.safetensors", "*.jinja"])
except Exception as e:
yield (f"预下载 {model_id} 时出现警告 (将尝试继续): {e}", None)
yield ("预下载完成,开始顺序加载模型到显存...", None)
for key in ["hypo", "summ", "topic", "quest"]:
model_id = selected_models[key]
base_id = adapter_base_models[key]
cache_key = (model_id, base_id)
if cache_key in model_cache_by_id:
loaded_generators[key] = model_cache_by_id[cache_key]
continue
yield (
f"正在加载 `{key}` 模块模型:`{model_id}`(adapter base fallback: `{base_id}`)...",
None,
)
gen = HFLocalGenerator(model_id=model_id, adapter_base_model_id=base_id)
model_cache_by_id[cache_key] = gen
loaded_generators[key] = gen
groq_key = _get_groq_api_key()
if GroqSDK is None:
raise RuntimeError("未安装 groq SDK,请在 requirements.txt 中添加 groq。")
groq_client = GroqSDK(api_key=groq_key)
child_model = _normalize_model_id(child_groq_model, "openai/gpt-oss-120b")
asq = AsqEngine(docx_dir=asq_docx_dir, llms=loaded_generators)
yield (
"本地模型已就绪,开始批量访谈...\n"
f"hypo={selected_models['hypo']}\n"
f"summ={selected_models['summ']}\n"
f"topic={selected_models['topic']}\n"
f"quest={selected_models['quest']}\n"
f"child_groq={child_model}",
None,
)
force_suspicion = os.environ.get("FORCE_SUSPICION")
wb = openpyxl.Workbook()
wb.remove(wb.active)
out_excel_path.parent.mkdir(parents=True, exist_ok=True)
for child_idx, (child_name, avatar_prompt) in enumerate(child_items, start=1):
yield (f"正在处理 child {child_idx}/4: **{child_name}** ...", None)
yield (f"正在处理 child {child_idx}/4: **{child_name}** ...\n"
f"-> 第一步:根据背景故事生成侦讯假设 (Hypotheses)。\n"
f" (A100 上首次启动推理,或者如果未合并 4bit LoRA,这一步可能长达 5-10 分钟,请耐心保持页面开启!)", None)
suspicion_true = decide_suspicion_true(force_suspicion)
suspicion_seed = random.randrange(2**32)
state = ChildSessionState(suspicion_true=suspicion_true, rng_seed=suspicion_seed)
bg_for_asq = extract_suspicion_block(avatar_prompt)
hyp_start_ts = time.perf_counter()
hypotheses = asq.step1_generate_hypo(bg_for_asq)
hyp_latency = time.perf_counter() - hyp_start_ts
yield (f"[{child_name}] ✅ 假设 (Hypotheses) 生成完毕!耗时: {hyp_latency:.1f} 秒。\n"
f"内容预览: {hypotheses[:150]}...\n", None)
ws = wb.create_sheet(title=safe_excel_sheet_title(child_name))
init_child_worksheet(ws, child_name=child_name, hypotheses=hypotheses)
child_history: List[Dict[str, str]] = [
build_suspicion_system_msg(state),
{"role": "user", "content": FIXED_INTRO + avatar_prompt},
]
asq_history: List[Dict[str, str]] = []
current_summ = ""
current_topic = ""
round_count = 0
yield (f"[{child_name}] Round 1/ {rounds}: child 发起首轮对话 Hi ...", None)
child_reply = interview_child_groq(
groq_client=groq_client,
history=child_history,
question="Hi",
model=child_model,
)
for r in range(1, rounds + 1):
if r == 1:
yield (f"[{child_name}] Round {r}/{rounds}: 初次梳理 Summary 和 Topic... 正在思考", None)
elif (r - 1) % summary_topic_refresh_interval == 0:
yield (f"[{child_name}] Round {r}/{rounds}: 正在重新刷新 Summary 和 Topic...", None)
else:
yield (f"[{child_name}] Round {r}/{rounds}: 正在生成追问 Question...", None)
asq_history, current_summ, current_topic, round_count, question_text, question_latency_seconds = asq.step3_once(
user_input=child_reply,
history=asq_history,
bg=bg_for_asq,
hyp=hypotheses,
current_summ=current_summ,
current_topic=current_topic,
round_count=round_count,
refresh_interval=summary_topic_refresh_interval,
)
yield (
f"[{child_name}] Round {r}/{rounds}: ASQ question response time = {question_latency_seconds:.3f}s",
None,
)
interviewer_q_for_gui = (question_text or "").strip()
yield (f"[{child_name}] Round {r}/{rounds}: Interviewer -> {interviewer_q_for_gui}", None)
question_for_child = strip_interviewer_prefix(question_text)
if not question_for_child.strip():
question_for_child = interviewer_q_for_gui
yield (f"[{child_name}] Round {r}/{rounds}: 发送 question 给 child ...", None)
child_reply_next = interview_child_groq(
groq_client=groq_client,
history=child_history,
question=question_for_child,
model=child_model,
)
yield (f"[{child_name}] Round {r}/{rounds}: Child -> {child_reply_next}", None)
excel_row = r + 1
ws[f"A{excel_row}"] = current_summ
ws[f"B{excel_row}"] = current_topic
ws[f"C{excel_row}"] = question_text
ws[f"D{excel_row}"] = child_reply_next
ws[f"E{excel_row}"] = round(question_latency_seconds, 6)
wb.save(out_excel_path)
yield (f"[{child_name}] Round {r}/{rounds}: 已保存到 Excel", None)
child_reply = child_reply_next
yield (f"全部 child 的 {rounds} 轮访谈已完成,Excel 正在输出。", str(out_excel_path))
def run_batch_export_to_excel_sync(
out_excel_path: Path,
rounds: int,
summary_topic_refresh_interval: int,
on_status=None,
) -> Path:
last_path: Path | None = None
for status, excel_path in run_batch_export_to_excel(
out_excel_path=out_excel_path,
rounds=rounds,
summary_topic_refresh_interval=summary_topic_refresh_interval,
):
if on_status:
on_status(status)
if excel_path:
last_path = Path(excel_path)
if not last_path:
raise RuntimeError("Excel output path not produced.")
return last_path