Spaces:
Sleeping
Sleeping
| # app.py | |
| import io | |
| import json | |
| import base64 | |
| import random | |
| from typing import Optional, Dict, Any, List, Tuple | |
| import pandas as pd | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import HfApi, hf_hub_download | |
| DATASET_REPO_ID = "piekenius123/Amaze" | |
| REPO_TYPE = "dataset" | |
| SHAPES = ["circle", "hexagon", "square", "triangle"] | |
| SPLITS = ["train", "val", "test"] | |
| MAZE_SIZE_MIN, MAZE_SIZE_MAX = 3, 16 | |
| MAZE_SIZE_CHOICES = ["All"] + [f"{n}×{n}" for n in range(MAZE_SIZE_MIN, MAZE_SIZE_MAX + 1)] | |
| IMAGE_COLS = ["original_img", "m_original_img", "sol_img", "mask_img", "cell_map"] | |
| # ------------------------- | |
| # Decode / parse helpers | |
| # ------------------------- | |
| def safe_json_loads(s: Any) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: | |
| if s is None: | |
| return None, None | |
| if isinstance(s, float) and pd.isna(s): | |
| return None, None | |
| if not isinstance(s, str): | |
| return None, f"metadata is not a string, got type={type(s)}" | |
| ss = s.strip() | |
| if ss == "" or ss.lower() == "null": | |
| return None, None | |
| try: | |
| return json.loads(ss), None | |
| except Exception as e: | |
| return None, str(e) | |
| def decode_base64_image(base64_str: Any) -> Optional[Image.Image]: | |
| if base64_str is None: | |
| return None | |
| if isinstance(base64_str, float) and pd.isna(base64_str): | |
| return None | |
| if isinstance(base64_str, str) and (base64_str.strip() == "" or base64_str.strip().lower() == "null"): | |
| return None | |
| if not isinstance(base64_str, str): | |
| return None | |
| s = base64_str.strip() | |
| try: | |
| if s.startswith("data:"): | |
| s = s.split(",", 1)[1] | |
| img_bytes = base64.b64decode(s) | |
| img = Image.open(io.BytesIO(img_bytes)) | |
| img.load() | |
| return img | |
| except Exception: | |
| return None | |
| def infer_shape_from_repo_path(path: str) -> Optional[str]: | |
| p = path.replace("\\", "/").lower() | |
| for s in SHAPES: | |
| if p.startswith(f"{s}/") or f"/{s}/" in p: | |
| return s | |
| return None | |
| def infer_split_from_repo_path(path: str) -> Optional[str]: | |
| p = path.replace("\\", "/").lower() | |
| fn = p.split("/")[-1] | |
| if fn == "maze_dataset_train.parquet": | |
| return "train" | |
| if fn == "maze_dataset_test.parquet": | |
| if "/maze-dataset_train/" in p: | |
| return "val" | |
| if "/maze-dataset/" in p: | |
| return "test" | |
| return None | |
| def get_metadata_size(meta_str: Any) -> Optional[Tuple[int, int]]: | |
| """ | |
| Your metadata structure says width/height are under maze_config (for non-circle). | |
| Some datasets also duplicate width/height at top-level; we support both. | |
| """ | |
| d, err = safe_json_loads(meta_str) | |
| if not d or err: | |
| return None | |
| mc = d.get("maze_config") if isinstance(d, dict) else None | |
| if isinstance(mc, dict) and ("width" in mc) and ("height" in mc): | |
| try: | |
| return int(mc["width"]), int(mc["height"]) | |
| except Exception: | |
| pass | |
| if ("width" in d) and ("height" in d): | |
| try: | |
| return int(d["width"]), int(d["height"]) | |
| except Exception: | |
| pass | |
| return None | |
| def filter_df_by_maze_size(df: pd.DataFrame, size_str: Optional[str]) -> pd.DataFrame: | |
| if not size_str or size_str == "All": | |
| return df | |
| try: | |
| a, b = size_str.split("×") | |
| w, h = int(a), int(b) | |
| except Exception: | |
| return df | |
| if "metadata" not in df.columns: | |
| return df | |
| mask = df["metadata"].apply(lambda m: get_metadata_size(m) == (w, h)) | |
| return df.loc[mask].reset_index(drop=True) | |
| def summarize_df(df: pd.DataFrame, filtered_len: Optional[int] = None) -> str: | |
| base = f"{len(df)} rows · {len(df.columns)} cols" | |
| if filtered_len is not None and filtered_len != len(df): | |
| base += f" · filtered: {filtered_len}" | |
| return base | |
| def find_index_by_id(df: pd.DataFrame, sample_id: str) -> Optional[int]: | |
| if "id" not in df.columns or not sample_id: | |
| return None | |
| try: | |
| mask = df["id"] == sample_id | |
| if mask.any(): | |
| return int(mask.idxmax()) if isinstance(df.index, pd.RangeIndex) else int(df.index.get_loc(df[mask].index[0])) | |
| except Exception: | |
| pass | |
| try: | |
| mask = df["id"].astype(str).str.contains(sample_id, na=False) | |
| if mask.any(): | |
| first = df[mask].index[0] | |
| return int(df.index.get_loc(first)) | |
| except Exception: | |
| pass | |
| return None | |
| # ------------------------- | |
| # HF repo index + cache | |
| # ------------------------- | |
| def build_repo_index() -> List[Dict[str, str]]: | |
| api = HfApi() | |
| files = api.list_repo_files(repo_id=DATASET_REPO_ID, repo_type=REPO_TYPE) | |
| records: List[Dict[str, str]] = [] | |
| for f in files: | |
| if not f.lower().endswith(".parquet"): | |
| continue | |
| shape = infer_shape_from_repo_path(f) | |
| split = infer_split_from_repo_path(f) | |
| if shape and split: | |
| records.append({"repo_path": f, "shape": shape, "split": split}) | |
| records.sort(key=lambda r: r["repo_path"]) | |
| return records | |
| _DF_CACHE: Dict[str, pd.DataFrame] = {} | |
| def download_and_load_df(repo_path: str) -> pd.DataFrame: | |
| local_path = hf_hub_download( | |
| repo_id=DATASET_REPO_ID, | |
| repo_type=REPO_TYPE, | |
| filename=repo_path, | |
| ) | |
| if local_path in _DF_CACHE: | |
| return _DF_CACHE[local_path] | |
| wanted_cols = ["id", "instruction", "metadata"] + IMAGE_COLS | |
| df = pd.read_parquet(local_path, columns=[c for c in wanted_cols if c is not None]) | |
| _DF_CACHE[local_path] = df | |
| return df | |
| def get_repo_paths(records: List[Dict[str, str]], shape: str, split: str) -> List[str]: | |
| out = [r["repo_path"] for r in (records or []) if r["shape"] == shape and r["split"] == split] | |
| out.sort() | |
| return out | |
| # ------------------------- | |
| # Rendering | |
| # ------------------------- | |
| def render_sample_view(df_filtered: pd.DataFrame, index: int): | |
| if len(df_filtered) == 0: | |
| return ( | |
| 0, | |
| gr.update(value="No samples (after filtering)."), | |
| "", | |
| [], | |
| {}, | |
| "", | |
| ) | |
| index = max(0, min(int(index), len(df_filtered) - 1)) | |
| row = df_filtered.iloc[index] | |
| sid = str(row.get("id", f"maze_{index}")) | |
| instruction = str(row.get("instruction", "")) | |
| original = decode_base64_image(row.get("original_img")) | |
| marked = decode_base64_image(row.get("m_original_img")) or original | |
| cell_map = decode_base64_image(row.get("cell_map")) | |
| mask = decode_base64_image(row.get("mask_img")) | |
| sol = decode_base64_image(row.get("sol_img")) | |
| meta_dict, meta_err = safe_json_loads(row.get("metadata")) | |
| if meta_err: | |
| meta_json = {"_parse_error": meta_err} | |
| else: | |
| meta_json = meta_dict or {} | |
| meta_raw = row.get("metadata", "") | |
| meta_raw = meta_raw if isinstance(meta_raw, str) else str(meta_raw) | |
| gallery_items = [ | |
| (marked, "Marked / Original"), | |
| (original, "Original"), | |
| (sol, "Solution"), | |
| (mask, "Mask"), | |
| (cell_map, "Cell map"), | |
| ] | |
| gallery_items = [(img, cap) for (img, cap) in gallery_items if img is not None] | |
| status_md = f"**Sample** `{sid}` \n**Index** `{index}` / `{len(df_filtered)-1}`" | |
| return index, status_md, instruction, gallery_items, meta_json, meta_raw | |
| # ------------------------- | |
| # Gradio callbacks | |
| # ------------------------- | |
| def init_app(): | |
| try: | |
| recs = build_repo_index() | |
| info_html = f"<div id='badges'><span class='badge'>✅ Indexed <b>{DATASET_REPO_ID}</b></span><span class='badge'>{len(recs)} parquet files</span></div>" | |
| return recs, info_html | |
| except Exception as e: | |
| return [], f"<div id='badges'><span class='badge'>❌ Failed to index: {e}</span></div>" | |
| def on_shape_split_change(records: List[Dict[str, str]], shape: str, split: str): | |
| choices = get_repo_paths(records, shape, split) | |
| value = choices[0] if choices else None | |
| tip_html = f"<div id='badges'><span class='badge'>Found <b>{len(choices)}</b> parquet file(s) for <b>{shape}</b> / <b>{split}</b></span></div>" | |
| return gr.Dropdown(choices=choices, value=value), tip_html | |
| def get_filtered_df(repo_path: str, size_str: str) -> Tuple[pd.DataFrame, str]: | |
| df = download_and_load_df(repo_path) | |
| filtered = filter_df_by_maze_size(df, size_str) | |
| summary = summarize_df(df, filtered_len=len(filtered)) | |
| return filtered, summary | |
| def on_select_parquet(repo_path: str, size_str: str): | |
| if not repo_path: | |
| return gr.update(value="<div id='badges'><span class='badge'>No parquet selected</span></div>"), gr.update(maximum=0, value=0) | |
| filtered, summary = get_filtered_df(repo_path, size_str) | |
| max_idx = max(0, len(filtered) - 1) | |
| summary_html = f"<div id='badges'><span class='badge'>{summary}</span></div>" | |
| return gr.update(value=summary_html), gr.update(maximum=max_idx, value=0) | |
| def on_prev(repo_path: str, index: int, size_str: str): | |
| if not repo_path: | |
| return 0, "No parquet selected.", "", [], {}, "" | |
| filtered, _ = get_filtered_df(repo_path, size_str) | |
| return render_sample_view(filtered, max(0, int(index) - 1)) | |
| def on_next(repo_path: str, index: int, size_str: str): | |
| if not repo_path: | |
| return 0, "No parquet selected.", "", [], {}, "" | |
| filtered, _ = get_filtered_df(repo_path, size_str) | |
| return render_sample_view(filtered, min(len(filtered) - 1, int(index) + 1)) | |
| def on_show(repo_path: str, index: int, size_str: str): | |
| if not repo_path: | |
| return 0, "No parquet selected.", "", [], {}, "" | |
| filtered, _ = get_filtered_df(repo_path, size_str) | |
| return render_sample_view(filtered, index) | |
| def on_random(repo_path: str, size_str: str): | |
| if not repo_path: | |
| return 0, "No parquet selected.", "", [], {}, "" | |
| filtered, _ = get_filtered_df(repo_path, size_str) | |
| if len(filtered) == 0: | |
| return render_sample_view(filtered, 0) | |
| return render_sample_view(filtered, random.randint(0, len(filtered) - 1)) | |
| def on_find_id(repo_path: str, query_id: str, size_str: str): | |
| if not repo_path: | |
| return 0, "No parquet selected.", "", [], {}, "" | |
| filtered, _ = get_filtered_df(repo_path, size_str) | |
| pos = find_index_by_id(filtered, query_id.strip() if isinstance(query_id, str) else "") | |
| if pos is None: | |
| out = list(render_sample_view(filtered, 0)) | |
| out[1] = out[1] + f" \n⚠️ id search `{query_id}` not found" | |
| return tuple(out) | |
| return render_sample_view(filtered, pos) | |
| # ------------------------- | |
| # UI (styled) | |
| # ------------------------- | |
| CSS = """ | |
| /* 使用系统默认字体 */ | |
| .gradio-container { font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif !important; } | |
| /* 全局:页面居中 + 不要铺满 */ | |
| .gradio-container { max-width: 1200px !important; margin: 0 auto !important; } | |
| /* 顶部控制卡片:紧凑、没有大灰底空白 */ | |
| #topbar { | |
| padding: 12px 14px; | |
| border-radius: 16px; | |
| background: var(--block-background-fill); | |
| border: 1px solid var(--border-color-primary); | |
| } | |
| #topbar .gr-row { flex-wrap: wrap; gap: 10px; } | |
| #topbar .gr-form { margin-bottom: 0 !important; } | |
| /* 输入/下拉更紧凑 */ | |
| #topbar input, #topbar textarea, #topbar .wrap { border-radius: 12px !important; } | |
| /* 按钮统一,不要变成右侧巨大菜单 */ | |
| #topbar button { height: 42px !important; border-radius: 12px !important; } | |
| /* badges */ | |
| #badges { display: flex; gap: 10px; flex-wrap: wrap; align-items: center; } | |
| .badge { | |
| padding: 6px 10px; | |
| border-radius: 999px; | |
| border: 1px solid var(--border-color-primary); | |
| background: var(--background-fill-secondary); | |
| font-size: 13px; | |
| line-height: 1.2; | |
| } | |
| /* Index 一行,按钮单独一行并向下留间距 */ | |
| #toolbar .gr-row { align-items: end; } | |
| #toolbar-btns { margin-top: 12px; } | |
| #toolbar-btns .gr-row { align-items: end; } | |
| /* Gallery 更像 viewer */ | |
| #viewer { margin-top: 10px; } | |
| """ | |
| THEME = gr.themes.Soft( | |
| radius_size=gr.themes.sizes.radius_lg, | |
| text_size=gr.themes.sizes.text_md, | |
| ) | |
| def build_ui(): | |
| with gr.Blocks(title="Amaze Viewer", theme=THEME, css=CSS) as demo: | |
| gr.Markdown( | |
| f""" | |
| # Amaze | |
| Dataset: https://huggingface.co/datasets/piekenius123/Amaze | |
| Amaze is a benchmark for Edting-as-Reasoning task (EAR). It features four maze shapes: circle, hexagon, square, and triangle. Each sample provides: an unmarked maze image (original_img), a maze image with start and end points marked (m_original_img), a blue solution path image (sol_img), a binary path mask (mask_img), a cell segmentation map (cell_map), and metadata (JSON) for describing the maze structure and difficulty. | |
| The test set covers various sizes from 3×3 to 16×16 (50 samples for each size), while the training set mainly consists of 3×3 mazes (1024 samples), and validation set consists of 3×3 mazes (256 samples). | |
| Browse samples by **shape / split / maze size**, then view images + metadata. | |
| """ | |
| ) | |
| records_state = gr.State([]) | |
| # Top control bar (compact card) | |
| with gr.Column(elem_id="topbar"): | |
| with gr.Row(): | |
| parquet_tip = gr.HTML(value="<div id='badges'></div>") | |
| summary_badge = gr.HTML(value="<div id='badges'><span class='badge'>No parquet selected</span></div>") | |
| scan_info = gr.HTML(value="<div id='badges'><span class='badge'>Indexing dataset repo…</span></div>") | |
| with gr.Row(): | |
| shape_dd = gr.Dropdown(label="Shape", choices=SHAPES, value="circle", scale=1) | |
| split_dd = gr.Dropdown(label="Split", choices=SPLITS, value="test", scale=1) | |
| size_dd = gr.Dropdown(label="Maze size", choices=MAZE_SIZE_CHOICES, value="All", scale=1) | |
| parquet_dd = gr.Dropdown(label="Parquet", choices=[], value=None, scale=2) | |
| with gr.Row(elem_id="toolbar"): | |
| id_query = gr.Textbox(label="Find by id", placeholder="UUID or substring", scale=2) | |
| idx_slider = gr.Slider(label="Index", minimum=0, maximum=0, value=0, step=1, scale=2) | |
| with gr.Row(): | |
| prev_btn = gr.Button("⬅ Prev", variant="secondary", scale=1) | |
| next_btn = gr.Button("Next ➡", variant="secondary", scale=1) | |
| random_btn = gr.Button("🎲 Random", variant="primary", scale=1) | |
| find_btn = gr.Button("🔎 Find", variant="secondary", scale=1) | |
| show_btn = gr.Button("Show", variant="secondary", scale=1) | |
| # Main viewer layout | |
| with gr.Row(elem_id="viewer"): | |
| with gr.Column(scale=3): | |
| status_md = gr.Markdown(elem_id="status") | |
| gallery = gr.Gallery( | |
| label="Images", | |
| columns=2, | |
| height=520, | |
| object_fit="contain", | |
| preview=True, | |
| ) | |
| with gr.Column(scale=2): | |
| instruction = gr.Textbox(label="Instruction", lines=6, interactive=False) | |
| with gr.Accordion("Metadata (parsed JSON)", open=True): | |
| meta_json = gr.JSON() | |
| with gr.Accordion("Metadata (raw)", open=False): | |
| meta_raw = gr.Textbox(lines=10, interactive=False) | |
| # ---- events ---- | |
| demo.load( | |
| fn=init_app, | |
| inputs=None, | |
| outputs=[records_state, scan_info], | |
| ).then( | |
| fn=on_shape_split_change, | |
| inputs=[records_state, shape_dd, split_dd], | |
| outputs=[parquet_dd, parquet_tip], | |
| ).then( | |
| fn=lambda p, s: on_select_parquet(p, s) if p else (gr.update(value="<div id='badges'><span class='badge'>No parquet selected</span></div>"), gr.update(maximum=0, value=0)), | |
| inputs=[parquet_dd, size_dd], | |
| outputs=[summary_badge, idx_slider], | |
| ).then( | |
| fn=lambda p, s: on_show(p, 0, s) if p else (0, "No parquet selected.", "", [], {}, ""), | |
| inputs=[parquet_dd, size_dd], | |
| outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], | |
| ) | |
| shape_dd.change( | |
| fn=on_shape_split_change, | |
| inputs=[records_state, shape_dd, split_dd], | |
| outputs=[parquet_dd, parquet_tip], | |
| ) | |
| split_dd.change( | |
| fn=on_shape_split_change, | |
| inputs=[records_state, shape_dd, split_dd], | |
| outputs=[parquet_dd, parquet_tip], | |
| ) | |
| parquet_dd.change( | |
| fn=on_select_parquet, | |
| inputs=[parquet_dd, size_dd], | |
| outputs=[summary_badge, idx_slider], | |
| ).then( | |
| fn=lambda p, s: on_show(p, 0, s) if p else (0, "No parquet selected.", "", [], {}, ""), | |
| inputs=[parquet_dd, size_dd], | |
| outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], | |
| ) | |
| size_dd.change( | |
| fn=on_select_parquet, | |
| inputs=[parquet_dd, size_dd], | |
| outputs=[summary_badge, idx_slider], | |
| ).then( | |
| fn=lambda p, s: on_show(p, 0, s) if p else (0, "No parquet selected.", "", [], {}, ""), | |
| inputs=[parquet_dd, size_dd], | |
| outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], | |
| ) | |
| show_btn.click( | |
| fn=on_show, | |
| inputs=[parquet_dd, idx_slider, size_dd], | |
| outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], | |
| ) | |
| idx_slider.release( | |
| fn=on_show, | |
| inputs=[parquet_dd, idx_slider, size_dd], | |
| outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], | |
| ) | |
| prev_btn.click( | |
| fn=on_prev, | |
| inputs=[parquet_dd, idx_slider, size_dd], | |
| outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], | |
| ) | |
| next_btn.click( | |
| fn=on_next, | |
| inputs=[parquet_dd, idx_slider, size_dd], | |
| outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], | |
| ) | |
| random_btn.click( | |
| fn=on_random, | |
| inputs=[parquet_dd, size_dd], | |
| outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], | |
| ) | |
| find_btn.click( | |
| fn=on_find_id, | |
| inputs=[parquet_dd, id_query, size_dd], | |
| outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| demo.launch() |