File size: 6,886 Bytes
4d1e784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# constraint_parser_llm.py
from __future__ import annotations
import json, re, os
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Choose one (Llama3.1 is slightly better at structured JSON)
MODEL_NAME = os.environ.get("PARSER_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
# Alternative:
# MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"

_tokenizer = None
_pipe = None

def _lazy_pipe():
    global _tokenizer, _pipe
    if _pipe is None:
        _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            device_map="auto",            # works CPU/GPU
            trust_remote_code=True,
        )
        _pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=_tokenizer,
            max_new_tokens=256,
            temperature=0.0,
            do_sample=False,
        )
    return _pipe

DAY_ABBR = ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"]
DAY_ALIASES = {
    "sunday":"Sun","sundays":"Sun","sun":"Sun",
    "monday":"Mon","mondays":"Mon","mon":"Mon",
    "tuesday":"Tue","tuesdays":"Tue","tue":"Tue",
    "wednesday":"Wed","wednesdays":"Wed","wed":"Wed",
    "thursday":"Thu","thursdays":"Thu","thu":"Thu",
    "friday":"Fri","fridays":"Fri","fri":"Fri",
    "saturday":"Sat","saturdays":"Sat","sat":"Sat",
}

@dataclass
class ParsedConstraints:
    subject_counts: Dict[str, int]
    banned_days: List[str]
    no_before: Optional[str] = None     # "10:00 AM"
    no_after: Optional[str] = None      # "6:00 PM"
    keywords: List[str] = None
    banned_professors: List[str] = None

    def to_json(self) -> str:
        return json.dumps(asdict(self))

SYSTEM_PROMPT = """You convert natural language scheduling requests into STRICT JSON.
Output only JSON with this schema:

{
  "subject_counts": {"<SubjectName>": <int>},   // optional
  "banned_days": ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"], // optional
  "no_before": "H:MM AM/PM", // optional
  "no_after": "H:MM AM/PM",  // optional
  "keywords": ["..."],       // optional
  "banned_professors": ["..."] // optional
}

Rules:
- Never include commentary.
- Use only "Sun","Mon","Tue","Wed","Thu","Fri","Sat".
- If the user says 'no weekends', add 'Sat' and 'Sun'.
- If they say 'mornings'/'afternoons'/'evenings', map to times:
  * mornings => no_before >= "10:00 AM" (i.e., prefer starts at/after 10am)
  * afternoons => no_before >= "12:00 PM"
  * evenings => no_before >= "4:00 PM"
  * 'no mornings' => no_after <= "12:00 PM"
- If counts are implied (e.g., "a couple of CS classes"), interpret:
  couple=2, few=3, several=3.
- If ambiguous, make reasonable assumptions but keep to schema.
"""

FEW_SHOTS = [
    # (user, json)
    ("make me 4 CS classes, no Sundays",
     {"subject_counts":{"Computer Science":4},"banned_days":["Sun"]}),
    ("i want two econ and one psychology after 1pm, avoid friday",
     {"subject_counts":{"Economics":2,"Psychology":1},"no_before":"1:00 PM","banned_days":["Fri"]}),
    ("prefer software engineering and algorithms, no class on mon or tue mornings",
     {"keywords":["software engineering","algorithms"],"banned_days":["Mon","Tue"],"no_after":"12:00 PM"}),
    ("no weekends, evenings only",
     {"banned_days":["Sat","Sun"],"no_before":"4:00 PM"}),
    ("skip wed; a couple math; no classes before 10am",
     {"subject_counts":{"Mathematics":2},"banned_days":["Wed"],"no_before":"10:00 AM"}),
]

def _build_prompt(user_text: str) -> str:
    parts = [SYSTEM_PROMPT, "\nExamples:"]
    for u, js in FEW_SHOTS:
        parts.append(f"User: {u}\nJSON: {json.dumps(js)}")
    parts.append(f"\nUser: {user_text}\nJSON:")
    return "\n".join(parts)

def _extract_json(text: str) -> Optional[dict]:
    text = text.strip()
    # Some models wrap with text; find first/last braces
    start = text.find("{")
    end = text.rfind("}")
    if start != -1 and end != -1 and end > start:
        try:
            return json.loads(text[start:end+1])
        except Exception:
            pass
    # Sometimes models produce ```json blocks
    m = re.search(r"```json\s*(\{.*?\})\s*```", text, re.S)
    if m:
        try:
            return json.loads(m.group(1))
        except Exception:
            pass
    return None

def _normalize_days(days: List[str]) -> List[str]:
    out = []
    for d in days or []:
        k = d.strip()
        # accept already-correct abbr
        if k in DAY_ABBR and k not in out:
            out.append(k); continue
        # map aliases
        v = DAY_ALIASES.get(k.lower())
        if v and v not in out:
            out.append(v)
    return out

def _fallback_light(text: str) -> ParsedConstraints:
    t = (text or "").lower()
    # crude day catch-all
    found = set()
    for k,v in DAY_ALIASES.items():
        if re.search(rf"\bno( classes?)? (on )?{re.escape(k)}\b", t) or re.search(rf"\bavoid (on )?{re.escape(k)}\b", t) or re.search(rf"\bno {re.escape(k)}\b", t):
            found.add(v)
    # time
    before = re.search(r"no classes before ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t)
    after  = re.search(r"no classes after ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t)
    def fmt(h,m,ap):
        mm = m if m else "00"
        return f"{int(h)}:{mm} {ap.upper()}"
    no_before = fmt(*before.groups()) if before else None
    no_after  = fmt(*after.groups()) if after else None
    return ParsedConstraints(subject_counts={}, banned_days=sorted(found), no_before=no_before, no_after=no_after, keywords=[], banned_professors=[])

def parse_constraints(text: str) -> ParsedConstraints:
    if not text or not text.strip():
        return ParsedConstraints(subject_counts={}, banned_days=[], keywords=[], banned_professors=[])

    prompt = _build_prompt(text.strip())
    out = _lazy_pipe()(prompt)[0]["generated_text"]

    obj = _extract_json(out)
    if not isinstance(obj, dict):
        # graceful fallback
        return _fallback_light(text)

    # sanitize / coerce
    subject_counts = {}
    for k,v in (obj.get("subject_counts") or {}).items():
        try:
            iv = int(v)
            if iv > 0:
                subject_counts[k] = iv
        except Exception:
            continue

    banned_days = _normalize_days(obj.get("banned_days") or [])
    no_before = obj.get("no_before") or None
    no_after  = obj.get("no_after") or None

    def _list_str(x): return [s for s in (x or []) if isinstance(s, str) and s.strip()]
    keywords = _list_str(obj.get("keywords"))
    banned_professors = _list_str(obj.get("banned_professors"))

    return ParsedConstraints(
        subject_counts=subject_counts,
        banned_days=banned_days,
        no_before=no_before,
        no_after=no_after,
        keywords=keywords,
        banned_professors=banned_professors
    )