Class_Schedule_Generator_AI / constraint_parser_llm.py
cgreszes's picture
Create constraint_parser_llm.py
4d1e784 verified
# constraint_parser_llm.py
from __future__ import annotations
import json, re, os
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Choose one (Llama3.1 is slightly better at structured JSON)
MODEL_NAME = os.environ.get("PARSER_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
# Alternative:
# MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
_tokenizer = None
_pipe = None
def _lazy_pipe():
global _tokenizer, _pipe
if _pipe is None:
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto", # works CPU/GPU
trust_remote_code=True,
)
_pipe = pipeline(
"text-generation",
model=model,
tokenizer=_tokenizer,
max_new_tokens=256,
temperature=0.0,
do_sample=False,
)
return _pipe
DAY_ABBR = ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"]
DAY_ALIASES = {
"sunday":"Sun","sundays":"Sun","sun":"Sun",
"monday":"Mon","mondays":"Mon","mon":"Mon",
"tuesday":"Tue","tuesdays":"Tue","tue":"Tue",
"wednesday":"Wed","wednesdays":"Wed","wed":"Wed",
"thursday":"Thu","thursdays":"Thu","thu":"Thu",
"friday":"Fri","fridays":"Fri","fri":"Fri",
"saturday":"Sat","saturdays":"Sat","sat":"Sat",
}
@dataclass
class ParsedConstraints:
subject_counts: Dict[str, int]
banned_days: List[str]
no_before: Optional[str] = None # "10:00 AM"
no_after: Optional[str] = None # "6:00 PM"
keywords: List[str] = None
banned_professors: List[str] = None
def to_json(self) -> str:
return json.dumps(asdict(self))
SYSTEM_PROMPT = """You convert natural language scheduling requests into STRICT JSON.
Output only JSON with this schema:
{
"subject_counts": {"<SubjectName>": <int>}, // optional
"banned_days": ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"], // optional
"no_before": "H:MM AM/PM", // optional
"no_after": "H:MM AM/PM", // optional
"keywords": ["..."], // optional
"banned_professors": ["..."] // optional
}
Rules:
- Never include commentary.
- Use only "Sun","Mon","Tue","Wed","Thu","Fri","Sat".
- If the user says 'no weekends', add 'Sat' and 'Sun'.
- If they say 'mornings'/'afternoons'/'evenings', map to times:
* mornings => no_before >= "10:00 AM" (i.e., prefer starts at/after 10am)
* afternoons => no_before >= "12:00 PM"
* evenings => no_before >= "4:00 PM"
* 'no mornings' => no_after <= "12:00 PM"
- If counts are implied (e.g., "a couple of CS classes"), interpret:
couple=2, few=3, several=3.
- If ambiguous, make reasonable assumptions but keep to schema.
"""
FEW_SHOTS = [
# (user, json)
("make me 4 CS classes, no Sundays",
{"subject_counts":{"Computer Science":4},"banned_days":["Sun"]}),
("i want two econ and one psychology after 1pm, avoid friday",
{"subject_counts":{"Economics":2,"Psychology":1},"no_before":"1:00 PM","banned_days":["Fri"]}),
("prefer software engineering and algorithms, no class on mon or tue mornings",
{"keywords":["software engineering","algorithms"],"banned_days":["Mon","Tue"],"no_after":"12:00 PM"}),
("no weekends, evenings only",
{"banned_days":["Sat","Sun"],"no_before":"4:00 PM"}),
("skip wed; a couple math; no classes before 10am",
{"subject_counts":{"Mathematics":2},"banned_days":["Wed"],"no_before":"10:00 AM"}),
]
def _build_prompt(user_text: str) -> str:
parts = [SYSTEM_PROMPT, "\nExamples:"]
for u, js in FEW_SHOTS:
parts.append(f"User: {u}\nJSON: {json.dumps(js)}")
parts.append(f"\nUser: {user_text}\nJSON:")
return "\n".join(parts)
def _extract_json(text: str) -> Optional[dict]:
text = text.strip()
# Some models wrap with text; find first/last braces
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
try:
return json.loads(text[start:end+1])
except Exception:
pass
# Sometimes models produce ```json blocks
m = re.search(r"```json\s*(\{.*?\})\s*```", text, re.S)
if m:
try:
return json.loads(m.group(1))
except Exception:
pass
return None
def _normalize_days(days: List[str]) -> List[str]:
out = []
for d in days or []:
k = d.strip()
# accept already-correct abbr
if k in DAY_ABBR and k not in out:
out.append(k); continue
# map aliases
v = DAY_ALIASES.get(k.lower())
if v and v not in out:
out.append(v)
return out
def _fallback_light(text: str) -> ParsedConstraints:
t = (text or "").lower()
# crude day catch-all
found = set()
for k,v in DAY_ALIASES.items():
if re.search(rf"\bno( classes?)? (on )?{re.escape(k)}\b", t) or re.search(rf"\bavoid (on )?{re.escape(k)}\b", t) or re.search(rf"\bno {re.escape(k)}\b", t):
found.add(v)
# time
before = re.search(r"no classes before ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t)
after = re.search(r"no classes after ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t)
def fmt(h,m,ap):
mm = m if m else "00"
return f"{int(h)}:{mm} {ap.upper()}"
no_before = fmt(*before.groups()) if before else None
no_after = fmt(*after.groups()) if after else None
return ParsedConstraints(subject_counts={}, banned_days=sorted(found), no_before=no_before, no_after=no_after, keywords=[], banned_professors=[])
def parse_constraints(text: str) -> ParsedConstraints:
if not text or not text.strip():
return ParsedConstraints(subject_counts={}, banned_days=[], keywords=[], banned_professors=[])
prompt = _build_prompt(text.strip())
out = _lazy_pipe()(prompt)[0]["generated_text"]
obj = _extract_json(out)
if not isinstance(obj, dict):
# graceful fallback
return _fallback_light(text)
# sanitize / coerce
subject_counts = {}
for k,v in (obj.get("subject_counts") or {}).items():
try:
iv = int(v)
if iv > 0:
subject_counts[k] = iv
except Exception:
continue
banned_days = _normalize_days(obj.get("banned_days") or [])
no_before = obj.get("no_before") or None
no_after = obj.get("no_after") or None
def _list_str(x): return [s for s in (x or []) if isinstance(s, str) and s.strip()]
keywords = _list_str(obj.get("keywords"))
banned_professors = _list_str(obj.get("banned_professors"))
return ParsedConstraints(
subject_counts=subject_counts,
banned_days=banned_days,
no_before=no_before,
no_after=no_after,
keywords=keywords,
banned_professors=banned_professors
)