Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import random | |
| import time | |
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterator, List, Tuple | |
| import openpyxl | |
| import torch | |
| from docx import Document | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| try: | |
| from peft import PeftModel # type: ignore | |
| except Exception: | |
| PeftModel = None # type: ignore | |
| try: | |
| from groq import Groq as GroqSDK # type: ignore | |
| except Exception: | |
| GroqSDK = None # type: ignore | |
| _BASE_DIR = Path(__file__).resolve().parent | |
| _DEFAULT_CHILD_AVATARS_PY = _BASE_DIR / "child_avatars_v3_05022026_groq_gpt_oss_120b.py" | |
| _DEFAULT_ASQ_DOCX_DIR = _BASE_DIR / "System prompts" | |
| _DEFAULT_MODEL_HYPO = os.environ.get("HF_MODEL_HYPO", "LucasUn/ASQ-Hypotheses") | |
| _DEFAULT_MODEL_SUMM = os.environ.get("HF_MODEL_SUMM", "LucasUn/ASQ-Summary") | |
| _DEFAULT_MODEL_TOPIC = os.environ.get("HF_MODEL_TOPIC", "LucasUn/ASQ-Topic_Navi") | |
| _DEFAULT_MODEL_QUEST = os.environ.get("HF_MODEL_QUEST", "LucasUn/ASQ-Question") | |
| _DEFAULT_TOKENIZER_FALLBACK_ID = os.environ.get( | |
| "HF_TOKENIZER_FALLBACK_ID", | |
| "Qwen/Qwen3.5-27B", | |
| ) | |
| _DEFAULT_ADAPTER_BASE_MODEL_ID = os.environ.get( | |
| "HF_ADAPTER_BASE_MODEL_ID", | |
| "Qwen/Qwen3.5-27B", | |
| ) | |
| _DEFAULT_ADAPTER_BASE_MODEL_HYPO = os.environ.get( | |
| "HF_ADAPTER_BASE_MODEL_HYPO", | |
| _DEFAULT_ADAPTER_BASE_MODEL_ID, | |
| ) | |
| _DEFAULT_ADAPTER_BASE_MODEL_SUMM = os.environ.get( | |
| "HF_ADAPTER_BASE_MODEL_SUMM", | |
| _DEFAULT_ADAPTER_BASE_MODEL_ID, | |
| ) | |
| _DEFAULT_ADAPTER_BASE_MODEL_TOPIC = os.environ.get( | |
| "HF_ADAPTER_BASE_MODEL_TOPIC", | |
| _DEFAULT_ADAPTER_BASE_MODEL_ID, | |
| ) | |
| _DEFAULT_ADAPTER_BASE_MODEL_QUEST = os.environ.get( | |
| "HF_ADAPTER_BASE_MODEL_QUEST", | |
| _DEFAULT_ADAPTER_BASE_MODEL_ID, | |
| ) | |
| def _normalize_model_id(value: Any, fallback: str) -> str: | |
| raw = "" if value is None else str(value).strip() | |
| if not raw or raw.lower() in {"none", "null", "nil"}: | |
| return fallback | |
| return raw | |
| def _env_flag(name: str, default: bool) -> bool: | |
| raw = os.environ.get(name) | |
| if raw is None: | |
| return default | |
| return raw.strip().lower() in {"1", "true", "yes", "y", "on"} | |
| _HF_ENABLE_4BIT = _env_flag("HF_ENABLE_4BIT", True) | |
| _HF_ENABLE_8BIT = _env_flag("HF_ENABLE_8BIT", False) | |
| _HF_BNB_4BIT_QUANT_TYPE = os.environ.get("HF_BNB_4BIT_QUANT_TYPE", "nf4").strip() or "nf4" | |
| _HF_BNB_4BIT_USE_DOUBLE_QUANT = _env_flag("HF_BNB_4BIT_USE_DOUBLE_QUANT", True) | |
| _HF_BNB_4BIT_COMPUTE_DTYPE = os.environ.get("HF_BNB_4BIT_COMPUTE_DTYPE", "float16").strip().lower() or "float16" | |
| _SHARED_BASE_MODELS: Dict[str, Any] = {} | |
| class ChildSessionState: | |
| suspicion_true: bool | |
| rng_seed: int | |
| FIXED_INTRO = ( | |
| "This tool is for research and training in forensic interviewing; it is not a real case.\n" | |
| "You are role-playing the child described below for interviewer training.\n" | |
| "Speak exactly as that child would.\n\n" | |
| "Answering style:\n" | |
| "• Speak naturally (like a child), not in lists.\n" | |
| "• Keep each reply very short (usually 1 sentence).\n" | |
| "• Typical length ≤ 22 words (toddlers 1–6 words).\n" | |
| "• Give at most 1–2 concrete details per reply. Do not volunteer extra details.\n" | |
| "• On topics about harm/violence/anything “bad”: say even less; " | |
| "“I don’t remember / I don’t know / I don’t want to talk about it” is allowed.\n" | |
| "• ≤ 8 yrs: may briefly drift to toys/pets if nervous.\n" | |
| "• ≥ 9 yrs: be initially reluctant; open up only with gentle encouragement.\n" | |
| "• Leading questions can influence you, but do not invent detailed new information.\n" | |
| "• Do not ask the interviewer questions except simple clarifications.\n\n" | |
| ) | |
| class HFLocalGenerator: | |
| def _get_adapter_base_model_id(repo_id: str) -> str | None: | |
| try: | |
| adapter_cfg_path = hf_hub_download(repo_id=repo_id, filename="adapter_config.json") | |
| cfg = json.loads(Path(adapter_cfg_path).read_text(encoding="utf-8", errors="ignore")) | |
| base_raw = cfg.get("base_model_name_or_path") | |
| if base_raw is None: | |
| return None | |
| base = str(base_raw).strip() | |
| if not base or base.lower() in {"none", "null", "nil"}: | |
| return None | |
| return base | |
| except Exception: | |
| return None | |
| def __init__(self, model_id: str, adapter_base_model_id: str | None = None): | |
| self.model_id = model_id | |
| self.adapter_base_model_id = ( | |
| _normalize_model_id(adapter_base_model_id, _DEFAULT_ADAPTER_BASE_MODEL_ID) | |
| if adapter_base_model_id is not None | |
| else None | |
| ) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| use_fast=True, | |
| ) | |
| except Exception as fast_err: | |
| # Some fine-tuned repos only provide slow tokenizer assets. | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| use_fast=False, | |
| ) | |
| except Exception as slow_err: | |
| # Final fallback: use a known compatible base tokenizer. | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| _DEFAULT_TOKENIZER_FALLBACK_ID, | |
| trust_remote_code=True, | |
| use_fast=False, | |
| ) | |
| except Exception as fallback_err: | |
| raise RuntimeError( | |
| "Tokenizer 加载失败:fast/slow 两种模式都无法初始化," | |
| f"且回退 tokenizer `{_DEFAULT_TOKENIZER_FALLBACK_ID}` 也加载失败。" | |
| ) from fallback_err | |
| model_kwargs: Dict[str, Any] = {"trust_remote_code": True, "low_cpu_mem_usage": True} | |
| if self.device == "cuda": | |
| model_kwargs["device_map"] = {"": 0} # Force everything to first GPU to avoid accelerate hanging on multi-GPU/CPU split | |
| quant_enabled = _HF_ENABLE_4BIT or _HF_ENABLE_8BIT | |
| if quant_enabled: | |
| dtype_map = { | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32, | |
| } | |
| compute_dtype = dtype_map.get(_HF_BNB_4BIT_COMPUTE_DTYPE, torch.float16) | |
| if _HF_ENABLE_4BIT: | |
| model_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type=_HF_BNB_4BIT_QUANT_TYPE, | |
| bnb_4bit_use_double_quant=_HF_BNB_4BIT_USE_DOUBLE_QUANT, | |
| bnb_4bit_compute_dtype=compute_dtype, | |
| ) | |
| else: | |
| model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) | |
| else: | |
| model_kwargs["torch_dtype"] = torch.float16 | |
| else: | |
| model_kwargs.update({ | |
| "device_map": "cpu", | |
| "torch_dtype": torch.float32, | |
| }) | |
| import gc | |
| try: | |
| # For testing/monitoring when loading huge files without hanging UI | |
| print(f"[{model_id}] Starting model load from HF...", flush=True) | |
| # Forcing accelerate to completely avoid CPU offloading which causes OOM Killer | |
| if "device_map" in model_kwargs: | |
| # Give it generous RAM limit but absolutely block swapping beyond that | |
| model_kwargs["max_memory"] = {0: "78GB", "cpu": "80GB"} | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| **model_kwargs | |
| ) | |
| print(f"[{model_id}] Successfully loaded into memory.", flush=True) | |
| self.is_peft = False | |
| # Force garbage collection after each large model load to free up unused CPU RAM | |
| gc.collect() | |
| except Exception as model_err: | |
| raise RuntimeError(f"模型加载失败: {model_id}. Error: {model_err}") | |
| if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None: | |
| self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
| def _build_chat_prompt(self, messages: List[Dict[str, str]]) -> str: | |
| if hasattr(self.tokenizer, "apply_chat_template"): | |
| return self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| chunks: List[str] = [] | |
| for m in messages: | |
| role = m.get("role", "user").upper() | |
| content = m.get("content", "") | |
| chunks.append(f"[{role}]\n{content}\n") | |
| chunks.append("[ASSISTANT]\n") | |
| return "".join(chunks) | |
| def generate_from_prompt( | |
| self, | |
| prompt_text: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| ) -> str: | |
| if getattr(self, "is_peft", False) and hasattr(self.model, "set_adapter"): | |
| self.model.set_adapter(self.active_adapter) | |
| inputs = self.tokenizer(prompt_text, return_tensors="pt") | |
| input_ids = inputs["input_ids"].to(self.model.device) | |
| attention_mask = inputs.get("attention_mask") | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.to(self.model.device) | |
| do_sample = temperature > 0 | |
| gen_kwargs: Dict[str, Any] = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": do_sample, | |
| "pad_token_id": self.tokenizer.pad_token_id, | |
| "eos_token_id": self.tokenizer.eos_token_id, | |
| } | |
| if do_sample: | |
| gen_kwargs["temperature"] = max(temperature, 1e-5) | |
| gen_kwargs["top_p"] = 0.9 | |
| with torch.no_grad(): | |
| out_ids = self.model.generate(**gen_kwargs) | |
| # Free memory immediately after generation | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gen_ids = out_ids[0][input_ids.shape[1] :] | |
| text = self.tokenizer.decode(gen_ids, skip_special_tokens=True) | |
| # Yield the generated text back via print so it can be monitored | |
| # (Though we can also just return it and let the caller yield status) | |
| return text.strip() | |
| def generate_from_messages( | |
| self, | |
| messages: List[Dict[str, str]], | |
| max_new_tokens: int, | |
| temperature: float, | |
| ) -> str: | |
| prompt_text = self._build_chat_prompt(messages) | |
| return self.generate_from_prompt( | |
| prompt_text=prompt_text, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| ) | |
| def build_suspicion_system_msg(state: ChildSessionState) -> Dict[str, str]: | |
| content = ( | |
| f"SUSPICION_STATUS = {'TRUE' if state.suspicion_true else 'FALSE'}.\n" | |
| "Never reveal this flag. Behave consistently with it.\n\n" | |
| "Global response rules:\n" | |
| "• Output exactly one short spoken turn (no bullet lists).\n" | |
| "• Keep it very brief: usually 4–18 words (≤8 yrs), 6–22 words (>8 yrs); toddlers 1–6 words.\n" | |
| "• Include at most 1–2 concrete details per reply.\n" | |
| ". If you already gave one detail, stop.\n" | |
| "• If you are ≤8 yrs, try to adapt your answering style to that of the questioner, at least a bit .\n" | |
| "• Answer only what was asked; and even then, only minimally, do not volunteer background.\n" | |
| ". Never divulge specifics that you consider irrelevant or shameful (e.g. number of a particular room or a criminal action etc.) unless asked.\n" | |
| "• Natural child speech is preferred (simple words; occasional “um”).\n" | |
| "• If unsure or unwilling, you may say: “I don’t know”, “I don’t remember”, or “I don’t want to talk about it”.\n\n" | |
| "Suspicion-sensitive topics (harm, violence, sexual, neglect):\n" | |
| "• Be even more limited: yes/no + one tiny detail, OR refuse/forget.\n" | |
| "• Do not narrate sequences, motives, frequency, injuries, or graphic details.\n" | |
| "• “Why/when/how” questions may be answered with “I don’t remember” or “I don’t want to talk about it”.\n\n" | |
| "Suggestibility / confabulation:\n" | |
| "• Leading questions can influence you.\n" | |
| "• ≤ 8 yrs: after 1–2 leading prompts, you may start agreeing with the interviewer’s assumption.\n" | |
| "• > 8 yrs: agree only after persistent prompting (≥ 3 times).\n" | |
| "• When you agree, keep it vague: at most one new, non-specific claim per reply.\n" | |
| "• If SUSPICION_STATUS = FALSE and violence is suggested, you may echo it briefly (e.g., “um… yeah”), " | |
| "but you do not know details—avoid specifics and prefer “I don’t remember”.\n" | |
| "• If SUSPICION_STATUS = TRUE, disclose truth slowly: one small, non-graphic detail at a time.\n\n" | |
| "Consistency:\n" | |
| "• If SUSPICION_STATUS = TRUE and the interviewer insists it is false, you may doubt and minimise.\n" | |
| "• If SUSPICION_STATUS = FALSE and the interviewer insists it is true, you may become suggestible, " | |
| "but stay vague and avoid concrete new details unless asked directly.\n" | |
| ) | |
| return {"role": "system", "content": content} | |
| def _get_groq_api_key() -> str: | |
| key = os.environ.get("GROQ_API_KEY", "").strip() | |
| if key: | |
| return key | |
| raise RuntimeError("缺少 GROQ_API_KEY,请在 Hugging Face Space Secrets 中配置。") | |
| def interview_child( | |
| llm: HFLocalGenerator, | |
| history: List[Dict[str, str]], | |
| question: str, | |
| ) -> str: | |
| """ | |
| Mutates `history` in-place: | |
| - append {"role":"user","content": question} | |
| - append {"role":"assistant","content": reply} | |
| """ | |
| history.append({"role": "user", "content": question}) | |
| last_reply = None | |
| for msg in reversed(history): | |
| if msg.get("role") == "assistant": | |
| last_reply = msg.get("content") | |
| break | |
| exc: Exception | None = None | |
| for attempt in range(3): | |
| try: | |
| msgs = history | |
| if attempt > 0 and last_reply is not None: | |
| msgs = history + [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "Produce exactly one concise conversational response. " | |
| "Do not list alternatives. Do not repeat your previous wording." | |
| ), | |
| } | |
| ] | |
| reply = llm.generate_from_messages( | |
| messages=msgs, | |
| temperature=random.choice([0.1, 0.2, 0.4]), | |
| max_new_tokens=96, | |
| ) | |
| if reply in { | |
| "I’m sorry, but I can’t continue this conversation.", | |
| "I’m sorry, but I can’t answer that.", | |
| "I don’t want to talk about that.", | |
| }: | |
| reply = "[silence]" | |
| if not reply or not reply.strip(): | |
| continue | |
| if reply == last_reply and attempt < 2: | |
| continue | |
| history.append({"role": "assistant", "content": reply}) | |
| return reply | |
| except Exception as e: | |
| exc = e | |
| time.sleep(0.2 * (2**attempt)) | |
| if exc: | |
| msg = str(exc).strip().replace("\n", " ") | |
| if len(msg) > 160: | |
| msg = msg[:160] + "..." | |
| return f"[child error: {type(exc).__name__}: {msg}]" | |
| return "[child error: empty/invalid reply after retries]" | |
| def interview_child_groq( | |
| groq_client: Any, | |
| history: List[Dict[str, str]], | |
| question: str, | |
| model: str, | |
| ) -> str: | |
| history.append({"role": "user", "content": question}) | |
| last_reply = None | |
| for msg in reversed(history): | |
| if msg.get("role") == "assistant": | |
| last_reply = msg.get("content") | |
| break | |
| exc: Exception | None = None | |
| for attempt in range(3): | |
| try: | |
| msgs = history | |
| if attempt > 0 and last_reply is not None: | |
| msgs = history + [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "Produce exactly one concise conversational response. " | |
| "Do not list alternatives. Do not repeat your previous wording." | |
| ), | |
| } | |
| ] | |
| resp = groq_client.chat.completions.create( | |
| model=model, | |
| messages=msgs, | |
| temperature=random.choice([0.1, 0.2, 0.4]), | |
| max_tokens=300, | |
| ) | |
| choice0 = resp.choices[0] | |
| reply = getattr(choice0.message, "content", "") or "" | |
| reply = str(reply).strip() | |
| if reply in { | |
| "I’m sorry, but I can’t continue this conversation.", | |
| "I’m sorry, but I can’t answer that.", | |
| "I don’t want to talk about that.", | |
| }: | |
| reply = "[silence]" | |
| if not reply: | |
| continue | |
| if reply == last_reply and attempt < 2: | |
| continue | |
| history.append({"role": "assistant", "content": reply}) | |
| return reply | |
| except Exception as e: | |
| exc = e | |
| time.sleep(0.2 * (2**attempt)) | |
| if exc: | |
| msg = str(exc).strip().replace("\n", " ") | |
| if len(msg) > 160: | |
| msg = msg[:160] + "..." | |
| return f"[child error: {type(exc).__name__}: {msg}]" | |
| return "[child error: empty/invalid reply after retries]" | |
| def strip_interviewer_prefix(question_text: str) -> str: | |
| if not question_text: | |
| return "" | |
| s = question_text.strip() | |
| if s.startswith("Interviewer:"): | |
| return s[len("Interviewer:") :].lstrip() | |
| return s | |
| def extract_avatra_prompts_from_file(child_avatars_py: str) -> Dict[str, str]: | |
| p = Path(child_avatars_py) | |
| text = p.read_text(encoding="utf-8", errors="ignore") | |
| start = text.find("AVATAR_PROMPTS = {") | |
| if start == -1: | |
| raise RuntimeError("Could not locate AVATAR_PROMPTS in child file.") | |
| marker = "\n}\n\n\"\"\"\n# Virtual Interview System" | |
| end = text.find(marker, start) | |
| if end == -1: | |
| marker = "\n}\n\n\"\"\"\n# Virtual Interview System (do not edit)" | |
| end = text.find(marker, start) | |
| if end == -1: | |
| raise RuntimeError("Could not locate end of AVATAR_PROMPTS dict.") | |
| snippet = text[start : end + 2] | |
| ns: Dict[str, Any] = {} | |
| exec(snippet, ns) | |
| prompts = ns.get("AVATAR_PROMPTS") | |
| if not isinstance(prompts, dict) or not prompts: | |
| raise RuntimeError("Failed to parse AVATAR_PROMPTS.") | |
| return prompts | |
| def extract_suspicion_block(avatar_prompt: str) -> str: | |
| m = re.search(r"(?is)\bsuspicion\s*:\s*", avatar_prompt) | |
| if not m: | |
| m = re.search(r"(?is)\bThe suspicion\s*:\s*", avatar_prompt) | |
| if not m: | |
| raise RuntimeError("Could not find suspicion section in avatar prompt.") | |
| block = avatar_prompt[m.end() :].strip() | |
| if not block: | |
| raise RuntimeError("Suspicion block is empty.") | |
| return block | |
| class AsqEngine: | |
| def __init__(self, docx_dir: Path, llms: Dict[str, HFLocalGenerator]): | |
| self._SYSTEM_PROMPTS_DIR = docx_dir | |
| self._llms = llms | |
| self._SYSTEM_PROMPT_HYPO = self._read_docx_text(self._SYSTEM_PROMPTS_DIR / "Hypotheses prompt.docx") | |
| self._SYSTEM_PROMPT_SUMM = self._read_docx_text(self._SYSTEM_PROMPTS_DIR / "Summary prompt.docx") | |
| self._SYSTEM_PROMPT_TOPIC = self._read_docx_text(self._SYSTEM_PROMPTS_DIR / "Topic prompt.docx") | |
| self._SYSTEM_PROMPT_QUEST = self._read_docx_text(self._SYSTEM_PROMPTS_DIR / "Question prompt.docx") | |
| def _read_docx_text(self, path: Path) -> str: | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Missing prompt docx: {path}") | |
| doc = Document(str(path)) | |
| parts: List[str] = [] | |
| for para in doc.paragraphs: | |
| txt = (para.text or "").strip() | |
| if txt: | |
| parts.append(txt) | |
| text = "\n".join(parts).strip() | |
| if not text: | |
| raise ValueError(f"Prompt docx is empty: {path}") | |
| return text | |
| def render_prompt(self, system_prompt: str, user_content: str) -> str: | |
| parts = [] | |
| if system_prompt.strip(): | |
| parts.append("[SYSTEM]\n" + system_prompt.strip() + "\n") | |
| parts.append("[USER]\n" + user_content.strip() + "\n") | |
| parts.append("[ASSISTANT]\n") | |
| return "".join(parts) | |
| def generate_response(self, system_prompt: str, user_prompt: str, model_key: str) -> str: | |
| llm = self._llms[model_key] | |
| prompt_text = self.render_prompt(system_prompt, user_prompt) | |
| # Add a tiny delay to ensure device sync if running on multiple GPUs or busy instances | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| return llm.generate_from_prompt( | |
| prompt_text=prompt_text, | |
| max_new_tokens=512, | |
| temperature=0.0, | |
| ) | |
| def _extract_section_content(self, text: str, section_heading: str) -> str: | |
| pattern = re.compile(rf"(?is){re.escape(section_heading)}\s*(.*?)(?=\n###\s+|$)") | |
| m = pattern.search(text) | |
| if not m: | |
| return "" | |
| return m.group(1).strip() | |
| def sanitize_summary(self, summary_text: str) -> str: | |
| if not summary_text: | |
| return "" | |
| heading_summary = "### Summary of the ongoing interview:" | |
| heading_missing = "### Missing Information:" | |
| summary_part = self._extract_section_content(summary_text, heading_summary) | |
| missing_part = self._extract_section_content(summary_text, heading_missing) | |
| if not summary_part and not missing_part: | |
| return summary_text.strip() | |
| if not summary_part: | |
| summary_part = "N/A" | |
| if not missing_part: | |
| missing_part = "N/A" | |
| return ( | |
| f"{heading_summary}\n" | |
| f"{summary_part}\n\n" | |
| f"{heading_missing}\n" | |
| f"{missing_part}" | |
| ).strip() | |
| def sanitize_topic(self, topic_text: str) -> str: | |
| if not topic_text: | |
| return "" | |
| heading = "### Topic navigation:" | |
| required_keys = [ | |
| "Active_Topic", | |
| "Exploration_Depth", | |
| "Recommended_Question_Type", | |
| "Target_Missing_Detail", | |
| "Reasoning", | |
| ] | |
| lines = topic_text.splitlines() | |
| start_idx = None | |
| for i, line in enumerate(lines): | |
| if line.strip().lower() == heading.lower(): | |
| start_idx = i | |
| break | |
| if start_idx is None: | |
| return topic_text.strip() | |
| parsed: Dict[str, str] = {} | |
| current_key = "" | |
| for raw in lines[start_idx + 1 :]: | |
| line = raw.strip() | |
| if not line: | |
| continue | |
| if line.startswith("### "): | |
| break | |
| matched_key = None | |
| for key in required_keys: | |
| prefix = f"{key}:" | |
| if line.startswith(prefix): | |
| matched_key = key | |
| parsed[key] = line[len(prefix) :].strip() | |
| current_key = key | |
| break | |
| if matched_key is not None: | |
| continue | |
| if current_key: | |
| parsed[current_key] = (parsed[current_key] + " " + line).strip() | |
| kept_lines: List[str] = [heading] | |
| for key in required_keys: | |
| value = parsed.get(key, "").strip() or "N/A" | |
| kept_lines.append(f"{key}: {value}") | |
| return "\n".join(kept_lines).strip() | |
| def step1_generate_hypo(self, background: str) -> str: | |
| if not background.strip(): | |
| return "請先输入案件背景信息。" | |
| return self.generate_response(self._SYSTEM_PROMPT_HYPO, background, "hypo") | |
| def format_context_for_prompt(self, history: List[Dict[str, str]], max_rounds: int = 15) -> str: | |
| max_msgs = max_rounds * 2 | |
| recent_history = history[-max_msgs:] if len(history) > max_msgs else history | |
| context_str = "" | |
| for i in range(0, len(recent_history) - 1, 2): | |
| msg1 = recent_history[i] | |
| msg2 = recent_history[i + 1] | |
| child_msg = msg1.get("content", "") if isinstance(msg1, dict) else "" | |
| interviewer_msg = msg2.get("content", "") if isinstance(msg2, dict) else "" | |
| context_str += f"Child: {child_msg}\nInterviewer: {interviewer_msg}\n" | |
| return context_str.strip() | |
| def step3_once( | |
| self, | |
| user_input: str, | |
| history: List[Dict[str, str]], | |
| bg: str, | |
| hyp: str, | |
| current_summ: str, | |
| current_topic: str, | |
| round_count: int, | |
| refresh_interval: int, | |
| ) -> Tuple[List[Dict[str, str]], str, str, int, str, float]: | |
| if not user_input.strip(): | |
| return history, current_summ, current_topic, round_count, "", 0.0 | |
| round_count += 1 | |
| formatted_user_input = f"Child: {user_input}" | |
| history.append({"role": "user", "content": user_input}) | |
| history.append({"role": "assistant", "content": "[thinking]"}) | |
| if round_count == 1: | |
| summ_user_prompt = "No interview at this stage" | |
| current_summ = self.sanitize_summary(self.generate_response(self._SYSTEM_PROMPT_SUMM, summ_user_prompt, "summ")) | |
| topic_user_prompt = ( | |
| f"Background:\n{bg}\n\nHypotheses:\n{hyp}\n\nSummary:\n{current_summ}\n\nRecent Context:\nNo context yet." | |
| ) | |
| current_topic = self.sanitize_topic(self.generate_response(self._SYSTEM_PROMPT_TOPIC, topic_user_prompt, "topic")) | |
| elif round_count > 1 and (round_count - 1) % refresh_interval == 0: | |
| all_context_str = self.format_context_for_prompt(history[:-1]) + f"\n{formatted_user_input}" | |
| current_summ = self.sanitize_summary(self.generate_response(self._SYSTEM_PROMPT_SUMM, all_context_str, "summ")) | |
| topic_user_prompt = ( | |
| f"Background:\n{bg}\n\nHypotheses:\n{hyp}\n\nSummary:\n{current_summ}\n\n" | |
| f"Recent Context:\n{self.format_context_for_prompt(history[:-1], 15)}\n{formatted_user_input}" | |
| ) | |
| current_topic = self.sanitize_topic(self.generate_response(self._SYSTEM_PROMPT_TOPIC, topic_user_prompt, "topic")) | |
| quest_user_prompt = ( | |
| f"Background:\n{bg}\n\n" | |
| f"Hypotheses:\n{hyp}\n\n" | |
| f"Summary:\n{current_summ}\n\n" | |
| f"Recent Context:\n{self.format_context_for_prompt(history[:-1], 15)}\n\n" | |
| f"Topic Navigation:\n{current_topic}\n\n" | |
| f"Current User Input:\n{formatted_user_input}" | |
| ) | |
| quest_start_ts = time.perf_counter() | |
| raw_response = self.generate_response(self._SYSTEM_PROMPT_QUEST, quest_user_prompt, "quest") | |
| question_latency_seconds = time.perf_counter() - quest_start_ts | |
| paragraphs = [p.strip() for p in raw_response.split("\n") if p.strip()] | |
| first_paragraph = paragraphs[0] if paragraphs else "Can you tell me more about that?" | |
| if not first_paragraph.startswith("Interviewer:"): | |
| model_reply = f"Interviewer: {first_paragraph}" | |
| else: | |
| model_reply = first_paragraph | |
| history[-1]["content"] = model_reply | |
| return history, current_summ, current_topic, round_count, model_reply, question_latency_seconds | |
| def decide_suspicion_true(force_value: str | None) -> bool: | |
| if not force_value: | |
| return bool(random.getrandbits(1)) | |
| s = force_value.strip().lower() | |
| if s in {"true", "1", "yes", "y"}: | |
| return True | |
| if s in {"false", "0", "no", "n"}: | |
| return False | |
| raise ValueError("Invalid FORCE_SUSPICION, expected true/false.") | |
| def safe_excel_sheet_title(title: str) -> str: | |
| title = (title or "").strip() | |
| if not title: | |
| title = "child" | |
| if len(title) > 31: | |
| title = title[:28] + "..." | |
| title = re.sub(r"[:\\/?*\[\]]", "_", title) | |
| return title | |
| def init_child_worksheet(ws, child_name: str, hypotheses: str) -> None: | |
| del child_name | |
| ws["A1"] = "Summary" | |
| ws["B1"] = "Target Navigation" | |
| ws["C1"] = hypotheses | |
| ws["D1"] = "Child Answer" | |
| ws["E1"] = "ASQ Question Response Time (s)" | |
| def run_batch_export_to_excel( | |
| out_excel_path: Path, | |
| rounds: int = 20, | |
| summary_topic_refresh_interval: int = 3, | |
| hf_model_hypo: str = _DEFAULT_MODEL_HYPO, | |
| hf_model_summ: str = _DEFAULT_MODEL_SUMM, | |
| hf_model_topic: str = _DEFAULT_MODEL_TOPIC, | |
| hf_model_quest: str = _DEFAULT_MODEL_QUEST, | |
| child_groq_model: str = os.environ.get("GROQ_CHILD_MODEL", "openai/gpt-oss-120b"), | |
| child_avatars_py: str | Path = _DEFAULT_CHILD_AVATARS_PY, | |
| asq_docx_dir: Path | None = None, | |
| ) -> Iterator[Tuple[str, str | None]]: | |
| if summary_topic_refresh_interval < 1: | |
| raise ValueError("summary_topic_refresh_interval 必须 >= 1。") | |
| yield ("正在解析四个 child 的头像提示词与 suspicion 字段...", None) | |
| child_path = str(Path(child_avatars_py)) | |
| avatar_prompts = extract_avatra_prompts_from_file(child_path) | |
| child_items = list(avatar_prompts.items()) | |
| if len(child_items) < 4: | |
| raise RuntimeError(f"Expected 4 children in avatar prompts, got {len(child_items)}.") | |
| child_items = child_items[:4] | |
| if asq_docx_dir is None: | |
| asq_docx_dir = _DEFAULT_ASQ_DOCX_DIR | |
| selected_models = { | |
| "hypo": _normalize_model_id(hf_model_hypo, _DEFAULT_MODEL_HYPO), | |
| "summ": _normalize_model_id(hf_model_summ, _DEFAULT_MODEL_SUMM), | |
| "topic": _normalize_model_id(hf_model_topic, _DEFAULT_MODEL_TOPIC), | |
| "quest": _normalize_model_id(hf_model_quest, _DEFAULT_MODEL_QUEST), | |
| } | |
| adapter_base_models = { | |
| "hypo": _normalize_model_id(_DEFAULT_ADAPTER_BASE_MODEL_HYPO, _DEFAULT_ADAPTER_BASE_MODEL_ID), | |
| "summ": _normalize_model_id(_DEFAULT_ADAPTER_BASE_MODEL_SUMM, _DEFAULT_ADAPTER_BASE_MODEL_ID), | |
| "topic": _normalize_model_id(_DEFAULT_ADAPTER_BASE_MODEL_TOPIC, _DEFAULT_ADAPTER_BASE_MODEL_ID), | |
| "quest": _normalize_model_id(_DEFAULT_ADAPTER_BASE_MODEL_QUEST, _DEFAULT_ADAPTER_BASE_MODEL_ID), | |
| } | |
| loaded_generators: Dict[str, HFLocalGenerator] = {} | |
| model_cache_by_id: Dict[Tuple[str, str], HFLocalGenerator] = {} | |
| # NEW FIX: Instead of keeping all models in memory, load one, generate, and unload it. | |
| # We will initialize them lazily in AsqEngine later if needed, but for now we must | |
| # clear memory correctly or delay their load. Wait, ASQEngine expects them all loaded. | |
| # Pre-download all models first to avoid hitting connection/cache locks during loading | |
| yield ("正在预下载所有模型文件到本地缓存,这可能需要一段时间...", None) | |
| from huggingface_hub import snapshot_download | |
| for key in ["hypo", "summ", "topic", "quest"]: | |
| model_id = selected_models[key] | |
| yield (f"正在检查/预下载模型: {model_id}", None) | |
| try: | |
| snapshot_download(repo_id=model_id, allow_patterns=["*.json", "*.safetensors", "*.jinja"]) | |
| except Exception as e: | |
| yield (f"预下载 {model_id} 时出现警告 (将尝试继续): {e}", None) | |
| yield ("预下载完成,开始顺序加载模型到显存...", None) | |
| for key in ["hypo", "summ", "topic", "quest"]: | |
| model_id = selected_models[key] | |
| base_id = adapter_base_models[key] | |
| cache_key = (model_id, base_id) | |
| if cache_key in model_cache_by_id: | |
| loaded_generators[key] = model_cache_by_id[cache_key] | |
| continue | |
| yield ( | |
| f"正在加载 `{key}` 模块模型:`{model_id}`(adapter base fallback: `{base_id}`)...", | |
| None, | |
| ) | |
| gen = HFLocalGenerator(model_id=model_id, adapter_base_model_id=base_id) | |
| model_cache_by_id[cache_key] = gen | |
| loaded_generators[key] = gen | |
| groq_key = _get_groq_api_key() | |
| if GroqSDK is None: | |
| raise RuntimeError("未安装 groq SDK,请在 requirements.txt 中添加 groq。") | |
| groq_client = GroqSDK(api_key=groq_key) | |
| child_model = _normalize_model_id(child_groq_model, "openai/gpt-oss-120b") | |
| asq = AsqEngine(docx_dir=asq_docx_dir, llms=loaded_generators) | |
| yield ( | |
| "本地模型已就绪,开始批量访谈...\n" | |
| f"hypo={selected_models['hypo']}\n" | |
| f"summ={selected_models['summ']}\n" | |
| f"topic={selected_models['topic']}\n" | |
| f"quest={selected_models['quest']}\n" | |
| f"child_groq={child_model}", | |
| None, | |
| ) | |
| force_suspicion = os.environ.get("FORCE_SUSPICION") | |
| wb = openpyxl.Workbook() | |
| wb.remove(wb.active) | |
| out_excel_path.parent.mkdir(parents=True, exist_ok=True) | |
| for child_idx, (child_name, avatar_prompt) in enumerate(child_items, start=1): | |
| yield (f"正在处理 child {child_idx}/4: **{child_name}** ...", None) | |
| yield (f"正在处理 child {child_idx}/4: **{child_name}** ...\n" | |
| f"-> 第一步:根据背景故事生成侦讯假设 (Hypotheses)。\n" | |
| f" (A100 上首次启动推理,或者如果未合并 4bit LoRA,这一步可能长达 5-10 分钟,请耐心保持页面开启!)", None) | |
| suspicion_true = decide_suspicion_true(force_suspicion) | |
| suspicion_seed = random.randrange(2**32) | |
| state = ChildSessionState(suspicion_true=suspicion_true, rng_seed=suspicion_seed) | |
| bg_for_asq = extract_suspicion_block(avatar_prompt) | |
| hyp_start_ts = time.perf_counter() | |
| hypotheses = asq.step1_generate_hypo(bg_for_asq) | |
| hyp_latency = time.perf_counter() - hyp_start_ts | |
| yield (f"[{child_name}] ✅ 假设 (Hypotheses) 生成完毕!耗时: {hyp_latency:.1f} 秒。\n" | |
| f"内容预览: {hypotheses[:150]}...\n", None) | |
| ws = wb.create_sheet(title=safe_excel_sheet_title(child_name)) | |
| init_child_worksheet(ws, child_name=child_name, hypotheses=hypotheses) | |
| child_history: List[Dict[str, str]] = [ | |
| build_suspicion_system_msg(state), | |
| {"role": "user", "content": FIXED_INTRO + avatar_prompt}, | |
| ] | |
| asq_history: List[Dict[str, str]] = [] | |
| current_summ = "" | |
| current_topic = "" | |
| round_count = 0 | |
| yield (f"[{child_name}] Round 1/ {rounds}: child 发起首轮对话 Hi ...", None) | |
| child_reply = interview_child_groq( | |
| groq_client=groq_client, | |
| history=child_history, | |
| question="Hi", | |
| model=child_model, | |
| ) | |
| for r in range(1, rounds + 1): | |
| if r == 1: | |
| yield (f"[{child_name}] Round {r}/{rounds}: 初次梳理 Summary 和 Topic... 正在思考", None) | |
| elif (r - 1) % summary_topic_refresh_interval == 0: | |
| yield (f"[{child_name}] Round {r}/{rounds}: 正在重新刷新 Summary 和 Topic...", None) | |
| else: | |
| yield (f"[{child_name}] Round {r}/{rounds}: 正在生成追问 Question...", None) | |
| asq_history, current_summ, current_topic, round_count, question_text, question_latency_seconds = asq.step3_once( | |
| user_input=child_reply, | |
| history=asq_history, | |
| bg=bg_for_asq, | |
| hyp=hypotheses, | |
| current_summ=current_summ, | |
| current_topic=current_topic, | |
| round_count=round_count, | |
| refresh_interval=summary_topic_refresh_interval, | |
| ) | |
| yield ( | |
| f"[{child_name}] Round {r}/{rounds}: ASQ question response time = {question_latency_seconds:.3f}s", | |
| None, | |
| ) | |
| interviewer_q_for_gui = (question_text or "").strip() | |
| yield (f"[{child_name}] Round {r}/{rounds}: Interviewer -> {interviewer_q_for_gui}", None) | |
| question_for_child = strip_interviewer_prefix(question_text) | |
| if not question_for_child.strip(): | |
| question_for_child = interviewer_q_for_gui | |
| yield (f"[{child_name}] Round {r}/{rounds}: 发送 question 给 child ...", None) | |
| child_reply_next = interview_child_groq( | |
| groq_client=groq_client, | |
| history=child_history, | |
| question=question_for_child, | |
| model=child_model, | |
| ) | |
| yield (f"[{child_name}] Round {r}/{rounds}: Child -> {child_reply_next}", None) | |
| excel_row = r + 1 | |
| ws[f"A{excel_row}"] = current_summ | |
| ws[f"B{excel_row}"] = current_topic | |
| ws[f"C{excel_row}"] = question_text | |
| ws[f"D{excel_row}"] = child_reply_next | |
| ws[f"E{excel_row}"] = round(question_latency_seconds, 6) | |
| wb.save(out_excel_path) | |
| yield (f"[{child_name}] Round {r}/{rounds}: 已保存到 Excel", None) | |
| child_reply = child_reply_next | |
| yield (f"全部 child 的 {rounds} 轮访谈已完成,Excel 正在输出。", str(out_excel_path)) | |
| def run_batch_export_to_excel_sync( | |
| out_excel_path: Path, | |
| rounds: int, | |
| summary_topic_refresh_interval: int, | |
| on_status=None, | |
| ) -> Path: | |
| last_path: Path | None = None | |
| for status, excel_path in run_batch_export_to_excel( | |
| out_excel_path=out_excel_path, | |
| rounds=rounds, | |
| summary_topic_refresh_interval=summary_topic_refresh_interval, | |
| ): | |
| if on_status: | |
| on_status(status) | |
| if excel_path: | |
| last_path = Path(excel_path) | |
| if not last_path: | |
| raise RuntimeError("Excel output path not produced.") | |
| return last_path | |