File size: 6,886 Bytes
4d1e784 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# constraint_parser_llm.py
from __future__ import annotations
import json, re, os
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Choose one (Llama3.1 is slightly better at structured JSON)
MODEL_NAME = os.environ.get("PARSER_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
# Alternative:
# MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
_tokenizer = None
_pipe = None
def _lazy_pipe():
global _tokenizer, _pipe
if _pipe is None:
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto", # works CPU/GPU
trust_remote_code=True,
)
_pipe = pipeline(
"text-generation",
model=model,
tokenizer=_tokenizer,
max_new_tokens=256,
temperature=0.0,
do_sample=False,
)
return _pipe
DAY_ABBR = ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"]
DAY_ALIASES = {
"sunday":"Sun","sundays":"Sun","sun":"Sun",
"monday":"Mon","mondays":"Mon","mon":"Mon",
"tuesday":"Tue","tuesdays":"Tue","tue":"Tue",
"wednesday":"Wed","wednesdays":"Wed","wed":"Wed",
"thursday":"Thu","thursdays":"Thu","thu":"Thu",
"friday":"Fri","fridays":"Fri","fri":"Fri",
"saturday":"Sat","saturdays":"Sat","sat":"Sat",
}
@dataclass
class ParsedConstraints:
subject_counts: Dict[str, int]
banned_days: List[str]
no_before: Optional[str] = None # "10:00 AM"
no_after: Optional[str] = None # "6:00 PM"
keywords: List[str] = None
banned_professors: List[str] = None
def to_json(self) -> str:
return json.dumps(asdict(self))
SYSTEM_PROMPT = """You convert natural language scheduling requests into STRICT JSON.
Output only JSON with this schema:
{
"subject_counts": {"<SubjectName>": <int>}, // optional
"banned_days": ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"], // optional
"no_before": "H:MM AM/PM", // optional
"no_after": "H:MM AM/PM", // optional
"keywords": ["..."], // optional
"banned_professors": ["..."] // optional
}
Rules:
- Never include commentary.
- Use only "Sun","Mon","Tue","Wed","Thu","Fri","Sat".
- If the user says 'no weekends', add 'Sat' and 'Sun'.
- If they say 'mornings'/'afternoons'/'evenings', map to times:
* mornings => no_before >= "10:00 AM" (i.e., prefer starts at/after 10am)
* afternoons => no_before >= "12:00 PM"
* evenings => no_before >= "4:00 PM"
* 'no mornings' => no_after <= "12:00 PM"
- If counts are implied (e.g., "a couple of CS classes"), interpret:
couple=2, few=3, several=3.
- If ambiguous, make reasonable assumptions but keep to schema.
"""
FEW_SHOTS = [
# (user, json)
("make me 4 CS classes, no Sundays",
{"subject_counts":{"Computer Science":4},"banned_days":["Sun"]}),
("i want two econ and one psychology after 1pm, avoid friday",
{"subject_counts":{"Economics":2,"Psychology":1},"no_before":"1:00 PM","banned_days":["Fri"]}),
("prefer software engineering and algorithms, no class on mon or tue mornings",
{"keywords":["software engineering","algorithms"],"banned_days":["Mon","Tue"],"no_after":"12:00 PM"}),
("no weekends, evenings only",
{"banned_days":["Sat","Sun"],"no_before":"4:00 PM"}),
("skip wed; a couple math; no classes before 10am",
{"subject_counts":{"Mathematics":2},"banned_days":["Wed"],"no_before":"10:00 AM"}),
]
def _build_prompt(user_text: str) -> str:
parts = [SYSTEM_PROMPT, "\nExamples:"]
for u, js in FEW_SHOTS:
parts.append(f"User: {u}\nJSON: {json.dumps(js)}")
parts.append(f"\nUser: {user_text}\nJSON:")
return "\n".join(parts)
def _extract_json(text: str) -> Optional[dict]:
text = text.strip()
# Some models wrap with text; find first/last braces
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
try:
return json.loads(text[start:end+1])
except Exception:
pass
# Sometimes models produce ```json blocks
m = re.search(r"```json\s*(\{.*?\})\s*```", text, re.S)
if m:
try:
return json.loads(m.group(1))
except Exception:
pass
return None
def _normalize_days(days: List[str]) -> List[str]:
out = []
for d in days or []:
k = d.strip()
# accept already-correct abbr
if k in DAY_ABBR and k not in out:
out.append(k); continue
# map aliases
v = DAY_ALIASES.get(k.lower())
if v and v not in out:
out.append(v)
return out
def _fallback_light(text: str) -> ParsedConstraints:
t = (text or "").lower()
# crude day catch-all
found = set()
for k,v in DAY_ALIASES.items():
if re.search(rf"\bno( classes?)? (on )?{re.escape(k)}\b", t) or re.search(rf"\bavoid (on )?{re.escape(k)}\b", t) or re.search(rf"\bno {re.escape(k)}\b", t):
found.add(v)
# time
before = re.search(r"no classes before ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t)
after = re.search(r"no classes after ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t)
def fmt(h,m,ap):
mm = m if m else "00"
return f"{int(h)}:{mm} {ap.upper()}"
no_before = fmt(*before.groups()) if before else None
no_after = fmt(*after.groups()) if after else None
return ParsedConstraints(subject_counts={}, banned_days=sorted(found), no_before=no_before, no_after=no_after, keywords=[], banned_professors=[])
def parse_constraints(text: str) -> ParsedConstraints:
if not text or not text.strip():
return ParsedConstraints(subject_counts={}, banned_days=[], keywords=[], banned_professors=[])
prompt = _build_prompt(text.strip())
out = _lazy_pipe()(prompt)[0]["generated_text"]
obj = _extract_json(out)
if not isinstance(obj, dict):
# graceful fallback
return _fallback_light(text)
# sanitize / coerce
subject_counts = {}
for k,v in (obj.get("subject_counts") or {}).items():
try:
iv = int(v)
if iv > 0:
subject_counts[k] = iv
except Exception:
continue
banned_days = _normalize_days(obj.get("banned_days") or [])
no_before = obj.get("no_before") or None
no_after = obj.get("no_after") or None
def _list_str(x): return [s for s in (x or []) if isinstance(s, str) and s.strip()]
keywords = _list_str(obj.get("keywords"))
banned_professors = _list_str(obj.get("banned_professors"))
return ParsedConstraints(
subject_counts=subject_counts,
banned_days=banned_days,
no_before=no_before,
no_after=no_after,
keywords=keywords,
banned_professors=banned_professors
)
|