# plan_extractor.py from __future__ import annotations from typing import Dict, Any, List, Optional import json, re # We’ll reuse your Cohere→HF fallback pattern from app.py from transformers import AutoTokenizer, AutoModelForCausalLM # only for local fallback def _cohere_chat_fn(): try: import cohere, os key = os.getenv("COHERE_API_KEY") if not key: return None return cohere.Client(api_key=key) except Exception: return None def draft_plan_from_scenario( scenario_text: str, column_bag: List[str], cohere_client=None, hf_tuple: Optional[tuple]=None, max_tokens: int = 800 ) -> Dict[str, Any]: """ Returns a JSON plan that is 100% scenario-derived. The plan includes: goals, required_inputs (metrics/entities/time windows), and an output_format hint. It may reference column candidates by *ideas* (not fixed labels). """ scenario = (scenario_text or "").strip() if not scenario: return { "goals": [], "requires": [], "output_format": "structured_analysis_v1", "notes": "Empty scenario text; no plan." } # Build prompt that feeds real columns to steer the plan dynamically col_hint = ", ".join(sorted(set(column_bag)))[:2000] sys = ( "You design a data analysis plan from a scenario. " "You do NOT assume undocumented numbers. " "You only request metrics that plausibly map to the provided column headers.\n" "Return STRICT JSON with keys: goals, requires, output_format.\n" "Each goal has a type (rank_top_n | summary_table | delta_over_time | capacity_calc | cost_total | custom), " "and parameters (e.g., metric names, groupings, n, filters, periods). " "Each requires item lists an input name and a description of how it could map to columns.\n" ) user = ( f"SCENARIO:\n{scenario}\n\n" f"AVAILABLE COLUMN HEADERS (from uploaded files, deduped):\n{col_hint}\n\n" "Produce the JSON plan now. Do not invent column names; propose inputs using phrases present in the scenario " "or clearly mappable to the provided headers." ) # Try Cohere first client = cohere_client or _cohere_chat_fn() if client is not None: try: resp = client.chat(model="command-r7b-12-2024", message=sys + "\n\n" + user, temperature=0.2, max_tokens=max_tokens) txt = getattr(resp, "text", None) or getattr(resp, "reply", None) if txt: # cohere may wrap in markdown; extract JSON block m = re.search(r"\{.*\}", txt, re.S) if m: return json.loads(m.group(0)) except Exception: pass # HF fallback (very lightweight) if hf_tuple is not None: model, tok = hf_tuple prompt = sys + "\n\n" + user + "\n\nJSON:" inpt = tok.apply_chat_template([{"role":"user","content":prompt}], tokenize=True, add_generation_prompt=True, return_tensors="pt") out = model.generate(inpt.to(model.device), max_new_tokens=max_tokens, do_sample=False) gen = tok.decode(out[0, inpt.shape[-1]:], skip_special_tokens=True) m = re.search(r"\{.*\}", gen, re.S) if m: try: return json.loads(m.group(0)) except Exception: pass # Ultra-conservative fallback return { "goals": [{"type":"summary_table","metrics":[],"by":[],"note":"fallback-empty"}], "requires": [], "output_format": "structured_analysis_v1" }