# constraint_parser_llm.py from __future__ import annotations import json, re, os from dataclasses import dataclass, asdict from typing import Dict, List, Optional from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline # Choose one (Llama3.1 is slightly better at structured JSON) MODEL_NAME = os.environ.get("PARSER_MODEL", "meta-llama/Llama-3.1-8B-Instruct") # Alternative: # MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3" _tokenizer = None _pipe = None def _lazy_pipe(): global _tokenizer, _pipe if _pipe is None: _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", # works CPU/GPU trust_remote_code=True, ) _pipe = pipeline( "text-generation", model=model, tokenizer=_tokenizer, max_new_tokens=256, temperature=0.0, do_sample=False, ) return _pipe DAY_ABBR = ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"] DAY_ALIASES = { "sunday":"Sun","sundays":"Sun","sun":"Sun", "monday":"Mon","mondays":"Mon","mon":"Mon", "tuesday":"Tue","tuesdays":"Tue","tue":"Tue", "wednesday":"Wed","wednesdays":"Wed","wed":"Wed", "thursday":"Thu","thursdays":"Thu","thu":"Thu", "friday":"Fri","fridays":"Fri","fri":"Fri", "saturday":"Sat","saturdays":"Sat","sat":"Sat", } @dataclass class ParsedConstraints: subject_counts: Dict[str, int] banned_days: List[str] no_before: Optional[str] = None # "10:00 AM" no_after: Optional[str] = None # "6:00 PM" keywords: List[str] = None banned_professors: List[str] = None def to_json(self) -> str: return json.dumps(asdict(self)) SYSTEM_PROMPT = """You convert natural language scheduling requests into STRICT JSON. Output only JSON with this schema: { "subject_counts": {"": }, // optional "banned_days": ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"], // optional "no_before": "H:MM AM/PM", // optional "no_after": "H:MM AM/PM", // optional "keywords": ["..."], // optional "banned_professors": ["..."] // optional } Rules: - Never include commentary. - Use only "Sun","Mon","Tue","Wed","Thu","Fri","Sat". - If the user says 'no weekends', add 'Sat' and 'Sun'. - If they say 'mornings'/'afternoons'/'evenings', map to times: * mornings => no_before >= "10:00 AM" (i.e., prefer starts at/after 10am) * afternoons => no_before >= "12:00 PM" * evenings => no_before >= "4:00 PM" * 'no mornings' => no_after <= "12:00 PM" - If counts are implied (e.g., "a couple of CS classes"), interpret: couple=2, few=3, several=3. - If ambiguous, make reasonable assumptions but keep to schema. """ FEW_SHOTS = [ # (user, json) ("make me 4 CS classes, no Sundays", {"subject_counts":{"Computer Science":4},"banned_days":["Sun"]}), ("i want two econ and one psychology after 1pm, avoid friday", {"subject_counts":{"Economics":2,"Psychology":1},"no_before":"1:00 PM","banned_days":["Fri"]}), ("prefer software engineering and algorithms, no class on mon or tue mornings", {"keywords":["software engineering","algorithms"],"banned_days":["Mon","Tue"],"no_after":"12:00 PM"}), ("no weekends, evenings only", {"banned_days":["Sat","Sun"],"no_before":"4:00 PM"}), ("skip wed; a couple math; no classes before 10am", {"subject_counts":{"Mathematics":2},"banned_days":["Wed"],"no_before":"10:00 AM"}), ] def _build_prompt(user_text: str) -> str: parts = [SYSTEM_PROMPT, "\nExamples:"] for u, js in FEW_SHOTS: parts.append(f"User: {u}\nJSON: {json.dumps(js)}") parts.append(f"\nUser: {user_text}\nJSON:") return "\n".join(parts) def _extract_json(text: str) -> Optional[dict]: text = text.strip() # Some models wrap with text; find first/last braces start = text.find("{") end = text.rfind("}") if start != -1 and end != -1 and end > start: try: return json.loads(text[start:end+1]) except Exception: pass # Sometimes models produce ```json blocks m = re.search(r"```json\s*(\{.*?\})\s*```", text, re.S) if m: try: return json.loads(m.group(1)) except Exception: pass return None def _normalize_days(days: List[str]) -> List[str]: out = [] for d in days or []: k = d.strip() # accept already-correct abbr if k in DAY_ABBR and k not in out: out.append(k); continue # map aliases v = DAY_ALIASES.get(k.lower()) if v and v not in out: out.append(v) return out def _fallback_light(text: str) -> ParsedConstraints: t = (text or "").lower() # crude day catch-all found = set() for k,v in DAY_ALIASES.items(): if re.search(rf"\bno( classes?)? (on )?{re.escape(k)}\b", t) or re.search(rf"\bavoid (on )?{re.escape(k)}\b", t) or re.search(rf"\bno {re.escape(k)}\b", t): found.add(v) # time before = re.search(r"no classes before ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t) after = re.search(r"no classes after ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t) def fmt(h,m,ap): mm = m if m else "00" return f"{int(h)}:{mm} {ap.upper()}" no_before = fmt(*before.groups()) if before else None no_after = fmt(*after.groups()) if after else None return ParsedConstraints(subject_counts={}, banned_days=sorted(found), no_before=no_before, no_after=no_after, keywords=[], banned_professors=[]) def parse_constraints(text: str) -> ParsedConstraints: if not text or not text.strip(): return ParsedConstraints(subject_counts={}, banned_days=[], keywords=[], banned_professors=[]) prompt = _build_prompt(text.strip()) out = _lazy_pipe()(prompt)[0]["generated_text"] obj = _extract_json(out) if not isinstance(obj, dict): # graceful fallback return _fallback_light(text) # sanitize / coerce subject_counts = {} for k,v in (obj.get("subject_counts") or {}).items(): try: iv = int(v) if iv > 0: subject_counts[k] = iv except Exception: continue banned_days = _normalize_days(obj.get("banned_days") or []) no_before = obj.get("no_before") or None no_after = obj.get("no_after") or None def _list_str(x): return [s for s in (x or []) if isinstance(s, str) and s.strip()] keywords = _list_str(obj.get("keywords")) banned_professors = _list_str(obj.get("banned_professors")) return ParsedConstraints( subject_counts=subject_counts, banned_days=banned_days, no_before=no_before, no_after=no_after, keywords=keywords, banned_professors=banned_professors )