Jayce-Ping commited on
Commit
7cdb0ca
·
verified ·
1 Parent(s): 6f7040a

Add files using upload-large-folder tool

Browse files
Files changed (29) hide show
  1. sequence/data_generation.py +336 -0
  2. sequence/test.py +376 -0
  3. sudoku/convert.py +81 -0
  4. sudoku/convert_wan.py +1287 -0
  5. sudoku/generate_dataset.py +424 -0
  6. sudoku/jsonl_to_csv.py +22 -0
  7. sudoku/simplify_dataset.py +19 -0
  8. sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00001-of-00007.safetensors +3 -0
  9. sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00002-of-00007.safetensors +3 -0
  10. sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00003-of-00007.safetensors +3 -0
  11. sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00004-of-00007.safetensors +3 -0
  12. sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00005-of-00007.safetensors +3 -0
  13. sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00006-of-00007.safetensors +3 -0
  14. sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00007-of-00007.safetensors +3 -0
  15. sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0.safetensors +3 -0
  16. sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-0.safetensors +3 -0
  17. sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-1.safetensors +3 -0
  18. sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-2.safetensors +3 -0
  19. sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-3.safetensors +3 -0
  20. sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-4.safetensors +3 -0
  21. sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31/epoch-3.safetensors +3 -0
  22. sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00001-of-00007.safetensors +3 -0
  23. sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00002-of-00007.safetensors +3 -0
  24. sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00003-of-00007.safetensors +3 -0
  25. sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00004-of-00007.safetensors +3 -0
  26. sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00005-of-00007.safetensors +3 -0
  27. sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00006-of-00007.safetensors +3 -0
  28. sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00007-of-00007.safetensors +3 -0
  29. sudoku/sudoku_processor.py +479 -0
