8BitStudio commited on
Commit
5dbc62b
Β·
verified Β·
1 Parent(s): ce1093f

Upload generate_hf.py

Browse files
Files changed (1) hide show
  1. generate_hf.py +1193 -0
generate_hf.py ADDED
@@ -0,0 +1,1193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Aniimage Generator β€” Generate anime images from text prompts.
3
+ https://huggingface.co/8BitStudio/Aniimage-1
4
+
5
+ Usage:
6
+ pip install torch torchvision diffusers transformers safetensors pillow huggingface_hub
7
+ python generate_hf.py
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ import tkinter as tk
16
+ from tkinter import ttk, simpledialog
17
+ from pathlib import Path
18
+ from PIL import Image, ImageTk, ImageEnhance, ImageFilter
19
+ from threading import Thread
20
+
21
+ # ── Paths ─────────────────────────────────────────────────────────────────────
22
+ SCRIPT_DIR = Path(__file__).resolve().parent
23
+ MODEL_DIR = SCRIPT_DIR / "models"
24
+ OUTPUT_DIR = SCRIPT_DIR / "generated"
25
+
26
+ # ── HuggingFace repo ─────────────────────────────────────────────────────────
27
+ HF_REPO_ID = "8BitStudio/Aniimage-1"
28
+
29
+ # ── UNet config (must match training) ─────────────────────────────────────────
30
+ UNET_CONFIG = dict(
31
+ sample_size=32,
32
+ in_channels=4,
33
+ out_channels=4,
34
+ block_out_channels=(256, 512, 768, 1024),
35
+ layers_per_block=2,
36
+ cross_attention_dim=768,
37
+ attention_head_dim=8,
38
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D",
39
+ "CrossAttnDownBlock2D", "DownBlock2D"),
40
+ up_block_types=("UpBlock2D", "CrossAttnUpBlock2D",
41
+ "CrossAttnUpBlock2D", "UpBlock2D"),
42
+ )
43
+
44
+ VAE_ID = "stabilityai/sd-vae-ft-mse"
45
+ CLIP_ID = "openai/clip-vit-large-patch14"
46
+
47
+ SCHEDULER_LIST = [
48
+ "DPM++ 2M Karras",
49
+ "DPM++ SDE Karras",
50
+ "Euler a",
51
+ "Euler",
52
+ "DDIM",
53
+ ]
54
+
55
+ DEFAULT_NEGATIVE = (
56
+ "low quality, ugly, blurry, distorted, deformed, bad anatomy, "
57
+ "bad proportions, extra limbs, missing limbs, watermark, text, "
58
+ "signature, washed out, flat colors, manga panel, disfigured, "
59
+ "poorly drawn, jpeg artifacts, cropped, out of frame"
60
+ )
61
+
62
+
63
+ # ── Model discovery ───────────────────────────────────────────────────────────
64
+
65
+ def download_from_hf():
66
+ """Download model weights from HuggingFace if not already cached."""
67
+ try:
68
+ from huggingface_hub import hf_hub_download
69
+ except ImportError:
70
+ print("Install huggingface_hub: pip install huggingface_hub")
71
+ return None
72
+
73
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
74
+ aniimage_dir = MODEL_DIR / "Aniimage-1"
75
+ weights_path = aniimage_dir / "diffusion_pytorch_model.safetensors"
76
+
77
+ if weights_path.exists():
78
+ print("Aniimage-1 weights already downloaded.")
79
+ return aniimage_dir
80
+
81
+ print(f"Downloading Aniimage-1 from {HF_REPO_ID}...")
82
+ aniimage_dir.mkdir(parents=True, exist_ok=True)
83
+
84
+ import shutil
85
+ dl_weights = hf_hub_download(repo_id=HF_REPO_ID,
86
+ filename="diffusion_pytorch_model.safetensors")
87
+ shutil.copy2(dl_weights, weights_path)
88
+
89
+ try:
90
+ dl_config = hf_hub_download(repo_id=HF_REPO_ID, filename="config.json")
91
+ shutil.copy2(dl_config, aniimage_dir / "config.json")
92
+ except Exception:
93
+ pass
94
+
95
+ print("Download complete!")
96
+ return aniimage_dir
97
+
98
+
99
+ def find_models():
100
+ """Find all available models."""
101
+ options = []
102
+ if MODEL_DIR.exists():
103
+ for d in sorted(MODEL_DIR.iterdir()):
104
+ if d.is_dir():
105
+ safetensors = d / "diffusion_pytorch_model.safetensors"
106
+ ema_path = d / "ema_unet.pt"
107
+ unet_path = d / "unet.pt"
108
+ if safetensors.exists():
109
+ options.append(("safetensors", d.name, d, "256"))
110
+ elif ema_path.exists() or unet_path.exists():
111
+ options.append(("checkpoint", d.name, d, "256"))
112
+ return options
113
+
114
+
115
+ # ── Theme ─────────────────────────────────────────────────────────────────────
116
+
117
+ C = {
118
+ "bg": "#111119",
119
+ "panel": "#1b1b2f",
120
+ "card": "#24243e",
121
+ "card_sel": "#3a3a6e",
122
+ "border": "#2e2e52",
123
+ "accent": "#6c5ce7",
124
+ "accent_h": "#8577ed",
125
+ "red": "#e74c3c",
126
+ "green": "#2ecc71",
127
+ "text": "#eaeaea",
128
+ "text2": "#a0a0b8",
129
+ "text3": "#60607a",
130
+ "input": "#16162a",
131
+ "input_fg": "#dcdcf0",
132
+ }
133
+
134
+
135
+ class Generator:
136
+ def __init__(self, device="cuda"):
137
+ self.device = device if device == "cuda" and torch.cuda.is_available() else "cpu"
138
+ self.vae = None
139
+ self.text_encoder = None
140
+ self.tokenizer = None
141
+ self.unet = None
142
+ self.scheduler = None
143
+ self.loaded_checkpoint = None
144
+ self.latent_size = 32
145
+ self.output_size = 256
146
+ self.cancelled = False
147
+
148
+ def switch_device(self, new_device):
149
+ """Move all loaded models to a new device."""
150
+ new_device = new_device if new_device == "cuda" and torch.cuda.is_available() else "cpu"
151
+ if new_device == self.device:
152
+ return
153
+ self.device = new_device
154
+ if self.vae is not None:
155
+ self.vae = self.vae.to(self.device)
156
+ if self.text_encoder is not None:
157
+ self.text_encoder = self.text_encoder.to(self.device)
158
+ if self.unet is not None:
159
+ self.unet = self.unet.to(self.device)
160
+ self.loaded_checkpoint = None # force reload on next generate
161
+ print(f"Switched to {self.device.upper()}")
162
+
163
+ def load_shared(self):
164
+ if self.vae is not None:
165
+ return
166
+ from diffusers import AutoencoderKL
167
+ from transformers import CLIPTextModel, CLIPTokenizer
168
+
169
+ print("Loading VAE...")
170
+ self.vae = AutoencoderKL.from_pretrained(VAE_ID).to(self.device)
171
+ self.vae.eval()
172
+
173
+ print("Loading CLIP text encoder...")
174
+ self.tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID)
175
+ self.text_encoder = CLIPTextModel.from_pretrained(CLIP_ID).to(self.device)
176
+ self.text_encoder.eval()
177
+
178
+ self.scheduler = self._make_scheduler("DPM++ 2M Karras")
179
+ self.scheduler_name = "DPM++ 2M Karras"
180
+ print("Shared models loaded.")
181
+
182
+ def _make_scheduler(self, name="DPM++ 2M Karras"):
183
+ from diffusers import (DDIMScheduler, DPMSolverMultistepScheduler,
184
+ EulerAncestralDiscreteScheduler,
185
+ EulerDiscreteScheduler)
186
+ base = dict(num_train_timesteps=1000, beta_schedule="scaled_linear",
187
+ prediction_type="epsilon")
188
+ if name == "DPM++ 2M Karras":
189
+ return DPMSolverMultistepScheduler(
190
+ **base, algorithm_type="dpmsolver++",
191
+ solver_order=2, use_karras_sigmas=True)
192
+ elif name == "DPM++ SDE Karras":
193
+ return DPMSolverMultistepScheduler(
194
+ **base, algorithm_type="sde-dpmsolver++",
195
+ use_karras_sigmas=True)
196
+ elif name == "Euler a":
197
+ return EulerAncestralDiscreteScheduler(**base)
198
+ elif name == "Euler":
199
+ return EulerDiscreteScheduler(**base)
200
+ else:
201
+ return DDIMScheduler(**base, clip_sample=False,
202
+ set_alpha_to_one=False)
203
+
204
+ def set_scheduler(self, name):
205
+ self.scheduler = self._make_scheduler(name)
206
+ self.scheduler_name = name
207
+
208
+ def load_model(self, model_path: Path, res_label: str = "256"):
209
+ if str(model_path) == self.loaded_checkpoint:
210
+ return
211
+ from diffusers import UNet2DConditionModel
212
+
213
+ self.load_shared()
214
+
215
+ if res_label == "512":
216
+ self.latent_size = 64
217
+ self.output_size = 512
218
+ else:
219
+ self.latent_size = 32
220
+ self.output_size = 256
221
+
222
+ unet_cfg = dict(UNET_CONFIG)
223
+ unet_cfg["sample_size"] = self.latent_size
224
+
225
+ print(f"Loading UNet from {model_path.name} ({res_label}px)...")
226
+ self.unet = UNet2DConditionModel(**unet_cfg).to(self.device)
227
+
228
+ safetensors_path = model_path / "diffusion_pytorch_model.safetensors"
229
+ ema_path = model_path / "ema_unet.pt"
230
+ unet_path = model_path / "unet.pt"
231
+
232
+ if safetensors_path.exists():
233
+ from safetensors.torch import load_file
234
+ state = load_file(str(safetensors_path), device=str(self.device))
235
+ self.unet.load_state_dict(state)
236
+ print("Loaded safetensors weights.")
237
+ elif ema_path.exists():
238
+ state = torch.load(ema_path, map_location=self.device, weights_only=True)
239
+ if "shadow_params" in state:
240
+ params = dict(self.unet.named_parameters())
241
+ keys = list(params.keys())
242
+ for i, sp in enumerate(state["shadow_params"]):
243
+ params[keys[i]].data.copy_(sp)
244
+ else:
245
+ self.unet.load_state_dict(state)
246
+ print("Loaded EMA weights.")
247
+ elif unet_path.exists():
248
+ self.unet.load_state_dict(
249
+ torch.load(unet_path, map_location=self.device, weights_only=True))
250
+ print("Loaded UNet weights.")
251
+ else:
252
+ raise FileNotFoundError(f"No weights found in {model_path}")
253
+
254
+ self.unet.eval()
255
+ self.loaded_checkpoint = str(model_path)
256
+ print(f"Ready to generate at {self.output_size}x{self.output_size}!")
257
+
258
+ def _decode_latents(self, latents, post_process=False):
259
+ scaled = latents / self.vae.config.scaling_factor
260
+ with torch.no_grad():
261
+ image = self.vae.decode(scaled.float()).sample
262
+ image = (image.float() / 2 + 0.5).clamp(0, 1)
263
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
264
+ image = (image * 255).round().astype("uint8")
265
+ img = Image.fromarray(image)
266
+ if post_process:
267
+ img = self._post_process(img)
268
+ return img
269
+
270
+ def _sharpen_latents(self, latents, amount=0.08):
271
+ blurred = F.avg_pool2d(latents, kernel_size=3, stride=1, padding=1)
272
+ return latents + amount * (latents - blurred)
273
+
274
+ def _post_process(self, img):
275
+ img = img.filter(ImageFilter.UnsharpMask(radius=1.5, percent=40, threshold=2))
276
+ img = ImageEnhance.Contrast(img).enhance(1.06)
277
+ img = ImageEnhance.Color(img).enhance(1.10)
278
+ return img
279
+
280
+ def _image_quality_score(self, img: Image.Image) -> float:
281
+ arr = np.array(img.convert("L"), dtype=np.float32)
282
+ lap = (np.roll(arr, 1, 0) + np.roll(arr, -1, 0)
283
+ + np.roll(arr, 1, 1) + np.roll(arr, -1, 1) - 4.0 * arr)
284
+ sharpness = float(np.var(lap))
285
+ arr_rgb = np.array(img, dtype=np.float32)
286
+ color_var = float(np.mean(np.var(arr_rgb, axis=(0, 1))))
287
+ score = (sharpness * 0.6 + color_var * 0.4)
288
+ return min(100.0, score / 10.0)
289
+
290
+ @torch.no_grad()
291
+ def generate(self, prompt: str, negative_prompt: str = "",
292
+ steps: int = 25, guidance_scale: float = 7.5,
293
+ seed: int = -1, preview_callback=None,
294
+ preview_every: int = 5) -> tuple:
295
+
296
+ if seed < 0:
297
+ seed = torch.randint(0, 2**32, (1,)).item()
298
+ gen = torch.Generator(device=self.device).manual_seed(seed)
299
+
300
+ tok = self.tokenizer(prompt, padding="max_length",
301
+ max_length=self.tokenizer.model_max_length,
302
+ truncation=True, return_tensors="pt")
303
+ text_emb = self.text_encoder(tok.input_ids.to(self.device))[0]
304
+
305
+ tok_neg = self.tokenizer(negative_prompt if negative_prompt else "",
306
+ padding="max_length",
307
+ max_length=self.tokenizer.model_max_length,
308
+ truncation=True, return_tensors="pt")
309
+ neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0]
310
+
311
+ text_emb_combined = torch.cat([neg_emb, text_emb])
312
+
313
+ scheduler = self._make_scheduler(self.scheduler_name)
314
+ scheduler.set_timesteps(steps, device=self.device)
315
+
316
+ latents = torch.randn(1, 4, self.latent_size, self.latent_size,
317
+ generator=gen, device=self.device)
318
+ latents = latents * scheduler.init_noise_sigma
319
+
320
+ timesteps = scheduler.timesteps
321
+ total_steps = len(timesteps)
322
+
323
+ for step_i, t in enumerate(timesteps):
324
+ if self.cancelled:
325
+ return None, seed
326
+
327
+ latent_input = torch.cat([latents] * 2)
328
+ latent_input = scheduler.scale_model_input(latent_input, t)
329
+
330
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
331
+ enabled=(self.device == "cuda")):
332
+ pred = self.unet(latent_input, t,
333
+ encoder_hidden_states=text_emb_combined).sample
334
+
335
+ pred_neg, pred_text = pred.chunk(2)
336
+ pred = pred_neg + guidance_scale * (pred_text - pred_neg)
337
+
338
+ latents = scheduler.step(pred, t, latents).prev_sample
339
+
340
+ if (preview_callback and step_i > 0
341
+ and step_i % preview_every == 0
342
+ and step_i < total_steps - 1):
343
+ preview = self._decode_latents(latents, post_process=False)
344
+ preview_callback(preview, step_i + 1, total_steps)
345
+
346
+ latents = self._sharpen_latents(latents)
347
+ final = self._decode_latents(latents, post_process=True)
348
+ return final, seed
349
+
350
+ @torch.no_grad()
351
+ def generate_adaptive(self, prompt: str, negative_prompt: str = "",
352
+ base_steps: int = 25, max_steps: int = 85,
353
+ guidance_scale: float = 7.5,
354
+ quality_threshold: float = 45.0,
355
+ preview_callback=None, preview_every: int = 5,
356
+ status_callback=None) -> tuple:
357
+
358
+ result = self.generate(
359
+ prompt=prompt, negative_prompt=negative_prompt,
360
+ steps=base_steps, guidance_scale=guidance_scale,
361
+ preview_callback=preview_callback, preview_every=preview_every)
362
+
363
+ if result[0] is None:
364
+ return result
365
+
366
+ image, seed = result
367
+ quality = self._image_quality_score(image)
368
+
369
+ if status_callback:
370
+ status_callback(f"Quality: {quality:.1f}/100")
371
+
372
+ if quality >= quality_threshold:
373
+ return image, seed
374
+
375
+ rounds = 0
376
+ max_rounds = (max_steps - base_steps) // 20
377
+
378
+ while quality < quality_threshold and rounds < max_rounds:
379
+ if self.cancelled:
380
+ return image, seed
381
+ rounds += 1
382
+ if status_callback:
383
+ status_callback(f"Refining +20 steps (round {rounds})...")
384
+
385
+ refined = self.refine(
386
+ source_image=image, prompt=prompt,
387
+ negative_prompt=negative_prompt,
388
+ extra_steps=20, strength=0.3,
389
+ guidance_scale=guidance_scale,
390
+ preview_callback=preview_callback, preview_every=5)
391
+
392
+ if refined is None:
393
+ return image, seed
394
+ image = refined
395
+ quality = self._image_quality_score(image)
396
+
397
+ if status_callback:
398
+ status_callback(f"Quality after round {rounds}: {quality:.1f}/100")
399
+
400
+ return image, seed
401
+
402
+ @torch.no_grad()
403
+ def refine(self, source_image: Image.Image, prompt: str,
404
+ negative_prompt: str = "", extra_steps: int = 20,
405
+ strength: float = 0.35, guidance_scale: float = 7.5,
406
+ preview_callback=None, preview_every: int = 5) -> Image.Image:
407
+
408
+ img = source_image.resize((self.output_size, self.output_size), Image.LANCZOS)
409
+ img_tensor = torch.from_numpy(np.array(img)).float().div(127.5).sub(1.0)
410
+ img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(self.device)
411
+
412
+ with torch.no_grad():
413
+ latents = self.vae.encode(img_tensor.float()).latent_dist.sample()
414
+ latents = latents * self.vae.config.scaling_factor
415
+
416
+ tok = self.tokenizer(prompt, padding="max_length",
417
+ max_length=self.tokenizer.model_max_length,
418
+ truncation=True, return_tensors="pt")
419
+ text_emb = self.text_encoder(tok.input_ids.to(self.device))[0]
420
+
421
+ tok_neg = self.tokenizer(negative_prompt if negative_prompt else "",
422
+ padding="max_length",
423
+ max_length=self.tokenizer.model_max_length,
424
+ truncation=True, return_tensors="pt")
425
+ neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0]
426
+ text_emb_combined = torch.cat([neg_emb, text_emb])
427
+
428
+ scheduler = self._make_scheduler(self.scheduler_name)
429
+ scheduler.set_timesteps(extra_steps, device=self.device)
430
+ start_step = max(0, int(len(scheduler.timesteps) * (1 - strength)))
431
+ timesteps = scheduler.timesteps[start_step:]
432
+
433
+ noise = torch.randn_like(latents)
434
+ latents = scheduler.add_noise(latents, noise, timesteps[:1])
435
+
436
+ total_steps = len(timesteps)
437
+ for step_i, t in enumerate(timesteps):
438
+ if self.cancelled:
439
+ return None
440
+ latent_input = torch.cat([latents] * 2)
441
+ latent_input = scheduler.scale_model_input(latent_input, t)
442
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
443
+ enabled=(self.device == "cuda")):
444
+ pred = self.unet(latent_input, t,
445
+ encoder_hidden_states=text_emb_combined).sample
446
+ pred_neg, pred_text = pred.chunk(2)
447
+ pred = pred_neg + guidance_scale * (pred_text - pred_neg)
448
+ latents = scheduler.step(pred, t, latents).prev_sample
449
+
450
+ if (preview_callback and step_i > 0
451
+ and step_i % preview_every == 0
452
+ and step_i < total_steps - 1):
453
+ preview = self._decode_latents(latents, post_process=False)
454
+ preview_callback(preview, step_i + 1, total_steps)
455
+
456
+ latents = self._sharpen_latents(latents)
457
+ return self._decode_latents(latents, post_process=True)
458
+
459
+
460
+ # ── GUI ───────────────────────────────────────────────────────────────────────
461
+
462
+ class App:
463
+ def __init__(self):
464
+ self.gen = Generator()
465
+ self.models = find_models()
466
+ self.generated_images = []
467
+ self.generated_seeds = []
468
+ self.photo_refs = []
469
+ self.generating = False
470
+ self.selected_index = None
471
+
472
+ self.root = tk.Tk()
473
+ self.root.title("Aniimage")
474
+ self.root.configure(bg=C["bg"])
475
+ self.root.resizable(True, True)
476
+ self.root.geometry("900x780")
477
+ self.root.minsize(640, 500)
478
+
479
+ self._setup_styles()
480
+ self._build_ui()
481
+
482
+ def _setup_styles(self):
483
+ s = ttk.Style()
484
+ s.theme_use("clam")
485
+
486
+ # Base
487
+ s.configure(".", background=C["bg"], foreground=C["text"], font=("Segoe UI", 10))
488
+ s.configure("TFrame", background=C["bg"])
489
+ s.configure("TLabel", background=C["bg"], foreground=C["text"])
490
+ s.configure("TCheckbutton", background=C["bg"], foreground=C["text"])
491
+
492
+ # Combobox β€” readable text
493
+ s.configure("TCombobox", fieldbackground=C["input"], foreground=C["input_fg"],
494
+ selectbackground=C["accent"], selectforeground="#ffffff",
495
+ arrowcolor=C["text2"], padding=4)
496
+ s.map("TCombobox",
497
+ fieldbackground=[("readonly", C["input"])],
498
+ foreground=[("readonly", C["input_fg"])],
499
+ selectbackground=[("readonly", C["accent"])],
500
+ selectforeground=[("readonly", "#ffffff")])
501
+ # Combobox dropdown list colors
502
+ self.root.option_add("*TCombobox*Listbox.background", C["input"])
503
+ self.root.option_add("*TCombobox*Listbox.foreground", C["input_fg"])
504
+ self.root.option_add("*TCombobox*Listbox.selectBackground", C["accent"])
505
+ self.root.option_add("*TCombobox*Listbox.selectForeground", "#ffffff")
506
+ self.root.option_add("*TCombobox*Listbox.font", ("Segoe UI", 10))
507
+
508
+ # Spinbox
509
+ s.configure("TSpinbox", fieldbackground=C["input"], foreground=C["input_fg"],
510
+ arrowcolor=C["text2"], padding=3)
511
+
512
+ # Buttons
513
+ s.configure("TButton", font=("Segoe UI", 10), padding=(14, 7),
514
+ background=C["card"], foreground=C["text"])
515
+ s.map("TButton", background=[("active", C["card_sel"]), ("disabled", C["bg"])],
516
+ foreground=[("disabled", C["text3"])])
517
+
518
+ s.configure("Go.TButton", font=("Segoe UI", 11, "bold"), padding=(20, 9),
519
+ background=C["accent"], foreground="#ffffff")
520
+ s.map("Go.TButton", background=[("active", C["accent_h"]),
521
+ ("disabled", C["border"])])
522
+
523
+ s.configure("Stop.TButton", font=("Segoe UI", 10, "bold"), padding=(14, 7),
524
+ background=C["red"], foreground="#ffffff")
525
+ s.map("Stop.TButton", background=[("active", "#c0392b"),
526
+ ("disabled", C["border"])])
527
+
528
+ # Labelframe
529
+ s.configure("TLabelframe", background=C["bg"], foreground=C["text2"])
530
+ s.configure("TLabelframe.Label", background=C["bg"],
531
+ foreground=C["text2"], font=("Segoe UI", 9, "bold"))
532
+
533
+ # Scrollbar
534
+ s.configure("Vertical.TScrollbar", background=C["card"],
535
+ troughcolor=C["bg"], arrowcolor=C["text3"])
536
+
537
+ def _make_entry(self, parent, font_size=11, dim=False):
538
+ """Create a styled tk.Entry with readable text."""
539
+ return tk.Entry(parent, font=("Segoe UI", font_size),
540
+ bg=C["input"], fg=C["input_fg"] if not dim else C["text2"],
541
+ insertbackground=C["input_fg"],
542
+ relief="flat", bd=6,
543
+ selectbackground=C["accent"], selectforeground="#ffffff",
544
+ highlightthickness=1, highlightcolor=C["accent"],
545
+ highlightbackground=C["border"])
546
+
547
+ def _build_ui(self):
548
+ # ── Header ────────────────────────────────────────────────────────
549
+ header = tk.Frame(self.root, bg=C["panel"], padx=20, pady=12)
550
+ header.pack(fill=tk.X)
551
+
552
+ tk.Label(header, text="Aniimage", bg=C["panel"], fg=C["accent"],
553
+ font=("Segoe UI", 20, "bold")).pack(side=tk.LEFT)
554
+ tk.Label(header, text="by 8BitStudio", bg=C["panel"], fg=C["text3"],
555
+ font=("Segoe UI", 10)).pack(side=tk.LEFT, padx=(10, 0), pady=(6, 0))
556
+
557
+ # Device switch β€” right side of header
558
+ device_frame = tk.Frame(header, bg=C["panel"])
559
+ device_frame.pack(side=tk.RIGHT)
560
+
561
+ tk.Label(device_frame, text="Device:", bg=C["panel"], fg=C["text2"],
562
+ font=("Segoe UI", 9)).pack(side=tk.LEFT, padx=(0, 5))
563
+
564
+ self.device_var = tk.StringVar(value="GPU" if self.gen.device == "cuda" else "CPU")
565
+ devices = ["GPU", "CPU"] if torch.cuda.is_available() else ["CPU"]
566
+ device_combo = ttk.Combobox(device_frame, textvariable=self.device_var,
567
+ values=devices, state="readonly", width=5)
568
+ device_combo.pack(side=tk.LEFT)
569
+ device_combo.bind("<<ComboboxSelected>>", self._on_device_change)
570
+
571
+ # ── Main content β€” two-column: controls left, images right ────────
572
+ main = tk.Frame(self.root, bg=C["bg"])
573
+ main.pack(fill=tk.BOTH, expand=True, padx=12, pady=(8, 12))
574
+
575
+ # Left panel (controls)
576
+ left = tk.Frame(main, bg=C["panel"], width=340, padx=16, pady=12)
577
+ left.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 8))
578
+ left.pack_propagate(False)
579
+
580
+ # Right panel (image grid)
581
+ right = tk.Frame(main, bg=C["bg"])
582
+ right.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
583
+
584
+ self._build_controls(left)
585
+ self._build_grid(right)
586
+
587
+ def _build_controls(self, parent):
588
+ # ── Model ─────────────────────────────────────────────────────────
589
+ tk.Label(parent, text="Model", bg=C["panel"], fg=C["text2"],
590
+ font=("Segoe UI", 9, "bold")).pack(anchor=tk.W)
591
+
592
+ self.model_var = tk.StringVar()
593
+ model_names = [m[1] for m in self.models] or ["No models found"]
594
+ self.model_combo = ttk.Combobox(parent, textvariable=self.model_var,
595
+ values=model_names, state="readonly", width=32)
596
+ self.model_combo.pack(fill=tk.X, pady=(3, 12))
597
+ self.model_combo.current(len(model_names) - 1)
598
+
599
+ # ── Prompt ────────────────────────────────────────────────────────
600
+ tk.Label(parent, text="Prompt", bg=C["panel"], fg=C["text2"],
601
+ font=("Segoe UI", 9, "bold")).pack(anchor=tk.W)
602
+ self.prompt_entry = self._make_entry(parent)
603
+ self.prompt_entry.pack(fill=tk.X, pady=(3, 8))
604
+ self.prompt_entry.insert(0, "a smiling anime girl with long blue hair")
605
+ self.prompt_entry.bind("<Return>", lambda e: self.on_generate())
606
+
607
+ # ── Negative prompt ───────────────────────────────────────────────
608
+ tk.Label(parent, text="Negative prompt", bg=C["panel"], fg=C["text3"],
609
+ font=("Segoe UI", 9)).pack(anchor=tk.W)
610
+ self.neg_entry = self._make_entry(parent, font_size=9, dim=True)
611
+ self.neg_entry.pack(fill=tk.X, pady=(3, 12))
612
+ self.neg_entry.insert(0, DEFAULT_NEGATIVE)
613
+
614
+ # ── Settings grid ─────────────────────────────────────────────────
615
+ grid = tk.Frame(parent, bg=C["panel"])
616
+ grid.pack(fill=tk.X, pady=(0, 8))
617
+
618
+ # Row 1: Scheduler
619
+ tk.Label(grid, text="Scheduler", bg=C["panel"], fg=C["text2"],
620
+ font=("Segoe UI", 9)).grid(row=0, column=0, sticky="w", pady=(0, 6))
621
+ self.scheduler_var = tk.StringVar(value="DPM++ 2M Karras")
622
+ sched_combo = ttk.Combobox(grid, textvariable=self.scheduler_var,
623
+ values=SCHEDULER_LIST, state="readonly", width=18)
624
+ sched_combo.grid(row=0, column=1, columnspan=3, sticky="ew", padx=(8, 0), pady=(0, 6))
625
+ sched_combo.bind("<<ComboboxSelected>>", self._on_scheduler_change)
626
+
627
+ # Row 2: Steps, CFG, Count
628
+ tk.Label(grid, text="Steps", bg=C["panel"], fg=C["text2"],
629
+ font=("Segoe UI", 9)).grid(row=1, column=0, sticky="w", pady=(0, 6))
630
+ self.steps_var = tk.StringVar(value="25")
631
+ tk.Entry(grid, textvariable=self.steps_var, width=5, font=("Segoe UI", 10),
632
+ bg=C["input"], fg=C["input_fg"], insertbackground=C["input_fg"],
633
+ relief="flat", bd=4).grid(row=1, column=1, sticky="w", padx=(8, 12), pady=(0, 6))
634
+
635
+ tk.Label(grid, text="CFG", bg=C["panel"], fg=C["text2"],
636
+ font=("Segoe UI", 9)).grid(row=1, column=2, sticky="w", pady=(0, 6))
637
+ self.cfg_var = tk.StringVar(value="7.5")
638
+ tk.Entry(grid, textvariable=self.cfg_var, width=5, font=("Segoe UI", 10),
639
+ bg=C["input"], fg=C["input_fg"], insertbackground=C["input_fg"],
640
+ relief="flat", bd=4).grid(row=1, column=3, sticky="w", padx=(8, 0), pady=(0, 6))
641
+
642
+ # Row 3: Count, Live preview
643
+ tk.Label(grid, text="Count", bg=C["panel"], fg=C["text2"],
644
+ font=("Segoe UI", 9)).grid(row=2, column=0, sticky="w", pady=(0, 6))
645
+ self.count_var = tk.StringVar(value="4")
646
+ ttk.Spinbox(grid, from_=1, to=12, textvariable=self.count_var, width=4,
647
+ font=("Segoe UI", 10)).grid(row=2, column=1, sticky="w", padx=(8, 12), pady=(0, 6))
648
+
649
+ self.live_preview_var = tk.BooleanVar(value=False)
650
+ ttk.Checkbutton(grid, text="Live preview",
651
+ variable=self.live_preview_var).grid(
652
+ row=2, column=2, columnspan=2, sticky="w", pady=(0, 6))
653
+
654
+ grid.columnconfigure(1, weight=1)
655
+ grid.columnconfigure(3, weight=1)
656
+
657
+ # ── Auto quality ──────────────────────────────────────────────────
658
+ self.auto_quality_var = tk.BooleanVar(value=False)
659
+ ttk.Checkbutton(parent, text="Auto quality (refine if undercooked)",
660
+ variable=self.auto_quality_var).pack(anchor=tk.W, pady=(0, 12))
661
+
662
+ # ── Buttons ───────────────────────────────────────────────────────
663
+ btn_frame = tk.Frame(parent, bg=C["panel"])
664
+ btn_frame.pack(fill=tk.X, pady=(0, 10))
665
+
666
+ self.gen_btn = ttk.Button(btn_frame, text="Generate", command=self.on_generate,
667
+ style="Go.TButton")
668
+ self.gen_btn.pack(fill=tk.X, pady=(0, 5))
669
+
670
+ btn_row = tk.Frame(btn_frame, bg=C["panel"])
671
+ btn_row.pack(fill=tk.X)
672
+
673
+ self.stop_btn = ttk.Button(btn_row, text="Stop", command=self.on_stop,
674
+ state=tk.DISABLED, style="Stop.TButton")
675
+ self.stop_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 3))
676
+
677
+ self.save_btn = ttk.Button(btn_row, text="Save Selected", command=self.on_save,
678
+ state=tk.DISABLED)
679
+ self.save_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(3, 3))
680
+
681
+ self.save_all_btn = ttk.Button(btn_row, text="Save All", command=self.on_save_all,
682
+ state=tk.DISABLED)
683
+ self.save_all_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(3, 0))
684
+
685
+ # ── Prompt queue ─────────────────────────────────────────────────
686
+ sep = tk.Frame(parent, height=1, bg=C["border"])
687
+ sep.pack(fill=tk.X, pady=(8, 10))
688
+
689
+ tk.Label(parent, text="Prompt Queue", bg=C["panel"], fg=C["text2"],
690
+ font=("Segoe UI", 9, "bold")).pack(anchor=tk.W)
691
+
692
+ queue_input = tk.Frame(parent, bg=C["panel"])
693
+ queue_input.pack(fill=tk.X, pady=(4, 0))
694
+
695
+ self.queue_entry = self._make_entry(queue_input, font_size=9)
696
+ self.queue_entry.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 4))
697
+ self.queue_entry.bind("<Return>", lambda e: self._queue_add())
698
+
699
+ ttk.Button(queue_input, text="Add", width=4,
700
+ command=self._queue_add).pack(side=tk.LEFT)
701
+
702
+ self.queue_listbox = tk.Listbox(
703
+ parent, height=4, bg=C["input"], fg=C["input_fg"],
704
+ selectbackground=C["accent"], selectforeground="#fff",
705
+ font=("Segoe UI", 9), activestyle="none",
706
+ relief="flat", bd=4, highlightthickness=0)
707
+ self.queue_listbox.pack(fill=tk.X, pady=(5, 0))
708
+
709
+ queue_btns = tk.Frame(parent, bg=C["panel"])
710
+ queue_btns.pack(fill=tk.X, pady=(4, 0))
711
+
712
+ self.queue_run_btn = ttk.Button(queue_btns, text="Run Queue",
713
+ command=self.on_run_queue, style="Go.TButton")
714
+ self.queue_run_btn.pack(side=tk.LEFT, padx=(0, 4))
715
+
716
+ for txt, cmd in [("Remove", self._queue_remove), ("Clear", self._queue_clear),
717
+ ("Up", self._queue_move_up), ("Down", self._queue_move_down),
718
+ ("+ Current", self._queue_add_current)]:
719
+ ttk.Button(queue_btns, text=txt, command=cmd).pack(side=tk.LEFT, padx=2)
720
+
721
+ # ── Status bar ────────────────────────────────────────────────────
722
+ status_frame = tk.Frame(parent, bg=C["bg"], padx=8, pady=6)
723
+ status_frame.pack(fill=tk.X, side=tk.BOTTOM)
724
+
725
+ self.status_var = tk.StringVar(value="Ready")
726
+ tk.Label(status_frame, textvariable=self.status_var,
727
+ bg=C["bg"], fg=C["green"], font=("Segoe UI", 9),
728
+ anchor="w").pack(fill=tk.X)
729
+
730
+ def _build_grid(self, parent):
731
+ self.canvas = tk.Canvas(parent, bg=C["bg"], highlightthickness=0)
732
+ scrollbar = ttk.Scrollbar(parent, orient=tk.VERTICAL, command=self.canvas.yview)
733
+ self.grid_frame = tk.Frame(self.canvas, bg=C["bg"])
734
+
735
+ self.grid_frame.bind("<Configure>",
736
+ lambda e: self.canvas.configure(
737
+ scrollregion=self.canvas.bbox("all")))
738
+ self.canvas_window = self.canvas.create_window((0, 0), window=self.grid_frame,
739
+ anchor="nw")
740
+ self.canvas.configure(yscrollcommand=scrollbar.set)
741
+
742
+ self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
743
+ scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
744
+
745
+ self.canvas.bind("<Configure>", self._on_canvas_resize)
746
+ self.canvas.bind_all("<MouseWheel>",
747
+ lambda e: self.canvas.yview_scroll(
748
+ int(-1 * (e.delta / 120)), "units"))
749
+
750
+ self.placeholder = tk.Label(self.grid_frame,
751
+ text="Generated images\nwill appear here",
752
+ bg=C["bg"], fg=C["text3"],
753
+ font=("Segoe UI", 13), justify="center")
754
+ self.placeholder.grid(row=0, column=0, pady=80)
755
+
756
+ # ── Event handlers ────────────────────────────────────────────────────
757
+
758
+ def _on_device_change(self, event=None):
759
+ choice = self.device_var.get()
760
+ new_dev = "cuda" if choice == "GPU" else "cpu"
761
+ self.status_var.set(f"Switching to {choice}...")
762
+ self.root.update()
763
+ self.gen.switch_device(new_dev)
764
+ self.status_var.set(f"Now using {choice}")
765
+
766
+ def _on_scheduler_change(self, event=None):
767
+ name = self.scheduler_var.get()
768
+ self.gen.set_scheduler(name)
769
+ self.status_var.set(f"Scheduler: {name}")
770
+
771
+ def _on_canvas_resize(self, event):
772
+ self.canvas.itemconfig(self.canvas_window, width=event.width)
773
+ if self.generated_images:
774
+ self._layout_grid()
775
+
776
+ def _get_grid_cols(self):
777
+ canvas_w = self.canvas.winfo_width()
778
+ if canvas_w < 50:
779
+ canvas_w = 560
780
+ tile_size = self._get_tile_size()
781
+ return max(1, canvas_w // (tile_size + 16))
782
+
783
+ def _get_tile_size(self):
784
+ n = len(self.generated_images)
785
+ if n <= 2: return 260
786
+ elif n <= 4: return 220
787
+ elif n <= 6: return 180
788
+ else: return 160
789
+
790
+ def _layout_grid(self):
791
+ for w in self.grid_frame.winfo_children():
792
+ w.destroy()
793
+ self.photo_refs.clear()
794
+
795
+ if not self.generated_images:
796
+ return
797
+
798
+ tile_size = self._get_tile_size()
799
+ cols = self._get_grid_cols()
800
+
801
+ for i, (img, seed) in enumerate(zip(self.generated_images, self.generated_seeds)):
802
+ row, col = divmod(i, cols)
803
+ is_selected = (i == self.selected_index)
804
+
805
+ card_bg = C["accent"] if is_selected else C["card"]
806
+ card = tk.Frame(self.grid_frame, bg=card_bg, padx=3, pady=3)
807
+ card.grid(row=row, column=col, padx=5, pady=5, sticky="nsew")
808
+
809
+ display = img.resize((tile_size, tile_size), Image.LANCZOS)
810
+ photo = ImageTk.PhotoImage(display)
811
+ self.photo_refs.append(photo)
812
+
813
+ img_label = tk.Label(card, image=photo, bg=card_bg, bd=0)
814
+ img_label.pack()
815
+ img_label.bind("<Button-1>", lambda e, idx=i: self._select_image(idx))
816
+ img_label.bind("<Button-3>", lambda e, idx=i: self._show_refine_menu(e, idx))
817
+
818
+ tk.Label(card, text=f"seed: {seed}", bg=card_bg,
819
+ fg=C["text3"], font=("Segoe UI", 8)).pack()
820
+
821
+ for c in range(cols):
822
+ self.grid_frame.columnconfigure(c, weight=1)
823
+
824
+ def _select_image(self, idx):
825
+ if idx >= len(self.generated_images):
826
+ return
827
+ self.selected_index = idx
828
+ self.save_btn.configure(state=tk.NORMAL)
829
+ self.status_var.set(f"Selected image {idx + 1} (seed: {self.generated_seeds[idx]})")
830
+ self._layout_grid()
831
+
832
+ def _show_refine_menu(self, event, idx):
833
+ if self.generating:
834
+ return
835
+ menu = tk.Menu(self.root, tearoff=0, bg=C["card"], fg=C["text"],
836
+ activebackground=C["accent"], activeforeground="#fff",
837
+ font=("Segoe UI", 10), bd=0)
838
+ menu.add_command(label=" Refine (more steps)... ",
839
+ command=lambda: self._ask_refine(idx))
840
+ menu.tk_popup(event.x_root, event.y_root)
841
+
842
+ def _ask_refine(self, idx):
843
+ extra = simpledialog.askinteger(
844
+ "Refine Image", "Extra denoising steps:",
845
+ initialvalue=20, minvalue=5, maxvalue=200, parent=self.root)
846
+ if extra is None:
847
+ return
848
+ self._select_image(idx)
849
+ self.generating = True
850
+ self.gen.cancelled = False
851
+ self.gen_btn.configure(state=tk.DISABLED)
852
+ self.stop_btn.configure(state=tk.NORMAL)
853
+ self.status_var.set(f"Refining image {idx + 1}...")
854
+ self.root.update()
855
+ Thread(target=self._refine_thread, args=(idx, extra), daemon=True).start()
856
+
857
+ def _refine_thread(self, idx, extra_steps):
858
+ try:
859
+ source = self.generated_images[idx]
860
+ prompt = self.prompt_entry.get().strip()
861
+ neg = self.neg_entry.get().strip()
862
+ cfg = float(self.cfg_var.get())
863
+ callback = self._show_preview if self.live_preview_var.get() else None
864
+
865
+ refined = self.gen.refine(
866
+ source_image=source, prompt=prompt, negative_prompt=neg,
867
+ extra_steps=extra_steps, guidance_scale=cfg,
868
+ preview_callback=callback, preview_every=5)
869
+
870
+ if refined is not None:
871
+ self.generated_images[idx] = refined
872
+ self.generated_seeds[idx] = f"{self.generated_seeds[idx]}+R{extra_steps}"
873
+ self._layout_grid()
874
+ self.status_var.set(f"Refined image {idx + 1}")
875
+ else:
876
+ self.status_var.set("Refine stopped.")
877
+ self.root.update()
878
+ except Exception as e:
879
+ self.status_var.set(f"Refine error: {e}")
880
+ import traceback; traceback.print_exc()
881
+ finally:
882
+ self.generating = False
883
+ self.gen.cancelled = False
884
+ self.gen_btn.configure(state=tk.NORMAL)
885
+ self.stop_btn.configure(state=tk.DISABLED)
886
+
887
+ # ── Queue ─────────────────────────────────────────────────────────────
888
+
889
+ def _queue_add(self):
890
+ text = self.queue_entry.get().strip()
891
+ if text:
892
+ self.queue_listbox.insert(tk.END, text)
893
+ self.queue_entry.delete(0, tk.END)
894
+
895
+ def _queue_add_current(self):
896
+ text = self.prompt_entry.get().strip()
897
+ if text:
898
+ self.queue_listbox.insert(tk.END, text)
899
+
900
+ def _queue_remove(self):
901
+ sel = self.queue_listbox.curselection()
902
+ if sel:
903
+ self.queue_listbox.delete(sel[0])
904
+
905
+ def _queue_clear(self):
906
+ self.queue_listbox.delete(0, tk.END)
907
+
908
+ def _queue_move_up(self):
909
+ sel = self.queue_listbox.curselection()
910
+ if sel and sel[0] > 0:
911
+ idx = sel[0]
912
+ text = self.queue_listbox.get(idx)
913
+ self.queue_listbox.delete(idx)
914
+ self.queue_listbox.insert(idx - 1, text)
915
+ self.queue_listbox.selection_set(idx - 1)
916
+
917
+ def _queue_move_down(self):
918
+ sel = self.queue_listbox.curselection()
919
+ if sel and sel[0] < self.queue_listbox.size() - 1:
920
+ idx = sel[0]
921
+ text = self.queue_listbox.get(idx)
922
+ self.queue_listbox.delete(idx)
923
+ self.queue_listbox.insert(idx + 1, text)
924
+ self.queue_listbox.selection_set(idx + 1)
925
+
926
+ def on_run_queue(self):
927
+ if self.generating or not self.models:
928
+ return
929
+ prompts = list(self.queue_listbox.get(0, tk.END))
930
+ if not prompts:
931
+ self.status_var.set("Queue is empty")
932
+ return
933
+ self.generating = True
934
+ self.gen.cancelled = False
935
+ self.gen_btn.configure(state=tk.DISABLED)
936
+ self.queue_run_btn.configure(state=tk.DISABLED)
937
+ self.stop_btn.configure(state=tk.NORMAL)
938
+ Thread(target=self._queue_thread, args=(prompts,), daemon=True).start()
939
+
940
+ def _queue_thread(self, prompts):
941
+ try:
942
+ idx = self.model_combo.current()
943
+ mdl = self.models[idx]
944
+ self.status_var.set(f"Loading {mdl[1]}...")
945
+ self.root.update()
946
+ self.gen.load_model(mdl[2], mdl[3])
947
+
948
+ neg = self.neg_entry.get().strip()
949
+ steps = int(self.steps_var.get())
950
+ cfg = float(self.cfg_var.get())
951
+ num_images = max(1, min(12, int(self.count_var.get())))
952
+ live_preview = self.live_preview_var.get()
953
+ auto_quality = self.auto_quality_var.get()
954
+
955
+ self.generated_images.clear()
956
+ self.generated_seeds.clear()
957
+ self.selected_index = None
958
+ if self.placeholder:
959
+ self.placeholder.destroy()
960
+ self.placeholder = None
961
+
962
+ for p_idx, prompt in enumerate(prompts):
963
+ if self.gen.cancelled:
964
+ break
965
+ self.queue_listbox.selection_clear(0, tk.END)
966
+ self.queue_listbox.selection_set(p_idx)
967
+ self.queue_listbox.see(p_idx)
968
+
969
+ for img_i in range(num_images):
970
+ if self.gen.cancelled:
971
+ break
972
+ self.status_var.set(
973
+ f"[{p_idx + 1}/{len(prompts)}] image {img_i + 1}/{num_images}")
974
+ self.root.update()
975
+
976
+ callback = None
977
+ if live_preview:
978
+ self._setup_preview_card()
979
+ callback = self._show_preview
980
+
981
+ if auto_quality:
982
+ image, used_seed = self.gen.generate_adaptive(
983
+ prompt=prompt, negative_prompt=neg,
984
+ base_steps=steps, max_steps=steps + 60,
985
+ guidance_scale=cfg,
986
+ preview_callback=callback, preview_every=5,
987
+ status_callback=lambda m: (
988
+ self.status_var.set(m), self.root.update()))
989
+ else:
990
+ image, used_seed = self.gen.generate(
991
+ prompt=prompt, negative_prompt=neg,
992
+ steps=steps, guidance_scale=cfg,
993
+ preview_callback=callback, preview_every=5)
994
+
995
+ if image is None:
996
+ break
997
+ self.generated_images.append(image)
998
+ self.generated_seeds.append(used_seed)
999
+ save_path = self._next_save_path(prompt)
1000
+ image.save(save_path)
1001
+ self._layout_grid()
1002
+ self.root.update()
1003
+
1004
+ if self.gen.cancelled:
1005
+ break
1006
+
1007
+ done = len(self.generated_images)
1008
+ self.status_var.set(
1009
+ f"Queue {'stopped' if self.gen.cancelled else 'done'}! {done} images saved.")
1010
+ if done > 0:
1011
+ self.save_all_btn.configure(state=tk.NORMAL)
1012
+
1013
+ except Exception as e:
1014
+ self.status_var.set(f"Queue error: {e}")
1015
+ import traceback; traceback.print_exc()
1016
+ finally:
1017
+ self.generating = False
1018
+ self.gen.cancelled = False
1019
+ self.gen_btn.configure(state=tk.NORMAL)
1020
+ self.queue_run_btn.configure(state=tk.NORMAL)
1021
+ self.stop_btn.configure(state=tk.DISABLED)
1022
+
1023
+ # ── Generation ────────────────────────────────────────────────────────
1024
+
1025
+ def on_stop(self):
1026
+ if self.generating:
1027
+ self.gen.cancelled = True
1028
+ self.status_var.set("Stopping...")
1029
+ self.root.update()
1030
+
1031
+ def on_generate(self):
1032
+ if self.generating or not self.models:
1033
+ return
1034
+ self.generating = True
1035
+ self.gen.cancelled = False
1036
+ self.gen_btn.configure(state=tk.DISABLED)
1037
+ self.stop_btn.configure(state=tk.NORMAL)
1038
+ self.status_var.set("Loading model...")
1039
+ self.root.update()
1040
+ Thread(target=self._generate_thread, daemon=True).start()
1041
+
1042
+ def _setup_preview_card(self):
1043
+ tile_size = self._get_tile_size()
1044
+ cols = self._get_grid_cols()
1045
+ row, col = divmod(len(self.generated_images), cols)
1046
+ card = tk.Frame(self.grid_frame, bg=C["card"], padx=3, pady=3)
1047
+ card.grid(row=row, column=col, padx=5, pady=5, sticky="nsew")
1048
+ self._preview_label = tk.Label(card, bg=C["card"],
1049
+ width=tile_size, height=tile_size)
1050
+ self._preview_label.pack()
1051
+ self.root.update()
1052
+
1053
+ def _show_preview(self, preview_img, step, total):
1054
+ tile_size = self._get_tile_size()
1055
+ display = preview_img.resize((tile_size, tile_size), Image.LANCZOS)
1056
+ photo = ImageTk.PhotoImage(display)
1057
+ self._preview_photo = photo
1058
+ if hasattr(self, '_preview_label') and self._preview_label.winfo_exists():
1059
+ self._preview_label.configure(image=photo)
1060
+ self.status_var.set(f"Step {step}/{total}")
1061
+ self.root.update()
1062
+
1063
+ def _generate_thread(self):
1064
+ try:
1065
+ idx = self.model_combo.current()
1066
+ mdl = self.models[idx]
1067
+ self.status_var.set(f"Loading {mdl[1]}...")
1068
+ self.root.update()
1069
+ self.gen.load_model(mdl[2], mdl[3])
1070
+
1071
+ prompt = self.prompt_entry.get().strip()
1072
+ neg = self.neg_entry.get().strip()
1073
+ steps = int(self.steps_var.get())
1074
+ cfg = float(self.cfg_var.get())
1075
+ num_images = max(1, min(12, int(self.count_var.get())))
1076
+ live_preview = self.live_preview_var.get()
1077
+ auto_quality = self.auto_quality_var.get()
1078
+
1079
+ self.generated_images.clear()
1080
+ self.generated_seeds.clear()
1081
+ self.selected_index = None
1082
+ if self.placeholder:
1083
+ self.placeholder.destroy()
1084
+ self.placeholder = None
1085
+
1086
+ for i in range(num_images):
1087
+ if self.gen.cancelled:
1088
+ break
1089
+ self.status_var.set(f"Generating {i + 1}/{num_images}...")
1090
+ self.root.update()
1091
+
1092
+ callback = None
1093
+ if live_preview:
1094
+ self._setup_preview_card()
1095
+ callback = self._show_preview
1096
+
1097
+ if auto_quality:
1098
+ image, used_seed = self.gen.generate_adaptive(
1099
+ prompt=prompt, negative_prompt=neg,
1100
+ base_steps=steps, max_steps=steps + 60,
1101
+ guidance_scale=cfg,
1102
+ preview_callback=callback, preview_every=5,
1103
+ status_callback=lambda m: (
1104
+ self.status_var.set(m), self.root.update()))
1105
+ else:
1106
+ image, used_seed = self.gen.generate(
1107
+ prompt=prompt, negative_prompt=neg,
1108
+ steps=steps, guidance_scale=cfg,
1109
+ preview_callback=callback, preview_every=5)
1110
+
1111
+ if image is None:
1112
+ break
1113
+ self.generated_images.append(image)
1114
+ self.generated_seeds.append(used_seed)
1115
+ self._layout_grid()
1116
+ self.root.update()
1117
+
1118
+ done = len(self.generated_images)
1119
+ if self.gen.cancelled:
1120
+ self.status_var.set(f"Stopped. {done} image(s) kept.")
1121
+ else:
1122
+ self.status_var.set(f"Done! {done} images. Click to select.")
1123
+ if done > 0:
1124
+ self.save_all_btn.configure(state=tk.NORMAL)
1125
+ self.save_btn.configure(state=tk.DISABLED)
1126
+
1127
+ except Exception as e:
1128
+ self.status_var.set(f"Error: {e}")
1129
+ import traceback; traceback.print_exc()
1130
+ finally:
1131
+ self.generating = False
1132
+ self.gen.cancelled = False
1133
+ self.gen_btn.configure(state=tk.NORMAL)
1134
+ self.stop_btn.configure(state=tk.DISABLED)
1135
+
1136
+ # ── Save ──────────────────────────────────────────────────────────────
1137
+
1138
+ def _next_save_path(self, prompt_text):
1139
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
1140
+ slug = prompt_text.strip()[:50] if prompt_text.strip() else "untitled"
1141
+ base = OUTPUT_DIR / f"{slug}.png"
1142
+ if not base.exists():
1143
+ return base
1144
+ n = 1
1145
+ while True:
1146
+ path = OUTPUT_DIR / f"{slug} {n}.png"
1147
+ if not path.exists():
1148
+ return path
1149
+ n += 1
1150
+
1151
+ def on_save(self):
1152
+ if self.selected_index is None or not self.generated_images:
1153
+ return
1154
+ img = self.generated_images[self.selected_index]
1155
+ path = self._next_save_path(self.prompt_entry.get().strip())
1156
+ img.save(path)
1157
+ self.status_var.set(f"Saved: {path.name}")
1158
+
1159
+ def on_save_all(self):
1160
+ if not self.generated_images:
1161
+ return
1162
+ prompt_text = self.prompt_entry.get().strip()
1163
+ for img in self.generated_images:
1164
+ path = self._next_save_path(prompt_text)
1165
+ img.save(path)
1166
+ self.status_var.set(f"Saved {len(self.generated_images)} images")
1167
+
1168
+ def run(self):
1169
+ self.root.mainloop()
1170
+
1171
+
1172
+ # ── Entry point ───────────────────────────────────────────────────────────────
1173
+
1174
+ if __name__ == "__main__":
1175
+ models = find_models()
1176
+ if not models:
1177
+ print("No models found locally. Downloading from HuggingFace...")
1178
+ result = download_from_hf()
1179
+ if result:
1180
+ models = find_models()
1181
+
1182
+ if not models:
1183
+ print("No models found!")
1184
+ print(f"Place model weights in: {MODEL_DIR}/YourModelName/")
1185
+ print("Expected files: diffusion_pytorch_model.safetensors or ema_unet.pt")
1186
+ sys.exit(1)
1187
+
1188
+ print(f"Found {len(models)} model(s): {', '.join(m[1] for m in models)}")
1189
+ print(f"Device: {'CUDA (GPU)' if torch.cuda.is_available() else 'CPU'}")
1190
+ print("Starting Aniimage...")
1191
+
1192
+ app = App()
1193
+ app.run()