| """High-level workflows for MathVision exploration.""" |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| from mathvision_explorer.dataset import MathVisionRecord, filter_records, load_jsonl_records |
| from mathvision_explorer.embeddings import ImageEmbedder, embed_record_image |
| from mathvision_explorer.index import Neighbor, VectorIndex |
|
|
|
|
| def build_image_index(records: list[MathVisionRecord], embedder: ImageEmbedder) -> VectorIndex: |
| """Build a vector index for all records that have image paths.""" |
|
|
| index = VectorIndex() |
| for record in records: |
| if record.image_path is None: |
| continue |
| index.add(record.problem_id, embed_record_image(record.image_path, embedder)) |
| return index |
|
|
|
|
| def find_similar_records( |
| records: list[MathVisionRecord], |
| index: VectorIndex, |
| query_id: str, |
| query_vector: tuple[float, ...], |
| *, |
| limit: int = 5, |
| ) -> list[tuple[MathVisionRecord, Neighbor]]: |
| """Find records nearest to a query vector.""" |
|
|
| record_by_id = {record.problem_id: record for record in records} |
| neighbors = index.search(query_vector, limit=limit, exclude_id=query_id) |
| return [ |
| (record_by_id[neighbor.item_id], neighbor) |
| for neighbor in neighbors |
| if neighbor.item_id in record_by_id |
| ] |
|
|
|
|
| def load_filtered_records( |
| path: Path, *, subject: str | None = None, level: int | None = None |
| ) -> list[MathVisionRecord]: |
| """Load records and apply optional explorer filters.""" |
|
|
| return filter_records(load_jsonl_records(path), subject=subject, level=level) |
|
|