StableDiffusion / app.py
Shilpaj's picture
Upload app.py
00c204f verified
#!/usr/bin/env python3
"""
Gradio Application for Stable Diffusion
Author: Shilpaj Bhalerao
Date: Feb 26, 2025
"""
import gc
import os
import torch
import gradio as gr
# import spaces
from tqdm.auto import tqdm
from PIL import Image
from utils import (
load_models, clear_gpu_memory, set_timesteps, latents_to_pil,
vignette_loss, get_concept_embedding, image_grid
)
# Remove this import to avoid the cached_download error
# from diffusers import StableDiffusionPipeline
def generate_latents(prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width):
"""
Function to generate latents from the UNet
:param seed_number: Seed
:param prompt: Text prompt
:param concept: Concept to influence generation (optional)
:param concept_strength: How strongly to apply the concept (0.0-1.0)
:return: Latents of the UNet. This will be passed to the VAE to generate the image
"""
global art_concepts
# Batch size
batch_size = 1
# Set the seed
generator = torch.manual_seed(seed)
# Prep text
text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
# Get the concept embedding
concept_embedding = art_concepts[concept]
# Apply concept embedding influence if provided
if concept_embedding is not None and concept_strength > 0:
# Fix the dimension mismatch by adding a batch dimension to concept_embedding if needed
if len(concept_embedding.shape) == 2 and len(text_embeddings.shape) == 3:
# Add batch dimension to concept_embedding to match text_embeddings
concept_embedding = concept_embedding.unsqueeze(0)
# Create weighted blend between original text embedding and concept
if text_embeddings.shape == concept_embedding.shape:
# Interpolate between text embeddings and concept
text_embeddings = (1 - concept_strength) * text_embeddings + concept_strength * concept_embedding
print(f"Successfully applied concept with strength {concept_strength}")
else:
print(f"Warning: Shapes still incompatible after adjustment. Concept: {concept_embedding.shape}, Text: {text_embeddings.shape}")
# And the uncond. input as before:
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Prep Scheduler
set_timesteps(scheduler, num_inference_steps)
# Prep latents
latents = torch.randn(
(batch_size, unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
# Loop
for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
sigma = scheduler.sigmas[i]
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
# perform CFG
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
#### ADDITIONAL GUIDANCE ###
if i%5 == 0:
# Requires grad on the latents
latents = latents.detach().requires_grad_()
# Get the predicted x0:
latents_x0 = latents - sigma * noise_pred
# latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
# Decode to image space
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
# Calculate loss
loss = vignette_loss(denoised_images) * vignette_loss_scale
# Occasionally print it out
if i%10==0:
print(i, 'loss:', loss.item())
# Get gradient
cond_grad = torch.autograd.grad(loss, latents)[0]
# Modify the latents based on this gradient
latents = latents.detach() - cond_grad * sigma**2
# Now step with scheduler
latents = scheduler.step(noise_pred, t, latents).prev_sample
return latents
def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
vignette_loss_scale=0.0, concept="none", concept_strength=0.5, height=512, width=512):
"""
Generate a single image
"""
global vae
latents = generate_latents(prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width)
generated_image = latents_to_pil(latents, vae)
return image_grid(generated_image, 1, 1, None)
def generate_style_images(prompt, num_inference_steps=30, guidance_scale=7.5,
vignette_loss_scale=0.0, concept_strength=0.5, height=512, width=512):
"""
Function to generate images of all the styles
"""
global art_concepts, vae
seed_list = [2000, 1000, 500, 600, 100]
latents_collect = []
concept_labels = []
# Load and remove the "none" element
concepts_list = list(art_concepts.keys())
concepts_list.remove("none")
for seed_no, concept in zip(seed_list, concepts_list):
# Clear the CUDA cache
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
print(f"Generating image with concept '{concept}' at strength {concept_strength}")
# Generate latents using the concept embedding
latents = generate_latents(prompt, seed_no, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width)
latents_collect.append(latents)
concept_labels.append(f"{concept} ({concept_strength})")
# Show results
latents_collect = torch.vstack(latents_collect)
images = latents_to_pil(latents_collect, vae)
return image_grid(images, 1, len(seed_list), concept_labels)
# Define Gradio interface
# @spaces.GPU(enable_queue=False)
def create_demo():
with gr.Blocks(title="Guided Stable Diffusion with Styles") as demo:
gr.Markdown("# Guided Stable Diffusion with Styles")
with gr.Tab("Single Image Generation"):
with gr.Row():
with gr.Column():
all_styles = ["none"] + list(art_concepts.keys())
all_styles.remove("none") # Remove "none" to avoid duplication
all_styles = ["none"] + all_styles # Add it back at the beginning
prompt = gr.Textbox(label="Prompt", placeholder="A cat sitting on a chair")
seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=1000)
concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none")
concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30)
height = gr.Slider(minimum=256, maximum=1024, step=1, label="Height", value=512)
width = gr.Slider(minimum=256, maximum=1024, step=1, label="Width", value=512)
guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=8.0)
vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=70.0)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image", type="pil")
with gr.Tab("Style Grid"):
with gr.Row():
with gr.Column():
grid_prompt = gr.Textbox(label="Prompt", placeholder="A dog running in the park")
grid_num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30)
grid_guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=8.0)
grid_vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=70.0)
grid_concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
grid_generate_btn = gr.Button("Generate Style Grid")
with gr.Column():
output_grid = gr.Image(label="Style Grid", type="pil")
# Set up event handlers
generate_btn.click(
generate_image,
inputs=[prompt, seed, num_inference_steps, guidance_scale,
vignette_loss_scale, concept_style, concept_strength, height, width],
outputs=output_image
)
grid_generate_btn.click(
generate_style_images,
inputs=[grid_prompt, grid_num_inference_steps,
grid_guidance_scale, grid_vignette_loss_scale, grid_concept_strength],
outputs=output_grid
)
return demo
# Launch the app
if __name__ == "__main__":
# Set device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
if device == "mps":
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
# Load models
vae, tokenizer, text_encoder, unet, scheduler, pipe = load_models(device=device)
# Define art style concepts
art_concepts = {
"sketch_painting": get_concept_embedding("a sketch painting, pencil drawing, hand-drawn illustration", tokenizer, text_encoder, device),
"oil_painting": get_concept_embedding("an oil painting, textured canvas, painterly technique", tokenizer, text_encoder, device),
"watercolor": get_concept_embedding("a watercolor painting, fluid, soft edges", tokenizer, text_encoder, device),
"digital_art": get_concept_embedding("digital art, computer generated, precise details", tokenizer, text_encoder, device),
"comic_book": get_concept_embedding("comic book style, ink outlines, cel shading", tokenizer, text_encoder, device),
"none": None
}
demo = create_demo()
demo.launch(debug=True)