piekenius123's picture
Update app.py
3f84696 verified
# app.py
import io
import json
import base64
import random
from typing import Optional, Dict, Any, List, Tuple
import pandas as pd
from PIL import Image
import gradio as gr
from huggingface_hub import HfApi, hf_hub_download
DATASET_REPO_ID = "piekenius123/Amaze"
REPO_TYPE = "dataset"
SHAPES = ["circle", "hexagon", "square", "triangle"]
SPLITS = ["train", "val", "test"]
MAZE_SIZE_MIN, MAZE_SIZE_MAX = 3, 16
MAZE_SIZE_CHOICES = ["All"] + [f"{n}×{n}" for n in range(MAZE_SIZE_MIN, MAZE_SIZE_MAX + 1)]
IMAGE_COLS = ["original_img", "m_original_img", "sol_img", "mask_img", "cell_map"]
# -------------------------
# Decode / parse helpers
# -------------------------
def safe_json_loads(s: Any) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
if s is None:
return None, None
if isinstance(s, float) and pd.isna(s):
return None, None
if not isinstance(s, str):
return None, f"metadata is not a string, got type={type(s)}"
ss = s.strip()
if ss == "" or ss.lower() == "null":
return None, None
try:
return json.loads(ss), None
except Exception as e:
return None, str(e)
def decode_base64_image(base64_str: Any) -> Optional[Image.Image]:
if base64_str is None:
return None
if isinstance(base64_str, float) and pd.isna(base64_str):
return None
if isinstance(base64_str, str) and (base64_str.strip() == "" or base64_str.strip().lower() == "null"):
return None
if not isinstance(base64_str, str):
return None
s = base64_str.strip()
try:
if s.startswith("data:"):
s = s.split(",", 1)[1]
img_bytes = base64.b64decode(s)
img = Image.open(io.BytesIO(img_bytes))
img.load()
return img
except Exception:
return None
def infer_shape_from_repo_path(path: str) -> Optional[str]:
p = path.replace("\\", "/").lower()
for s in SHAPES:
if p.startswith(f"{s}/") or f"/{s}/" in p:
return s
return None
def infer_split_from_repo_path(path: str) -> Optional[str]:
p = path.replace("\\", "/").lower()
fn = p.split("/")[-1]
if fn == "maze_dataset_train.parquet":
return "train"
if fn == "maze_dataset_test.parquet":
if "/maze-dataset_train/" in p:
return "val"
if "/maze-dataset/" in p:
return "test"
return None
def get_metadata_size(meta_str: Any) -> Optional[Tuple[int, int]]:
"""
Your metadata structure says width/height are under maze_config (for non-circle).
Some datasets also duplicate width/height at top-level; we support both.
"""
d, err = safe_json_loads(meta_str)
if not d or err:
return None
mc = d.get("maze_config") if isinstance(d, dict) else None
if isinstance(mc, dict) and ("width" in mc) and ("height" in mc):
try:
return int(mc["width"]), int(mc["height"])
except Exception:
pass
if ("width" in d) and ("height" in d):
try:
return int(d["width"]), int(d["height"])
except Exception:
pass
return None
def filter_df_by_maze_size(df: pd.DataFrame, size_str: Optional[str]) -> pd.DataFrame:
if not size_str or size_str == "All":
return df
try:
a, b = size_str.split("×")
w, h = int(a), int(b)
except Exception:
return df
if "metadata" not in df.columns:
return df
mask = df["metadata"].apply(lambda m: get_metadata_size(m) == (w, h))
return df.loc[mask].reset_index(drop=True)
def summarize_df(df: pd.DataFrame, filtered_len: Optional[int] = None) -> str:
base = f"{len(df)} rows · {len(df.columns)} cols"
if filtered_len is not None and filtered_len != len(df):
base += f" · filtered: {filtered_len}"
return base
def find_index_by_id(df: pd.DataFrame, sample_id: str) -> Optional[int]:
if "id" not in df.columns or not sample_id:
return None
try:
mask = df["id"] == sample_id
if mask.any():
return int(mask.idxmax()) if isinstance(df.index, pd.RangeIndex) else int(df.index.get_loc(df[mask].index[0]))
except Exception:
pass
try:
mask = df["id"].astype(str).str.contains(sample_id, na=False)
if mask.any():
first = df[mask].index[0]
return int(df.index.get_loc(first))
except Exception:
pass
return None
# -------------------------
# HF repo index + cache
# -------------------------
def build_repo_index() -> List[Dict[str, str]]:
api = HfApi()
files = api.list_repo_files(repo_id=DATASET_REPO_ID, repo_type=REPO_TYPE)
records: List[Dict[str, str]] = []
for f in files:
if not f.lower().endswith(".parquet"):
continue
shape = infer_shape_from_repo_path(f)
split = infer_split_from_repo_path(f)
if shape and split:
records.append({"repo_path": f, "shape": shape, "split": split})
records.sort(key=lambda r: r["repo_path"])
return records
_DF_CACHE: Dict[str, pd.DataFrame] = {}
def download_and_load_df(repo_path: str) -> pd.DataFrame:
local_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
repo_type=REPO_TYPE,
filename=repo_path,
)
if local_path in _DF_CACHE:
return _DF_CACHE[local_path]
wanted_cols = ["id", "instruction", "metadata"] + IMAGE_COLS
df = pd.read_parquet(local_path, columns=[c for c in wanted_cols if c is not None])
_DF_CACHE[local_path] = df
return df
def get_repo_paths(records: List[Dict[str, str]], shape: str, split: str) -> List[str]:
out = [r["repo_path"] for r in (records or []) if r["shape"] == shape and r["split"] == split]
out.sort()
return out
# -------------------------
# Rendering
# -------------------------
def render_sample_view(df_filtered: pd.DataFrame, index: int):
if len(df_filtered) == 0:
return (
0,
gr.update(value="No samples (after filtering)."),
"",
[],
{},
"",
)
index = max(0, min(int(index), len(df_filtered) - 1))
row = df_filtered.iloc[index]
sid = str(row.get("id", f"maze_{index}"))
instruction = str(row.get("instruction", ""))
original = decode_base64_image(row.get("original_img"))
marked = decode_base64_image(row.get("m_original_img")) or original
cell_map = decode_base64_image(row.get("cell_map"))
mask = decode_base64_image(row.get("mask_img"))
sol = decode_base64_image(row.get("sol_img"))
meta_dict, meta_err = safe_json_loads(row.get("metadata"))
if meta_err:
meta_json = {"_parse_error": meta_err}
else:
meta_json = meta_dict or {}
meta_raw = row.get("metadata", "")
meta_raw = meta_raw if isinstance(meta_raw, str) else str(meta_raw)
gallery_items = [
(marked, "Marked / Original"),
(original, "Original"),
(sol, "Solution"),
(mask, "Mask"),
(cell_map, "Cell map"),
]
gallery_items = [(img, cap) for (img, cap) in gallery_items if img is not None]
status_md = f"**Sample** `{sid}` \n**Index** `{index}` / `{len(df_filtered)-1}`"
return index, status_md, instruction, gallery_items, meta_json, meta_raw
# -------------------------
# Gradio callbacks
# -------------------------
def init_app():
try:
recs = build_repo_index()
info_html = f"<div id='badges'><span class='badge'>✅ Indexed <b>{DATASET_REPO_ID}</b></span><span class='badge'>{len(recs)} parquet files</span></div>"
return recs, info_html
except Exception as e:
return [], f"<div id='badges'><span class='badge'>❌ Failed to index: {e}</span></div>"
def on_shape_split_change(records: List[Dict[str, str]], shape: str, split: str):
choices = get_repo_paths(records, shape, split)
value = choices[0] if choices else None
tip_html = f"<div id='badges'><span class='badge'>Found <b>{len(choices)}</b> parquet file(s) for <b>{shape}</b> / <b>{split}</b></span></div>"
return gr.Dropdown(choices=choices, value=value), tip_html
def get_filtered_df(repo_path: str, size_str: str) -> Tuple[pd.DataFrame, str]:
df = download_and_load_df(repo_path)
filtered = filter_df_by_maze_size(df, size_str)
summary = summarize_df(df, filtered_len=len(filtered))
return filtered, summary
def on_select_parquet(repo_path: str, size_str: str):
if not repo_path:
return gr.update(value="<div id='badges'><span class='badge'>No parquet selected</span></div>"), gr.update(maximum=0, value=0)
filtered, summary = get_filtered_df(repo_path, size_str)
max_idx = max(0, len(filtered) - 1)
summary_html = f"<div id='badges'><span class='badge'>{summary}</span></div>"
return gr.update(value=summary_html), gr.update(maximum=max_idx, value=0)
def on_prev(repo_path: str, index: int, size_str: str):
if not repo_path:
return 0, "No parquet selected.", "", [], {}, ""
filtered, _ = get_filtered_df(repo_path, size_str)
return render_sample_view(filtered, max(0, int(index) - 1))
def on_next(repo_path: str, index: int, size_str: str):
if not repo_path:
return 0, "No parquet selected.", "", [], {}, ""
filtered, _ = get_filtered_df(repo_path, size_str)
return render_sample_view(filtered, min(len(filtered) - 1, int(index) + 1))
def on_show(repo_path: str, index: int, size_str: str):
if not repo_path:
return 0, "No parquet selected.", "", [], {}, ""
filtered, _ = get_filtered_df(repo_path, size_str)
return render_sample_view(filtered, index)
def on_random(repo_path: str, size_str: str):
if not repo_path:
return 0, "No parquet selected.", "", [], {}, ""
filtered, _ = get_filtered_df(repo_path, size_str)
if len(filtered) == 0:
return render_sample_view(filtered, 0)
return render_sample_view(filtered, random.randint(0, len(filtered) - 1))
def on_find_id(repo_path: str, query_id: str, size_str: str):
if not repo_path:
return 0, "No parquet selected.", "", [], {}, ""
filtered, _ = get_filtered_df(repo_path, size_str)
pos = find_index_by_id(filtered, query_id.strip() if isinstance(query_id, str) else "")
if pos is None:
out = list(render_sample_view(filtered, 0))
out[1] = out[1] + f" \n⚠️ id search `{query_id}` not found"
return tuple(out)
return render_sample_view(filtered, pos)
# -------------------------
# UI (styled)
# -------------------------
CSS = """
/* 使用系统默认字体 */
.gradio-container { font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif !important; }
/* 全局:页面居中 + 不要铺满 */
.gradio-container { max-width: 1200px !important; margin: 0 auto !important; }
/* 顶部控制卡片:紧凑、没有大灰底空白 */
#topbar {
padding: 12px 14px;
border-radius: 16px;
background: var(--block-background-fill);
border: 1px solid var(--border-color-primary);
}
#topbar .gr-row { flex-wrap: wrap; gap: 10px; }
#topbar .gr-form { margin-bottom: 0 !important; }
/* 输入/下拉更紧凑 */
#topbar input, #topbar textarea, #topbar .wrap { border-radius: 12px !important; }
/* 按钮统一,不要变成右侧巨大菜单 */
#topbar button { height: 42px !important; border-radius: 12px !important; }
/* badges */
#badges { display: flex; gap: 10px; flex-wrap: wrap; align-items: center; }
.badge {
padding: 6px 10px;
border-radius: 999px;
border: 1px solid var(--border-color-primary);
background: var(--background-fill-secondary);
font-size: 13px;
line-height: 1.2;
}
/* Index 一行,按钮单独一行并向下留间距 */
#toolbar .gr-row { align-items: end; }
#toolbar-btns { margin-top: 12px; }
#toolbar-btns .gr-row { align-items: end; }
/* Gallery 更像 viewer */
#viewer { margin-top: 10px; }
"""
THEME = gr.themes.Soft(
radius_size=gr.themes.sizes.radius_lg,
text_size=gr.themes.sizes.text_md,
)
def build_ui():
with gr.Blocks(title="Amaze Viewer", theme=THEME, css=CSS) as demo:
gr.Markdown(
f"""
# Amaze
Dataset: https://huggingface.co/datasets/piekenius123/Amaze
Amaze is a benchmark for Edting-as-Reasoning task (EAR). It features four maze shapes: circle, hexagon, square, and triangle. Each sample provides: an unmarked maze image (original_img), a maze image with start and end points marked (m_original_img), a blue solution path image (sol_img), a binary path mask (mask_img), a cell segmentation map (cell_map), and metadata (JSON) for describing the maze structure and difficulty.
The test set covers various sizes from 3×3 to 16×16 (50 samples for each size), while the training set mainly consists of 3×3 mazes (1024 samples), and validation set consists of 3×3 mazes (256 samples).
Browse samples by **shape / split / maze size**, then view images + metadata.
"""
)
records_state = gr.State([])
# Top control bar (compact card)
with gr.Column(elem_id="topbar"):
with gr.Row():
parquet_tip = gr.HTML(value="<div id='badges'></div>")
summary_badge = gr.HTML(value="<div id='badges'><span class='badge'>No parquet selected</span></div>")
scan_info = gr.HTML(value="<div id='badges'><span class='badge'>Indexing dataset repo…</span></div>")
with gr.Row():
shape_dd = gr.Dropdown(label="Shape", choices=SHAPES, value="circle", scale=1)
split_dd = gr.Dropdown(label="Split", choices=SPLITS, value="test", scale=1)
size_dd = gr.Dropdown(label="Maze size", choices=MAZE_SIZE_CHOICES, value="All", scale=1)
parquet_dd = gr.Dropdown(label="Parquet", choices=[], value=None, scale=2)
with gr.Row(elem_id="toolbar"):
id_query = gr.Textbox(label="Find by id", placeholder="UUID or substring", scale=2)
idx_slider = gr.Slider(label="Index", minimum=0, maximum=0, value=0, step=1, scale=2)
with gr.Row():
prev_btn = gr.Button("⬅ Prev", variant="secondary", scale=1)
next_btn = gr.Button("Next ➡", variant="secondary", scale=1)
random_btn = gr.Button("🎲 Random", variant="primary", scale=1)
find_btn = gr.Button("🔎 Find", variant="secondary", scale=1)
show_btn = gr.Button("Show", variant="secondary", scale=1)
# Main viewer layout
with gr.Row(elem_id="viewer"):
with gr.Column(scale=3):
status_md = gr.Markdown(elem_id="status")
gallery = gr.Gallery(
label="Images",
columns=2,
height=520,
object_fit="contain",
preview=True,
)
with gr.Column(scale=2):
instruction = gr.Textbox(label="Instruction", lines=6, interactive=False)
with gr.Accordion("Metadata (parsed JSON)", open=True):
meta_json = gr.JSON()
with gr.Accordion("Metadata (raw)", open=False):
meta_raw = gr.Textbox(lines=10, interactive=False)
# ---- events ----
demo.load(
fn=init_app,
inputs=None,
outputs=[records_state, scan_info],
).then(
fn=on_shape_split_change,
inputs=[records_state, shape_dd, split_dd],
outputs=[parquet_dd, parquet_tip],
).then(
fn=lambda p, s: on_select_parquet(p, s) if p else (gr.update(value="<div id='badges'><span class='badge'>No parquet selected</span></div>"), gr.update(maximum=0, value=0)),
inputs=[parquet_dd, size_dd],
outputs=[summary_badge, idx_slider],
).then(
fn=lambda p, s: on_show(p, 0, s) if p else (0, "No parquet selected.", "", [], {}, ""),
inputs=[parquet_dd, size_dd],
outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
)
shape_dd.change(
fn=on_shape_split_change,
inputs=[records_state, shape_dd, split_dd],
outputs=[parquet_dd, parquet_tip],
)
split_dd.change(
fn=on_shape_split_change,
inputs=[records_state, shape_dd, split_dd],
outputs=[parquet_dd, parquet_tip],
)
parquet_dd.change(
fn=on_select_parquet,
inputs=[parquet_dd, size_dd],
outputs=[summary_badge, idx_slider],
).then(
fn=lambda p, s: on_show(p, 0, s) if p else (0, "No parquet selected.", "", [], {}, ""),
inputs=[parquet_dd, size_dd],
outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
)
size_dd.change(
fn=on_select_parquet,
inputs=[parquet_dd, size_dd],
outputs=[summary_badge, idx_slider],
).then(
fn=lambda p, s: on_show(p, 0, s) if p else (0, "No parquet selected.", "", [], {}, ""),
inputs=[parquet_dd, size_dd],
outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
)
show_btn.click(
fn=on_show,
inputs=[parquet_dd, idx_slider, size_dd],
outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
)
idx_slider.release(
fn=on_show,
inputs=[parquet_dd, idx_slider, size_dd],
outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
)
prev_btn.click(
fn=on_prev,
inputs=[parquet_dd, idx_slider, size_dd],
outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
)
next_btn.click(
fn=on_next,
inputs=[parquet_dd, idx_slider, size_dd],
outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
)
random_btn.click(
fn=on_random,
inputs=[parquet_dd, size_dd],
outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
)
find_btn.click(
fn=on_find_id,
inputs=[parquet_dd, id_query, size_dd],
outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
)
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch()