cgreszes commited on
Commit
30bee5d
·
verified ·
1 Parent(s): f630108

Create constraint_parser.py

Browse files
Files changed (1) hide show
  1. constraint_parser.py +179 -0
constraint_parser.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # constraint_parser.py
2
+ from __future__ import annotations
3
+ import json
4
+ import re
5
+ from dataclasses import dataclass, asdict
6
+ from typing import Dict, List, Optional
7
+
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
+
10
+ # Small, cheap, works on CPU:
11
+ MODEL_NAME = "google/flan-t5-small"
12
+
13
+ _tokenizer = None
14
+ _pipe = None
15
+
16
+ def _lazy_pipe():
17
+ global _tokenizer, _pipe
18
+ if _pipe is None:
19
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
21
+ _pipe = pipeline("text2text-generation", model=model, tokenizer=_tokenizer)
22
+ return _pipe
23
+
24
+ DAYS_ALIASES = {
25
+ "sun": "Sun", "sunday": "Sun", "sundays": "Sun",
26
+ "mon": "Mon", "monday": "Mon", "mondays": "Mon",
27
+ "tue": "Tue", "tuesday": "Tue", "tuesdays": "Tue",
28
+ "wed": "Wed", "wednesday": "Wed", "wednesdays": "Wed",
29
+ "thu": "Thu", "thursday": "Thu", "thursdays": "Thu",
30
+ "fri": "Fri", "friday": "Fri", "fridays": "Fri",
31
+ "sat": "Sat", "saturday": "Sat", "saturdays": "Sat",
32
+ }
33
+
34
+ @dataclass
35
+ class ParsedConstraints:
36
+ subject_counts: Dict[str, int] # e.g., {"Computer Science": 4}
37
+ banned_days: List[str] # ["Sun","Fri"]
38
+ no_before: Optional[str] = None # "10:00 AM"
39
+ no_after: Optional[str] = None # "6:00 PM"
40
+ keywords: List[str] = None # ["software engineering","algorithms"]
41
+ banned_professors: List[str] = None # optional future use
42
+
43
+ def to_json(self) -> str:
44
+ return json.dumps(asdict(self))
45
+
46
+ SCHEMA_INSTRUCTIONS = """You are a parser. Convert the user request into STRICT JSON with this schema:
47
+ {
48
+ "subject_counts": {"<SubjectName>": <int>, ...}, // optional
49
+ "banned_days": ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"], // optional
50
+ "no_before": "H:MM AM/PM", // optional
51
+ "no_after": "H:MM AM/PM", // optional
52
+ "keywords": ["..."], // optional: thematic words/phrases to prefer
53
+ "banned_professors": ["..."] // optional
54
+ }
55
+ Only output valid JSON. No commentary.
56
+ """
57
+
58
+ def _normalize_days(days):
59
+ out = []
60
+ for d in days:
61
+ k = d.strip().lower()
62
+ if k in DAYS_ALIASES:
63
+ std = DAYS_ALIASES[k]
64
+ if std not in out:
65
+ out.append(std)
66
+ return out
67
+
68
+ def try_regex_first(text: str) -> ParsedConstraints:
69
+ """Cheap guardrail: catch obvious patterns before LLM."""
70
+ t = text.lower()
71
+
72
+ # days
73
+ days = set([m.group(1) for m in re.finditer(
74
+ r"(sundays?|mondays?|tuesdays?|wednesdays?|thursdays?|fridays?|saturdays?|sun|mon|tue|wed|thu|fri|sat)", t)])
75
+ banned_days = _normalize_days(days) if ("no classes on" in t or "don't give me classes on" in t or "dont give me classes on" in t) else []
76
+
77
+ # times
78
+ before = re.search(r"no classes before ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t)
79
+ after = re.search(r"no classes after ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t)
80
+ def fmt(h,m,ap):
81
+ mm = m if m else "00"
82
+ return f"{int(h)}:{mm} {ap.upper()}"
83
+ no_before = fmt(*before.groups()) if before else None
84
+ no_after = fmt(*after.groups()) if after else None
85
+
86
+ # subject count like "4 cs classes" / "three computer science"
87
+ subject_counts = {}
88
+ counts_map = {"one":1,"two":2,"three":3,"four":4,"five":5,"six":6}
89
+ # cheap mapping aliases
90
+ subj_alias = {
91
+ "cs":"Computer Science",
92
+ "computer science":"Computer Science",
93
+ "math":"Mathematics",
94
+ "econ":"Economics",
95
+ "psych":"Psychology",
96
+ "bio":"Biology",
97
+ "chem":"Chemistry",
98
+ "phys":"Physics",
99
+ "art history":"Art History",
100
+ "philosophy":"Philosophy",
101
+ "finance":"Finance",
102
+ }
103
+ m = re.findall(r"(\b\d+\b|\bone\b|\btwo\b|\bthree\b|\bfour\b|\bfive\b|\bsix\b)\s+([a-z ]+?)\s+classes", t)
104
+ for num, subj_phrase in m:
105
+ cnt = int(num) if num.isdigit() else counts_map.get(num, None)
106
+ if not cnt: continue
107
+ sp = subj_phrase.strip()
108
+ # resolve subject
109
+ resolved = None
110
+ for k,v in subj_alias.items():
111
+ if k in sp:
112
+ resolved = v; break
113
+ if resolved:
114
+ subject_counts[resolved] = subject_counts.get(resolved, 0) + cnt
115
+
116
+ return ParsedConstraints(
117
+ subject_counts=subject_counts,
118
+ banned_days=banned_days,
119
+ no_before=no_before,
120
+ no_after=no_after,
121
+ keywords=[],
122
+ banned_professors=[]
123
+ )
124
+
125
+ def parse_constraints(text: str) -> ParsedConstraints:
126
+ """LLM parse with regex fallback & JSON repair."""
127
+ if not text or not text.strip():
128
+ return ParsedConstraints(subject_counts={}, banned_days=[], keywords=[], banned_professors=[])
129
+
130
+ # Start with regex quick pass
131
+ base = try_regex_first(text)
132
+
133
+ prompt = SCHEMA_INSTRUCTIONS + "\nUser request: " + text.strip()
134
+ out = _lazy_pipe()(prompt, max_new_tokens=256, temperature=0.0)[0]["generated_text"]
135
+
136
+ # try straight parse
137
+ obj = None
138
+ try:
139
+ obj = json.loads(out)
140
+ except Exception:
141
+ # attempt to extract JSON substring
142
+ start = out.find("{")
143
+ end = out.rfind("}")
144
+ if start != -1 and end != -1 and end > start:
145
+ try:
146
+ obj = json.loads(out[start:end+1])
147
+ except Exception:
148
+ obj = None
149
+
150
+ if not isinstance(obj, dict):
151
+ # fall back to regex-only result
152
+ return base
153
+
154
+ # merge with base
155
+ def get(k, default):
156
+ v = obj.get(k, default)
157
+ return v if v is not None else default
158
+
159
+ subject_counts = base.subject_counts.copy()
160
+ for k,v in get("subject_counts", {}).items():
161
+ if isinstance(v, int) and v > 0:
162
+ subject_counts[k] = subject_counts.get(k, 0) + v
163
+
164
+ banned_days = list({*base.banned_days, *(_normalize_days(get("banned_days", [])))})
165
+
166
+ no_before = get("no_before", base.no_before)
167
+ no_after = get("no_after", base.no_after)
168
+
169
+ keywords = get("keywords", []) or base.keywords
170
+ banned_professors = get("banned_professors", []) or base.banned_professors
171
+
172
+ return ParsedConstraints(
173
+ subject_counts=subject_counts,
174
+ banned_days=banned_days,
175
+ no_before=no_before,
176
+ no_after=no_after,
177
+ keywords=keywords or [],
178
+ banned_professors=banned_professors or []
179
+ )