# constraint_parser.py from __future__ import annotations import json import re from dataclasses import dataclass, asdict from typing import Dict, List, Optional from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline # Small, cheap, works on CPU: MODEL_NAME = "google/flan-t5-small" _tokenizer = None _pipe = None def _lazy_pipe(): global _tokenizer, _pipe if _pipe is None: _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) _pipe = pipeline("text2text-generation", model=model, tokenizer=_tokenizer) return _pipe DAYS_ALIASES = { "sun": "Sun", "sunday": "Sun", "sundays": "Sun", "mon": "Mon", "monday": "Mon", "mondays": "Mon", "tue": "Tue", "tuesday": "Tue", "tuesdays": "Tue", "wed": "Wed", "wednesday": "Wed", "wednesdays": "Wed", "thu": "Thu", "thursday": "Thu", "thursdays": "Thu", "fri": "Fri", "friday": "Fri", "fridays": "Fri", "sat": "Sat", "saturday": "Sat", "saturdays": "Sat", } @dataclass class ParsedConstraints: subject_counts: Dict[str, int] # e.g., {"Computer Science": 4} banned_days: List[str] # ["Sun","Fri"] no_before: Optional[str] = None # "10:00 AM" no_after: Optional[str] = None # "6:00 PM" keywords: List[str] = None # ["software engineering","algorithms"] banned_professors: List[str] = None # optional future use def to_json(self) -> str: return json.dumps(asdict(self)) SCHEMA_INSTRUCTIONS = """You are a parser. Convert the user request into STRICT JSON with this schema: { "subject_counts": {"": , ...}, // optional "banned_days": ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"], // optional "no_before": "H:MM AM/PM", // optional "no_after": "H:MM AM/PM", // optional "keywords": ["..."], // optional: thematic words/phrases to prefer "banned_professors": ["..."] // optional } Only output valid JSON. No commentary. """ def _normalize_days(days): out = [] for d in days: k = d.strip().lower() if k in DAYS_ALIASES: std = DAYS_ALIASES[k] if std not in out: out.append(std) return out def try_regex_first(text: str) -> ParsedConstraints: """Cheap guardrail: catch obvious patterns before LLM.""" t = text.lower() # days days = set([m.group(1) for m in re.finditer( r"(sundays?|mondays?|tuesdays?|wednesdays?|thursdays?|fridays?|saturdays?|sun|mon|tue|wed|thu|fri|sat)", t)]) banned_days = _normalize_days(days) if ("no classes on" in t or "don't give me classes on" in t or "dont give me classes on" in t) else [] # times before = re.search(r"no classes before ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t) after = re.search(r"no classes after ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t) def fmt(h,m,ap): mm = m if m else "00" return f"{int(h)}:{mm} {ap.upper()}" no_before = fmt(*before.groups()) if before else None no_after = fmt(*after.groups()) if after else None # subject count like "4 cs classes" / "three computer science" subject_counts = {} counts_map = {"one":1,"two":2,"three":3,"four":4,"five":5,"six":6} # cheap mapping aliases subj_alias = { "cs":"Computer Science", "computer science":"Computer Science", "math":"Mathematics", "econ":"Economics", "psych":"Psychology", "bio":"Biology", "chem":"Chemistry", "phys":"Physics", "art history":"Art History", "philosophy":"Philosophy", "finance":"Finance", } m = re.findall(r"(\b\d+\b|\bone\b|\btwo\b|\bthree\b|\bfour\b|\bfive\b|\bsix\b)\s+([a-z ]+?)\s+classes", t) for num, subj_phrase in m: cnt = int(num) if num.isdigit() else counts_map.get(num, None) if not cnt: continue sp = subj_phrase.strip() # resolve subject resolved = None for k,v in subj_alias.items(): if k in sp: resolved = v; break if resolved: subject_counts[resolved] = subject_counts.get(resolved, 0) + cnt return ParsedConstraints( subject_counts=subject_counts, banned_days=banned_days, no_before=no_before, no_after=no_after, keywords=[], banned_professors=[] ) def parse_constraints(text: str) -> ParsedConstraints: """LLM parse with regex fallback & JSON repair.""" if not text or not text.strip(): return ParsedConstraints(subject_counts={}, banned_days=[], keywords=[], banned_professors=[]) # Start with regex quick pass base = try_regex_first(text) prompt = SCHEMA_INSTRUCTIONS + "\nUser request: " + text.strip() out = _lazy_pipe()(prompt, max_new_tokens=256, temperature=0.0)[0]["generated_text"] # try straight parse obj = None try: obj = json.loads(out) except Exception: # attempt to extract JSON substring start = out.find("{") end = out.rfind("}") if start != -1 and end != -1 and end > start: try: obj = json.loads(out[start:end+1]) except Exception: obj = None if not isinstance(obj, dict): # fall back to regex-only result return base # merge with base def get(k, default): v = obj.get(k, default) return v if v is not None else default subject_counts = base.subject_counts.copy() for k,v in get("subject_counts", {}).items(): if isinstance(v, int) and v > 0: subject_counts[k] = subject_counts.get(k, 0) + v banned_days = list({*base.banned_days, *(_normalize_days(get("banned_days", [])))}) no_before = get("no_before", base.no_before) no_after = get("no_after", base.no_after) keywords = get("keywords", []) or base.keywords banned_professors = get("banned_professors", []) or base.banned_professors return ParsedConstraints( subject_counts=subject_counts, banned_days=banned_days, no_before=no_before, no_after=no_after, keywords=keywords or [], banned_professors=banned_professors or [] )