import os import torch import gradio as gr from PIL import Image import torch.nn.functional as F from torchvision import transforms as tfms from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler # Import DPMSolver # 1. Device and dtype: Correctly determine device and dtype. Use float16 if CUDA is available. torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32 print(f"Using device: {torch_device}, dtype: {torch_dtype}") # Helpful for debugging # 2. Model Path and Loading: Use a more efficient scheduler and reduce memory usage. model_path = "CompVis/stable-diffusion-v1-4" # Use DPMSolverMultistepScheduler for faster and higher-quality sampling scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler") sd_pipeline = DiffusionPipeline.from_pretrained( model_path, torch_dtype=torch_dtype, scheduler=scheduler, # Use the DPM scheduler # low_cpu_mem_usage is deprecated, but still helpful on CPU. low_cpu_mem_usage=True if torch_device == "cpu" else False, # Use attention slicing to reduce VRAM usage during inference. # This has a small performance cost but significantly lowers memory. safety_checker=None, #Removing the safety checker to avoid false positives blocking image generation requires_safety_checker=False ).to(torch_device) # Optimize attention for memory efficiency (if using CUDA) if torch_device == "cuda": sd_pipeline.enable_xformers_memory_efficient_attention() # Use xformers if installed! # OR, if xformers is not available: # sd_pipeline.enable_attention_slicing() # Use attention slicing (less effective, but built-in) # 3. Textual Inversion Loading: Load *only* the necessary concepts. Load them one by one. # This is *much* more memory efficient than loading all at once. style_token_dict = { "Illustration Style": '', "Line Art": '', "Hitokomoru Style": '', "Marc Allante": '', "Midjourney": '', "Hanfu Anime": '', "Birb Style": '' } # Load inversions individually. This is crucial for managing memory. def load_inversion(concept_name, token): try: sd_pipeline.load_textual_inversion(f"sd-concepts-library/{concept_name}", token=token) print(f"Loaded textual inversion: {concept_name}") except Exception as e: print(f"Error loading {concept_name}: {e}") # Load each style individually. load_inversion("illustration-style", style_token_dict["Illustration Style"]) load_inversion("line-art", style_token_dict["Line Art"]) load_inversion("hitokomoru-style-nao", style_token_dict["Hitokomoru Style"]) load_inversion("style-of-marc-allante", style_token_dict["Marc Allante"]) load_inversion("midjourney-style", style_token_dict["Midjourney"]) load_inversion("hanfu-anime-style", style_token_dict["Hanfu Anime"]) load_inversion("birb-style", style_token_dict["Birb Style"]) # 4. Guidance Function: Optimized for speed and clarity. def apply_guidance(image, guidance_method, loss_scale): img_tensor = tfms.ToTensor()(image).unsqueeze(0).to(torch_device) loss_scale = loss_scale / 10000.0 # Pre-calculate for efficiency if guidance_method == 'Grayscale': gray = tfms.Grayscale(num_output_channels=3)(img_tensor) # keep 3 channels guided = img_tensor + (gray - img_tensor) * loss_scale elif guidance_method == 'Bright': guided = torch.clamp(img_tensor * (1 + loss_scale), 0, 1) # Direct brightness adjustment elif guidance_method == 'Contrast': mean = img_tensor.mean() guided = torch.clamp((img_tensor - mean) * (1 + loss_scale) + mean, 0, 1) # Contrast adjustment elif guidance_method == 'Symmetry': flipped = torch.flip(img_tensor, [3]) guided = img_tensor + (flipped - img_tensor) * loss_scale elif guidance_method == 'Saturation': # Use torchvision's functional approach for efficiency. guided = tfms.functional.adjust_saturation(img_tensor, 1 + loss_scale) guided = torch.clamp(guided, 0, 1) else: return image # Convert back to PIL Image (optimized for conciseness) guided = tfms.ToPILImage()(guided.squeeze(0).cpu()) return guided # 5. Inference Function: Use the pipeline efficiently. def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size): prompt = f"{text} {style_token_dict[style]}" width, height = map(int, image_size.split('x')) generator = torch.Generator(device=torch_device).manual_seed(seed) # Generate image (more concise) image_pipeline = sd_pipeline( prompt, num_inference_steps=inference_step, guidance_scale=guidance_scale, generator=generator, height=height, width=width, ).images[0] image_guide = apply_guidance(image_pipeline, guidance_method, loss_scale) return image_pipeline, image_guide # 6. Gradio Interface (CSS and HTML remain largely the same, but I've included minor improvements) css_and_html = """

Dreamscape Creator

Unleash your imagination with AI-powered generative art

🎨
Illustration Style
✏️
Line Art
🌌
Midjourney Style
👘
Hanfu Anime
""" with gr.Blocks(css=css_and_html) as demo: gr.HTML(css_and_html) with gr.Row(): text = gr.Textbox(label="Prompt", placeholder="Describe your dreamscape...") style = gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style") with gr.Row(): inference_step = gr.Slider(1, 50, 20, step=1, label="Inference steps") guidance_scale = gr.Slider(1, 10, 7.5, step=0.1, label="Guidance scale") seed = gr.Slider(0, 10000, 42, step=1, label="Seed", randomize=True) # Add randomize with gr.Row(): guidance_method = gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast', 'Symmetry', 'Saturation'], value="Grayscale") loss_scale = gr.Slider(100, 10000, 200, step=100, label="Loss scale") with gr.Row(): image_size = gr.Radio(["256x256", "512x512"], label="Image Size", value="256x256") with gr.Row(): generate_button = gr.Button("Create Dreamscape", variant="primary") with gr.Row(): output_image = gr.Image(label="Your Dreamscape", interactive=False) # Disable interaction output_image_guided = gr.Image(label="Guided Dreamscape", interactive=False) # Disable interaction generate_button.click( inference, inputs=[text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size], outputs=[output_image, output_image_guided] ) gr.Examples( examples=[ ["Magical Forest with Glowing Trees", 'Birb Style', 40, 7.5, 42, 'Grayscale', 200, "256x256"], ["Ancient Temple Ruins at Sunset", 'Midjourney', 30, 8.0, 123, 'Bright', 5678, "256x256"], ["Japanese garden with cherry blossoms", 'Hitokomoru Style', 40, 7.0, 789, 'Contrast', 250, "256x256"], ], inputs=[text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size], outputs=[output_image, output_image_guided], fn=inference, # cache_examples=True, # Caching can be problematic on Spaces, especially with limited RAM. Disable if needed. cache_examples=False, examples_per_page=5 ) if __name__ == "__main__": demo.launch()