Add files using upload-large-folder tool
Browse files- sequence/data_generation.py +336 -0
- sequence/test.py +376 -0
- sudoku/convert.py +81 -0
- sudoku/convert_wan.py +1287 -0
- sudoku/generate_dataset.py +424 -0
- sudoku/jsonl_to_csv.py +22 -0
- sudoku/simplify_dataset.py +19 -0
- sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00001-of-00007.safetensors +3 -0
- sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00002-of-00007.safetensors +3 -0
- sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00003-of-00007.safetensors +3 -0
- sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00004-of-00007.safetensors +3 -0
- sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00005-of-00007.safetensors +3 -0
- sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00006-of-00007.safetensors +3 -0
- sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00007-of-00007.safetensors +3 -0
- sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0.safetensors +3 -0
- sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-0.safetensors +3 -0
- sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-1.safetensors +3 -0
- sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-2.safetensors +3 -0
- sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-3.safetensors +3 -0
- sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-4.safetensors +3 -0
- sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31/epoch-3.safetensors +3 -0
- sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00001-of-00007.safetensors +3 -0
- sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00002-of-00007.safetensors +3 -0
- sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00003-of-00007.safetensors +3 -0
- sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00004-of-00007.safetensors +3 -0
- sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00005-of-00007.safetensors +3 -0
- sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00006-of-00007.safetensors +3 -0
- sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00007-of-00007.safetensors +3 -0
- sudoku/sudoku_processor.py +479 -0
sequence/data_generation.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sequence Prediction Dataset Generator.
|
| 3 |
+
|
| 4 |
+
Generates image pairs for sequence prediction tasks with various
|
| 5 |
+
mathematical sequences (arithmetic, geometric, fibonacci, etc.)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import random
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Callable
|
| 12 |
+
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
import matplotlib.patches as patches
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ============== Sequence Generators ==============
|
| 18 |
+
|
| 19 |
+
def arithmetic_seq(start: int, diff: int, length: int = 4) -> list[int]:
|
| 20 |
+
"""Arithmetic sequence: a, a+d, a+2d, ..."""
|
| 21 |
+
return [start + i * diff for i in range(length)]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def geometric_seq(start: int, ratio: int, length: int = 4) -> list[int]:
|
| 25 |
+
"""Geometric sequence: a, a*r, a*r^2, ..."""
|
| 26 |
+
return [start * (ratio ** i) for i in range(length)]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def square_seq(start: int, length: int = 4) -> list[int]:
|
| 30 |
+
"""Square numbers: n^2, (n+1)^2, ..."""
|
| 31 |
+
return [(start + i) ** 2 for i in range(length)]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def cube_seq(start: int, length: int = 4) -> list[int]:
|
| 35 |
+
"""Cube numbers: n^3, (n+1)^3, ..."""
|
| 36 |
+
return [(start + i) ** 3 for i in range(length)]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def triangular_seq(start: int, length: int = 4) -> list[int]:
|
| 40 |
+
"""Triangular numbers: n(n+1)/2"""
|
| 41 |
+
return [(start + i) * (start + i + 1) // 2 for i in range(length)]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def fibonacci_like_seq(a: int, b: int, length: int = 4) -> list[int]:
|
| 45 |
+
"""Fibonacci-like: a, b, a+b, a+2b, ..."""
|
| 46 |
+
seq = [a, b]
|
| 47 |
+
for _ in range(length - 2):
|
| 48 |
+
seq.append(seq[-1] + seq[-2])
|
| 49 |
+
return seq[:length]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def prime_seq(start_idx: int, length: int = 4) -> list[int]:
|
| 53 |
+
"""Prime numbers starting from index."""
|
| 54 |
+
primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47]
|
| 55 |
+
return primes[start_idx:start_idx + length]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def power_of_two_seq(start: int, length: int = 4) -> list[int]:
|
| 59 |
+
"""Powers of 2: 2^n, 2^(n+1), ..."""
|
| 60 |
+
return [2 ** (start + i) for i in range(length)]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def factorial_seq(start: int, length: int = 4) -> list[int]:
|
| 64 |
+
"""Factorial sequence: n!, (n+1)!, ..."""
|
| 65 |
+
from math import factorial
|
| 66 |
+
return [factorial(start + i) for i in range(length)]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ============== Sequence Factory ==============
|
| 70 |
+
|
| 71 |
+
SEQUENCE_TYPES = {
|
| 72 |
+
"arithmetic": lambda rng: arithmetic_seq(
|
| 73 |
+
rng.randint(1, 20), rng.randint(1, 10)
|
| 74 |
+
),
|
| 75 |
+
"arithmetic_neg": lambda rng: arithmetic_seq(
|
| 76 |
+
rng.randint(20, 50), -rng.randint(1, 5)
|
| 77 |
+
),
|
| 78 |
+
"geometric_2": lambda rng: geometric_seq(
|
| 79 |
+
rng.randint(1, 5), 2
|
| 80 |
+
),
|
| 81 |
+
"geometric_3": lambda rng: geometric_seq(
|
| 82 |
+
rng.randint(1, 3), 3
|
| 83 |
+
),
|
| 84 |
+
"square": lambda rng: square_seq(rng.randint(1, 10)),
|
| 85 |
+
"cube": lambda rng: cube_seq(rng.randint(1, 5)),
|
| 86 |
+
"triangular": lambda rng: triangular_seq(rng.randint(1, 10)),
|
| 87 |
+
"fibonacci": lambda rng: fibonacci_like_seq(
|
| 88 |
+
rng.randint(1, 5), rng.randint(1, 5)
|
| 89 |
+
),
|
| 90 |
+
"prime": lambda rng: prime_seq(rng.randint(0, 10)),
|
| 91 |
+
"power_of_2": lambda rng: power_of_two_seq(rng.randint(0, 6)),
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def generate_sequence_pair(seq: list[int]) -> tuple[list, list]:
|
| 96 |
+
"""
|
| 97 |
+
Generate a pair of sequences for the task.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
(partial, complete): partial has last element as "", complete is full.
|
| 101 |
+
"""
|
| 102 |
+
partial = seq[:-1] + [""]
|
| 103 |
+
return partial, seq
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ============== Image Generation ==============
|
| 107 |
+
|
| 108 |
+
def round_to_multiple(x: int, multiple: int = 16) -> int:
|
| 109 |
+
"""Round x up to nearest multiple."""
|
| 110 |
+
return ((x + multiple - 1) // multiple) * multiple
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def create_number_grid(
|
| 114 |
+
numbers: list,
|
| 115 |
+
save_path: str,
|
| 116 |
+
height: int = 224,
|
| 117 |
+
width: int = 896,
|
| 118 |
+
fontsize: int = 48,
|
| 119 |
+
size_multiple: int = 16,
|
| 120 |
+
) -> None:
|
| 121 |
+
"""
|
| 122 |
+
Create a 1xN grid image with numbers in each cell.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
numbers: List of numbers/strings to display.
|
| 126 |
+
save_path: Output file path.
|
| 127 |
+
height: Target height in pixels (will be rounded to size_multiple).
|
| 128 |
+
width: Target width in pixels (will be rounded to size_multiple).
|
| 129 |
+
fontsize: Font size for the numbers.
|
| 130 |
+
size_multiple: Ensure dimensions are multiples of this (default 16).
|
| 131 |
+
"""
|
| 132 |
+
from PIL import Image
|
| 133 |
+
|
| 134 |
+
n = len(numbers)
|
| 135 |
+
|
| 136 |
+
# Ensure dimensions are multiples of size_multiple
|
| 137 |
+
width = round_to_multiple(width, size_multiple)
|
| 138 |
+
height = round_to_multiple(height, size_multiple)
|
| 139 |
+
|
| 140 |
+
# Use fixed DPI and calculate figsize
|
| 141 |
+
dpi = 100
|
| 142 |
+
fig_width = width / dpi
|
| 143 |
+
fig_height = height / dpi
|
| 144 |
+
|
| 145 |
+
fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=dpi)
|
| 146 |
+
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
| 147 |
+
|
| 148 |
+
for i, num in enumerate(numbers):
|
| 149 |
+
rect = patches.Rectangle(
|
| 150 |
+
(i, 0), 1, 1, linewidth=2,
|
| 151 |
+
edgecolor='black', facecolor='white'
|
| 152 |
+
)
|
| 153 |
+
ax.add_patch(rect)
|
| 154 |
+
ax.text(
|
| 155 |
+
i + 0.5, 0.5, str(num), fontsize=fontsize,
|
| 156 |
+
ha='center', va='center', fontweight='bold'
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
ax.set_xlim(0, n)
|
| 160 |
+
ax.set_ylim(0, 1)
|
| 161 |
+
ax.set_aspect('equal')
|
| 162 |
+
ax.axis('off')
|
| 163 |
+
|
| 164 |
+
# Save with exact pixel dimensions
|
| 165 |
+
fig.savefig(save_path, dpi=dpi, facecolor='white', edgecolor='none')
|
| 166 |
+
plt.close(fig)
|
| 167 |
+
|
| 168 |
+
# Final resize to ensure exact dimensions (16 multiples)
|
| 169 |
+
img = Image.open(save_path)
|
| 170 |
+
if img.size != (width, height):
|
| 171 |
+
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
| 172 |
+
img.save(save_path)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# ============== Dataset Generation ==============
|
| 176 |
+
|
| 177 |
+
class SequenceDatasetGenerator:
|
| 178 |
+
"""Generate sequence prediction dataset with train/test splits."""
|
| 179 |
+
|
| 180 |
+
def __init__(
|
| 181 |
+
self,
|
| 182 |
+
output_dir: str,
|
| 183 |
+
seed: int = 42,
|
| 184 |
+
num_pairs: tuple[int, int] = (2, 3),
|
| 185 |
+
seq_types: list[str] | None = None,
|
| 186 |
+
image_height: int = 224,
|
| 187 |
+
image_width: int = 896,
|
| 188 |
+
fontsize: int = 48,
|
| 189 |
+
):
|
| 190 |
+
"""
|
| 191 |
+
Args:
|
| 192 |
+
output_dir: Directory to save the dataset.
|
| 193 |
+
seed: Random seed for reproducibility.
|
| 194 |
+
num_pairs: Range of pairs per sample (min, max inclusive).
|
| 195 |
+
seq_types: List of sequence types to use (None = all).
|
| 196 |
+
image_height: Image height in pixels (rounded to 16).
|
| 197 |
+
image_width: Image width in pixels (rounded to 16).
|
| 198 |
+
fontsize: Font size for numbers.
|
| 199 |
+
"""
|
| 200 |
+
self.output_dir = Path(output_dir)
|
| 201 |
+
self.rng = random.Random(seed)
|
| 202 |
+
self.num_pairs = num_pairs
|
| 203 |
+
self.seq_types = seq_types or list(SEQUENCE_TYPES.keys())
|
| 204 |
+
self.image_height = round_to_multiple(image_height, 16)
|
| 205 |
+
self.image_width = round_to_multiple(image_width, 16)
|
| 206 |
+
self.fontsize = fontsize
|
| 207 |
+
|
| 208 |
+
# Create directories
|
| 209 |
+
for split in ["train", "test"]:
|
| 210 |
+
(self.output_dir / split / "images").mkdir(parents=True, exist_ok=True)
|
| 211 |
+
|
| 212 |
+
def _generate_sample(self, sample_id: int) -> dict:
|
| 213 |
+
"""Generate a single sample with multiple sequence pairs."""
|
| 214 |
+
num_pairs = self.rng.randint(*self.num_pairs)
|
| 215 |
+
seq_type = self.rng.choice(self.seq_types)
|
| 216 |
+
|
| 217 |
+
# Generate base sequence and subsequent ones
|
| 218 |
+
base_seq = SEQUENCE_TYPES[seq_type](self.rng)
|
| 219 |
+
|
| 220 |
+
pairs = []
|
| 221 |
+
for i in range(num_pairs):
|
| 222 |
+
# Shift sequence for each pair
|
| 223 |
+
if seq_type.startswith("arithmetic"):
|
| 224 |
+
diff = base_seq[1] - base_seq[0]
|
| 225 |
+
seq = [x + i * diff for x in base_seq]
|
| 226 |
+
elif seq_type.startswith("geometric"):
|
| 227 |
+
ratio = base_seq[1] // base_seq[0] if base_seq[0] != 0 else 2
|
| 228 |
+
seq = [x * (ratio ** i) for x in base_seq]
|
| 229 |
+
else:
|
| 230 |
+
# For other types, regenerate with offset
|
| 231 |
+
seq = [x + i for x in base_seq]
|
| 232 |
+
|
| 233 |
+
partial, complete = generate_sequence_pair(seq)
|
| 234 |
+
pairs.append({
|
| 235 |
+
"partial": partial,
|
| 236 |
+
"complete": complete,
|
| 237 |
+
"answer": complete[-1],
|
| 238 |
+
})
|
| 239 |
+
|
| 240 |
+
return {
|
| 241 |
+
"id": sample_id,
|
| 242 |
+
"seq_type": seq_type,
|
| 243 |
+
"num_pairs": num_pairs,
|
| 244 |
+
"pairs": pairs,
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
def _save_sample_images(
|
| 248 |
+
self, sample: dict, split: str, include_last_answer: bool = True
|
| 249 |
+
) -> dict:
|
| 250 |
+
"""Save images for a sample and return metadata."""
|
| 251 |
+
sample_id = sample["id"]
|
| 252 |
+
image_dir = self.output_dir / split / "images"
|
| 253 |
+
|
| 254 |
+
images = []
|
| 255 |
+
img_idx = 0
|
| 256 |
+
|
| 257 |
+
for i, pair in enumerate(sample["pairs"]):
|
| 258 |
+
# Always save partial (query) image
|
| 259 |
+
partial_path = f"{sample_id:05d}_{img_idx}.png"
|
| 260 |
+
create_number_grid(
|
| 261 |
+
pair["partial"], image_dir / partial_path,
|
| 262 |
+
height=self.image_height, width=self.image_width,
|
| 263 |
+
fontsize=self.fontsize,
|
| 264 |
+
)
|
| 265 |
+
images.append(partial_path)
|
| 266 |
+
img_idx += 1
|
| 267 |
+
|
| 268 |
+
# Save complete image based on split logic
|
| 269 |
+
is_last = (i == sample["num_pairs"] - 1)
|
| 270 |
+
if include_last_answer or not is_last:
|
| 271 |
+
complete_path = f"{sample_id:05d}_{img_idx}.png"
|
| 272 |
+
create_number_grid(
|
| 273 |
+
pair["complete"], image_dir / complete_path,
|
| 274 |
+
height=self.image_height, width=self.image_width,
|
| 275 |
+
fontsize=self.fontsize,
|
| 276 |
+
)
|
| 277 |
+
images.append(complete_path)
|
| 278 |
+
img_idx += 1
|
| 279 |
+
|
| 280 |
+
return {
|
| 281 |
+
"id": sample_id,
|
| 282 |
+
"seq_type": sample["seq_type"],
|
| 283 |
+
"num_pairs": sample["num_pairs"],
|
| 284 |
+
"images": images,
|
| 285 |
+
"answer": sample["pairs"][-1]["answer"], # Last image's answer
|
| 286 |
+
"sequences": [p["complete"] for p in sample["pairs"]],
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
def generate(self, num_train: int, num_test: int) -> None:
|
| 290 |
+
"""
|
| 291 |
+
Generate the full dataset.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
num_train: Number of training samples.
|
| 295 |
+
num_test: Number of test samples.
|
| 296 |
+
"""
|
| 297 |
+
train_meta, test_meta = [], []
|
| 298 |
+
|
| 299 |
+
# Generate training samples (all pairs complete)
|
| 300 |
+
print(f"Generating {num_train} training samples...")
|
| 301 |
+
for i in range(num_train):
|
| 302 |
+
sample = self._generate_sample(i)
|
| 303 |
+
meta = self._save_sample_images(sample, "train", include_last_answer=True)
|
| 304 |
+
train_meta.append(meta)
|
| 305 |
+
if (i + 1) % 50 == 0:
|
| 306 |
+
print(f" Train: {i + 1}/{num_train}")
|
| 307 |
+
|
| 308 |
+
# Generate test samples (last answer hidden)
|
| 309 |
+
print(f"Generating {num_test} test samples...")
|
| 310 |
+
for i in range(num_test):
|
| 311 |
+
sample = self._generate_sample(num_train + i)
|
| 312 |
+
meta = self._save_sample_images(sample, "test", include_last_answer=False)
|
| 313 |
+
test_meta.append(meta)
|
| 314 |
+
if (i + 1) % 50 == 0:
|
| 315 |
+
print(f" Test: {i + 1}/{num_test}")
|
| 316 |
+
|
| 317 |
+
# Save metadata
|
| 318 |
+
with open(self.output_dir / "train.json", "w") as f:
|
| 319 |
+
json.dump(train_meta, f, indent=2)
|
| 320 |
+
with open(self.output_dir / "test.json", "w") as f:
|
| 321 |
+
json.dump(test_meta, f, indent=2)
|
| 322 |
+
|
| 323 |
+
print(f"\nDataset saved to {self.output_dir}")
|
| 324 |
+
print(f" Train: {num_train} samples")
|
| 325 |
+
print(f" Test: {num_test} samples")
|
| 326 |
+
print(f" Image size: {self.image_width}x{self.image_height}")
|
| 327 |
+
print(f" Sequence types: {self.seq_types}")
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
if __name__ == "__main__":
|
| 331 |
+
generator = SequenceDatasetGenerator(
|
| 332 |
+
output_dir="/home/claude/sequence_dataset",
|
| 333 |
+
seed=42,
|
| 334 |
+
num_pairs=(2, 3),
|
| 335 |
+
)
|
| 336 |
+
generator.generate(num_train=100, num_test=20)
|
sequence/test.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sequence Prediction Evaluation with QwenImageEditPlusPipeline / Flux2KleinPipeline.
|
| 3 |
+
|
| 4 |
+
Evaluates the model's ability to predict the next number in a sequence
|
| 5 |
+
by generating images and extracting answers via OCR.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import re
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from enum import Enum
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ModelType(str, Enum):
|
| 21 |
+
QWEN_IMAGE_EDIT = "qwen"
|
| 22 |
+
FLUX2_KLEIN = "flux2-klein"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class EvalConfig:
|
| 27 |
+
"""Evaluation configuration."""
|
| 28 |
+
dataset_dir: str = "sequence_dataset"
|
| 29 |
+
output_dir: str = "eval_results"
|
| 30 |
+
|
| 31 |
+
# Model selection
|
| 32 |
+
model_type: ModelType = ModelType.QWEN_IMAGE_EDIT
|
| 33 |
+
model_id: str = "" # Auto-set based on model_type if empty
|
| 34 |
+
|
| 35 |
+
# Prompts
|
| 36 |
+
prompt: str = (
|
| 37 |
+
"Based on the number patterns shown in the previous images, "
|
| 38 |
+
"fill in the missing number in the empty cell of the last image."
|
| 39 |
+
)
|
| 40 |
+
negative_prompt: str = ""
|
| 41 |
+
|
| 42 |
+
# Generation params
|
| 43 |
+
num_inference_steps: int = 5
|
| 44 |
+
guidance_scale: float = 1.0
|
| 45 |
+
true_cfg_scale: float = 4.0 # For Qwen
|
| 46 |
+
height: int = 210
|
| 47 |
+
width: int = 750
|
| 48 |
+
|
| 49 |
+
seed: int = 42
|
| 50 |
+
device: str = "cuda"
|
| 51 |
+
dtype: torch.dtype = field(default_factory=lambda: torch.bfloat16)
|
| 52 |
+
|
| 53 |
+
def __post_init__(self):
|
| 54 |
+
"""Set default model_id based on model_type."""
|
| 55 |
+
if not self.model_id:
|
| 56 |
+
if self.model_type == ModelType.QWEN_IMAGE_EDIT:
|
| 57 |
+
self.model_id = "Qwen/Qwen-Image-Edit-2509"
|
| 58 |
+
elif self.model_type == ModelType.FLUX2_KLEIN:
|
| 59 |
+
self.model_id = "black-forest-labs/FLUX.2-klein-9B"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class OCRExtractor:
|
| 63 |
+
"""Extract numbers from grid images using OCR."""
|
| 64 |
+
|
| 65 |
+
def __init__(self, backend: str = "easyocr"):
|
| 66 |
+
"""
|
| 67 |
+
Args:
|
| 68 |
+
backend: OCR backend ("easyocr" or "pytesseract").
|
| 69 |
+
"""
|
| 70 |
+
self.backend = backend
|
| 71 |
+
if backend == "easyocr":
|
| 72 |
+
import easyocr
|
| 73 |
+
self.reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
|
| 74 |
+
elif backend == "pytesseract":
|
| 75 |
+
import pytesseract
|
| 76 |
+
self.pytesseract = pytesseract
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError(f"Unknown backend: {backend}")
|
| 79 |
+
|
| 80 |
+
def extract_last_number(self, image: Image.Image) -> int | None:
|
| 81 |
+
"""
|
| 82 |
+
Extract the last (rightmost) number from a grid image.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
image: PIL Image of the number grid.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Extracted number or None if extraction fails.
|
| 89 |
+
"""
|
| 90 |
+
w, h = image.size
|
| 91 |
+
cell_crop = image.crop((w * 3 // 4, 0, w, h))
|
| 92 |
+
cell_array = np.array(cell_crop)
|
| 93 |
+
|
| 94 |
+
if self.backend == "easyocr":
|
| 95 |
+
results = self.reader.readtext(cell_array)
|
| 96 |
+
for _, text, conf in results:
|
| 97 |
+
digits = re.findall(r'-?\d+', text)
|
| 98 |
+
if digits:
|
| 99 |
+
return int(digits[0])
|
| 100 |
+
|
| 101 |
+
elif self.backend == "pytesseract":
|
| 102 |
+
text = self.pytesseract.image_to_string(
|
| 103 |
+
cell_crop, config='--psm 7 -c tessedit_char_whitelist=0123456789-'
|
| 104 |
+
)
|
| 105 |
+
digits = re.findall(r'-?\d+', text)
|
| 106 |
+
if digits:
|
| 107 |
+
return int(digits[0])
|
| 108 |
+
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
def extract_all_numbers(self, image: Image.Image, num_cells: int = 4) -> list[int | None]:
|
| 112 |
+
"""Extract all numbers from a grid image."""
|
| 113 |
+
w, h = image.size
|
| 114 |
+
cell_width = w // num_cells
|
| 115 |
+
numbers = []
|
| 116 |
+
|
| 117 |
+
for i in range(num_cells):
|
| 118 |
+
cell_crop = image.crop((i * cell_width, 0, (i + 1) * cell_width, h))
|
| 119 |
+
cell_array = np.array(cell_crop)
|
| 120 |
+
|
| 121 |
+
if self.backend == "easyocr":
|
| 122 |
+
results = self.reader.readtext(cell_array)
|
| 123 |
+
num = None
|
| 124 |
+
for _, text, conf in results:
|
| 125 |
+
digits = re.findall(r'-?\d+', text)
|
| 126 |
+
if digits:
|
| 127 |
+
num = int(digits[0])
|
| 128 |
+
break
|
| 129 |
+
numbers.append(num)
|
| 130 |
+
|
| 131 |
+
elif self.backend == "pytesseract":
|
| 132 |
+
text = self.pytesseract.image_to_string(
|
| 133 |
+
cell_crop, config='--psm 7 -c tessedit_char_whitelist=0123456789-'
|
| 134 |
+
)
|
| 135 |
+
digits = re.findall(r'-?\d+', text)
|
| 136 |
+
numbers.append(int(digits[0]) if digits else None)
|
| 137 |
+
|
| 138 |
+
return numbers
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class SequenceEvaluator:
|
| 142 |
+
"""Evaluator for sequence prediction task."""
|
| 143 |
+
|
| 144 |
+
def __init__(self, config: EvalConfig):
|
| 145 |
+
self.config = config
|
| 146 |
+
self.output_dir = Path(config.output_dir)
|
| 147 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 148 |
+
|
| 149 |
+
# Load pipeline based on model type
|
| 150 |
+
self.pipeline = self._load_pipeline()
|
| 151 |
+
|
| 152 |
+
# Initialize OCR
|
| 153 |
+
self.ocr = OCRExtractor(backend="easyocr")
|
| 154 |
+
|
| 155 |
+
def _load_pipeline(self):
|
| 156 |
+
"""Load pipeline based on model type."""
|
| 157 |
+
if self.config.model_type == ModelType.QWEN_IMAGE_EDIT:
|
| 158 |
+
return self._load_qwen_pipeline()
|
| 159 |
+
elif self.config.model_type == ModelType.FLUX2_KLEIN:
|
| 160 |
+
return self._load_flux2_klein_pipeline()
|
| 161 |
+
else:
|
| 162 |
+
raise ValueError(f"Unknown model type: {self.config.model_type}")
|
| 163 |
+
|
| 164 |
+
def _load_qwen_pipeline(self):
|
| 165 |
+
"""Load QwenImageEditPlusPipeline."""
|
| 166 |
+
from diffusers import QwenImageEditPlusPipeline
|
| 167 |
+
|
| 168 |
+
pipeline = QwenImageEditPlusPipeline.from_pretrained(
|
| 169 |
+
self.config.model_id,
|
| 170 |
+
torch_dtype=self.config.dtype,
|
| 171 |
+
)
|
| 172 |
+
pipeline.to(self.config.device)
|
| 173 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 174 |
+
return pipeline
|
| 175 |
+
|
| 176 |
+
def _load_flux2_klein_pipeline(self):
|
| 177 |
+
"""Load Flux2KleinPipeline."""
|
| 178 |
+
from diffusers import Flux2KleinPipeline
|
| 179 |
+
|
| 180 |
+
pipeline = Flux2KleinPipeline.from_pretrained(
|
| 181 |
+
self.config.model_id,
|
| 182 |
+
torch_dtype=self.config.dtype,
|
| 183 |
+
)
|
| 184 |
+
pipeline.enable_model_cpu_offload()
|
| 185 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 186 |
+
return pipeline
|
| 187 |
+
|
| 188 |
+
def _load_images(self, image_paths: list[str], image_dir: Path) -> list[Image.Image]:
|
| 189 |
+
"""Load images from paths."""
|
| 190 |
+
return [Image.open(image_dir / p).convert("RGB") for p in image_paths]
|
| 191 |
+
|
| 192 |
+
def predict(self, images: list[Image.Image]) -> Image.Image:
|
| 193 |
+
"""
|
| 194 |
+
Generate prediction image given input images.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
images: List of input images (context + query).
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Generated image with predicted number.
|
| 201 |
+
"""
|
| 202 |
+
generator = torch.Generator(device=self.config.device).manual_seed(self.config.seed)
|
| 203 |
+
|
| 204 |
+
if self.config.model_type == ModelType.QWEN_IMAGE_EDIT:
|
| 205 |
+
inputs = {
|
| 206 |
+
"image": images,
|
| 207 |
+
"prompt": self.config.prompt,
|
| 208 |
+
"generator": generator,
|
| 209 |
+
"true_cfg_scale": self.config.true_cfg_scale,
|
| 210 |
+
"negative_prompt": self.config.negative_prompt,
|
| 211 |
+
"num_inference_steps": self.config.num_inference_steps,
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
elif self.config.model_type == ModelType.FLUX2_KLEIN:
|
| 215 |
+
# Flux2Klein uses image parameter for multi-image editing
|
| 216 |
+
inputs = {
|
| 217 |
+
"image": images,
|
| 218 |
+
"prompt": self.config.prompt,
|
| 219 |
+
"generator": generator,
|
| 220 |
+
"guidance_scale": self.config.guidance_scale,
|
| 221 |
+
"num_inference_steps": self.config.num_inference_steps,
|
| 222 |
+
"height": self.config.height,
|
| 223 |
+
"width": self.config.width,
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
with torch.inference_mode():
|
| 227 |
+
output = self.pipeline(**inputs)
|
| 228 |
+
|
| 229 |
+
return output.images[0]
|
| 230 |
+
|
| 231 |
+
def evaluate_sample(self, sample: dict, image_dir: Path) -> dict:
|
| 232 |
+
"""
|
| 233 |
+
Evaluate a single sample.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
sample: Sample metadata dict.
|
| 237 |
+
image_dir: Directory containing images.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
Evaluation result dict.
|
| 241 |
+
"""
|
| 242 |
+
# Load input images (all available in test set)
|
| 243 |
+
images = self._load_images(sample["images"], image_dir)
|
| 244 |
+
|
| 245 |
+
# Generate prediction
|
| 246 |
+
pred_image = self.predict(images)
|
| 247 |
+
|
| 248 |
+
# Save prediction image
|
| 249 |
+
pred_path = self.output_dir / f"{sample['id']:05d}_pred.png"
|
| 250 |
+
pred_image.save(pred_path)
|
| 251 |
+
|
| 252 |
+
# Extract predicted number via OCR
|
| 253 |
+
pred_number = self.ocr.extract_last_number(pred_image)
|
| 254 |
+
|
| 255 |
+
# Get ground truth
|
| 256 |
+
gt_number = sample["answer"]
|
| 257 |
+
|
| 258 |
+
# Check correctness
|
| 259 |
+
correct = pred_number == gt_number
|
| 260 |
+
|
| 261 |
+
return {
|
| 262 |
+
"id": sample["id"],
|
| 263 |
+
"seq_type": sample["seq_type"],
|
| 264 |
+
"gt_answer": gt_number,
|
| 265 |
+
"pred_answer": pred_number,
|
| 266 |
+
"correct": correct,
|
| 267 |
+
"pred_image": str(pred_path),
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
def evaluate(self, split: str = "test") -> dict:
|
| 271 |
+
"""
|
| 272 |
+
Evaluate on entire dataset split.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
split: Dataset split ("train" or "test").
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Evaluation results summary.
|
| 279 |
+
"""
|
| 280 |
+
dataset_dir = Path(self.config.dataset_dir)
|
| 281 |
+
|
| 282 |
+
# Load metadata
|
| 283 |
+
with open(dataset_dir / f"{split}.json") as f:
|
| 284 |
+
samples = json.load(f)
|
| 285 |
+
|
| 286 |
+
image_dir = dataset_dir / split / "images"
|
| 287 |
+
|
| 288 |
+
results = []
|
| 289 |
+
for sample in tqdm(samples, desc=f"Evaluating {split}"):
|
| 290 |
+
result = self.evaluate_sample(sample, image_dir)
|
| 291 |
+
results.append(result)
|
| 292 |
+
|
| 293 |
+
# Compute metrics
|
| 294 |
+
total = len(results)
|
| 295 |
+
correct = sum(r["correct"] for r in results)
|
| 296 |
+
accuracy = correct / total if total > 0 else 0.0
|
| 297 |
+
|
| 298 |
+
# Per-type accuracy
|
| 299 |
+
type_stats = {}
|
| 300 |
+
for r in results:
|
| 301 |
+
seq_type = r["seq_type"]
|
| 302 |
+
if seq_type not in type_stats:
|
| 303 |
+
type_stats[seq_type] = {"correct": 0, "total": 0}
|
| 304 |
+
type_stats[seq_type]["total"] += 1
|
| 305 |
+
if r["correct"]:
|
| 306 |
+
type_stats[seq_type]["correct"] += 1
|
| 307 |
+
|
| 308 |
+
type_accuracy = {
|
| 309 |
+
k: v["correct"] / v["total"] for k, v in type_stats.items()
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
summary = {
|
| 313 |
+
"split": split,
|
| 314 |
+
"model_type": self.config.model_type.value,
|
| 315 |
+
"model_id": self.config.model_id,
|
| 316 |
+
"total": total,
|
| 317 |
+
"correct": correct,
|
| 318 |
+
"accuracy": accuracy,
|
| 319 |
+
"type_accuracy": type_accuracy,
|
| 320 |
+
"results": results,
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
# Save results
|
| 324 |
+
with open(self.output_dir / f"{split}_results.json", "w") as f:
|
| 325 |
+
json.dump(summary, f, indent=2)
|
| 326 |
+
|
| 327 |
+
return summary
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def main():
|
| 331 |
+
"""Run evaluation."""
|
| 332 |
+
import argparse
|
| 333 |
+
|
| 334 |
+
parser = argparse.ArgumentParser(description="Sequence Prediction Evaluation")
|
| 335 |
+
parser.add_argument("--model", type=str, default="flux2-klein",
|
| 336 |
+
choices=["qwen", "flux2-klein"],
|
| 337 |
+
help="Model type to use")
|
| 338 |
+
parser.add_argument("--model-id", type=str, default="",
|
| 339 |
+
help="Custom model ID (optional)")
|
| 340 |
+
parser.add_argument("--dataset-dir", type=str, default="sequence_dataset",
|
| 341 |
+
help="Dataset directory")
|
| 342 |
+
parser.add_argument("--output-dir", type=str, default="eval_results",
|
| 343 |
+
help="Output directory")
|
| 344 |
+
parser.add_argument("--steps", type=int, default=50,
|
| 345 |
+
help="Number of inference steps")
|
| 346 |
+
parser.add_argument("--seed", type=int, default=42,
|
| 347 |
+
help="Random seed")
|
| 348 |
+
args = parser.parse_args()
|
| 349 |
+
|
| 350 |
+
config = EvalConfig(
|
| 351 |
+
dataset_dir=args.dataset_dir,
|
| 352 |
+
output_dir=args.output_dir,
|
| 353 |
+
model_type=ModelType(args.model),
|
| 354 |
+
model_id=args.model_id,
|
| 355 |
+
num_inference_steps=args.steps,
|
| 356 |
+
seed=args.seed,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
print(f"Model: {config.model_type.value} ({config.model_id})")
|
| 360 |
+
|
| 361 |
+
evaluator = SequenceEvaluator(config)
|
| 362 |
+
results = evaluator.evaluate("test")
|
| 363 |
+
|
| 364 |
+
print(f"\n{'='*50}")
|
| 365 |
+
print(f"Evaluation Results ({config.model_type.value})")
|
| 366 |
+
print(f"{'='*50}")
|
| 367 |
+
print(f"Total samples: {results['total']}")
|
| 368 |
+
print(f"Correct: {results['correct']}")
|
| 369 |
+
print(f"Accuracy: {results['accuracy']:.2%}")
|
| 370 |
+
print(f"\nPer-type accuracy:")
|
| 371 |
+
for seq_type, acc in sorted(results["type_accuracy"].items()):
|
| 372 |
+
print(f" {seq_type}: {acc:.2%}")
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
if __name__ == "__main__":
|
| 376 |
+
main()
|
sudoku/convert.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
将单个safetensors文件转换为HuggingFace Diffusers格式。
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python convert_single.py --ckpt epoch-4.safetensors --model_type Wan-T2V-14B --output_path ./output
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import torch
|
| 10 |
+
from safetensors.torch import load_file
|
| 11 |
+
from accelerate import init_empty_weights
|
| 12 |
+
|
| 13 |
+
# 从原脚本导入(或直接复制相关字典和函数)
|
| 14 |
+
from convert_wan import (
|
| 15 |
+
get_transformer_config,
|
| 16 |
+
update_state_dict_,
|
| 17 |
+
DTYPE_MAPPING,
|
| 18 |
+
)
|
| 19 |
+
from diffusers import WanTransformer3DModel, WanVACETransformer3DModel, WanAnimateTransformer3DModel
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def convert_single_checkpoint(ckpt_path: str, model_type: str, dtype: str = "bf16"):
|
| 23 |
+
"""
|
| 24 |
+
转换单个checkpoint文件为Diffusers格式Transformer。
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
ckpt_path: safetensors文件路径
|
| 28 |
+
model_type: 模型类型,如 "Wan-T2V-14B", "Wan-I2V-14B-720p" 等
|
| 29 |
+
dtype: 输出精度
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
转换后的transformer模型
|
| 33 |
+
"""
|
| 34 |
+
# 1. 获取配置和重命名规则
|
| 35 |
+
config, rename_dict, special_keys_remap = get_transformer_config(model_type)
|
| 36 |
+
diffusers_config = config["diffusers_config"]
|
| 37 |
+
|
| 38 |
+
# 2. 加载原始权重
|
| 39 |
+
state_dict = load_file(ckpt_path)
|
| 40 |
+
|
| 41 |
+
# 3. 重命名keys
|
| 42 |
+
for key in list(state_dict.keys()):
|
| 43 |
+
new_key = key
|
| 44 |
+
for old, new in rename_dict.items():
|
| 45 |
+
new_key = new_key.replace(old, new)
|
| 46 |
+
update_state_dict_(state_dict, key, new_key)
|
| 47 |
+
|
| 48 |
+
# 4. 处理特殊keys
|
| 49 |
+
for key in list(state_dict.keys()):
|
| 50 |
+
for special_key, handler_fn in special_keys_remap.items():
|
| 51 |
+
if special_key in key:
|
| 52 |
+
handler_fn(key, state_dict)
|
| 53 |
+
|
| 54 |
+
# 5. 创建模型并加载权重
|
| 55 |
+
with init_empty_weights():
|
| 56 |
+
if "Animate" in model_type:
|
| 57 |
+
transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
|
| 58 |
+
elif "VACE" in model_type:
|
| 59 |
+
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
|
| 60 |
+
else:
|
| 61 |
+
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
| 62 |
+
|
| 63 |
+
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
| 64 |
+
|
| 65 |
+
if dtype != "none":
|
| 66 |
+
transformer = transformer.to(DTYPE_MAPPING[dtype])
|
| 67 |
+
|
| 68 |
+
return transformer
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
parser = argparse.ArgumentParser()
|
| 73 |
+
parser.add_argument("--ckpt", type=str, required=True, help="safetensors文件路径")
|
| 74 |
+
parser.add_argument("--model_type", type=str, required=True, help="模型类型")
|
| 75 |
+
parser.add_argument("--output_path", type=str, required=True, help="输出目录")
|
| 76 |
+
parser.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16", "none"])
|
| 77 |
+
args = parser.parse_args()
|
| 78 |
+
|
| 79 |
+
transformer = convert_single_checkpoint(args.ckpt, args.model_type, args.dtype)
|
| 80 |
+
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
| 81 |
+
print(f"Saved to {args.output_path}")
|
sudoku/convert_wan.py
ADDED
|
@@ -0,0 +1,1287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import pathlib
|
| 3 |
+
from typing import Any, Dict, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from accelerate import init_empty_weights
|
| 7 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 8 |
+
from safetensors.torch import load_file
|
| 9 |
+
from transformers import (
|
| 10 |
+
AutoProcessor,
|
| 11 |
+
AutoTokenizer,
|
| 12 |
+
CLIPImageProcessor,
|
| 13 |
+
CLIPVisionModel,
|
| 14 |
+
CLIPVisionModelWithProjection,
|
| 15 |
+
UMT5EncoderModel,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from diffusers import (
|
| 19 |
+
AutoencoderKLWan,
|
| 20 |
+
UniPCMultistepScheduler,
|
| 21 |
+
WanAnimatePipeline,
|
| 22 |
+
WanAnimateTransformer3DModel,
|
| 23 |
+
WanImageToVideoPipeline,
|
| 24 |
+
WanPipeline,
|
| 25 |
+
WanTransformer3DModel,
|
| 26 |
+
WanVACEPipeline,
|
| 27 |
+
WanVACETransformer3DModel,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
TRANSFORMER_KEYS_RENAME_DICT = {
|
| 32 |
+
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
| 33 |
+
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
| 34 |
+
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
| 35 |
+
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
| 36 |
+
"time_projection.1": "condition_embedder.time_proj",
|
| 37 |
+
"head.modulation": "scale_shift_table",
|
| 38 |
+
"head.head": "proj_out",
|
| 39 |
+
"modulation": "scale_shift_table",
|
| 40 |
+
"ffn.0": "ffn.net.0.proj",
|
| 41 |
+
"ffn.2": "ffn.net.2",
|
| 42 |
+
# Hack to swap the layer names
|
| 43 |
+
# The original model calls the norms in following order: norm1, norm3, norm2
|
| 44 |
+
# We convert it to: norm1, norm2, norm3
|
| 45 |
+
"norm2": "norm__placeholder",
|
| 46 |
+
"norm3": "norm2",
|
| 47 |
+
"norm__placeholder": "norm3",
|
| 48 |
+
# For the I2V model
|
| 49 |
+
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
| 50 |
+
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
| 51 |
+
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
| 52 |
+
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
| 53 |
+
# for the FLF2V model
|
| 54 |
+
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
|
| 55 |
+
# Add attention component mappings
|
| 56 |
+
"self_attn.q": "attn1.to_q",
|
| 57 |
+
"self_attn.k": "attn1.to_k",
|
| 58 |
+
"self_attn.v": "attn1.to_v",
|
| 59 |
+
"self_attn.o": "attn1.to_out.0",
|
| 60 |
+
"self_attn.norm_q": "attn1.norm_q",
|
| 61 |
+
"self_attn.norm_k": "attn1.norm_k",
|
| 62 |
+
"cross_attn.q": "attn2.to_q",
|
| 63 |
+
"cross_attn.k": "attn2.to_k",
|
| 64 |
+
"cross_attn.v": "attn2.to_v",
|
| 65 |
+
"cross_attn.o": "attn2.to_out.0",
|
| 66 |
+
"cross_attn.norm_q": "attn2.norm_q",
|
| 67 |
+
"cross_attn.norm_k": "attn2.norm_k",
|
| 68 |
+
"attn2.to_k_img": "attn2.add_k_proj",
|
| 69 |
+
"attn2.to_v_img": "attn2.add_v_proj",
|
| 70 |
+
"attn2.norm_k_img": "attn2.norm_added_k",
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
VACE_TRANSFORMER_KEYS_RENAME_DICT = {
|
| 74 |
+
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
| 75 |
+
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
| 76 |
+
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
| 77 |
+
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
| 78 |
+
"time_projection.1": "condition_embedder.time_proj",
|
| 79 |
+
"head.modulation": "scale_shift_table",
|
| 80 |
+
"head.head": "proj_out",
|
| 81 |
+
"modulation": "scale_shift_table",
|
| 82 |
+
"ffn.0": "ffn.net.0.proj",
|
| 83 |
+
"ffn.2": "ffn.net.2",
|
| 84 |
+
# Hack to swap the layer names
|
| 85 |
+
# The original model calls the norms in following order: norm1, norm3, norm2
|
| 86 |
+
# We convert it to: norm1, norm2, norm3
|
| 87 |
+
"norm2": "norm__placeholder",
|
| 88 |
+
"norm3": "norm2",
|
| 89 |
+
"norm__placeholder": "norm3",
|
| 90 |
+
# # For the I2V model
|
| 91 |
+
# "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
| 92 |
+
# "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
| 93 |
+
# "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
| 94 |
+
# "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
| 95 |
+
# # for the FLF2V model
|
| 96 |
+
# "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
|
| 97 |
+
# Add attention component mappings
|
| 98 |
+
"self_attn.q": "attn1.to_q",
|
| 99 |
+
"self_attn.k": "attn1.to_k",
|
| 100 |
+
"self_attn.v": "attn1.to_v",
|
| 101 |
+
"self_attn.o": "attn1.to_out.0",
|
| 102 |
+
"self_attn.norm_q": "attn1.norm_q",
|
| 103 |
+
"self_attn.norm_k": "attn1.norm_k",
|
| 104 |
+
"cross_attn.q": "attn2.to_q",
|
| 105 |
+
"cross_attn.k": "attn2.to_k",
|
| 106 |
+
"cross_attn.v": "attn2.to_v",
|
| 107 |
+
"cross_attn.o": "attn2.to_out.0",
|
| 108 |
+
"cross_attn.norm_q": "attn2.norm_q",
|
| 109 |
+
"cross_attn.norm_k": "attn2.norm_k",
|
| 110 |
+
"attn2.to_k_img": "attn2.add_k_proj",
|
| 111 |
+
"attn2.to_v_img": "attn2.add_v_proj",
|
| 112 |
+
"attn2.norm_k_img": "attn2.norm_added_k",
|
| 113 |
+
"before_proj": "proj_in",
|
| 114 |
+
"after_proj": "proj_out",
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
|
| 118 |
+
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
| 119 |
+
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
| 120 |
+
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
| 121 |
+
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
| 122 |
+
"time_projection.1": "condition_embedder.time_proj",
|
| 123 |
+
"head.modulation": "scale_shift_table",
|
| 124 |
+
"head.head": "proj_out",
|
| 125 |
+
"modulation": "scale_shift_table",
|
| 126 |
+
"ffn.0": "ffn.net.0.proj",
|
| 127 |
+
"ffn.2": "ffn.net.2",
|
| 128 |
+
# Hack to swap the layer names
|
| 129 |
+
# The original model calls the norms in following order: norm1, norm3, norm2
|
| 130 |
+
# We convert it to: norm1, norm2, norm3
|
| 131 |
+
"norm2": "norm__placeholder",
|
| 132 |
+
"norm3": "norm2",
|
| 133 |
+
"norm__placeholder": "norm3",
|
| 134 |
+
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
| 135 |
+
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
| 136 |
+
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
| 137 |
+
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
| 138 |
+
# Add attention component mappings
|
| 139 |
+
"self_attn.q": "attn1.to_q",
|
| 140 |
+
"self_attn.k": "attn1.to_k",
|
| 141 |
+
"self_attn.v": "attn1.to_v",
|
| 142 |
+
"self_attn.o": "attn1.to_out.0",
|
| 143 |
+
"self_attn.norm_q": "attn1.norm_q",
|
| 144 |
+
"self_attn.norm_k": "attn1.norm_k",
|
| 145 |
+
"cross_attn.q": "attn2.to_q",
|
| 146 |
+
"cross_attn.k": "attn2.to_k",
|
| 147 |
+
"cross_attn.v": "attn2.to_v",
|
| 148 |
+
"cross_attn.o": "attn2.to_out.0",
|
| 149 |
+
"cross_attn.norm_q": "attn2.norm_q",
|
| 150 |
+
"cross_attn.norm_k": "attn2.norm_k",
|
| 151 |
+
"cross_attn.k_img": "attn2.to_k_img",
|
| 152 |
+
"cross_attn.v_img": "attn2.to_v_img",
|
| 153 |
+
"cross_attn.norm_k_img": "attn2.norm_k_img",
|
| 154 |
+
# After cross_attn -> attn2 rename, we need to rename the img keys
|
| 155 |
+
"attn2.to_k_img": "attn2.add_k_proj",
|
| 156 |
+
"attn2.to_v_img": "attn2.add_v_proj",
|
| 157 |
+
"attn2.norm_k_img": "attn2.norm_added_k",
|
| 158 |
+
# Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
|
| 159 |
+
# Motion encoder mappings
|
| 160 |
+
# The name mapping is complicated for the convolutional part so we handle that in its own function
|
| 161 |
+
"motion_encoder.enc.fc": "motion_encoder.motion_network",
|
| 162 |
+
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
|
| 163 |
+
# Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
|
| 164 |
+
"face_encoder.conv1_local.conv": "face_encoder.conv1_local",
|
| 165 |
+
"face_encoder.conv2.conv": "face_encoder.conv2",
|
| 166 |
+
"face_encoder.conv3.conv": "face_encoder.conv3",
|
| 167 |
+
# Face adapter mappings are handled in a separate function
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# TODO: Verify this and simplify if possible.
|
| 172 |
+
def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
|
| 173 |
+
"""
|
| 174 |
+
Convert all motion encoder weights for Animate model.
|
| 175 |
+
|
| 176 |
+
In the original model:
|
| 177 |
+
- All Linear layers in fc use EqualLinear
|
| 178 |
+
- All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
|
| 179 |
+
- Blur kernels are stored as buffers in Sequential modules
|
| 180 |
+
- ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)]
|
| 181 |
+
|
| 182 |
+
Conversion strategy:
|
| 183 |
+
1. Drop .kernel buffers (blur kernels)
|
| 184 |
+
2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
|
| 185 |
+
"""
|
| 186 |
+
# Skip if not a weight, bias, or kernel
|
| 187 |
+
if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
# Handle Blur kernel buffers from original implementation.
|
| 191 |
+
# After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
|
| 192 |
+
# Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
|
| 193 |
+
if ".kernel" in key and "motion_encoder" in key:
|
| 194 |
+
# Remove unexpected blur kernel buffers to avoid strict load errors
|
| 195 |
+
state_dict.pop(key, None)
|
| 196 |
+
return
|
| 197 |
+
|
| 198 |
+
# Rename Sequential indices to named components in ConvLayer and ResBlock
|
| 199 |
+
if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
|
| 200 |
+
parts = key.split(".")
|
| 201 |
+
|
| 202 |
+
# Find the sequential index (digit) after convs or after conv1/conv2/skip
|
| 203 |
+
# Examples:
|
| 204 |
+
# - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
|
| 205 |
+
# - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
|
| 206 |
+
# - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
|
| 207 |
+
# - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
|
| 208 |
+
# - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
|
| 209 |
+
# - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
|
| 210 |
+
# - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
|
| 211 |
+
# - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
|
| 212 |
+
# - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
|
| 213 |
+
# - enc.net_app.convs.8 -> conv_out (final conv layer)
|
| 214 |
+
|
| 215 |
+
convs_idx = parts.index("convs") if "convs" in parts else -1
|
| 216 |
+
if convs_idx >= 0 and len(parts) - convs_idx >= 2:
|
| 217 |
+
bias = False
|
| 218 |
+
# The nn.Sequential index will always follow convs
|
| 219 |
+
sequential_idx = int(parts[convs_idx + 1])
|
| 220 |
+
if sequential_idx == 0:
|
| 221 |
+
if key.endswith(".weight"):
|
| 222 |
+
new_key = "motion_encoder.conv_in.weight"
|
| 223 |
+
elif key.endswith(".bias"):
|
| 224 |
+
new_key = "motion_encoder.conv_in.act_fn.bias"
|
| 225 |
+
bias = True
|
| 226 |
+
elif sequential_idx == final_conv_idx:
|
| 227 |
+
if key.endswith(".weight"):
|
| 228 |
+
new_key = "motion_encoder.conv_out.weight"
|
| 229 |
+
else:
|
| 230 |
+
# Intermediate .convs. layers, which get mapped to .res_blocks.
|
| 231 |
+
prefix = "motion_encoder.res_blocks."
|
| 232 |
+
|
| 233 |
+
layer_name = parts[convs_idx + 2]
|
| 234 |
+
if layer_name == "skip":
|
| 235 |
+
layer_name = "conv_skip"
|
| 236 |
+
|
| 237 |
+
if key.endswith(".weight"):
|
| 238 |
+
param_name = "weight"
|
| 239 |
+
elif key.endswith(".bias"):
|
| 240 |
+
param_name = "act_fn.bias"
|
| 241 |
+
bias = True
|
| 242 |
+
|
| 243 |
+
suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
|
| 244 |
+
suffix = ".".join(suffix_parts)
|
| 245 |
+
new_key = prefix + suffix
|
| 246 |
+
|
| 247 |
+
param = state_dict.pop(key)
|
| 248 |
+
if bias:
|
| 249 |
+
param = param.squeeze()
|
| 250 |
+
state_dict[new_key] = param
|
| 251 |
+
return
|
| 252 |
+
return
|
| 253 |
+
return
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
|
| 257 |
+
"""
|
| 258 |
+
Convert face adapter weights for the Animate model.
|
| 259 |
+
|
| 260 |
+
The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
|
| 261 |
+
"""
|
| 262 |
+
# Skip if not a weight or bias
|
| 263 |
+
if ".weight" not in key and ".bias" not in key:
|
| 264 |
+
return
|
| 265 |
+
|
| 266 |
+
prefix = "face_adapter."
|
| 267 |
+
if ".fuser_blocks." in key:
|
| 268 |
+
parts = key.split(".")
|
| 269 |
+
|
| 270 |
+
module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
|
| 271 |
+
if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
|
| 272 |
+
block_idx = parts[module_list_idx + 1]
|
| 273 |
+
layer_name = parts[module_list_idx + 2]
|
| 274 |
+
param_name = parts[module_list_idx + 3]
|
| 275 |
+
|
| 276 |
+
if layer_name == "linear1_kv":
|
| 277 |
+
layer_name_k = "to_k"
|
| 278 |
+
layer_name_v = "to_v"
|
| 279 |
+
|
| 280 |
+
suffix_k = ".".join([block_idx, layer_name_k, param_name])
|
| 281 |
+
suffix_v = ".".join([block_idx, layer_name_v, param_name])
|
| 282 |
+
new_key_k = prefix + suffix_k
|
| 283 |
+
new_key_v = prefix + suffix_v
|
| 284 |
+
|
| 285 |
+
kv_proj = state_dict.pop(key)
|
| 286 |
+
k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
|
| 287 |
+
state_dict[new_key_k] = k_proj
|
| 288 |
+
state_dict[new_key_v] = v_proj
|
| 289 |
+
return
|
| 290 |
+
else:
|
| 291 |
+
if layer_name == "q_norm":
|
| 292 |
+
new_layer_name = "norm_q"
|
| 293 |
+
elif layer_name == "k_norm":
|
| 294 |
+
new_layer_name = "norm_k"
|
| 295 |
+
elif layer_name == "linear1_q":
|
| 296 |
+
new_layer_name = "to_q"
|
| 297 |
+
elif layer_name == "linear2":
|
| 298 |
+
new_layer_name = "to_out"
|
| 299 |
+
|
| 300 |
+
suffix_parts = [block_idx, new_layer_name, param_name]
|
| 301 |
+
suffix = ".".join(suffix_parts)
|
| 302 |
+
new_key = prefix + suffix
|
| 303 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 304 |
+
return
|
| 305 |
+
return
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
| 309 |
+
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
| 310 |
+
ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
| 311 |
+
"motion_encoder": convert_animate_motion_encoder_weights,
|
| 312 |
+
"face_adapter": convert_animate_face_adapter_weights,
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
| 317 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def load_sharded_safetensors(dir: pathlib.Path):
|
| 321 |
+
file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
|
| 322 |
+
state_dict = {}
|
| 323 |
+
for path in file_paths:
|
| 324 |
+
state_dict.update(load_file(path))
|
| 325 |
+
return state_dict
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
| 329 |
+
if model_type == "Wan-T2V-1.3B":
|
| 330 |
+
config = {
|
| 331 |
+
"model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
|
| 332 |
+
"diffusers_config": {
|
| 333 |
+
"added_kv_proj_dim": None,
|
| 334 |
+
"attention_head_dim": 128,
|
| 335 |
+
"cross_attn_norm": True,
|
| 336 |
+
"eps": 1e-06,
|
| 337 |
+
"ffn_dim": 8960,
|
| 338 |
+
"freq_dim": 256,
|
| 339 |
+
"in_channels": 16,
|
| 340 |
+
"num_attention_heads": 12,
|
| 341 |
+
"num_layers": 30,
|
| 342 |
+
"out_channels": 16,
|
| 343 |
+
"patch_size": [1, 2, 2],
|
| 344 |
+
"qk_norm": "rms_norm_across_heads",
|
| 345 |
+
"text_dim": 4096,
|
| 346 |
+
},
|
| 347 |
+
}
|
| 348 |
+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
| 349 |
+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
| 350 |
+
elif model_type == "Wan-T2V-14B":
|
| 351 |
+
config = {
|
| 352 |
+
"model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
|
| 353 |
+
"diffusers_config": {
|
| 354 |
+
"added_kv_proj_dim": None,
|
| 355 |
+
"attention_head_dim": 128,
|
| 356 |
+
"cross_attn_norm": True,
|
| 357 |
+
"eps": 1e-06,
|
| 358 |
+
"ffn_dim": 13824,
|
| 359 |
+
"freq_dim": 256,
|
| 360 |
+
"in_channels": 16,
|
| 361 |
+
"num_attention_heads": 40,
|
| 362 |
+
"num_layers": 40,
|
| 363 |
+
"out_channels": 16,
|
| 364 |
+
"patch_size": [1, 2, 2],
|
| 365 |
+
"qk_norm": "rms_norm_across_heads",
|
| 366 |
+
"text_dim": 4096,
|
| 367 |
+
},
|
| 368 |
+
}
|
| 369 |
+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
| 370 |
+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
| 371 |
+
elif model_type == "Wan-I2V-14B-480p":
|
| 372 |
+
config = {
|
| 373 |
+
"model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
|
| 374 |
+
"diffusers_config": {
|
| 375 |
+
"image_dim": 1280,
|
| 376 |
+
"added_kv_proj_dim": 5120,
|
| 377 |
+
"attention_head_dim": 128,
|
| 378 |
+
"cross_attn_norm": True,
|
| 379 |
+
"eps": 1e-06,
|
| 380 |
+
"ffn_dim": 13824,
|
| 381 |
+
"freq_dim": 256,
|
| 382 |
+
"in_channels": 36,
|
| 383 |
+
"num_attention_heads": 40,
|
| 384 |
+
"num_layers": 40,
|
| 385 |
+
"out_channels": 16,
|
| 386 |
+
"patch_size": [1, 2, 2],
|
| 387 |
+
"qk_norm": "rms_norm_across_heads",
|
| 388 |
+
"text_dim": 4096,
|
| 389 |
+
},
|
| 390 |
+
}
|
| 391 |
+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
| 392 |
+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
| 393 |
+
elif model_type == "Wan-I2V-14B-720p":
|
| 394 |
+
config = {
|
| 395 |
+
"model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
|
| 396 |
+
"diffusers_config": {
|
| 397 |
+
"image_dim": 1280,
|
| 398 |
+
"added_kv_proj_dim": 5120,
|
| 399 |
+
"attention_head_dim": 128,
|
| 400 |
+
"cross_attn_norm": True,
|
| 401 |
+
"eps": 1e-06,
|
| 402 |
+
"ffn_dim": 13824,
|
| 403 |
+
"freq_dim": 256,
|
| 404 |
+
"in_channels": 36,
|
| 405 |
+
"num_attention_heads": 40,
|
| 406 |
+
"num_layers": 40,
|
| 407 |
+
"out_channels": 16,
|
| 408 |
+
"patch_size": [1, 2, 2],
|
| 409 |
+
"qk_norm": "rms_norm_across_heads",
|
| 410 |
+
"text_dim": 4096,
|
| 411 |
+
},
|
| 412 |
+
}
|
| 413 |
+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
| 414 |
+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
| 415 |
+
elif model_type == "Wan-FLF2V-14B-720P":
|
| 416 |
+
config = {
|
| 417 |
+
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
|
| 418 |
+
"diffusers_config": {
|
| 419 |
+
"image_dim": 1280,
|
| 420 |
+
"added_kv_proj_dim": 5120,
|
| 421 |
+
"attention_head_dim": 128,
|
| 422 |
+
"cross_attn_norm": True,
|
| 423 |
+
"eps": 1e-06,
|
| 424 |
+
"ffn_dim": 13824,
|
| 425 |
+
"freq_dim": 256,
|
| 426 |
+
"in_channels": 36,
|
| 427 |
+
"num_attention_heads": 40,
|
| 428 |
+
"num_layers": 40,
|
| 429 |
+
"out_channels": 16,
|
| 430 |
+
"patch_size": [1, 2, 2],
|
| 431 |
+
"qk_norm": "rms_norm_across_heads",
|
| 432 |
+
"text_dim": 4096,
|
| 433 |
+
"rope_max_seq_len": 1024,
|
| 434 |
+
"pos_embed_seq_len": 257 * 2,
|
| 435 |
+
},
|
| 436 |
+
}
|
| 437 |
+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
| 438 |
+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
| 439 |
+
elif model_type == "Wan-VACE-1.3B":
|
| 440 |
+
config = {
|
| 441 |
+
"model_id": "Wan-AI/Wan2.1-VACE-1.3B",
|
| 442 |
+
"diffusers_config": {
|
| 443 |
+
"added_kv_proj_dim": None,
|
| 444 |
+
"attention_head_dim": 128,
|
| 445 |
+
"cross_attn_norm": True,
|
| 446 |
+
"eps": 1e-06,
|
| 447 |
+
"ffn_dim": 8960,
|
| 448 |
+
"freq_dim": 256,
|
| 449 |
+
"in_channels": 16,
|
| 450 |
+
"num_attention_heads": 12,
|
| 451 |
+
"num_layers": 30,
|
| 452 |
+
"out_channels": 16,
|
| 453 |
+
"patch_size": [1, 2, 2],
|
| 454 |
+
"qk_norm": "rms_norm_across_heads",
|
| 455 |
+
"text_dim": 4096,
|
| 456 |
+
"vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
|
| 457 |
+
"vace_in_channels": 96,
|
| 458 |
+
},
|
| 459 |
+
}
|
| 460 |
+
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
| 461 |
+
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
| 462 |
+
elif model_type == "Wan-VACE-14B":
|
| 463 |
+
config = {
|
| 464 |
+
"model_id": "Wan-AI/Wan2.1-VACE-14B",
|
| 465 |
+
"diffusers_config": {
|
| 466 |
+
"added_kv_proj_dim": None,
|
| 467 |
+
"attention_head_dim": 128,
|
| 468 |
+
"cross_attn_norm": True,
|
| 469 |
+
"eps": 1e-06,
|
| 470 |
+
"ffn_dim": 13824,
|
| 471 |
+
"freq_dim": 256,
|
| 472 |
+
"in_channels": 16,
|
| 473 |
+
"num_attention_heads": 40,
|
| 474 |
+
"num_layers": 40,
|
| 475 |
+
"out_channels": 16,
|
| 476 |
+
"patch_size": [1, 2, 2],
|
| 477 |
+
"qk_norm": "rms_norm_across_heads",
|
| 478 |
+
"text_dim": 4096,
|
| 479 |
+
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
|
| 480 |
+
"vace_in_channels": 96,
|
| 481 |
+
},
|
| 482 |
+
}
|
| 483 |
+
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
| 484 |
+
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
| 485 |
+
elif model_type == "Wan2.2-VACE-Fun-14B":
|
| 486 |
+
config = {
|
| 487 |
+
"model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B",
|
| 488 |
+
"diffusers_config": {
|
| 489 |
+
"added_kv_proj_dim": None,
|
| 490 |
+
"attention_head_dim": 128,
|
| 491 |
+
"cross_attn_norm": True,
|
| 492 |
+
"eps": 1e-06,
|
| 493 |
+
"ffn_dim": 13824,
|
| 494 |
+
"freq_dim": 256,
|
| 495 |
+
"in_channels": 16,
|
| 496 |
+
"num_attention_heads": 40,
|
| 497 |
+
"num_layers": 40,
|
| 498 |
+
"out_channels": 16,
|
| 499 |
+
"patch_size": [1, 2, 2],
|
| 500 |
+
"qk_norm": "rms_norm_across_heads",
|
| 501 |
+
"text_dim": 4096,
|
| 502 |
+
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
|
| 503 |
+
"vace_in_channels": 96,
|
| 504 |
+
},
|
| 505 |
+
}
|
| 506 |
+
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
| 507 |
+
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
| 508 |
+
elif model_type == "Wan2.2-I2V-14B-720p":
|
| 509 |
+
config = {
|
| 510 |
+
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
|
| 511 |
+
"diffusers_config": {
|
| 512 |
+
"added_kv_proj_dim": None,
|
| 513 |
+
"attention_head_dim": 128,
|
| 514 |
+
"cross_attn_norm": True,
|
| 515 |
+
"eps": 1e-06,
|
| 516 |
+
"ffn_dim": 13824,
|
| 517 |
+
"freq_dim": 256,
|
| 518 |
+
"in_channels": 36,
|
| 519 |
+
"num_attention_heads": 40,
|
| 520 |
+
"num_layers": 40,
|
| 521 |
+
"out_channels": 16,
|
| 522 |
+
"patch_size": [1, 2, 2],
|
| 523 |
+
"qk_norm": "rms_norm_across_heads",
|
| 524 |
+
"text_dim": 4096,
|
| 525 |
+
},
|
| 526 |
+
}
|
| 527 |
+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
| 528 |
+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
| 529 |
+
elif model_type == "Wan2.2-T2V-A14B":
|
| 530 |
+
config = {
|
| 531 |
+
"model_id": "Wan-AI/Wan2.2-T2V-A14B",
|
| 532 |
+
"diffusers_config": {
|
| 533 |
+
"added_kv_proj_dim": None,
|
| 534 |
+
"attention_head_dim": 128,
|
| 535 |
+
"cross_attn_norm": True,
|
| 536 |
+
"eps": 1e-06,
|
| 537 |
+
"ffn_dim": 13824,
|
| 538 |
+
"freq_dim": 256,
|
| 539 |
+
"in_channels": 16,
|
| 540 |
+
"num_attention_heads": 40,
|
| 541 |
+
"num_layers": 40,
|
| 542 |
+
"out_channels": 16,
|
| 543 |
+
"patch_size": [1, 2, 2],
|
| 544 |
+
"qk_norm": "rms_norm_across_heads",
|
| 545 |
+
"text_dim": 4096,
|
| 546 |
+
},
|
| 547 |
+
}
|
| 548 |
+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
| 549 |
+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
| 550 |
+
elif model_type == "Wan2.2-TI2V-5B":
|
| 551 |
+
config = {
|
| 552 |
+
"model_id": "Wan-AI/Wan2.2-TI2V-5B",
|
| 553 |
+
"diffusers_config": {
|
| 554 |
+
"added_kv_proj_dim": None,
|
| 555 |
+
"attention_head_dim": 128,
|
| 556 |
+
"cross_attn_norm": True,
|
| 557 |
+
"eps": 1e-06,
|
| 558 |
+
"ffn_dim": 14336,
|
| 559 |
+
"freq_dim": 256,
|
| 560 |
+
"in_channels": 48,
|
| 561 |
+
"num_attention_heads": 24,
|
| 562 |
+
"num_layers": 30,
|
| 563 |
+
"out_channels": 48,
|
| 564 |
+
"patch_size": [1, 2, 2],
|
| 565 |
+
"qk_norm": "rms_norm_across_heads",
|
| 566 |
+
"text_dim": 4096,
|
| 567 |
+
},
|
| 568 |
+
}
|
| 569 |
+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
| 570 |
+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
| 571 |
+
elif model_type == "Wan2.2-Animate-14B":
|
| 572 |
+
config = {
|
| 573 |
+
"model_id": "Wan-AI/Wan2.2-Animate-14B",
|
| 574 |
+
"diffusers_config": {
|
| 575 |
+
"image_dim": 1280,
|
| 576 |
+
"added_kv_proj_dim": 5120,
|
| 577 |
+
"attention_head_dim": 128,
|
| 578 |
+
"cross_attn_norm": True,
|
| 579 |
+
"eps": 1e-06,
|
| 580 |
+
"ffn_dim": 13824,
|
| 581 |
+
"freq_dim": 256,
|
| 582 |
+
"in_channels": 36,
|
| 583 |
+
"num_attention_heads": 40,
|
| 584 |
+
"num_layers": 40,
|
| 585 |
+
"out_channels": 16,
|
| 586 |
+
"patch_size": (1, 2, 2),
|
| 587 |
+
"qk_norm": "rms_norm_across_heads",
|
| 588 |
+
"text_dim": 4096,
|
| 589 |
+
"rope_max_seq_len": 1024,
|
| 590 |
+
"pos_embed_seq_len": None,
|
| 591 |
+
"motion_encoder_size": 512, # Start of Wan Animate-specific configs
|
| 592 |
+
"motion_style_dim": 512,
|
| 593 |
+
"motion_dim": 20,
|
| 594 |
+
"motion_encoder_dim": 512,
|
| 595 |
+
"face_encoder_hidden_dim": 1024,
|
| 596 |
+
"face_encoder_num_heads": 4,
|
| 597 |
+
"inject_face_latents_blocks": 5,
|
| 598 |
+
},
|
| 599 |
+
}
|
| 600 |
+
RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
|
| 601 |
+
SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
| 602 |
+
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def convert_transformer(model_type: str, stage: str = None):
|
| 606 |
+
config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)
|
| 607 |
+
|
| 608 |
+
diffusers_config = config["diffusers_config"]
|
| 609 |
+
model_id = config["model_id"]
|
| 610 |
+
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
|
| 611 |
+
|
| 612 |
+
if stage is not None:
|
| 613 |
+
model_dir = model_dir / stage
|
| 614 |
+
|
| 615 |
+
original_state_dict = load_sharded_safetensors(model_dir)
|
| 616 |
+
|
| 617 |
+
with init_empty_weights():
|
| 618 |
+
if "Animate" in model_type:
|
| 619 |
+
transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
|
| 620 |
+
elif "VACE" in model_type:
|
| 621 |
+
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
|
| 622 |
+
else:
|
| 623 |
+
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
| 624 |
+
|
| 625 |
+
for key in list(original_state_dict.keys()):
|
| 626 |
+
new_key = key[:]
|
| 627 |
+
for replace_key, rename_key in RENAME_DICT.items():
|
| 628 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 629 |
+
update_state_dict_(original_state_dict, key, new_key)
|
| 630 |
+
|
| 631 |
+
for key in list(original_state_dict.keys()):
|
| 632 |
+
for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items():
|
| 633 |
+
if special_key not in key:
|
| 634 |
+
continue
|
| 635 |
+
handler_fn_inplace(key, original_state_dict)
|
| 636 |
+
|
| 637 |
+
# Load state dict into the meta model, which will materialize the tensors
|
| 638 |
+
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
| 639 |
+
|
| 640 |
+
# Move to CPU to ensure all tensors are materialized
|
| 641 |
+
transformer = transformer.to("cpu")
|
| 642 |
+
|
| 643 |
+
return transformer
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
def convert_vae():
|
| 647 |
+
vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth")
|
| 648 |
+
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
|
| 649 |
+
new_state_dict = {}
|
| 650 |
+
|
| 651 |
+
# Create mappings for specific components
|
| 652 |
+
middle_key_mapping = {
|
| 653 |
+
# Encoder middle block
|
| 654 |
+
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
|
| 655 |
+
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
|
| 656 |
+
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
|
| 657 |
+
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
|
| 658 |
+
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
|
| 659 |
+
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
|
| 660 |
+
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
|
| 661 |
+
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
|
| 662 |
+
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
|
| 663 |
+
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
|
| 664 |
+
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
|
| 665 |
+
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
|
| 666 |
+
# Decoder middle block
|
| 667 |
+
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
|
| 668 |
+
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
|
| 669 |
+
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
|
| 670 |
+
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
|
| 671 |
+
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
|
| 672 |
+
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
|
| 673 |
+
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
|
| 674 |
+
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
|
| 675 |
+
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
|
| 676 |
+
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
|
| 677 |
+
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
|
| 678 |
+
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
# Create a mapping for attention blocks
|
| 682 |
+
attention_mapping = {
|
| 683 |
+
# Encoder middle attention
|
| 684 |
+
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
|
| 685 |
+
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
|
| 686 |
+
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
|
| 687 |
+
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
|
| 688 |
+
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
|
| 689 |
+
# Decoder middle attention
|
| 690 |
+
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
|
| 691 |
+
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
|
| 692 |
+
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
|
| 693 |
+
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
|
| 694 |
+
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
|
| 695 |
+
}
|
| 696 |
+
|
| 697 |
+
# Create a mapping for the head components
|
| 698 |
+
head_mapping = {
|
| 699 |
+
# Encoder head
|
| 700 |
+
"encoder.head.0.gamma": "encoder.norm_out.gamma",
|
| 701 |
+
"encoder.head.2.bias": "encoder.conv_out.bias",
|
| 702 |
+
"encoder.head.2.weight": "encoder.conv_out.weight",
|
| 703 |
+
# Decoder head
|
| 704 |
+
"decoder.head.0.gamma": "decoder.norm_out.gamma",
|
| 705 |
+
"decoder.head.2.bias": "decoder.conv_out.bias",
|
| 706 |
+
"decoder.head.2.weight": "decoder.conv_out.weight",
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
# Create a mapping for the quant components
|
| 710 |
+
quant_mapping = {
|
| 711 |
+
"conv1.weight": "quant_conv.weight",
|
| 712 |
+
"conv1.bias": "quant_conv.bias",
|
| 713 |
+
"conv2.weight": "post_quant_conv.weight",
|
| 714 |
+
"conv2.bias": "post_quant_conv.bias",
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
# Process each key in the state dict
|
| 718 |
+
for key, value in old_state_dict.items():
|
| 719 |
+
# Handle middle block keys using the mapping
|
| 720 |
+
if key in middle_key_mapping:
|
| 721 |
+
new_key = middle_key_mapping[key]
|
| 722 |
+
new_state_dict[new_key] = value
|
| 723 |
+
# Handle attention blocks using the mapping
|
| 724 |
+
elif key in attention_mapping:
|
| 725 |
+
new_key = attention_mapping[key]
|
| 726 |
+
new_state_dict[new_key] = value
|
| 727 |
+
# Handle head keys using the mapping
|
| 728 |
+
elif key in head_mapping:
|
| 729 |
+
new_key = head_mapping[key]
|
| 730 |
+
new_state_dict[new_key] = value
|
| 731 |
+
# Handle quant keys using the mapping
|
| 732 |
+
elif key in quant_mapping:
|
| 733 |
+
new_key = quant_mapping[key]
|
| 734 |
+
new_state_dict[new_key] = value
|
| 735 |
+
# Handle encoder conv1
|
| 736 |
+
elif key == "encoder.conv1.weight":
|
| 737 |
+
new_state_dict["encoder.conv_in.weight"] = value
|
| 738 |
+
elif key == "encoder.conv1.bias":
|
| 739 |
+
new_state_dict["encoder.conv_in.bias"] = value
|
| 740 |
+
# Handle decoder conv1
|
| 741 |
+
elif key == "decoder.conv1.weight":
|
| 742 |
+
new_state_dict["decoder.conv_in.weight"] = value
|
| 743 |
+
elif key == "decoder.conv1.bias":
|
| 744 |
+
new_state_dict["decoder.conv_in.bias"] = value
|
| 745 |
+
# Handle encoder downsamples
|
| 746 |
+
elif key.startswith("encoder.downsamples."):
|
| 747 |
+
# Convert to down_blocks
|
| 748 |
+
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
|
| 749 |
+
|
| 750 |
+
# Convert residual block naming but keep the original structure
|
| 751 |
+
if ".residual.0.gamma" in new_key:
|
| 752 |
+
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
| 753 |
+
elif ".residual.2.bias" in new_key:
|
| 754 |
+
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
|
| 755 |
+
elif ".residual.2.weight" in new_key:
|
| 756 |
+
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
|
| 757 |
+
elif ".residual.3.gamma" in new_key:
|
| 758 |
+
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
|
| 759 |
+
elif ".residual.6.bias" in new_key:
|
| 760 |
+
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
|
| 761 |
+
elif ".residual.6.weight" in new_key:
|
| 762 |
+
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
|
| 763 |
+
elif ".shortcut.bias" in new_key:
|
| 764 |
+
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
| 765 |
+
elif ".shortcut.weight" in new_key:
|
| 766 |
+
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
| 767 |
+
|
| 768 |
+
new_state_dict[new_key] = value
|
| 769 |
+
|
| 770 |
+
# Handle decoder upsamples
|
| 771 |
+
elif key.startswith("decoder.upsamples."):
|
| 772 |
+
# Convert to up_blocks
|
| 773 |
+
parts = key.split(".")
|
| 774 |
+
block_idx = int(parts[2])
|
| 775 |
+
|
| 776 |
+
# Group residual blocks
|
| 777 |
+
if "residual" in key:
|
| 778 |
+
if block_idx in [0, 1, 2]:
|
| 779 |
+
new_block_idx = 0
|
| 780 |
+
resnet_idx = block_idx
|
| 781 |
+
elif block_idx in [4, 5, 6]:
|
| 782 |
+
new_block_idx = 1
|
| 783 |
+
resnet_idx = block_idx - 4
|
| 784 |
+
elif block_idx in [8, 9, 10]:
|
| 785 |
+
new_block_idx = 2
|
| 786 |
+
resnet_idx = block_idx - 8
|
| 787 |
+
elif block_idx in [12, 13, 14]:
|
| 788 |
+
new_block_idx = 3
|
| 789 |
+
resnet_idx = block_idx - 12
|
| 790 |
+
else:
|
| 791 |
+
# Keep as is for other blocks
|
| 792 |
+
new_state_dict[key] = value
|
| 793 |
+
continue
|
| 794 |
+
|
| 795 |
+
# Convert residual block naming
|
| 796 |
+
if ".residual.0.gamma" in key:
|
| 797 |
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
|
| 798 |
+
elif ".residual.2.bias" in key:
|
| 799 |
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
|
| 800 |
+
elif ".residual.2.weight" in key:
|
| 801 |
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
|
| 802 |
+
elif ".residual.3.gamma" in key:
|
| 803 |
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
|
| 804 |
+
elif ".residual.6.bias" in key:
|
| 805 |
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
|
| 806 |
+
elif ".residual.6.weight" in key:
|
| 807 |
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
|
| 808 |
+
else:
|
| 809 |
+
new_key = key
|
| 810 |
+
|
| 811 |
+
new_state_dict[new_key] = value
|
| 812 |
+
|
| 813 |
+
# Handle shortcut connections
|
| 814 |
+
elif ".shortcut." in key:
|
| 815 |
+
if block_idx == 4:
|
| 816 |
+
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
|
| 817 |
+
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
|
| 818 |
+
else:
|
| 819 |
+
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
| 820 |
+
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
|
| 821 |
+
|
| 822 |
+
new_state_dict[new_key] = value
|
| 823 |
+
|
| 824 |
+
# Handle upsamplers
|
| 825 |
+
elif ".resample." in key or ".time_conv." in key:
|
| 826 |
+
if block_idx == 3:
|
| 827 |
+
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
|
| 828 |
+
elif block_idx == 7:
|
| 829 |
+
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
|
| 830 |
+
elif block_idx == 11:
|
| 831 |
+
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
|
| 832 |
+
else:
|
| 833 |
+
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
| 834 |
+
|
| 835 |
+
new_state_dict[new_key] = value
|
| 836 |
+
else:
|
| 837 |
+
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
| 838 |
+
new_state_dict[new_key] = value
|
| 839 |
+
else:
|
| 840 |
+
# Keep other keys unchanged
|
| 841 |
+
new_state_dict[key] = value
|
| 842 |
+
|
| 843 |
+
with init_empty_weights():
|
| 844 |
+
vae = AutoencoderKLWan()
|
| 845 |
+
vae.load_state_dict(new_state_dict, strict=True, assign=True)
|
| 846 |
+
return vae
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
vae22_diffusers_config = {
|
| 850 |
+
"base_dim": 160,
|
| 851 |
+
"z_dim": 48,
|
| 852 |
+
"is_residual": True,
|
| 853 |
+
"in_channels": 12,
|
| 854 |
+
"out_channels": 12,
|
| 855 |
+
"decoder_base_dim": 256,
|
| 856 |
+
"scale_factor_temporal": 4,
|
| 857 |
+
"scale_factor_spatial": 16,
|
| 858 |
+
"patch_size": 2,
|
| 859 |
+
"latents_mean": [
|
| 860 |
+
-0.2289,
|
| 861 |
+
-0.0052,
|
| 862 |
+
-0.1323,
|
| 863 |
+
-0.2339,
|
| 864 |
+
-0.2799,
|
| 865 |
+
0.0174,
|
| 866 |
+
0.1838,
|
| 867 |
+
0.1557,
|
| 868 |
+
-0.1382,
|
| 869 |
+
0.0542,
|
| 870 |
+
0.2813,
|
| 871 |
+
0.0891,
|
| 872 |
+
0.1570,
|
| 873 |
+
-0.0098,
|
| 874 |
+
0.0375,
|
| 875 |
+
-0.1825,
|
| 876 |
+
-0.2246,
|
| 877 |
+
-0.1207,
|
| 878 |
+
-0.0698,
|
| 879 |
+
0.5109,
|
| 880 |
+
0.2665,
|
| 881 |
+
-0.2108,
|
| 882 |
+
-0.2158,
|
| 883 |
+
0.2502,
|
| 884 |
+
-0.2055,
|
| 885 |
+
-0.0322,
|
| 886 |
+
0.1109,
|
| 887 |
+
0.1567,
|
| 888 |
+
-0.0729,
|
| 889 |
+
0.0899,
|
| 890 |
+
-0.2799,
|
| 891 |
+
-0.1230,
|
| 892 |
+
-0.0313,
|
| 893 |
+
-0.1649,
|
| 894 |
+
0.0117,
|
| 895 |
+
0.0723,
|
| 896 |
+
-0.2839,
|
| 897 |
+
-0.2083,
|
| 898 |
+
-0.0520,
|
| 899 |
+
0.3748,
|
| 900 |
+
0.0152,
|
| 901 |
+
0.1957,
|
| 902 |
+
0.1433,
|
| 903 |
+
-0.2944,
|
| 904 |
+
0.3573,
|
| 905 |
+
-0.0548,
|
| 906 |
+
-0.1681,
|
| 907 |
+
-0.0667,
|
| 908 |
+
],
|
| 909 |
+
"latents_std": [
|
| 910 |
+
0.4765,
|
| 911 |
+
1.0364,
|
| 912 |
+
0.4514,
|
| 913 |
+
1.1677,
|
| 914 |
+
0.5313,
|
| 915 |
+
0.4990,
|
| 916 |
+
0.4818,
|
| 917 |
+
0.5013,
|
| 918 |
+
0.8158,
|
| 919 |
+
1.0344,
|
| 920 |
+
0.5894,
|
| 921 |
+
1.0901,
|
| 922 |
+
0.6885,
|
| 923 |
+
0.6165,
|
| 924 |
+
0.8454,
|
| 925 |
+
0.4978,
|
| 926 |
+
0.5759,
|
| 927 |
+
0.3523,
|
| 928 |
+
0.7135,
|
| 929 |
+
0.6804,
|
| 930 |
+
0.5833,
|
| 931 |
+
1.4146,
|
| 932 |
+
0.8986,
|
| 933 |
+
0.5659,
|
| 934 |
+
0.7069,
|
| 935 |
+
0.5338,
|
| 936 |
+
0.4889,
|
| 937 |
+
0.4917,
|
| 938 |
+
0.4069,
|
| 939 |
+
0.4999,
|
| 940 |
+
0.6866,
|
| 941 |
+
0.4093,
|
| 942 |
+
0.5709,
|
| 943 |
+
0.6065,
|
| 944 |
+
0.6415,
|
| 945 |
+
0.4944,
|
| 946 |
+
0.5726,
|
| 947 |
+
1.2042,
|
| 948 |
+
0.5458,
|
| 949 |
+
1.6887,
|
| 950 |
+
0.3971,
|
| 951 |
+
1.0600,
|
| 952 |
+
0.3943,
|
| 953 |
+
0.5537,
|
| 954 |
+
0.5444,
|
| 955 |
+
0.4089,
|
| 956 |
+
0.7468,
|
| 957 |
+
0.7744,
|
| 958 |
+
],
|
| 959 |
+
"clip_output": False,
|
| 960 |
+
}
|
| 961 |
+
|
| 962 |
+
|
| 963 |
+
def convert_vae_22():
|
| 964 |
+
vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.2-TI2V-5B", "Wan2.2_VAE.pth")
|
| 965 |
+
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
|
| 966 |
+
new_state_dict = {}
|
| 967 |
+
|
| 968 |
+
# Create mappings for specific components
|
| 969 |
+
middle_key_mapping = {
|
| 970 |
+
# Encoder middle block
|
| 971 |
+
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
|
| 972 |
+
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
|
| 973 |
+
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
|
| 974 |
+
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
|
| 975 |
+
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
|
| 976 |
+
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
|
| 977 |
+
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
|
| 978 |
+
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
|
| 979 |
+
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
|
| 980 |
+
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
|
| 981 |
+
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
|
| 982 |
+
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
|
| 983 |
+
# Decoder middle block
|
| 984 |
+
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
|
| 985 |
+
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
|
| 986 |
+
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
|
| 987 |
+
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
|
| 988 |
+
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
|
| 989 |
+
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
|
| 990 |
+
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
|
| 991 |
+
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
|
| 992 |
+
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
|
| 993 |
+
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
|
| 994 |
+
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
|
| 995 |
+
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
|
| 996 |
+
}
|
| 997 |
+
|
| 998 |
+
# Create a mapping for attention blocks
|
| 999 |
+
attention_mapping = {
|
| 1000 |
+
# Encoder middle attention
|
| 1001 |
+
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
|
| 1002 |
+
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
|
| 1003 |
+
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
|
| 1004 |
+
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
|
| 1005 |
+
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
|
| 1006 |
+
# Decoder middle attention
|
| 1007 |
+
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
|
| 1008 |
+
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
|
| 1009 |
+
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
|
| 1010 |
+
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
|
| 1011 |
+
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
|
| 1012 |
+
}
|
| 1013 |
+
|
| 1014 |
+
# Create a mapping for the head components
|
| 1015 |
+
head_mapping = {
|
| 1016 |
+
# Encoder head
|
| 1017 |
+
"encoder.head.0.gamma": "encoder.norm_out.gamma",
|
| 1018 |
+
"encoder.head.2.bias": "encoder.conv_out.bias",
|
| 1019 |
+
"encoder.head.2.weight": "encoder.conv_out.weight",
|
| 1020 |
+
# Decoder head
|
| 1021 |
+
"decoder.head.0.gamma": "decoder.norm_out.gamma",
|
| 1022 |
+
"decoder.head.2.bias": "decoder.conv_out.bias",
|
| 1023 |
+
"decoder.head.2.weight": "decoder.conv_out.weight",
|
| 1024 |
+
}
|
| 1025 |
+
|
| 1026 |
+
# Create a mapping for the quant components
|
| 1027 |
+
quant_mapping = {
|
| 1028 |
+
"conv1.weight": "quant_conv.weight",
|
| 1029 |
+
"conv1.bias": "quant_conv.bias",
|
| 1030 |
+
"conv2.weight": "post_quant_conv.weight",
|
| 1031 |
+
"conv2.bias": "post_quant_conv.bias",
|
| 1032 |
+
}
|
| 1033 |
+
|
| 1034 |
+
# Process each key in the state dict
|
| 1035 |
+
for key, value in old_state_dict.items():
|
| 1036 |
+
# Handle middle block keys using the mapping
|
| 1037 |
+
if key in middle_key_mapping:
|
| 1038 |
+
new_key = middle_key_mapping[key]
|
| 1039 |
+
new_state_dict[new_key] = value
|
| 1040 |
+
# Handle attention blocks using the mapping
|
| 1041 |
+
elif key in attention_mapping:
|
| 1042 |
+
new_key = attention_mapping[key]
|
| 1043 |
+
new_state_dict[new_key] = value
|
| 1044 |
+
# Handle head keys using the mapping
|
| 1045 |
+
elif key in head_mapping:
|
| 1046 |
+
new_key = head_mapping[key]
|
| 1047 |
+
new_state_dict[new_key] = value
|
| 1048 |
+
# Handle quant keys using the mapping
|
| 1049 |
+
elif key in quant_mapping:
|
| 1050 |
+
new_key = quant_mapping[key]
|
| 1051 |
+
new_state_dict[new_key] = value
|
| 1052 |
+
# Handle encoder conv1
|
| 1053 |
+
elif key == "encoder.conv1.weight":
|
| 1054 |
+
new_state_dict["encoder.conv_in.weight"] = value
|
| 1055 |
+
elif key == "encoder.conv1.bias":
|
| 1056 |
+
new_state_dict["encoder.conv_in.bias"] = value
|
| 1057 |
+
# Handle decoder conv1
|
| 1058 |
+
elif key == "decoder.conv1.weight":
|
| 1059 |
+
new_state_dict["decoder.conv_in.weight"] = value
|
| 1060 |
+
elif key == "decoder.conv1.bias":
|
| 1061 |
+
new_state_dict["decoder.conv_in.bias"] = value
|
| 1062 |
+
# Handle encoder downsamples
|
| 1063 |
+
elif key.startswith("encoder.downsamples."):
|
| 1064 |
+
# Change encoder.downsamples to encoder.down_blocks
|
| 1065 |
+
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
|
| 1066 |
+
|
| 1067 |
+
# Handle residual blocks - change downsamples to resnets and rename components
|
| 1068 |
+
if "residual" in new_key or "shortcut" in new_key:
|
| 1069 |
+
# Change the second downsamples to resnets
|
| 1070 |
+
new_key = new_key.replace(".downsamples.", ".resnets.")
|
| 1071 |
+
|
| 1072 |
+
# Rename residual components
|
| 1073 |
+
if ".residual.0.gamma" in new_key:
|
| 1074 |
+
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
| 1075 |
+
elif ".residual.2.weight" in new_key:
|
| 1076 |
+
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
|
| 1077 |
+
elif ".residual.2.bias" in new_key:
|
| 1078 |
+
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
|
| 1079 |
+
elif ".residual.3.gamma" in new_key:
|
| 1080 |
+
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
|
| 1081 |
+
elif ".residual.6.weight" in new_key:
|
| 1082 |
+
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
|
| 1083 |
+
elif ".residual.6.bias" in new_key:
|
| 1084 |
+
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
|
| 1085 |
+
elif ".shortcut.weight" in new_key:
|
| 1086 |
+
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
| 1087 |
+
elif ".shortcut.bias" in new_key:
|
| 1088 |
+
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
| 1089 |
+
|
| 1090 |
+
# Handle resample blocks - change downsamples to downsampler and remove index
|
| 1091 |
+
elif "resample" in new_key or "time_conv" in new_key:
|
| 1092 |
+
# Change the second downsamples to downsampler and remove the index
|
| 1093 |
+
parts = new_key.split(".")
|
| 1094 |
+
# Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
|
| 1095 |
+
# We want to change it to: encoder.down_blocks.X.downsampler.resample...
|
| 1096 |
+
if len(parts) >= 4 and parts[3] == "downsamples":
|
| 1097 |
+
# Remove the index (parts[4]) and change downsamples to downsampler
|
| 1098 |
+
new_parts = parts[:3] + ["downsampler"] + parts[5:]
|
| 1099 |
+
new_key = ".".join(new_parts)
|
| 1100 |
+
|
| 1101 |
+
new_state_dict[new_key] = value
|
| 1102 |
+
|
| 1103 |
+
# Handle decoder upsamples
|
| 1104 |
+
elif key.startswith("decoder.upsamples."):
|
| 1105 |
+
# Change decoder.upsamples to decoder.up_blocks
|
| 1106 |
+
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
| 1107 |
+
|
| 1108 |
+
# Handle residual blocks - change upsamples to resnets and rename components
|
| 1109 |
+
if "residual" in new_key or "shortcut" in new_key:
|
| 1110 |
+
# Change the second upsamples to resnets
|
| 1111 |
+
new_key = new_key.replace(".upsamples.", ".resnets.")
|
| 1112 |
+
|
| 1113 |
+
# Rename residual components
|
| 1114 |
+
if ".residual.0.gamma" in new_key:
|
| 1115 |
+
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
| 1116 |
+
elif ".residual.2.weight" in new_key:
|
| 1117 |
+
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
|
| 1118 |
+
elif ".residual.2.bias" in new_key:
|
| 1119 |
+
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
|
| 1120 |
+
elif ".residual.3.gamma" in new_key:
|
| 1121 |
+
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
|
| 1122 |
+
elif ".residual.6.weight" in new_key:
|
| 1123 |
+
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
|
| 1124 |
+
elif ".residual.6.bias" in new_key:
|
| 1125 |
+
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
|
| 1126 |
+
elif ".shortcut.weight" in new_key:
|
| 1127 |
+
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
| 1128 |
+
elif ".shortcut.bias" in new_key:
|
| 1129 |
+
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
| 1130 |
+
|
| 1131 |
+
# Handle resample blocks - change upsamples to upsampler and remove index
|
| 1132 |
+
elif "resample" in new_key or "time_conv" in new_key:
|
| 1133 |
+
# Change the second upsamples to upsampler and remove the index
|
| 1134 |
+
parts = new_key.split(".")
|
| 1135 |
+
# Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
|
| 1136 |
+
# We want to change it to: encoder.down_blocks.X.downsampler.resample...
|
| 1137 |
+
if len(parts) >= 4 and parts[3] == "upsamples":
|
| 1138 |
+
# Remove the index (parts[4]) and change upsamples to upsampler
|
| 1139 |
+
new_parts = parts[:3] + ["upsampler"] + parts[5:]
|
| 1140 |
+
new_key = ".".join(new_parts)
|
| 1141 |
+
|
| 1142 |
+
new_state_dict[new_key] = value
|
| 1143 |
+
else:
|
| 1144 |
+
# Keep other keys unchanged
|
| 1145 |
+
new_state_dict[key] = value
|
| 1146 |
+
|
| 1147 |
+
with init_empty_weights():
|
| 1148 |
+
vae = AutoencoderKLWan(**vae22_diffusers_config)
|
| 1149 |
+
vae.load_state_dict(new_state_dict, strict=True, assign=True)
|
| 1150 |
+
return vae
|
| 1151 |
+
|
| 1152 |
+
|
| 1153 |
+
def get_args():
|
| 1154 |
+
parser = argparse.ArgumentParser()
|
| 1155 |
+
parser.add_argument("--model_type", type=str, default=None)
|
| 1156 |
+
parser.add_argument("--output_path", type=str, required=True)
|
| 1157 |
+
parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"])
|
| 1158 |
+
return parser.parse_args()
|
| 1159 |
+
|
| 1160 |
+
|
| 1161 |
+
DTYPE_MAPPING = {
|
| 1162 |
+
"fp32": torch.float32,
|
| 1163 |
+
"fp16": torch.float16,
|
| 1164 |
+
"bf16": torch.bfloat16,
|
| 1165 |
+
}
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
if __name__ == "__main__":
|
| 1169 |
+
args = get_args()
|
| 1170 |
+
|
| 1171 |
+
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type:
|
| 1172 |
+
transformer = convert_transformer(args.model_type, stage="high_noise_model")
|
| 1173 |
+
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
|
| 1174 |
+
else:
|
| 1175 |
+
transformer = convert_transformer(args.model_type)
|
| 1176 |
+
transformer_2 = None
|
| 1177 |
+
|
| 1178 |
+
if "Wan2.2" in args.model_type and "TI2V" in args.model_type:
|
| 1179 |
+
vae = convert_vae_22()
|
| 1180 |
+
else:
|
| 1181 |
+
vae = convert_vae()
|
| 1182 |
+
|
| 1183 |
+
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
|
| 1184 |
+
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
| 1185 |
+
if "FLF2V" in args.model_type:
|
| 1186 |
+
flow_shift = 16.0
|
| 1187 |
+
elif "TI2V" in args.model_type or "Animate" in args.model_type:
|
| 1188 |
+
flow_shift = 5.0
|
| 1189 |
+
else:
|
| 1190 |
+
flow_shift = 3.0
|
| 1191 |
+
scheduler = UniPCMultistepScheduler(
|
| 1192 |
+
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
|
| 1193 |
+
)
|
| 1194 |
+
|
| 1195 |
+
# If user has specified "none", we keep the original dtypes of the state dict without any conversion
|
| 1196 |
+
if args.dtype != "none":
|
| 1197 |
+
dtype = DTYPE_MAPPING[args.dtype]
|
| 1198 |
+
transformer.to(dtype)
|
| 1199 |
+
if transformer_2 is not None:
|
| 1200 |
+
transformer_2.to(dtype)
|
| 1201 |
+
|
| 1202 |
+
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
|
| 1203 |
+
pipe = WanImageToVideoPipeline(
|
| 1204 |
+
transformer=transformer,
|
| 1205 |
+
transformer_2=transformer_2,
|
| 1206 |
+
text_encoder=text_encoder,
|
| 1207 |
+
tokenizer=tokenizer,
|
| 1208 |
+
vae=vae,
|
| 1209 |
+
scheduler=scheduler,
|
| 1210 |
+
boundary_ratio=0.9,
|
| 1211 |
+
)
|
| 1212 |
+
elif "Wan2.2" and "T2V" in args.model_type:
|
| 1213 |
+
pipe = WanPipeline(
|
| 1214 |
+
transformer=transformer,
|
| 1215 |
+
transformer_2=transformer_2,
|
| 1216 |
+
text_encoder=text_encoder,
|
| 1217 |
+
tokenizer=tokenizer,
|
| 1218 |
+
vae=vae,
|
| 1219 |
+
scheduler=scheduler,
|
| 1220 |
+
boundary_ratio=0.875,
|
| 1221 |
+
)
|
| 1222 |
+
elif "Wan2.2" and "TI2V" in args.model_type:
|
| 1223 |
+
pipe = WanPipeline(
|
| 1224 |
+
transformer=transformer,
|
| 1225 |
+
text_encoder=text_encoder,
|
| 1226 |
+
tokenizer=tokenizer,
|
| 1227 |
+
vae=vae,
|
| 1228 |
+
scheduler=scheduler,
|
| 1229 |
+
expand_timesteps=True,
|
| 1230 |
+
)
|
| 1231 |
+
elif "I2V" in args.model_type or "FLF2V" in args.model_type:
|
| 1232 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 1233 |
+
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
| 1234 |
+
)
|
| 1235 |
+
image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
| 1236 |
+
pipe = WanImageToVideoPipeline(
|
| 1237 |
+
transformer=transformer,
|
| 1238 |
+
text_encoder=text_encoder,
|
| 1239 |
+
tokenizer=tokenizer,
|
| 1240 |
+
vae=vae,
|
| 1241 |
+
scheduler=scheduler,
|
| 1242 |
+
image_encoder=image_encoder,
|
| 1243 |
+
image_processor=image_processor,
|
| 1244 |
+
)
|
| 1245 |
+
elif "Wan2.2-VACE" in args.model_type:
|
| 1246 |
+
pipe = WanVACEPipeline(
|
| 1247 |
+
transformer=transformer,
|
| 1248 |
+
transformer_2=transformer_2,
|
| 1249 |
+
text_encoder=text_encoder,
|
| 1250 |
+
tokenizer=tokenizer,
|
| 1251 |
+
vae=vae,
|
| 1252 |
+
scheduler=scheduler,
|
| 1253 |
+
boundary_ratio=0.875,
|
| 1254 |
+
)
|
| 1255 |
+
elif "Wan-VACE" in args.model_type:
|
| 1256 |
+
pipe = WanVACEPipeline(
|
| 1257 |
+
transformer=transformer,
|
| 1258 |
+
text_encoder=text_encoder,
|
| 1259 |
+
tokenizer=tokenizer,
|
| 1260 |
+
vae=vae,
|
| 1261 |
+
scheduler=scheduler,
|
| 1262 |
+
)
|
| 1263 |
+
elif "Animate" in args.model_type:
|
| 1264 |
+
image_encoder = CLIPVisionModel.from_pretrained(
|
| 1265 |
+
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
| 1266 |
+
)
|
| 1267 |
+
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
| 1268 |
+
|
| 1269 |
+
pipe = WanAnimatePipeline(
|
| 1270 |
+
transformer=transformer,
|
| 1271 |
+
text_encoder=text_encoder,
|
| 1272 |
+
tokenizer=tokenizer,
|
| 1273 |
+
vae=vae,
|
| 1274 |
+
scheduler=scheduler,
|
| 1275 |
+
image_encoder=image_encoder,
|
| 1276 |
+
image_processor=image_processor,
|
| 1277 |
+
)
|
| 1278 |
+
else:
|
| 1279 |
+
pipe = WanPipeline(
|
| 1280 |
+
transformer=transformer,
|
| 1281 |
+
text_encoder=text_encoder,
|
| 1282 |
+
tokenizer=tokenizer,
|
| 1283 |
+
vae=vae,
|
| 1284 |
+
scheduler=scheduler,
|
| 1285 |
+
)
|
| 1286 |
+
|
| 1287 |
+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
sudoku/generate_dataset.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sudoku Video Dataset Generator - Supports flexible solution count expressions per puzzle.
|
| 3 |
+
With checkpoint/resume support via metadata.json.
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
import random
|
| 8 |
+
import argparse
|
| 9 |
+
from dataclasses import dataclass, asdict
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Tuple, Optional, Union, Dict, Any
|
| 12 |
+
import numpy as np
|
| 13 |
+
import cv2
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from sudoku_processor import SudokuProcessor
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ==================== Solution Range ====================
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class SolRange:
|
| 22 |
+
"""Flexible solution count constraint for puzzle generation."""
|
| 23 |
+
min_sol: int
|
| 24 |
+
max_sol: Optional[int]
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def parse(cls, expr: str) -> "SolRange":
|
| 28 |
+
expr = expr.strip()
|
| 29 |
+
m = re.fullmatch(r'(\d+)\s*-\s*(\d+)', expr)
|
| 30 |
+
if m:
|
| 31 |
+
lo, hi = int(m.group(1)), int(m.group(2))
|
| 32 |
+
if lo < 1: raise ValueError(f"min_sol must be >= 1, got {lo}")
|
| 33 |
+
if hi < lo: raise ValueError(f"Invalid range: {lo}-{hi}")
|
| 34 |
+
return cls(min_sol=lo, max_sol=hi)
|
| 35 |
+
m = re.fullmatch(r'(>=|>|<=|<|==)\s*(\d+)', expr)
|
| 36 |
+
if m:
|
| 37 |
+
op, n = m.group(1), int(m.group(2))
|
| 38 |
+
if op == '>=': return cls(min_sol=max(1, n), max_sol=None)
|
| 39 |
+
elif op == '>': return cls(min_sol=max(1, n + 1), max_sol=None)
|
| 40 |
+
elif op == '<=': return cls(min_sol=1, max_sol=n)
|
| 41 |
+
elif op == '<': return cls(min_sol=1, max_sol=max(1, n - 1))
|
| 42 |
+
elif op == '==': return cls(min_sol=n, max_sol=n)
|
| 43 |
+
m = re.fullmatch(r'(\d+)', expr)
|
| 44 |
+
if m:
|
| 45 |
+
n = int(m.group(1))
|
| 46 |
+
if n < 1: raise ValueError(f"sol_num must be >= 1, got {n}")
|
| 47 |
+
return cls(min_sol=n, max_sol=n)
|
| 48 |
+
raise ValueError(f"Invalid sol_num expression: '{expr}'")
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def is_exact(self): return self.max_sol is not None and self.min_sol == self.max_sol
|
| 52 |
+
@property
|
| 53 |
+
def is_unique_only(self): return self.is_exact and self.min_sol == 1
|
| 54 |
+
@property
|
| 55 |
+
def allows_unique(self): return self.min_sol <= 1
|
| 56 |
+
@property
|
| 57 |
+
def requires_multi(self): return self.min_sol > 1
|
| 58 |
+
@property
|
| 59 |
+
def effective_max(self): return self.max_sol if self.max_sol is not None else max(self.min_sol, 10)
|
| 60 |
+
def accepts(self, count):
|
| 61 |
+
if count < self.min_sol: return False
|
| 62 |
+
if self.max_sol is not None and count > self.max_sol: return False
|
| 63 |
+
return True
|
| 64 |
+
def __repr__(self):
|
| 65 |
+
if self.is_exact: return f"SolRange(=={self.min_sol})"
|
| 66 |
+
if self.max_sol is None: return f"SolRange(>={self.min_sol})"
|
| 67 |
+
return f"SolRange({self.min_sol}-{self.max_sol})"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ==================== Checkpoint Management ====================
|
| 71 |
+
|
| 72 |
+
@dataclass
|
| 73 |
+
class GenerationState:
|
| 74 |
+
"""Tracks generation progress for checkpoint/resume."""
|
| 75 |
+
params_hash: str
|
| 76 |
+
clue_progress: Dict[int, int] # clue_level -> generated_count
|
| 77 |
+
seen_grids: List[str]
|
| 78 |
+
all_samples: List[Dict]
|
| 79 |
+
completed: bool = False
|
| 80 |
+
|
| 81 |
+
def to_dict(self) -> Dict:
|
| 82 |
+
return asdict(self)
|
| 83 |
+
|
| 84 |
+
@classmethod
|
| 85 |
+
def from_dict(cls, d: Dict) -> "GenerationState":
|
| 86 |
+
return cls(**d)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def compute_params_hash(params: Dict) -> str:
|
| 90 |
+
"""Compute hash of generation parameters for consistency check."""
|
| 91 |
+
import hashlib
|
| 92 |
+
# Only hash parameters that affect generation logic
|
| 93 |
+
key_params = {k: v for k, v in params.items()
|
| 94 |
+
if k not in ['output_dir']} # output_dir can differ
|
| 95 |
+
return hashlib.md5(json.dumps(key_params, sort_keys=True).encode()).hexdigest()[:12]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]:
|
| 99 |
+
"""Load checkpoint if exists and params match."""
|
| 100 |
+
meta_path = output_dir / "metadata.json"
|
| 101 |
+
if not meta_path.exists():
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
with open(meta_path) as f:
|
| 105 |
+
data = json.load(f)
|
| 106 |
+
|
| 107 |
+
state = GenerationState.from_dict(data["state"])
|
| 108 |
+
expected_hash = compute_params_hash(params)
|
| 109 |
+
|
| 110 |
+
if state.params_hash != expected_hash:
|
| 111 |
+
print(f"⚠️ Parameters changed (hash {state.params_hash} → {expected_hash}), starting fresh")
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
if state.completed:
|
| 115 |
+
print("✓ Generation already completed")
|
| 116 |
+
return state
|
| 117 |
+
|
| 118 |
+
print(f"✓ Resuming from checkpoint: {sum(state.clue_progress.values())} puzzles generated")
|
| 119 |
+
return state
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict):
|
| 123 |
+
"""Save current generation state to metadata.json."""
|
| 124 |
+
meta_path = output_dir / "metadata.json"
|
| 125 |
+
data = {
|
| 126 |
+
"params": params,
|
| 127 |
+
"state": state.to_dict()
|
| 128 |
+
}
|
| 129 |
+
# Atomic write
|
| 130 |
+
tmp_path = meta_path.with_suffix('.tmp')
|
| 131 |
+
with open(tmp_path, 'w') as f:
|
| 132 |
+
json.dump(data, f, indent=2)
|
| 133 |
+
tmp_path.rename(meta_path)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# ==================== Core Functions ====================
|
| 137 |
+
|
| 138 |
+
def get_fill_order(puzzle, solution):
|
| 139 |
+
return [(i, j, solution[i][j]) for i in range(9) for j in range(9) if puzzle[i][j] == 0]
|
| 140 |
+
|
| 141 |
+
def create_processor(resolution=None):
|
| 142 |
+
if resolution is None: return SudokuProcessor()
|
| 143 |
+
target_size = min(resolution)
|
| 144 |
+
cell_size = target_size // 9
|
| 145 |
+
sf = cell_size / 60
|
| 146 |
+
return SudokuProcessor(cell_size=cell_size, font_scale=1.2*sf, thickness=max(1, int(2*sf)))
|
| 147 |
+
|
| 148 |
+
def generate_video_frames(proc, puzzle, solution, n_start, m_end, k=1, max_frames=None):
|
| 149 |
+
fills = get_fill_order(puzzle, solution)
|
| 150 |
+
n_fills = len(fills)
|
| 151 |
+
effective_k = k
|
| 152 |
+
if max_frames is not None and n_start + n_fills * k + m_end > max_frames:
|
| 153 |
+
avail = max_frames - n_start - m_end
|
| 154 |
+
effective_k = max(1, avail // n_fills) if avail > 0 and n_fills > 0 else 1
|
| 155 |
+
|
| 156 |
+
frames = []
|
| 157 |
+
current = [row[:] for row in puzzle]
|
| 158 |
+
img = proc.render(current)
|
| 159 |
+
frames.extend([img.copy() for _ in range(n_start)])
|
| 160 |
+
|
| 161 |
+
for r, c, v in fills:
|
| 162 |
+
current[r][c] = v
|
| 163 |
+
frames.append(proc.render(current, highlight_new=(r, c), original=puzzle))
|
| 164 |
+
if effective_k > 1:
|
| 165 |
+
img = proc.render(current, original=puzzle)
|
| 166 |
+
frames.extend([img.copy() for _ in range(effective_k - 1)])
|
| 167 |
+
|
| 168 |
+
img = proc.render(solution, original=puzzle)
|
| 169 |
+
frames.extend([img.copy() for _ in range(m_end)])
|
| 170 |
+
if max_frames is not None and len(frames) > max_frames:
|
| 171 |
+
frames = frames[:max_frames]
|
| 172 |
+
return frames
|
| 173 |
+
|
| 174 |
+
def save_video(frames, path, fps=10):
|
| 175 |
+
h, w = frames[0].shape[:2]
|
| 176 |
+
writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
| 177 |
+
for f in frames: writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
|
| 178 |
+
writer.release()
|
| 179 |
+
|
| 180 |
+
def normalize_num_per_clue(num_per_clue, clue_levels):
|
| 181 |
+
if isinstance(num_per_clue, int): return [num_per_clue] * len(clue_levels)
|
| 182 |
+
if len(num_per_clue) != len(clue_levels):
|
| 183 |
+
raise ValueError(f"num_per_clue length ({len(num_per_clue)}) != clue_levels ({len(clue_levels)})")
|
| 184 |
+
return num_per_clue
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ==================== Puzzle Generation with SolRange ====================
|
| 188 |
+
|
| 189 |
+
def generate_puzzle_with_range(proc, clue, sol_range, min_hamming):
|
| 190 |
+
"""Generate one puzzle respecting sol_range. Returns (puzzle, solutions) or None."""
|
| 191 |
+
if sol_range.is_unique_only:
|
| 192 |
+
puzzle, solution = proc.generate(clue, unique=True)
|
| 193 |
+
return puzzle, [solution]
|
| 194 |
+
|
| 195 |
+
if sol_range.requires_multi:
|
| 196 |
+
try:
|
| 197 |
+
puzzle, solutions = proc.generate_multi_solution(
|
| 198 |
+
clue, min_solutions=sol_range.min_sol,
|
| 199 |
+
max_solutions=sol_range.effective_max,
|
| 200 |
+
max_attempts=1, min_hamming=min_hamming
|
| 201 |
+
)
|
| 202 |
+
if sol_range.accepts(len(solutions)):
|
| 203 |
+
return puzzle, solutions
|
| 204 |
+
except RuntimeError:
|
| 205 |
+
pass
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
puzzle, solutions = proc.generate_multi_solution(
|
| 210 |
+
clue, min_solutions=max(2, sol_range.min_sol),
|
| 211 |
+
max_solutions=sol_range.effective_max,
|
| 212 |
+
max_attempts=1, min_hamming=min_hamming
|
| 213 |
+
)
|
| 214 |
+
if sol_range.accepts(len(solutions)):
|
| 215 |
+
return puzzle, solutions
|
| 216 |
+
except RuntimeError:
|
| 217 |
+
pass
|
| 218 |
+
|
| 219 |
+
if sol_range.allows_unique:
|
| 220 |
+
puzzle, solution = proc.generate(clue, unique=True)
|
| 221 |
+
return puzzle, [solution]
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# ==================== Dataset Generation ====================
|
| 226 |
+
|
| 227 |
+
def generate_dataset(
|
| 228 |
+
output_dir="sudoku_video", clue_levels=[30,40,50,60], num_per_clue=50,
|
| 229 |
+
sol_num="1", min_hamming=10, train_ratio=0.8,
|
| 230 |
+
prompt="Solve this Sudoku puzzle using red font.",
|
| 231 |
+
n_start=10, m_end=10, k=1, max_frames=None, fps=10,
|
| 232 |
+
resolution=None, seed=42, checkpoint_interval=50
|
| 233 |
+
):
|
| 234 |
+
"""
|
| 235 |
+
Generate Sudoku video dataset with checkpoint/resume support.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
checkpoint_interval: Save checkpoint every N puzzles (default: 50)
|
| 239 |
+
"""
|
| 240 |
+
# Prepare params dict for hashing
|
| 241 |
+
params = {
|
| 242 |
+
"clue_levels": clue_levels, "num_per_clue": num_per_clue,
|
| 243 |
+
"sol_num": sol_num, "min_hamming": min_hamming, "train_ratio": train_ratio,
|
| 244 |
+
"prompt": prompt, "n_start": n_start, "m_end": m_end, "k": k,
|
| 245 |
+
"max_frames": max_frames, "fps": fps, "resolution": resolution, "seed": seed
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
output_dir = Path(output_dir)
|
| 249 |
+
video_dir = output_dir / "videos"
|
| 250 |
+
image_dir = output_dir / "images"
|
| 251 |
+
video_dir.mkdir(parents=True, exist_ok=True)
|
| 252 |
+
image_dir.mkdir(parents=True, exist_ok=True)
|
| 253 |
+
|
| 254 |
+
# Try to resume from checkpoint
|
| 255 |
+
state = load_checkpoint(output_dir, params)
|
| 256 |
+
|
| 257 |
+
if state and state.completed:
|
| 258 |
+
return # Already done
|
| 259 |
+
|
| 260 |
+
sol_range = SolRange.parse(str(sol_num))
|
| 261 |
+
proc = create_processor(resolution)
|
| 262 |
+
actual_size = proc.img_size
|
| 263 |
+
num_per_clue_list = normalize_num_per_clue(num_per_clue, clue_levels)
|
| 264 |
+
max_puzzles = max(num_per_clue_list)
|
| 265 |
+
num_width = len(str(max_puzzles))
|
| 266 |
+
|
| 267 |
+
# Initialize or restore state
|
| 268 |
+
if state is None:
|
| 269 |
+
random.seed(seed)
|
| 270 |
+
state = GenerationState(
|
| 271 |
+
params_hash=compute_params_hash(params),
|
| 272 |
+
clue_progress={clue: 0 for clue in clue_levels},
|
| 273 |
+
seen_grids=[],
|
| 274 |
+
all_samples=[]
|
| 275 |
+
)
|
| 276 |
+
print(f"Starting fresh generation with solution range: {sol_range}")
|
| 277 |
+
else:
|
| 278 |
+
# Restore RNG state approximately by fast-forwarding
|
| 279 |
+
random.seed(seed)
|
| 280 |
+
for _ in range(sum(state.clue_progress.values()) * 10):
|
| 281 |
+
random.random()
|
| 282 |
+
|
| 283 |
+
seen_grids = set(state.seen_grids)
|
| 284 |
+
all_samples = state.all_samples.copy()
|
| 285 |
+
clue_progress = {int(k): v for k, v in state.clue_progress.items()}
|
| 286 |
+
|
| 287 |
+
total_target = sum(num_per_clue_list)
|
| 288 |
+
total_done = sum(clue_progress.values())
|
| 289 |
+
stats_unique = sum(1 for s in all_samples if s["total_solutions"] == 1 and s["sol_idx"] == 0)
|
| 290 |
+
stats_multi = sum(1 for s in all_samples if s["total_solutions"] > 1 and s["sol_idx"] == 0)
|
| 291 |
+
puzzles_since_checkpoint = 0
|
| 292 |
+
|
| 293 |
+
with tqdm(total=total_target, initial=total_done, desc="Total", unit="puzzle") as pbar_total:
|
| 294 |
+
for clue, target_count in zip(clue_levels, num_per_clue_list):
|
| 295 |
+
generated = clue_progress.get(clue, 0)
|
| 296 |
+
if generated >= target_count:
|
| 297 |
+
continue # This clue level is done
|
| 298 |
+
|
| 299 |
+
max_attempts = (target_count - generated) * 20
|
| 300 |
+
|
| 301 |
+
with tqdm(total=target_count, initial=generated, desc=f"Clue {clue:2d}",
|
| 302 |
+
unit="puzzle", leave=False) as pbar_clue:
|
| 303 |
+
for _ in range(max_attempts):
|
| 304 |
+
if generated >= target_count:
|
| 305 |
+
break
|
| 306 |
+
|
| 307 |
+
result = generate_puzzle_with_range(proc, clue, sol_range, min_hamming)
|
| 308 |
+
if result is None:
|
| 309 |
+
continue
|
| 310 |
+
puzzle, solutions = result
|
| 311 |
+
|
| 312 |
+
fp = proc.encode(puzzle)
|
| 313 |
+
if fp in seen_grids:
|
| 314 |
+
continue
|
| 315 |
+
seen_grids.add(fp)
|
| 316 |
+
|
| 317 |
+
n_sols = len(solutions)
|
| 318 |
+
if n_sols == 1:
|
| 319 |
+
stats_unique += 1
|
| 320 |
+
else:
|
| 321 |
+
stats_multi += 1
|
| 322 |
+
|
| 323 |
+
img_name = f"clue{clue}_{generated:0{num_width}d}.png"
|
| 324 |
+
puzzle_img = proc.render(puzzle)
|
| 325 |
+
cv2.imwrite(str(image_dir / img_name), cv2.cvtColor(puzzle_img, cv2.COLOR_RGB2BGR))
|
| 326 |
+
|
| 327 |
+
for si, sol in enumerate(solutions):
|
| 328 |
+
vid_name = f"clue{clue}_{generated:0{num_width}d}_sol{si}.mp4"
|
| 329 |
+
frames = generate_video_frames(proc, puzzle, sol, n_start, m_end, k, max_frames)
|
| 330 |
+
save_video(frames, video_dir / vid_name, fps)
|
| 331 |
+
|
| 332 |
+
hdists = [proc._hamming(sol, solutions[j]) for j in range(n_sols) if j != si]
|
| 333 |
+
all_samples.append({
|
| 334 |
+
"prompt": prompt, "video": vid_name, "image": img_name,
|
| 335 |
+
"clue": clue, "puzzle": fp, "solution": proc.encode(sol),
|
| 336 |
+
"sol_idx": si, "total_solutions": n_sols,
|
| 337 |
+
"frame_count": len(frames),
|
| 338 |
+
"min_hamming_to_others": min(hdists) if hdists else 0
|
| 339 |
+
})
|
| 340 |
+
|
| 341 |
+
generated += 1
|
| 342 |
+
clue_progress[clue] = generated
|
| 343 |
+
puzzles_since_checkpoint += 1
|
| 344 |
+
pbar_clue.update(1)
|
| 345 |
+
pbar_total.update(1)
|
| 346 |
+
|
| 347 |
+
# Periodic checkpoint
|
| 348 |
+
if puzzles_since_checkpoint >= checkpoint_interval:
|
| 349 |
+
state.clue_progress = clue_progress
|
| 350 |
+
state.seen_grids = list(seen_grids)
|
| 351 |
+
state.all_samples = all_samples
|
| 352 |
+
save_checkpoint(output_dir, state, params)
|
| 353 |
+
puzzles_since_checkpoint = 0
|
| 354 |
+
|
| 355 |
+
tqdm.write(f"Clue {clue}: {generated} puzzles, "
|
| 356 |
+
f"{sum(1 for s in all_samples if s['clue'] == clue)} videos")
|
| 357 |
+
|
| 358 |
+
# Final output
|
| 359 |
+
random.seed(seed + 1) # Deterministic shuffle
|
| 360 |
+
random.shuffle(all_samples)
|
| 361 |
+
split_idx = int(len(all_samples) * train_ratio)
|
| 362 |
+
|
| 363 |
+
def write_jsonl(samples, path):
|
| 364 |
+
with open(path, 'w') as f:
|
| 365 |
+
for s in samples:
|
| 366 |
+
json.dump(s, f)
|
| 367 |
+
f.write('\n')
|
| 368 |
+
|
| 369 |
+
write_jsonl(all_samples[:split_idx], output_dir / "train.jsonl")
|
| 370 |
+
write_jsonl(all_samples[split_idx:], output_dir / "test.jsonl")
|
| 371 |
+
|
| 372 |
+
# Mark as completed
|
| 373 |
+
state.clue_progress = clue_progress
|
| 374 |
+
state.seen_grids = list(seen_grids)
|
| 375 |
+
state.all_samples = all_samples
|
| 376 |
+
state.completed = True
|
| 377 |
+
save_checkpoint(output_dir, state, params)
|
| 378 |
+
|
| 379 |
+
print(f"\n✓ Dataset complete: {output_dir}/")
|
| 380 |
+
print(f" Resolution: {actual_size}x{actual_size}")
|
| 381 |
+
print(f" Solution range: {sol_range}")
|
| 382 |
+
print(f" Puzzles: {len(seen_grids)} ({stats_unique} unique, {stats_multi} multi-sol)")
|
| 383 |
+
print(f" Videos: {len(all_samples)}")
|
| 384 |
+
print(f" Train: {split_idx}, Test: {len(all_samples) - split_idx}")
|
| 385 |
+
|
| 386 |
+
hammings = [s["min_hamming_to_others"] for s in all_samples if s["min_hamming_to_others"] > 0]
|
| 387 |
+
if hammings:
|
| 388 |
+
print(f" Solution diversity: avg={np.mean(hammings):.1f}, min={min(hammings)}, max={max(hammings)}")
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def parse_resolution(s):
|
| 392 |
+
w, h = map(int, s.lower().split('x'))
|
| 393 |
+
return (w, h)
|
| 394 |
+
|
| 395 |
+
def parse_args():
|
| 396 |
+
p = argparse.ArgumentParser(description="Generate Sudoku video dataset with resume support")
|
| 397 |
+
p.add_argument("--output-dir", type=str, default="sudoku")
|
| 398 |
+
p.add_argument("--clue-levels", type=int, nargs="+", default=[20,30,40,50,60,70])
|
| 399 |
+
p.add_argument("--num-per-clue", type=int, nargs="+", default=[15000,10000,10000,5000,2000,1000])
|
| 400 |
+
p.add_argument("--sol-num", type=str, default="<=3",
|
| 401 |
+
help="'1', '3', '>=1', '>1', '<=3', '<3', '2-5'")
|
| 402 |
+
p.add_argument("--min-hamming", type=int, default=10)
|
| 403 |
+
p.add_argument("--train-ratio", type=float, default=0.9)
|
| 404 |
+
p.add_argument("--prompt", type=str, default="Solve this Sudoku puzzle using red font.")
|
| 405 |
+
p.add_argument("--n-start", type=int, default=2)
|
| 406 |
+
p.add_argument("--m-end", type=int, default=3)
|
| 407 |
+
p.add_argument("--k", type=int, default=1)
|
| 408 |
+
p.add_argument("--max-frames", type=int, default=None)
|
| 409 |
+
p.add_argument("--fps", type=int, default=10)
|
| 410 |
+
p.add_argument("--resolution", type=str, default="1024x1024")
|
| 411 |
+
p.add_argument("--seed", type=int, default=42)
|
| 412 |
+
p.add_argument("--checkpoint-interval", type=int, default=50,
|
| 413 |
+
help="Save checkpoint every N puzzles (default: 50)")
|
| 414 |
+
return p.parse_args()
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if __name__ == "__main__":
|
| 418 |
+
args = parse_args()
|
| 419 |
+
kwargs = vars(args)
|
| 420 |
+
if isinstance(kwargs["num_per_clue"], list) and len(kwargs["num_per_clue"]) == 1:
|
| 421 |
+
kwargs["num_per_clue"] = kwargs["num_per_clue"][0]
|
| 422 |
+
if kwargs["resolution"]:
|
| 423 |
+
kwargs["resolution"] = parse_resolution(kwargs["resolution"])
|
| 424 |
+
generate_dataset(**kwargs)
|
sudoku/jsonl_to_csv.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import csv
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
dataset='sudoku'
|
| 6 |
+
split='train'
|
| 7 |
+
|
| 8 |
+
# Load test data
|
| 9 |
+
with open(f'{dataset}/{split}_info.jsonl', 'r') as f:
|
| 10 |
+
data = [json.loads(line) for line in f]
|
| 11 |
+
|
| 12 |
+
# Write to CSV
|
| 13 |
+
with open(f'{dataset}/{split}.csv', 'w', newline='', encoding='utf-8') as f:
|
| 14 |
+
writer = csv.writer(f)
|
| 15 |
+
writer.writerow(['input_image', 'video', 'prompt'])
|
| 16 |
+
|
| 17 |
+
for idx, item in enumerate(data):
|
| 18 |
+
writer.writerow([
|
| 19 |
+
'images/' + item['image'],
|
| 20 |
+
'videos/' + item['video'],
|
| 21 |
+
item['prompt'],
|
| 22 |
+
])
|
sudoku/simplify_dataset.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
dataset = 'sudoku_600'
|
| 4 |
+
split = 'test'
|
| 5 |
+
|
| 6 |
+
# Read original data
|
| 7 |
+
with open(f'{dataset}/{split}.jsonl', 'r') as f:
|
| 8 |
+
data = [json.loads(line) for line in f]
|
| 9 |
+
|
| 10 |
+
# Transform to simplified format
|
| 11 |
+
new_data = [{'prompt': d['prompt'], 'image': d['image']} for d in data]
|
| 12 |
+
|
| 13 |
+
# Save simplified data to {split}.jsonl
|
| 14 |
+
with open(f'{dataset}/{split}.jsonl', 'w') as f:
|
| 15 |
+
f.writelines(json.dumps(item) + '\n' for item in new_data)
|
| 16 |
+
|
| 17 |
+
# Save original data to {split}_info.jsonl
|
| 18 |
+
with open(f'{dataset}/{split}_info.jsonl', 'w') as f:
|
| 19 |
+
f.writelines(json.dumps(item) + '\n' for item in data)
|
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00001-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3a287551bf47373c4e66324c27a84ce2daa48c89acdde4eb8d89178d8ad09da9
|
| 3 |
+
size 4992484608
|
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00002-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:35a65b38950cf1b3d01460bec6b03e5efdc66854a678f5089a668a2664a91f4c
|
| 3 |
+
size 4898551584
|
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00003-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:498c9dc1c9d6edabf5514ff30440c752507c07afba637ea0793b611dcd4fd4ea
|
| 3 |
+
size 4987667104
|
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00004-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e33036a5d81e8a72a5ac48657af77bcd751a7dce06e6e883643f1a7e377e69c6
|
| 3 |
+
size 4987711216
|
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00005-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9f9f17087d5bd04bb83dfe3958bf03a54def4b6029f9677b19b9c98484a9b742
|
| 3 |
+
size 4950959936
|
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00006-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:29381ca0d3790260c577b3849efd7ecb87d1550dfc2ca52ae1ab7052c615bf32
|
| 3 |
+
size 4950980632
|
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0-diffusers/diffusion_pytorch_model-00007-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2214191998778c951068e0b15363ddab49e089270b58ef62935aa4a523a5358a
|
| 3 |
+
size 3021537400
|
sudoku/sudoku/checkpoints/Wan2.1-I2V-14B-720P_full_0206/epoch-0.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:504d9c47c4e559e8df2d2b93a69e094ac2b9d1e65f8124849c7e719b47380952
|
| 3 |
+
size 32789894056
|
sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-0.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae539b75e0dc5a5809b3b87dd5b02abc1fa9cda2e79b0d9ca63d5954795eaeec
|
| 3 |
+
size 9999659704
|
sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-1.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b303e8e24c73ca768be637d2e787f97b95d8c8ae36ae56e8986842549bd69a0b
|
| 3 |
+
size 9999659704
|
sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-2.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0c1e05b6a600d53a11306fddc78fce8660cc13716ad330f7c51adb8258118673
|
| 3 |
+
size 9999659704
|
sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-3.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5eae0f23c6c4786bcebd2289bc7e06a912f726c94d2e77f20d0b543630751ca
|
| 3 |
+
size 9999659704
|
sudoku/sudoku/checkpoints/Wan2.2-TI2V-5B_full/epoch-4.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d8ab68aeec9fb7b79dac05b95eb4d2d5b9e3cbe20eaeef37adbca4c6c5565fd
|
| 3 |
+
size 9999659704
|
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31/epoch-3.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2f2d8a61d1723198c9cc194edb7e562e5aa84e9e7dd358dcffa492d98b5c8c1c
|
| 3 |
+
size 32789894056
|
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00001-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e869c2d79321f9c73a4443150d45b2b7455a92c9cb3a3fb1321d698a882fbf3e
|
| 3 |
+
size 4992484608
|
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00002-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c24ed79955530349d2f53c16cd038e17281f3728f129b950088812f9d4393a5
|
| 3 |
+
size 4898551584
|
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00003-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8721792f07a2411d5d07639bac2a52585aeb20727de09a835b0f6251385a4ee7
|
| 3 |
+
size 4987667104
|
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00004-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:342d017549cd13b866272c0c7ae43f9464ce07936a0c77e3f6c6c36b8a6c8aeb
|
| 3 |
+
size 4987711216
|
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00005-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1389e9cf029b963914a7c725fb6ab5ced0576b364865e7f9cbb84d4f803b9b74
|
| 3 |
+
size 4950959936
|
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00006-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be51484e98b724b87e673cc0eeee8aeac66392e838f57879d5fdc34aef9d4022
|
| 3 |
+
size 4950980632
|
sudoku/sudoku_600/checkpoints/Wan2.1-I2V-14B-720P_full_1_31_diffusers/diffusion_pytorch_model-00007-of-00007.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ffbdf62dff812b47dbf7e4a96e950bcc396a73922301310b9e84bc008d54e34e
|
| 3 |
+
size 3021537400
|
sudoku/sudoku_processor.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SudokuProcessor - Sudoku puzzle generation, solving, and rendering using SAT solver.
|
| 3 |
+
Supports efficient diverse multi-solution generation.
|
| 4 |
+
"""
|
| 5 |
+
import random
|
| 6 |
+
from typing import List, Tuple, Optional
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from pysat.solvers import Solver
|
| 12 |
+
HAS_PYSAT = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
HAS_PYSAT = False
|
| 15 |
+
print("Warning: pysat not found, install with: pip install python-sat")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SudokuProcessor:
|
| 19 |
+
"""Handles Sudoku puzzle generation, solving, and image rendering."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, cell_size: int = 60, font_scale: float = 1.2, thickness: int = 2):
|
| 22 |
+
self.cell_size = cell_size
|
| 23 |
+
self.font_scale = font_scale
|
| 24 |
+
self.thickness = thickness
|
| 25 |
+
self.img_size = cell_size * 9
|
| 26 |
+
|
| 27 |
+
# Colors (RGB)
|
| 28 |
+
self.bg_color = (255, 255, 255)
|
| 29 |
+
self.line_color = (0, 0, 0)
|
| 30 |
+
self.original_color = (0, 0, 0)
|
| 31 |
+
self.filled_color = (200, 0, 0)
|
| 32 |
+
self.highlight_color = (255, 255, 200)
|
| 33 |
+
|
| 34 |
+
self._base_clauses_cache = None
|
| 35 |
+
|
| 36 |
+
# ==================== SAT Encoding ====================
|
| 37 |
+
|
| 38 |
+
def _var(self, r: int, c: int, n: int) -> int:
|
| 39 |
+
"""Map (row, col, num) to SAT variable (1-indexed)."""
|
| 40 |
+
return r * 81 + c * 9 + n + 1
|
| 41 |
+
|
| 42 |
+
def _decode_var(self, v: int) -> Tuple[int, int, int]:
|
| 43 |
+
v -= 1
|
| 44 |
+
return v // 81, (v % 81) // 9, v % 9
|
| 45 |
+
|
| 46 |
+
def _base_clauses(self) -> List[List[int]]:
|
| 47 |
+
"""Generate base Sudoku constraint clauses (cached)."""
|
| 48 |
+
if self._base_clauses_cache is not None:
|
| 49 |
+
return self._base_clauses_cache
|
| 50 |
+
|
| 51 |
+
clauses = []
|
| 52 |
+
for i in range(9):
|
| 53 |
+
for j in range(9):
|
| 54 |
+
clauses.append([self._var(i, j, n) for n in range(9)])
|
| 55 |
+
for n1 in range(9):
|
| 56 |
+
for n2 in range(n1 + 1, 9):
|
| 57 |
+
clauses.append([-self._var(i, j, n1), -self._var(i, j, n2)])
|
| 58 |
+
|
| 59 |
+
for n in range(9):
|
| 60 |
+
for i in range(9):
|
| 61 |
+
clauses.append([self._var(i, j, n) for j in range(9)])
|
| 62 |
+
for j1 in range(9):
|
| 63 |
+
for j2 in range(j1 + 1, 9):
|
| 64 |
+
clauses.append([-self._var(i, j1, n), -self._var(i, j2, n)])
|
| 65 |
+
clauses.append([self._var(j, i, n) for j in range(9)])
|
| 66 |
+
for j1 in range(9):
|
| 67 |
+
for j2 in range(j1 + 1, 9):
|
| 68 |
+
clauses.append([-self._var(j1, i, n), -self._var(j2, i, n)])
|
| 69 |
+
for br in range(3):
|
| 70 |
+
for bc in range(3):
|
| 71 |
+
box = [self._var(br*3+di, bc*3+dj, n) for di in range(3) for dj in range(3)]
|
| 72 |
+
clauses.append(box)
|
| 73 |
+
for i1 in range(9):
|
| 74 |
+
for i2 in range(i1 + 1, 9):
|
| 75 |
+
clauses.append([-box[i1], -box[i2]])
|
| 76 |
+
|
| 77 |
+
self._base_clauses_cache = clauses
|
| 78 |
+
return clauses
|
| 79 |
+
|
| 80 |
+
def _grid_clauses(self, grid: List[List[int]]) -> List[List[int]]:
|
| 81 |
+
return [[self._var(i, j, grid[i][j] - 1)]
|
| 82 |
+
for i in range(9) for j in range(9) if grid[i][j] != 0]
|
| 83 |
+
|
| 84 |
+
def _model_to_grid(self, model: List[int]) -> List[List[int]]:
|
| 85 |
+
grid = [[0] * 9 for _ in range(9)]
|
| 86 |
+
for v in model:
|
| 87 |
+
if 0 < v <= 729:
|
| 88 |
+
r, c, n = self._decode_var(v)
|
| 89 |
+
grid[r][c] = n + 1
|
| 90 |
+
return grid
|
| 91 |
+
|
| 92 |
+
# ==================== Solving ====================
|
| 93 |
+
|
| 94 |
+
def solve(self, grid: List[List[int]]) -> Optional[List[List[int]]]:
|
| 95 |
+
if HAS_PYSAT:
|
| 96 |
+
with Solver(name='g3') as s:
|
| 97 |
+
for c in self._base_clauses() + self._grid_clauses(grid):
|
| 98 |
+
s.add_clause(c)
|
| 99 |
+
return self._model_to_grid(s.get_model()) if s.solve() else None
|
| 100 |
+
return self._solve_backtrack(grid)
|
| 101 |
+
|
| 102 |
+
def _solve_backtrack(self, grid: List[List[int]]) -> Optional[List[List[int]]]:
|
| 103 |
+
board = [row[:] for row in grid]
|
| 104 |
+
return board if self._backtrack(board) else None
|
| 105 |
+
|
| 106 |
+
def _backtrack(self, board: List[List[int]]) -> bool:
|
| 107 |
+
empty = self._find_empty(board)
|
| 108 |
+
if not empty:
|
| 109 |
+
return True
|
| 110 |
+
r, c = empty
|
| 111 |
+
for num in range(1, 10):
|
| 112 |
+
if self._is_valid(board, r, c, num):
|
| 113 |
+
board[r][c] = num
|
| 114 |
+
if self._backtrack(board):
|
| 115 |
+
return True
|
| 116 |
+
board[r][c] = 0
|
| 117 |
+
return False
|
| 118 |
+
|
| 119 |
+
def _find_empty(self, board: List[List[int]]) -> Optional[Tuple[int, int]]:
|
| 120 |
+
for i in range(9):
|
| 121 |
+
for j in range(9):
|
| 122 |
+
if board[i][j] == 0:
|
| 123 |
+
return (i, j)
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
def _is_valid(self, board: List[List[int]], row: int, col: int, num: int) -> bool:
|
| 127 |
+
if num in board[row]:
|
| 128 |
+
return False
|
| 129 |
+
if any(board[i][col] == num for i in range(9)):
|
| 130 |
+
return False
|
| 131 |
+
br, bc = 3 * (row // 3), 3 * (col // 3)
|
| 132 |
+
return all(board[i][j] != num for i in range(br, br+3) for j in range(bc, bc+3))
|
| 133 |
+
|
| 134 |
+
def count_solutions(self, grid: List[List[int]], limit: int = 2) -> int:
|
| 135 |
+
if HAS_PYSAT:
|
| 136 |
+
count = 0
|
| 137 |
+
with Solver(name='g3') as s:
|
| 138 |
+
for c in self._base_clauses() + self._grid_clauses(grid):
|
| 139 |
+
s.add_clause(c)
|
| 140 |
+
while count < limit and s.solve():
|
| 141 |
+
count += 1
|
| 142 |
+
s.add_clause([-v for v in s.get_model() if 0 < v <= 729])
|
| 143 |
+
return count
|
| 144 |
+
return self._count_backtrack(grid, limit)
|
| 145 |
+
|
| 146 |
+
def _count_backtrack(self, grid: List[List[int]], limit: int) -> int:
|
| 147 |
+
board = [row[:] for row in grid]
|
| 148 |
+
self._sol_count, self._sol_limit = 0, limit
|
| 149 |
+
self._count_helper(board)
|
| 150 |
+
return self._sol_count
|
| 151 |
+
|
| 152 |
+
def _count_helper(self, board: List[List[int]]) -> bool:
|
| 153 |
+
if self._sol_count >= self._sol_limit:
|
| 154 |
+
return True
|
| 155 |
+
empty = self._find_empty(board)
|
| 156 |
+
if not empty:
|
| 157 |
+
self._sol_count += 1
|
| 158 |
+
return self._sol_count >= self._sol_limit
|
| 159 |
+
r, c = empty
|
| 160 |
+
for num in range(1, 10):
|
| 161 |
+
if self._is_valid(board, r, c, num):
|
| 162 |
+
board[r][c] = num
|
| 163 |
+
if self._count_helper(board):
|
| 164 |
+
return True
|
| 165 |
+
board[r][c] = 0
|
| 166 |
+
return False
|
| 167 |
+
|
| 168 |
+
def find_solutions(self, grid: List[List[int]], limit: int = 10) -> List[List[List[int]]]:
|
| 169 |
+
if HAS_PYSAT:
|
| 170 |
+
solutions = []
|
| 171 |
+
with Solver(name='g3') as s:
|
| 172 |
+
for c in self._base_clauses() + self._grid_clauses(grid):
|
| 173 |
+
s.add_clause(c)
|
| 174 |
+
while len(solutions) < limit and s.solve():
|
| 175 |
+
model = s.get_model()
|
| 176 |
+
solutions.append(self._model_to_grid(model))
|
| 177 |
+
s.add_clause([-v for v in model if 0 < v <= 729])
|
| 178 |
+
return solutions
|
| 179 |
+
return self._find_backtrack(grid, limit)
|
| 180 |
+
|
| 181 |
+
def _find_backtrack(self, grid: List[List[int]], limit: int) -> List[List[List[int]]]:
|
| 182 |
+
board, solutions = [row[:] for row in grid], []
|
| 183 |
+
self._find_helper(board, solutions, limit)
|
| 184 |
+
return solutions
|
| 185 |
+
|
| 186 |
+
def _find_helper(self, board: List[List[int]], solutions: List, limit: int) -> bool:
|
| 187 |
+
if len(solutions) >= limit:
|
| 188 |
+
return True
|
| 189 |
+
empty = self._find_empty(board)
|
| 190 |
+
if not empty:
|
| 191 |
+
solutions.append([row[:] for row in board])
|
| 192 |
+
return len(solutions) >= limit
|
| 193 |
+
r, c = empty
|
| 194 |
+
for num in range(1, 10):
|
| 195 |
+
if self._is_valid(board, r, c, num):
|
| 196 |
+
board[r][c] = num
|
| 197 |
+
if self._find_helper(board, solutions, limit):
|
| 198 |
+
return True
|
| 199 |
+
board[r][c] = 0
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
# ==================== Generation ====================
|
| 203 |
+
|
| 204 |
+
def generate(self, clues: int = 30, unique: bool = True) -> Tuple[List[List[int]], List[List[int]]]:
|
| 205 |
+
"""Generate a Sudoku puzzle with specified number of clues."""
|
| 206 |
+
solution = self._generate_full_grid()
|
| 207 |
+
puzzle = [row[:] for row in solution]
|
| 208 |
+
|
| 209 |
+
cells = [(i, j) for i in range(9) for j in range(9)]
|
| 210 |
+
random.shuffle(cells)
|
| 211 |
+
|
| 212 |
+
removed, target = 0, 81 - clues
|
| 213 |
+
for r, c in cells:
|
| 214 |
+
if removed >= target:
|
| 215 |
+
break
|
| 216 |
+
backup = puzzle[r][c]
|
| 217 |
+
puzzle[r][c] = 0
|
| 218 |
+
if unique and self.count_solutions(puzzle, 2) != 1:
|
| 219 |
+
puzzle[r][c] = backup
|
| 220 |
+
else:
|
| 221 |
+
removed += 1
|
| 222 |
+
|
| 223 |
+
return puzzle, solution
|
| 224 |
+
|
| 225 |
+
def _generate_full_grid(self) -> List[List[int]]:
|
| 226 |
+
if HAS_PYSAT:
|
| 227 |
+
with Solver(name='g3') as s:
|
| 228 |
+
for c in self._base_clauses():
|
| 229 |
+
s.add_clause(c)
|
| 230 |
+
cells = [(i, j) for i in range(9) for j in range(9)]
|
| 231 |
+
random.shuffle(cells)
|
| 232 |
+
assumptions = []
|
| 233 |
+
for r, c in cells[:11]:
|
| 234 |
+
nums = list(range(9))
|
| 235 |
+
random.shuffle(nums)
|
| 236 |
+
for n in nums:
|
| 237 |
+
if s.solve(assumptions=assumptions + [self._var(r, c, n)]):
|
| 238 |
+
assumptions.append(self._var(r, c, n))
|
| 239 |
+
break
|
| 240 |
+
s.solve(assumptions=assumptions)
|
| 241 |
+
return self._model_to_grid(s.get_model())
|
| 242 |
+
|
| 243 |
+
board = [[0] * 9 for _ in range(9)]
|
| 244 |
+
self._fill_grid(board)
|
| 245 |
+
return board
|
| 246 |
+
|
| 247 |
+
def _fill_grid(self, board: List[List[int]]) -> bool:
|
| 248 |
+
empty = self._find_empty(board)
|
| 249 |
+
if not empty:
|
| 250 |
+
return True
|
| 251 |
+
r, c = empty
|
| 252 |
+
nums = list(range(1, 10))
|
| 253 |
+
random.shuffle(nums)
|
| 254 |
+
for num in nums:
|
| 255 |
+
if self._is_valid(board, r, c, num):
|
| 256 |
+
board[r][c] = num
|
| 257 |
+
if self._fill_grid(board):
|
| 258 |
+
return True
|
| 259 |
+
board[r][c] = 0
|
| 260 |
+
return False
|
| 261 |
+
|
| 262 |
+
# ==================== Diverse Multi-Solution Generation ====================
|
| 263 |
+
|
| 264 |
+
@staticmethod
|
| 265 |
+
def _hamming(sol1: List[List[int]], sol2: List[List[int]]) -> int:
|
| 266 |
+
"""Count differing cells between two complete grids."""
|
| 267 |
+
return sum(sol1[i][j] != sol2[i][j] for i in range(9) for j in range(9))
|
| 268 |
+
|
| 269 |
+
@staticmethod
|
| 270 |
+
def _greedy_diverse_select(
|
| 271 |
+
candidates: List[List[List[int]]],
|
| 272 |
+
target_count: int,
|
| 273 |
+
min_hamming: int,
|
| 274 |
+
_hamming_fn=None,
|
| 275 |
+
) -> List[List[List[int]]]:
|
| 276 |
+
"""
|
| 277 |
+
Greedily select diverse solutions using farthest-point sampling.
|
| 278 |
+
|
| 279 |
+
1. Start with a random candidate.
|
| 280 |
+
2. Repeatedly add the candidate with maximum min-distance to the selected set.
|
| 281 |
+
3. Stop when enough are selected or no candidate meets min_hamming.
|
| 282 |
+
"""
|
| 283 |
+
if _hamming_fn is None:
|
| 284 |
+
_hamming_fn = SudokuProcessor._hamming
|
| 285 |
+
|
| 286 |
+
if len(candidates) <= 1:
|
| 287 |
+
return list(candidates)
|
| 288 |
+
|
| 289 |
+
n = len(candidates)
|
| 290 |
+
|
| 291 |
+
# Pre-compute pairwise distances
|
| 292 |
+
dist = [[0] * n for _ in range(n)]
|
| 293 |
+
for i in range(n):
|
| 294 |
+
for j in range(i + 1, n):
|
| 295 |
+
d = _hamming_fn(candidates[i], candidates[j])
|
| 296 |
+
dist[i][j] = d
|
| 297 |
+
dist[j][i] = d
|
| 298 |
+
|
| 299 |
+
# Farthest-point sampling
|
| 300 |
+
selected = [random.randint(0, n - 1)]
|
| 301 |
+
remaining = set(range(n)) - {selected[0]}
|
| 302 |
+
|
| 303 |
+
while len(selected) < target_count and remaining:
|
| 304 |
+
best_idx = -1
|
| 305 |
+
best_min_dist = -1
|
| 306 |
+
|
| 307 |
+
for r in remaining:
|
| 308 |
+
min_d = min(dist[r][s] for s in selected)
|
| 309 |
+
if min_d > best_min_dist:
|
| 310 |
+
best_min_dist = min_d
|
| 311 |
+
best_idx = r
|
| 312 |
+
|
| 313 |
+
if best_min_dist < min_hamming:
|
| 314 |
+
break
|
| 315 |
+
|
| 316 |
+
selected.append(best_idx)
|
| 317 |
+
remaining.discard(best_idx)
|
| 318 |
+
|
| 319 |
+
return [candidates[i] for i in selected]
|
| 320 |
+
|
| 321 |
+
def generate_multi_solution(
|
| 322 |
+
self,
|
| 323 |
+
clues: int,
|
| 324 |
+
min_solutions: int = 2,
|
| 325 |
+
max_solutions: int = 5,
|
| 326 |
+
max_attempts: int = 100,
|
| 327 |
+
min_hamming: int = 10
|
| 328 |
+
) -> Tuple[List[List[int]], List[List[List[int]]]]:
|
| 329 |
+
"""
|
| 330 |
+
Generate a puzzle with multiple diverse solutions.
|
| 331 |
+
|
| 332 |
+
Puzzle-first strategy:
|
| 333 |
+
1. Generate a full grid, randomly remove (81-clues) cells WITHOUT
|
| 334 |
+
uniqueness check → guaranteed to have ≥1 solution, likely many.
|
| 335 |
+
2. Enumerate candidate solutions of this puzzle via SAT.
|
| 336 |
+
3. Greedily select diverse solutions (farthest-point sampling).
|
| 337 |
+
4. If not enough diverse solutions, retry with a new puzzle.
|
| 338 |
+
|
| 339 |
+
This is correct because all returned solutions are guaranteed valid
|
| 340 |
+
completions of the returned puzzle.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
clues: Number of given cells.
|
| 344 |
+
min_solutions: Minimum diverse solutions required.
|
| 345 |
+
max_solutions: Maximum to return.
|
| 346 |
+
max_attempts: Outer retry budget.
|
| 347 |
+
min_hamming: Minimum pairwise Hamming distance.
|
| 348 |
+
Returns:
|
| 349 |
+
(puzzle, solutions) — all solutions are valid and pairwise diverse.
|
| 350 |
+
Raises:
|
| 351 |
+
RuntimeError: If unable to find a qualifying puzzle.
|
| 352 |
+
"""
|
| 353 |
+
# Adaptive hamming threshold
|
| 354 |
+
adaptive_hamming = min_hamming
|
| 355 |
+
if min_hamming == 10: # default → auto-adapt
|
| 356 |
+
if clues >= 55:
|
| 357 |
+
adaptive_hamming = 3
|
| 358 |
+
elif clues >= 45:
|
| 359 |
+
adaptive_hamming = 5
|
| 360 |
+
elif clues >= 35:
|
| 361 |
+
adaptive_hamming = 8
|
| 362 |
+
else:
|
| 363 |
+
adaptive_hamming = 12
|
| 364 |
+
|
| 365 |
+
# Adaptive search depth: more empty cells → more solutions likely exist
|
| 366 |
+
empty_cells = 81 - clues
|
| 367 |
+
if empty_cells <= 15:
|
| 368 |
+
max_search = 30
|
| 369 |
+
elif empty_cells <= 25:
|
| 370 |
+
max_search = 80
|
| 371 |
+
elif empty_cells <= 40:
|
| 372 |
+
max_search = 150
|
| 373 |
+
else:
|
| 374 |
+
max_search = 300
|
| 375 |
+
|
| 376 |
+
for _ in range(max_attempts):
|
| 377 |
+
# Phase 1: Generate puzzle (random removal, no uniqueness check)
|
| 378 |
+
solution = self._generate_full_grid()
|
| 379 |
+
puzzle = [row[:] for row in solution]
|
| 380 |
+
|
| 381 |
+
cells = [(i, j) for i in range(9) for j in range(9)]
|
| 382 |
+
random.shuffle(cells)
|
| 383 |
+
for r, c in cells[:81 - clues]:
|
| 384 |
+
puzzle[r][c] = 0
|
| 385 |
+
|
| 386 |
+
# # Phase 2: Quick feasibility — need at least min_solutions solutions
|
| 387 |
+
# quick_count = self.count_solutions(puzzle, min_solutions + 1)
|
| 388 |
+
# if quick_count < min_solutions:
|
| 389 |
+
# continue
|
| 390 |
+
|
| 391 |
+
# Phase 3: Enumerate candidates
|
| 392 |
+
candidates = self.find_solutions(puzzle, max_search)
|
| 393 |
+
if len(candidates) < min_solutions:
|
| 394 |
+
continue
|
| 395 |
+
|
| 396 |
+
# Phase 4: Greedy diverse selection
|
| 397 |
+
diverse = self._greedy_diverse_select(
|
| 398 |
+
candidates, max_solutions, adaptive_hamming
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if len(diverse) >= min_solutions:
|
| 402 |
+
return puzzle, diverse[:max_solutions]
|
| 403 |
+
|
| 404 |
+
raise RuntimeError(
|
| 405 |
+
f"Failed to generate puzzle with {min_solutions}-{max_solutions} "
|
| 406 |
+
f"diverse solutions (hamming>={adaptive_hamming}) after {max_attempts} attempts"
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# ==================== Encoding ====================
|
| 410 |
+
|
| 411 |
+
def encode(self, grid: List[List[int]]) -> str:
|
| 412 |
+
return ''.join(str(grid[i][j]) for i in range(9) for j in range(9))
|
| 413 |
+
|
| 414 |
+
def decode(self, s: str) -> List[List[int]]:
|
| 415 |
+
return [[int(s[i * 9 + j]) for j in range(9)] for i in range(9)]
|
| 416 |
+
|
| 417 |
+
# ==================== Rendering ====================
|
| 418 |
+
|
| 419 |
+
def render(
|
| 420 |
+
self,
|
| 421 |
+
grid: List[List[int]],
|
| 422 |
+
highlight_new: Optional[Tuple[int, int]] = None,
|
| 423 |
+
original: Optional[List[List[int]]] = None
|
| 424 |
+
) -> np.ndarray:
|
| 425 |
+
img = np.full((self.img_size, self.img_size, 3), self.bg_color, dtype=np.uint8)
|
| 426 |
+
cs = self.cell_size
|
| 427 |
+
|
| 428 |
+
if highlight_new:
|
| 429 |
+
r, c = highlight_new
|
| 430 |
+
cv2.rectangle(img, (c * cs, r * cs), ((c+1) * cs, (r+1) * cs), self.highlight_color, -1)
|
| 431 |
+
|
| 432 |
+
for i in range(10):
|
| 433 |
+
thick = 3 if i % 3 == 0 else 1
|
| 434 |
+
pos = i * cs
|
| 435 |
+
cv2.line(img, (pos, 0), (pos, self.img_size), self.line_color, thick)
|
| 436 |
+
cv2.line(img, (0, pos), (self.img_size, pos), self.line_color, thick)
|
| 437 |
+
|
| 438 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 439 |
+
for i in range(9):
|
| 440 |
+
for j in range(9):
|
| 441 |
+
if grid[i][j] == 0:
|
| 442 |
+
continue
|
| 443 |
+
is_original = original is None or original[i][j] != 0
|
| 444 |
+
color = self.original_color if is_original else self.filled_color
|
| 445 |
+
text = str(grid[i][j])
|
| 446 |
+
(tw, th), _ = cv2.getTextSize(text, font, self.font_scale, self.thickness)
|
| 447 |
+
cv2.putText(img, text, (j*cs + (cs-tw)//2, i*cs + (cs+th)//2),
|
| 448 |
+
font, self.font_scale, color, self.thickness)
|
| 449 |
+
|
| 450 |
+
return img
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
if __name__ == "__main__":
|
| 454 |
+
proc = SudokuProcessor()
|
| 455 |
+
print(f"Using {'SAT solver' if HAS_PYSAT else 'backtracking'}...")
|
| 456 |
+
|
| 457 |
+
# Test unique puzzle
|
| 458 |
+
puzzle, solution = proc.generate(clues=25, unique=True)
|
| 459 |
+
print("Puzzle:"); [print(row) for row in puzzle]
|
| 460 |
+
print(f"Clues: {sum(c != 0 for row in puzzle for c in row)}")
|
| 461 |
+
|
| 462 |
+
cv2.imwrite("test_puzzle.png", cv2.cvtColor(proc.render(puzzle), cv2.COLOR_RGB2BGR))
|
| 463 |
+
cv2.imwrite("test_solution.png", cv2.cvtColor(proc.render(solution, original=puzzle), cv2.COLOR_RGB2BGR))
|
| 464 |
+
print("Saved test images.")
|
| 465 |
+
|
| 466 |
+
# Test diverse multi-solution at various clue levels
|
| 467 |
+
print("\n=== Testing diverse multi-solution generation ===")
|
| 468 |
+
for clues in [25, 35, 45, 55]:
|
| 469 |
+
print(f"\nClue {clues}:")
|
| 470 |
+
try:
|
| 471 |
+
puzzle_m, solutions_m = proc.generate_multi_solution(
|
| 472 |
+
clues=clues, min_solutions=3, max_solutions=3, min_hamming=10
|
| 473 |
+
)
|
| 474 |
+
print(f" Generated puzzle with {len(solutions_m)} diverse solutions")
|
| 475 |
+
for i in range(len(solutions_m)):
|
| 476 |
+
for j in range(i + 1, len(solutions_m)):
|
| 477 |
+
print(f" Hamming(sol{i}, sol{j}) = {proc._hamming(solutions_m[i], solutions_m[j])}")
|
| 478 |
+
except RuntimeError as e:
|
| 479 |
+
print(f" {e}")
|