| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | !pip install -q diffusers transformers accelerate safetensors |
| |
|
| | import torch |
| | import gc |
| | from huggingface_hub import hf_hub_download |
| | from diffusers import UNet2DConditionModel, AutoencoderKL |
| | from transformers import CLIPTextModel, CLIPTokenizer |
| | from safetensors.torch import load_file |
| | from PIL import Image |
| | import numpy as np |
| | import json |
| |
|
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | |
| | |
| | |
| | DEVICE = "cuda" |
| | DTYPE = torch.float16 |
| |
|
| | LUNE_REPO = "AbstractPhil/sd15-flow-lune-flux" |
| | LUNE_WEIGHTS = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/diffusion_pytorch_model.safetensors" |
| | LUNE_CONFIG = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/config.json" |
| |
|
| | |
| | |
| | |
| | print("Loading CLIP...") |
| | clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
| | clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() |
| |
|
| | print("Loading VAE...") |
| | vae = AutoencoderKL.from_pretrained( |
| | "stable-diffusion-v1-5/stable-diffusion-v1-5", |
| | subfolder="vae", |
| | torch_dtype=DTYPE |
| | ).to(DEVICE).eval() |
| |
|
| | |
| | |
| | |
| | print(f"\nLoading Lune...") |
| | config_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_CONFIG) |
| | with open(config_path, 'r') as f: |
| | lune_config = json.load(f) |
| |
|
| | print(f" prediction_type: {lune_config.get('prediction_type', 'NOT SET')}") |
| |
|
| | unet = UNet2DConditionModel.from_config(lune_config).to(DEVICE).to(DTYPE).eval() |
| |
|
| | weights_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_WEIGHTS) |
| | state_dict = load_file(weights_path) |
| | unet.load_state_dict(state_dict, strict=False) |
| |
|
| | del state_dict |
| | gc.collect() |
| |
|
| | for p in unet.parameters(): |
| | p.requires_grad = False |
| |
|
| | print("β Lune ready!") |
| |
|
| | |
| | |
| | |
| | def shift_sigma(sigma: torch.Tensor, shift: float = 3.0) -> torch.Tensor: |
| | """ |
| | Apply timestep shift (same as trainer). |
| | sigma_shifted = shift * sigma / (1 + (shift - 1) * sigma) |
| | """ |
| | return (shift * sigma) / (1 + (shift - 1) * sigma) |
| |
|
| | @torch.inference_mode() |
| | def encode_prompt(prompt): |
| | inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", |
| | max_length=77, truncation=True).to(DEVICE) |
| | return clip_enc(**inputs).last_hidden_state.to(DTYPE) |
| |
|
| | |
| | |
| | |
| | @torch.inference_mode() |
| | def generate_lune( |
| | prompt: str, |
| | negative_prompt: str = "", |
| | seed: int = 42, |
| | steps: int = 30, |
| | cfg: float = 7.5, |
| | shift: float = 3.0, |
| | ): |
| | """ |
| | Correct Lune sampler matching trainer's flow convention. |
| | |
| | Trainer: |
| | x_t = sigma * noise + (1 - sigma) * data |
| | target = noise - data |
| | |
| | Sampling: |
| | - Start at sigma=1 (pure noise) |
| | - End at sigma=0 (clean data) |
| | - x_{sigma - dt} = x_sigma - v * dt (SUBTRACT because v points toward noise) |
| | """ |
| | torch.manual_seed(seed) |
| | |
| | cond = encode_prompt(prompt) |
| | uncond = encode_prompt(negative_prompt) if negative_prompt else encode_prompt("") |
| | |
| | |
| | x = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE) |
| | |
| | |
| | |
| | sigmas_linear = torch.linspace(1, 0, steps + 1, device=DEVICE) |
| | sigmas = shift_sigma(sigmas_linear, shift=shift) |
| | |
| | print(f"Lune: '{prompt[:30]}' | {steps} steps, cfg={cfg}, shift={shift}") |
| | print(f" sigma range: {sigmas[0].item():.3f} β {sigmas[-1].item():.3f}") |
| | |
| | for i in range(steps): |
| | sigma = sigmas[i] |
| | sigma_next = sigmas[i + 1] |
| | dt = sigma - sigma_next |
| | |
| | |
| | timestep = sigma * 1000 |
| | t_input = timestep.view(1).to(DEVICE) |
| | |
| | |
| | v_cond = unet(x, t_input, encoder_hidden_states=cond).sample |
| | v_uncond = unet(x, t_input, encoder_hidden_states=uncond).sample |
| | v = v_uncond + cfg * (v_cond - v_uncond) |
| | |
| | |
| | |
| | x = x - v * dt |
| | |
| | if (i + 1) % (steps // 5) == 0: |
| | print(f" Step {i+1}/{steps}, sigma={sigma.item():.3f} β {sigma_next.item():.3f}") |
| | |
| | |
| | x = x / 0.18215 |
| | img = vae.decode(x).sample |
| | img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy() |
| | return Image.fromarray((img * 255).astype(np.uint8)) |
| |
|
| | |
| | |
| | |
| | print("\n" + "="*60) |
| | print("Testing Lune with CORRECT flow convention") |
| | print(" x_t = sigma*noise + (1-sigma)*data") |
| | print(" v = noise - data") |
| | print(" Sample by SUBTRACTING v") |
| | print("="*60) |
| |
|
| | from IPython.display import display |
| |
|
| | prompt = "a castle at sunset" |
| |
|
| | print("\n--- shift=3.0 (default) ---") |
| | img = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=3.0) |
| | display(img) |
| |
|
| | print("\n--- shift=2.5 (trainer default) ---") |
| | img2 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=2.5) |
| | display(img2) |
| |
|
| | print("\n--- shift=1.0 (no shift) ---") |
| | img3 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=1.0) |
| | display(img3) |
| |
|
| | |
| | import matplotlib.pyplot as plt |
| | fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
| | for ax, (s, im) in zip(axes, [(3.0, img), (2.5, img2), (1.0, img3)]): |
| | ax.imshow(im) |
| | ax.set_title(f"shift={s}") |
| | ax.axis('off') |
| | plt.tight_layout() |
| | plt.show() |
| |
|
| | print("\nβ If images look correct, the output should be beautiful.") |