Upload generate_hf.py
Browse files- generate_hf.py +1193 -0
generate_hf.py
ADDED
|
@@ -0,0 +1,1193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Aniimage Generator β Generate anime images from text prompts.
|
| 3 |
+
https://huggingface.co/8BitStudio/Aniimage-1
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
pip install torch torchvision diffusers transformers safetensors pillow huggingface_hub
|
| 7 |
+
python generate_hf.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import numpy as np
|
| 15 |
+
import tkinter as tk
|
| 16 |
+
from tkinter import ttk, simpledialog
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from PIL import Image, ImageTk, ImageEnhance, ImageFilter
|
| 19 |
+
from threading import Thread
|
| 20 |
+
|
| 21 |
+
# ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 23 |
+
MODEL_DIR = SCRIPT_DIR / "models"
|
| 24 |
+
OUTPUT_DIR = SCRIPT_DIR / "generated"
|
| 25 |
+
|
| 26 |
+
# ββ HuggingFace repo βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
HF_REPO_ID = "8BitStudio/Aniimage-1"
|
| 28 |
+
|
| 29 |
+
# ββ UNet config (must match training) βββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
+
UNET_CONFIG = dict(
|
| 31 |
+
sample_size=32,
|
| 32 |
+
in_channels=4,
|
| 33 |
+
out_channels=4,
|
| 34 |
+
block_out_channels=(256, 512, 768, 1024),
|
| 35 |
+
layers_per_block=2,
|
| 36 |
+
cross_attention_dim=768,
|
| 37 |
+
attention_head_dim=8,
|
| 38 |
+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D",
|
| 39 |
+
"CrossAttnDownBlock2D", "DownBlock2D"),
|
| 40 |
+
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D",
|
| 41 |
+
"CrossAttnUpBlock2D", "UpBlock2D"),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
VAE_ID = "stabilityai/sd-vae-ft-mse"
|
| 45 |
+
CLIP_ID = "openai/clip-vit-large-patch14"
|
| 46 |
+
|
| 47 |
+
SCHEDULER_LIST = [
|
| 48 |
+
"DPM++ 2M Karras",
|
| 49 |
+
"DPM++ SDE Karras",
|
| 50 |
+
"Euler a",
|
| 51 |
+
"Euler",
|
| 52 |
+
"DDIM",
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
DEFAULT_NEGATIVE = (
|
| 56 |
+
"low quality, ugly, blurry, distorted, deformed, bad anatomy, "
|
| 57 |
+
"bad proportions, extra limbs, missing limbs, watermark, text, "
|
| 58 |
+
"signature, washed out, flat colors, manga panel, disfigured, "
|
| 59 |
+
"poorly drawn, jpeg artifacts, cropped, out of frame"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ββ Model discovery βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
+
|
| 65 |
+
def download_from_hf():
|
| 66 |
+
"""Download model weights from HuggingFace if not already cached."""
|
| 67 |
+
try:
|
| 68 |
+
from huggingface_hub import hf_hub_download
|
| 69 |
+
except ImportError:
|
| 70 |
+
print("Install huggingface_hub: pip install huggingface_hub")
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 74 |
+
aniimage_dir = MODEL_DIR / "Aniimage-1"
|
| 75 |
+
weights_path = aniimage_dir / "diffusion_pytorch_model.safetensors"
|
| 76 |
+
|
| 77 |
+
if weights_path.exists():
|
| 78 |
+
print("Aniimage-1 weights already downloaded.")
|
| 79 |
+
return aniimage_dir
|
| 80 |
+
|
| 81 |
+
print(f"Downloading Aniimage-1 from {HF_REPO_ID}...")
|
| 82 |
+
aniimage_dir.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
import shutil
|
| 85 |
+
dl_weights = hf_hub_download(repo_id=HF_REPO_ID,
|
| 86 |
+
filename="diffusion_pytorch_model.safetensors")
|
| 87 |
+
shutil.copy2(dl_weights, weights_path)
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
dl_config = hf_hub_download(repo_id=HF_REPO_ID, filename="config.json")
|
| 91 |
+
shutil.copy2(dl_config, aniimage_dir / "config.json")
|
| 92 |
+
except Exception:
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
print("Download complete!")
|
| 96 |
+
return aniimage_dir
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def find_models():
|
| 100 |
+
"""Find all available models."""
|
| 101 |
+
options = []
|
| 102 |
+
if MODEL_DIR.exists():
|
| 103 |
+
for d in sorted(MODEL_DIR.iterdir()):
|
| 104 |
+
if d.is_dir():
|
| 105 |
+
safetensors = d / "diffusion_pytorch_model.safetensors"
|
| 106 |
+
ema_path = d / "ema_unet.pt"
|
| 107 |
+
unet_path = d / "unet.pt"
|
| 108 |
+
if safetensors.exists():
|
| 109 |
+
options.append(("safetensors", d.name, d, "256"))
|
| 110 |
+
elif ema_path.exists() or unet_path.exists():
|
| 111 |
+
options.append(("checkpoint", d.name, d, "256"))
|
| 112 |
+
return options
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ββ Theme βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
+
|
| 117 |
+
C = {
|
| 118 |
+
"bg": "#111119",
|
| 119 |
+
"panel": "#1b1b2f",
|
| 120 |
+
"card": "#24243e",
|
| 121 |
+
"card_sel": "#3a3a6e",
|
| 122 |
+
"border": "#2e2e52",
|
| 123 |
+
"accent": "#6c5ce7",
|
| 124 |
+
"accent_h": "#8577ed",
|
| 125 |
+
"red": "#e74c3c",
|
| 126 |
+
"green": "#2ecc71",
|
| 127 |
+
"text": "#eaeaea",
|
| 128 |
+
"text2": "#a0a0b8",
|
| 129 |
+
"text3": "#60607a",
|
| 130 |
+
"input": "#16162a",
|
| 131 |
+
"input_fg": "#dcdcf0",
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Generator:
|
| 136 |
+
def __init__(self, device="cuda"):
|
| 137 |
+
self.device = device if device == "cuda" and torch.cuda.is_available() else "cpu"
|
| 138 |
+
self.vae = None
|
| 139 |
+
self.text_encoder = None
|
| 140 |
+
self.tokenizer = None
|
| 141 |
+
self.unet = None
|
| 142 |
+
self.scheduler = None
|
| 143 |
+
self.loaded_checkpoint = None
|
| 144 |
+
self.latent_size = 32
|
| 145 |
+
self.output_size = 256
|
| 146 |
+
self.cancelled = False
|
| 147 |
+
|
| 148 |
+
def switch_device(self, new_device):
|
| 149 |
+
"""Move all loaded models to a new device."""
|
| 150 |
+
new_device = new_device if new_device == "cuda" and torch.cuda.is_available() else "cpu"
|
| 151 |
+
if new_device == self.device:
|
| 152 |
+
return
|
| 153 |
+
self.device = new_device
|
| 154 |
+
if self.vae is not None:
|
| 155 |
+
self.vae = self.vae.to(self.device)
|
| 156 |
+
if self.text_encoder is not None:
|
| 157 |
+
self.text_encoder = self.text_encoder.to(self.device)
|
| 158 |
+
if self.unet is not None:
|
| 159 |
+
self.unet = self.unet.to(self.device)
|
| 160 |
+
self.loaded_checkpoint = None # force reload on next generate
|
| 161 |
+
print(f"Switched to {self.device.upper()}")
|
| 162 |
+
|
| 163 |
+
def load_shared(self):
|
| 164 |
+
if self.vae is not None:
|
| 165 |
+
return
|
| 166 |
+
from diffusers import AutoencoderKL
|
| 167 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 168 |
+
|
| 169 |
+
print("Loading VAE...")
|
| 170 |
+
self.vae = AutoencoderKL.from_pretrained(VAE_ID).to(self.device)
|
| 171 |
+
self.vae.eval()
|
| 172 |
+
|
| 173 |
+
print("Loading CLIP text encoder...")
|
| 174 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID)
|
| 175 |
+
self.text_encoder = CLIPTextModel.from_pretrained(CLIP_ID).to(self.device)
|
| 176 |
+
self.text_encoder.eval()
|
| 177 |
+
|
| 178 |
+
self.scheduler = self._make_scheduler("DPM++ 2M Karras")
|
| 179 |
+
self.scheduler_name = "DPM++ 2M Karras"
|
| 180 |
+
print("Shared models loaded.")
|
| 181 |
+
|
| 182 |
+
def _make_scheduler(self, name="DPM++ 2M Karras"):
|
| 183 |
+
from diffusers import (DDIMScheduler, DPMSolverMultistepScheduler,
|
| 184 |
+
EulerAncestralDiscreteScheduler,
|
| 185 |
+
EulerDiscreteScheduler)
|
| 186 |
+
base = dict(num_train_timesteps=1000, beta_schedule="scaled_linear",
|
| 187 |
+
prediction_type="epsilon")
|
| 188 |
+
if name == "DPM++ 2M Karras":
|
| 189 |
+
return DPMSolverMultistepScheduler(
|
| 190 |
+
**base, algorithm_type="dpmsolver++",
|
| 191 |
+
solver_order=2, use_karras_sigmas=True)
|
| 192 |
+
elif name == "DPM++ SDE Karras":
|
| 193 |
+
return DPMSolverMultistepScheduler(
|
| 194 |
+
**base, algorithm_type="sde-dpmsolver++",
|
| 195 |
+
use_karras_sigmas=True)
|
| 196 |
+
elif name == "Euler a":
|
| 197 |
+
return EulerAncestralDiscreteScheduler(**base)
|
| 198 |
+
elif name == "Euler":
|
| 199 |
+
return EulerDiscreteScheduler(**base)
|
| 200 |
+
else:
|
| 201 |
+
return DDIMScheduler(**base, clip_sample=False,
|
| 202 |
+
set_alpha_to_one=False)
|
| 203 |
+
|
| 204 |
+
def set_scheduler(self, name):
|
| 205 |
+
self.scheduler = self._make_scheduler(name)
|
| 206 |
+
self.scheduler_name = name
|
| 207 |
+
|
| 208 |
+
def load_model(self, model_path: Path, res_label: str = "256"):
|
| 209 |
+
if str(model_path) == self.loaded_checkpoint:
|
| 210 |
+
return
|
| 211 |
+
from diffusers import UNet2DConditionModel
|
| 212 |
+
|
| 213 |
+
self.load_shared()
|
| 214 |
+
|
| 215 |
+
if res_label == "512":
|
| 216 |
+
self.latent_size = 64
|
| 217 |
+
self.output_size = 512
|
| 218 |
+
else:
|
| 219 |
+
self.latent_size = 32
|
| 220 |
+
self.output_size = 256
|
| 221 |
+
|
| 222 |
+
unet_cfg = dict(UNET_CONFIG)
|
| 223 |
+
unet_cfg["sample_size"] = self.latent_size
|
| 224 |
+
|
| 225 |
+
print(f"Loading UNet from {model_path.name} ({res_label}px)...")
|
| 226 |
+
self.unet = UNet2DConditionModel(**unet_cfg).to(self.device)
|
| 227 |
+
|
| 228 |
+
safetensors_path = model_path / "diffusion_pytorch_model.safetensors"
|
| 229 |
+
ema_path = model_path / "ema_unet.pt"
|
| 230 |
+
unet_path = model_path / "unet.pt"
|
| 231 |
+
|
| 232 |
+
if safetensors_path.exists():
|
| 233 |
+
from safetensors.torch import load_file
|
| 234 |
+
state = load_file(str(safetensors_path), device=str(self.device))
|
| 235 |
+
self.unet.load_state_dict(state)
|
| 236 |
+
print("Loaded safetensors weights.")
|
| 237 |
+
elif ema_path.exists():
|
| 238 |
+
state = torch.load(ema_path, map_location=self.device, weights_only=True)
|
| 239 |
+
if "shadow_params" in state:
|
| 240 |
+
params = dict(self.unet.named_parameters())
|
| 241 |
+
keys = list(params.keys())
|
| 242 |
+
for i, sp in enumerate(state["shadow_params"]):
|
| 243 |
+
params[keys[i]].data.copy_(sp)
|
| 244 |
+
else:
|
| 245 |
+
self.unet.load_state_dict(state)
|
| 246 |
+
print("Loaded EMA weights.")
|
| 247 |
+
elif unet_path.exists():
|
| 248 |
+
self.unet.load_state_dict(
|
| 249 |
+
torch.load(unet_path, map_location=self.device, weights_only=True))
|
| 250 |
+
print("Loaded UNet weights.")
|
| 251 |
+
else:
|
| 252 |
+
raise FileNotFoundError(f"No weights found in {model_path}")
|
| 253 |
+
|
| 254 |
+
self.unet.eval()
|
| 255 |
+
self.loaded_checkpoint = str(model_path)
|
| 256 |
+
print(f"Ready to generate at {self.output_size}x{self.output_size}!")
|
| 257 |
+
|
| 258 |
+
def _decode_latents(self, latents, post_process=False):
|
| 259 |
+
scaled = latents / self.vae.config.scaling_factor
|
| 260 |
+
with torch.no_grad():
|
| 261 |
+
image = self.vae.decode(scaled.float()).sample
|
| 262 |
+
image = (image.float() / 2 + 0.5).clamp(0, 1)
|
| 263 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
| 264 |
+
image = (image * 255).round().astype("uint8")
|
| 265 |
+
img = Image.fromarray(image)
|
| 266 |
+
if post_process:
|
| 267 |
+
img = self._post_process(img)
|
| 268 |
+
return img
|
| 269 |
+
|
| 270 |
+
def _sharpen_latents(self, latents, amount=0.08):
|
| 271 |
+
blurred = F.avg_pool2d(latents, kernel_size=3, stride=1, padding=1)
|
| 272 |
+
return latents + amount * (latents - blurred)
|
| 273 |
+
|
| 274 |
+
def _post_process(self, img):
|
| 275 |
+
img = img.filter(ImageFilter.UnsharpMask(radius=1.5, percent=40, threshold=2))
|
| 276 |
+
img = ImageEnhance.Contrast(img).enhance(1.06)
|
| 277 |
+
img = ImageEnhance.Color(img).enhance(1.10)
|
| 278 |
+
return img
|
| 279 |
+
|
| 280 |
+
def _image_quality_score(self, img: Image.Image) -> float:
|
| 281 |
+
arr = np.array(img.convert("L"), dtype=np.float32)
|
| 282 |
+
lap = (np.roll(arr, 1, 0) + np.roll(arr, -1, 0)
|
| 283 |
+
+ np.roll(arr, 1, 1) + np.roll(arr, -1, 1) - 4.0 * arr)
|
| 284 |
+
sharpness = float(np.var(lap))
|
| 285 |
+
arr_rgb = np.array(img, dtype=np.float32)
|
| 286 |
+
color_var = float(np.mean(np.var(arr_rgb, axis=(0, 1))))
|
| 287 |
+
score = (sharpness * 0.6 + color_var * 0.4)
|
| 288 |
+
return min(100.0, score / 10.0)
|
| 289 |
+
|
| 290 |
+
@torch.no_grad()
|
| 291 |
+
def generate(self, prompt: str, negative_prompt: str = "",
|
| 292 |
+
steps: int = 25, guidance_scale: float = 7.5,
|
| 293 |
+
seed: int = -1, preview_callback=None,
|
| 294 |
+
preview_every: int = 5) -> tuple:
|
| 295 |
+
|
| 296 |
+
if seed < 0:
|
| 297 |
+
seed = torch.randint(0, 2**32, (1,)).item()
|
| 298 |
+
gen = torch.Generator(device=self.device).manual_seed(seed)
|
| 299 |
+
|
| 300 |
+
tok = self.tokenizer(prompt, padding="max_length",
|
| 301 |
+
max_length=self.tokenizer.model_max_length,
|
| 302 |
+
truncation=True, return_tensors="pt")
|
| 303 |
+
text_emb = self.text_encoder(tok.input_ids.to(self.device))[0]
|
| 304 |
+
|
| 305 |
+
tok_neg = self.tokenizer(negative_prompt if negative_prompt else "",
|
| 306 |
+
padding="max_length",
|
| 307 |
+
max_length=self.tokenizer.model_max_length,
|
| 308 |
+
truncation=True, return_tensors="pt")
|
| 309 |
+
neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0]
|
| 310 |
+
|
| 311 |
+
text_emb_combined = torch.cat([neg_emb, text_emb])
|
| 312 |
+
|
| 313 |
+
scheduler = self._make_scheduler(self.scheduler_name)
|
| 314 |
+
scheduler.set_timesteps(steps, device=self.device)
|
| 315 |
+
|
| 316 |
+
latents = torch.randn(1, 4, self.latent_size, self.latent_size,
|
| 317 |
+
generator=gen, device=self.device)
|
| 318 |
+
latents = latents * scheduler.init_noise_sigma
|
| 319 |
+
|
| 320 |
+
timesteps = scheduler.timesteps
|
| 321 |
+
total_steps = len(timesteps)
|
| 322 |
+
|
| 323 |
+
for step_i, t in enumerate(timesteps):
|
| 324 |
+
if self.cancelled:
|
| 325 |
+
return None, seed
|
| 326 |
+
|
| 327 |
+
latent_input = torch.cat([latents] * 2)
|
| 328 |
+
latent_input = scheduler.scale_model_input(latent_input, t)
|
| 329 |
+
|
| 330 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
|
| 331 |
+
enabled=(self.device == "cuda")):
|
| 332 |
+
pred = self.unet(latent_input, t,
|
| 333 |
+
encoder_hidden_states=text_emb_combined).sample
|
| 334 |
+
|
| 335 |
+
pred_neg, pred_text = pred.chunk(2)
|
| 336 |
+
pred = pred_neg + guidance_scale * (pred_text - pred_neg)
|
| 337 |
+
|
| 338 |
+
latents = scheduler.step(pred, t, latents).prev_sample
|
| 339 |
+
|
| 340 |
+
if (preview_callback and step_i > 0
|
| 341 |
+
and step_i % preview_every == 0
|
| 342 |
+
and step_i < total_steps - 1):
|
| 343 |
+
preview = self._decode_latents(latents, post_process=False)
|
| 344 |
+
preview_callback(preview, step_i + 1, total_steps)
|
| 345 |
+
|
| 346 |
+
latents = self._sharpen_latents(latents)
|
| 347 |
+
final = self._decode_latents(latents, post_process=True)
|
| 348 |
+
return final, seed
|
| 349 |
+
|
| 350 |
+
@torch.no_grad()
|
| 351 |
+
def generate_adaptive(self, prompt: str, negative_prompt: str = "",
|
| 352 |
+
base_steps: int = 25, max_steps: int = 85,
|
| 353 |
+
guidance_scale: float = 7.5,
|
| 354 |
+
quality_threshold: float = 45.0,
|
| 355 |
+
preview_callback=None, preview_every: int = 5,
|
| 356 |
+
status_callback=None) -> tuple:
|
| 357 |
+
|
| 358 |
+
result = self.generate(
|
| 359 |
+
prompt=prompt, negative_prompt=negative_prompt,
|
| 360 |
+
steps=base_steps, guidance_scale=guidance_scale,
|
| 361 |
+
preview_callback=preview_callback, preview_every=preview_every)
|
| 362 |
+
|
| 363 |
+
if result[0] is None:
|
| 364 |
+
return result
|
| 365 |
+
|
| 366 |
+
image, seed = result
|
| 367 |
+
quality = self._image_quality_score(image)
|
| 368 |
+
|
| 369 |
+
if status_callback:
|
| 370 |
+
status_callback(f"Quality: {quality:.1f}/100")
|
| 371 |
+
|
| 372 |
+
if quality >= quality_threshold:
|
| 373 |
+
return image, seed
|
| 374 |
+
|
| 375 |
+
rounds = 0
|
| 376 |
+
max_rounds = (max_steps - base_steps) // 20
|
| 377 |
+
|
| 378 |
+
while quality < quality_threshold and rounds < max_rounds:
|
| 379 |
+
if self.cancelled:
|
| 380 |
+
return image, seed
|
| 381 |
+
rounds += 1
|
| 382 |
+
if status_callback:
|
| 383 |
+
status_callback(f"Refining +20 steps (round {rounds})...")
|
| 384 |
+
|
| 385 |
+
refined = self.refine(
|
| 386 |
+
source_image=image, prompt=prompt,
|
| 387 |
+
negative_prompt=negative_prompt,
|
| 388 |
+
extra_steps=20, strength=0.3,
|
| 389 |
+
guidance_scale=guidance_scale,
|
| 390 |
+
preview_callback=preview_callback, preview_every=5)
|
| 391 |
+
|
| 392 |
+
if refined is None:
|
| 393 |
+
return image, seed
|
| 394 |
+
image = refined
|
| 395 |
+
quality = self._image_quality_score(image)
|
| 396 |
+
|
| 397 |
+
if status_callback:
|
| 398 |
+
status_callback(f"Quality after round {rounds}: {quality:.1f}/100")
|
| 399 |
+
|
| 400 |
+
return image, seed
|
| 401 |
+
|
| 402 |
+
@torch.no_grad()
|
| 403 |
+
def refine(self, source_image: Image.Image, prompt: str,
|
| 404 |
+
negative_prompt: str = "", extra_steps: int = 20,
|
| 405 |
+
strength: float = 0.35, guidance_scale: float = 7.5,
|
| 406 |
+
preview_callback=None, preview_every: int = 5) -> Image.Image:
|
| 407 |
+
|
| 408 |
+
img = source_image.resize((self.output_size, self.output_size), Image.LANCZOS)
|
| 409 |
+
img_tensor = torch.from_numpy(np.array(img)).float().div(127.5).sub(1.0)
|
| 410 |
+
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(self.device)
|
| 411 |
+
|
| 412 |
+
with torch.no_grad():
|
| 413 |
+
latents = self.vae.encode(img_tensor.float()).latent_dist.sample()
|
| 414 |
+
latents = latents * self.vae.config.scaling_factor
|
| 415 |
+
|
| 416 |
+
tok = self.tokenizer(prompt, padding="max_length",
|
| 417 |
+
max_length=self.tokenizer.model_max_length,
|
| 418 |
+
truncation=True, return_tensors="pt")
|
| 419 |
+
text_emb = self.text_encoder(tok.input_ids.to(self.device))[0]
|
| 420 |
+
|
| 421 |
+
tok_neg = self.tokenizer(negative_prompt if negative_prompt else "",
|
| 422 |
+
padding="max_length",
|
| 423 |
+
max_length=self.tokenizer.model_max_length,
|
| 424 |
+
truncation=True, return_tensors="pt")
|
| 425 |
+
neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0]
|
| 426 |
+
text_emb_combined = torch.cat([neg_emb, text_emb])
|
| 427 |
+
|
| 428 |
+
scheduler = self._make_scheduler(self.scheduler_name)
|
| 429 |
+
scheduler.set_timesteps(extra_steps, device=self.device)
|
| 430 |
+
start_step = max(0, int(len(scheduler.timesteps) * (1 - strength)))
|
| 431 |
+
timesteps = scheduler.timesteps[start_step:]
|
| 432 |
+
|
| 433 |
+
noise = torch.randn_like(latents)
|
| 434 |
+
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
| 435 |
+
|
| 436 |
+
total_steps = len(timesteps)
|
| 437 |
+
for step_i, t in enumerate(timesteps):
|
| 438 |
+
if self.cancelled:
|
| 439 |
+
return None
|
| 440 |
+
latent_input = torch.cat([latents] * 2)
|
| 441 |
+
latent_input = scheduler.scale_model_input(latent_input, t)
|
| 442 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
|
| 443 |
+
enabled=(self.device == "cuda")):
|
| 444 |
+
pred = self.unet(latent_input, t,
|
| 445 |
+
encoder_hidden_states=text_emb_combined).sample
|
| 446 |
+
pred_neg, pred_text = pred.chunk(2)
|
| 447 |
+
pred = pred_neg + guidance_scale * (pred_text - pred_neg)
|
| 448 |
+
latents = scheduler.step(pred, t, latents).prev_sample
|
| 449 |
+
|
| 450 |
+
if (preview_callback and step_i > 0
|
| 451 |
+
and step_i % preview_every == 0
|
| 452 |
+
and step_i < total_steps - 1):
|
| 453 |
+
preview = self._decode_latents(latents, post_process=False)
|
| 454 |
+
preview_callback(preview, step_i + 1, total_steps)
|
| 455 |
+
|
| 456 |
+
latents = self._sharpen_latents(latents)
|
| 457 |
+
return self._decode_latents(latents, post_process=True)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
# ββ GUI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 461 |
+
|
| 462 |
+
class App:
|
| 463 |
+
def __init__(self):
|
| 464 |
+
self.gen = Generator()
|
| 465 |
+
self.models = find_models()
|
| 466 |
+
self.generated_images = []
|
| 467 |
+
self.generated_seeds = []
|
| 468 |
+
self.photo_refs = []
|
| 469 |
+
self.generating = False
|
| 470 |
+
self.selected_index = None
|
| 471 |
+
|
| 472 |
+
self.root = tk.Tk()
|
| 473 |
+
self.root.title("Aniimage")
|
| 474 |
+
self.root.configure(bg=C["bg"])
|
| 475 |
+
self.root.resizable(True, True)
|
| 476 |
+
self.root.geometry("900x780")
|
| 477 |
+
self.root.minsize(640, 500)
|
| 478 |
+
|
| 479 |
+
self._setup_styles()
|
| 480 |
+
self._build_ui()
|
| 481 |
+
|
| 482 |
+
def _setup_styles(self):
|
| 483 |
+
s = ttk.Style()
|
| 484 |
+
s.theme_use("clam")
|
| 485 |
+
|
| 486 |
+
# Base
|
| 487 |
+
s.configure(".", background=C["bg"], foreground=C["text"], font=("Segoe UI", 10))
|
| 488 |
+
s.configure("TFrame", background=C["bg"])
|
| 489 |
+
s.configure("TLabel", background=C["bg"], foreground=C["text"])
|
| 490 |
+
s.configure("TCheckbutton", background=C["bg"], foreground=C["text"])
|
| 491 |
+
|
| 492 |
+
# Combobox β readable text
|
| 493 |
+
s.configure("TCombobox", fieldbackground=C["input"], foreground=C["input_fg"],
|
| 494 |
+
selectbackground=C["accent"], selectforeground="#ffffff",
|
| 495 |
+
arrowcolor=C["text2"], padding=4)
|
| 496 |
+
s.map("TCombobox",
|
| 497 |
+
fieldbackground=[("readonly", C["input"])],
|
| 498 |
+
foreground=[("readonly", C["input_fg"])],
|
| 499 |
+
selectbackground=[("readonly", C["accent"])],
|
| 500 |
+
selectforeground=[("readonly", "#ffffff")])
|
| 501 |
+
# Combobox dropdown list colors
|
| 502 |
+
self.root.option_add("*TCombobox*Listbox.background", C["input"])
|
| 503 |
+
self.root.option_add("*TCombobox*Listbox.foreground", C["input_fg"])
|
| 504 |
+
self.root.option_add("*TCombobox*Listbox.selectBackground", C["accent"])
|
| 505 |
+
self.root.option_add("*TCombobox*Listbox.selectForeground", "#ffffff")
|
| 506 |
+
self.root.option_add("*TCombobox*Listbox.font", ("Segoe UI", 10))
|
| 507 |
+
|
| 508 |
+
# Spinbox
|
| 509 |
+
s.configure("TSpinbox", fieldbackground=C["input"], foreground=C["input_fg"],
|
| 510 |
+
arrowcolor=C["text2"], padding=3)
|
| 511 |
+
|
| 512 |
+
# Buttons
|
| 513 |
+
s.configure("TButton", font=("Segoe UI", 10), padding=(14, 7),
|
| 514 |
+
background=C["card"], foreground=C["text"])
|
| 515 |
+
s.map("TButton", background=[("active", C["card_sel"]), ("disabled", C["bg"])],
|
| 516 |
+
foreground=[("disabled", C["text3"])])
|
| 517 |
+
|
| 518 |
+
s.configure("Go.TButton", font=("Segoe UI", 11, "bold"), padding=(20, 9),
|
| 519 |
+
background=C["accent"], foreground="#ffffff")
|
| 520 |
+
s.map("Go.TButton", background=[("active", C["accent_h"]),
|
| 521 |
+
("disabled", C["border"])])
|
| 522 |
+
|
| 523 |
+
s.configure("Stop.TButton", font=("Segoe UI", 10, "bold"), padding=(14, 7),
|
| 524 |
+
background=C["red"], foreground="#ffffff")
|
| 525 |
+
s.map("Stop.TButton", background=[("active", "#c0392b"),
|
| 526 |
+
("disabled", C["border"])])
|
| 527 |
+
|
| 528 |
+
# Labelframe
|
| 529 |
+
s.configure("TLabelframe", background=C["bg"], foreground=C["text2"])
|
| 530 |
+
s.configure("TLabelframe.Label", background=C["bg"],
|
| 531 |
+
foreground=C["text2"], font=("Segoe UI", 9, "bold"))
|
| 532 |
+
|
| 533 |
+
# Scrollbar
|
| 534 |
+
s.configure("Vertical.TScrollbar", background=C["card"],
|
| 535 |
+
troughcolor=C["bg"], arrowcolor=C["text3"])
|
| 536 |
+
|
| 537 |
+
def _make_entry(self, parent, font_size=11, dim=False):
|
| 538 |
+
"""Create a styled tk.Entry with readable text."""
|
| 539 |
+
return tk.Entry(parent, font=("Segoe UI", font_size),
|
| 540 |
+
bg=C["input"], fg=C["input_fg"] if not dim else C["text2"],
|
| 541 |
+
insertbackground=C["input_fg"],
|
| 542 |
+
relief="flat", bd=6,
|
| 543 |
+
selectbackground=C["accent"], selectforeground="#ffffff",
|
| 544 |
+
highlightthickness=1, highlightcolor=C["accent"],
|
| 545 |
+
highlightbackground=C["border"])
|
| 546 |
+
|
| 547 |
+
def _build_ui(self):
|
| 548 |
+
# ββ Header ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 549 |
+
header = tk.Frame(self.root, bg=C["panel"], padx=20, pady=12)
|
| 550 |
+
header.pack(fill=tk.X)
|
| 551 |
+
|
| 552 |
+
tk.Label(header, text="Aniimage", bg=C["panel"], fg=C["accent"],
|
| 553 |
+
font=("Segoe UI", 20, "bold")).pack(side=tk.LEFT)
|
| 554 |
+
tk.Label(header, text="by 8BitStudio", bg=C["panel"], fg=C["text3"],
|
| 555 |
+
font=("Segoe UI", 10)).pack(side=tk.LEFT, padx=(10, 0), pady=(6, 0))
|
| 556 |
+
|
| 557 |
+
# Device switch β right side of header
|
| 558 |
+
device_frame = tk.Frame(header, bg=C["panel"])
|
| 559 |
+
device_frame.pack(side=tk.RIGHT)
|
| 560 |
+
|
| 561 |
+
tk.Label(device_frame, text="Device:", bg=C["panel"], fg=C["text2"],
|
| 562 |
+
font=("Segoe UI", 9)).pack(side=tk.LEFT, padx=(0, 5))
|
| 563 |
+
|
| 564 |
+
self.device_var = tk.StringVar(value="GPU" if self.gen.device == "cuda" else "CPU")
|
| 565 |
+
devices = ["GPU", "CPU"] if torch.cuda.is_available() else ["CPU"]
|
| 566 |
+
device_combo = ttk.Combobox(device_frame, textvariable=self.device_var,
|
| 567 |
+
values=devices, state="readonly", width=5)
|
| 568 |
+
device_combo.pack(side=tk.LEFT)
|
| 569 |
+
device_combo.bind("<<ComboboxSelected>>", self._on_device_change)
|
| 570 |
+
|
| 571 |
+
# ββ Main content β two-column: controls left, images right ββββββββ
|
| 572 |
+
main = tk.Frame(self.root, bg=C["bg"])
|
| 573 |
+
main.pack(fill=tk.BOTH, expand=True, padx=12, pady=(8, 12))
|
| 574 |
+
|
| 575 |
+
# Left panel (controls)
|
| 576 |
+
left = tk.Frame(main, bg=C["panel"], width=340, padx=16, pady=12)
|
| 577 |
+
left.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 8))
|
| 578 |
+
left.pack_propagate(False)
|
| 579 |
+
|
| 580 |
+
# Right panel (image grid)
|
| 581 |
+
right = tk.Frame(main, bg=C["bg"])
|
| 582 |
+
right.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
| 583 |
+
|
| 584 |
+
self._build_controls(left)
|
| 585 |
+
self._build_grid(right)
|
| 586 |
+
|
| 587 |
+
def _build_controls(self, parent):
|
| 588 |
+
# ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 589 |
+
tk.Label(parent, text="Model", bg=C["panel"], fg=C["text2"],
|
| 590 |
+
font=("Segoe UI", 9, "bold")).pack(anchor=tk.W)
|
| 591 |
+
|
| 592 |
+
self.model_var = tk.StringVar()
|
| 593 |
+
model_names = [m[1] for m in self.models] or ["No models found"]
|
| 594 |
+
self.model_combo = ttk.Combobox(parent, textvariable=self.model_var,
|
| 595 |
+
values=model_names, state="readonly", width=32)
|
| 596 |
+
self.model_combo.pack(fill=tk.X, pady=(3, 12))
|
| 597 |
+
self.model_combo.current(len(model_names) - 1)
|
| 598 |
+
|
| 599 |
+
# ββ Prompt ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 600 |
+
tk.Label(parent, text="Prompt", bg=C["panel"], fg=C["text2"],
|
| 601 |
+
font=("Segoe UI", 9, "bold")).pack(anchor=tk.W)
|
| 602 |
+
self.prompt_entry = self._make_entry(parent)
|
| 603 |
+
self.prompt_entry.pack(fill=tk.X, pady=(3, 8))
|
| 604 |
+
self.prompt_entry.insert(0, "a smiling anime girl with long blue hair")
|
| 605 |
+
self.prompt_entry.bind("<Return>", lambda e: self.on_generate())
|
| 606 |
+
|
| 607 |
+
# ββ Negative prompt βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 608 |
+
tk.Label(parent, text="Negative prompt", bg=C["panel"], fg=C["text3"],
|
| 609 |
+
font=("Segoe UI", 9)).pack(anchor=tk.W)
|
| 610 |
+
self.neg_entry = self._make_entry(parent, font_size=9, dim=True)
|
| 611 |
+
self.neg_entry.pack(fill=tk.X, pady=(3, 12))
|
| 612 |
+
self.neg_entry.insert(0, DEFAULT_NEGATIVE)
|
| 613 |
+
|
| 614 |
+
# ββ Settings grid βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 615 |
+
grid = tk.Frame(parent, bg=C["panel"])
|
| 616 |
+
grid.pack(fill=tk.X, pady=(0, 8))
|
| 617 |
+
|
| 618 |
+
# Row 1: Scheduler
|
| 619 |
+
tk.Label(grid, text="Scheduler", bg=C["panel"], fg=C["text2"],
|
| 620 |
+
font=("Segoe UI", 9)).grid(row=0, column=0, sticky="w", pady=(0, 6))
|
| 621 |
+
self.scheduler_var = tk.StringVar(value="DPM++ 2M Karras")
|
| 622 |
+
sched_combo = ttk.Combobox(grid, textvariable=self.scheduler_var,
|
| 623 |
+
values=SCHEDULER_LIST, state="readonly", width=18)
|
| 624 |
+
sched_combo.grid(row=0, column=1, columnspan=3, sticky="ew", padx=(8, 0), pady=(0, 6))
|
| 625 |
+
sched_combo.bind("<<ComboboxSelected>>", self._on_scheduler_change)
|
| 626 |
+
|
| 627 |
+
# Row 2: Steps, CFG, Count
|
| 628 |
+
tk.Label(grid, text="Steps", bg=C["panel"], fg=C["text2"],
|
| 629 |
+
font=("Segoe UI", 9)).grid(row=1, column=0, sticky="w", pady=(0, 6))
|
| 630 |
+
self.steps_var = tk.StringVar(value="25")
|
| 631 |
+
tk.Entry(grid, textvariable=self.steps_var, width=5, font=("Segoe UI", 10),
|
| 632 |
+
bg=C["input"], fg=C["input_fg"], insertbackground=C["input_fg"],
|
| 633 |
+
relief="flat", bd=4).grid(row=1, column=1, sticky="w", padx=(8, 12), pady=(0, 6))
|
| 634 |
+
|
| 635 |
+
tk.Label(grid, text="CFG", bg=C["panel"], fg=C["text2"],
|
| 636 |
+
font=("Segoe UI", 9)).grid(row=1, column=2, sticky="w", pady=(0, 6))
|
| 637 |
+
self.cfg_var = tk.StringVar(value="7.5")
|
| 638 |
+
tk.Entry(grid, textvariable=self.cfg_var, width=5, font=("Segoe UI", 10),
|
| 639 |
+
bg=C["input"], fg=C["input_fg"], insertbackground=C["input_fg"],
|
| 640 |
+
relief="flat", bd=4).grid(row=1, column=3, sticky="w", padx=(8, 0), pady=(0, 6))
|
| 641 |
+
|
| 642 |
+
# Row 3: Count, Live preview
|
| 643 |
+
tk.Label(grid, text="Count", bg=C["panel"], fg=C["text2"],
|
| 644 |
+
font=("Segoe UI", 9)).grid(row=2, column=0, sticky="w", pady=(0, 6))
|
| 645 |
+
self.count_var = tk.StringVar(value="4")
|
| 646 |
+
ttk.Spinbox(grid, from_=1, to=12, textvariable=self.count_var, width=4,
|
| 647 |
+
font=("Segoe UI", 10)).grid(row=2, column=1, sticky="w", padx=(8, 12), pady=(0, 6))
|
| 648 |
+
|
| 649 |
+
self.live_preview_var = tk.BooleanVar(value=False)
|
| 650 |
+
ttk.Checkbutton(grid, text="Live preview",
|
| 651 |
+
variable=self.live_preview_var).grid(
|
| 652 |
+
row=2, column=2, columnspan=2, sticky="w", pady=(0, 6))
|
| 653 |
+
|
| 654 |
+
grid.columnconfigure(1, weight=1)
|
| 655 |
+
grid.columnconfigure(3, weight=1)
|
| 656 |
+
|
| 657 |
+
# ββ Auto quality ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 658 |
+
self.auto_quality_var = tk.BooleanVar(value=False)
|
| 659 |
+
ttk.Checkbutton(parent, text="Auto quality (refine if undercooked)",
|
| 660 |
+
variable=self.auto_quality_var).pack(anchor=tk.W, pady=(0, 12))
|
| 661 |
+
|
| 662 |
+
# ββ Buttons βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 663 |
+
btn_frame = tk.Frame(parent, bg=C["panel"])
|
| 664 |
+
btn_frame.pack(fill=tk.X, pady=(0, 10))
|
| 665 |
+
|
| 666 |
+
self.gen_btn = ttk.Button(btn_frame, text="Generate", command=self.on_generate,
|
| 667 |
+
style="Go.TButton")
|
| 668 |
+
self.gen_btn.pack(fill=tk.X, pady=(0, 5))
|
| 669 |
+
|
| 670 |
+
btn_row = tk.Frame(btn_frame, bg=C["panel"])
|
| 671 |
+
btn_row.pack(fill=tk.X)
|
| 672 |
+
|
| 673 |
+
self.stop_btn = ttk.Button(btn_row, text="Stop", command=self.on_stop,
|
| 674 |
+
state=tk.DISABLED, style="Stop.TButton")
|
| 675 |
+
self.stop_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 3))
|
| 676 |
+
|
| 677 |
+
self.save_btn = ttk.Button(btn_row, text="Save Selected", command=self.on_save,
|
| 678 |
+
state=tk.DISABLED)
|
| 679 |
+
self.save_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(3, 3))
|
| 680 |
+
|
| 681 |
+
self.save_all_btn = ttk.Button(btn_row, text="Save All", command=self.on_save_all,
|
| 682 |
+
state=tk.DISABLED)
|
| 683 |
+
self.save_all_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(3, 0))
|
| 684 |
+
|
| 685 |
+
# ββ Prompt queue βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 686 |
+
sep = tk.Frame(parent, height=1, bg=C["border"])
|
| 687 |
+
sep.pack(fill=tk.X, pady=(8, 10))
|
| 688 |
+
|
| 689 |
+
tk.Label(parent, text="Prompt Queue", bg=C["panel"], fg=C["text2"],
|
| 690 |
+
font=("Segoe UI", 9, "bold")).pack(anchor=tk.W)
|
| 691 |
+
|
| 692 |
+
queue_input = tk.Frame(parent, bg=C["panel"])
|
| 693 |
+
queue_input.pack(fill=tk.X, pady=(4, 0))
|
| 694 |
+
|
| 695 |
+
self.queue_entry = self._make_entry(queue_input, font_size=9)
|
| 696 |
+
self.queue_entry.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 4))
|
| 697 |
+
self.queue_entry.bind("<Return>", lambda e: self._queue_add())
|
| 698 |
+
|
| 699 |
+
ttk.Button(queue_input, text="Add", width=4,
|
| 700 |
+
command=self._queue_add).pack(side=tk.LEFT)
|
| 701 |
+
|
| 702 |
+
self.queue_listbox = tk.Listbox(
|
| 703 |
+
parent, height=4, bg=C["input"], fg=C["input_fg"],
|
| 704 |
+
selectbackground=C["accent"], selectforeground="#fff",
|
| 705 |
+
font=("Segoe UI", 9), activestyle="none",
|
| 706 |
+
relief="flat", bd=4, highlightthickness=0)
|
| 707 |
+
self.queue_listbox.pack(fill=tk.X, pady=(5, 0))
|
| 708 |
+
|
| 709 |
+
queue_btns = tk.Frame(parent, bg=C["panel"])
|
| 710 |
+
queue_btns.pack(fill=tk.X, pady=(4, 0))
|
| 711 |
+
|
| 712 |
+
self.queue_run_btn = ttk.Button(queue_btns, text="Run Queue",
|
| 713 |
+
command=self.on_run_queue, style="Go.TButton")
|
| 714 |
+
self.queue_run_btn.pack(side=tk.LEFT, padx=(0, 4))
|
| 715 |
+
|
| 716 |
+
for txt, cmd in [("Remove", self._queue_remove), ("Clear", self._queue_clear),
|
| 717 |
+
("Up", self._queue_move_up), ("Down", self._queue_move_down),
|
| 718 |
+
("+ Current", self._queue_add_current)]:
|
| 719 |
+
ttk.Button(queue_btns, text=txt, command=cmd).pack(side=tk.LEFT, padx=2)
|
| 720 |
+
|
| 721 |
+
# ββ Status bar ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 722 |
+
status_frame = tk.Frame(parent, bg=C["bg"], padx=8, pady=6)
|
| 723 |
+
status_frame.pack(fill=tk.X, side=tk.BOTTOM)
|
| 724 |
+
|
| 725 |
+
self.status_var = tk.StringVar(value="Ready")
|
| 726 |
+
tk.Label(status_frame, textvariable=self.status_var,
|
| 727 |
+
bg=C["bg"], fg=C["green"], font=("Segoe UI", 9),
|
| 728 |
+
anchor="w").pack(fill=tk.X)
|
| 729 |
+
|
| 730 |
+
def _build_grid(self, parent):
|
| 731 |
+
self.canvas = tk.Canvas(parent, bg=C["bg"], highlightthickness=0)
|
| 732 |
+
scrollbar = ttk.Scrollbar(parent, orient=tk.VERTICAL, command=self.canvas.yview)
|
| 733 |
+
self.grid_frame = tk.Frame(self.canvas, bg=C["bg"])
|
| 734 |
+
|
| 735 |
+
self.grid_frame.bind("<Configure>",
|
| 736 |
+
lambda e: self.canvas.configure(
|
| 737 |
+
scrollregion=self.canvas.bbox("all")))
|
| 738 |
+
self.canvas_window = self.canvas.create_window((0, 0), window=self.grid_frame,
|
| 739 |
+
anchor="nw")
|
| 740 |
+
self.canvas.configure(yscrollcommand=scrollbar.set)
|
| 741 |
+
|
| 742 |
+
self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
| 743 |
+
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
| 744 |
+
|
| 745 |
+
self.canvas.bind("<Configure>", self._on_canvas_resize)
|
| 746 |
+
self.canvas.bind_all("<MouseWheel>",
|
| 747 |
+
lambda e: self.canvas.yview_scroll(
|
| 748 |
+
int(-1 * (e.delta / 120)), "units"))
|
| 749 |
+
|
| 750 |
+
self.placeholder = tk.Label(self.grid_frame,
|
| 751 |
+
text="Generated images\nwill appear here",
|
| 752 |
+
bg=C["bg"], fg=C["text3"],
|
| 753 |
+
font=("Segoe UI", 13), justify="center")
|
| 754 |
+
self.placeholder.grid(row=0, column=0, pady=80)
|
| 755 |
+
|
| 756 |
+
# ββ Event handlers ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 757 |
+
|
| 758 |
+
def _on_device_change(self, event=None):
|
| 759 |
+
choice = self.device_var.get()
|
| 760 |
+
new_dev = "cuda" if choice == "GPU" else "cpu"
|
| 761 |
+
self.status_var.set(f"Switching to {choice}...")
|
| 762 |
+
self.root.update()
|
| 763 |
+
self.gen.switch_device(new_dev)
|
| 764 |
+
self.status_var.set(f"Now using {choice}")
|
| 765 |
+
|
| 766 |
+
def _on_scheduler_change(self, event=None):
|
| 767 |
+
name = self.scheduler_var.get()
|
| 768 |
+
self.gen.set_scheduler(name)
|
| 769 |
+
self.status_var.set(f"Scheduler: {name}")
|
| 770 |
+
|
| 771 |
+
def _on_canvas_resize(self, event):
|
| 772 |
+
self.canvas.itemconfig(self.canvas_window, width=event.width)
|
| 773 |
+
if self.generated_images:
|
| 774 |
+
self._layout_grid()
|
| 775 |
+
|
| 776 |
+
def _get_grid_cols(self):
|
| 777 |
+
canvas_w = self.canvas.winfo_width()
|
| 778 |
+
if canvas_w < 50:
|
| 779 |
+
canvas_w = 560
|
| 780 |
+
tile_size = self._get_tile_size()
|
| 781 |
+
return max(1, canvas_w // (tile_size + 16))
|
| 782 |
+
|
| 783 |
+
def _get_tile_size(self):
|
| 784 |
+
n = len(self.generated_images)
|
| 785 |
+
if n <= 2: return 260
|
| 786 |
+
elif n <= 4: return 220
|
| 787 |
+
elif n <= 6: return 180
|
| 788 |
+
else: return 160
|
| 789 |
+
|
| 790 |
+
def _layout_grid(self):
|
| 791 |
+
for w in self.grid_frame.winfo_children():
|
| 792 |
+
w.destroy()
|
| 793 |
+
self.photo_refs.clear()
|
| 794 |
+
|
| 795 |
+
if not self.generated_images:
|
| 796 |
+
return
|
| 797 |
+
|
| 798 |
+
tile_size = self._get_tile_size()
|
| 799 |
+
cols = self._get_grid_cols()
|
| 800 |
+
|
| 801 |
+
for i, (img, seed) in enumerate(zip(self.generated_images, self.generated_seeds)):
|
| 802 |
+
row, col = divmod(i, cols)
|
| 803 |
+
is_selected = (i == self.selected_index)
|
| 804 |
+
|
| 805 |
+
card_bg = C["accent"] if is_selected else C["card"]
|
| 806 |
+
card = tk.Frame(self.grid_frame, bg=card_bg, padx=3, pady=3)
|
| 807 |
+
card.grid(row=row, column=col, padx=5, pady=5, sticky="nsew")
|
| 808 |
+
|
| 809 |
+
display = img.resize((tile_size, tile_size), Image.LANCZOS)
|
| 810 |
+
photo = ImageTk.PhotoImage(display)
|
| 811 |
+
self.photo_refs.append(photo)
|
| 812 |
+
|
| 813 |
+
img_label = tk.Label(card, image=photo, bg=card_bg, bd=0)
|
| 814 |
+
img_label.pack()
|
| 815 |
+
img_label.bind("<Button-1>", lambda e, idx=i: self._select_image(idx))
|
| 816 |
+
img_label.bind("<Button-3>", lambda e, idx=i: self._show_refine_menu(e, idx))
|
| 817 |
+
|
| 818 |
+
tk.Label(card, text=f"seed: {seed}", bg=card_bg,
|
| 819 |
+
fg=C["text3"], font=("Segoe UI", 8)).pack()
|
| 820 |
+
|
| 821 |
+
for c in range(cols):
|
| 822 |
+
self.grid_frame.columnconfigure(c, weight=1)
|
| 823 |
+
|
| 824 |
+
def _select_image(self, idx):
|
| 825 |
+
if idx >= len(self.generated_images):
|
| 826 |
+
return
|
| 827 |
+
self.selected_index = idx
|
| 828 |
+
self.save_btn.configure(state=tk.NORMAL)
|
| 829 |
+
self.status_var.set(f"Selected image {idx + 1} (seed: {self.generated_seeds[idx]})")
|
| 830 |
+
self._layout_grid()
|
| 831 |
+
|
| 832 |
+
def _show_refine_menu(self, event, idx):
|
| 833 |
+
if self.generating:
|
| 834 |
+
return
|
| 835 |
+
menu = tk.Menu(self.root, tearoff=0, bg=C["card"], fg=C["text"],
|
| 836 |
+
activebackground=C["accent"], activeforeground="#fff",
|
| 837 |
+
font=("Segoe UI", 10), bd=0)
|
| 838 |
+
menu.add_command(label=" Refine (more steps)... ",
|
| 839 |
+
command=lambda: self._ask_refine(idx))
|
| 840 |
+
menu.tk_popup(event.x_root, event.y_root)
|
| 841 |
+
|
| 842 |
+
def _ask_refine(self, idx):
|
| 843 |
+
extra = simpledialog.askinteger(
|
| 844 |
+
"Refine Image", "Extra denoising steps:",
|
| 845 |
+
initialvalue=20, minvalue=5, maxvalue=200, parent=self.root)
|
| 846 |
+
if extra is None:
|
| 847 |
+
return
|
| 848 |
+
self._select_image(idx)
|
| 849 |
+
self.generating = True
|
| 850 |
+
self.gen.cancelled = False
|
| 851 |
+
self.gen_btn.configure(state=tk.DISABLED)
|
| 852 |
+
self.stop_btn.configure(state=tk.NORMAL)
|
| 853 |
+
self.status_var.set(f"Refining image {idx + 1}...")
|
| 854 |
+
self.root.update()
|
| 855 |
+
Thread(target=self._refine_thread, args=(idx, extra), daemon=True).start()
|
| 856 |
+
|
| 857 |
+
def _refine_thread(self, idx, extra_steps):
|
| 858 |
+
try:
|
| 859 |
+
source = self.generated_images[idx]
|
| 860 |
+
prompt = self.prompt_entry.get().strip()
|
| 861 |
+
neg = self.neg_entry.get().strip()
|
| 862 |
+
cfg = float(self.cfg_var.get())
|
| 863 |
+
callback = self._show_preview if self.live_preview_var.get() else None
|
| 864 |
+
|
| 865 |
+
refined = self.gen.refine(
|
| 866 |
+
source_image=source, prompt=prompt, negative_prompt=neg,
|
| 867 |
+
extra_steps=extra_steps, guidance_scale=cfg,
|
| 868 |
+
preview_callback=callback, preview_every=5)
|
| 869 |
+
|
| 870 |
+
if refined is not None:
|
| 871 |
+
self.generated_images[idx] = refined
|
| 872 |
+
self.generated_seeds[idx] = f"{self.generated_seeds[idx]}+R{extra_steps}"
|
| 873 |
+
self._layout_grid()
|
| 874 |
+
self.status_var.set(f"Refined image {idx + 1}")
|
| 875 |
+
else:
|
| 876 |
+
self.status_var.set("Refine stopped.")
|
| 877 |
+
self.root.update()
|
| 878 |
+
except Exception as e:
|
| 879 |
+
self.status_var.set(f"Refine error: {e}")
|
| 880 |
+
import traceback; traceback.print_exc()
|
| 881 |
+
finally:
|
| 882 |
+
self.generating = False
|
| 883 |
+
self.gen.cancelled = False
|
| 884 |
+
self.gen_btn.configure(state=tk.NORMAL)
|
| 885 |
+
self.stop_btn.configure(state=tk.DISABLED)
|
| 886 |
+
|
| 887 |
+
# ββ Queue βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 888 |
+
|
| 889 |
+
def _queue_add(self):
|
| 890 |
+
text = self.queue_entry.get().strip()
|
| 891 |
+
if text:
|
| 892 |
+
self.queue_listbox.insert(tk.END, text)
|
| 893 |
+
self.queue_entry.delete(0, tk.END)
|
| 894 |
+
|
| 895 |
+
def _queue_add_current(self):
|
| 896 |
+
text = self.prompt_entry.get().strip()
|
| 897 |
+
if text:
|
| 898 |
+
self.queue_listbox.insert(tk.END, text)
|
| 899 |
+
|
| 900 |
+
def _queue_remove(self):
|
| 901 |
+
sel = self.queue_listbox.curselection()
|
| 902 |
+
if sel:
|
| 903 |
+
self.queue_listbox.delete(sel[0])
|
| 904 |
+
|
| 905 |
+
def _queue_clear(self):
|
| 906 |
+
self.queue_listbox.delete(0, tk.END)
|
| 907 |
+
|
| 908 |
+
def _queue_move_up(self):
|
| 909 |
+
sel = self.queue_listbox.curselection()
|
| 910 |
+
if sel and sel[0] > 0:
|
| 911 |
+
idx = sel[0]
|
| 912 |
+
text = self.queue_listbox.get(idx)
|
| 913 |
+
self.queue_listbox.delete(idx)
|
| 914 |
+
self.queue_listbox.insert(idx - 1, text)
|
| 915 |
+
self.queue_listbox.selection_set(idx - 1)
|
| 916 |
+
|
| 917 |
+
def _queue_move_down(self):
|
| 918 |
+
sel = self.queue_listbox.curselection()
|
| 919 |
+
if sel and sel[0] < self.queue_listbox.size() - 1:
|
| 920 |
+
idx = sel[0]
|
| 921 |
+
text = self.queue_listbox.get(idx)
|
| 922 |
+
self.queue_listbox.delete(idx)
|
| 923 |
+
self.queue_listbox.insert(idx + 1, text)
|
| 924 |
+
self.queue_listbox.selection_set(idx + 1)
|
| 925 |
+
|
| 926 |
+
def on_run_queue(self):
|
| 927 |
+
if self.generating or not self.models:
|
| 928 |
+
return
|
| 929 |
+
prompts = list(self.queue_listbox.get(0, tk.END))
|
| 930 |
+
if not prompts:
|
| 931 |
+
self.status_var.set("Queue is empty")
|
| 932 |
+
return
|
| 933 |
+
self.generating = True
|
| 934 |
+
self.gen.cancelled = False
|
| 935 |
+
self.gen_btn.configure(state=tk.DISABLED)
|
| 936 |
+
self.queue_run_btn.configure(state=tk.DISABLED)
|
| 937 |
+
self.stop_btn.configure(state=tk.NORMAL)
|
| 938 |
+
Thread(target=self._queue_thread, args=(prompts,), daemon=True).start()
|
| 939 |
+
|
| 940 |
+
def _queue_thread(self, prompts):
|
| 941 |
+
try:
|
| 942 |
+
idx = self.model_combo.current()
|
| 943 |
+
mdl = self.models[idx]
|
| 944 |
+
self.status_var.set(f"Loading {mdl[1]}...")
|
| 945 |
+
self.root.update()
|
| 946 |
+
self.gen.load_model(mdl[2], mdl[3])
|
| 947 |
+
|
| 948 |
+
neg = self.neg_entry.get().strip()
|
| 949 |
+
steps = int(self.steps_var.get())
|
| 950 |
+
cfg = float(self.cfg_var.get())
|
| 951 |
+
num_images = max(1, min(12, int(self.count_var.get())))
|
| 952 |
+
live_preview = self.live_preview_var.get()
|
| 953 |
+
auto_quality = self.auto_quality_var.get()
|
| 954 |
+
|
| 955 |
+
self.generated_images.clear()
|
| 956 |
+
self.generated_seeds.clear()
|
| 957 |
+
self.selected_index = None
|
| 958 |
+
if self.placeholder:
|
| 959 |
+
self.placeholder.destroy()
|
| 960 |
+
self.placeholder = None
|
| 961 |
+
|
| 962 |
+
for p_idx, prompt in enumerate(prompts):
|
| 963 |
+
if self.gen.cancelled:
|
| 964 |
+
break
|
| 965 |
+
self.queue_listbox.selection_clear(0, tk.END)
|
| 966 |
+
self.queue_listbox.selection_set(p_idx)
|
| 967 |
+
self.queue_listbox.see(p_idx)
|
| 968 |
+
|
| 969 |
+
for img_i in range(num_images):
|
| 970 |
+
if self.gen.cancelled:
|
| 971 |
+
break
|
| 972 |
+
self.status_var.set(
|
| 973 |
+
f"[{p_idx + 1}/{len(prompts)}] image {img_i + 1}/{num_images}")
|
| 974 |
+
self.root.update()
|
| 975 |
+
|
| 976 |
+
callback = None
|
| 977 |
+
if live_preview:
|
| 978 |
+
self._setup_preview_card()
|
| 979 |
+
callback = self._show_preview
|
| 980 |
+
|
| 981 |
+
if auto_quality:
|
| 982 |
+
image, used_seed = self.gen.generate_adaptive(
|
| 983 |
+
prompt=prompt, negative_prompt=neg,
|
| 984 |
+
base_steps=steps, max_steps=steps + 60,
|
| 985 |
+
guidance_scale=cfg,
|
| 986 |
+
preview_callback=callback, preview_every=5,
|
| 987 |
+
status_callback=lambda m: (
|
| 988 |
+
self.status_var.set(m), self.root.update()))
|
| 989 |
+
else:
|
| 990 |
+
image, used_seed = self.gen.generate(
|
| 991 |
+
prompt=prompt, negative_prompt=neg,
|
| 992 |
+
steps=steps, guidance_scale=cfg,
|
| 993 |
+
preview_callback=callback, preview_every=5)
|
| 994 |
+
|
| 995 |
+
if image is None:
|
| 996 |
+
break
|
| 997 |
+
self.generated_images.append(image)
|
| 998 |
+
self.generated_seeds.append(used_seed)
|
| 999 |
+
save_path = self._next_save_path(prompt)
|
| 1000 |
+
image.save(save_path)
|
| 1001 |
+
self._layout_grid()
|
| 1002 |
+
self.root.update()
|
| 1003 |
+
|
| 1004 |
+
if self.gen.cancelled:
|
| 1005 |
+
break
|
| 1006 |
+
|
| 1007 |
+
done = len(self.generated_images)
|
| 1008 |
+
self.status_var.set(
|
| 1009 |
+
f"Queue {'stopped' if self.gen.cancelled else 'done'}! {done} images saved.")
|
| 1010 |
+
if done > 0:
|
| 1011 |
+
self.save_all_btn.configure(state=tk.NORMAL)
|
| 1012 |
+
|
| 1013 |
+
except Exception as e:
|
| 1014 |
+
self.status_var.set(f"Queue error: {e}")
|
| 1015 |
+
import traceback; traceback.print_exc()
|
| 1016 |
+
finally:
|
| 1017 |
+
self.generating = False
|
| 1018 |
+
self.gen.cancelled = False
|
| 1019 |
+
self.gen_btn.configure(state=tk.NORMAL)
|
| 1020 |
+
self.queue_run_btn.configure(state=tk.NORMAL)
|
| 1021 |
+
self.stop_btn.configure(state=tk.DISABLED)
|
| 1022 |
+
|
| 1023 |
+
# ββ Generation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1024 |
+
|
| 1025 |
+
def on_stop(self):
|
| 1026 |
+
if self.generating:
|
| 1027 |
+
self.gen.cancelled = True
|
| 1028 |
+
self.status_var.set("Stopping...")
|
| 1029 |
+
self.root.update()
|
| 1030 |
+
|
| 1031 |
+
def on_generate(self):
|
| 1032 |
+
if self.generating or not self.models:
|
| 1033 |
+
return
|
| 1034 |
+
self.generating = True
|
| 1035 |
+
self.gen.cancelled = False
|
| 1036 |
+
self.gen_btn.configure(state=tk.DISABLED)
|
| 1037 |
+
self.stop_btn.configure(state=tk.NORMAL)
|
| 1038 |
+
self.status_var.set("Loading model...")
|
| 1039 |
+
self.root.update()
|
| 1040 |
+
Thread(target=self._generate_thread, daemon=True).start()
|
| 1041 |
+
|
| 1042 |
+
def _setup_preview_card(self):
|
| 1043 |
+
tile_size = self._get_tile_size()
|
| 1044 |
+
cols = self._get_grid_cols()
|
| 1045 |
+
row, col = divmod(len(self.generated_images), cols)
|
| 1046 |
+
card = tk.Frame(self.grid_frame, bg=C["card"], padx=3, pady=3)
|
| 1047 |
+
card.grid(row=row, column=col, padx=5, pady=5, sticky="nsew")
|
| 1048 |
+
self._preview_label = tk.Label(card, bg=C["card"],
|
| 1049 |
+
width=tile_size, height=tile_size)
|
| 1050 |
+
self._preview_label.pack()
|
| 1051 |
+
self.root.update()
|
| 1052 |
+
|
| 1053 |
+
def _show_preview(self, preview_img, step, total):
|
| 1054 |
+
tile_size = self._get_tile_size()
|
| 1055 |
+
display = preview_img.resize((tile_size, tile_size), Image.LANCZOS)
|
| 1056 |
+
photo = ImageTk.PhotoImage(display)
|
| 1057 |
+
self._preview_photo = photo
|
| 1058 |
+
if hasattr(self, '_preview_label') and self._preview_label.winfo_exists():
|
| 1059 |
+
self._preview_label.configure(image=photo)
|
| 1060 |
+
self.status_var.set(f"Step {step}/{total}")
|
| 1061 |
+
self.root.update()
|
| 1062 |
+
|
| 1063 |
+
def _generate_thread(self):
|
| 1064 |
+
try:
|
| 1065 |
+
idx = self.model_combo.current()
|
| 1066 |
+
mdl = self.models[idx]
|
| 1067 |
+
self.status_var.set(f"Loading {mdl[1]}...")
|
| 1068 |
+
self.root.update()
|
| 1069 |
+
self.gen.load_model(mdl[2], mdl[3])
|
| 1070 |
+
|
| 1071 |
+
prompt = self.prompt_entry.get().strip()
|
| 1072 |
+
neg = self.neg_entry.get().strip()
|
| 1073 |
+
steps = int(self.steps_var.get())
|
| 1074 |
+
cfg = float(self.cfg_var.get())
|
| 1075 |
+
num_images = max(1, min(12, int(self.count_var.get())))
|
| 1076 |
+
live_preview = self.live_preview_var.get()
|
| 1077 |
+
auto_quality = self.auto_quality_var.get()
|
| 1078 |
+
|
| 1079 |
+
self.generated_images.clear()
|
| 1080 |
+
self.generated_seeds.clear()
|
| 1081 |
+
self.selected_index = None
|
| 1082 |
+
if self.placeholder:
|
| 1083 |
+
self.placeholder.destroy()
|
| 1084 |
+
self.placeholder = None
|
| 1085 |
+
|
| 1086 |
+
for i in range(num_images):
|
| 1087 |
+
if self.gen.cancelled:
|
| 1088 |
+
break
|
| 1089 |
+
self.status_var.set(f"Generating {i + 1}/{num_images}...")
|
| 1090 |
+
self.root.update()
|
| 1091 |
+
|
| 1092 |
+
callback = None
|
| 1093 |
+
if live_preview:
|
| 1094 |
+
self._setup_preview_card()
|
| 1095 |
+
callback = self._show_preview
|
| 1096 |
+
|
| 1097 |
+
if auto_quality:
|
| 1098 |
+
image, used_seed = self.gen.generate_adaptive(
|
| 1099 |
+
prompt=prompt, negative_prompt=neg,
|
| 1100 |
+
base_steps=steps, max_steps=steps + 60,
|
| 1101 |
+
guidance_scale=cfg,
|
| 1102 |
+
preview_callback=callback, preview_every=5,
|
| 1103 |
+
status_callback=lambda m: (
|
| 1104 |
+
self.status_var.set(m), self.root.update()))
|
| 1105 |
+
else:
|
| 1106 |
+
image, used_seed = self.gen.generate(
|
| 1107 |
+
prompt=prompt, negative_prompt=neg,
|
| 1108 |
+
steps=steps, guidance_scale=cfg,
|
| 1109 |
+
preview_callback=callback, preview_every=5)
|
| 1110 |
+
|
| 1111 |
+
if image is None:
|
| 1112 |
+
break
|
| 1113 |
+
self.generated_images.append(image)
|
| 1114 |
+
self.generated_seeds.append(used_seed)
|
| 1115 |
+
self._layout_grid()
|
| 1116 |
+
self.root.update()
|
| 1117 |
+
|
| 1118 |
+
done = len(self.generated_images)
|
| 1119 |
+
if self.gen.cancelled:
|
| 1120 |
+
self.status_var.set(f"Stopped. {done} image(s) kept.")
|
| 1121 |
+
else:
|
| 1122 |
+
self.status_var.set(f"Done! {done} images. Click to select.")
|
| 1123 |
+
if done > 0:
|
| 1124 |
+
self.save_all_btn.configure(state=tk.NORMAL)
|
| 1125 |
+
self.save_btn.configure(state=tk.DISABLED)
|
| 1126 |
+
|
| 1127 |
+
except Exception as e:
|
| 1128 |
+
self.status_var.set(f"Error: {e}")
|
| 1129 |
+
import traceback; traceback.print_exc()
|
| 1130 |
+
finally:
|
| 1131 |
+
self.generating = False
|
| 1132 |
+
self.gen.cancelled = False
|
| 1133 |
+
self.gen_btn.configure(state=tk.NORMAL)
|
| 1134 |
+
self.stop_btn.configure(state=tk.DISABLED)
|
| 1135 |
+
|
| 1136 |
+
# ββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1137 |
+
|
| 1138 |
+
def _next_save_path(self, prompt_text):
|
| 1139 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 1140 |
+
slug = prompt_text.strip()[:50] if prompt_text.strip() else "untitled"
|
| 1141 |
+
base = OUTPUT_DIR / f"{slug}.png"
|
| 1142 |
+
if not base.exists():
|
| 1143 |
+
return base
|
| 1144 |
+
n = 1
|
| 1145 |
+
while True:
|
| 1146 |
+
path = OUTPUT_DIR / f"{slug} {n}.png"
|
| 1147 |
+
if not path.exists():
|
| 1148 |
+
return path
|
| 1149 |
+
n += 1
|
| 1150 |
+
|
| 1151 |
+
def on_save(self):
|
| 1152 |
+
if self.selected_index is None or not self.generated_images:
|
| 1153 |
+
return
|
| 1154 |
+
img = self.generated_images[self.selected_index]
|
| 1155 |
+
path = self._next_save_path(self.prompt_entry.get().strip())
|
| 1156 |
+
img.save(path)
|
| 1157 |
+
self.status_var.set(f"Saved: {path.name}")
|
| 1158 |
+
|
| 1159 |
+
def on_save_all(self):
|
| 1160 |
+
if not self.generated_images:
|
| 1161 |
+
return
|
| 1162 |
+
prompt_text = self.prompt_entry.get().strip()
|
| 1163 |
+
for img in self.generated_images:
|
| 1164 |
+
path = self._next_save_path(prompt_text)
|
| 1165 |
+
img.save(path)
|
| 1166 |
+
self.status_var.set(f"Saved {len(self.generated_images)} images")
|
| 1167 |
+
|
| 1168 |
+
def run(self):
|
| 1169 |
+
self.root.mainloop()
|
| 1170 |
+
|
| 1171 |
+
|
| 1172 |
+
# ββ Entry point βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1173 |
+
|
| 1174 |
+
if __name__ == "__main__":
|
| 1175 |
+
models = find_models()
|
| 1176 |
+
if not models:
|
| 1177 |
+
print("No models found locally. Downloading from HuggingFace...")
|
| 1178 |
+
result = download_from_hf()
|
| 1179 |
+
if result:
|
| 1180 |
+
models = find_models()
|
| 1181 |
+
|
| 1182 |
+
if not models:
|
| 1183 |
+
print("No models found!")
|
| 1184 |
+
print(f"Place model weights in: {MODEL_DIR}/YourModelName/")
|
| 1185 |
+
print("Expected files: diffusion_pytorch_model.safetensors or ema_unet.pt")
|
| 1186 |
+
sys.exit(1)
|
| 1187 |
+
|
| 1188 |
+
print(f"Found {len(models)} model(s): {', '.join(m[1] for m in models)}")
|
| 1189 |
+
print(f"Device: {'CUDA (GPU)' if torch.cuda.is_available() else 'CPU'}")
|
| 1190 |
+
print("Starting Aniimage...")
|
| 1191 |
+
|
| 1192 |
+
app = App()
|
| 1193 |
+
app.run()
|