| """ |
| Sudoku Video Dataset Generator - Supports flexible solution count expressions per puzzle. |
| With checkpoint/resume support via metadata.json. |
| |
| The *frames* parameter replaces the old max_frames + k pair: |
| - frames=None → 1 content frame per fill step (variable length) |
| - frames=N → exactly N content frames; fills distributed evenly |
| (slow-motion if N > fills, fast-forward if N < fills) |
| """ |
| import json |
| import re |
| import random |
| import argparse |
| from dataclasses import dataclass, asdict |
| from pathlib import Path |
| from typing import List, Tuple, Optional, Dict |
| import numpy as np |
| import cv2 |
| from tqdm import tqdm |
| from sudoku_processor import SudokuProcessor |
|
|
|
|
| |
|
|
| @dataclass |
| class SolRange: |
| """Flexible solution count constraint for puzzle generation.""" |
| min_sol: int |
| max_sol: Optional[int] |
|
|
| @classmethod |
| def parse(cls, expr: str) -> "SolRange": |
| expr = expr.strip() |
| m = re.fullmatch(r'(\d+)\s*-\s*(\d+)', expr) |
| if m: |
| lo, hi = int(m.group(1)), int(m.group(2)) |
| if lo < 1: raise ValueError(f"min_sol must be >= 1, got {lo}") |
| if hi < lo: raise ValueError(f"Invalid range: {lo}-{hi}") |
| return cls(min_sol=lo, max_sol=hi) |
| m = re.fullmatch(r'(>=|>|<=|<|==)\s*(\d+)', expr) |
| if m: |
| op, n = m.group(1), int(m.group(2)) |
| if op == '>=': return cls(min_sol=max(1, n), max_sol=None) |
| elif op == '>': return cls(min_sol=max(1, n + 1), max_sol=None) |
| elif op == '<=': return cls(min_sol=1, max_sol=n) |
| elif op == '<': return cls(min_sol=1, max_sol=max(1, n - 1)) |
| elif op == '==': return cls(min_sol=n, max_sol=n) |
| m = re.fullmatch(r'(\d+)', expr) |
| if m: |
| n = int(m.group(1)) |
| if n < 1: raise ValueError(f"sol_num must be >= 1, got {n}") |
| return cls(min_sol=n, max_sol=n) |
| raise ValueError(f"Invalid sol_num expression: '{expr}'") |
|
|
| @property |
| def is_exact(self): return self.max_sol is not None and self.min_sol == self.max_sol |
| @property |
| def is_unique_only(self): return self.is_exact and self.min_sol == 1 |
| @property |
| def allows_unique(self): return self.min_sol <= 1 |
| @property |
| def requires_multi(self): return self.min_sol > 1 |
| @property |
| def effective_max(self): return self.max_sol if self.max_sol is not None else max(self.min_sol, 10) |
| def accepts(self, count): |
| if count < self.min_sol: return False |
| if self.max_sol is not None and count > self.max_sol: return False |
| return True |
| def __repr__(self): |
| if self.is_exact: return f"SolRange(=={self.min_sol})" |
| if self.max_sol is None: return f"SolRange(>={self.min_sol})" |
| return f"SolRange({self.min_sol}-{self.max_sol})" |
|
|
|
|
| |
|
|
| @dataclass |
| class GenerationState: |
| """Tracks generation progress for checkpoint/resume.""" |
| params_hash: str |
| clue_progress: Dict[int, int] |
| seen_grids: List[str] |
| all_samples: List[Dict] |
| completed: bool = False |
|
|
| def to_dict(self) -> Dict: |
| return asdict(self) |
|
|
| @classmethod |
| def from_dict(cls, d: Dict) -> "GenerationState": |
| return cls(**d) |
|
|
|
|
| def compute_params_hash(params: Dict) -> str: |
| """Compute hash of generation parameters for consistency check.""" |
| import hashlib |
| key_params = {k: v for k, v in params.items() if k not in ['output_dir']} |
| return hashlib.md5(json.dumps(key_params, sort_keys=True).encode()).hexdigest()[:12] |
|
|
|
|
| def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]: |
| """Load checkpoint if exists and params match.""" |
| meta_path = output_dir / "metadata.json" |
| if not meta_path.exists(): |
| return None |
| with open(meta_path) as f: |
| data = json.load(f) |
| state = GenerationState.from_dict(data["state"]) |
| expected_hash = compute_params_hash(params) |
| if state.params_hash != expected_hash: |
| print(f"⚠️ Parameters changed (hash {state.params_hash} → {expected_hash}), starting fresh") |
| return None |
| if state.completed: |
| print("✓ Generation already completed") |
| return state |
| print(f"✓ Resuming from checkpoint: {sum(state.clue_progress.values())} puzzles generated") |
| return state |
|
|
|
|
| def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict): |
| """Save current generation state to metadata.json.""" |
| meta_path = output_dir / "metadata.json" |
| tmp_path = meta_path.with_suffix('.tmp') |
| with open(tmp_path, 'w') as f: |
| json.dump({"params": params, "state": state.to_dict()}, f, indent=2) |
| tmp_path.rename(meta_path) |
|
|
|
|
| |
|
|
| def get_fill_order(puzzle, solution): |
| """Return list of (row, col, value) for empty cells in row-major order.""" |
| return [(i, j, solution[i][j]) for i in range(9) for j in range(9) if puzzle[i][j] == 0] |
|
|
|
|
| def create_processor(resolution=None): |
| """Create a SudokuProcessor with optional custom resolution.""" |
| if resolution is None: |
| return SudokuProcessor() |
| target_size = min(resolution) |
| cell_size = target_size // 9 |
| sf = cell_size / 60 |
| return SudokuProcessor( |
| cell_size=cell_size, font_scale=1.2 * sf, thickness=max(1, int(2 * sf)) |
| ) |
|
|
|
|
| def generate_video_frames(proc, puzzle, solution, n_start, m_end, frames=None): |
| """ |
| Generate progressive video frames for a Sudoku solve. |
| |
| The *frames* parameter controls the number of **content frames** |
| (between the opening and closing holds): |
| |
| - frames=None → 1 content frame per fill step (n_fills total) |
| - frames > fills → multiple frames per fill step (slow-motion) |
| - frames < fills → multiple fills per frame (fast-forward) |
| - frames == fills → identical to None |
| |
| Total output length = n_start + content_frames + m_end. |
| |
| Args: |
| proc: SudokuProcessor instance. |
| puzzle: 9×9 puzzle grid (0 = empty). |
| solution: 9×9 solved grid. |
| n_start: Hold frames for puzzle at the beginning. |
| m_end: Hold frames for completed solution at the end. |
| frames: Desired number of content frames (None = one per fill). |
| |
| Returns: |
| List of numpy arrays (RGB images). |
| """ |
| fills = get_fill_order(puzzle, solution) |
| n_fills = len(fills) |
|
|
| if n_fills == 0: |
| img = proc.render(solution, original=puzzle) |
| return [img.copy() for _ in range(n_start + m_end + 1)] |
|
|
| content_frames = frames if frames is not None else n_fills |
| content_frames = max(1, content_frames) |
|
|
| result = [] |
| current = [row[:] for row in puzzle] |
|
|
| |
| img = proc.render(current) |
| result.extend([img.copy() for _ in range(n_start)]) |
|
|
| |
| if content_frames == n_fills: |
| |
| for r, c, v in fills: |
| current[r][c] = v |
| result.append(proc.render(current, highlight_new=(r, c), original=puzzle)) |
|
|
| elif content_frames > n_fills: |
| |
| for i, (r, c, v) in enumerate(fills): |
| current[r][c] = v |
| f_lo = i * content_frames // n_fills |
| f_hi = (i + 1) * content_frames // n_fills |
| count = f_hi - f_lo |
|
|
| |
| result.append(proc.render(current, highlight_new=(r, c), original=puzzle)) |
| |
| if count > 1: |
| img = proc.render(current, original=puzzle) |
| result.extend([img.copy() for _ in range(count - 1)]) |
|
|
| else: |
| |
| for f in range(content_frames): |
| prev_step = f * n_fills // content_frames |
| target_step = (f + 1) * n_fills // content_frames |
| last_r, last_c = None, None |
| for idx in range(prev_step, target_step): |
| r, c, v = fills[idx] |
| current[r][c] = v |
| last_r, last_c = r, c |
| if last_r is not None: |
| result.append( |
| proc.render(current, highlight_new=(last_r, last_c), original=puzzle) |
| ) |
| else: |
| result.append(proc.render(current, original=puzzle)) |
|
|
| |
| img = proc.render(solution, original=puzzle) |
| result.extend([img.copy() for _ in range(m_end)]) |
|
|
| return result |
|
|
|
|
| def save_video(frames, path, fps=10): |
| """Save list of numpy RGB frames as mp4.""" |
| h, w = frames[0].shape[:2] |
| writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) |
| for f in frames: |
| writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR)) |
| writer.release() |
|
|
|
|
| def normalize_num_per_clue(num_per_clue, clue_levels): |
| """Broadcast single int to list, or validate list length.""" |
| if isinstance(num_per_clue, int): |
| return [num_per_clue] * len(clue_levels) |
| if len(num_per_clue) != len(clue_levels): |
| raise ValueError( |
| f"num_per_clue length ({len(num_per_clue)}) != clue_levels ({len(clue_levels)})" |
| ) |
| return num_per_clue |
|
|
|
|
| |
|
|
| def generate_puzzle_with_range(proc, clue, sol_range, min_hamming): |
| """Generate one puzzle respecting sol_range. Returns (puzzle, solutions) or None.""" |
| if sol_range.is_unique_only: |
| puzzle, solution = proc.generate(clue, unique=True) |
| return puzzle, [solution] |
|
|
| if sol_range.requires_multi: |
| try: |
| puzzle, solutions = proc.generate_multi_solution( |
| clue, min_solutions=sol_range.min_sol, |
| max_solutions=sol_range.effective_max, |
| max_attempts=1, min_hamming=min_hamming |
| ) |
| if sol_range.accepts(len(solutions)): |
| return puzzle, solutions |
| except RuntimeError: |
| pass |
| return None |
|
|
| try: |
| puzzle, solutions = proc.generate_multi_solution( |
| clue, min_solutions=max(2, sol_range.min_sol), |
| max_solutions=sol_range.effective_max, |
| max_attempts=1, min_hamming=min_hamming |
| ) |
| if sol_range.accepts(len(solutions)): |
| return puzzle, solutions |
| except RuntimeError: |
| pass |
|
|
| if sol_range.allows_unique: |
| puzzle, solution = proc.generate(clue, unique=True) |
| return puzzle, [solution] |
| return None |
|
|
|
|
| |
|
|
| def generate_dataset( |
| output_dir="sudoku", clue_levels=[20, 30, 40, 50, 60, 70], |
| num_per_clue=[15000, 10000, 10000, 5000, 2000, 1000], |
| sol_num="<=3", min_hamming=10, train_ratio=0.9, |
| prompt="Solve this Sudoku puzzle using red font.", |
| n_start=2, m_end=3, frames=None, fps=10, |
| resolution=None, seed=42, checkpoint_interval=50 |
| ): |
| """ |
| Generate Sudoku video dataset with checkpoint/resume support. |
| |
| The *frames* parameter controls the number of **content frames** per video: |
| - None → one content frame per fill step (variable length per puzzle) |
| - N > 0 → exactly N content frames; fills distributed evenly |
| |
| Args: |
| checkpoint_interval: Save checkpoint every N puzzles (default: 50) |
| """ |
| params = { |
| "clue_levels": clue_levels, "num_per_clue": num_per_clue, |
| "sol_num": sol_num, "min_hamming": min_hamming, "train_ratio": train_ratio, |
| "prompt": prompt, "n_start": n_start, "m_end": m_end, "frames": frames, |
| "fps": fps, "resolution": resolution, "seed": seed |
| } |
|
|
| output_dir = Path(output_dir) |
| video_dir = output_dir / "videos" |
| image_dir = output_dir / "images" |
| video_dir.mkdir(parents=True, exist_ok=True) |
| image_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| state = load_checkpoint(output_dir, params) |
| if state and state.completed: |
| return |
|
|
| sol_range = SolRange.parse(str(sol_num)) |
| proc = create_processor(resolution) |
| actual_size = proc.img_size |
| num_per_clue_list = normalize_num_per_clue(num_per_clue, clue_levels) |
| max_puzzles = max(num_per_clue_list) |
| num_width = len(str(max_puzzles)) |
|
|
| |
| if state is None: |
| random.seed(seed) |
| state = GenerationState( |
| params_hash=compute_params_hash(params), |
| clue_progress={clue: 0 for clue in clue_levels}, |
| seen_grids=[], |
| all_samples=[] |
| ) |
| print(f"Starting fresh generation with solution range: {sol_range}") |
| print(f" frames={'auto (1 per fill)' if frames is None else frames}, " |
| f"n_start={n_start}, m_end={m_end}, fps={fps}") |
| else: |
| random.seed(seed) |
| for _ in range(sum(state.clue_progress.values()) * 10): |
| random.random() |
|
|
| seen_grids = set(state.seen_grids) |
| all_samples = state.all_samples.copy() |
| clue_progress = {int(k): v for k, v in state.clue_progress.items()} |
|
|
| total_target = sum(num_per_clue_list) |
| total_done = sum(clue_progress.values()) |
| stats_unique = sum(1 for s in all_samples if s["total_solutions"] == 1 and s["sol_idx"] == 0) |
| stats_multi = sum(1 for s in all_samples if s["total_solutions"] > 1 and s["sol_idx"] == 0) |
| puzzles_since_checkpoint = 0 |
|
|
| with tqdm(total=total_target, initial=total_done, desc="Total", unit="puzzle") as pbar_total: |
| for clue, target_count in zip(clue_levels, num_per_clue_list): |
| generated = clue_progress.get(clue, 0) |
| if generated >= target_count: |
| continue |
|
|
| max_attempts = (target_count - generated) * 20 |
|
|
| with tqdm(total=target_count, initial=generated, desc=f"Clue {clue:2d}", |
| unit="puzzle", leave=False) as pbar_clue: |
| for _ in range(max_attempts): |
| if generated >= target_count: |
| break |
|
|
| result = generate_puzzle_with_range(proc, clue, sol_range, min_hamming) |
| if result is None: |
| continue |
| puzzle, solutions = result |
|
|
| fp = proc.encode(puzzle) |
| if fp in seen_grids: |
| continue |
| seen_grids.add(fp) |
|
|
| n_sols = len(solutions) |
| if n_sols == 1: |
| stats_unique += 1 |
| else: |
| stats_multi += 1 |
|
|
| img_name = f"clue{clue}_{generated:0{num_width}d}.png" |
| puzzle_img = proc.render(puzzle) |
| cv2.imwrite( |
| str(image_dir / img_name), |
| cv2.cvtColor(puzzle_img, cv2.COLOR_RGB2BGR), |
| ) |
|
|
| for si, sol in enumerate(solutions): |
| vid_name = f"clue{clue}_{generated:0{num_width}d}_sol{si}.mp4" |
| vid_frames = generate_video_frames( |
| proc, puzzle, sol, n_start, m_end, frames |
| ) |
| save_video(vid_frames, video_dir / vid_name, fps) |
|
|
| hdists = [ |
| proc._hamming(sol, solutions[j]) |
| for j in range(n_sols) if j != si |
| ] |
| all_samples.append({ |
| "prompt": prompt, "video": vid_name, "image": img_name, |
| "clue": clue, "puzzle": fp, "solution": proc.encode(sol), |
| "sol_idx": si, "total_solutions": n_sols, |
| "frame_count": len(vid_frames), |
| "min_hamming_to_others": min(hdists) if hdists else 0, |
| }) |
|
|
| generated += 1 |
| clue_progress[clue] = generated |
| puzzles_since_checkpoint += 1 |
| pbar_clue.update(1) |
| pbar_total.update(1) |
|
|
| if puzzles_since_checkpoint >= checkpoint_interval: |
| state.clue_progress = clue_progress |
| state.seen_grids = list(seen_grids) |
| state.all_samples = all_samples |
| save_checkpoint(output_dir, state, params) |
| puzzles_since_checkpoint = 0 |
|
|
| tqdm.write( |
| f"Clue {clue}: {generated} puzzles, " |
| f"{sum(1 for s in all_samples if s['clue'] == clue)} videos" |
| ) |
|
|
| |
| random.seed(seed + 1) |
| by_clue: Dict[int, List[Dict]] = {} |
| for s in all_samples: |
| by_clue.setdefault(s["clue"], []).append(s) |
|
|
| train_samples, test_samples = [], [] |
| for clue in sorted(by_clue): |
| group = by_clue[clue] |
| random.shuffle(group) |
| cl_split = int(len(group) * train_ratio) |
| train_samples.extend(group[:cl_split]) |
| test_samples.extend(group[cl_split:]) |
|
|
| random.shuffle(train_samples) |
| random.shuffle(test_samples) |
| split_idx = len(train_samples) |
|
|
| def write_jsonl(samples, path): |
| with open(path, 'w') as f: |
| for s in samples: |
| json.dump(s, f) |
| f.write('\n') |
|
|
| write_jsonl(train_samples, output_dir / "train.jsonl") |
| write_jsonl(test_samples, output_dir / "test.jsonl") |
|
|
| |
| state.clue_progress = clue_progress |
| state.seen_grids = list(seen_grids) |
| state.all_samples = all_samples |
| state.completed = True |
| save_checkpoint(output_dir, state, params) |
|
|
| print(f"\n✓ Dataset complete: {output_dir}/") |
| print(f" Resolution: {actual_size}x{actual_size}") |
| print(f" Solution range: {sol_range}") |
| print(f" Puzzles: {len(seen_grids)} ({stats_unique} unique, {stats_multi} multi-sol)") |
| print(f" Videos: {len(all_samples)}") |
| print(f" Train: {split_idx}, Test: {len(all_samples) - split_idx}") |
|
|
| fcounts = [s["frame_count"] for s in all_samples] |
| print(f" Frame counts: avg={np.mean(fcounts):.1f}, " |
| f"min={min(fcounts)}, max={max(fcounts)}") |
|
|
| hammings = [s["min_hamming_to_others"] for s in all_samples if s["min_hamming_to_others"] > 0] |
| if hammings: |
| print(f" Solution diversity: avg={np.mean(hammings):.1f}, " |
| f"min={min(hammings)}, max={max(hammings)}") |
|
|
|
|
| def parse_resolution(s): |
| w, h = map(int, s.lower().split('x')) |
| return (w, h) |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser( |
| description="Generate Sudoku video dataset with resume support" |
| ) |
| p.add_argument("--output-dir", type=str, default="sudoku") |
| p.add_argument("--clue-levels", type=int, nargs="+", |
| default=[20, 30, 40, 50, 60, 70]) |
| p.add_argument("--num-per-clue", type=int, nargs="+", |
| default=[15000, 10000, 10000, 5000, 2000, 1000]) |
| p.add_argument("--sol-num", type=str, default="<=3", |
| help="'1', '3', '>=1', '>1', '<=3', '<3', '2-5'") |
| p.add_argument("--min-hamming", type=int, default=10) |
| p.add_argument("--train-ratio", type=float, default=0.9) |
| p.add_argument("--prompt", type=str, |
| default="Solve this Sudoku puzzle using red font.") |
| p.add_argument("--n-start", type=int, default=2, |
| help="Hold frames for puzzle at video start") |
| p.add_argument("--m-end", type=int, default=3, |
| help="Hold frames for completed solution at video end") |
| p.add_argument("--frames", type=int, default=None, |
| help="Content frames per video. None=1 per fill (auto). " |
| "If > fills: slow-motion. If < fills: fast-forward.") |
| p.add_argument("--fps", type=int, default=10) |
| p.add_argument("--resolution", type=str, default="1024x1024") |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--checkpoint-interval", type=int, default=50, |
| help="Save checkpoint every N puzzles (default: 50)") |
| return p.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| kwargs = vars(args) |
| if isinstance(kwargs["num_per_clue"], list) and len(kwargs["num_per_clue"]) == 1: |
| kwargs["num_per_clue"] = kwargs["num_per_clue"][0] |
| if kwargs["resolution"]: |
| kwargs["resolution"] = parse_resolution(kwargs["resolution"]) |
| generate_dataset(**kwargs) |