sequence/data_generation.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sequence Prediction Dataset Generator.
3
+
4
+ Generates image pairs for sequence prediction tasks with various
5
+ mathematical sequences (arithmetic, geometric, fibonacci, etc.)
6
+ """
7
+
8
+ import json
9
+ import random
10
+ from pathlib import Path
11
+ from typing import Callable
12
+
13
+ import matplotlib.pyplot as plt
14
+ import matplotlib.patches as patches
15
+
16
+
17
+ # ============== Sequence Generators ==============
18
+
19
+ def arithmetic_seq(start: int, diff: int, length: int = 4) -> list[int]:
20
+ """Arithmetic sequence: a, a+d, a+2d, ..."""
21
+ return [start + i * diff for i in range(length)]
22
+
23
+
24
+ def geometric_seq(start: int, ratio: int, length: int = 4) -> list[int]:
25
+ """Geometric sequence: a, a*r, a*r^2, ..."""
26
+ return [start * (ratio ** i) for i in range(length)]
27
+
28
+
29
+ def square_seq(start: int, length: int = 4) -> list[int]:
30
+ """Square numbers: n^2, (n+1)^2, ..."""
31
+ return [(start + i) ** 2 for i in range(length)]
32
+
33
+
34
+ def cube_seq(start: int, length: int = 4) -> list[int]:
35
+ """Cube numbers: n^3, (n+1)^3, ..."""
36
+ return [(start + i) ** 3 for i in range(length)]
37
+
38
+
39
+ def triangular_seq(start: int, length: int = 4) -> list[int]:
40
+ """Triangular numbers: n(n+1)/2"""
41
+ return [(start + i) * (start + i + 1) // 2 for i in range(length)]
42
+
43
+
44
+ def fibonacci_like_seq(a: int, b: int, length: int = 4) -> list[int]:
45
+ """Fibonacci-like: a, b, a+b, a+2b, ..."""
46
+ seq = [a, b]
47
+ for _ in range(length - 2):
48
+ seq.append(seq[-1] + seq[-2])
49
+ return seq[:length]
50
+
51
+
52
+ def prime_seq(start_idx: int, length: int = 4) -> list[int]:
53
+ """Prime numbers starting from index."""
54
+ primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47]
55
+ return primes[start_idx:start_idx + length]
56
+
57
+
58
+ def power_of_two_seq(start: int, length: int = 4) -> list[int]:
59
+ """Powers of 2: 2^n, 2^(n+1), ..."""
60
+ return [2 ** (start + i) for i in range(length)]
61
+
62
+
63
+ def factorial_seq(start: int, length: int = 4) -> list[int]:
64
+ """Factorial sequence: n!, (n+1)!, ..."""
65
+ from math import factorial
66
+ return [factorial(start + i) for i in range(length)]
67
+
68
+
69
+ # ============== Sequence Factory ==============
70
+
71
+ SEQUENCE_TYPES = {
72
+ "arithmetic": lambda rng: arithmetic_seq(
73
+ rng.randint(1, 20), rng.randint(1, 10)
74
+ ),
75
+ "arithmetic_neg": lambda rng: arithmetic_seq(
76
+ rng.randint(20, 50), -rng.randint(1, 5)
77
+ ),
78
+ "geometric_2": lambda rng: geometric_seq(
79
+ rng.randint(1, 5), 2
80
+ ),
81
+ "geometric_3": lambda rng: geometric_seq(
82
+ rng.randint(1, 3), 3
83
+ ),
84
+ "square": lambda rng: square_seq(rng.randint(1, 10)),
85
+ "cube": lambda rng: cube_seq(rng.randint(1, 5)),
86
+ "triangular": lambda rng: triangular_seq(rng.randint(1, 10)),
87
+ "fibonacci": lambda rng: fibonacci_like_seq(
88
+ rng.randint(1, 5), rng.randint(1, 5)
89
+ ),
90
+ "prime": lambda rng: prime_seq(rng.randint(0, 10)),
91
+ "power_of_2": lambda rng: power_of_two_seq(rng.randint(0, 6)),
92
+ }
93
+
94
+
95
+ def generate_sequence_pair(seq: list[int]) -> tuple[list, list]:
96
+ """
97
+ Generate a pair of sequences for the task.
98
+
99
+ Returns:
100
+ (partial, complete): partial has last element as "", complete is full.
101
+ """
102
+ partial = seq[:-1] + [""]
103
+ return partial, seq
104
+
105
+
106
+ # ============== Image Generation ==============
107
+
108
+ def round_to_multiple(x: int, multiple: int = 16) -> int:
109
+ """Round x up to nearest multiple."""
110
+ return ((x + multiple - 1) // multiple) * multiple
111
+
112
+
113
+ def create_number_grid(
114
+ numbers: list,
115
+ save_path: str,
116
+ height: int = 224,
117
+ width: int = 896,
118
+ fontsize: int = 48,
119
+ size_multiple: int = 16,
120
+ ) -> None:
121
+ """
122
+ Create a 1xN grid image with numbers in each cell.
123
+
124
+ Args:
125
+ numbers: List of numbers/strings to display.
126
+ save_path: Output file path.
127
+ height: Target height in pixels (will be rounded to size_multiple).
128
+ width: Target width in pixels (will be rounded to size_multiple).
129
+ fontsize: Font size for the numbers.
130
+ size_multiple: Ensure dimensions are multiples of this (default 16).
131
+ """
132
+ from PIL import Image
133
+
134
+ n = len(numbers)
135
+
136
+ # Ensure dimensions are multiples of size_multiple
137
+ width = round_to_multiple(width, size_multiple)
138
+ height = round_to_multiple(height, size_multiple)
139
+
140
+ # Use fixed DPI and calculate figsize
141
+ dpi = 100
142
+ fig_width = width / dpi
143
+ fig_height = height / dpi
144
+
145
+ fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=dpi)
146
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
147
+
148
+ for i, num in enumerate(numbers):
149
+ rect = patches.Rectangle(
150
+ (i, 0), 1, 1, linewidth=2,
151
+ edgecolor='black', facecolor='white'
152
+ )
153
+ ax.add_patch(rect)
154
+ ax.text(
155
+ i + 0.5, 0.5, str(num), fontsize=fontsize,
156
+ ha='center', va='center', fontweight='bold'
157
+ )
158
+
159
+ ax.set_xlim(0, n)
160
+ ax.set_ylim(0, 1)
161
+ ax.set_aspect('equal')
162
+ ax.axis('off')
163
+
164
+ # Save with exact pixel dimensions
165
+ fig.savefig(save_path, dpi=dpi, facecolor='white', edgecolor='none')
166
+ plt.close(fig)
167
+
168
+ # Final resize to ensure exact dimensions (16 multiples)
169
+ img = Image.open(save_path)
170
+ if img.size != (width, height):
171
+ img = img.resize((width, height), Image.Resampling.LANCZOS)
172
+ img.save(save_path)
173
+
174
+
175
+ # ============== Dataset Generation ==============
176
+
177
+ class SequenceDatasetGenerator:
178
+ """Generate sequence prediction dataset with train/test splits."""
179
+
180
+ def __init__(
181
+ self,
182
+ output_dir: str,
183
+ seed: int = 42,
184
+ num_pairs: tuple[int, int] = (2, 3),
185
+ seq_types: list[str] | None = None,
186
+ image_height: int = 224,
187
+ image_width: int = 896,
188
+ fontsize: int = 48,
189
+ ):
190
+ """
191
+ Args:
192
+ output_dir: Directory to save the dataset.
193
+ seed: Random seed for reproducibility.
194
+ num_pairs: Range of pairs per sample (min, max inclusive).
195
+ seq_types: List of sequence types to use (None = all).
196
+ image_height: Image height in pixels (rounded to 16).
197
+ image_width: Image width in pixels (rounded to 16).
198
+ fontsize: Font size for numbers.
199
+ """
200
+ self.output_dir = Path(output_dir)
201
+ self.rng = random.Random(seed)
202
+ self.num_pairs = num_pairs
203
+ self.seq_types = seq_types or list(SEQUENCE_TYPES.keys())
204
+ self.image_height = round_to_multiple(image_height, 16)
205
+ self.image_width = round_to_multiple(image_width, 16)
206
+ self.fontsize = fontsize
207
+
208
+ # Create directories
209
+ for split in ["train", "test"]:
210
+ (self.output_dir / split / "images").mkdir(parents=True, exist_ok=True)
211
+
212
+ def _generate_sample(self, sample_id: int) -> dict:
213
+ """Generate a single sample with multiple sequence pairs."""
214
+ num_pairs = self.rng.randint(*self.num_pairs)
215
+ seq_type = self.rng.choice(self.seq_types)
216
+
217
+ # Generate base sequence and subsequent ones
218
+ base_seq = SEQUENCE_TYPES[seq_type](self.rng)
219
+
220
+ pairs = []
221
+ for i in range(num_pairs):
222
+ # Shift sequence for each pair
223
+ if seq_type.startswith("arithmetic"):
224
+ diff = base_seq[1] - base_seq[0]
225
+ seq = [x + i * diff for x in base_seq]
226
+ elif seq_type.startswith("geometric"):
227
+ ratio = base_seq[1] // base_seq[0] if base_seq[0] != 0 else 2
228
+ seq = [x * (ratio ** i) for x in base_seq]
229
+ else:
230
+ # For other types, regenerate with offset
231
+ seq = [x + i for x in base_seq]
232
+
233
+ partial, complete = generate_sequence_pair(seq)
234
+ pairs.append({
235
+ "partial": partial,
236
+ "complete": complete,
237
+ "answer": complete[-1],
238
+ })
239
+
240
+ return {
241
+ "id": sample_id,
242
+ "seq_type": seq_type,
243
+ "num_pairs": num_pairs,
244
+ "pairs": pairs,
245
+ }
246
+
247
+ def _save_sample_images(
248
+ self, sample: dict, split: str, include_last_answer: bool = True
249
+ ) -> dict:
250
+ """Save images for a sample and return metadata."""
251
+ sample_id = sample["id"]
252
+ image_dir = self.output_dir / split / "images"
253
+
254
+ images = []
255
+ img_idx = 0
256
+
257
+ for i, pair in enumerate(sample["pairs"]):
258
+ # Always save partial (query) image
259
+ partial_path = f"{sample_id:05d}_{img_idx}.png"
260
+ create_number_grid(
261
+ pair["partial"], image_dir / partial_path,
262
+ height=self.image_height, width=self.image_width,
263
+ fontsize=self.fontsize,
264
+ )
265
+ images.append(partial_path)
266
+ img_idx += 1
267
+
268
+ # Save complete image based on split logic
269
+ is_last = (i == sample["num_pairs"] - 1)
270
+ if include_last_answer or not is_last:
271
+ complete_path = f"{sample_id:05d}_{img_idx}.png"
272
+ create_number_grid(
273
+ pair["complete"], image_dir / complete_path,
274
+ height=self.image_height, width=self.image_width,
275
+ fontsize=self.fontsize,
276
+ )
277
+ images.append(complete_path)
278
+ img_idx += 1
279
+
280
+ return {
281
+ "id": sample_id,
282
+ "seq_type": sample["seq_type"],
283
+ "num_pairs": sample["num_pairs"],
284
+ "images": images,
285
+ "answer": sample["pairs"][-1]["answer"], # Last image's answer
286
+ "sequences": [p["complete"] for p in sample["pairs"]],
287
+ }
288
+
289
+ def generate(self, num_train: int, num_test: int) -> None:
290
+ """
291
+ Generate the full dataset.
292
+
293
+ Args:
294
+ num_train: Number of training samples.
295
+ num_test: Number of test samples.
296
+ """
297
+ train_meta, test_meta = [], []
298
+
299
+ # Generate training samples (all pairs complete)
300
+ print(f"Generating {num_train} training samples...")
301
+ for i in range(num_train):
302
+ sample = self._generate_sample(i)
303
+ meta = self._save_sample_images(sample, "train", include_last_answer=True)
304
+ train_meta.append(meta)
305
+ if (i + 1) % 50 == 0:
306
+ print(f" Train: {i + 1}/{num_train}")
307
+
308
+ # Generate test samples (last answer hidden)
309
+ print(f"Generating {num_test} test samples...")
310
+ for i in range(num_test):
311
+ sample = self._generate_sample(num_train + i)
312
+ meta = self._save_sample_images(sample, "test", include_last_answer=False)
313
+ test_meta.append(meta)
314
+ if (i + 1) % 50 == 0:
315
+ print(f" Test: {i + 1}/{num_test}")
316
+
317
+ # Save metadata
318
+ with open(self.output_dir / "train.json", "w") as f:
319
+ json.dump(train_meta, f, indent=2)
320
+ with open(self.output_dir / "test.json", "w") as f:
321
+ json.dump(test_meta, f, indent=2)
322
+
323
+ print(f"\nDataset saved to {self.output_dir}")
324
+ print(f" Train: {num_train} samples")
325
+ print(f" Test: {num_test} samples")
326
+ print(f" Image size: {self.image_width}x{self.image_height}")
327
+ print(f" Sequence types: {self.seq_types}")
328
+
329
+
330
+ if __name__ == "__main__":
331
+ generator = SequenceDatasetGenerator(
332
+ output_dir="/home/claude/sequence_dataset",
333
+ seed=42,
334
+ num_pairs=(2, 3),
335
+ )
336
+ generator.generate(num_train=100, num_test=20)
sequence/test.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sequence Prediction Evaluation with QwenImageEditPlusPipeline / Flux2KleinPipeline.
3
+
4
+ Evaluates the model's ability to predict the next number in a sequence
5
+ by generating images and extracting answers via OCR.
6
+ """
7
+
8
+ import json
9
+ import re
10
+ from pathlib import Path
11
+ from dataclasses import dataclass, field
12
+ from enum import Enum
13
+
14
+ import numpy as np
15
+ import torch
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+
19
+
20
+ class ModelType(str, Enum):
21
+ QWEN_IMAGE_EDIT = "qwen"
22
+ FLUX2_KLEIN = "flux2-klein"
23
+
24
+
25
+ @dataclass
26
+ class EvalConfig:
27
+ """Evaluation configuration."""
28
+ dataset_dir: str = "sequence_dataset"
29
+ output_dir: str = "eval_results"
30
+
31
+ # Model selection
32
+ model_type: ModelType = ModelType.QWEN_IMAGE_EDIT
33
+ model_id: str = "" # Auto-set based on model_type if empty
34
+
35
+ # Prompts
36
+ prompt: str = (
37
+ "Based on the number patterns shown in the previous images, "
38
+ "fill in the missing number in the empty cell of the last image."
39
+ )
40
+ negative_prompt: str = ""
41
+
42
+ # Generation params
43
+ num_inference_steps: int = 5
44
+ guidance_scale: float = 1.0
45
+ true_cfg_scale: float = 4.0 # For Qwen
46
+ height: int = 210
47
+ width: int = 750
48
+
49
+ seed: int = 42
50
+ device: str = "cuda"
51
+ dtype: torch.dtype = field(default_factory=lambda: torch.bfloat16)
52
+
53
+ def __post_init__(self):
54
+ """Set default model_id based on model_type."""
55
+ if not self.model_id:
56
+ if self.model_type == ModelType.QWEN_IMAGE_EDIT:
57
+ self.model_id = "Qwen/Qwen-Image-Edit-2509"
58
+ elif self.model_type == ModelType.FLUX2_KLEIN:
59
+ self.model_id = "black-forest-labs/FLUX.2-klein-9B"
60
+
61
+
62
+ class OCRExtractor:
63
+ """Extract numbers from grid images using OCR."""
64
+
65
+ def __init__(self, backend: str = "easyocr"):
66
+ """
67
+ Args:
68
+ backend: OCR backend ("easyocr" or "pytesseract").
69
+ """
70
+ self.backend = backend
71
+ if backend == "easyocr":
72
+ import easyocr
73
+ self.reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
74
+ elif backend == "pytesseract":
75
+ import pytesseract
76
+ self.pytesseract = pytesseract
77
+ else:
78
+ raise ValueError(f"Unknown backend: {backend}")
79
+
80
+ def extract_last_number(self, image: Image.Image) -> int | None:
81
+ """
82
+ Extract the last (rightmost) number from a grid image.
83
+
84
+ Args:
85
+ image: PIL Image of the number grid.
86
+
87
+ Returns:
88
+ Extracted number or None if extraction fails.
89
+ """
90
+ w, h = image.size
91
+ cell_crop = image.crop((w * 3 // 4, 0, w, h))
92
+ cell_array = np.array(cell_crop)
93
+
94
+ if self.backend == "easyocr":
95
+ results = self.reader.readtext(cell_array)
96
+ for _, text, conf in results:
97
+ digits = re.findall(r'-?\d+', text)
98
+ if digits:
99
+ return int(digits[0])
100
+
101
+ elif self.backend == "pytesseract":
102
+ text = self.pytesseract.image_to_string(
103
+ cell_crop, config='--psm 7 -c tessedit_char_whitelist=0123456789-'
104
+ )
105
+ digits = re.findall(r'-?\d+', text)
106
+ if digits:
107
+ return int(digits[0])
108
+
109
+ return None
110
+
111
+ def extract_all_numbers(self, image: Image.Image, num_cells: int = 4) -> list[int | None]:
112
+ """Extract all numbers from a grid image."""
113
+ w, h = image.size
114
+ cell_width = w // num_cells
115
+ numbers = []
116
+
117
+ for i in range(num_cells):
118
+ cell_crop = image.crop((i * cell_width, 0, (i + 1) * cell_width, h))
119
+ cell_array = np.array(cell_crop)
120
+
121
+ if self.backend == "easyocr":
122
+ results = self.reader.readtext(cell_array)
123
+ num = None
124
+ for _, text, conf in results:
125
+ digits = re.findall(r'-?\d+', text)
126
+ if digits:
127
+ num = int(digits[0])
128
+ break
129
+ numbers.append(num)
130
+
131
+ elif self.backend == "pytesseract":
132
+ text = self.pytesseract.image_to_string(
133
+ cell_crop, config='--psm 7 -c tessedit_char_whitelist=0123456789-'
134
+ )
135
+ digits = re.findall(r'-?\d+', text)
136
+ numbers.append(int(digits[0]) if digits else None)
137
+
138
+ return numbers
139
+
140
+
141
+ class SequenceEvaluator:
142
+ """Evaluator for sequence prediction task."""
143
+
144
+ def __init__(self, config: EvalConfig):
145
+ self.config = config
146
+ self.output_dir = Path(config.output_dir)
147
+ self.output_dir.mkdir(parents=True, exist_ok=True)
148
+
149
+ # Load pipeline based on model type
150
+ self.pipeline = self._load_pipeline()
151
+
152
+ # Initialize OCR
153
+ self.ocr = OCRExtractor(backend="easyocr")
154
+
155
+ def _load_pipeline(self):
156
+ """Load pipeline based on model type."""
157
+ if self.config.model_type == ModelType.QWEN_IMAGE_EDIT:
158
+ return self._load_qwen_pipeline()
159
+ elif self.config.model_type == ModelType.FLUX2_KLEIN:
160
+ return self._load_flux2_klein_pipeline()
161
+ else:
162
+ raise ValueError(f"Unknown model type: {self.config.model_type}")
163
+
164
+ def _load_qwen_pipeline(self):
165
+ """Load QwenImageEditPlusPipeline."""
166
+ from diffusers import QwenImageEditPlusPipeline
167
+
168
+ pipeline = QwenImageEditPlusPipeline.from_pretrained(
169
+ self.config.model_id,
170
+ torch_dtype=self.config.dtype,
171
+ )
172
+ pipeline.to(self.config.device)
173
+ pipeline.set_progress_bar_config(disable=True)
174
+ return pipeline
175
+
176
+ def _load_flux2_klein_pipeline(self):
177
+ """Load Flux2KleinPipeline."""
178
+ from diffusers import Flux2KleinPipeline
179
+
180
+ pipeline = Flux2KleinPipeline.from_pretrained(
181
+ self.config.model_id,
182
+ torch_dtype=self.config.dtype,
183
+ )
184
+ pipeline.enable_model_cpu_offload()
185
+ pipeline.set_progress_bar_config(disable=True)
186
+ return pipeline
187
+
188
+ def _load_images(self, image_paths: list[str], image_dir: Path) -> list[Image.Image]:
189
+ """Load images from paths."""
190
+ return [Image.open(image_dir / p).convert("RGB") for p in image_paths]
191
+
192
+ def predict(self, images: list[Image.Image]) -> Image.Image:
193
+ """
194
+ Generate prediction image given input images.
195
+
196
+ Args:
197
+ images: List of input images (context + query).
198
+
199
+ Returns:
200
+ Generated image with predicted number.
201
+ """
202
+ generator = torch.Generator(device=self.config.device).manual_seed(self.config.seed)
203
+
204
+ if self.config.model_type == ModelType.QWEN_IMAGE_EDIT:
205
+ inputs = {
206
+ "image": images,
207
+ "prompt": self.config.prompt,
208
+ "generator": generator,
209
+ "true_cfg_scale": self.config.true_cfg_scale,
210
+ "negative_prompt": self.config.negative_prompt,
211
+ "num_inference_steps": self.config.num_inference_steps,
212
+ }
213
+
214
+ elif self.config.model_type == ModelType.FLUX2_KLEIN:
215
+ # Flux2Klein uses image parameter for multi-image editing
216
+ inputs = {
217
+ "image": images,
218
+ "prompt": self.config.prompt,
219
+ "generator": generator,
220
+ "guidance_scale": self.config.guidance_scale,
221
+ "num_inference_steps": self.config.num_inference_steps,
222
+ "height": self.config.height,
223
+ "width": self.config.width,
224
+ }
225
+
226
+ with torch.inference_mode():
227
+ output = self.pipeline(**inputs)
228
+
229
+ return output.images[0]
230
+
231
+ def evaluate_sample(self, sample: dict, image_dir: Path) -> dict:
232
+ """
233
+ Evaluate a single sample.
234
+
235
+ Args:
236
+ sample: Sample metadata dict.
237
+ image_dir: Directory containing images.
238
+
239
+ Returns:
240
+ Evaluation result dict.
241
+ """
242
+ # Load input images (all available in test set)
243
+ images = self._load_images(sample["images"], image_dir)
244
+
245
+ # Generate prediction
246
+ pred_image = self.predict(images)
247
+
248
+ # Save prediction image
249
+ pred_path = self.output_dir / f"{sample['id']:05d}_pred.png"
250
+ pred_image.save(pred_path)
251
+
252
+ # Extract predicted number via OCR
253
+ pred_number = self.ocr.extract_last_number(pred_image)
254
+
255
+ # Get ground truth
256
+ gt_number = sample["answer"]
257
+
258
+ # Check correctness
259
+ correct = pred_number == gt_number
260
+
261
+ return {
262
+ "id": sample["id"],
263
+ "seq_type": sample["seq_type"],
264
+ "gt_answer": gt_number,
265
+ "pred_answer": pred_number,
266
+ "correct": correct,
267
+ "pred_image": str(pred_path),
268
+ }
269
+
270
+ def evaluate(self, split: str = "test") -> dict:
271
+ """
272
+ Evaluate on entire dataset split.
273
+
274
+ Args:
275
+ split: Dataset split ("train" or "test").
276
+
277
+ Returns:
278
+ Evaluation results summary.
279
+ """
280
+ dataset_dir = Path(self.config.dataset_dir)
281
+
282
+ # Load metadata
283
+ with open(dataset_dir / f"{split}.json") as f:
284
+ samples = json.load(f)
285
+
286
+ image_dir = dataset_dir / split / "images"
287
+
288
+ results = []
289
+ for sample in tqdm(samples, desc=f"Evaluating {split}"):
290
+ result = self.evaluate_sample(sample, image_dir)
291
+ results.append(result)
292
+
293
+ # Compute metrics
294
+ total = len(results)
295
+ correct = sum(r["correct"] for r in results)
296
+ accuracy = correct / total if total > 0 else 0.0
297
+
298
+ # Per-type accuracy
299
+ type_stats = {}
300
+ for r in results:
301
+ seq_type = r["seq_type"]
302
+ if seq_type not in type_stats:
303
+ type_stats[seq_type] = {"correct": 0, "total": 0}
304
+ type_stats[seq_type]["total"] += 1
305
+ if r["correct"]:
306
+ type_stats[seq_type]["correct"] += 1
307
+
308
+ type_accuracy = {
309
+ k: v["correct"] / v["total"] for k, v in type_stats.items()
310
+ }
311
+
312
+ summary = {
313
+ "split": split,
314
+ "model_type": self.config.model_type.value,
315
+ "model_id": self.config.model_id,
316
+ "total": total,
317
+ "correct": correct,
318
+ "accuracy": accuracy,
319
+ "type_accuracy": type_accuracy,
320
+ "results": results,
321
+ }
322
+
323
+ # Save results
324
+ with open(self.output_dir / f"{split}_results.json", "w") as f:
325
+ json.dump(summary, f, indent=2)
326
+
327
+ return summary
328
+
329
+
330
+ def main():
331
+ """Run evaluation."""
332
+ import argparse
333
+
334
+ parser = argparse.ArgumentParser(description="Sequence Prediction Evaluation")
335
+ parser.add_argument("--model", type=str, default="flux2-klein",
336
+ choices=["qwen", "flux2-klein"],
337
+ help="Model type to use")
338
+ parser.add_argument("--model-id", type=str, default="",
339
+ help="Custom model ID (optional)")
340
+ parser.add_argument("--dataset-dir", type=str, default="sequence_dataset",
341
+ help="Dataset directory")
342
+ parser.add_argument("--output-dir", type=str, default="eval_results",
343
+ help="Output directory")
344
+ parser.add_argument("--steps", type=int, default=50,
345
+ help="Number of inference steps")
346
+ parser.add_argument("--seed", type=int, default=42,
347
+ help="Random seed")
348
+ args = parser.parse_args()
349
+
350
+ config = EvalConfig(
351
+ dataset_dir=args.dataset_dir,
352
+ output_dir=args.output_dir,
353
+ model_type=ModelType(args.model),
354
+ model_id=args.model_id,
355
+ num_inference_steps=args.steps,
356
+ seed=args.seed,
357
+ )
358
+
359
+ print(f"Model: {config.model_type.value} ({config.model_id})")
360
+
361
+ evaluator = SequenceEvaluator(config)
362
+ results = evaluator.evaluate("test")
363
+
364
+ print(f"\n{'='*50}")
365
+ print(f"Evaluation Results ({config.model_type.value})")
366
+ print(f"{'='*50}")
367
+ print(f"Total samples: {results['total']}")
368
+ print(f"Correct: {results['correct']}")
369
+ print(f"Accuracy: {results['accuracy']:.2%}")
370
+ print(f"\nPer-type accuracy:")
371
+ for seq_type, acc in sorted(results["type_accuracy"].items()):
372
+ print(f" {seq_type}: {acc:.2%}")
373
+
374
+
375
+ if __name__ == "__main__":
376
+ main()
sudoku/convert.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 将单个safetensors文件转换为HuggingFace Diffusers格式。
3
+
4
+ Usage:
5
+ python convert_single.py --ckpt epoch-4.safetensors --model_type Wan-T2V-14B --output_path ./output
6
+ """
7
+
8
+ import argparse
9
+ import torch
10
+ from safetensors.torch import load_file
11
+ from accelerate import init_empty_weights
12
+
13
+ # 从原脚本导入(或直接复制相关字典和函数)
14
+ from convert_wan import (
15
+ get_transformer_config,
16
+ update_state_dict_,
17
+ DTYPE_MAPPING,
18
+ )
19
+ from diffusers import WanTransformer3DModel, WanVACETransformer3DModel, WanAnimateTransformer3DModel
20
+
21
+
22
+ def convert_single_checkpoint(ckpt_path: str, model_type: str, dtype: str = "bf16"):
23
+ """
24
+ 转换单个checkpoint文件为Diffusers格式Transformer。
25
+
26
+ Args:
27
+ ckpt_path: safetensors文件路径
28
+ model_type: 模型类型,如 "Wan-T2V-14B", "Wan-I2V-14B-720p" 等
29
+ dtype: 输出精度
30
+
31
+ Returns:
32
+ 转换后的transformer模型
33
+ """
34
+ # 1. 获取配置和重命名规则
35
+ config, rename_dict, special_keys_remap = get_transformer_config(model_type)
36
+ diffusers_config = config["diffusers_config"]
37
+
38
+ # 2. 加载原始权重
39
+ state_dict = load_file(ckpt_path)
40
+
41
+ # 3. 重命名keys
42
+ for key in list(state_dict.keys()):
43
+ new_key = key
44
+ for old, new in rename_dict.items():
45
+ new_key = new_key.replace(old, new)
46
+ update_state_dict_(state_dict, key, new_key)
47
+
48
+ # 4. 处理特殊keys
49
+ for key in list(state_dict.keys()):
50
+ for special_key, handler_fn in special_keys_remap.items():
51
+ if special_key in key:
52
+ handler_fn(key, state_dict)
53
+
54
+ # 5. 创建模型并加载权重
55
+ with init_empty_weights():
56
+ if "Animate" in model_type:
57
+ transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
58
+ elif "VACE" in model_type:
59
+ transformer = WanVACETransformer3DModel.from_config(diffusers_config)
60
+ else:
61
+ transformer = WanTransformer3DModel.from_config(diffusers_config)
62
+
63
+ transformer.load_state_dict(state_dict, strict=True, assign=True)
64
+
65
+ if dtype != "none":
66
+ transformer = transformer.to(DTYPE_MAPPING[dtype])
67
+
68
+ return transformer
69
+
70
+
71
+ if __name__ == "__main__":
72
+ parser = argparse.ArgumentParser()
73
+ parser.add_argument("--ckpt", type=str, required=True, help="safetensors文件路径")
74
+ parser.add_argument("--model_type", type=str, required=True, help="模型类型")
75
+ parser.add_argument("--output_path", type=str, required=True, help="输出目录")
76
+ parser.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16", "none"])
77
+ args = parser.parse_args()
78
+
79
+ transformer = convert_single_checkpoint(args.ckpt, args.model_type, args.dtype)
80
+ transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
81
+ print(f"Saved to {args.output_path}")
sudoku/convert_wan.py ADDED
@@ -0,0 +1,1287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pathlib
3
+ from typing import Any, Dict, Tuple
4
+
5
+ import torch
6
+ from accelerate import init_empty_weights
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
+ from safetensors.torch import load_file
9
+ from transformers import (
10
+ AutoProcessor,
11
+ AutoTokenizer,
12
+ CLIPImageProcessor,
13
+ CLIPVisionModel,
14
+ CLIPVisionModelWithProjection,
15
+ UMT5EncoderModel,
16
+ )
17
+
18
+ from diffusers import (
19
+ AutoencoderKLWan,
20
+ UniPCMultistepScheduler,
21
+ WanAnimatePipeline,
22
+ WanAnimateTransformer3DModel,
23
+ WanImageToVideoPipeline,
24
+ WanPipeline,
25
+ WanTransformer3DModel,
26
+ WanVACEPipeline,
27
+ WanVACETransformer3DModel,
28
+ )
29
+
30
+
31
+ TRANSFORMER_KEYS_RENAME_DICT = {
32
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
33
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
34
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
35
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
36
+ "time_projection.1": "condition_embedder.time_proj",
37
+ "head.modulation": "scale_shift_table",
38
+ "head.head": "proj_out",
39
+ "modulation": "scale_shift_table",
40
+ "ffn.0": "ffn.net.0.proj",
41
+ "ffn.2": "ffn.net.2",
42
+ # Hack to swap the layer names
43
+ # The original model calls the norms in following order: norm1, norm3, norm2
44
+ # We convert it to: norm1, norm2, norm3
45
+ "norm2": "norm__placeholder",
46
+ "norm3": "norm2",
47
+ "norm__placeholder": "norm3",
48
+ # For the I2V model
49
+ "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
50
+ "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
51
+ "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
52
+ "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
53
+ # for the FLF2V model
54
+ "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
55
+ # Add attention component mappings
56
+ "self_attn.q": "attn1.to_q",
57
+ "self_attn.k": "attn1.to_k",
58
+ "self_attn.v": "attn1.to_v",
59
+ "self_attn.o": "attn1.to_out.0",
60
+ "self_attn.norm_q": "attn1.norm_q",
61
+ "self_attn.norm_k": "attn1.norm_k",
62
+ "cross_attn.q": "attn2.to_q",
63
+ "cross_attn.k": "attn2.to_k",
64
+ "cross_attn.v": "attn2.to_v",
65
+ "cross_attn.o": "attn2.to_out.0",
66
+ "cross_attn.norm_q": "attn2.norm_q",
67
+ "cross_attn.norm_k": "attn2.norm_k",
68
+ "attn2.to_k_img": "attn2.add_k_proj",
69
+ "attn2.to_v_img": "attn2.add_v_proj",
70
+ "attn2.norm_k_img": "attn2.norm_added_k",
71
+ }
72
+
73
+ VACE_TRANSFORMER_KEYS_RENAME_DICT = {
74
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
75
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
76
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
77
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
78
+ "time_projection.1": "condition_embedder.time_proj",
79
+ "head.modulation": "scale_shift_table",
80
+ "head.head": "proj_out",
81
+ "modulation": "scale_shift_table",
82
+ "ffn.0": "ffn.net.0.proj",
83
+ "ffn.2": "ffn.net.2",
84
+ # Hack to swap the layer names
85
+ # The original model calls the norms in following order: norm1, norm3, norm2
86
+ # We convert it to: norm1, norm2, norm3
87
+ "norm2": "norm__placeholder",
88
+ "norm3": "norm2",
89
+ "norm__placeholder": "norm3",
90
+ # # For the I2V model
91
+ # "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
92
+ # "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
93
+ # "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
94
+ # "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
95
+ # # for the FLF2V model
96
+ # "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
97
+ # Add attention component mappings
98
+ "self_attn.q": "attn1.to_q",
99
+ "self_attn.k": "attn1.to_k",
100
+ "self_attn.v": "attn1.to_v",
101
+ "self_attn.o": "attn1.to_out.0",
102
+ "self_attn.norm_q": "attn1.norm_q",
103
+ "self_attn.norm_k": "attn1.norm_k",
104
+ "cross_attn.q": "attn2.to_q",
105
+ "cross_attn.k": "attn2.to_k",
106
+ "cross_attn.v": "attn2.to_v",
107
+ "cross_attn.o": "attn2.to_out.0",
108
+ "cross_attn.norm_q": "attn2.norm_q",
109
+ "cross_attn.norm_k": "attn2.norm_k",
110
+ "attn2.to_k_img": "attn2.add_k_proj",
111
+ "attn2.to_v_img": "attn2.add_v_proj",
112
+ "attn2.norm_k_img": "attn2.norm_added_k",
113
+ "before_proj": "proj_in",
114
+ "after_proj": "proj_out",
115
+ }
116
+
117
+ ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
118
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
119
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
120
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
121
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
122
+ "time_projection.1": "condition_embedder.time_proj",
123
+ "head.modulation": "scale_shift_table",
124
+ "head.head": "proj_out",
125
+ "modulation": "scale_shift_table",
126
+ "ffn.0": "ffn.net.0.proj",
127
+ "ffn.2": "ffn.net.2",
128
+ # Hack to swap the layer names
129
+ # The original model calls the norms in following order: norm1, norm3, norm2
130
+ # We convert it to: norm1, norm2, norm3
131
+ "norm2": "norm__placeholder",
132
+ "norm3": "norm2",
133
+ "norm__placeholder": "norm3",
134
+ "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
135
+ "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
136
+ "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
137
+ "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
138
+ # Add attention component mappings
139
+ "self_attn.q": "attn1.to_q",
140
+ "self_attn.k": "attn1.to_k",
141
+ "self_attn.v": "attn1.to_v",
142
+ "self_attn.o": "attn1.to_out.0",
143
+ "self_attn.norm_q": "attn1.norm_q",
144
+ "self_attn.norm_k": "attn1.norm_k",
145
+ "cross_attn.q": "attn2.to_q",
146
+ "cross_attn.k": "attn2.to_k",
147
+ "cross_attn.v": "attn2.to_v",
148
+ "cross_attn.o": "attn2.to_out.0",
149
+ "cross_attn.norm_q": "attn2.norm_q",
150
+ "cross_attn.norm_k": "attn2.norm_k",
151
+ "cross_attn.k_img": "attn2.to_k_img",
152
+ "cross_attn.v_img": "attn2.to_v_img",
153
+ "cross_attn.norm_k_img": "attn2.norm_k_img",
154
+ # After cross_attn -> attn2 rename, we need to rename the img keys
155
+ "attn2.to_k_img": "attn2.add_k_proj",
156
+ "attn2.to_v_img": "attn2.add_v_proj",
157
+ "attn2.norm_k_img": "attn2.norm_added_k",
158
+ # Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
159
+ # Motion encoder mappings
160
+ # The name mapping is complicated for the convolutional part so we handle that in its own function
161
+ "motion_encoder.enc.fc": "motion_encoder.motion_network",
162
+ "motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
163
+ # Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
164
+ "face_encoder.conv1_local.conv": "face_encoder.conv1_local",
165
+ "face_encoder.conv2.conv": "face_encoder.conv2",
166
+ "face_encoder.conv3.conv": "face_encoder.conv3",
167
+ # Face adapter mappings are handled in a separate function
168
+ }
169
+
170
+
171
+ # TODO: Verify this and simplify if possible.
172
+ def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
173
+ """
174
+ Convert all motion encoder weights for Animate model.
175
+
176
+ In the original model:
177
+ - All Linear layers in fc use EqualLinear
178
+ - All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
179
+ - Blur kernels are stored as buffers in Sequential modules
180
+ - ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)]
181
+
182
+ Conversion strategy:
183
+ 1. Drop .kernel buffers (blur kernels)
184
+ 2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
185
+ """
186
+ # Skip if not a weight, bias, or kernel
187
+ if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
188
+ return
189
+
190
+ # Handle Blur kernel buffers from original implementation.
191
+ # After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
192
+ # Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
193
+ if ".kernel" in key and "motion_encoder" in key:
194
+ # Remove unexpected blur kernel buffers to avoid strict load errors
195
+ state_dict.pop(key, None)
196
+ return
197
+
198
+ # Rename Sequential indices to named components in ConvLayer and ResBlock
199
+ if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
200
+ parts = key.split(".")
201
+
202
+ # Find the sequential index (digit) after convs or after conv1/conv2/skip
203
+ # Examples:
204
+ # - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
205
+ # - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
206
+ # - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
207
+ # - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
208
+ # - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
209
+ # - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
210
+ # - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
211
+ # - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
212
+ # - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
213
+ # - enc.net_app.convs.8 -> conv_out (final conv layer)
214
+
215
+ convs_idx = parts.index("convs") if "convs" in parts else -1
216
+ if convs_idx >= 0 and len(parts) - convs_idx >= 2:
217
+ bias = False
218
+ # The nn.Sequential index will always follow convs
219
+ sequential_idx = int(parts[convs_idx + 1])
220
+ if sequential_idx == 0:
221
+ if key.endswith(".weight"):
222
+ new_key = "motion_encoder.conv_in.weight"
223
+ elif key.endswith(".bias"):
224
+ new_key = "motion_encoder.conv_in.act_fn.bias"
225
+ bias = True
226
+ elif sequential_idx == final_conv_idx:
227
+ if key.endswith(".weight"):
228
+ new_key = "motion_encoder.conv_out.weight"
229
+ else:
230
+ # Intermediate .convs. layers, which get mapped to .res_blocks.
231
+ prefix = "motion_encoder.res_blocks."
232
+
233
+ layer_name = parts[convs_idx + 2]
234
+ if layer_name == "skip":
235
+ layer_name = "conv_skip"
236
+
237
+ if key.endswith(".weight"):
238
+ param_name = "weight"
239
+ elif key.endswith(".bias"):
240
+ param_name = "act_fn.bias"
241
+ bias = True
242
+
243
+ suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
244
+ suffix = ".".join(suffix_parts)
245
+ new_key = prefix + suffix
246
+
247
+ param = state_dict.pop(key)
248
+ if bias:
249
+ param = param.squeeze()
250
+ state_dict[new_key] = param
251
+ return
252
+ return
253
+ return
254
+
255
+
256
+ def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
257
+ """
258
+ Convert face adapter weights for the Animate model.
259
+
260
+ The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
261
+ """
262
+ # Skip if not a weight or bias
263
+ if ".weight" not in key and ".bias" not in key:
264
+ return
265
+
266
+ prefix = "face_adapter."
267
+ if ".fuser_blocks." in key:
268
+ parts = key.split(".")
269
+
270
+ module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
271
+ if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
272
+ block_idx = parts[module_list_idx + 1]
273
+ layer_name = parts[module_list_idx + 2]
274
+ param_name = parts[module_list_idx + 3]
275
+
276
+ if layer_name == "linear1_kv":
277
+ layer_name_k = "to_k"
278
+ layer_name_v = "to_v"
279
+
280
+ suffix_k = ".".join([block_idx, layer_name_k, param_name])
281
+ suffix_v = ".".join([block_idx, layer_name_v, param_name])
282
+ new_key_k = prefix + suffix_k
283
+ new_key_v = prefix + suffix_v
284
+
285
+ kv_proj = state_dict.pop(key)
286
+ k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
287
+ state_dict[new_key_k] = k_proj
288
+ state_dict[new_key_v] = v_proj
289
+ return
290
+ else:
291
+ if layer_name == "q_norm":
292
+ new_layer_name = "norm_q"
293
+ elif layer_name == "k_norm":
294
+ new_layer_name = "norm_k"
295
+ elif layer_name == "linear1_q":
296
+ new_layer_name = "to_q"
297
+ elif layer_name == "linear2":
298
+ new_layer_name = "to_out"
299
+
300
+ suffix_parts = [block_idx, new_layer_name, param_name]
301
+ suffix = ".".join(suffix_parts)
302
+ new_key = prefix + suffix
303
+ state_dict[new_key] = state_dict.pop(key)
304
+ return
305
+ return
306
+
307
+
308
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {}
309
+ VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
310
+ ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
311
+ "motion_encoder": convert_animate_motion_encoder_weights,
312
+ "face_adapter": convert_animate_face_adapter_weights,
313
+ }
314
+
315
+
316
+ def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
317
+ state_dict[new_key] = state_dict.pop(old_key)
318
+
319
+
320
+ def load_sharded_safetensors(dir: pathlib.Path):
321
+ file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
322
+ state_dict = {}
323
+ for path in file_paths:
324
+ state_dict.update(load_file(path))
325
+ return state_dict
326
+
327
+
328
+ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
329
+ if model_type == "Wan-T2V-1.3B":
330
+ config = {
331
+ "model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
332
+ "diffusers_config": {
333
+ "added_kv_proj_dim": None,
334
+ "attention_head_dim": 128,
335
+ "cross_attn_norm": True,
336
+ "eps": 1e-06,
337
+ "ffn_dim": 8960,
338
+ "freq_dim": 256,
339
+ "in_channels": 16,
340
+ "num_attention_heads": 12,
341
+ "num_layers": 30,
342
+ "out_channels": 16,
343
+ "patch_size": [1, 2, 2],
344
+ "qk_norm": "rms_norm_across_heads",
345
+ "text_dim": 4096,
346
+ },
347
+ }
348
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
349
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
350
+ elif model_type == "Wan-T2V-14B":
351
+ config = {
352
+ "model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
353
+ "diffusers_config": {
354
+ "added_kv_proj_dim": None,
355
+ "attention_head_dim": 128,
356
+ "cross_attn_norm": True,
357
+ "eps": 1e-06,
358
+ "ffn_dim": 13824,
359
+ "freq_dim": 256,
360
+ "in_channels": 16,
361
+ "num_attention_heads": 40,
362
+ "num_layers": 40,
363
+ "out_channels": 16,
364
+ "patch_size": [1, 2, 2],
365
+ "qk_norm": "rms_norm_across_heads",
366
+ "text_dim": 4096,
367
+ },
368
+ }
369
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
370
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
371
+ elif model_type == "Wan-I2V-14B-480p":
372
+ config = {
373
+ "model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
374
+ "diffusers_config": {
375
+ "image_dim": 1280,
376
+ "added_kv_proj_dim": 5120,
377
+ "attention_head_dim": 128,
378
+ "cross_attn_norm": True,
379
+ "eps": 1e-06,
380
+ "ffn_dim": 13824,
381
+ "freq_dim": 256,
382
+ "in_channels": 36,
383
+ "num_attention_heads": 40,
384
+ "num_layers": 40,
385
+ "out_channels": 16,
386
+ "patch_size": [1, 2, 2],
387
+ "qk_norm": "rms_norm_across_heads",
388
+ "text_dim": 4096,
389
+ },
390
+ }
391
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
392
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
393
+ elif model_type == "Wan-I2V-14B-720p":
394
+ config = {
395
+ "model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
396
+ "diffusers_config": {
397
+ "image_dim": 1280,
398
+ "added_kv_proj_dim": 5120,
399
+ "attention_head_dim": 128,
400
+ "cross_attn_norm": True,
401
+ "eps": 1e-06,
402
+ "ffn_dim": 13824,
403
+ "freq_dim": 256,
404
+ "in_channels": 36,
405
+ "num_attention_heads": 40,
406
+ "num_layers": 40,
407
+ "out_channels": 16,
408
+ "patch_size": [1, 2, 2],
409
+ "qk_norm": "rms_norm_across_heads",
410
+ "text_dim": 4096,
411
+ },
412
+ }
413
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
414
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
415
+ elif model_type == "Wan-FLF2V-14B-720P":
416
+ config = {
417
+ "model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
418
+ "diffusers_config": {
419
+ "image_dim": 1280,
420
+ "added_kv_proj_dim": 5120,
421
+ "attention_head_dim": 128,
422
+ "cross_attn_norm": True,
423
+ "eps": 1e-06,
424
+ "ffn_dim": 13824,
425
+ "freq_dim": 256,
426
+ "in_channels": 36,
427
+ "num_attention_heads": 40,
428
+ "num_layers": 40,
429
+ "out_channels": 16,
430
+ "patch_size": [1, 2, 2],
431
+ "qk_norm": "rms_norm_across_heads",
432
+ "text_dim": 4096,
433
+ "rope_max_seq_len": 1024,
434
+ "pos_embed_seq_len": 257 * 2,
435
+ },
436
+ }
437
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
438
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
439
+ elif model_type == "Wan-VACE-1.3B":
440
+ config = {
441
+ "model_id": "Wan-AI/Wan2.1-VACE-1.3B",
442
+ "diffusers_config": {
443
+ "added_kv_proj_dim": None,
444
+ "attention_head_dim": 128,
445
+ "cross_attn_norm": True,
446
+ "eps": 1e-06,
447
+ "ffn_dim": 8960,
448
+ "freq_dim": 256,
449
+ "in_channels": 16,
450
+ "num_attention_heads": 12,
451
+ "num_layers": 30,
452
+ "out_channels": 16,
453
+ "patch_size": [1, 2, 2],
454
+ "qk_norm": "rms_norm_across_heads",
455
+ "text_dim": 4096,
456
+ "vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
457
+ "vace_in_channels": 96,
458
+ },
459
+ }
460
+ RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
461
+ SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
462
+ elif model_type == "Wan-VACE-14B":
463
+ config = {
464
+ "model_id": "Wan-AI/Wan2.1-VACE-14B",
465
+ "diffusers_config": {
466
+ "added_kv_proj_dim": None,
467
+ "attention_head_dim": 128,
468
+ "cross_attn_norm": True,
469
+ "eps": 1e-06,
470
+ "ffn_dim": 13824,
471
+ "freq_dim": 256,
472
+ "in_channels": 16,
473
+ "num_attention_heads": 40,
474
+ "num_layers": 40,
475
+ "out_channels": 16,
476
+ "patch_size": [1, 2, 2],
477
+ "qk_norm": "rms_norm_across_heads",
478
+ "text_dim": 4096,
479
+ "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
480
+ "vace_in_channels": 96,
481
+ },
482
+ }
483
+ RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
484
+ SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
485
+ elif model_type == "Wan2.2-VACE-Fun-14B":
486
+ config = {
487
+ "model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B",
488
+ "diffusers_config": {
489
+ "added_kv_proj_dim": None,
490
+ "attention_head_dim": 128,
491
+ "cross_attn_norm": True,
492
+ "eps": 1e-06,
493
+ "ffn_dim": 13824,
494
+ "freq_dim": 256,
495
+ "in_channels": 16,
496
+ "num_attention_heads": 40,
497
+ "num_layers": 40,
498
+ "out_channels": 16,
499
+ "patch_size": [1, 2, 2],
500
+ "qk_norm": "rms_norm_across_heads",
501
+ "text_dim": 4096,
502
+ "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
503
+ "vace_in_channels": 96,
504
+ },
505
+ }
506
+ RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
507
+ SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
508
+ elif model_type == "Wan2.2-I2V-14B-720p":
509
+ config = {
510
+ "model_id": "Wan-AI/Wan2.2-I2V-A14B",
511
+ "diffusers_config": {
512
+ "added_kv_proj_dim": None,
513
+ "attention_head_dim": 128,
514
+ "cross_attn_norm": True,
515
+ "eps": 1e-06,
516
+ "ffn_dim": 13824,
517
+ "freq_dim": 256,
518
+ "in_channels": 36,
519
+ "num_attention_heads": 40,
520
+ "num_layers": 40,
521
+ "out_channels": 16,
522
+ "patch_size": [1, 2, 2],
523
+ "qk_norm": "rms_norm_across_heads",
524
+ "text_dim": 4096,
525
+ },
526
+ }
527
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
528
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
529
+ elif model_type == "Wan2.2-T2V-A14B":
530
+ config = {
531
+ "model_id": "Wan-AI/Wan2.2-T2V-A14B",
532
+ "diffusers_config": {
533
+ "added_kv_proj_dim": None,
534
+ "attention_head_dim": 128,
535
+ "cross_attn_norm": True,
536
+ "eps": 1e-06,
537
+ "ffn_dim": 13824,
538
+ "freq_dim": 256,
539
+ "in_channels": 16,
540
+ "num_attention_heads": 40,
541
+ "num_layers": 40,
542
+ "out_channels": 16,
543
+ "patch_size": [1, 2, 2],
544
+ "qk_norm": "rms_norm_across_heads",
545
+ "text_dim": 4096,
546
+ },
547
+ }
548
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
549
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
550
+ elif model_type == "Wan2.2-TI2V-5B":
551
+ config = {
552
+ "model_id": "Wan-AI/Wan2.2-TI2V-5B",
553
+ "diffusers_config": {
554
+ "added_kv_proj_dim": None,
555
+ "attention_head_dim": 128,
556
+ "cross_attn_norm": True,
557
+ "eps": 1e-06,
558
+ "ffn_dim": 14336,
559
+ "freq_dim": 256,
560
+ "in_channels": 48,
561
+ "num_attention_heads": 24,
562
+ "num_layers": 30,
563
+ "out_channels": 48,
564
+ "patch_size": [1, 2, 2],
565
+ "qk_norm": "rms_norm_across_heads",
566
+ "text_dim": 4096,
567
+ },
568
+ }
569
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
570
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
571
+ elif model_type == "Wan2.2-Animate-14B":
572
+ config = {
573
+ "model_id": "Wan-AI/Wan2.2-Animate-14B",
574
+ "diffusers_config": {
575
+ "image_dim": 1280,
576
+ "added_kv_proj_dim": 5120,
577
+ "attention_head_dim": 128,
578
+ "cross_attn_norm": True,
579
+ "eps": 1e-06,
580
+ "ffn_dim": 13824,
581
+ "freq_dim": 256,
582
+ "in_channels": 36,
583
+ "num_attention_heads": 40,
584
+ "num_layers": 40,
585
+ "out_channels": 16,
586
+ "patch_size": (1, 2, 2),
587
+ "qk_norm": "rms_norm_across_heads",
588
+ "text_dim": 4096,
589
+ "rope_max_seq_len": 1024,
590
+ "pos_embed_seq_len": None,
591
+ "motion_encoder_size": 512, # Start of Wan Animate-specific configs
592
+ "motion_style_dim": 512,
593
+ "motion_dim": 20,
594
+ "motion_encoder_dim": 512,
595
+ "face_encoder_hidden_dim": 1024,
596
+ "face_encoder_num_heads": 4,
597
+ "inject_face_latents_blocks": 5,
598
+ },
599
+ }
600
+ RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
601
+ SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
602
+ return config, RENAME_DICT, SPECIAL_KEYS_REMAP
603
+
604
+
605
+ def convert_transformer(model_type: str, stage: str = None):
606
+ config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)
607
+
608
+ diffusers_config = config["diffusers_config"]
609
+ model_id = config["model_id"]
610
+ model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
611
+
612
+ if stage is not None:
613
+ model_dir = model_dir / stage
614
+
615
+ original_state_dict = load_sharded_safetensors(model_dir)
616
+
617
+ with init_empty_weights():
618
+ if "Animate" in model_type:
619
+ transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
620
+ elif "VACE" in model_type:
621
+ transformer = WanVACETransformer3DModel.from_config(diffusers_config)
622
+ else:
623
+ transformer = WanTransformer3DModel.from_config(diffusers_config)
624
+
625
+ for key in list(original_state_dict.keys()):
626
+ new_key = key[:]
627
+ for replace_key, rename_key in RENAME_DICT.items():
628
+ new_key = new_key.replace(replace_key, rename_key)
629
+ update_state_dict_(original_state_dict, key, new_key)
630
+
631
+ for key in list(original_state_dict.keys()):
632
+ for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items():
633
+ if special_key not in key:
634
+ continue
635
+ handler_fn_inplace(key, original_state_dict)
636
+
637
+ # Load state dict into the meta model, which will materialize the tensors
638
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
639
+
640
+ # Move to CPU to ensure all tensors are materialized
641
+ transformer = transformer.to("cpu")
642
+
643
+ return transformer
644
+
645
+
646
+ def convert_vae():
647
+ vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth")
648
+ old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
649
+ new_state_dict = {}
650
+
651
+ # Create mappings for specific components
652
+ middle_key_mapping = {
653
+ # Encoder middle block
654
+ "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
655
+ "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
656
+ "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
657
+ "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
658
+ "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
659
+ "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
660
+ "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
661
+ "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
662
+ "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
663
+ "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
664
+ "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
665
+ "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
666
+ # Decoder middle block
667
+ "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
668
+ "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
669
+ "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
670
+ "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
671
+ "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
672
+ "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
673
+ "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
674
+ "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
675
+ "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
676
+ "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
677
+ "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
678
+ "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
679
+ }
680
+
681
+ # Create a mapping for attention blocks
682
+ attention_mapping = {
683
+ # Encoder middle attention
684
+ "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
685
+ "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
686
+ "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
687
+ "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
688
+ "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
689
+ # Decoder middle attention
690
+ "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
691
+ "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
692
+ "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
693
+ "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
694
+ "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
695
+ }
696
+
697
+ # Create a mapping for the head components
698
+ head_mapping = {
699
+ # Encoder head
700
+ "encoder.head.0.gamma": "encoder.norm_out.gamma",
701
+ "encoder.head.2.bias": "encoder.conv_out.bias",
702
+ "encoder.head.2.weight": "encoder.conv_out.weight",
703
+ # Decoder head
704
+ "decoder.head.0.gamma": "decoder.norm_out.gamma",
705
+ "decoder.head.2.bias": "decoder.conv_out.bias",
706
+ "decoder.head.2.weight": "decoder.conv_out.weight",
707
+ }
708
+
709
+ # Create a mapping for the quant components
710
+ quant_mapping = {
711
+ "conv1.weight": "quant_conv.weight",
712
+ "conv1.bias": "quant_conv.bias",
713
+ "conv2.weight": "post_quant_conv.weight",
714
+ "conv2.bias": "post_quant_conv.bias",
715
+ }
716
+
717
+ # Process each key in the state dict
718
+ for key, value in old_state_dict.items():
719
+ # Handle middle block keys using the mapping
720
+ if key in middle_key_mapping:
721
+ new_key = middle_key_mapping[key]
722
+ new_state_dict[new_key] = value
723
+ # Handle attention blocks using the mapping
724
+ elif key in attention_mapping:
725
+ new_key = attention_mapping[key]
726
+ new_state_dict[new_key] = value
727
+ # Handle head keys using the mapping
728
+ elif key in head_mapping:
729
+ new_key = head_mapping[key]
730
+ new_state_dict[new_key] = value
731
+ # Handle quant keys using the mapping
732
+ elif key in quant_mapping:
733
+ new_key = quant_mapping[key]
734
+ new_state_dict[new_key] = value
735
+ # Handle encoder conv1
736
+ elif key == "encoder.conv1.weight":
737
+ new_state_dict["encoder.conv_in.weight"] = value
738
+ elif key == "encoder.conv1.bias":
739
+ new_state_dict["encoder.conv_in.bias"] = value
740
+ # Handle decoder conv1
741
+ elif key == "decoder.conv1.weight":
742
+ new_state_dict["decoder.conv_in.weight"] = value
743
+ elif key == "decoder.conv1.bias":
744
+ new_state_dict["decoder.conv_in.bias"] = value
745
+ # Handle encoder downsamples
746
+ elif key.startswith("encoder.downsamples."):
747
+ # Convert to down_blocks
748
+ new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
749
+
750
+ # Convert residual block naming but keep the original structure
751
+ if ".residual.0.gamma" in new_key:
752
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
753
+ elif ".residual.2.bias" in new_key:
754
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
755
+ elif ".residual.2.weight" in new_key:
756
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
757
+ elif ".residual.3.gamma" in new_key:
758
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
759
+ elif ".residual.6.bias" in new_key:
760
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
761
+ elif ".residual.6.weight" in new_key:
762
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
763
+ elif ".shortcut.bias" in new_key:
764
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
765
+ elif ".shortcut.weight" in new_key:
766
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
767
+
768
+ new_state_dict[new_key] = value
769
+
770
+ # Handle decoder upsamples
771
+ elif key.startswith("decoder.upsamples."):
772
+ # Convert to up_blocks
773
+ parts = key.split(".")
774
+ block_idx = int(parts[2])
775
+
776
+ # Group residual blocks
777
+ if "residual" in key:
778
+ if block_idx in [0, 1, 2]:
779
+ new_block_idx = 0
780
+ resnet_idx = block_idx
781
+ elif block_idx in [4, 5, 6]:
782
+ new_block_idx = 1
783
+ resnet_idx = block_idx - 4
784
+ elif block_idx in [8, 9, 10]:
785
+ new_block_idx = 2
786
+ resnet_idx = block_idx - 8
787
+ elif block_idx in [12, 13, 14]:
788
+ new_block_idx = 3
789
+ resnet_idx = block_idx - 12
790
+ else:
791
+ # Keep as is for other blocks
792
+ new_state_dict[key] = value
793
+ continue
794
+
795
+ # Convert residual block naming
796
+ if ".residual.0.gamma" in key:
797
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
798
+ elif ".residual.2.bias" in key:
799
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
800
+ elif ".residual.2.weight" in key:
801
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
802
+ elif ".residual.3.gamma" in key:
803
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
804
+ elif ".residual.6.bias" in key:
805
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
806
+ elif ".residual.6.weight" in key:
807
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
808
+ else:
809
+ new_key = key
810
+
811
+ new_state_dict[new_key] = value
812
+
813
+ # Handle shortcut connections
814
+ elif ".shortcut." in key:
815
+ if block_idx == 4:
816
+ new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
817
+ new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
818
+ else:
819
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
820
+ new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
821
+
822
+ new_state_dict[new_key] = value
823
+
824
+ # Handle upsamplers
825
+ elif ".resample." in key or ".time_conv." in key:
826
+ if block_idx == 3:
827
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
828
+ elif block_idx == 7:
829
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
830
+ elif block_idx == 11:
831
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
832
+ else:
833
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
834
+
835
+ new_state_dict[new_key] = value
836
+ else:
837
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
838
+ new_state_dict[new_key] = value
839
+ else:
840
+ # Keep other keys unchanged
841
+ new_state_dict[key] = value
842
+
843
+ with init_empty_weights():
844
+ vae = AutoencoderKLWan()
845
+ vae.load_state_dict(new_state_dict, strict=True, assign=True)
846
+ return vae
847
+
848
+
849
+ vae22_diffusers_config = {
850
+ "base_dim": 160,
851
+ "z_dim": 48,
852
+ "is_residual": True,
853
+ "in_channels": 12,
854
+ "out_channels": 12,
855
+ "decoder_base_dim": 256,
856
+ "scale_factor_temporal": 4,
857
+ "scale_factor_spatial": 16,
858
+ "patch_size": 2,
859
+ "latents_mean": [
860
+ -0.2289,
861
+ -0.0052,
862
+ -0.1323,
863
+ -0.2339,
864
+ -0.2799,
865
+ 0.0174,
866
+ 0.1838,
867
+ 0.1557,
868
+ -0.1382,
869
+ 0.0542,
870
+ 0.2813,
871
+ 0.0891,
872
+ 0.1570,
873
+ -0.0098,
874
+ 0.0375,
875
+ -0.1825,
876
+ -0.2246,
877
+ -0.1207,
878
+ -0.0698,
879
+ 0.5109,
880
+ 0.2665,
881
+ -0.2108,
882
+ -0.2158,
883
+ 0.2502,
884
+ -0.2055,
885
+ -0.0322,
886
+ 0.1109,
887
+ 0.1567,
888
+ -0.0729,
889
+ 0.0899,
890
+ -0.2799,
891
+ -0.1230,
892
+ -0.0313,
893
+ -0.1649,
894
+ 0.0117,
895
+ 0.0723,
896
+ -0.2839,
897
+ -0.2083,
898
+ -0.0520,
899
+ 0.3748,
900
+ 0.0152,
901
+ 0.1957,
902
+ 0.1433,
903
+ -0.2944,
904
+ 0.3573,
905
+ -0.0548,
906
+ -0.1681,
907
+ -0.0667,
908
+ ],
909
+ "latents_std": [
910
+ 0.4765,
911
+ 1.0364,
912
+ 0.4514,
913
+ 1.1677,
914
+ 0.5313,
915
+ 0.4990,
916
+ 0.4818,
917
+ 0.5013,
918
+ 0.8158,
919
+ 1.0344,
920
+ 0.5894,
921
+ 1.0901,
922
+ 0.6885,
923
+ 0.6165,
924
+ 0.8454,
925
+ 0.4978,
926
+ 0.5759,
927
+ 0.3523,
928
+ 0.7135,
929
+ 0.6804,
930
+ 0.5833,
931
+ 1.4146,
932
+ 0.8986,
933
+ 0.5659,
934
+ 0.7069,
935
+ 0.5338,
936
+ 0.4889,
937
+ 0.4917,
938
+ 0.4069,
939
+ 0.4999,
940
+ 0.6866,
941
+ 0.4093,
942
+ 0.5709,
943
+ 0.6065,
944
+ 0.6415,
945
+ 0.4944,
946
+ 0.5726,
947
+ 1.2042,
948
+ 0.5458,
949
+ 1.6887,
950
+ 0.3971,
951
+ 1.0600,
952
+ 0.3943,
953
+ 0.5537,
954
+ 0.5444,
955
+ 0.4089,
956
+ 0.7468,
957
+ 0.7744,
958
+ ],
959
+ "clip_output": False,
960
+ }
961
+
962
+
963
+ def convert_vae_22():
964
+ vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.2-TI2V-5B", "Wan2.2_VAE.pth")
965
+ old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
966
+ new_state_dict = {}
967
+
968
+ # Create mappings for specific components
969
+ middle_key_mapping = {
970
+ # Encoder middle block
971
+ "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
972
+ "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
973
+ "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
974
+ "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
975
+ "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
976
+ "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
977
+ "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
978
+ "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
979
+ "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
980
+ "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
981
+ "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
982
+ "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
983
+ # Decoder middle block
984
+ "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
985
+ "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
986
+ "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
987
+ "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
988
+ "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
989
+ "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
990
+ "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
991
+ "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
992
+ "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
993
+ "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
994
+ "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
995
+ "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
996
+ }
997
+
998
+ # Create a mapping for attention blocks
999
+ attention_mapping = {
1000
+ # Encoder middle attention
1001
+ "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
1002
+ "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
1003
+ "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
1004
+ "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
1005
+ "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
1006
+ # Decoder middle attention
1007
+ "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
1008
+ "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
1009
+ "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
1010
+ "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
1011
+ "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
1012
+ }
1013
+
1014
+ # Create a mapping for the head components
1015
+ head_mapping = {
1016
+ # Encoder head
1017
+ "encoder.head.0.gamma": "encoder.norm_out.gamma",
1018
+ "encoder.head.2.bias": "encoder.conv_out.bias",
1019
+ "encoder.head.2.weight": "encoder.conv_out.weight",
1020
+ # Decoder head
1021
+ "decoder.head.0.gamma": "decoder.norm_out.gamma",
1022
+ "decoder.head.2.bias": "decoder.conv_out.bias",
1023
+ "decoder.head.2.weight": "decoder.conv_out.weight",
1024
+ }
1025
+
1026
+ # Create a mapping for the quant components
1027
+ quant_mapping = {
1028
+ "conv1.weight": "quant_conv.weight",
1029
+ "conv1.bias": "quant_conv.bias",
1030
+ "conv2.weight": "post_quant_conv.weight",
1031
+ "conv2.bias": "post_quant_conv.bias",
1032
+ }
1033
+
1034
+ # Process each key in the state dict
1035
+ for key, value in old_state_dict.items():
1036
+ # Handle middle block keys using the mapping
1037
+ if key in middle_key_mapping:
1038
+ new_key = middle_key_mapping[key]
1039
+ new_state_dict[new_key] = value
1040
+ # Handle attention blocks using the mapping
1041
+ elif key in attention_mapping:
1042
+ new_key = attention_mapping[key]
1043
+ new_state_dict[new_key] = value
1044
+ # Handle head keys using the mapping
1045
+ elif key in head_mapping:
1046
+ new_key = head_mapping[key]
1047
+ new_state_dict[new_key] = value
1048
+ # Handle quant keys using the mapping
1049
+ elif key in quant_mapping:
1050
+ new_key = quant_mapping[key]
1051
+ new_state_dict[new_key] = value
1052
+ # Handle encoder conv1
1053
+ elif key == "encoder.conv1.weight":
1054
+ new_state_dict["encoder.conv_in.weight"] = value
1055
+ elif key == "encoder.conv1.bias":
1056
+ new_state_dict["encoder.conv_in.bias"] = value
1057
+ # Handle decoder conv1
1058
+ elif key == "decoder.conv1.weight":
1059
+ new_state_dict["decoder.conv_in.weight"] = value
1060
+ elif key == "decoder.conv1.bias":
1061
+ new_state_dict["decoder.conv_in.bias"] = value
1062
+ # Handle encoder downsamples
1063
+ elif key.startswith("encoder.downsamples."):
1064
+ # Change encoder.downsamples to encoder.down_blocks
1065
+ new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
1066
+
1067
+ # Handle residual blocks - change downsamples to resnets and rename components
1068
+ if "residual" in new_key or "shortcut" in new_key:
1069
+ # Change the second downsamples to resnets
1070
+ new_key = new_key.replace(".downsamples.", ".resnets.")
1071
+
1072
+ # Rename residual components
1073
+ if ".residual.0.gamma" in new_key:
1074
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
1075
+ elif ".residual.2.weight" in new_key:
1076
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
1077
+ elif ".residual.2.bias" in new_key:
1078
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
1079
+ elif ".residual.3.gamma" in new_key:
1080
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
1081
+ elif ".residual.6.weight" in new_key:
1082
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
1083
+ elif ".residual.6.bias" in new_key:
1084
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
1085
+ elif ".shortcut.weight" in new_key:
1086
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
1087
+ elif ".shortcut.bias" in new_key:
1088
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
1089
+
1090
+ # Handle resample blocks - change downsamples to downsampler and remove index
1091
+ elif "resample" in new_key or "time_conv" in new_key:
1092
+ # Change the second downsamples to downsampler and remove the index
1093
+ parts = new_key.split(".")
1094
+ # Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
1095
+ # We want to change it to: encoder.down_blocks.X.downsampler.resample...
1096
+ if len(parts) >= 4 and parts[3] == "downsamples":
1097
+ # Remove the index (parts[4]) and change downsamples to downsampler
1098
+ new_parts = parts[:3] + ["downsampler"] + parts[5:]
1099
+ new_key = ".".join(new_parts)
1100
+
1101
+ new_state_dict[new_key] = value
1102
+
1103
+ # Handle decoder upsamples
1104
+ elif key.startswith("decoder.upsamples."):
1105
+ # Change decoder.upsamples to decoder.up_blocks
1106
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
1107
+
1108
+ # Handle residual blocks - change upsamples to resnets and rename components
1109
+ if "residual" in new_key or "shortcut" in new_key:
1110
+ # Change the second upsamples to resnets
1111
+ new_key = new_key.replace(".upsamples.", ".resnets.")
1112
+
1113
+ # Rename residual components
1114
+ if ".residual.0.gamma" in new_key:
1115
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
1116
+ elif ".residual.2.weight" in new_key:
1117
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
1118
+ elif ".residual.2.bias" in new_key:
1119
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
1120
+ elif ".residual.3.gamma" in new_key:
1121
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
1122
+ elif ".residual.6.weight" in new_key:
1123
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
1124
+ elif ".residual.6.bias" in new_key:
1125
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
1126
+ elif ".shortcut.weight" in new_key:
1127
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
1128
+ elif ".shortcut.bias" in new_key:
1129
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
1130
+
1131
+ # Handle resample blocks - change upsamples to upsampler and remove index
1132
+ elif "resample" in new_key or "time_conv" in new_key:
1133
+ # Change the second upsamples to upsampler and remove the index
1134
+ parts = new_key.split(".")
1135
+ # Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
1136
+ # We want to change it to: encoder.down_blocks.X.downsampler.resample...
1137
+ if len(parts) >= 4 and parts[3] == "upsamples":
1138
+ # Remove the index (parts[4]) and change upsamples to upsampler
1139
+ new_parts = parts[:3] + ["upsampler"] + parts[5:]
1140
+ new_key = ".".join(new_parts)
1141
+
1142
+ new_state_dict[new_key] = value
1143
+ else:
1144
+ # Keep other keys unchanged
1145
+ new_state_dict[key] = value
1146
+
1147
+ with init_empty_weights():
1148
+ vae = AutoencoderKLWan(**vae22_diffusers_config)
1149
+ vae.load_state_dict(new_state_dict, strict=True, assign=True)
1150
+ return vae
1151
+
1152
+
1153
+ def get_args():
1154
+ parser = argparse.ArgumentParser()
1155
+ parser.add_argument("--model_type", type=str, default=None)
1156
+ parser.add_argument("--output_path", type=str, required=True)
1157
+ parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"])
1158
+ return parser.parse_args()
1159
+
1160
+
1161
+ DTYPE_MAPPING = {
1162
+ "fp32": torch.float32,
1163
+ "fp16": torch.float16,
1164
+ "bf16": torch.bfloat16,
1165
+ }
1166
+
1167
+
1168
+ if __name__ == "__main__":
1169
+ args = get_args()
1170
+
1171
+ if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type:
1172
+ transformer = convert_transformer(args.model_type, stage="high_noise_model")
1173
+ transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
1174
+ else:
1175
+ transformer = convert_transformer(args.model_type)
1176
+ transformer_2 = None
1177
+
1178
+ if "Wan2.2" in args.model_type and "TI2V" in args.model_type:
1179
+ vae = convert_vae_22()
1180
+ else:
1181
+ vae = convert_vae()
1182
+
1183
+ text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
1184
+ tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
1185
+ if "FLF2V" in args.model_type:
1186
+ flow_shift = 16.0
1187
+ elif "TI2V" in args.model_type or "Animate" in args.model_type:
1188
+ flow_shift = 5.0
1189
+ else:
1190
+ flow_shift = 3.0
1191
+ scheduler = UniPCMultistepScheduler(
1192
+ prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
1193
+ )
1194
+
1195
+ # If user has specified "none", we keep the original dtypes of the state dict without any conversion
1196
+ if args.dtype != "none":
1197
+ dtype = DTYPE_MAPPING[args.dtype]
1198
+ transformer.to(dtype)
1199
+ if transformer_2 is not None:
1200
+ transformer_2.to(dtype)
1201
+
1202
+ if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
1203
+ pipe = WanImageToVideoPipeline(
1204
+ transformer=transformer,
1205
+ transformer_2=transformer_2,
1206
+ text_encoder=text_encoder,
1207
+ tokenizer=tokenizer,
1208
+ vae=vae,
1209
+ scheduler=scheduler,
1210
+ boundary_ratio=0.9,
1211
+ )
1212
+ elif "Wan2.2" and "T2V" in args.model_type:
1213
+ pipe = WanPipeline(
1214
+ transformer=transformer,
1215
+ transformer_2=transformer_2,
1216
+ text_encoder=text_encoder,
1217
+ tokenizer=tokenizer,
1218
+ vae=vae,
1219
+ scheduler=scheduler,
1220
+ boundary_ratio=0.875,
1221
+ )
1222
+ elif "Wan2.2" and "TI2V" in args.model_type:
1223
+ pipe = WanPipeline(
1224
+ transformer=transformer,
1225
+ text_encoder=text_encoder,
1226
+ tokenizer=tokenizer,
1227
+ vae=vae,
1228
+ scheduler=scheduler,
1229
+ expand_timesteps=True,
1230
+ )
1231
+ elif "I2V" in args.model_type or "FLF2V" in args.model_type:
1232
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
1233
+ "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
1234
+ )
1235
+ image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
1236
+ pipe = WanImageToVideoPipeline(
1237
+ transformer=transformer,
1238
+ text_encoder=text_encoder,
1239
+ tokenizer=tokenizer,
1240
+ vae=vae,
1241
+ scheduler=scheduler,
1242
+ image_encoder=image_encoder,
1243
+ image_processor=image_processor,
1244
+ )
1245
+ elif "Wan2.2-VACE" in args.model_type:
1246
+ pipe = WanVACEPipeline(
1247
+ transformer=transformer,
1248
+ transformer_2=transformer_2,
1249
+ text_encoder=text_encoder,
1250
+ tokenizer=tokenizer,
1251
+ vae=vae,
1252
+ scheduler=scheduler,
1253
+ boundary_ratio=0.875,
1254
+ )
1255
+ elif "Wan-VACE" in args.model_type:
1256
+ pipe = WanVACEPipeline(
1257
+ transformer=transformer,
1258
+ text_encoder=text_encoder,
1259
+ tokenizer=tokenizer,
1260
+ vae=vae,
1261
+ scheduler=scheduler,
1262
+ )
1263
+ elif "Animate" in args.model_type:
1264
+ image_encoder = CLIPVisionModel.from_pretrained(
1265
+ "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
1266
+ )
1267
+ image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
1268
+
1269
+ pipe = WanAnimatePipeline(
1270
+ transformer=transformer,
1271
+ text_encoder=text_encoder,
1272
+ tokenizer=tokenizer,
1273
+ vae=vae,
1274
+ scheduler=scheduler,
1275
+ image_encoder=image_encoder,
1276
+ image_processor=image_processor,
1277
+ )
1278
+ else:
1279
+ pipe = WanPipeline(
1280
+ transformer=transformer,
1281
+ text_encoder=text_encoder,
1282
+ tokenizer=tokenizer,
1283
+ vae=vae,
1284
+ scheduler=scheduler,
1285
+ )
1286
+
1287
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
sudoku/generate_dataset.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sudoku Video Dataset Generator - Supports flexible solution count expressions per puzzle.
3
+ With checkpoint/resume support via metadata.json.
4
+ """
5
+ import json
6
+ import re
7
+ import random
8
+ import argparse
9
+ from dataclasses import dataclass, asdict
10
+ from pathlib import Path
11
+ from typing import List, Tuple, Optional, Union, Dict, Any
12
+ import numpy as np
13
+ import cv2
14
+ from tqdm import tqdm
15
+ from sudoku_processor import SudokuProcessor
16
+
17
+
18
+ # ==================== Solution Range ====================
19
+
20
+ @dataclass
21
+ class SolRange:
22
+ """Flexible solution count constraint for puzzle generation."""
23
+ min_sol: int
24
+ max_sol: Optional[int]
25
+
26
+ @classmethod
27
+ def parse(cls, expr: str) -> "SolRange":
28
+ expr = expr.strip()
29
+ m = re.fullmatch(r'(\d+)\s*-\s*(\d+)', expr)
30
+ if m:
31
+ lo, hi = int(m.group(1)), int(m.group(2))
32
+ if lo < 1: raise ValueError(f"min_sol must be >= 1, got {lo}")
33
+ if hi < lo: raise ValueError(f"Invalid range: {lo}-{hi}")
34
+ return cls(min_sol=lo, max_sol=hi)
35
+ m = re.fullmatch(r'(>=|>|<=|<|==)\s*(\d+)', expr)
36
+ if m:
37
+ op, n = m.group(1), int(m.group(2))
38
+ if op == '>=': return cls(min_sol=max(1, n), max_sol=None)
39
+ elif op == '>': return cls(min_sol=max(1, n + 1), max_sol=None)
40
+ elif op == '<=': return cls(min_sol=1, max_sol=n)
41
+ elif op == '<': return cls(min_sol=1, max_sol=max(1, n - 1))
42
+ elif op == '==': return cls(min_sol=n, max_sol=n)
43
+ m = re.fullmatch(r'(\d+)', expr)
44
+ if m:
45
+ n = int(m.group(1))
46
+ if n < 1: raise ValueError(f"sol_num must be >= 1, got {n}")
47
+ return cls(min_sol=n, max_sol=n)
48
+ raise ValueError(f"Invalid sol_num expression: '{expr}'")
49
+
50
+ @property
51
+ def is_exact(self): return self.max_sol is not None and self.min_sol == self.max_sol
52
+ @property
53
+ def is_unique_only(self): return self.is_exact and self.min_sol == 1
54
+ @property
55
+ def allows_unique(self): return self.min_sol <= 1
56
+ @property
57
+ def requires_multi(self): return self.min_sol > 1
58
+ @property
59
+ def effective_max(self): return self.max_sol if self.max_sol is not None else max(self.min_sol, 10)
60
+ def accepts(self, count):
61
+ if count < self.min_sol: return False
62
+ if self.max_sol is not None and count > self.max_sol: return False
63
+ return True
64
+ def __repr__(self):
65
+ if self.is_exact: return f"SolRange(=={self.min_sol})"
66
+ if self.max_sol is None: return f"SolRange(>={self.min_sol})"
67
+ return f"SolRange({self.min_sol}-{self.max_sol})"
68
+
69
+
70
+ # ==================== Checkpoint Management ====================
71
+
72
+ @dataclass
73
+ class GenerationState:
74
+ """Tracks generation progress for checkpoint/resume."""
75
+ params_hash: str
76
+ clue_progress: Dict[int, int] # clue_level -> generated_count
77
+ seen_grids: List[str]
78
+ all_samples: List[Dict]
79
+ completed: bool = False
80
+
81
+ def to_dict(self) -> Dict:
82
+ return asdict(self)
83
+
84
+ @classmethod
85
+ def from_dict(cls, d: Dict) -> "GenerationState":
86
+ return cls(**d)
87
+
88
+
89
+ def compute_params_hash(params: Dict) -> str:
90
+ """Compute hash of generation parameters for consistency check."""
91
+ import hashlib
92
+ # Only hash parameters that affect generation logic
93
+ key_params = {k: v for k, v in params.items()
94
+ if k not in ['output_dir']} # output_dir can differ
95
+ return hashlib.md5(json.dumps(key_params, sort_keys=True).encode()).hexdigest()[:12]
96
+
97
+
98
+ def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]:
99
+ """Load checkpoint if exists and params match."""
100
+ meta_path = output_dir / "metadata.json"
101
+ if not meta_path.exists():
102
+ return None
103
+
104
+ with open(meta_path) as f:
105
+ data = json.load(f)
106
+
107
+ state = GenerationState.from_dict(data["state"])
108
+ expected_hash = compute_params_hash(params)
109
+
110
+ if state.params_hash != expected_hash:
111
+ print(f"⚠️ Parameters changed (hash {state.params_hash} → {expected_hash}), starting fresh")
112
+ return None
113
+
114
+ if state.completed:
115
+ print("✓ Generation already completed")
116
+ return state
117
+
118
+ print(f"✓ Resuming from checkpoint: {sum(state.clue_progress.values())} puzzles generated")
119
+ return state
120
+
121
+
122
+ def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict):
123
+ """Save current generation state to metadata.json."""
124
+ meta_path = output_dir / "metadata.json"
125
+ data = {
126
+ "params": params,
127
+ "state": state.to_dict()
128
+ }
129
+ # Atomic write
130
+ tmp_path = meta_path.with_suffix('.tmp')
131
+ with open(tmp_path, 'w') as f:
132
+ json.dump(data, f, indent=2)
133
+ tmp_path.rename(meta_path)
134
+
135
+
136
+ # ==================== Core Functions ====================
137
+
138
+ def get_fill_order(puzzle, solution):
139
+ return [(i, j, solution[i][j]) for i in range(9) for j in range(9) if puzzle[i][j] == 0]
140
+
141
+ def create_processor(resolution=None):
142
+ if resolution is None: return SudokuProcessor()
143
+ target_size = min(resolution)
144
+ cell_size = target_size // 9
145
+ sf = cell_size / 60
146
+ return SudokuProcessor(cell_size=cell_size, font_scale=1.2*sf, thickness=max(1, int(2*sf)))
147
+
148
+ def generate_video_frames(proc, puzzle, solution, n_start, m_end, k=1, max_frames=None):
149
+ fills = get_fill_order(puzzle, solution)
150
+ n_fills = len(fills)
151
+ effective_k = k
152
+ if max_frames is not None and n_start + n_fills * k + m_end > max_frames:
153
+ avail = max_frames - n_start - m_end
154
+ effective_k = max(1, avail // n_fills) if avail > 0 and n_fills > 0 else 1
155
+
156
+ frames = []
157
+ current = [row[:] for row in puzzle]
158
+ img = proc.render(current)
159
+ frames.extend([img.copy() for _ in range(n_start)])
160
+
161
+ for r, c, v in fills:
162
+ current[r][c] = v
163
+ frames.append(proc.render(current, highlight_new=(r, c), original=puzzle))
164
+ if effective_k > 1:
165
+ img = proc.render(current, original=puzzle)
166
+ frames.extend([img.copy() for _ in range(effective_k - 1)])
167
+
168
+ img = proc.render(solution, original=puzzle)
169
+ frames.extend([img.copy() for _ in range(m_end)])
170
+ if max_frames is not None and len(frames) > max_frames:
171
+ frames = frames[:max_frames]
172
+ return frames
173
+
174
+ def save_video(frames, path, fps=10):
175
+ h, w = frames[0].shape[:2]
176
+ writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
177
+ for f in frames: writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
178
+ writer.release()
179
+
180
+ def normalize_num_per_clue(num_per_clue, clue_levels):
181
+ if isinstance(num_per_clue, int): return [num_per_clue] * len(clue_levels)
182
+ if len(num_per_clue) != len(clue_levels):
183
+ raise ValueError(f"num_per_clue length ({len(num_per_clue)}) != clue_levels ({len(clue_levels)})")
184
+ return num_per_clue
185
+
186
+
187
+ # ==================== Puzzle Generation with SolRange ====================
188
+
189
+ def generate_puzzle_with_range(proc, clue, sol_range, min_hamming):
190
+ """Generate one puzzle respecting sol_range. Returns (puzzle, solutions) or None."""
191
+ if sol_range.is_unique_only:
192
+ puzzle, solution = proc.generate(clue, unique=True)
193
+ return puzzle, [solution]
194
+
195
+ if sol_range.requires_multi:
196
+ try:
197
+ puzzle, solutions = proc.generate_multi_solution(
198
+ clue, min_solutions=sol_range.min_sol,
199
+ max_solutions=sol_range.effective_max,
200
+ max_attempts=1, min_hamming=min_hamming
201
+ )
202
+ if sol_range.accepts(len(solutions)):
203
+ return puzzle, solutions
204
+ except RuntimeError:
205
+ pass
206
+ return None
207
+
208
+ try:
209
+ puzzle, solutions = proc.generate_multi_solution(
210
+ clue, min_solutions=max(2, sol_range.min_sol),
211
+ max_solutions=sol_range.effective_max,
212
+ max_attempts=1, min_hamming=min_hamming
213
+ )
214
+ if sol_range.accepts(len(solutions)):
215
+ return puzzle, solutions
216
+ except RuntimeError:
217
+ pass
218
+
219
+ if sol_range.allows_unique:
220
+ puzzle, solution = proc.generate(clue, unique=True)
221
+ return puzzle, [solution]
222
+ return None
223
+
224
+
225
+ # ==================== Dataset Generation ====================
226
+
227
+ def generate_dataset(
228
+ output_dir="sudoku_video", clue_levels=[30,40,50,60], num_per_clue=50,
229
+ sol_num="1", min_hamming=10, train_ratio=0.8,
230
+ prompt="Solve this Sudoku puzzle using red font.",
231
+ n_start=10, m_end=10, k=1, max_frames=None, fps=10,
232
+ resolution=None, seed=42, checkpoint_interval=50
233
+ ):
234
+ """
235
+ Generate Sudoku video dataset with checkpoint/resume support.
236
+
237
+ Args:
238
+ checkpoint_interval: Save checkpoint every N puzzles (default: 50)
239
+ """
240
+ # Prepare params dict for hashing
241
+ params = {
242
+ "clue_levels": clue_levels, "num_per_clue": num_per_clue,
243
+ "sol_num": sol_num, "min_hamming": min_hamming, "train_ratio": train_ratio,
244
+ "prompt": prompt, "n_start": n_start, "m_end": m_end, "k": k,
245
+ "max_frames": max_frames, "fps": fps, "resolution": resolution, "seed": seed
246
+ }
247
+
248
+ output_dir = Path(output_dir)
249
+ video_dir = output_dir / "videos"
250
+ image_dir = output_dir / "images"
251
+ video_dir.mkdir(parents=True, exist_ok=True)
252
+ image_dir.mkdir(parents=True, exist_ok=True)
253
+
254
+ # Try to resume from checkpoint
255
+ state = load_checkpoint(output_dir, params)
256
+
257
+ if state and state.completed:
258
+ return # Already done
259
+
260
+ sol_range = SolRange.parse(str(sol_num))
261
+ proc = create_processor(resolution)
262
+ actual_size = proc.img_size
263
+ num_per_clue_list = normalize_num_per_clue(num_per_clue, clue_levels)
264
+ max_puzzles = max(num_per_clue_list)
265
+ num_width = len(str(max_puzzles))
266
+
267
+ # Initialize or restore state
268
+ if state is None:
269
+ random.seed(seed)
270
+ state = GenerationState(
271
+ params_hash=compute_params_hash(params),
272
+ clue_progress={clue: 0 for clue in clue_levels},
273
+ seen_grids=[],
274
+ all_samples=[]
275
+ )
276
+ print(f"Starting fresh generation with solution range: {sol_range}")
277
+ else:
278
+ # Restore RNG state approximately by fast-forwarding
279
+ random.seed(seed)
280
+ for _ in range(sum(state.clue_progress.values()) * 10):
281
+ random.random()
282
+
283
+ seen_grids = set(state.seen_grids)
284
+ all_samples = state.all_samples.copy()
285
+ clue_progress = {int(k): v for k, v in state.clue_progress.items()}
286
+
287
+ total_target = sum(num_per_clue_list)
288
+ total_done = sum(clue_progress.values())
289
+ stats_unique = sum(1 for s in all_samples if s["total_solutions"] == 1 and s["sol_idx"] == 0)
290
+ stats_multi = sum(1 for s in all_samples if s["total_solutions"] > 1 and s["sol_idx"] == 0)
291
+ puzzles_since_checkpoint = 0
292
+
293
+ with tqdm(total=total_target, initial=total_done, desc="Total", unit="puzzle") as pbar_total:
294
+ for clue, target_count in zip(clue_levels, num_per_clue_list):
295
+ generated = clue_progress.get(clue, 0)
296
+ if generated >= target_count:
297
+ continue # This clue level is done
298
+
299
+ max_attempts = (target_count - generated) * 20
300
+
301
+ with tqdm(total=target_count, initial=generated, desc=f"Clue {clue:2d}",
302
+ unit="puzzle", leave=False) as pbar_clue:
303
+ for _ in range(max_attempts):
304
+ if generated >= target_count:
305
+ break
306
+
307
+ result = generate_puzzle_with_range(proc, clue, sol_range, min_hamming)
308
+ if result is None:
309
+ continue
310
+ puzzle, solutions = result
311
+
312
+ fp = proc.encode(puzzle)
313
+ if fp in seen_grids:
314
+ continue
315
+ seen_grids.add(fp)
316
+
317
+ n_sols = len(solutions)
318
+ if n_sols == 1:
319
+ stats_unique += 1
320
+ else:
321
+ stats_multi += 1
322
+
323
+ img_name = f"clue{clue}_{generated:0{num_width}d}.png"
324
+ puzzle_img = proc.render(puzzle)
325
+ cv2.imwrite(str(image_dir / img_name), cv2.cvtColor(puzzle_img, cv2.COLOR_RGB2BGR))
326
+
327
+ for si, sol in enumerate(solutions):
328
+ vid_name = f"clue{clue}_{generated:0{num_width}d}_sol{si}.mp4"
329
+ frames = generate_video_frames(proc, puzzle, sol, n_start, m_end, k, max_frames)
330
+ save_video(frames, video_dir / vid_name, fps)
331
+
332
+ hdists = [proc._hamming(sol, solutions[j]) for j in range(n_sols) if j != si]
333
+ all_samples.append({
334
+ "prompt": prompt, "video": vid_name, "image": img_name,
335
+ "clue": clue, "puzzle": fp, "solution": proc.encode(sol),
336
+ "sol_idx": si, "total_solutions": n_sols,
337
+ "frame_count": len(frames),
338
+ "min_hamming_to_others": min(hdists) if hdists else 0
339
+ })
340
+
341
+ generated += 1
342
+ clue_progress[clue] = generated
343
+ puzzles_since_checkpoint += 1
344
+ pbar_clue.update(1)
345
+ pbar_total.update(1)
346
+
347
+ # Periodic checkpoint
348
+ if puzzles_since_checkpoint >= checkpoint_interval:
349
+ state.clue_progress = clue_progress
350
+ state.seen_grids = list(seen_grids)
351
+ state.all_samples = all_samples
352
+ save_checkpoint(output_dir, state, params)
353
+ puzzles_since_checkpoint = 0
354
+
355
+ tqdm.write(f"Clue {clue}: {generated} puzzles, "
356
+ f"{sum(1 for s in all_samples if s['clue'] == clue)} videos")
357
+
358
+ # Final output
359
+ random.seed(seed + 1) # Deterministic shuffle
360
+ random.shuffle(all_samples)
361
+ split_idx = int(len(all_samples) * train_ratio)
362
+
363
+ def write_jsonl(samples, path):
364
+ with open(path, 'w') as f:
365
+ for s in samples:
366
+ json.dump(s, f)
367
+ f.write('\n')
368
+
369
+ write_jsonl(all_samples[:split_idx], output_dir / "train.jsonl")
370
+ write_jsonl(all_samples[split_idx:], output_dir / "test.jsonl")
371
+
372
+ # Mark as completed
373
+ state.clue_progress = clue_progress
374
+ state.seen_grids = list(seen_grids)
375
+ state.all_samples = all_samples
376
+ state.completed = True
377
+ save_checkpoint(output_dir, state, params)
378
+
379
+ print(f"\n✓ Dataset complete: {output_dir}/")
380
+ print(f" Resolution: {actual_size}x{actual_size}")
381
+ print(f" Solution range: {sol_range}")
382
+ print(f" Puzzles: {len(seen_grids)} ({stats_unique} unique, {stats_multi} multi-sol)")
383
+ print(f" Videos: {len(all_samples)}")
384
+ print(f" Train: {split_idx}, Test: {len(all_samples) - split_idx}")
385
+
386
+ hammings = [s["min_hamming_to_others"] for s in all_samples if s["min_hamming_to_others"] > 0]
387
+ if hammings:
388
+ print(f" Solution diversity: avg={np.mean(hammings):.1f}, min={min(hammings)}, max={max(hammings)}")
389
+
390
+
391
+ def parse_resolution(s):
392
+ w, h = map(int, s.lower().split('x'))
393
+ return (w, h)
394
+
395
+ def parse_args():
396
+ p = argparse.ArgumentParser(description="Generate Sudoku video dataset with resume support")
397
+ p.add_argument("--output-dir", type=str, default="sudoku")
398
+ p.add_argument("--clue-levels", type=int, nargs="+", default=[20,30,40,50,60,70])
399
+ p.add_argument("--num-per-clue", type=int, nargs="+", default=[15000,10000,10000,5000,2000,1000])
400
+ p.add_argument("--sol-num", type=str, default="<=3",
401
+ help="'1', '3', '>=1', '>1', '<=3', '<3', '2-5'")
402
+ p.add_argument("--min-hamming", type=int, default=10)
403
+ p.add_argument("--train-ratio", type=float, default=0.9)
404
+ p.add_argument("--prompt", type=str, default="Solve this Sudoku puzzle using red font.")
405
+ p.add_argument("--n-start", type=int, default=2)
406
+ p.add_argument("--m-end", type=int, default=3)
407
+ p.add_argument("--k", type=int, default=1)
408
+ p.add_argument("--max-frames", type=int, default=None)
409
+ p.add_argument("--fps", type=int, default=10)
410
+ p.add_argument("--resolution", type=str, default="1024x1024")
411
+ p.add_argument("--seed", type=int, default=42)
412
+ p.add_argument("--checkpoint-interval", type=int, default=50,
413
+ help="Save checkpoint every N puzzles (default: 50)")
414
+ return p.parse_args()
415
+
416
+
417
+ if __name__ == "__main__":
418
+ args = parse_args()
419
+ kwargs = vars(args)
420
+ if isinstance(kwargs["num_per_clue"], list) and len(kwargs["num_per_clue"]) == 1:
421
+ kwargs["num_per_clue"] = kwargs["num_per_clue"][0]
422
+ if kwargs["resolution"]:
423
+ kwargs["resolution"] = parse_resolution(kwargs["resolution"])
424
+ generate_dataset(**kwargs)
sudoku/jsonl_to_csv.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import csv
3
+ from pathlib import Path
4
+
5
+ dataset='sudoku'
6
+ split='train'
7
+
8
+ # Load test data
9
+ with open(f'{dataset}/{split}_info.jsonl', 'r') as f:
10
+ data = [json.loads(line) for line in f]
11
+
12
+ # Write to CSV
13
+ with open(f'{dataset}/{split}.csv', 'w', newline='', encoding='utf-8') as f:
14
+ writer = csv.writer(f)
15
+ writer.writerow(['input_image', 'video', 'prompt'])
16
+
17
+ for idx, item in enumerate(data):
18
+ writer.writerow([
19
+ 'images/' + item['image'],
20
+ 'videos/' + item['video'],
21
+ item['prompt'],
22
+ ])
sudoku/simplify_dataset.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ dataset = 'sudoku_600'
4
+ split = 'test'
5
+
6
+ # Read original data
7
+ with open(f'{dataset}/{split}.jsonl', 'r') as f:
8
+ data = [json.loads(line) for line in f]
9
+
10
+ # Transform to simplified format
11
+ new_data = [{'prompt': d['prompt'], 'image': d['image']} for d in data]
12
+
13
+ # Save simplified data to {split}.jsonl
14
+ with open(f'{dataset}/{split}.jsonl', 'w') as f:
15
+ f.writelines(json.dumps(item) + '\n' for item in new_data)
16
+
17
+ # Save original data to {split}_info.jsonl
18
+ with open(f'{dataset}/{split}_info.jsonl', 'w') as f:
19
+ f.writelines(json.dumps(item) + '\n' for item in data)
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00001-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a287551bf47373c4e66324c27a84ce2daa48c89acdde4eb8d89178d8ad09da9
3
+ size 4992484608
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00002-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35a65b38950cf1b3d01460bec6b03e5efdc66854a678f5089a668a2664a91f4c
3
+ size 4898551584
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00003-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:498c9dc1c9d6edabf5514ff30440c752507c07afba637ea0793b611dcd4fd4ea
3
+ size 4987667104
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00004-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e33036a5d81e8a72a5ac48657af77bcd751a7dce06e6e883643f1a7e377e69c6
3
+ size 4987711216
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00005-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f9f17087d5bd04bb83dfe3958bf03a54def4b6029f9677b19b9c98484a9b742
3
+ size 4950959936
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00006-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29381ca0d3790260c577b3849efd7ecb87d1550dfc2ca52ae1ab7052c615bf32
3
+ size 4950980632
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00007-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2214191998778c951068e0b15363ddab49e089270b58ef62935aa4a523a5358a
3
+ size 3021537400
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:504d9c47c4e559e8df2d2b93a69e094ac2b9d1e65f8124849c7e719b47380952
3
+ size 32789894056
sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae539b75e0dc5a5809b3b87dd5b02abc1fa9cda2e79b0d9ca63d5954795eaeec
3
+ size 9999659704
sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b303e8e24c73ca768be637d2e787f97b95d8c8ae36ae56e8986842549bd69a0b
3
+ size 9999659704
sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c1e05b6a600d53a11306fddc78fce8660cc13716ad330f7c51adb8258118673
3
+ size 9999659704
sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-3.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5eae0f23c6c4786bcebd2289bc7e06a912f726c94d2e77f20d0b543630751ca
3
+ size 9999659704
sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-4.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d8ab68aeec9fb7b79dac05b95eb4d2d5b9e3cbe20eaeef37adbca4c6c5565fd
3
+ size 9999659704
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31/epoch-3.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f2d8a61d1723198c9cc194edb7e562e5aa84e9e7dd358dcffa492d98b5c8c1c
3
+ size 32789894056
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00001-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e869c2d79321f9c73a4443150d45b2b7455a92c9cb3a3fb1321d698a882fbf3e
3
+ size 4992484608
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00002-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c24ed79955530349d2f53c16cd038e17281f3728f129b950088812f9d4393a5
3
+ size 4898551584
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00003-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8721792f07a2411d5d07639bac2a52585aeb20727de09a835b0f6251385a4ee7
3
+ size 4987667104
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00004-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:342d017549cd13b866272c0c7ae43f9464ce07936a0c77e3f6c6c36b8a6c8aeb
3
+ size 4987711216
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00005-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1389e9cf029b963914a7c725fb6ab5ced0576b364865e7f9cbb84d4f803b9b74
3
+ size 4950959936
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00006-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be51484e98b724b87e673cc0eeee8aeac66392e838f57879d5fdc34aef9d4022
3
+ size 4950980632
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00007-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffbdf62dff812b47dbf7e4a96e950bcc396a73922301310b9e84bc008d54e34e
3
+ size 3021537400
sudoku/sudoku_processor.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SudokuProcessor - Sudoku puzzle generation, solving, and rendering using SAT solver.
3
+ Supports efficient diverse multi-solution generation.
4
+ """
5
+ import random
6
+ from typing import List, Tuple, Optional
7
+ import numpy as np
8
+ import cv2
9
+
10
+ try:
11
+ from pysat.solvers import Solver
12
+ HAS_PYSAT = True
13
+ except ImportError:
14
+ HAS_PYSAT = False
15
+ print("Warning: pysat not found, install with: pip install python-sat")
16
+
17
+
18
+ class SudokuProcessor:
19
+ """Handles Sudoku puzzle generation, solving, and image rendering."""
20
+
21
+ def __init__(self, cell_size: int = 60, font_scale: float = 1.2, thickness: int = 2):
22
+ self.cell_size = cell_size
23
+ self.font_scale = font_scale
24
+ self.thickness = thickness
25
+ self.img_size = cell_size * 9
26
+
27
+ # Colors (RGB)
28
+ self.bg_color = (255, 255, 255)
29
+ self.line_color = (0, 0, 0)
30
+ self.original_color = (0, 0, 0)
31
+ self.filled_color = (200, 0, 0)
32
+ self.highlight_color = (255, 255, 200)
33
+
34
+ self._base_clauses_cache = None
35
+
36
+ # ==================== SAT Encoding ====================
37
+
38
+ def _var(self, r: int, c: int, n: int) -> int:
39
+ """Map (row, col, num) to SAT variable (1-indexed)."""
40
+ return r * 81 + c * 9 + n + 1
41
+
42
+ def _decode_var(self, v: int) -> Tuple[int, int, int]:
43
+ v -= 1
44
+ return v // 81, (v % 81) // 9, v % 9
45
+
46
+ def _base_clauses(self) -> List[List[int]]:
47
+ """Generate base Sudoku constraint clauses (cached)."""
48
+ if self._base_clauses_cache is not None:
49
+ return self._base_clauses_cache
50
+
51
+ clauses = []
52
+ for i in range(9):
53
+ for j in range(9):
54
+ clauses.append([self._var(i, j, n) for n in range(9)])
55
+ for n1 in range(9):
56
+ for n2 in range(n1 + 1, 9):
57
+ clauses.append([-self._var(i, j, n1), -self._var(i, j, n2)])
58
+
59
+ for n in range(9):
60
+ for i in range(9):
61
+ clauses.append([self._var(i, j, n) for j in range(9)])
62
+ for j1 in range(9):
63
+ for j2 in range(j1 + 1, 9):
64
+ clauses.append([-self._var(i, j1, n), -self._var(i, j2, n)])
65
+ clauses.append([self._var(j, i, n) for j in range(9)])
66
+ for j1 in range(9):
67
+ for j2 in range(j1 + 1, 9):
68
+ clauses.append([-self._var(j1, i, n), -self._var(j2, i, n)])
69
+ for br in range(3):
70
+ for bc in range(3):
71
+ box = [self._var(br*3+di, bc*3+dj, n) for di in range(3) for dj in range(3)]
72
+ clauses.append(box)
73
+ for i1 in range(9):
74
+ for i2 in range(i1 + 1, 9):
75
+ clauses.append([-box[i1], -box[i2]])
76
+
77
+ self._base_clauses_cache = clauses
78
+ return clauses
79
+
80
+ def _grid_clauses(self, grid: List[List[int]]) -> List[List[int]]:
81
+ return [[self._var(i, j, grid[i][j] - 1)]
82
+ for i in range(9) for j in range(9) if grid[i][j] != 0]
83
+
84
+ def _model_to_grid(self, model: List[int]) -> List[List[int]]:
85
+ grid = [[0] * 9 for _ in range(9)]
86
+ for v in model:
87
+ if 0 < v <= 729:
88
+ r, c, n = self._decode_var(v)
89
+ grid[r][c] = n + 1
90
+ return grid
91
+
92
+ # ==================== Solving ====================
93
+
94
+ def solve(self, grid: List[List[int]]) -> Optional[List[List[int]]]:
95
+ if HAS_PYSAT:
96
+ with Solver(name='g3') as s:
97
+ for c in self._base_clauses() + self._grid_clauses(grid):
98
+ s.add_clause(c)
99
+ return self._model_to_grid(s.get_model()) if s.solve() else None
100
+ return self._solve_backtrack(grid)
101
+
102
+ def _solve_backtrack(self, grid: List[List[int]]) -> Optional[List[List[int]]]:
103
+ board = [row[:] for row in grid]
104
+ return board if self._backtrack(board) else None
105
+
106
+ def _backtrack(self, board: List[List[int]]) -> bool:
107
+ empty = self._find_empty(board)
108
+ if not empty:
109
+ return True
110
+ r, c = empty
111
+ for num in range(1, 10):
112
+ if self._is_valid(board, r, c, num):
113
+ board[r][c] = num
114
+ if self._backtrack(board):
115
+ return True
116
+ board[r][c] = 0
117
+ return False
118
+
119
+ def _find_empty(self, board: List[List[int]]) -> Optional[Tuple[int, int]]:
120
+ for i in range(9):
121
+ for j in range(9):
122
+ if board[i][j] == 0:
123
+ return (i, j)
124
+ return None
125
+
126
+ def _is_valid(self, board: List[List[int]], row: int, col: int, num: int) -> bool:
127
+ if num in board[row]:
128
+ return False
129
+ if any(board[i][col] == num for i in range(9)):
130
+ return False
131
+ br, bc = 3 * (row // 3), 3 * (col // 3)
132
+ return all(board[i][j] != num for i in range(br, br+3) for j in range(bc, bc+3))
133
+
134
+ def count_solutions(self, grid: List[List[int]], limit: int = 2) -> int:
135
+ if HAS_PYSAT:
136
+ count = 0
137
+ with Solver(name='g3') as s:
138
+ for c in self._base_clauses() + self._grid_clauses(grid):
139
+ s.add_clause(c)
140
+ while count < limit and s.solve():
141
+ count += 1
142
+ s.add_clause([-v for v in s.get_model() if 0 < v <= 729])
143
+ return count
144
+ return self._count_backtrack(grid, limit)
145
+
146
+ def _count_backtrack(self, grid: List[List[int]], limit: int) -> int:
147
+ board = [row[:] for row in grid]
148
+ self._sol_count, self._sol_limit = 0, limit
149
+ self._count_helper(board)
150
+ return self._sol_count
151
+
152
+ def _count_helper(self, board: List[List[int]]) -> bool:
153
+ if self._sol_count >= self._sol_limit:
154
+ return True
155
+ empty = self._find_empty(board)
156
+ if not empty:
157
+ self._sol_count += 1
158
+ return self._sol_count >= self._sol_limit
159
+ r, c = empty
160
+ for num in range(1, 10):
161
+ if self._is_valid(board, r, c, num):
162
+ board[r][c] = num
163
+ if self._count_helper(board):
164
+ return True
165
+ board[r][c] = 0
166
+ return False
167
+
168
+ def find_solutions(self, grid: List[List[int]], limit: int = 10) -> List[List[List[int]]]:
169
+ if HAS_PYSAT:
170
+ solutions = []
171
+ with Solver(name='g3') as s:
172
+ for c in self._base_clauses() + self._grid_clauses(grid):
173
+ s.add_clause(c)
174
+ while len(solutions) < limit and s.solve():
175
+ model = s.get_model()
176
+ solutions.append(self._model_to_grid(model))
177
+ s.add_clause([-v for v in model if 0 < v <= 729])
178
+ return solutions
179
+ return self._find_backtrack(grid, limit)
180
+
181
+ def _find_backtrack(self, grid: List[List[int]], limit: int) -> List[List[List[int]]]:
182
+ board, solutions = [row[:] for row in grid], []
183
+ self._find_helper(board, solutions, limit)
184
+ return solutions
185
+
186
+ def _find_helper(self, board: List[List[int]], solutions: List, limit: int) -> bool:
187
+ if len(solutions) >= limit:
188
+ return True
189
+ empty = self._find_empty(board)
190
+ if not empty:
191
+ solutions.append([row[:] for row in board])
192
+ return len(solutions) >= limit
193
+ r, c = empty
194
+ for num in range(1, 10):
195
+ if self._is_valid(board, r, c, num):
196
+ board[r][c] = num
197
+ if self._find_helper(board, solutions, limit):
198
+ return True
199
+ board[r][c] = 0
200
+ return False
201
+
202
+ # ==================== Generation ====================
203
+
204
+ def generate(self, clues: int = 30, unique: bool = True) -> Tuple[List[List[int]], List[List[int]]]:
205
+ """Generate a Sudoku puzzle with specified number of clues."""
206
+ solution = self._generate_full_grid()
207
+ puzzle = [row[:] for row in solution]
208
+
209
+ cells = [(i, j) for i in range(9) for j in range(9)]
210
+ random.shuffle(cells)
211
+
212
+ removed, target = 0, 81 - clues
213
+ for r, c in cells:
214
+ if removed >= target:
215
+ break
216
+ backup = puzzle[r][c]
217
+ puzzle[r][c] = 0
218
+ if unique and self.count_solutions(puzzle, 2) != 1:
219
+ puzzle[r][c] = backup
220
+ else:
221
+ removed += 1
222
+
223
+ return puzzle, solution
224
+
225
+ def _generate_full_grid(self) -> List[List[int]]:
226
+ if HAS_PYSAT:
227
+ with Solver(name='g3') as s:
228
+ for c in self._base_clauses():
229
+ s.add_clause(c)
230
+ cells = [(i, j) for i in range(9) for j in range(9)]
231
+ random.shuffle(cells)
232
+ assumptions = []
233
+ for r, c in cells[:11]:
234
+ nums = list(range(9))
235
+ random.shuffle(nums)
236
+ for n in nums:
237
+ if s.solve(assumptions=assumptions + [self._var(r, c, n)]):
238
+ assumptions.append(self._var(r, c, n))
239
+ break
240
+ s.solve(assumptions=assumptions)
241
+ return self._model_to_grid(s.get_model())
242
+
243
+ board = [[0] * 9 for _ in range(9)]
244
+ self._fill_grid(board)
245
+ return board
246
+
247
+ def _fill_grid(self, board: List[List[int]]) -> bool:
248
+ empty = self._find_empty(board)
249
+ if not empty:
250
+ return True
251
+ r, c = empty
252
+ nums = list(range(1, 10))
253
+ random.shuffle(nums)
254
+ for num in nums:
255
+ if self._is_valid(board, r, c, num):
256
+ board[r][c] = num
257
+ if self._fill_grid(board):
258
+ return True
259
+ board[r][c] = 0
260
+ return False
261
+
262
+ # ==================== Diverse Multi-Solution Generation ====================
263
+
264
+ @staticmethod
265
+ def _hamming(sol1: List[List[int]], sol2: List[List[int]]) -> int:
266
+ """Count differing cells between two complete grids."""
267
+ return sum(sol1[i][j] != sol2[i][j] for i in range(9) for j in range(9))
268
+
269
+ @staticmethod
270
+ def _greedy_diverse_select(
271
+ candidates: List[List[List[int]]],
272
+ target_count: int,
273
+ min_hamming: int,
274
+ _hamming_fn=None,
275
+ ) -> List[List[List[int]]]:
276
+ """
277
+ Greedily select diverse solutions using farthest-point sampling.
278
+
279
+ 1. Start with a random candidate.
280
+ 2. Repeatedly add the candidate with maximum min-distance to the selected set.
281
+ 3. Stop when enough are selected or no candidate meets min_hamming.
282
+ """
283
+ if _hamming_fn is None:
284
+ _hamming_fn = SudokuProcessor._hamming
285
+
286
+ if len(candidates) <= 1:
287
+ return list(candidates)
288
+
289
+ n = len(candidates)
290
+
291
+ # Pre-compute pairwise distances
292
+ dist = [[0] * n for _ in range(n)]
293
+ for i in range(n):
294
+ for j in range(i + 1, n):
295
+ d = _hamming_fn(candidates[i], candidates[j])
296
+ dist[i][j] = d
297
+ dist[j][i] = d
298
+
299
+ # Farthest-point sampling
300
+ selected = [random.randint(0, n - 1)]
301
+ remaining = set(range(n)) - {selected[0]}
302
+
303
+ while len(selected) < target_count and remaining:
304
+ best_idx = -1
305
+ best_min_dist = -1
306
+
307
+ for r in remaining:
308
+ min_d = min(dist[r][s] for s in selected)
309
+ if min_d > best_min_dist:
310
+ best_min_dist = min_d
311
+ best_idx = r
312
+
313
+ if best_min_dist < min_hamming:
314
+ break
315
+
316
+ selected.append(best_idx)
317
+ remaining.discard(best_idx)
318
+
319
+ return [candidates[i] for i in selected]
320
+
321
+ def generate_multi_solution(
322
+ self,
323
+ clues: int,
324
+ min_solutions: int = 2,
325
+ max_solutions: int = 5,
326
+ max_attempts: int = 100,
327
+ min_hamming: int = 10
328
+ ) -> Tuple[List[List[int]], List[List[List[int]]]]:
329
+ """
330
+ Generate a puzzle with multiple diverse solutions.
331
+
332
+ Puzzle-first strategy:
333
+ 1. Generate a full grid, randomly remove (81-clues) cells WITHOUT
334
+ uniqueness check → guaranteed to have ≥1 solution, likely many.
335
+ 2. Enumerate candidate solutions of this puzzle via SAT.
336
+ 3. Greedily select diverse solutions (farthest-point sampling).
337
+ 4. If not enough diverse solutions, retry with a new puzzle.
338
+
339
+ This is correct because all returned solutions are guaranteed valid
340
+ completions of the returned puzzle.
341
+
342
+ Args:
343
+ clues: Number of given cells.
344
+ min_solutions: Minimum diverse solutions required.
345
+ max_solutions: Maximum to return.
346
+ max_attempts: Outer retry budget.
347
+ min_hamming: Minimum pairwise Hamming distance.
348
+ Returns:
349
+ (puzzle, solutions) — all solutions are valid and pairwise diverse.
350
+ Raises:
351
+ RuntimeError: If unable to find a qualifying puzzle.
352
+ """
353
+ # Adaptive hamming threshold
354
+ adaptive_hamming = min_hamming
355
+ if min_hamming == 10: # default → auto-adapt
356
+ if clues >= 55:
357
+ adaptive_hamming = 3
358
+ elif clues >= 45:
359
+ adaptive_hamming = 5
360
+ elif clues >= 35:
361
+ adaptive_hamming = 8
362
+ else:
363
+ adaptive_hamming = 12
364
+
365
+ # Adaptive search depth: more empty cells → more solutions likely exist
366
+ empty_cells = 81 - clues
367
+ if empty_cells <= 15:
368
+ max_search = 30
369
+ elif empty_cells <= 25:
370
+ max_search = 80
371
+ elif empty_cells <= 40:
372
+ max_search = 150
373
+ else:
374
+ max_search = 300
375
+
376
+ for _ in range(max_attempts):
377
+ # Phase 1: Generate puzzle (random removal, no uniqueness check)
378
+ solution = self._generate_full_grid()
379
+ puzzle = [row[:] for row in solution]
380
+
381
+ cells = [(i, j) for i in range(9) for j in range(9)]
382
+ random.shuffle(cells)
383
+ for r, c in cells[:81 - clues]:
384
+ puzzle[r][c] = 0
385
+
386
+ # # Phase 2: Quick feasibility — need at least min_solutions solutions
387
+ # quick_count = self.count_solutions(puzzle, min_solutions + 1)
388
+ # if quick_count < min_solutions:
389
+ # continue
390
+
391
+ # Phase 3: Enumerate candidates
392
+ candidates = self.find_solutions(puzzle, max_search)
393
+ if len(candidates) < min_solutions:
394
+ continue
395
+
396
+ # Phase 4: Greedy diverse selection
397
+ diverse = self._greedy_diverse_select(
398
+ candidates, max_solutions, adaptive_hamming
399
+ )
400
+
401
+ if len(diverse) >= min_solutions:
402
+ return puzzle, diverse[:max_solutions]
403
+
404
+ raise RuntimeError(
405
+ f"Failed to generate puzzle with {min_solutions}-{max_solutions} "
406
+ f"diverse solutions (hamming>={adaptive_hamming}) after {max_attempts} attempts"
407
+ )
408
+
409
+ # ==================== Encoding ====================
410
+
411
+ def encode(self, grid: List[List[int]]) -> str:
412
+ return ''.join(str(grid[i][j]) for i in range(9) for j in range(9))
413
+
414
+ def decode(self, s: str) -> List[List[int]]:
415
+ return [[int(s[i * 9 + j]) for j in range(9)] for i in range(9)]
416
+
417
+ # ==================== Rendering ====================
418
+
419
+ def render(
420
+ self,
421
+ grid: List[List[int]],
422
+ highlight_new: Optional[Tuple[int, int]] = None,
423
+ original: Optional[List[List[int]]] = None
424
+ ) -> np.ndarray:
425
+ img = np.full((self.img_size, self.img_size, 3), self.bg_color, dtype=np.uint8)
426
+ cs = self.cell_size
427
+
428
+ if highlight_new:
429
+ r, c = highlight_new
430
+ cv2.rectangle(img, (c * cs, r * cs), ((c+1) * cs, (r+1) * cs), self.highlight_color, -1)
431
+
432
+ for i in range(10):
433
+ thick = 3 if i % 3 == 0 else 1
434
+ pos = i * cs
435
+ cv2.line(img, (pos, 0), (pos, self.img_size), self.line_color, thick)
436
+ cv2.line(img, (0, pos), (self.img_size, pos), self.line_color, thick)
437
+
438
+ font = cv2.FONT_HERSHEY_SIMPLEX
439
+ for i in range(9):
440
+ for j in range(9):
441
+ if grid[i][j] == 0:
442
+ continue
443
+ is_original = original is None or original[i][j] != 0
444
+ color = self.original_color if is_original else self.filled_color
445
+ text = str(grid[i][j])
446
+ (tw, th), _ = cv2.getTextSize(text, font, self.font_scale, self.thickness)
447
+ cv2.putText(img, text, (j*cs + (cs-tw)//2, i*cs + (cs+th)//2),
448
+ font, self.font_scale, color, self.thickness)
449
+
450
+ return img
451
+
452
+
453
+ if __name__ == "__main__":
454
+ proc = SudokuProcessor()
455
+ print(f"Using {'SAT solver' if HAS_PYSAT else 'backtracking'}...")
456
+
457
+ # Test unique puzzle
458
+ puzzle, solution = proc.generate(clues=25, unique=True)
459
+ print("Puzzle:"); [print(row) for row in puzzle]
460
+ print(f"Clues: {sum(c != 0 for row in puzzle for c in row)}")
461
+
462
+ cv2.imwrite("test_puzzle.png", cv2.cvtColor(proc.render(puzzle), cv2.COLOR_RGB2BGR))
463
+ cv2.imwrite("test_solution.png", cv2.cvtColor(proc.render(solution, original=puzzle), cv2.COLOR_RGB2BGR))
464
+ print("Saved test images.")
465
+
466
+ # Test diverse multi-solution at various clue levels
467
+ print("\n=== Testing diverse multi-solution generation ===")
468
+ for clues in [25, 35, 45, 55]:
469
+ print(f"\nClue {clues}:")
470
+ try:
471
+ puzzle_m, solutions_m = proc.generate_multi_solution(
472
+ clues=clues, min_solutions=3, max_solutions=3, min_hamming=10
473
+ )
474
+ print(f" Generated puzzle with {len(solutions_m)} diverse solutions")
475
+ for i in range(len(solutions_m)):
476
+ for j in range(i + 1, len(solutions_m)):
477
+ print(f" Hamming(sol{i}, sol{j}) = {proc._hamming(solutions_m[i], solutions_m[j])}")
478
+ except RuntimeError as e:
479
+ print(f" {e}")