ddebree commited on
Commit
f9306c2
·
verified ·
1 Parent(s): 683fd03

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1
4
+ ENV PYTHONUNBUFFERED=1
5
+ ENV HF_HOME=/data/.cache/huggingface
6
+ ENV STREAMLIT_SERVER_HEADLESS=true
7
+ ENV STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
8
+
9
+ WORKDIR /app
10
+
11
+ RUN apt-get update \
12
+ && apt-get install -y --no-install-recommends git \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ COPY pyproject.toml README.md requirements.txt ./
16
+ COPY src ./src
17
+ COPY app.py ./
18
+
19
+ RUN pip install --no-cache-dir --upgrade pip \
20
+ && pip install --no-cache-dir -r requirements.txt
21
+
22
+ EXPOSE 8501
23
+
24
+ CMD ["streamlit", "run", "app.py", "--server.address=0.0.0.0", "--server.port=8501"]
README.md CHANGED
@@ -1,10 +1,144 @@
1
  ---
2
- title: Mathvision Jepa Explorer
3
- emoji: 💻
4
- colorFrom: gray
5
- colorTo: red
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: MathVision JEPA Explorer
3
+ emoji: 🔎
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
+ # MathVision Explorer
11
+
12
+ Typed Python-startpunt voor een **MathVision + JEPA** explorer. De eerste versie werkt
13
+ lokaal met MathVision-achtige JSONL-bestanden en een eenvoudige, testbare image embedder.
14
+ Daarna kun je de embedder vervangen door V-JEPA-features.
15
+
16
+ ## Installatie
17
+
18
+ ```powershell
19
+ uv sync --dev
20
+ ```
21
+
22
+ Of via Make:
23
+
24
+ ```bash
25
+ make sync
26
+ ```
27
+
28
+ Met V-JEPA ondersteuning:
29
+
30
+ ```powershell
31
+ make sync-jepa
32
+ ```
33
+
34
+ Alles klaarzetten, inclusief Streamlit, V-JEPA, demo-data, gallery en checks:
35
+
36
+ ```bash
37
+ make ready
38
+ ```
39
+
40
+ Met Streamlit UI:
41
+
42
+ ```bash
43
+ make sync-app
44
+ ```
45
+
46
+ ## Verwacht JSONL-formaat
47
+
48
+ Elke regel is een probleem:
49
+
50
+ ```json
51
+ {"id":"mv-001","question":"How many cubes are visible?","answer":"7","image":"images/mv-001.png","subject":"geometry","level":2}
52
+ ```
53
+
54
+ Ondersteunde velden:
55
+
56
+ - `id` of `problem_id`
57
+ - `question`
58
+ - `answer`
59
+ - `image`
60
+ - `options`
61
+ - `subject`
62
+ - `level`
63
+ - `solution`
64
+
65
+ ## Gebruik
66
+
67
+ Maak eerst een zichtbare mini-demo:
68
+
69
+ ```powershell
70
+ uv run mathvision demo data/demo
71
+ uv run mathvision export-html data/demo/demo.jsonl artifacts/demo.html
72
+ ```
73
+
74
+ Hetzelfde via Make:
75
+
76
+ ```bash
77
+ make demo
78
+ make gallery
79
+ make app
80
+ ```
81
+
82
+ Inspecteer een export:
83
+
84
+ ```powershell
85
+ uv run mathvision inspect data/mathvision-sample.jsonl
86
+ ```
87
+
88
+ Bouw een lokale image-index:
89
+
90
+ ```powershell
91
+ uv run mathvision index data/mathvision-sample.jsonl artifacts/image-index.tsv
92
+ ```
93
+
94
+ Bouw een V-JEPA index:
95
+
96
+ ```powershell
97
+ uv run mathvision index data/mathvision-sample.jsonl artifacts/jepa-index.tsv --embedder vjepa
98
+ ```
99
+
100
+ Zoek visueel vergelijkbare problemen:
101
+
102
+ ```powershell
103
+ uv run mathvision search data/mathvision-sample.jsonl artifacts/image-index.tsv mv-001 --limit 5
104
+ ```
105
+
106
+ Zoek met dezelfde V-JEPA embedder:
107
+
108
+ ```powershell
109
+ uv run mathvision search data/mathvision-sample.jsonl artifacts/jepa-index.tsv mv-001 --embedder vjepa --limit 5
110
+ ```
111
+
112
+ ## V-JEPA integratie
113
+
114
+ De module `mathvision_explorer.embeddings` definieert een `ImageEmbedder` protocol. De
115
+ meegeleverde `ColorStatsEmbedder` is bewust simpel, zodat tests snel en offline draaien.
116
+ `VJepaImageEmbedder` gebruikt standaard `facebook/vjepa2-vitl-fpc64-256` via Hugging Face
117
+ Transformers en maakt van een still image een korte herhaalde video-input voor
118
+ `get_vision_features`.
119
+
120
+ ## Kwaliteit
121
+
122
+ ```powershell
123
+ uv run pytest
124
+ uv run mypy
125
+ uv run ruff check .
126
+ ```
127
+
128
+ Of:
129
+
130
+ ```bash
131
+ make check
132
+ ```
133
+
134
+ ## Hugging Face Spaces
135
+
136
+ Deze repo is voorbereid als Docker Space. De hosted entrypoint is `app.py`; die maakt
137
+ demo-data aan en start daarna de Streamlit explorer.
138
+
139
+ Belangrijk voor V-JEPA:
140
+
141
+ - Docker gebruikt CPU-only PyTorch wheels.
142
+ - Transformers wordt vanaf Hugging Face main geinstalleerd, omdat V-JEPA 2 daar recente
143
+ support nodig heeft.
144
+ - De eerste keer dat `vjepa` gekozen wordt, downloadt de app het model.
app.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hosted Streamlit entrypoint for Hugging Face Spaces."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ from mathvision_explorer.demo import create_demo_dataset
8
+ from mathvision_explorer.streamlit_app import main
9
+
10
+ create_demo_dataset(Path("data/demo"))
11
+ main()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git
pyproject.toml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "mathvision-explorer"
3
+ version = "0.1.0"
4
+ description = "Typed tools for exploring MathVision-style visual math datasets with JEPA-ready embeddings."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ authors = [{ name = "MathVision Explorer" }]
8
+ dependencies = [
9
+ "pillow>=10.0",
10
+ ]
11
+
12
+ [project.optional-dependencies]
13
+ app = [
14
+ "streamlit>=1.35",
15
+ ]
16
+ jepa = [
17
+ "numpy>=1.26",
18
+ "torch>=2.4",
19
+ "torchvision>=0.19",
20
+ ]
21
+
22
+ [project.scripts]
23
+ mathvision = "mathvision_explorer.cli:main"
24
+
25
+ [build-system]
26
+ requires = ["hatchling"]
27
+ build-backend = "hatchling.build"
28
+
29
+ [dependency-groups]
30
+ dev = [
31
+ "mypy>=1.10",
32
+ "pytest>=8.0",
33
+ "ruff>=0.5",
34
+ ]
35
+
36
+ [tool.pytest.ini_options]
37
+ addopts = "-q"
38
+ testpaths = ["tests"]
39
+
40
+ [tool.uv.sources]
41
+ torch = { index = "pytorch-cpu" }
42
+ torchvision = { index = "pytorch-cpu" }
43
+
44
+ [[tool.uv.index]]
45
+ name = "pytorch-cpu"
46
+ url = "https://download.pytorch.org/whl/cpu"
47
+ explicit = true
48
+
49
+ [tool.ruff]
50
+ line-length = 100
51
+ target-version = "py311"
52
+
53
+ [tool.ruff.lint]
54
+ select = ["E", "F", "I", "UP", "B", "SIM"]
55
+
56
+ [tool.mypy]
57
+ python_version = "3.11"
58
+ strict = true
59
+ files = ["src", "tests"]
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ -e .
3
+ streamlit>=1.35
4
+ numpy>=1.26
5
+ torch>=2.4
6
+ torchvision>=0.19
7
+ git+https://github.com/huggingface/transformers
src/mathvision_explorer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Utilities for exploring visual math problems and JEPA-style embeddings."""
2
+
3
+ from mathvision_explorer.dataset import MathVisionRecord, load_jsonl_records
4
+ from mathvision_explorer.index import Neighbor, VectorIndex
5
+
6
+ __all__ = ["MathVisionRecord", "Neighbor", "VectorIndex", "load_jsonl_records"]
src/mathvision_explorer/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (494 Bytes). View file
 
src/mathvision_explorer/__pycache__/cli.cpython-313.pyc ADDED
Binary file (7.01 kB). View file
 
src/mathvision_explorer/__pycache__/dataset.cpython-313.pyc ADDED
Binary file (7.04 kB). View file
 
src/mathvision_explorer/__pycache__/demo.cpython-313.pyc ADDED
Binary file (8.83 kB). View file
 
src/mathvision_explorer/__pycache__/embeddings.cpython-313.pyc ADDED
Binary file (8.24 kB). View file
 
src/mathvision_explorer/__pycache__/explorer.cpython-313.pyc ADDED
Binary file (2.42 kB). View file
 
src/mathvision_explorer/__pycache__/html.cpython-313.pyc ADDED
Binary file (5.29 kB). View file
 
src/mathvision_explorer/__pycache__/index.cpython-313.pyc ADDED
Binary file (6.59 kB). View file
 
src/mathvision_explorer/__pycache__/similarity.cpython-313.pyc ADDED
Binary file (3.49 kB). View file
 
src/mathvision_explorer/cli.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Command-line interface for MathVision Explorer."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+
9
+ from mathvision_explorer.dataset import load_jsonl_records, summarize_records
10
+ from mathvision_explorer.demo import create_demo_dataset
11
+ from mathvision_explorer.embeddings import (
12
+ ColorStatsEmbedder,
13
+ ImageEmbedder,
14
+ VJepaImageEmbedder,
15
+ embed_record_image,
16
+ )
17
+ from mathvision_explorer.explorer import build_image_index
18
+ from mathvision_explorer.html import export_html
19
+ from mathvision_explorer.index import VectorIndex
20
+
21
+
22
+ def main() -> None:
23
+ """Run the MathVision Explorer command-line interface."""
24
+
25
+ parser = argparse.ArgumentParser(prog="mathvision")
26
+ subparsers = parser.add_subparsers(dest="command", required=True)
27
+
28
+ inspect_parser = subparsers.add_parser("inspect", help="Inspect a MathVision-style JSONL file.")
29
+ inspect_parser.add_argument("jsonl", type=Path)
30
+
31
+ demo_parser = subparsers.add_parser("demo", help="Create a tiny local demo dataset.")
32
+ demo_parser.add_argument("output_dir", type=Path)
33
+
34
+ html_parser = subparsers.add_parser("export-html", help="Export records to a browser gallery.")
35
+ html_parser.add_argument("jsonl", type=Path)
36
+ html_parser.add_argument("output", type=Path)
37
+
38
+ index_parser = subparsers.add_parser("index", help="Build a local image-feature index.")
39
+ index_parser.add_argument("jsonl", type=Path)
40
+ index_parser.add_argument("output", type=Path)
41
+ _add_embedder_arguments(index_parser)
42
+
43
+ search_parser = subparsers.add_parser("search", help="Search similar indexed records.")
44
+ search_parser.add_argument("jsonl", type=Path)
45
+ search_parser.add_argument("index", type=Path)
46
+ search_parser.add_argument("query_id")
47
+ search_parser.add_argument("--limit", type=int, default=5)
48
+ _add_embedder_arguments(search_parser)
49
+
50
+ args = parser.parse_args()
51
+ if args.command == "inspect":
52
+ _inspect(args.jsonl)
53
+ elif args.command == "demo":
54
+ _demo(args.output_dir)
55
+ elif args.command == "export-html":
56
+ _export_html(args.jsonl, args.output)
57
+ elif args.command == "index":
58
+ _index(args.jsonl, args.output, embedder=_embedder_from_args(args))
59
+ elif args.command == "search":
60
+ _search(
61
+ args.jsonl,
62
+ args.index,
63
+ args.query_id,
64
+ limit=args.limit,
65
+ embedder=_embedder_from_args(args),
66
+ )
67
+
68
+
69
+ def _inspect(jsonl: Path) -> None:
70
+ records = load_jsonl_records(jsonl)
71
+ print(json.dumps(summarize_records(records), indent=2, sort_keys=True))
72
+
73
+
74
+ def _demo(output_dir: Path) -> None:
75
+ jsonl_path = create_demo_dataset(output_dir)
76
+ print(f"Wrote demo dataset to {jsonl_path}")
77
+
78
+
79
+ def _export_html(jsonl: Path, output: Path) -> None:
80
+ records = load_jsonl_records(jsonl)
81
+ export_html(records, output)
82
+ print(f"Wrote gallery to {output}")
83
+
84
+
85
+ def _index(jsonl: Path, output: Path, *, embedder: ImageEmbedder) -> None:
86
+ records = load_jsonl_records(jsonl)
87
+ index = build_image_index(records, embedder)
88
+ index.save_tsv(output)
89
+ print(f"Wrote {len(index)} vectors to {output}")
90
+
91
+
92
+ def _search(
93
+ jsonl: Path,
94
+ index_path: Path,
95
+ query_id: str,
96
+ *,
97
+ limit: int,
98
+ embedder: ImageEmbedder,
99
+ ) -> None:
100
+ records = load_jsonl_records(jsonl)
101
+ record_by_id = {record.problem_id: record for record in records}
102
+ query_record = record_by_id.get(query_id)
103
+ if query_record is None:
104
+ raise SystemExit(f"Unknown query id: {query_id}")
105
+
106
+ query_vector = embed_record_image(query_record.image_path, embedder)
107
+ index = VectorIndex.load_tsv(index_path)
108
+ for neighbor in index.search(query_vector, limit=limit, exclude_id=query_id):
109
+ record = record_by_id.get(neighbor.item_id)
110
+ label = record.question if record is not None else neighbor.item_id
111
+ print(f"{neighbor.score:.4f}\t{neighbor.item_id}\t{label}")
112
+
113
+
114
+ def _add_embedder_arguments(parser: argparse.ArgumentParser) -> None:
115
+ parser.add_argument("--embedder", choices=["color", "vjepa"], default="color")
116
+ parser.add_argument("--jepa-model", default="facebook/vjepa2-vitl-fpc64-256")
117
+ parser.add_argument("--jepa-device", default=None)
118
+ parser.add_argument("--jepa-frames", type=int, default=16)
119
+
120
+
121
+ def _embedder_from_args(args: argparse.Namespace) -> ImageEmbedder:
122
+ if args.embedder == "vjepa":
123
+ return VJepaImageEmbedder(
124
+ model_id=args.jepa_model,
125
+ device=args.jepa_device,
126
+ frame_count=args.jepa_frames,
127
+ )
128
+ return ColorStatsEmbedder()
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
src/mathvision_explorer/dataset.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset loading helpers for MathVision-like JSONL exports."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+
11
+ @dataclass(frozen=True, slots=True)
12
+ class MathVisionRecord:
13
+ """A single visual math problem with optional image and solution metadata."""
14
+
15
+ problem_id: str
16
+ question: str
17
+ answer: str
18
+ subject: str | None = None
19
+ level: int | None = None
20
+ image_path: Path | None = None
21
+ options: tuple[str, ...] = ()
22
+ solution: str | None = None
23
+
24
+
25
+ def load_jsonl_records(path: Path) -> list[MathVisionRecord]:
26
+ """Load MathVision-like records from a UTF-8 JSONL file."""
27
+
28
+ records: list[MathVisionRecord] = []
29
+ with path.open("r", encoding="utf-8") as jsonl_file:
30
+ for line_number, line in enumerate(jsonl_file, start=1):
31
+ stripped = line.strip()
32
+ if not stripped:
33
+ continue
34
+ payload = json.loads(stripped)
35
+ if not isinstance(payload, dict):
36
+ msg = f"Line {line_number} must contain a JSON object."
37
+ raise ValueError(msg)
38
+ records.append(record_from_mapping(payload, source_dir=path.parent))
39
+ return records
40
+
41
+
42
+ def record_from_mapping(
43
+ payload: dict[str, Any], *, source_dir: Path | None = None
44
+ ) -> MathVisionRecord:
45
+ """Create a typed record from a raw dictionary."""
46
+
47
+ problem_id = _required_string(payload, "id", fallback_key="problem_id")
48
+ question = _required_string(payload, "question")
49
+ answer = _required_string(payload, "answer")
50
+ image_path = _optional_path(payload.get("image"), source_dir=source_dir)
51
+ options = _options_from_value(payload.get("options"))
52
+
53
+ return MathVisionRecord(
54
+ problem_id=problem_id,
55
+ question=question,
56
+ answer=answer,
57
+ subject=_optional_string(payload.get("subject")),
58
+ level=_optional_int(payload.get("level")),
59
+ image_path=image_path,
60
+ options=options,
61
+ solution=_optional_string(payload.get("solution")),
62
+ )
63
+
64
+
65
+ def filter_records(
66
+ records: list[MathVisionRecord],
67
+ *,
68
+ subject: str | None = None,
69
+ level: int | None = None,
70
+ ) -> list[MathVisionRecord]:
71
+ """Return records matching optional subject and level filters."""
72
+
73
+ return [
74
+ record
75
+ for record in records
76
+ if (subject is None or record.subject == subject)
77
+ and (level is None or record.level == level)
78
+ ]
79
+
80
+
81
+ def summarize_records(records: list[MathVisionRecord]) -> dict[str, object]:
82
+ """Build a compact summary for CLI output or dashboards."""
83
+
84
+ subjects = sorted({record.subject for record in records if record.subject is not None})
85
+ levels = sorted({record.level for record in records if record.level is not None})
86
+ image_count = sum(1 for record in records if record.image_path is not None)
87
+ return {
88
+ "records": len(records),
89
+ "images": image_count,
90
+ "subjects": subjects,
91
+ "levels": levels,
92
+ }
93
+
94
+
95
+ def _required_string(
96
+ payload: dict[str, Any], key: str, *, fallback_key: str | None = None
97
+ ) -> str:
98
+ value = payload.get(key)
99
+ if value is None and fallback_key is not None:
100
+ value = payload.get(fallback_key)
101
+ if not isinstance(value, str) or not value.strip():
102
+ msg = f"Missing required string field: {key}"
103
+ raise ValueError(msg)
104
+ return value
105
+
106
+
107
+ def _optional_string(value: object) -> str | None:
108
+ if value is None:
109
+ return None
110
+ if not isinstance(value, str):
111
+ msg = "Optional text fields must be strings when present."
112
+ raise ValueError(msg)
113
+ return value
114
+
115
+
116
+ def _optional_int(value: object) -> int | None:
117
+ if value is None:
118
+ return None
119
+ if isinstance(value, bool) or not isinstance(value, int):
120
+ msg = "Level must be an integer when present."
121
+ raise ValueError(msg)
122
+ return value
123
+
124
+
125
+ def _optional_path(value: object, *, source_dir: Path | None) -> Path | None:
126
+ if value is None:
127
+ return None
128
+ if not isinstance(value, str) or not value:
129
+ msg = "Image path must be a non-empty string when present."
130
+ raise ValueError(msg)
131
+ image_path = Path(value)
132
+ if source_dir is not None and not image_path.is_absolute():
133
+ return source_dir / image_path
134
+ return image_path
135
+
136
+
137
+ def _options_from_value(value: object) -> tuple[str, ...]:
138
+ if value is None:
139
+ return ()
140
+ if not isinstance(value, list) or not all(isinstance(option, str) for option in value):
141
+ msg = "Options must be a list of strings when present."
142
+ raise ValueError(msg)
143
+ return tuple(value)
src/mathvision_explorer/demo.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Demo dataset generation for trying the explorer locally."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+
8
+ from PIL import Image, ImageDraw
9
+
10
+
11
+ def create_demo_dataset(output_dir: Path) -> Path:
12
+ """Create a tiny MathVision-like dataset with simple generated images."""
13
+
14
+ image_dir = output_dir / "images"
15
+ image_dir.mkdir(parents=True, exist_ok=True)
16
+
17
+ records = [
18
+ {
19
+ "id": "demo-red-squares",
20
+ "question": "How many red squares are visible?",
21
+ "answer": "4",
22
+ "image": "images/red-squares.png",
23
+ "options": ["3", "4", "5"],
24
+ "subject": "counting",
25
+ "level": 1,
26
+ "solution": "Count the four red square tiles.",
27
+ },
28
+ {
29
+ "id": "demo-red-squares-small",
30
+ "question": "How many small red squares are visible?",
31
+ "answer": "5",
32
+ "image": "images/red-squares-small.png",
33
+ "options": ["4", "5", "6"],
34
+ "subject": "counting",
35
+ "level": 2,
36
+ "solution": "The red tiles form a group of five.",
37
+ },
38
+ {
39
+ "id": "demo-blue-triangles",
40
+ "question": "How many blue triangles are visible?",
41
+ "answer": "3",
42
+ "image": "images/blue-triangles.png",
43
+ "options": ["2", "3", "4"],
44
+ "subject": "geometry",
45
+ "level": 1,
46
+ "solution": "There are three separate blue triangles.",
47
+ },
48
+ {
49
+ "id": "demo-blue-pyramids",
50
+ "question": "Which shape appears repeatedly?",
51
+ "answer": "triangle",
52
+ "image": "images/blue-pyramids.png",
53
+ "options": ["circle", "triangle", "square"],
54
+ "subject": "geometry",
55
+ "level": 2,
56
+ "solution": "The repeated blue shapes are triangles.",
57
+ },
58
+ {
59
+ "id": "demo-red-grid",
60
+ "question": "Which tile color dominates the grid?",
61
+ "answer": "red",
62
+ "image": "images/red-grid.png",
63
+ "options": ["red", "blue", "green"],
64
+ "subject": "pattern",
65
+ "level": 2,
66
+ "solution": "Most grid cells are red.",
67
+ },
68
+ {
69
+ "id": "demo-green-grid",
70
+ "question": "Which tile color dominates this grid?",
71
+ "answer": "green",
72
+ "image": "images/green-grid.png",
73
+ "options": ["red", "blue", "green"],
74
+ "subject": "pattern",
75
+ "level": 2,
76
+ "solution": "Green appears in most grid cells.",
77
+ },
78
+ {
79
+ "id": "demo-number-line",
80
+ "question": "Which point is closest to 4?",
81
+ "answer": "C",
82
+ "image": "images/number-line.png",
83
+ "options": ["A", "B", "C"],
84
+ "subject": "algebra",
85
+ "level": 1,
86
+ "solution": "Point C is drawn nearest to the tick labeled 4.",
87
+ },
88
+ {
89
+ "id": "demo-clock",
90
+ "question": "Which hour does the short hand point to?",
91
+ "answer": "3",
92
+ "image": "images/clock.png",
93
+ "options": ["2", "3", "4"],
94
+ "subject": "measurement",
95
+ "level": 1,
96
+ "solution": "The shorter hand points toward 3.",
97
+ },
98
+ ]
99
+
100
+ _draw_red_squares(image_dir / "red-squares.png")
101
+ _draw_red_squares_small(image_dir / "red-squares-small.png")
102
+ _draw_blue_triangles(image_dir / "blue-triangles.png")
103
+ _draw_blue_pyramids(image_dir / "blue-pyramids.png")
104
+ _draw_red_grid(image_dir / "red-grid.png")
105
+ _draw_green_grid(image_dir / "green-grid.png")
106
+ _draw_number_line(image_dir / "number-line.png")
107
+ _draw_clock(image_dir / "clock.png")
108
+
109
+ jsonl_path = output_dir / "demo.jsonl"
110
+ with jsonl_path.open("w", encoding="utf-8") as jsonl_file:
111
+ for record in records:
112
+ jsonl_file.write(json.dumps(record, sort_keys=True))
113
+ jsonl_file.write("\n")
114
+ return jsonl_path
115
+
116
+
117
+ def _new_canvas() -> Image.Image:
118
+ return Image.new("RGB", (420, 280), color=(248, 250, 252))
119
+
120
+
121
+ def _draw_red_squares(path: Path) -> None:
122
+ image = _new_canvas()
123
+ draw = ImageDraw.Draw(image)
124
+ for x, y in [(80, 60), (170, 60), (80, 150), (170, 150)]:
125
+ draw.rectangle((x, y, x + 58, y + 58), fill=(220, 38, 38), outline=(127, 29, 29), width=3)
126
+ image.save(path)
127
+
128
+
129
+ def _draw_blue_triangles(path: Path) -> None:
130
+ image = _new_canvas()
131
+ draw = ImageDraw.Draw(image)
132
+ triangles = [
133
+ [(90, 190), (130, 80), (170, 190)],
134
+ [(190, 190), (230, 80), (270, 190)],
135
+ [(290, 190), (330, 80), (370, 190)],
136
+ ]
137
+ for triangle in triangles:
138
+ draw.polygon(triangle, fill=(37, 99, 235), outline=(30, 64, 175))
139
+ image.save(path)
140
+
141
+
142
+ def _draw_red_squares_small(path: Path) -> None:
143
+ image = _new_canvas()
144
+ draw = ImageDraw.Draw(image)
145
+ for x, y in [(78, 54), (148, 54), (218, 54), (112, 134), (184, 134)]:
146
+ draw.rectangle((x, y, x + 46, y + 46), fill=(239, 68, 68), outline=(127, 29, 29), width=3)
147
+ image.save(path)
148
+
149
+
150
+ def _draw_blue_pyramids(path: Path) -> None:
151
+ image = _new_canvas()
152
+ draw = ImageDraw.Draw(image)
153
+ for x, y, size in [(82, 178, 52), (162, 178, 68), (262, 178, 82)]:
154
+ draw.polygon(
155
+ [(x, y), (x + size // 2, y - size), (x + size, y)],
156
+ fill=(59, 130, 246),
157
+ outline=(30, 64, 175),
158
+ )
159
+ image.save(path)
160
+
161
+
162
+ def _draw_red_grid(path: Path) -> None:
163
+ image = _new_canvas()
164
+ draw = ImageDraw.Draw(image)
165
+ colors = [
166
+ (220, 38, 38),
167
+ (220, 38, 38),
168
+ (22, 163, 74),
169
+ (220, 38, 38),
170
+ (37, 99, 235),
171
+ (220, 38, 38),
172
+ ]
173
+ for index, color in enumerate(colors):
174
+ row, column = divmod(index, 3)
175
+ x = 92 + column * 82
176
+ y = 64 + row * 82
177
+ draw.rectangle((x, y, x + 64, y + 64), fill=color, outline=(15, 23, 42), width=2)
178
+ image.save(path)
179
+
180
+
181
+ def _draw_green_grid(path: Path) -> None:
182
+ image = _new_canvas()
183
+ draw = ImageDraw.Draw(image)
184
+ colors = [
185
+ (22, 163, 74),
186
+ (22, 163, 74),
187
+ (220, 38, 38),
188
+ (22, 163, 74),
189
+ (37, 99, 235),
190
+ (22, 163, 74),
191
+ ]
192
+ for index, color in enumerate(colors):
193
+ row, column = divmod(index, 3)
194
+ x = 92 + column * 82
195
+ y = 64 + row * 82
196
+ draw.rectangle((x, y, x + 64, y + 64), fill=color, outline=(15, 23, 42), width=2)
197
+ image.save(path)
198
+
199
+
200
+ def _draw_number_line(path: Path) -> None:
201
+ image = _new_canvas()
202
+ draw = ImageDraw.Draw(image)
203
+ draw.line((62, 148, 358, 148), fill=(15, 23, 42), width=4)
204
+ for index in range(6):
205
+ x = 62 + index * 59
206
+ draw.line((x, 134, x, 162), fill=(15, 23, 42), width=3)
207
+ draw.text((x - 5, 170), str(index), fill=(15, 23, 42))
208
+ points = [
209
+ ("A", 174, (37, 99, 235)),
210
+ ("B", 246, (22, 163, 74)),
211
+ ("C", 296, (220, 38, 38)),
212
+ ]
213
+ for label, x, color in points:
214
+ draw.ellipse((x - 9, 108, x + 9, 126), fill=color)
215
+ draw.text((x - 5, 86), label, fill=(15, 23, 42))
216
+ image.save(path)
217
+
218
+
219
+ def _draw_clock(path: Path) -> None:
220
+ image = _new_canvas()
221
+ draw = ImageDraw.Draw(image)
222
+ center = (210, 140)
223
+ draw.ellipse((100, 30, 320, 250), fill=(255, 255, 255), outline=(15, 23, 42), width=4)
224
+ for label, xy in [("12", (199, 48)), ("3", (290, 132)), ("6", (205, 220)), ("9", (120, 132))]:
225
+ draw.text(xy, label, fill=(15, 23, 42))
226
+ draw.line((center[0], center[1], 282, 140), fill=(220, 38, 38), width=6)
227
+ draw.line((center[0], center[1], 210, 68), fill=(37, 99, 235), width=4)
228
+ draw.ellipse((202, 132, 218, 148), fill=(15, 23, 42))
229
+ image.save(path)
src/mathvision_explorer/embeddings.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Embedding helpers for image records.
2
+
3
+ The default embedder is intentionally lightweight and deterministic. It gives the
4
+ project a testable local baseline while leaving room to plug in V-JEPA features later.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from importlib import import_module
10
+ from pathlib import Path
11
+ from typing import Any, Protocol
12
+
13
+ from PIL import Image, ImageStat
14
+
15
+
16
+ class ImageEmbedder(Protocol):
17
+ """Protocol for objects that turn image paths into numeric vectors."""
18
+
19
+ def embed_image(self, image_path: Path) -> tuple[float, ...]:
20
+ """Return an embedding vector for an image file."""
21
+
22
+
23
+ class ColorStatsEmbedder:
24
+ """Embed images with normalized RGB mean and standard deviation features."""
25
+
26
+ def embed_image(self, image_path: Path) -> tuple[float, ...]:
27
+ """Return six normalized color-statistics features for an image."""
28
+
29
+ with Image.open(image_path) as image:
30
+ rgb_image = image.convert("RGB")
31
+ stat = ImageStat.Stat(rgb_image)
32
+ means = tuple(channel / 255.0 for channel in stat.mean)
33
+ stddevs = tuple(channel / 255.0 for channel in stat.stddev)
34
+ return means + stddevs
35
+
36
+
37
+ class MissingImageError(RuntimeError):
38
+ """Raised when a record cannot be embedded because no image path is available."""
39
+
40
+
41
+ class JepaDependencyError(RuntimeError):
42
+ """Raised when optional V-JEPA dependencies are not installed."""
43
+
44
+
45
+ class VJepaImageEmbedder:
46
+ """Embed images with a Hugging Face V-JEPA 2 image/video encoder.
47
+
48
+ The implementation follows the model-card pattern for image inputs: a still image is
49
+ processed as video pixels and repeated across frames before `get_vision_features`.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ *,
55
+ model_id: str = "facebook/vjepa2-vitl-fpc64-256",
56
+ device: str | None = None,
57
+ frame_count: int = 16,
58
+ ) -> None:
59
+ """Load the V-JEPA processor and model lazily at embedder construction time."""
60
+
61
+ if frame_count < 1:
62
+ raise ValueError("Frame count must be at least 1.")
63
+
64
+ self.model_id = model_id
65
+ self.frame_count = frame_count
66
+ self._torch = _import_optional("torch")
67
+ transformers = _import_optional("transformers")
68
+ _quiet_transformers_logging(transformers)
69
+
70
+ try:
71
+ self._processor = transformers.AutoVideoProcessor.from_pretrained(model_id)
72
+ self._model = transformers.AutoModel.from_pretrained(model_id)
73
+ except ImportError as error:
74
+ msg = (
75
+ "V-JEPA backend dependencies are missing. Install them with "
76
+ "`make sync-jepa`."
77
+ )
78
+ raise JepaDependencyError(msg) from error
79
+ self._device = device or ("cuda" if self._torch.cuda.is_available() else "cpu")
80
+ self._model.to(self._device)
81
+ self._model.eval()
82
+
83
+ def embed_image(self, image_path: Path) -> tuple[float, ...]:
84
+ """Return a pooled V-JEPA feature vector for an image."""
85
+
86
+ with Image.open(image_path) as image:
87
+ rgb_image = image.convert("RGB")
88
+
89
+ encoded = self._processor(rgb_image, return_tensors="pt").to(self._model.device)
90
+ pixel_values = encoded["pixel_values_videos"]
91
+ pixel_values = pixel_values.repeat(1, self.frame_count, 1, 1, 1)
92
+
93
+ with self._torch.no_grad():
94
+ features = self._model.get_vision_features(pixel_values)
95
+ pooled = _mean_pool_features(features)
96
+
97
+ return tuple(float(value) for value in pooled.squeeze(0).detach().cpu().tolist())
98
+
99
+
100
+ def _mean_pool_features(features: Any) -> Any:
101
+ """Pool token/time dimensions while preserving the final feature dimension."""
102
+
103
+ if features.ndim <= 2:
104
+ return features
105
+ return features.mean(dim=tuple(range(1, features.ndim - 1)))
106
+
107
+
108
+ def embed_record_image(image_path: Path | None, embedder: ImageEmbedder) -> tuple[float, ...]:
109
+ """Embed a record image or raise a clear error when the path is missing."""
110
+
111
+ if image_path is None:
112
+ raise MissingImageError("Record has no image path to embed.")
113
+ return embedder.embed_image(image_path)
114
+
115
+
116
+ def _import_optional(module_name: str) -> Any:
117
+ try:
118
+ return import_module(module_name)
119
+ except ImportError as error:
120
+ msg = (
121
+ "V-JEPA dependencies are missing. Install them with "
122
+ "`uv sync --extra jepa --dev`."
123
+ )
124
+ raise JepaDependencyError(msg) from error
125
+
126
+
127
+ def _quiet_transformers_logging(transformers: Any) -> None:
128
+ """Reduce noisy dev-version Transformers compatibility logging."""
129
+
130
+ try:
131
+ transformers.logging.set_verbosity_error()
132
+ except AttributeError:
133
+ return
src/mathvision_explorer/explorer.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """High-level workflows for MathVision exploration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ from mathvision_explorer.dataset import MathVisionRecord, filter_records, load_jsonl_records
8
+ from mathvision_explorer.embeddings import ImageEmbedder, embed_record_image
9
+ from mathvision_explorer.index import Neighbor, VectorIndex
10
+
11
+
12
+ def build_image_index(records: list[MathVisionRecord], embedder: ImageEmbedder) -> VectorIndex:
13
+ """Build a vector index for all records that have image paths."""
14
+
15
+ index = VectorIndex()
16
+ for record in records:
17
+ if record.image_path is None:
18
+ continue
19
+ index.add(record.problem_id, embed_record_image(record.image_path, embedder))
20
+ return index
21
+
22
+
23
+ def find_similar_records(
24
+ records: list[MathVisionRecord],
25
+ index: VectorIndex,
26
+ query_id: str,
27
+ query_vector: tuple[float, ...],
28
+ *,
29
+ limit: int = 5,
30
+ ) -> list[tuple[MathVisionRecord, Neighbor]]:
31
+ """Find records nearest to a query vector."""
32
+
33
+ record_by_id = {record.problem_id: record for record in records}
34
+ neighbors = index.search(query_vector, limit=limit, exclude_id=query_id)
35
+ return [
36
+ (record_by_id[neighbor.item_id], neighbor)
37
+ for neighbor in neighbors
38
+ if neighbor.item_id in record_by_id
39
+ ]
40
+
41
+
42
+ def load_filtered_records(
43
+ path: Path, *, subject: str | None = None, level: int | None = None
44
+ ) -> list[MathVisionRecord]:
45
+ """Load records and apply optional explorer filters."""
46
+
47
+ return filter_records(load_jsonl_records(path), subject=subject, level=level)
src/mathvision_explorer/html.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HTML export for visual inspection of MathVision-like records."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from html import escape
6
+ from pathlib import Path
7
+
8
+ from mathvision_explorer.dataset import MathVisionRecord
9
+
10
+
11
+ def export_html(records: list[MathVisionRecord], output: Path) -> None:
12
+ """Write a standalone HTML gallery for records."""
13
+
14
+ output.parent.mkdir(parents=True, exist_ok=True)
15
+ cards = "\n".join(_render_card(record, output_dir=output.parent) for record in records)
16
+ html = f"""<!doctype html>
17
+ <html lang="en">
18
+ <head>
19
+ <meta charset="utf-8">
20
+ <meta name="viewport" content="width=device-width, initial-scale=1">
21
+ <title>MathVision Explorer</title>
22
+ <style>
23
+ body {{
24
+ margin: 0;
25
+ font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
26
+ color: #172033;
27
+ background: #f7f8fb;
28
+ }}
29
+ header {{
30
+ padding: 28px 32px 18px;
31
+ background: #ffffff;
32
+ border-bottom: 1px solid #dde3ee;
33
+ }}
34
+ h1 {{
35
+ margin: 0 0 6px;
36
+ font-size: 28px;
37
+ font-weight: 750;
38
+ }}
39
+ main {{
40
+ display: grid;
41
+ gap: 18px;
42
+ grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
43
+ padding: 24px 32px 36px;
44
+ }}
45
+ article {{
46
+ overflow: hidden;
47
+ border: 1px solid #d8deea;
48
+ border-radius: 8px;
49
+ background: #ffffff;
50
+ }}
51
+ img {{
52
+ display: block;
53
+ width: 100%;
54
+ aspect-ratio: 3 / 2;
55
+ object-fit: contain;
56
+ background: #eef2f7;
57
+ }}
58
+ .body {{
59
+ padding: 16px;
60
+ }}
61
+ .meta {{
62
+ display: flex;
63
+ flex-wrap: wrap;
64
+ gap: 8px;
65
+ margin-bottom: 12px;
66
+ font-size: 13px;
67
+ }}
68
+ .tag {{
69
+ padding: 3px 8px;
70
+ border: 1px solid #cfd7e6;
71
+ border-radius: 999px;
72
+ background: #f4f7fb;
73
+ }}
74
+ h2 {{
75
+ margin: 0 0 12px;
76
+ font-size: 16px;
77
+ line-height: 1.35;
78
+ }}
79
+ p {{
80
+ margin: 8px 0 0;
81
+ line-height: 1.45;
82
+ }}
83
+ .answer {{
84
+ font-weight: 700;
85
+ }}
86
+ </style>
87
+ </head>
88
+ <body>
89
+ <header>
90
+ <h1>MathVision Explorer</h1>
91
+ <div>{len(records)} visual math records</div>
92
+ </header>
93
+ <main>
94
+ {cards}
95
+ </main>
96
+ </body>
97
+ </html>
98
+ """
99
+ output.write_text(html, encoding="utf-8")
100
+
101
+
102
+ def _render_card(record: MathVisionRecord, *, output_dir: Path) -> str:
103
+ image_html = ""
104
+ if record.image_path is not None:
105
+ image_src = _relative_or_absolute_image(record.image_path, output_dir=output_dir)
106
+ image_html = f' <img src="{escape(image_src)}" alt="{escape(record.problem_id)}">\n'
107
+
108
+ meta = [_tag(record.problem_id)]
109
+ if record.subject is not None:
110
+ meta.append(_tag(record.subject))
111
+ if record.level is not None:
112
+ meta.append(_tag(f"level {record.level}"))
113
+
114
+ options = ""
115
+ if record.options:
116
+ options = f"<p>Options: {escape(', '.join(record.options))}</p>"
117
+
118
+ solution = ""
119
+ if record.solution:
120
+ solution = f"<p>{escape(record.solution)}</p>"
121
+
122
+ return f""" <article>
123
+ {image_html} <div class="body">
124
+ <div class="meta">{''.join(meta)}</div>
125
+ <h2>{escape(record.question)}</h2>
126
+ <p class="answer">Answer: {escape(record.answer)}</p>
127
+ {options}
128
+ {solution}
129
+ </div>
130
+ </article>"""
131
+
132
+
133
+ def _tag(value: str) -> str:
134
+ return f'<span class="tag">{escape(value)}</span>'
135
+
136
+
137
+ def _relative_or_absolute_image(image_path: Path, *, output_dir: Path) -> str:
138
+ try:
139
+ return image_path.resolve().relative_to(output_dir.resolve()).as_posix()
140
+ except ValueError:
141
+ return image_path.resolve().as_uri()
src/mathvision_explorer/index.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Small vector index for nearest-neighbor exploration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+
9
+
10
+ @dataclass(frozen=True, slots=True)
11
+ class Neighbor:
12
+ """A nearest-neighbor result from a vector search."""
13
+
14
+ item_id: str
15
+ score: float
16
+
17
+
18
+ class VectorIndex:
19
+ """In-memory cosine-similarity vector index."""
20
+
21
+ def __init__(self) -> None:
22
+ """Create an empty vector index."""
23
+
24
+ self._vectors: dict[str, tuple[float, ...]] = {}
25
+
26
+ def add(self, item_id: str, vector: tuple[float, ...]) -> None:
27
+ """Add or replace an item vector."""
28
+
29
+ if not vector:
30
+ raise ValueError("Vector must contain at least one value.")
31
+ if not all(math.isfinite(value) for value in vector):
32
+ raise ValueError("Vector values must be finite.")
33
+ self._vectors[item_id] = vector
34
+
35
+ def search(
36
+ self, query_vector: tuple[float, ...], *, limit: int = 5, exclude_id: str | None = None
37
+ ) -> list[Neighbor]:
38
+ """Return the closest vectors by cosine similarity."""
39
+
40
+ if limit < 1:
41
+ raise ValueError("Limit must be at least 1.")
42
+ neighbors = [
43
+ Neighbor(item_id=item_id, score=_cosine_similarity(query_vector, vector))
44
+ for item_id, vector in self._vectors.items()
45
+ if item_id != exclude_id
46
+ ]
47
+ return sorted(neighbors, key=lambda neighbor: neighbor.score, reverse=True)[:limit]
48
+
49
+ def save_tsv(self, path: Path) -> None:
50
+ """Persist the index as a simple tab-separated text file."""
51
+
52
+ path.parent.mkdir(parents=True, exist_ok=True)
53
+ with path.open("w", encoding="utf-8") as index_file:
54
+ for item_id, vector in sorted(self._vectors.items()):
55
+ values = "\t".join(str(value) for value in vector)
56
+ index_file.write(f"{item_id}\t{values}\n")
57
+
58
+ @classmethod
59
+ def load_tsv(cls, path: Path) -> VectorIndex:
60
+ """Load an index produced by :meth:`save_tsv`."""
61
+
62
+ index = cls()
63
+ with path.open("r", encoding="utf-8") as index_file:
64
+ for line_number, line in enumerate(index_file, start=1):
65
+ fields = line.rstrip("\n").split("\t")
66
+ if len(fields) < 2:
67
+ msg = f"Line {line_number} must contain an id and vector values."
68
+ raise ValueError(msg)
69
+ index.add(fields[0], tuple(float(value) for value in fields[1:]))
70
+ return index
71
+
72
+ def __len__(self) -> int:
73
+ """Return the number of indexed vectors."""
74
+
75
+ return len(self._vectors)
76
+
77
+
78
+ def _cosine_similarity(left: tuple[float, ...], right: tuple[float, ...]) -> float:
79
+ if len(left) != len(right):
80
+ raise ValueError("Vectors must have the same dimensions.")
81
+ left_norm = math.sqrt(sum(value * value for value in left))
82
+ right_norm = math.sqrt(sum(value * value for value in right))
83
+ if left_norm == 0.0 or right_norm == 0.0:
84
+ return 0.0
85
+ dot_product = sum(
86
+ left_value * right_value for left_value, right_value in zip(left, right, strict=True)
87
+ )
88
+ return dot_product / (left_norm * right_norm)
src/mathvision_explorer/py.typed ADDED
@@ -0,0 +1 @@
 
 
1
+
src/mathvision_explorer/similarity.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers for interpreting nearest-neighbor similarity scores."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ from mathvision_explorer.dataset import MathVisionRecord
8
+
9
+
10
+ @dataclass(frozen=True, slots=True)
11
+ class MatchInterpretation:
12
+ """Human-readable context for a nearest-neighbor match."""
13
+
14
+ label: str
15
+ subject_match: bool | None
16
+ level_delta: int | None
17
+ summary: str
18
+
19
+
20
+ def similarity_label(score: float) -> str:
21
+ """Return a compact label for a cosine similarity score."""
22
+
23
+ if score >= 0.90:
24
+ return "Very close visual match"
25
+ if score >= 0.75:
26
+ return "Related visual structure"
27
+ if score >= 0.55:
28
+ return "Loose visual overlap"
29
+ return "Weak visual match"
30
+
31
+
32
+ def interpret_match(
33
+ query: MathVisionRecord,
34
+ match: MathVisionRecord,
35
+ *,
36
+ score: float,
37
+ ) -> MatchInterpretation:
38
+ """Explain a match using visual score, subject overlap, and difficulty delta."""
39
+
40
+ subject_match = _subject_match(query, match)
41
+ level_delta = _level_delta(query, match)
42
+ parts = [similarity_label(score)]
43
+
44
+ if subject_match is True:
45
+ parts.append("same subject")
46
+ elif subject_match is False:
47
+ parts.append("different subject")
48
+
49
+ if level_delta is not None:
50
+ if level_delta == 0:
51
+ parts.append("same level")
52
+ elif level_delta > 0:
53
+ parts.append(f"{level_delta} level harder")
54
+ else:
55
+ parts.append(f"{abs(level_delta)} level easier")
56
+
57
+ return MatchInterpretation(
58
+ label=similarity_label(score),
59
+ subject_match=subject_match,
60
+ level_delta=level_delta,
61
+ summary="; ".join(parts),
62
+ )
63
+
64
+
65
+ def embedder_description(embedder_name: str) -> str:
66
+ """Describe what a selected embedder is comparing."""
67
+
68
+ if embedder_name == "vjepa":
69
+ return (
70
+ "V-JEPA compares learned visual features: layout, shapes, object-like structure, "
71
+ "and spatial patterns. Read scores relatively, not as percentages."
72
+ )
73
+ return (
74
+ "Color compares only RGB means and spread. It is a fast sanity-check baseline, "
75
+ "not semantic visual understanding."
76
+ )
77
+
78
+
79
+ def _subject_match(query: MathVisionRecord, match: MathVisionRecord) -> bool | None:
80
+ if query.subject is None or match.subject is None:
81
+ return None
82
+ return query.subject == match.subject
83
+
84
+
85
+ def _level_delta(query: MathVisionRecord, match: MathVisionRecord) -> int | None:
86
+ if query.level is None or match.level is None:
87
+ return None
88
+ return match.level - query.level
src/mathvision_explorer/streamlit_app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streamlit app for browsing MathVision-like records."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from importlib import import_module
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from mathvision_explorer.dataset import MathVisionRecord, filter_records, load_jsonl_records
11
+ from mathvision_explorer.embeddings import ColorStatsEmbedder, ImageEmbedder, VJepaImageEmbedder
12
+ from mathvision_explorer.explorer import build_image_index, find_similar_records
13
+ from mathvision_explorer.similarity import embedder_description, interpret_match
14
+
15
+
16
+ def main() -> None:
17
+ """Run the Streamlit explorer app."""
18
+
19
+ args = _parse_args()
20
+ st = _load_streamlit()
21
+ records = load_jsonl_records(args.jsonl)
22
+
23
+ st.set_page_config(page_title="MathVision Explorer", layout="wide")
24
+ st.title("MathVision Explorer")
25
+
26
+ subjects = sorted({record.subject for record in records if record.subject is not None})
27
+ levels = sorted({record.level for record in records if record.level is not None})
28
+
29
+ with st.sidebar:
30
+ st.header("Filters")
31
+ subject = st.selectbox("Subject", ["all", *subjects])
32
+ level_label = st.selectbox("Level", ["all", *(str(level) for level in levels)])
33
+ show_solutions = st.toggle("Show solutions", value=True)
34
+ st.header("Latent Space")
35
+ embedder_label = st.selectbox(
36
+ "Embedder",
37
+ ["color (fast demo)", "vjepa (requires make sync-jepa)"],
38
+ )
39
+ query_id = st.selectbox("Query record", [record.problem_id for record in records])
40
+ neighbor_count = st.slider("Neighbors", min_value=1, max_value=8, value=3)
41
+
42
+ selected_subject = None if subject == "all" else subject
43
+ selected_level = None if level_label == "all" else int(level_label)
44
+ filtered = filter_records(records, subject=selected_subject, level=selected_level)
45
+
46
+ _render_similarity_panel(
47
+ st,
48
+ records,
49
+ query_id=query_id,
50
+ embedder_name=_embedder_name_from_label(embedder_label),
51
+ neighbor_count=neighbor_count,
52
+ )
53
+
54
+ st.caption(f"{len(filtered)} of {len(records)} records")
55
+ for record in filtered:
56
+ _render_record(st, record, show_solution=show_solutions)
57
+
58
+
59
+ def _render_similarity_panel(
60
+ st: Any,
61
+ records: list[MathVisionRecord],
62
+ *,
63
+ query_id: str,
64
+ embedder_name: str,
65
+ neighbor_count: int,
66
+ ) -> None:
67
+ st.header("Nearest Neighbors")
68
+ st.caption(embedder_description(embedder_name))
69
+ record_by_id = {record.problem_id: record for record in records}
70
+ query = record_by_id[query_id]
71
+ if query.image_path is None:
72
+ st.warning("Selected query has no image.")
73
+ return
74
+
75
+ try:
76
+ embedder = _load_embedder(embedder_name)
77
+ query_vector = embedder.embed_image(query.image_path)
78
+ index = build_image_index(records, embedder)
79
+ matches = find_similar_records(
80
+ records,
81
+ index,
82
+ query.problem_id,
83
+ query_vector,
84
+ limit=neighbor_count,
85
+ )
86
+ except RuntimeError as error:
87
+ st.error(str(error))
88
+ return
89
+
90
+ columns = st.columns([1, 2])
91
+ with columns[0]:
92
+ st.caption("Query")
93
+ st.image(str(query.image_path), width="stretch")
94
+ st.write(query.problem_id)
95
+ with columns[1]:
96
+ for record, neighbor in matches:
97
+ interpretation = interpret_match(query, record, score=neighbor.score)
98
+ with st.container(border=True):
99
+ match_columns = st.columns([0.35, 1])
100
+ with match_columns[0]:
101
+ if record.image_path is not None:
102
+ st.image(str(record.image_path), width="stretch")
103
+ with match_columns[1]:
104
+ st.write(f"**{record.problem_id}**")
105
+ st.caption(f"similarity {neighbor.score:.4f} | {interpretation.label}")
106
+ st.write(record.question)
107
+ st.write(interpretation.summary)
108
+
109
+
110
+ def _render_record(st: Any, record: MathVisionRecord, *, show_solution: bool) -> None:
111
+ with st.container(border=True):
112
+ columns = st.columns([1, 1.4])
113
+ with columns[0]:
114
+ if record.image_path is not None:
115
+ st.image(str(record.image_path), width="stretch")
116
+ with columns[1]:
117
+ st.subheader(record.question)
118
+ badges = [record.problem_id]
119
+ if record.subject is not None:
120
+ badges.append(record.subject)
121
+ if record.level is not None:
122
+ badges.append(f"level {record.level}")
123
+ st.caption(" | ".join(badges))
124
+ if record.options:
125
+ st.write("Options: " + ", ".join(record.options))
126
+ st.write(f"Answer: **{record.answer}**")
127
+ if show_solution and record.solution:
128
+ st.write(record.solution)
129
+
130
+
131
+ def _parse_args() -> argparse.Namespace:
132
+ parser = argparse.ArgumentParser()
133
+ parser.add_argument("--jsonl", type=Path, default=Path("data/demo/demo.jsonl"))
134
+ return parser.parse_args()
135
+
136
+
137
+ def _load_streamlit() -> Any:
138
+ try:
139
+ return import_module("streamlit")
140
+ except ImportError as error:
141
+ msg = "Streamlit is missing. Install it with `uv sync --extra app --dev`."
142
+ raise RuntimeError(msg) from error
143
+
144
+
145
+ def _load_embedder(embedder_name: str) -> ImageEmbedder:
146
+ if embedder_name == "vjepa":
147
+ return VJepaImageEmbedder()
148
+ return ColorStatsEmbedder()
149
+
150
+
151
+ def _embedder_name_from_label(label: str) -> str:
152
+ return "vjepa" if label.startswith("vjepa") else "color"
153
+
154
+
155
+ if __name__ == "__main__":
156
+ main()