Create constraint_parser_llm.py
Browse files- constraint_parser_llm.py +192 -0
constraint_parser_llm.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# constraint_parser_llm.py
|
2 |
+
from __future__ import annotations
|
3 |
+
import json, re, os
|
4 |
+
from dataclasses import dataclass, asdict
|
5 |
+
from typing import Dict, List, Optional
|
6 |
+
|
7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
8 |
+
|
9 |
+
# Choose one (Llama3.1 is slightly better at structured JSON)
|
10 |
+
MODEL_NAME = os.environ.get("PARSER_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
|
11 |
+
# Alternative:
|
12 |
+
# MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
13 |
+
|
14 |
+
_tokenizer = None
|
15 |
+
_pipe = None
|
16 |
+
|
17 |
+
def _lazy_pipe():
|
18 |
+
global _tokenizer, _pipe
|
19 |
+
if _pipe is None:
|
20 |
+
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
21 |
+
model = AutoModelForCausalLM.from_pretrained(
|
22 |
+
MODEL_NAME,
|
23 |
+
device_map="auto", # works CPU/GPU
|
24 |
+
trust_remote_code=True,
|
25 |
+
)
|
26 |
+
_pipe = pipeline(
|
27 |
+
"text-generation",
|
28 |
+
model=model,
|
29 |
+
tokenizer=_tokenizer,
|
30 |
+
max_new_tokens=256,
|
31 |
+
temperature=0.0,
|
32 |
+
do_sample=False,
|
33 |
+
)
|
34 |
+
return _pipe
|
35 |
+
|
36 |
+
DAY_ABBR = ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"]
|
37 |
+
DAY_ALIASES = {
|
38 |
+
"sunday":"Sun","sundays":"Sun","sun":"Sun",
|
39 |
+
"monday":"Mon","mondays":"Mon","mon":"Mon",
|
40 |
+
"tuesday":"Tue","tuesdays":"Tue","tue":"Tue",
|
41 |
+
"wednesday":"Wed","wednesdays":"Wed","wed":"Wed",
|
42 |
+
"thursday":"Thu","thursdays":"Thu","thu":"Thu",
|
43 |
+
"friday":"Fri","fridays":"Fri","fri":"Fri",
|
44 |
+
"saturday":"Sat","saturdays":"Sat","sat":"Sat",
|
45 |
+
}
|
46 |
+
|
47 |
+
@dataclass
|
48 |
+
class ParsedConstraints:
|
49 |
+
subject_counts: Dict[str, int]
|
50 |
+
banned_days: List[str]
|
51 |
+
no_before: Optional[str] = None # "10:00 AM"
|
52 |
+
no_after: Optional[str] = None # "6:00 PM"
|
53 |
+
keywords: List[str] = None
|
54 |
+
banned_professors: List[str] = None
|
55 |
+
|
56 |
+
def to_json(self) -> str:
|
57 |
+
return json.dumps(asdict(self))
|
58 |
+
|
59 |
+
SYSTEM_PROMPT = """You convert natural language scheduling requests into STRICT JSON.
|
60 |
+
Output only JSON with this schema:
|
61 |
+
|
62 |
+
{
|
63 |
+
"subject_counts": {"<SubjectName>": <int>}, // optional
|
64 |
+
"banned_days": ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"], // optional
|
65 |
+
"no_before": "H:MM AM/PM", // optional
|
66 |
+
"no_after": "H:MM AM/PM", // optional
|
67 |
+
"keywords": ["..."], // optional
|
68 |
+
"banned_professors": ["..."] // optional
|
69 |
+
}
|
70 |
+
|
71 |
+
Rules:
|
72 |
+
- Never include commentary.
|
73 |
+
- Use only "Sun","Mon","Tue","Wed","Thu","Fri","Sat".
|
74 |
+
- If the user says 'no weekends', add 'Sat' and 'Sun'.
|
75 |
+
- If they say 'mornings'/'afternoons'/'evenings', map to times:
|
76 |
+
* mornings => no_before >= "10:00 AM" (i.e., prefer starts at/after 10am)
|
77 |
+
* afternoons => no_before >= "12:00 PM"
|
78 |
+
* evenings => no_before >= "4:00 PM"
|
79 |
+
* 'no mornings' => no_after <= "12:00 PM"
|
80 |
+
- If counts are implied (e.g., "a couple of CS classes"), interpret:
|
81 |
+
couple=2, few=3, several=3.
|
82 |
+
- If ambiguous, make reasonable assumptions but keep to schema.
|
83 |
+
"""
|
84 |
+
|
85 |
+
FEW_SHOTS = [
|
86 |
+
# (user, json)
|
87 |
+
("make me 4 CS classes, no Sundays",
|
88 |
+
{"subject_counts":{"Computer Science":4},"banned_days":["Sun"]}),
|
89 |
+
("i want two econ and one psychology after 1pm, avoid friday",
|
90 |
+
{"subject_counts":{"Economics":2,"Psychology":1},"no_before":"1:00 PM","banned_days":["Fri"]}),
|
91 |
+
("prefer software engineering and algorithms, no class on mon or tue mornings",
|
92 |
+
{"keywords":["software engineering","algorithms"],"banned_days":["Mon","Tue"],"no_after":"12:00 PM"}),
|
93 |
+
("no weekends, evenings only",
|
94 |
+
{"banned_days":["Sat","Sun"],"no_before":"4:00 PM"}),
|
95 |
+
("skip wed; a couple math; no classes before 10am",
|
96 |
+
{"subject_counts":{"Mathematics":2},"banned_days":["Wed"],"no_before":"10:00 AM"}),
|
97 |
+
]
|
98 |
+
|
99 |
+
def _build_prompt(user_text: str) -> str:
|
100 |
+
parts = [SYSTEM_PROMPT, "\nExamples:"]
|
101 |
+
for u, js in FEW_SHOTS:
|
102 |
+
parts.append(f"User: {u}\nJSON: {json.dumps(js)}")
|
103 |
+
parts.append(f"\nUser: {user_text}\nJSON:")
|
104 |
+
return "\n".join(parts)
|
105 |
+
|
106 |
+
def _extract_json(text: str) -> Optional[dict]:
|
107 |
+
text = text.strip()
|
108 |
+
# Some models wrap with text; find first/last braces
|
109 |
+
start = text.find("{")
|
110 |
+
end = text.rfind("}")
|
111 |
+
if start != -1 and end != -1 and end > start:
|
112 |
+
try:
|
113 |
+
return json.loads(text[start:end+1])
|
114 |
+
except Exception:
|
115 |
+
pass
|
116 |
+
# Sometimes models produce ```json blocks
|
117 |
+
m = re.search(r"```json\s*(\{.*?\})\s*```", text, re.S)
|
118 |
+
if m:
|
119 |
+
try:
|
120 |
+
return json.loads(m.group(1))
|
121 |
+
except Exception:
|
122 |
+
pass
|
123 |
+
return None
|
124 |
+
|
125 |
+
def _normalize_days(days: List[str]) -> List[str]:
|
126 |
+
out = []
|
127 |
+
for d in days or []:
|
128 |
+
k = d.strip()
|
129 |
+
# accept already-correct abbr
|
130 |
+
if k in DAY_ABBR and k not in out:
|
131 |
+
out.append(k); continue
|
132 |
+
# map aliases
|
133 |
+
v = DAY_ALIASES.get(k.lower())
|
134 |
+
if v and v not in out:
|
135 |
+
out.append(v)
|
136 |
+
return out
|
137 |
+
|
138 |
+
def _fallback_light(text: str) -> ParsedConstraints:
|
139 |
+
t = (text or "").lower()
|
140 |
+
# crude day catch-all
|
141 |
+
found = set()
|
142 |
+
for k,v in DAY_ALIASES.items():
|
143 |
+
if re.search(rf"\bno( classes?)? (on )?{re.escape(k)}\b", t) or re.search(rf"\bavoid (on )?{re.escape(k)}\b", t) or re.search(rf"\bno {re.escape(k)}\b", t):
|
144 |
+
found.add(v)
|
145 |
+
# time
|
146 |
+
before = re.search(r"no classes before ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t)
|
147 |
+
after = re.search(r"no classes after ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t)
|
148 |
+
def fmt(h,m,ap):
|
149 |
+
mm = m if m else "00"
|
150 |
+
return f"{int(h)}:{mm} {ap.upper()}"
|
151 |
+
no_before = fmt(*before.groups()) if before else None
|
152 |
+
no_after = fmt(*after.groups()) if after else None
|
153 |
+
return ParsedConstraints(subject_counts={}, banned_days=sorted(found), no_before=no_before, no_after=no_after, keywords=[], banned_professors=[])
|
154 |
+
|
155 |
+
def parse_constraints(text: str) -> ParsedConstraints:
|
156 |
+
if not text or not text.strip():
|
157 |
+
return ParsedConstraints(subject_counts={}, banned_days=[], keywords=[], banned_professors=[])
|
158 |
+
|
159 |
+
prompt = _build_prompt(text.strip())
|
160 |
+
out = _lazy_pipe()(prompt)[0]["generated_text"]
|
161 |
+
|
162 |
+
obj = _extract_json(out)
|
163 |
+
if not isinstance(obj, dict):
|
164 |
+
# graceful fallback
|
165 |
+
return _fallback_light(text)
|
166 |
+
|
167 |
+
# sanitize / coerce
|
168 |
+
subject_counts = {}
|
169 |
+
for k,v in (obj.get("subject_counts") or {}).items():
|
170 |
+
try:
|
171 |
+
iv = int(v)
|
172 |
+
if iv > 0:
|
173 |
+
subject_counts[k] = iv
|
174 |
+
except Exception:
|
175 |
+
continue
|
176 |
+
|
177 |
+
banned_days = _normalize_days(obj.get("banned_days") or [])
|
178 |
+
no_before = obj.get("no_before") or None
|
179 |
+
no_after = obj.get("no_after") or None
|
180 |
+
|
181 |
+
def _list_str(x): return [s for s in (x or []) if isinstance(s, str) and s.strip()]
|
182 |
+
keywords = _list_str(obj.get("keywords"))
|
183 |
+
banned_professors = _list_str(obj.get("banned_professors"))
|
184 |
+
|
185 |
+
return ParsedConstraints(
|
186 |
+
subject_counts=subject_counts,
|
187 |
+
banned_days=banned_days,
|
188 |
+
no_before=no_before,
|
189 |
+
no_after=no_after,
|
190 |
+
keywords=keywords,
|
191 |
+
banned_professors=banned_professors
|
192 |
+
)
|