|
|
""" |
|
|
Lyra/Lune Flow-Matching Inference Space |
|
|
Author: AbstractPhil |
|
|
License: MIT |
|
|
|
|
|
SD1.5-based flow matching with geometric crystalline architectures. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from typing import Optional, Dict |
|
|
import spaces |
|
|
|
|
|
from diffusers import ( |
|
|
UNet2DConditionModel, |
|
|
AutoencoderKL, |
|
|
DPMSolverMultistepScheduler, |
|
|
EulerDiscreteScheduler |
|
|
) |
|
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
try: |
|
|
from geovocab2.train.model.vae.vae_lyra import MultiModalVAE, MultiModalVAEConfig |
|
|
LYRA_AVAILABLE = True |
|
|
except ImportError: |
|
|
print("⚠️ Lyra VAE not available - install geovocab2") |
|
|
LYRA_AVAILABLE = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FlowMatchingPipeline: |
|
|
"""Custom pipeline for flow-matching inference.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vae: AutoencoderKL, |
|
|
text_encoder: CLIPTextModel, |
|
|
tokenizer: CLIPTokenizer, |
|
|
unet: UNet2DConditionModel, |
|
|
scheduler, |
|
|
device: str = "cuda", |
|
|
t5_encoder: Optional[T5EncoderModel] = None, |
|
|
t5_tokenizer: Optional[T5Tokenizer] = None, |
|
|
lyra_model: Optional[any] = None |
|
|
): |
|
|
self.vae = vae |
|
|
self.text_encoder = text_encoder |
|
|
self.tokenizer = tokenizer |
|
|
self.unet = unet |
|
|
self.scheduler = scheduler |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.t5_encoder = t5_encoder |
|
|
self.t5_tokenizer = t5_tokenizer |
|
|
self.lyra_model = lyra_model |
|
|
|
|
|
|
|
|
self.vae_scale_factor = 0.18215 |
|
|
|
|
|
def encode_prompt(self, prompt: str, negative_prompt: str = ""): |
|
|
"""Encode text prompts to embeddings.""" |
|
|
|
|
|
text_inputs = self.tokenizer( |
|
|
prompt, |
|
|
padding="max_length", |
|
|
max_length=self.tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text_input_ids = text_inputs.input_ids.to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
prompt_embeds = self.text_encoder(text_input_ids)[0] |
|
|
|
|
|
|
|
|
if negative_prompt: |
|
|
uncond_inputs = self.tokenizer( |
|
|
negative_prompt, |
|
|
padding="max_length", |
|
|
max_length=self.tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
uncond_input_ids = uncond_inputs.input_ids.to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0] |
|
|
else: |
|
|
negative_prompt_embeds = torch.zeros_like(prompt_embeds) |
|
|
|
|
|
return prompt_embeds, negative_prompt_embeds |
|
|
|
|
|
def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""): |
|
|
"""Encode text prompts using Lyra VAE (CLIP + T5 fusion).""" |
|
|
if self.lyra_model is None or self.t5_encoder is None: |
|
|
raise ValueError("Lyra VAE components not initialized") |
|
|
|
|
|
|
|
|
text_inputs = self.tokenizer( |
|
|
prompt, |
|
|
padding="max_length", |
|
|
max_length=self.tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text_input_ids = text_inputs.input_ids.to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
clip_embeds = self.text_encoder(text_input_ids)[0] |
|
|
|
|
|
|
|
|
t5_inputs = self.t5_tokenizer( |
|
|
prompt, |
|
|
max_length=77, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
return_tensors='pt' |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state |
|
|
|
|
|
|
|
|
modality_inputs = { |
|
|
'clip': clip_embeds, |
|
|
't5': t5_embeds |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
reconstructions, mu, logvar = self.lyra_model( |
|
|
modality_inputs, |
|
|
target_modalities=['clip'] |
|
|
) |
|
|
prompt_embeds = reconstructions['clip'] |
|
|
|
|
|
|
|
|
if negative_prompt: |
|
|
uncond_inputs = self.tokenizer( |
|
|
negative_prompt, |
|
|
padding="max_length", |
|
|
max_length=self.tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
uncond_input_ids = uncond_inputs.input_ids.to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
clip_embeds_uncond = self.text_encoder(uncond_input_ids)[0] |
|
|
|
|
|
t5_inputs_uncond = self.t5_tokenizer( |
|
|
negative_prompt, |
|
|
max_length=77, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
return_tensors='pt' |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
t5_embeds_uncond = self.t5_encoder(**t5_inputs_uncond).last_hidden_state |
|
|
|
|
|
modality_inputs_uncond = { |
|
|
'clip': clip_embeds_uncond, |
|
|
't5': t5_embeds_uncond |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
reconstructions_uncond, _, _ = self.lyra_model( |
|
|
modality_inputs_uncond, |
|
|
target_modalities=['clip'] |
|
|
) |
|
|
negative_prompt_embeds = reconstructions_uncond['clip'] |
|
|
else: |
|
|
negative_prompt_embeds = torch.zeros_like(prompt_embeds) |
|
|
|
|
|
return prompt_embeds, negative_prompt_embeds |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, |
|
|
prompt: str, |
|
|
negative_prompt: str = "", |
|
|
height: int = 512, |
|
|
width: int = 512, |
|
|
num_inference_steps: int = 20, |
|
|
guidance_scale: float = 7.5, |
|
|
shift: float = 2.5, |
|
|
use_flow_matching: bool = True, |
|
|
prediction_type: str = "epsilon", |
|
|
seed: Optional[int] = None, |
|
|
use_lyra: bool = False, |
|
|
progress_callback=None |
|
|
): |
|
|
"""Generate image using flow matching or standard diffusion.""" |
|
|
|
|
|
|
|
|
if seed is not None: |
|
|
generator = torch.Generator(device=self.device).manual_seed(seed) |
|
|
else: |
|
|
generator = None |
|
|
|
|
|
|
|
|
if use_lyra and self.lyra_model is not None: |
|
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt_lyra( |
|
|
prompt, negative_prompt |
|
|
) |
|
|
else: |
|
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt( |
|
|
prompt, negative_prompt |
|
|
) |
|
|
|
|
|
|
|
|
latent_channels = 4 |
|
|
latent_height = height // 8 |
|
|
latent_width = width // 8 |
|
|
|
|
|
latents = torch.randn( |
|
|
(1, latent_channels, latent_height, latent_width), |
|
|
generator=generator, |
|
|
device=self.device, |
|
|
dtype=torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=self.device) |
|
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
|
|
|
|
|
|
if not use_flow_matching: |
|
|
latents = latents * self.scheduler.init_noise_sigma |
|
|
|
|
|
|
|
|
for i, t in enumerate(timesteps): |
|
|
if progress_callback: |
|
|
progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}") |
|
|
|
|
|
|
|
|
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents |
|
|
|
|
|
|
|
|
|
|
|
if use_flow_matching and shift > 0: |
|
|
|
|
|
sigma = t.float() / 1000.0 |
|
|
sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) |
|
|
|
|
|
|
|
|
scaling = torch.sqrt(1 + sigma_shifted ** 2) |
|
|
latent_model_input = latent_model_input / scaling |
|
|
else: |
|
|
|
|
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
|
|
|
timestep = t.expand(latent_model_input.shape[0]) |
|
|
|
|
|
|
|
|
text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if guidance_scale > 1.0 else prompt_embeds |
|
|
|
|
|
noise_pred = self.unet( |
|
|
latent_model_input, |
|
|
timestep, |
|
|
encoder_hidden_states=text_embeds, |
|
|
return_dict=False |
|
|
)[0] |
|
|
|
|
|
|
|
|
if guidance_scale > 1.0: |
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
|
|
|
if use_flow_matching: |
|
|
|
|
|
sigma = t.float() / 1000.0 |
|
|
sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) |
|
|
|
|
|
if prediction_type == "v_prediction": |
|
|
|
|
|
v_pred = noise_pred |
|
|
alpha_t = torch.sqrt(1 - sigma_shifted ** 2) |
|
|
sigma_t = sigma_shifted |
|
|
noise_pred = alpha_t * v_pred + sigma_t * latents |
|
|
|
|
|
|
|
|
dt = -1.0 / num_inference_steps |
|
|
latents = latents + dt * noise_pred |
|
|
else: |
|
|
|
|
|
latents = self.scheduler.step( |
|
|
noise_pred, t, latents, return_dict=False |
|
|
)[0] |
|
|
|
|
|
|
|
|
latents = latents / self.vae_scale_factor |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(self, 'is_lune_model') and self.is_lune_model: |
|
|
latents = latents * 5.52 |
|
|
|
|
|
with torch.no_grad(): |
|
|
image = self.vae.decode(latents).sample |
|
|
|
|
|
|
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
|
|
image = (image * 255).round().astype("uint8") |
|
|
image = Image.fromarray(image[0]) |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"): |
|
|
"""Load Lune checkpoint from .pt file.""" |
|
|
print(f"📥 Downloading checkpoint: {repo_id}/{filename}") |
|
|
|
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=filename, |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
print(f"✓ Downloaded to: {checkpoint_path}") |
|
|
print(f"📦 Loading checkpoint...") |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
|
|
|
|
|
print(f"🏗️ Initializing SD1.5 UNet...") |
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
subfolder="unet", |
|
|
torch_dtype=torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
student_state_dict = checkpoint["student"] |
|
|
|
|
|
|
|
|
cleaned_dict = {} |
|
|
for key, value in student_state_dict.items(): |
|
|
if key.startswith("unet."): |
|
|
cleaned_dict[key[5:]] = value |
|
|
else: |
|
|
cleaned_dict[key] = value |
|
|
|
|
|
|
|
|
unet.load_state_dict(cleaned_dict, strict=False) |
|
|
|
|
|
step = checkpoint.get("gstep", "unknown") |
|
|
print(f"✅ Loaded checkpoint from step {step}") |
|
|
|
|
|
return unet.to(device) |
|
|
|
|
|
|
|
|
def load_lyra_vae(repo_id: str = "AbstractPhil/vae-lyra", device: str = "cuda"): |
|
|
"""Load Lyra VAE from HuggingFace.""" |
|
|
if not LYRA_AVAILABLE: |
|
|
print("⚠️ Lyra VAE not available - geovocab2 not installed") |
|
|
return None |
|
|
|
|
|
print(f"🎵 Loading Lyra VAE from {repo_id}...") |
|
|
|
|
|
try: |
|
|
|
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename="best_model.pt", |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
print(f"✓ Downloaded checkpoint: {checkpoint_path}") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
|
|
|
|
|
if 'config' in checkpoint: |
|
|
config_dict = checkpoint['config'] |
|
|
else: |
|
|
|
|
|
config_dict = { |
|
|
'modality_dims': {"clip": 768, "t5": 768}, |
|
|
'latent_dim': 768, |
|
|
'seq_len': 77, |
|
|
'encoder_layers': 3, |
|
|
'decoder_layers': 3, |
|
|
'hidden_dim': 1024, |
|
|
'dropout': 0.1, |
|
|
'fusion_strategy': 'cantor', |
|
|
'fusion_heads': 8, |
|
|
'fusion_dropout': 0.1 |
|
|
} |
|
|
|
|
|
|
|
|
vae_config = MultiModalVAEConfig( |
|
|
modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 768}), |
|
|
latent_dim=config_dict.get('latent_dim', 768), |
|
|
seq_len=config_dict.get('seq_len', 77), |
|
|
encoder_layers=config_dict.get('encoder_layers', 3), |
|
|
decoder_layers=config_dict.get('decoder_layers', 3), |
|
|
hidden_dim=config_dict.get('hidden_dim', 1024), |
|
|
dropout=config_dict.get('dropout', 0.1), |
|
|
fusion_strategy=config_dict.get('fusion_strategy', 'cantor'), |
|
|
fusion_heads=config_dict.get('fusion_heads', 8), |
|
|
fusion_dropout=config_dict.get('fusion_dropout', 0.1) |
|
|
) |
|
|
|
|
|
|
|
|
lyra_model = MultiModalVAE(vae_config) |
|
|
|
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
|
lyra_model.load_state_dict(checkpoint['model_state_dict']) |
|
|
else: |
|
|
lyra_model.load_state_dict(checkpoint) |
|
|
|
|
|
lyra_model.to(device) |
|
|
lyra_model.eval() |
|
|
|
|
|
|
|
|
print(f"✅ Lyra VAE loaded successfully") |
|
|
if 'global_step' in checkpoint: |
|
|
print(f" Training step: {checkpoint['global_step']:,}") |
|
|
if 'best_loss' in checkpoint: |
|
|
print(f" Best loss: {checkpoint['best_loss']:.4f}") |
|
|
print(f" Fusion strategy: {vae_config.fusion_strategy}") |
|
|
print(f" Latent dim: {vae_config.latent_dim}") |
|
|
|
|
|
return lyra_model |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to load Lyra VAE: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def initialize_pipeline(model_choice: str, clip_model: str = "openai/clip-vit-large-patch14", device: str = "cuda"): |
|
|
"""Initialize the complete pipeline.""" |
|
|
|
|
|
print(f"🚀 Initializing {model_choice} pipeline...") |
|
|
print(f" CLIP model: {clip_model}") |
|
|
|
|
|
is_lune = "Lune" in model_choice |
|
|
|
|
|
|
|
|
print("Loading VAE...") |
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
subfolder="vae", |
|
|
torch_dtype=torch.float32 |
|
|
).to(device) |
|
|
|
|
|
print(f"Loading CLIP text encoder: {clip_model}...") |
|
|
text_encoder = CLIPTextModel.from_pretrained( |
|
|
clip_model, |
|
|
torch_dtype=torch.float32 |
|
|
).to(device) |
|
|
|
|
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
|
clip_model |
|
|
) |
|
|
|
|
|
|
|
|
print("Loading T5-base encoder...") |
|
|
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base") |
|
|
t5_encoder = T5EncoderModel.from_pretrained( |
|
|
"t5-base", |
|
|
torch_dtype=torch.float32 |
|
|
).to(device) |
|
|
t5_encoder.eval() |
|
|
print("✓ T5 loaded") |
|
|
|
|
|
print("Loading Lyra VAE...") |
|
|
lyra_model = load_lyra_vae(device=device) |
|
|
if lyra_model is None: |
|
|
print("⚠️ Lyra VAE not available - fusion disabled") |
|
|
|
|
|
|
|
|
if is_lune: |
|
|
|
|
|
repo_id = "AbstractPhil/sd15-flow-lune" |
|
|
filename = "sd15_flow_lune_e34_s34000.pt" |
|
|
unet = load_lune_checkpoint(repo_id, filename, device) |
|
|
|
|
|
elif model_choice == "SD1.5 Base": |
|
|
|
|
|
print("Loading SD1.5 base UNet...") |
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
subfolder="unet", |
|
|
torch_dtype=torch.float32 |
|
|
).to(device) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown model: {model_choice}") |
|
|
|
|
|
|
|
|
scheduler = EulerDiscreteScheduler.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
subfolder="scheduler" |
|
|
) |
|
|
|
|
|
print("✅ Pipeline initialized!") |
|
|
|
|
|
pipeline = FlowMatchingPipeline( |
|
|
vae=vae, |
|
|
text_encoder=text_encoder, |
|
|
tokenizer=tokenizer, |
|
|
unet=unet, |
|
|
scheduler=scheduler, |
|
|
device=device, |
|
|
t5_encoder=t5_encoder, |
|
|
t5_tokenizer=t5_tokenizer, |
|
|
lyra_model=lyra_model |
|
|
) |
|
|
|
|
|
|
|
|
pipeline.is_lune_model = is_lune |
|
|
|
|
|
return pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CURRENT_PIPELINE = None |
|
|
CURRENT_MODEL = None |
|
|
CURRENT_CLIP_MODEL = None |
|
|
|
|
|
|
|
|
def get_pipeline(model_choice: str, clip_model: str): |
|
|
"""Get or create pipeline for selected model and CLIP variant.""" |
|
|
global CURRENT_PIPELINE, CURRENT_MODEL, CURRENT_CLIP_MODEL |
|
|
|
|
|
if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice or CURRENT_CLIP_MODEL != clip_model: |
|
|
CURRENT_PIPELINE = initialize_pipeline(model_choice, clip_model, device="cuda") |
|
|
CURRENT_MODEL = model_choice |
|
|
CURRENT_CLIP_MODEL = clip_model |
|
|
|
|
|
return CURRENT_PIPELINE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def estimate_duration(num_steps: int, width: int, height: int, use_lyra: bool = False) -> int: |
|
|
"""Estimate GPU duration based on generation parameters.""" |
|
|
|
|
|
base_time_per_step = 0.3 |
|
|
|
|
|
|
|
|
resolution_factor = (width * height) / (512 * 512) |
|
|
|
|
|
|
|
|
estimated = num_steps * base_time_per_step * resolution_factor |
|
|
|
|
|
|
|
|
if use_lyra: |
|
|
estimated *= 2 |
|
|
estimated += 2 |
|
|
|
|
|
|
|
|
return int(estimated + 15) |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=lambda *args: estimate_duration(args[4], args[6], args[7], args[11])) |
|
|
def generate_image( |
|
|
prompt: str, |
|
|
negative_prompt: str, |
|
|
model_choice: str, |
|
|
clip_model: str, |
|
|
num_steps: int, |
|
|
cfg_scale: float, |
|
|
width: int, |
|
|
height: int, |
|
|
shift: float, |
|
|
use_flow_matching: bool, |
|
|
prediction_type: str, |
|
|
use_lyra: bool, |
|
|
seed: int, |
|
|
randomize_seed: bool, |
|
|
progress=gr.Progress() |
|
|
): |
|
|
"""Generate image with ZeroGPU support. Returns (standard_img, lyra_img, seed) or (img, None, seed).""" |
|
|
|
|
|
|
|
|
if randomize_seed: |
|
|
seed = np.random.randint(0, 2**32 - 1) |
|
|
|
|
|
|
|
|
def progress_callback(step, total, desc): |
|
|
progress((step + 1) / total, desc=desc) |
|
|
|
|
|
try: |
|
|
|
|
|
pipeline = get_pipeline(model_choice, clip_model) |
|
|
|
|
|
if not use_lyra or pipeline.lyra_model is None: |
|
|
|
|
|
progress(0.05, desc="Generating (standard)...") |
|
|
|
|
|
image = pipeline( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
height=height, |
|
|
width=width, |
|
|
num_inference_steps=num_steps, |
|
|
guidance_scale=cfg_scale, |
|
|
shift=shift, |
|
|
use_flow_matching=use_flow_matching, |
|
|
prediction_type=prediction_type, |
|
|
seed=seed, |
|
|
use_lyra=False, |
|
|
progress_callback=progress_callback |
|
|
) |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
return image, None, seed |
|
|
|
|
|
else: |
|
|
|
|
|
progress(0.05, desc="Generating standard version...") |
|
|
|
|
|
image_standard = pipeline( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
height=height, |
|
|
width=width, |
|
|
num_inference_steps=num_steps, |
|
|
guidance_scale=cfg_scale, |
|
|
shift=shift, |
|
|
use_flow_matching=use_flow_matching, |
|
|
prediction_type=prediction_type, |
|
|
seed=seed, |
|
|
use_lyra=False, |
|
|
progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d) |
|
|
) |
|
|
|
|
|
progress(0.5, desc="Generating Lyra fusion version...") |
|
|
|
|
|
image_lyra = pipeline( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
height=height, |
|
|
width=width, |
|
|
num_inference_steps=num_steps, |
|
|
guidance_scale=cfg_scale, |
|
|
shift=shift, |
|
|
use_flow_matching=use_flow_matching, |
|
|
prediction_type=prediction_type, |
|
|
seed=seed, |
|
|
use_lyra=True, |
|
|
progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d) |
|
|
) |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
return image_standard, image_lyra, seed |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Generation failed: {e}") |
|
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_demo(): |
|
|
"""Create Gradio interface.""" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# 🌙 Lyra/Lune Flow-Matching Image Generation |
|
|
|
|
|
**Geometric crystalline diffusion with flow matching** by [AbstractPhil](https://huggingface.co/AbstractPhil) |
|
|
|
|
|
Generate images using SD1.5-based models with geometric deep learning: |
|
|
- **Flow-Lune**: Flow matching with pentachoron geometric structures (15-25 steps) |
|
|
- **SD1.5 Base**: Standard Stable Diffusion 1.5 baseline |
|
|
- **Lyra VAE Toggle**: Add CLIP+T5 fusion for side-by-side comparison |
|
|
- **CLIP Variants**: Different text encoders for varied semantic understanding |
|
|
|
|
|
Enable Lyra to see both standard CLIP and geometric CLIP+T5 fusion results! |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
prompt = gr.TextArea( |
|
|
label="Prompt", |
|
|
value="A serene mountain landscape at golden hour, crystal clear lake reflecting snow-capped peaks, photorealistic, 8k", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
negative_prompt = gr.TextArea( |
|
|
label="Negative Prompt", |
|
|
placeholder="blurry, low quality, distorted...", |
|
|
value="blurry, low quality", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
|
|
|
model_choice = gr.Dropdown( |
|
|
label="Base Model", |
|
|
choices=[ |
|
|
"Flow-Lune (Latest)", |
|
|
"SD1.5 Base" |
|
|
], |
|
|
value="Flow-Lune (Latest)" |
|
|
) |
|
|
|
|
|
|
|
|
clip_model_choice = gr.Dropdown( |
|
|
label="CLIP Model", |
|
|
choices=[ |
|
|
"openai/clip-vit-large-patch14", |
|
|
|
|
|
|
|
|
|
|
|
], |
|
|
value="openai/clip-vit-large-patch14", |
|
|
info="Text encoder variant" |
|
|
) |
|
|
|
|
|
|
|
|
use_lyra = gr.Checkbox( |
|
|
label="Enable Lyra VAE (CLIP+T5 Fusion)", |
|
|
value=True, |
|
|
info="Generate side-by-side comparison with geometric fusion" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("Flow Matching Settings", open=True): |
|
|
use_flow_matching = gr.Checkbox( |
|
|
label="Enable Flow Matching", |
|
|
value=True, |
|
|
info="Use flow matching ODE integration" |
|
|
) |
|
|
|
|
|
shift = gr.Slider( |
|
|
label="Shift", |
|
|
minimum=0.0, |
|
|
maximum=5.0, |
|
|
value=2.5, |
|
|
step=0.1, |
|
|
info="Flow matching shift parameter (0=disabled, 1-3 typical)" |
|
|
) |
|
|
|
|
|
prediction_type = gr.Radio( |
|
|
label="Prediction Type", |
|
|
choices=["epsilon", "v_prediction"], |
|
|
value="v_prediction", |
|
|
info="Type of model prediction" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("Generation Settings", open=True): |
|
|
num_steps = gr.Slider( |
|
|
label="Steps", |
|
|
minimum=1, |
|
|
maximum=50, |
|
|
value=20, |
|
|
step=1, |
|
|
info="Flow matching typically needs fewer steps (15-25)" |
|
|
) |
|
|
|
|
|
cfg_scale = gr.Slider( |
|
|
label="CFG Scale", |
|
|
minimum=1.0, |
|
|
maximum=20.0, |
|
|
value=7.5, |
|
|
step=0.5 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
width = gr.Slider( |
|
|
label="Width", |
|
|
minimum=256, |
|
|
maximum=1024, |
|
|
value=512, |
|
|
step=64 |
|
|
) |
|
|
|
|
|
height = gr.Slider( |
|
|
label="Height", |
|
|
minimum=256, |
|
|
maximum=1024, |
|
|
value=512, |
|
|
step=64 |
|
|
) |
|
|
|
|
|
seed = gr.Slider( |
|
|
label="Seed", |
|
|
minimum=0, |
|
|
maximum=2**32 - 1, |
|
|
value=42, |
|
|
step=1 |
|
|
) |
|
|
|
|
|
randomize_seed = gr.Checkbox( |
|
|
label="Randomize Seed", |
|
|
value=True |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Row(): |
|
|
output_image_standard = gr.Image( |
|
|
label="Standard Generation", |
|
|
type="pil", |
|
|
visible=True |
|
|
) |
|
|
|
|
|
output_image_lyra = gr.Image( |
|
|
label="Lyra Fusion 🎵", |
|
|
type="pil", |
|
|
visible=True |
|
|
) |
|
|
|
|
|
output_seed = gr.Number( |
|
|
label="Used Seed", |
|
|
precision=0 |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### Tips: |
|
|
- **Flow matching** works best with 15-25 steps (vs 50+ for standard diffusion) |
|
|
- **Shift** controls the flow trajectory (2.0-2.5 recommended for Lune) |
|
|
- Lower shift = more direct path, higher shift = more exploration |
|
|
- **Lune** uses v_prediction by default for optimal results |
|
|
- **Lyra toggle** generates side-by-side comparison (CLIP vs CLIP+T5 fusion) |
|
|
- **CLIP variants** may give different semantic interpretations |
|
|
- **SD1.5 Base** uses epsilon (standard diffusion) |
|
|
- Lune operates in a scaled latent space (5.52x) for geometric efficiency |
|
|
|
|
|
### Model Info: |
|
|
- **Flow-Lune**: Trained with flow matching on 500k SD1.5 distillation pairs |
|
|
- **Lyra VAE**: Multi-modal fusion (CLIP+T5) via Cantor geometric attention |
|
|
- **SD1.5 Base**: Standard Stable Diffusion 1.5 for comparison |
|
|
|
|
|
### CLIP Models: |
|
|
- **openai/clip-vit-large-patch14**: Standard CLIP-L (default) |
|
|
- **openai/clip-vit-large-patch14-336**: Higher resolution CLIP-L |
|
|
- **laion/CLIP-ViT-L-14**: LAION-trained CLIP-L variant |
|
|
- **laion/CLIP-ViT-bigG-14**: Larger CLIP-G model |
|
|
|
|
|
[📚 Learn more about geometric deep learning](https://github.com/AbstractEyes/lattice_vocabulary) |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
[ |
|
|
"A serene mountain landscape at golden hour, crystal clear lake reflecting snow-capped peaks, photorealistic, 8k", |
|
|
"blurry, low quality", |
|
|
"Flow-Lune (Latest)", |
|
|
"openai/clip-vit-large-patch14", |
|
|
20, |
|
|
7.5, |
|
|
512, |
|
|
512, |
|
|
2.5, |
|
|
True, |
|
|
"v_prediction", |
|
|
False, |
|
|
42, |
|
|
False |
|
|
], |
|
|
[ |
|
|
"A futuristic cyberpunk city at night, neon lights, rain-slicked streets, highly detailed", |
|
|
"low quality, blurry", |
|
|
"Flow-Lune (Latest)", |
|
|
"openai/clip-vit-large-patch14", |
|
|
20, |
|
|
7.5, |
|
|
512, |
|
|
512, |
|
|
2.5, |
|
|
True, |
|
|
"v_prediction", |
|
|
True, |
|
|
123, |
|
|
False |
|
|
], |
|
|
[ |
|
|
"Portrait of a majestic lion, golden mane, dramatic lighting, wildlife photography", |
|
|
"cartoon, painting", |
|
|
"SD1.5 Base", |
|
|
"openai/clip-vit-large-patch14", |
|
|
30, |
|
|
7.5, |
|
|
512, |
|
|
512, |
|
|
0.0, |
|
|
False, |
|
|
"epsilon", |
|
|
True, |
|
|
456, |
|
|
False |
|
|
] |
|
|
], |
|
|
inputs=[ |
|
|
prompt, negative_prompt, model_choice, clip_model_choice, num_steps, cfg_scale, |
|
|
width, height, shift, use_flow_matching, prediction_type, use_lyra, |
|
|
seed, randomize_seed |
|
|
], |
|
|
outputs=[output_image_standard, output_image_lyra, output_seed], |
|
|
fn=generate_image, |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_model_change(model_name): |
|
|
"""Update default settings based on model selection.""" |
|
|
if model_name == "SD1.5 Base": |
|
|
|
|
|
return { |
|
|
use_flow_matching: gr.update(value=False), |
|
|
prediction_type: gr.update(value="epsilon") |
|
|
} |
|
|
else: |
|
|
|
|
|
return { |
|
|
use_flow_matching: gr.update(value=True), |
|
|
prediction_type: gr.update(value="v_prediction") |
|
|
} |
|
|
|
|
|
|
|
|
def on_lyra_toggle(lyra_enabled): |
|
|
"""Show/hide Lyra comparison image.""" |
|
|
if lyra_enabled: |
|
|
return { |
|
|
output_image_standard: gr.update(visible=True, label="Standard CLIP"), |
|
|
output_image_lyra: gr.update(visible=True, label="Lyra Fusion (CLIP+T5) 🎵") |
|
|
} |
|
|
else: |
|
|
return { |
|
|
output_image_standard: gr.update(visible=True, label="Generated Image"), |
|
|
output_image_lyra: gr.update(visible=False) |
|
|
} |
|
|
|
|
|
model_choice.change( |
|
|
fn=on_model_change, |
|
|
inputs=[model_choice], |
|
|
outputs=[use_flow_matching, prediction_type] |
|
|
) |
|
|
|
|
|
use_lyra.change( |
|
|
fn=on_lyra_toggle, |
|
|
inputs=[use_lyra], |
|
|
outputs=[output_image_standard, output_image_lyra] |
|
|
) |
|
|
on_lyra_toggle(True) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_image, |
|
|
inputs=[ |
|
|
prompt, negative_prompt, model_choice, clip_model_choice, num_steps, cfg_scale, |
|
|
width, height, shift, use_flow_matching, prediction_type, use_lyra, |
|
|
seed, randomize_seed |
|
|
], |
|
|
outputs=[output_image_standard, output_image_lyra, output_seed] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_demo() |
|
|
demo.queue(max_size=20) |
|
|
demo.launch(show_api=False) |