Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Gradio Application for Stable Diffusion | |
Author: Shilpaj Bhalerao | |
Date: Feb 26, 2025 | |
""" | |
import gc | |
import os | |
import torch | |
import gradio as gr | |
# import spaces | |
from tqdm.auto import tqdm | |
from PIL import Image | |
from utils import ( | |
load_models, clear_gpu_memory, set_timesteps, latents_to_pil, | |
vignette_loss, get_concept_embedding, image_grid | |
) | |
# Remove this import to avoid the cached_download error | |
# from diffusers import StableDiffusionPipeline | |
def generate_latents(prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width): | |
""" | |
Function to generate latents from the UNet | |
:param seed_number: Seed | |
:param prompt: Text prompt | |
:param concept: Concept to influence generation (optional) | |
:param concept_strength: How strongly to apply the concept (0.0-1.0) | |
:return: Latents of the UNet. This will be passed to the VAE to generate the image | |
""" | |
global art_concepts | |
# Batch size | |
batch_size = 1 | |
# Set the seed | |
generator = torch.manual_seed(seed) | |
# Prep text | |
text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
# Get the concept embedding | |
concept_embedding = art_concepts[concept] | |
# Apply concept embedding influence if provided | |
if concept_embedding is not None and concept_strength > 0: | |
# Fix the dimension mismatch by adding a batch dimension to concept_embedding if needed | |
if len(concept_embedding.shape) == 2 and len(text_embeddings.shape) == 3: | |
# Add batch dimension to concept_embedding to match text_embeddings | |
concept_embedding = concept_embedding.unsqueeze(0) | |
# Create weighted blend between original text embedding and concept | |
if text_embeddings.shape == concept_embedding.shape: | |
# Interpolate between text embeddings and concept | |
text_embeddings = (1 - concept_strength) * text_embeddings + concept_strength * concept_embedding | |
print(f"Successfully applied concept with strength {concept_strength}") | |
else: | |
print(f"Warning: Shapes still incompatible after adjustment. Concept: {concept_embedding.shape}, Text: {text_embeddings.shape}") | |
# And the uncond. input as before: | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer( | |
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
with torch.no_grad(): | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
# Prep Scheduler | |
set_timesteps(scheduler, num_inference_steps) | |
# Prep latents | |
latents = torch.randn( | |
(batch_size, unet.in_channels, height // 8, width // 8), | |
generator=generator, | |
) | |
latents = latents.to(device) | |
latents = latents * scheduler.init_noise_sigma | |
# Loop | |
for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)): | |
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
latent_model_input = torch.cat([latents] * 2) | |
sigma = scheduler.sigmas[i] | |
latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
# predict the noise residual | |
with torch.no_grad(): | |
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] | |
# perform CFG | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
#### ADDITIONAL GUIDANCE ### | |
if i%5 == 0: | |
# Requires grad on the latents | |
latents = latents.detach().requires_grad_() | |
# Get the predicted x0: | |
latents_x0 = latents - sigma * noise_pred | |
# latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample | |
# Decode to image space | |
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1) | |
# Calculate loss | |
loss = vignette_loss(denoised_images) * vignette_loss_scale | |
# Occasionally print it out | |
if i%10==0: | |
print(i, 'loss:', loss.item()) | |
# Get gradient | |
cond_grad = torch.autograd.grad(loss, latents)[0] | |
# Modify the latents based on this gradient | |
latents = latents.detach() - cond_grad * sigma**2 | |
# Now step with scheduler | |
latents = scheduler.step(noise_pred, t, latents).prev_sample | |
return latents | |
def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5, | |
vignette_loss_scale=0.0, concept="none", concept_strength=0.5, height=512, width=512): | |
""" | |
Generate a single image | |
""" | |
global vae | |
latents = generate_latents(prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width) | |
generated_image = latents_to_pil(latents, vae) | |
return image_grid(generated_image, 1, 1, None) | |
def generate_style_images(prompt, num_inference_steps=30, guidance_scale=7.5, | |
vignette_loss_scale=0.0, concept_strength=0.5, height=512, width=512): | |
""" | |
Function to generate images of all the styles | |
""" | |
global art_concepts, vae | |
seed_list = [2000, 1000, 500, 600, 100] | |
latents_collect = [] | |
concept_labels = [] | |
# Load and remove the "none" element | |
concepts_list = list(art_concepts.keys()) | |
concepts_list.remove("none") | |
for seed_no, concept in zip(seed_list, concepts_list): | |
# Clear the CUDA cache | |
torch.cuda.empty_cache() | |
gc.collect() | |
torch.cuda.empty_cache() | |
print(f"Generating image with concept '{concept}' at strength {concept_strength}") | |
# Generate latents using the concept embedding | |
latents = generate_latents(prompt, seed_no, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width) | |
latents_collect.append(latents) | |
concept_labels.append(f"{concept} ({concept_strength})") | |
# Show results | |
latents_collect = torch.vstack(latents_collect) | |
images = latents_to_pil(latents_collect, vae) | |
return image_grid(images, 1, len(seed_list), concept_labels) | |
# Define Gradio interface | |
# @spaces.GPU(enable_queue=False) | |
def create_demo(): | |
with gr.Blocks(title="Guided Stable Diffusion with Styles") as demo: | |
gr.Markdown("# Guided Stable Diffusion with Styles") | |
with gr.Tab("Single Image Generation"): | |
with gr.Row(): | |
with gr.Column(): | |
all_styles = ["none"] + list(art_concepts.keys()) | |
all_styles.remove("none") # Remove "none" to avoid duplication | |
all_styles = ["none"] + all_styles # Add it back at the beginning | |
prompt = gr.Textbox(label="Prompt", placeholder="A cat sitting on a chair") | |
seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=1000) | |
concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none") | |
concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5) | |
num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
height = gr.Slider(minimum=256, maximum=1024, step=1, label="Height", value=512) | |
width = gr.Slider(minimum=256, maximum=1024, step=1, label="Width", value=512) | |
guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=8.0) | |
vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=70.0) | |
generate_btn = gr.Button("Generate Image") | |
with gr.Column(): | |
output_image = gr.Image(label="Generated Image", type="pil") | |
with gr.Tab("Style Grid"): | |
with gr.Row(): | |
with gr.Column(): | |
grid_prompt = gr.Textbox(label="Prompt", placeholder="A dog running in the park") | |
grid_num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
grid_guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=8.0) | |
grid_vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=70.0) | |
grid_concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5) | |
grid_generate_btn = gr.Button("Generate Style Grid") | |
with gr.Column(): | |
output_grid = gr.Image(label="Style Grid", type="pil") | |
# Set up event handlers | |
generate_btn.click( | |
generate_image, | |
inputs=[prompt, seed, num_inference_steps, guidance_scale, | |
vignette_loss_scale, concept_style, concept_strength, height, width], | |
outputs=output_image | |
) | |
grid_generate_btn.click( | |
generate_style_images, | |
inputs=[grid_prompt, grid_num_inference_steps, | |
grid_guidance_scale, grid_vignette_loss_scale, grid_concept_strength], | |
outputs=output_grid | |
) | |
return demo | |
# Launch the app | |
if __name__ == "__main__": | |
# Set device | |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
if device == "mps": | |
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" | |
# Load models | |
vae, tokenizer, text_encoder, unet, scheduler, pipe = load_models(device=device) | |
# Define art style concepts | |
art_concepts = { | |
"sketch_painting": get_concept_embedding("a sketch painting, pencil drawing, hand-drawn illustration", tokenizer, text_encoder, device), | |
"oil_painting": get_concept_embedding("an oil painting, textured canvas, painterly technique", tokenizer, text_encoder, device), | |
"watercolor": get_concept_embedding("a watercolor painting, fluid, soft edges", tokenizer, text_encoder, device), | |
"digital_art": get_concept_embedding("digital art, computer generated, precise details", tokenizer, text_encoder, device), | |
"comic_book": get_concept_embedding("comic book style, ink outlines, cel shading", tokenizer, text_encoder, device), | |
"none": None | |
} | |
demo = create_demo() | |
demo.launch(debug=True) | |