SyncAI / src /image_generator_hf.py
ICGenAIShare04's picture
Upload 52 files
72f552e verified
"""Generate images using SDXL + Hyper-SD 8-step + style LoRA from registry.
Reads segments.json (with prompts from prompt_generator) and generates
one 768x1344 (9:16 vertical) image per segment.
Pipeline: SDXL base → Hyper-SD 8-step CFG LoRA (speed) → style LoRA (aesthetics)
"""
import json
from pathlib import Path
from typing import Optional
import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline
from huggingface_hub import hf_hub_download
from src.styles import get_style
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
VAE_MODEL = "madebyollin/sdxl-vae-fp16-fix"
HYPER_SD_REPO = "ByteDance/Hyper-SD"
HYPER_SD_FILE = "Hyper-SDXL-8steps-CFG-lora.safetensors"
WIDTH = 768
HEIGHT = 1344
NUM_STEPS = 8
GUIDANCE_SCALE = 5.0
HYPER_SD_WEIGHT = 0.125 # official recommendation
def _get_device_and_dtype():
"""Detect best available device and matching dtype."""
if torch.cuda.is_available():
return "cuda", torch.float16
if torch.backends.mps.is_available():
return "mps", torch.float32 # float32 required for MPS reliability
return "cpu", torch.float32
def load_pipeline(style_name: str = "Warm Sunset"):
"""Load SDXL pipeline with Hyper-SD and a style LoRA from the registry.
Args:
style_name: Key in STYLES registry. Use "None" for no style LoRA.
Returns:
Configured DiffusionPipeline ready for inference.
"""
style = get_style(style_name)
device, dtype = _get_device_and_dtype()
print(f"Loading SDXL pipeline on {device} ({dtype})...")
vae = AutoencoderKL.from_pretrained(VAE_MODEL, torch_dtype=dtype)
load_kwargs = {"torch_dtype": dtype, "vae": vae, "use_safetensors": True}
if dtype == torch.float16:
load_kwargs["variant"] = "fp16"
pipe = DiffusionPipeline.from_pretrained(BASE_MODEL, **load_kwargs)
# Hyper-SD 8-step CFG LoRA (always loaded)
hyper_path = hf_hub_download(HYPER_SD_REPO, HYPER_SD_FILE)
pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd")
# Style LoRA from registry
_apply_style(pipe, style)
# DDIMScheduler with trailing timestep spacing — required for Hyper-SD
pipe.scheduler = DDIMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
pipe.to(device)
if device == "mps":
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
print("Pipeline ready.")
return pipe
def _apply_style(pipe, style: dict):
"""Load a style LoRA and set adapter weights."""
source = style["source"]
if source is None:
pipe.set_adapters(["hyper-sd"], adapter_weights=[HYPER_SD_WEIGHT])
print("No style LoRA — using base SDXL + Hyper-SD.")
return
load_kwargs = {"adapter_name": "style"}
# Local file: resolve relative to project root, pass dir + weight_name
project_root = Path(__file__).resolve().parent.parent
source_path = (project_root / source).resolve()
if source_path.is_file():
load_kwargs["weight_name"] = source_path.name
pipe.load_lora_weights(str(source_path.parent), **load_kwargs)
else:
# HF Hub repo ID
if style["weight_name"]:
load_kwargs["weight_name"] = style["weight_name"]
pipe.load_lora_weights(source, **load_kwargs)
pipe.set_adapters(
["hyper-sd", "style"],
adapter_weights=[HYPER_SD_WEIGHT, style["weight"]],
)
print(f"Loaded style LoRA: {source}")
def switch_style(pipe, style_name: str):
"""Switch to a different style LoRA at runtime.
Unloads all LoRAs then reloads Hyper-SD + new style.
"""
style = get_style(style_name)
pipe.unload_lora_weights()
# Re-load Hyper-SD
hyper_path = hf_hub_download(HYPER_SD_REPO, HYPER_SD_FILE)
pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd")
# Load new style
_apply_style(pipe, style)
print(f"Switched to style: {style_name}")
def generate_image(
pipe,
prompt: str,
negative_prompt: str = "",
seed: Optional[int] = None,
) -> "PIL.Image.Image":
"""Generate a single 768x1344 vertical image."""
generator = None
if seed is not None:
generator = torch.Generator(device="cpu").manual_seed(seed)
return pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=NUM_STEPS,
guidance_scale=GUIDANCE_SCALE,
height=HEIGHT,
width=WIDTH,
generator=generator,
).images[0]
def generate_all(
segments: list[dict],
pipe,
output_dir: str | Path,
trigger_word: str = "",
seed: int = 42,
progress_callback=None,
) -> list[Path]:
"""Generate images for all segments.
Args:
segments: List of segment dicts (with 'prompt' and 'negative_prompt').
pipe: Loaded DiffusionPipeline.
output_dir: Directory to save images.
trigger_word: LoRA trigger word appended to prompts.
seed: Base seed (incremented per segment for variety).
Returns:
List of saved image paths.
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
paths = []
for seg in segments:
idx = seg["segment"]
path = output_dir / f"segment_{idx:03d}.png"
if path.exists():
print(f" Segment {idx}/{len(segments)}: already exists, skipping")
paths.append(path)
continue
prompt = seg["prompt"]
if trigger_word:
prompt = f"{trigger_word} style, {prompt}"
neg = seg.get("negative_prompt", "")
print(f" Segment {idx}/{len(segments)}: generating...")
image = generate_image(pipe, prompt, neg, seed=seed + idx)
path = output_dir / f"segment_{idx:03d}.png"
image.save(path)
paths.append(path)
print(f" Saved {path.name}")
if progress_callback:
progress_callback(idx, len(segments))
return paths
def run(
data_dir: str | Path,
style_name: str = "Warm Sunset",
seed: int = 42,
progress_callback=None,
) -> list[Path]:
"""Full image generation pipeline: load model, read segments, generate, save.
Args:
data_dir: Run directory containing segments.json (e.g. data/Gone/run_001/).
style_name: Style from the registry (see src/styles.py).
seed: Base random seed.
Returns:
List of saved image paths.
"""
data_dir = Path(data_dir)
style = get_style(style_name)
with open(data_dir / "segments.json") as f:
segments = json.load(f)
pipe = load_pipeline(style_name)
paths = generate_all(segments, pipe, data_dir / "images", style["trigger"], seed, progress_callback)
print(f"\nGenerated {len(paths)} images in {data_dir / 'images'}")
return paths
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python -m src.image_generator_hf <data_dir> [style_name]")
print(' e.g. python -m src.image_generator_hf data/Gone/run_001 "Warm Sunset"')
sys.exit(1)
style = sys.argv[2] if len(sys.argv) > 2 else "Warm Sunset"
run(sys.argv[1], style_name=style)