# Import spaces first to avoid CUDA initialization conflicts import spaces import gradio as gr import numpy as np import random import torch from PIL import Image from torchvision import transforms from diffusers import DiffusionPipeline, AutoencoderKL # Define constants flux_dtype = torch.bfloat16 vae_dtype = torch.float32 MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 # Move device selection after spaces import device = "cuda" if torch.cuda.is_available() else "cpu" def load_models(): # Load the initial VAE model for preprocessing in float32 vae_model_name = "runwayml/stable-diffusion-v1-5" vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder="vae").to(device).to(vae_dtype) # Load the FLUX diffusion pipeline with bfloat16 pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=flux_dtype) pipe.enable_model_cpu_offload() pipe.vae.enable_slicing() pipe.vae.enable_tiling() pipe.to(device) return vae, pipe # Defer model loading until it's needed vae, pipe = None, None def ensure_models_loaded(): global vae, pipe if vae is None or pipe is None: vae, pipe = load_models() def preprocess_image(image, image_size): preprocess = transforms.Compose([ transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) image = preprocess(image).unsqueeze(0).to(device, dtype=vae_dtype) print("Image processed successfully.") return image def encode_image(image, vae): try: with torch.no_grad(): latents = vae.encode(image).latent_dist.sample() * 0.18215 print("Image encoded successfully.") return latents except RuntimeError as e: print(f"Error during image encoding: {e}") raise @spaces.GPU() def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)): ensure_models_loaded() if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=device).manual_seed(seed) fallback_image = Image.new("RGB", (width, height), (255, 0, 0)) # Red image as a fallback try: if init_image is None: # text2img case result = pipe( prompt=prompt, height=height, width=width, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=0.0, max_sequence_length=256 ) else: # img2img case print("Initial image provided, starting preprocessing...") vae_image_size = 1024 # Using FLUX VAE sample size for preprocessing init_image = init_image.convert("RGB") init_image = preprocess_image(init_image, vae_image_size) print("Starting encoding of the image...") latents = encode_image(init_image, vae) print(f"Latents shape after encoding: {latents.shape}") # Ensure the latents size matches the expected input size for the FLUX model print("Interpolating latents to match model's input size...") latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear') latent_channels = latents.shape[1] print(f"Latent channels from VAE: {latent_channels}, expected by FLUX model: {pipe.vae.config.latent_channels}") if latent_channels != pipe.vae.config.latent_channels: print(f"Adjusting latent channels from {latent_channels} to {pipe.vae.config.latent_channels}") conv = torch.nn.Conv2d(latent_channels, pipe.vae.config.latent_channels, kernel_size=1).to(device, dtype=flux_dtype) latents = conv(latents.to(flux_dtype)) latents = latents.permute(0, 2, 3, 1).contiguous().view(-1, pipe.vae.config.latent_channels) print(f"Latents shape after permutation: {latents.shape}") result = pipe( prompt=prompt, height=height, width=width, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=0.0, latents=latents ) image = result.images[0] return image, seed except Exception as e: print(f"Error during inference: {e}") return fallback_image, seed # ... (rest of the Gradio interface code remains the same) # Define example prompts examples = [ "a tiny astronaut hatching from an egg on the moon", "a cat holding a sign that says hello world", "an anime illustration of a wiener schnitzel", ] # CSS styling for the Japanese-inspired interface css = """ body { background-color: #fff; font-family: 'Noto Sans JP', sans-serif; color: #333; } #col-container { margin: 0 auto; max-width: 520px; border: 2px solid #000; padding: 20px; background-color: #f7f7f7; border-radius: 10px; } .gr-button { background-color: #e60012; color: #fff; border: 2px solid #000; } .gr-button:hover { background-color: #c20010; } .gr-slider, .gr-checkbox, .gr-textbox { border: 2px solid #000; } .gr-accordion { border: 2px solid #000; background-color: #fff; } .gr-image { border: 2px solid #000; } """ # Create the Gradio interface with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # FLUX.1 [schnell] 12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)] """) with gr.Row(): prompt = gr.Textbox( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0) with gr.Row(): init_image = gr.Image(label="Initial Image (optional)", type="pil") result = gr.Image(label="Result", show_label=False) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) with gr.Row(): num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=4, ) gr.Examples( examples=examples, fn=infer, inputs=[prompt], outputs=[result, seed], cache_examples="lazy" ) run_button.click( infer, inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps], outputs=[result, seed] ) demo.launch()