|
import re |
|
import random |
|
from datetime import datetime |
|
from typing import Dict, List, Tuple |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from matplotlib.patches import Rectangle |
|
from dataset import DATASET |
|
|
|
from constraint_parser_llm import parse_constraints |
|
from semantic_ranker import score_courses |
|
|
|
|
|
import re |
|
|
|
DAY_ALIASES = { |
|
"sunday": "Sun", "sundays": "Sun", "sun": "Sun", |
|
"monday": "Mon", "mondays": "Mon", "mon": "Mon", |
|
"tuesday": "Tue", "tuesdays": "Tue", "tue": "Tue", |
|
"wednesday": "Wed", "wednesdays": "Wed", "wed": "Wed", |
|
"thursday": "Thu", "thursdays": "Thu", "thu": "Thu", |
|
"friday": "Fri", "fridays": "Fri", "fri": "Fri", |
|
"saturday": "Sat", "saturdays": "Sat", "sat": "Sat", |
|
} |
|
WEEKEND = {"Sat", "Sun"} |
|
WEEKDAYS = {"Mon", "Tue", "Wed", "Thu", "Fri"} |
|
|
|
def _norm_ampm(h: str, m: str|None, ap: str|None) -> str: |
|
h_i = int(h) |
|
m_s = (m or "00") |
|
ap = (ap or "").upper() |
|
if ap not in ("AM", "PM"): |
|
|
|
ap = "AM" if h_i <= 11 else "PM" |
|
return f"{h_i}:{m_s} {ap}" |
|
|
|
def parse_det_constraints(text: str): |
|
""" |
|
Deterministic, high-recall extraction of: |
|
- banned_days: {"Mon","Tue",...} |
|
- no_before: "H:MM AM/PM" (start times >= this) |
|
- no_after: "H:MM AM/PM" (end times <= this) |
|
""" |
|
res = {"banned_days": set(), "no_before": None, "no_after": None} |
|
if not text: |
|
return res |
|
t = text.lower().strip() |
|
|
|
|
|
if re.search(r"\bno (?:weekend|weekends)\b", t) or re.search(r"\bweekdays only\b|\bonly on weekdays\b", t): |
|
res["banned_days"] |= WEEKEND |
|
if re.search(r"\bweekends only\b|\bonly on weekends\b", t) or re.search(r"\bno weekdays\b", t): |
|
res["banned_days"] |= WEEKDAYS |
|
|
|
|
|
triggers = r"(?:no|avoid|except|skip|without|not on|exclude|ban|block|never on)" |
|
for alias, abbr in DAY_ALIASES.items(): |
|
|
|
if re.search(rf"\b{triggers}\s+(?:classes?|class|lectures?)?\s*(?:on\s*)?{re.escape(alias)}\b", t): |
|
res["banned_days"].add(abbr) |
|
|
|
|
|
m_after = re.search(r"\b(?:only\s*)?(?:start\s*)?after\s+(\d{1,2})(?::(\d{2}))?\s*(am|pm)?\b", t) |
|
m_not_before = re.search(r"\b(?:not before|no earlier than)\s+(\d{1,2})(?::(\d{2}))?\s*(am|pm)?\b", t) |
|
if m_after or m_not_before: |
|
h, mm, ap = (m_after or m_not_before).groups() |
|
res["no_before"] = _norm_ampm(h, mm, ap) |
|
|
|
|
|
m_before = re.search(r"\b(?:before|end before|not after|no later than)\s+(\d{1,2})(?::(\d{2}))?\s*(am|pm)?\b", t) |
|
if m_before: |
|
h, mm, ap = m_before.groups() |
|
res["no_after"] = _norm_ampm(h, mm, ap) |
|
|
|
|
|
if re.search(r"\bno mornings?\b", t): |
|
res["no_after"] = "12:00 PM" |
|
if re.search(r"\bmornings?\b", t) and not re.search(r"\bno mornings?\b", t): |
|
res["no_before"] = res["no_before"] or "10:00 AM" |
|
|
|
if re.search(r"\bafternoons?\b", t): |
|
res["no_before"] = res["no_before"] or "12:00 PM" |
|
if re.search(r"\bno afternoons?\b", t): |
|
res["no_after"] = "12:00 PM" |
|
|
|
if re.search(r"\bevenings?\b", t): |
|
res["no_before"] = res["no_before"] or "4:00 PM" |
|
if re.search(r"\bno evenings?\b", t): |
|
res["no_after"] = "4:00 PM" |
|
|
|
return res |
|
|
|
|
|
DAY_ALIASES = { |
|
"sunday": "Sun", "sundays": "Sun", "sun": "Sun", |
|
"monday": "Mon", "mondays": "Mon", "mon": "Mon", |
|
"tuesday": "Tue", "tuesdays": "Tue", "tue": "Tue", |
|
"wednesday": "Wed", "wednesdays": "Wed", "wed": "Wed", |
|
"thursday": "Thu", "thursdays": "Thu", "thu": "Thu", |
|
"friday": "Fri", "fridays": "Fri", "fri": "Fri", |
|
"saturday": "Sat", "saturdays": "Sat", "sat": "Sat", |
|
} |
|
|
|
def extract_banned_days_free_text(text: str): |
|
""" |
|
Catch broad natural language variations: |
|
- "no class on monday", "no monday", "avoid mon", "not on mondays", "skip monday", "except monday" |
|
Returns a set of abbreviations like {"Mon","Sun"}. |
|
""" |
|
if not text: |
|
return set() |
|
t = text.lower() |
|
found = set() |
|
|
|
triggers = ["no", "avoid", "except", "skip", "without", "not on", "exclude", "ban", "block"] |
|
for key, abbr in DAY_ALIASES.items(): |
|
|
|
trig_group = "(?:" + "|".join(map(re.escape, triggers)) + ")" |
|
pattern = rf"\b{trig_group}\s+(?:class(?:es)?\s*)?(?:on\s*)?{re.escape(key)}\b" |
|
if re.search(pattern, t): |
|
found.add(abbr) |
|
|
|
pattern2 = rf"\b{trig_group}\s+{re.escape(key)}\b" |
|
if re.search(pattern2, t): |
|
found.add(abbr) |
|
return found |
|
|
|
|
|
df = pd.DataFrame(DATASET) |
|
|
|
|
|
SUBJECTS = sorted(df["subject"].unique().tolist()) |
|
DAY_MAP = {"Sun":"Sunday","Mon":"Monday","Tue":"Tuesday","Wed":"Wednesday","Thu":"Thursday","Fri":"Friday","Sat":"Saturday"} |
|
DAYS_AXIS = ["Sunday","Monday","Tuesday","Wednesday","Thursday","Friday","Saturday"] |
|
TIME_FMT = "%I:%M %p" |
|
|
|
|
|
def parse_time_range(timestr: str): |
|
start_s, end_s = [t.strip() for t in timestr.split("-")] |
|
return datetime.strptime(start_s, TIME_FMT), datetime.strptime(end_s, TIME_FMT) |
|
|
|
def to_hours(t: datetime) -> float: |
|
return t.hour + t.minute / 60.0 |
|
|
|
def block_overlaps(a, b) -> bool: |
|
sa, ea = a; sb, eb = b |
|
return (sa < eb) and (sb < ea) |
|
|
|
def class_record_to_blocks(row_dict): |
|
start_dt, end_dt = parse_time_range(row_dict["times"]) |
|
start_h, end_h = to_hours(start_dt), to_hours(end_dt) |
|
out = [] |
|
for d in row_dict["days"].split(","): |
|
d = d.strip() |
|
if d in DAY_MAP: |
|
out.append((DAY_MAP[d], (start_h, end_h), row_dict)) |
|
return out |
|
|
|
def filter_by_constraints(df_in: pd.DataFrame, instructions: str) -> pd.DataFrame: |
|
filtered = df_in.copy() |
|
text = (instructions or "").lower() |
|
|
|
|
|
days_regex = r"(sundays?|mondays?|tuesdays?|wednesdays?|thursdays?|fridays?|saturdays?|sun|mon|tue|wed|thu|fri|sat)" |
|
if ("no classes on" in text) or ("don't give me classes on" in text) or ("dont give me classes on" in text): |
|
banned_days = {m.group(1) for m in re.finditer(days_regex, text)} |
|
norm_banned = {d[:3].title() for d in banned_days if d[:3].title() in DAY_MAP} |
|
if norm_banned: |
|
mask = ~filtered["days"].apply(lambda s: any(b in [x.strip() for x in s.split(",")] for b in norm_banned)) |
|
filtered = filtered[mask] |
|
|
|
|
|
before = re.search(r"no classes before ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", text) |
|
after = re.search(r"no classes after ([0-9]{1,2})(?::([0-9]{2}))?\s*(am|pm)", text) |
|
|
|
def to_24h(hs, ms, ap): |
|
h = int(hs); m = int(ms) if ms else 0; ap = ap.lower() |
|
if ap == "pm" and h != 12: h += 12 |
|
if ap == "am" and h == 12: h = 0 |
|
return h + m/60.0 |
|
|
|
min_start = to_24h(*before.groups()) if before else None |
|
max_end = to_24h(*after.groups()) if after else None |
|
|
|
if min_start is not None: |
|
filtered = filtered[filtered["times"].apply(lambda t: to_hours(parse_time_range(t)[0]) >= min_start)] |
|
if max_end is not None: |
|
filtered = filtered[filtered["times"].apply(lambda t: to_hours(parse_time_range(t)[1]) <= max_end)] |
|
|
|
return filtered |
|
|
|
def pick_schedules(df_pool: pd.DataFrame, demand: Dict[str, int], max_attempts=500) -> List[List[dict]]: |
|
schedules = [] |
|
rng = random.Random(123) |
|
|
|
def conflict_free(selected_rows: List[dict], candidate_row: pd.Series) -> bool: |
|
cand_blocks = class_record_to_blocks(candidate_row.to_dict()) |
|
by_day = {} |
|
for r in selected_rows: |
|
for d, (s, e), _ in class_record_to_blocks(r): |
|
by_day.setdefault(d, []).append((s, e)) |
|
for d, (s, e), _ in cand_blocks: |
|
for (cs, ce) in by_day.get(d, []): |
|
if block_overlaps((s, e), (cs, ce)): |
|
return False |
|
return True |
|
|
|
for _ in range(3): |
|
attempts = 0 |
|
built = None |
|
while attempts < max_attempts and built is None: |
|
attempts += 1 |
|
remaining = demand.copy() |
|
chosen: List[dict] = [] |
|
idxs = list(df_pool.index) |
|
rng.shuffle(idxs) |
|
subjects_order = list(remaining.keys()) |
|
rng.shuffle(subjects_order) |
|
|
|
progress = True |
|
while progress and any(remaining[s] > 0 for s in subjects_order): |
|
progress = False |
|
for sub in subjects_order: |
|
if remaining[sub] <= 0: |
|
continue |
|
sub_idxs = [i for i in idxs if df_pool.at[i, "subject"] == sub and df_pool.at[i, "class_id"] not in {c["class_id"] for c in chosen}] |
|
rng.shuffle(sub_idxs) |
|
for i in sub_idxs: |
|
row = df_pool.loc[i] |
|
if conflict_free(chosen, row): |
|
chosen.append(row.to_dict()) |
|
remaining[sub] -= 1 |
|
progress = True |
|
break |
|
if all(v == 0 for v in remaining.values()): |
|
built = chosen |
|
|
|
if built is None: |
|
built = chosen if 'chosen' in locals() else [] |
|
schedules.append(built) |
|
return schedules |
|
|
|
def draw_timetable(schedule_rows: List[dict], title: str): |
|
fig, ax = plt.subplots(figsize=(10, 7), dpi=150) |
|
ax.set_xlim(0, 7); ax.set_ylim(8, 21) |
|
|
|
|
|
ax.set_xticks(range(7)) |
|
ax.set_xticklabels(DAYS_AXIS) |
|
ax.xaxis.tick_top() |
|
ax.xaxis.set_label_position('top') |
|
ax.tick_params(axis='x', which='both', bottom=False, top=True, labelbottom=False, labeltop=True) |
|
|
|
ax.set_yticks(range(8, 22, 1)) |
|
ax.set_ylabel("Time") |
|
ax.set_title(title, pad=20) |
|
|
|
for x in range(8): ax.axvline(x, linewidth=0.5) |
|
for y in range(8, 22): ax.axhline(y, linewidth=0.3) |
|
|
|
rng = np.random.default_rng(42) |
|
colors = rng.random((len(schedule_rows), 3)) |
|
|
|
for idx, row in enumerate(schedule_rows): |
|
blocks = class_record_to_blocks(row) |
|
for day, (start_h, end_h), _ in blocks: |
|
if day not in DAYS_AXIS: continue |
|
x = DAYS_AXIS.index(day); y = start_h; height = end_h - start_h |
|
rect = Rectangle((x + 0.05, y + 0.02), 0.9, height - 0.04, |
|
linewidth=0.8, edgecolor='black', facecolor=colors[idx], alpha=0.8) |
|
ax.add_patch(rect) |
|
ax.text(x + 0.07, y + 0.1, f"{row['name']}\n{row['professor']}", |
|
fontsize=7, va='top', wrap=True) |
|
|
|
ax.invert_yaxis(); ax.set_facecolor("white"); fig.tight_layout() |
|
|
|
|
|
fig.canvas.draw() |
|
try: |
|
buf = np.asarray(fig.canvas.buffer_rgba()) |
|
except AttributeError: |
|
renderer = fig.canvas.get_renderer() |
|
buf = np.asarray(renderer.buffer_rgba()) |
|
img = buf[..., :3].copy() |
|
plt.close(fig) |
|
return img |
|
|
|
def schedule_details_table(rows: List[dict]) -> pd.DataFrame: |
|
if not rows: |
|
return pd.DataFrame(columns=["class_id","name","professor","days","times","subject"]) |
|
return pd.DataFrame(rows)[["class_id","name","professor","days","times","subject"]] |
|
|
|
|
|
def _ensure_array_table(tbl): |
|
"""Gradio Dataframe can return None or pandas; normalize to list-of-lists.""" |
|
if tbl is None: |
|
return [] |
|
if isinstance(tbl, pd.DataFrame): |
|
return tbl.values.tolist() |
|
return tbl |
|
|
|
def add_subject(tbl, subject): |
|
tbl = _ensure_array_table(tbl) |
|
if not subject: |
|
return tbl, gr.update(value=None), f"Total classes: {sum(int(r[1]) if str(r[1]).isdigit() else 0 for r in tbl)}" |
|
|
|
subjects = [r[0] for r in tbl] |
|
if subject in subjects: |
|
i = subjects.index(subject) |
|
try: |
|
tbl[i][1] = int(tbl[i][1]) + 1 |
|
except Exception: |
|
tbl[i][1] = 1 |
|
else: |
|
tbl.append([subject, 1]) |
|
total = sum(int(r[1]) if str(r[1]).isdigit() else 0 for r in tbl) |
|
return tbl, gr.update(value=None), f"Total classes: {total}" |
|
|
|
def update_total(tbl): |
|
tbl = _ensure_array_table(tbl) |
|
total = sum(int(row[1]) if str(row[1]).isdigit() else 0 for row in tbl) |
|
return f"Total classes: {total}" |
|
|
|
def generate(tbl, instructions): |
|
tbl = _ensure_array_table(tbl) |
|
|
|
|
|
if not tbl: |
|
blank = draw_timetable([], "Schedule option 1 (empty)") |
|
return ( |
|
blank, "Schedule option 1", |
|
[], 0, |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|
|
|
|
demand = {} |
|
for subject, count in tbl: |
|
try: |
|
c = max(0, int(count)) |
|
except Exception: |
|
c = 0 |
|
if c > 0: |
|
demand[subject] = demand.get(subject, 0) + c |
|
|
|
if not demand: |
|
blank = draw_timetable([], "Schedule option 1 (empty)") |
|
return ( |
|
blank, "Schedule option 1", |
|
[], 0, |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|
|
|
|
det = parse_det_constraints(instructions or "") |
|
|
|
|
|
pool = df.copy() |
|
|
|
|
|
banned_days = set(det["banned_days"]) |
|
if banned_days: |
|
pool = pool[~pool["days"].apply( |
|
lambda s: any(b in [x.strip() for x in s.split(",")] for b in banned_days) |
|
)] |
|
|
|
|
|
def _times_to_hours(ts: str): |
|
s, e = [t.strip() for t in ts.split("-")] |
|
sd = datetime.strptime(s, TIME_FMT) |
|
ed = datetime.strptime(e, TIME_FMT) |
|
return sd.hour + sd.minute/60.0, ed.hour + ed.minute/60.0 |
|
|
|
def _to_hour_24(hhmm_ap: str): |
|
|
|
sd = datetime.strptime(hhmm_ap.strip(), TIME_FMT) |
|
return sd.hour + sd.minute/60.0 |
|
|
|
if det["no_before"]: |
|
th = _to_hour_24(det["no_before"]) |
|
pool = pool[pool["times"].apply(lambda t: _times_to_hours(t)[0] >= th)] |
|
if det["no_after"]: |
|
th = _to_hour_24(det["no_after"]) |
|
pool = pool[pool["times"].apply(lambda t: _times_to_hours(t)[1] <= th)] |
|
|
|
|
|
scheds = pick_schedules(pool, demand) |
|
idx = 0 |
|
rows = scheds[idx] if (scheds and len(scheds) > 0) else [] |
|
title = f"Schedule option {idx+1}" |
|
img = draw_timetable(rows, title) |
|
|
|
prev_vis = gr.update(visible=False) |
|
next_vis = gr.update(visible=True if len(scheds) > 1 else False) |
|
details_vis = gr.update(visible=True) |
|
hide_details_table = gr.update(visible=False) |
|
|
|
return img, title, scheds, idx, prev_vis, next_vis, details_vis, hide_details_table |
|
|
|
def step(direction, scheds, idx): |
|
if not scheds: |
|
return gr.update(), "", idx, gr.update(visible=False), gr.update(visible=False) |
|
if direction == "next": |
|
idx = (idx + 1) % len(scheds) |
|
else: |
|
idx = (idx - 1) % len(scheds) |
|
title = f"Schedule option {idx+1}" |
|
img = draw_timetable(scheds[idx], title) |
|
|
|
return img, title, idx, gr.update(visible=True), gr.update(visible=True) |
|
|
|
def get_details(scheds, idx): |
|
rows = scheds[idx] if (scheds and 0 <= idx < len(scheds)) else [] |
|
return schedule_details_table(rows) |
|
|
|
|
|
with gr.Blocks(css=""" |
|
:root { --radius: 16px; } |
|
#container { max-width: 1400px; margin: 0 auto; } |
|
.card { border: 1px solid #e5e7eb; border-radius: var(--radius); padding: 16px; background: white; box-shadow: 0 6px 24px rgba(0,0,0,0.04); } |
|
.split { display: grid; grid-template-columns: 1fr 1fr; gap: 16px; align-items: start; } |
|
.totals { font-weight: 600; } |
|
.gr-accordion-header { font-weight: 600; } |
|
#custom-instructions-label { position: relative; display: inline-block; cursor: pointer; } |
|
""") as demo: |
|
gr.Markdown("# Class Schedule Generator") |
|
|
|
with gr.Row(elem_id="container"): |
|
|
|
with gr.Column(scale=1, min_width=480, elem_classes=["card"]): |
|
subject_dropdown = gr.Dropdown( |
|
SUBJECTS, label="Select your subjects", value=None, allow_custom_value=False |
|
) |
|
add_btn = gr.Button("➕ Add subject") |
|
|
|
subject_table = gr.Dataframe( |
|
headers=["Subject","Count"], |
|
datatype=["str","number"], |
|
type="array", |
|
value=[], |
|
row_count=(0,"dynamic"), |
|
col_count=2, |
|
interactive=True, |
|
label="Selected subjects & counts" |
|
) |
|
|
|
total_text = gr.Markdown("Total classes: 0", elem_classes=["totals"]) |
|
|
|
|
|
gr.Markdown("### 🤖 Custom instructions") |
|
custom_instructions = gr.Textbox( |
|
label="", |
|
placeholder="Don't give me classes on Sundays", |
|
lines=2 |
|
) |
|
EXAMPLE_PROMPTS = [ |
|
"No classes on Mondays", |
|
"Only start after 10 AM", |
|
"Avoid Fridays, no weekends", |
|
"Weekdays only", |
|
"No class on Sundays and only start after 10 AM", |
|
"Avoid Monday and avoid Friday; end before 3 PM", |
|
"Weekends only and mornings only", |
|
"Evenings only", |
|
] |
|
|
|
with gr.Accordion("See example prompts ▼", open=False): |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
ex_btns_left = [gr.Button(EXAMPLE_PROMPTS[i]) for i in range(0, len(EXAMPLE_PROMPTS), 2)] |
|
with gr.Column(): |
|
|
|
ex_btns_right = [gr.Button(EXAMPLE_PROMPTS[i]) for i in range(1, len(EXAMPLE_PROMPTS), 2)] |
|
|
|
|
|
generate_btn = gr.Button("✨ Generate schedule", variant="primary") |
|
|
|
|
|
with gr.Column(scale=1, min_width=480, elem_classes=["card"]): |
|
schedule_title = gr.Markdown("Schedule option 1") |
|
empty_img = draw_timetable([], "Schedule option 1") |
|
timetable_img = gr.Image(value=empty_img, label=None, interactive=False) |
|
|
|
with gr.Row(): |
|
prev_btn = gr.Button("◀ Previous", visible=False) |
|
next_btn = gr.Button("Next ▶", visible=False) |
|
details_btn = gr.Button("See full class details", visible=False) |
|
|
|
|
|
details_table = gr.Dataframe( |
|
headers=["class_id","name","professor","days","times","subject"], |
|
interactive=False, |
|
visible=False, |
|
label="Selected classes" |
|
) |
|
|
|
|
|
schedules_state = gr.State([]) |
|
index_state = gr.State(0) |
|
|
|
|
|
def _fill_prompt(txt: str): |
|
return gr.update(value=txt) |
|
|
|
for btn in (ex_btns_left + ex_btns_right): |
|
|
|
btn.click(lambda t=btn.value: _fill_prompt(t), outputs=[custom_instructions]) |
|
|
|
|
|
add_btn.click( |
|
add_subject, |
|
inputs=[subject_table, subject_dropdown], |
|
outputs=[subject_table, subject_dropdown, total_text] |
|
) |
|
|
|
subject_table.change(update_total, inputs=[subject_table], outputs=[total_text]) |
|
|
|
generate_btn.click( |
|
generate, |
|
inputs=[subject_table, custom_instructions], |
|
outputs=[timetable_img, schedule_title, schedules_state, index_state, |
|
prev_btn, next_btn, details_btn, details_table] |
|
) |
|
|
|
prev_btn.click( |
|
lambda s,i: step("prev", s, i), |
|
inputs=[schedules_state, index_state], |
|
outputs=[timetable_img, schedule_title, index_state, prev_btn, next_btn] |
|
) |
|
next_btn.click( |
|
lambda s,i: step("next", s, i), |
|
inputs=[schedules_state, index_state], |
|
outputs=[timetable_img, schedule_title, index_state, prev_btn, next_btn] |
|
) |
|
|
|
|
|
details_btn.click( |
|
lambda s,i: gr.update(visible=True, value=get_details(s,i)), |
|
inputs=[schedules_state, index_state], |
|
outputs=[details_table] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |