cgreszes commited on
Commit
4d1e784
·
verified ·
1 Parent(s): 20ddb69

Create constraint_parser_llm.py

Browse files
Files changed (1) hide show
  1. constraint_parser_llm.py +192 -0
constraint_parser_llm.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # constraint_parser_llm.py
2
+ from __future__ import annotations
3
+ import json, re, os
4
+ from dataclasses import dataclass, asdict
5
+ from typing import Dict, List, Optional
6
+
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
+
9
+ # Choose one (Llama3.1 is slightly better at structured JSON)
10
+ MODEL_NAME = os.environ.get("PARSER_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
11
+ # Alternative:
12
+ # MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
13
+
14
+ _tokenizer = None
15
+ _pipe = None
16
+
17
+ def _lazy_pipe():
18
+ global _tokenizer, _pipe
19
+ if _pipe is None:
20
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_NAME,
23
+ device_map="auto", # works CPU/GPU
24
+ trust_remote_code=True,
25
+ )
26
+ _pipe = pipeline(
27
+ "text-generation",
28
+ model=model,
29
+ tokenizer=_tokenizer,
30
+ max_new_tokens=256,
31
+ temperature=0.0,
32
+ do_sample=False,
33
+ )
34
+ return _pipe
35
+
36
+ DAY_ABBR = ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"]
37
+ DAY_ALIASES = {
38
+ "sunday":"Sun","sundays":"Sun","sun":"Sun",
39
+ "monday":"Mon","mondays":"Mon","mon":"Mon",
40
+ "tuesday":"Tue","tuesdays":"Tue","tue":"Tue",
41
+ "wednesday":"Wed","wednesdays":"Wed","wed":"Wed",
42
+ "thursday":"Thu","thursdays":"Thu","thu":"Thu",
43
+ "friday":"Fri","fridays":"Fri","fri":"Fri",
44
+ "saturday":"Sat","saturdays":"Sat","sat":"Sat",
45
+ }
46
+
47
+ @dataclass
48
+ class ParsedConstraints:
49
+ subject_counts: Dict[str, int]
50
+ banned_days: List[str]
51
+ no_before: Optional[str] = None # "10:00 AM"
52
+ no_after: Optional[str] = None # "6:00 PM"
53
+ keywords: List[str] = None
54
+ banned_professors: List[str] = None
55
+
56
+ def to_json(self) -> str:
57
+ return json.dumps(asdict(self))
58
+
59
+ SYSTEM_PROMPT = """You convert natural language scheduling requests into STRICT JSON.
60
+ Output only JSON with this schema:
61
+
62
+ {
63
+ "subject_counts": {"<SubjectName>": <int>}, // optional
64
+ "banned_days": ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"], // optional
65
+ "no_before": "H:MM AM/PM", // optional
66
+ "no_after": "H:MM AM/PM", // optional
67
+ "keywords": ["..."], // optional
68
+ "banned_professors": ["..."] // optional
69
+ }
70
+
71
+ Rules:
72
+ - Never include commentary.
73
+ - Use only "Sun","Mon","Tue","Wed","Thu","Fri","Sat".
74
+ - If the user says 'no weekends', add 'Sat' and 'Sun'.
75
+ - If they say 'mornings'/'afternoons'/'evenings', map to times:
76
+ * mornings => no_before >= "10:00 AM" (i.e., prefer starts at/after 10am)
77
+ * afternoons => no_before >= "12:00 PM"
78
+ * evenings => no_before >= "4:00 PM"
79
+ * 'no mornings' => no_after <= "12:00 PM"
80
+ - If counts are implied (e.g., "a couple of CS classes"), interpret:
81
+ couple=2, few=3, several=3.
82
+ - If ambiguous, make reasonable assumptions but keep to schema.
83
+ """
84
+
85
+ FEW_SHOTS = [
86
+ # (user, json)
87
+ ("make me 4 CS classes, no Sundays",
88
+ {"subject_counts":{"Computer Science":4},"banned_days":["Sun"]}),
89
+ ("i want two econ and one psychology after 1pm, avoid friday",
90
+ {"subject_counts":{"Economics":2,"Psychology":1},"no_before":"1:00 PM","banned_days":["Fri"]}),
91
+ ("prefer software engineering and algorithms, no class on mon or tue mornings",
92
+ {"keywords":["software engineering","algorithms"],"banned_days":["Mon","Tue"],"no_after":"12:00 PM"}),
93
+ ("no weekends, evenings only",
94
+ {"banned_days":["Sat","Sun"],"no_before":"4:00 PM"}),
95
+ ("skip wed; a couple math; no classes before 10am",
96
+ {"subject_counts":{"Mathematics":2},"banned_days":["Wed"],"no_before":"10:00 AM"}),
97
+ ]
98
+
99
+ def _build_prompt(user_text: str) -> str:
100
+ parts = [SYSTEM_PROMPT, "\nExamples:"]
101
+ for u, js in FEW_SHOTS:
102
+ parts.append(f"User: {u}\nJSON: {json.dumps(js)}")
103
+ parts.append(f"\nUser: {user_text}\nJSON:")
104
+ return "\n".join(parts)
105
+
106
+ def _extract_json(text: str) -> Optional[dict]:
107
+ text = text.strip()
108
+ # Some models wrap with text; find first/last braces
109
+ start = text.find("{")
110
+ end = text.rfind("}")
111
+ if start != -1 and end != -1 and end > start:
112
+ try:
113
+ return json.loads(text[start:end+1])
114
+ except Exception:
115
+ pass
116
+ # Sometimes models produce ```json blocks
117
+ m = re.search(r"```json\s*(\{.*?\})\s*```", text, re.S)
118
+ if m:
119
+ try:
120
+ return json.loads(m.group(1))
121
+ except Exception:
122
+ pass
123
+ return None
124
+
125
+ def _normalize_days(days: List[str]) -> List[str]:
126
+ out = []
127
+ for d in days or []:
128
+ k = d.strip()
129
+ # accept already-correct abbr
130
+ if k in DAY_ABBR and k not in out:
131
+ out.append(k); continue
132
+ # map aliases
133
+ v = DAY_ALIASES.get(k.lower())
134
+ if v and v not in out:
135
+ out.append(v)
136
+ return out
137
+
138
+ def _fallback_light(text: str) -> ParsedConstraints:
139
+ t = (text or "").lower()
140
+ # crude day catch-all
141
+ found = set()
142
+ for k,v in DAY_ALIASES.items():
143
+ if re.search(rf"\bno( classes?)? (on )?{re.escape(k)}\b", t) or re.search(rf"\bavoid (on )?{re.escape(k)}\b", t) or re.search(rf"\bno {re.escape(k)}\b", t):
144
+ found.add(v)
145
+ # time
146
+ before = re.search(r"no classes before ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t)
147
+ after = re.search(r"no classes after ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t)
148
+ def fmt(h,m,ap):
149
+ mm = m if m else "00"
150
+ return f"{int(h)}:{mm} {ap.upper()}"
151
+ no_before = fmt(*before.groups()) if before else None
152
+ no_after = fmt(*after.groups()) if after else None
153
+ return ParsedConstraints(subject_counts={}, banned_days=sorted(found), no_before=no_before, no_after=no_after, keywords=[], banned_professors=[])
154
+
155
+ def parse_constraints(text: str) -> ParsedConstraints:
156
+ if not text or not text.strip():
157
+ return ParsedConstraints(subject_counts={}, banned_days=[], keywords=[], banned_professors=[])
158
+
159
+ prompt = _build_prompt(text.strip())
160
+ out = _lazy_pipe()(prompt)[0]["generated_text"]
161
+
162
+ obj = _extract_json(out)
163
+ if not isinstance(obj, dict):
164
+ # graceful fallback
165
+ return _fallback_light(text)
166
+
167
+ # sanitize / coerce
168
+ subject_counts = {}
169
+ for k,v in (obj.get("subject_counts") or {}).items():
170
+ try:
171
+ iv = int(v)
172
+ if iv > 0:
173
+ subject_counts[k] = iv
174
+ except Exception:
175
+ continue
176
+
177
+ banned_days = _normalize_days(obj.get("banned_days") or [])
178
+ no_before = obj.get("no_before") or None
179
+ no_after = obj.get("no_after") or None
180
+
181
+ def _list_str(x): return [s for s in (x or []) if isinstance(s, str) and s.strip()]
182
+ keywords = _list_str(obj.get("keywords"))
183
+ banned_professors = _list_str(obj.get("banned_professors"))
184
+
185
+ return ParsedConstraints(
186
+ subject_counts=subject_counts,
187
+ banned_days=banned_days,
188
+ no_before=no_before,
189
+ no_after=no_after,
190
+ keywords=keywords,
191
+ banned_professors=banned_professors
192
+ )