|
|
|
from __future__ import annotations |
|
import json, re, os |
|
from dataclasses import dataclass, asdict |
|
from typing import Dict, List, Optional |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
|
|
MODEL_NAME = os.environ.get("PARSER_MODEL", "meta-llama/Llama-3.1-8B-Instruct") |
|
|
|
|
|
|
|
_tokenizer = None |
|
_pipe = None |
|
|
|
def _lazy_pipe(): |
|
global _tokenizer, _pipe |
|
if _pipe is None: |
|
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
) |
|
_pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=_tokenizer, |
|
max_new_tokens=256, |
|
temperature=0.0, |
|
do_sample=False, |
|
) |
|
return _pipe |
|
|
|
DAY_ABBR = ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"] |
|
DAY_ALIASES = { |
|
"sunday":"Sun","sundays":"Sun","sun":"Sun", |
|
"monday":"Mon","mondays":"Mon","mon":"Mon", |
|
"tuesday":"Tue","tuesdays":"Tue","tue":"Tue", |
|
"wednesday":"Wed","wednesdays":"Wed","wed":"Wed", |
|
"thursday":"Thu","thursdays":"Thu","thu":"Thu", |
|
"friday":"Fri","fridays":"Fri","fri":"Fri", |
|
"saturday":"Sat","saturdays":"Sat","sat":"Sat", |
|
} |
|
|
|
@dataclass |
|
class ParsedConstraints: |
|
subject_counts: Dict[str, int] |
|
banned_days: List[str] |
|
no_before: Optional[str] = None |
|
no_after: Optional[str] = None |
|
keywords: List[str] = None |
|
banned_professors: List[str] = None |
|
|
|
def to_json(self) -> str: |
|
return json.dumps(asdict(self)) |
|
|
|
SYSTEM_PROMPT = """You convert natural language scheduling requests into STRICT JSON. |
|
Output only JSON with this schema: |
|
|
|
{ |
|
"subject_counts": {"<SubjectName>": <int>}, // optional |
|
"banned_days": ["Sun","Mon","Tue","Wed","Thu","Fri","Sat"], // optional |
|
"no_before": "H:MM AM/PM", // optional |
|
"no_after": "H:MM AM/PM", // optional |
|
"keywords": ["..."], // optional |
|
"banned_professors": ["..."] // optional |
|
} |
|
|
|
Rules: |
|
- Never include commentary. |
|
- Use only "Sun","Mon","Tue","Wed","Thu","Fri","Sat". |
|
- If the user says 'no weekends', add 'Sat' and 'Sun'. |
|
- If they say 'mornings'/'afternoons'/'evenings', map to times: |
|
* mornings => no_before >= "10:00 AM" (i.e., prefer starts at/after 10am) |
|
* afternoons => no_before >= "12:00 PM" |
|
* evenings => no_before >= "4:00 PM" |
|
* 'no mornings' => no_after <= "12:00 PM" |
|
- If counts are implied (e.g., "a couple of CS classes"), interpret: |
|
couple=2, few=3, several=3. |
|
- If ambiguous, make reasonable assumptions but keep to schema. |
|
""" |
|
|
|
FEW_SHOTS = [ |
|
|
|
("make me 4 CS classes, no Sundays", |
|
{"subject_counts":{"Computer Science":4},"banned_days":["Sun"]}), |
|
("i want two econ and one psychology after 1pm, avoid friday", |
|
{"subject_counts":{"Economics":2,"Psychology":1},"no_before":"1:00 PM","banned_days":["Fri"]}), |
|
("prefer software engineering and algorithms, no class on mon or tue mornings", |
|
{"keywords":["software engineering","algorithms"],"banned_days":["Mon","Tue"],"no_after":"12:00 PM"}), |
|
("no weekends, evenings only", |
|
{"banned_days":["Sat","Sun"],"no_before":"4:00 PM"}), |
|
("skip wed; a couple math; no classes before 10am", |
|
{"subject_counts":{"Mathematics":2},"banned_days":["Wed"],"no_before":"10:00 AM"}), |
|
] |
|
|
|
def _build_prompt(user_text: str) -> str: |
|
parts = [SYSTEM_PROMPT, "\nExamples:"] |
|
for u, js in FEW_SHOTS: |
|
parts.append(f"User: {u}\nJSON: {json.dumps(js)}") |
|
parts.append(f"\nUser: {user_text}\nJSON:") |
|
return "\n".join(parts) |
|
|
|
def _extract_json(text: str) -> Optional[dict]: |
|
text = text.strip() |
|
|
|
start = text.find("{") |
|
end = text.rfind("}") |
|
if start != -1 and end != -1 and end > start: |
|
try: |
|
return json.loads(text[start:end+1]) |
|
except Exception: |
|
pass |
|
|
|
m = re.search(r"```json\s*(\{.*?\})\s*```", text, re.S) |
|
if m: |
|
try: |
|
return json.loads(m.group(1)) |
|
except Exception: |
|
pass |
|
return None |
|
|
|
def _normalize_days(days: List[str]) -> List[str]: |
|
out = [] |
|
for d in days or []: |
|
k = d.strip() |
|
|
|
if k in DAY_ABBR and k not in out: |
|
out.append(k); continue |
|
|
|
v = DAY_ALIASES.get(k.lower()) |
|
if v and v not in out: |
|
out.append(v) |
|
return out |
|
|
|
def _fallback_light(text: str) -> ParsedConstraints: |
|
t = (text or "").lower() |
|
|
|
found = set() |
|
for k,v in DAY_ALIASES.items(): |
|
if re.search(rf"\bno( classes?)? (on )?{re.escape(k)}\b", t) or re.search(rf"\bavoid (on )?{re.escape(k)}\b", t) or re.search(rf"\bno {re.escape(k)}\b", t): |
|
found.add(v) |
|
|
|
before = re.search(r"no classes before ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t) |
|
after = re.search(r"no classes after ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", t) |
|
def fmt(h,m,ap): |
|
mm = m if m else "00" |
|
return f"{int(h)}:{mm} {ap.upper()}" |
|
no_before = fmt(*before.groups()) if before else None |
|
no_after = fmt(*after.groups()) if after else None |
|
return ParsedConstraints(subject_counts={}, banned_days=sorted(found), no_before=no_before, no_after=no_after, keywords=[], banned_professors=[]) |
|
|
|
def parse_constraints(text: str) -> ParsedConstraints: |
|
if not text or not text.strip(): |
|
return ParsedConstraints(subject_counts={}, banned_days=[], keywords=[], banned_professors=[]) |
|
|
|
prompt = _build_prompt(text.strip()) |
|
out = _lazy_pipe()(prompt)[0]["generated_text"] |
|
|
|
obj = _extract_json(out) |
|
if not isinstance(obj, dict): |
|
|
|
return _fallback_light(text) |
|
|
|
|
|
subject_counts = {} |
|
for k,v in (obj.get("subject_counts") or {}).items(): |
|
try: |
|
iv = int(v) |
|
if iv > 0: |
|
subject_counts[k] = iv |
|
except Exception: |
|
continue |
|
|
|
banned_days = _normalize_days(obj.get("banned_days") or []) |
|
no_before = obj.get("no_before") or None |
|
no_after = obj.get("no_after") or None |
|
|
|
def _list_str(x): return [s for s in (x or []) if isinstance(s, str) and s.strip()] |
|
keywords = _list_str(obj.get("keywords")) |
|
banned_professors = _list_str(obj.get("banned_professors")) |
|
|
|
return ParsedConstraints( |
|
subject_counts=subject_counts, |
|
banned_days=banned_days, |
|
no_before=no_before, |
|
no_after=no_after, |
|
keywords=keywords, |
|
banned_professors=banned_professors |
|
) |
|
|