StableDiffusion / app.py
jatingocodeo's picture
Update app.py
6fca3f3 verified
raw
history blame
10.7 kB
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import numpy as np
import os
from huggingface_hub import hf_hub_download
import warnings
from transformers import CLIPProcessor, CLIPModel
warnings.filterwarnings("ignore")
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load CLIP model for semantic guidance
print("Loading CLIP model for semantic guidance...")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Dictionary of available concepts
CONCEPTS = {
"canna-lily-flowers102": {
"repo_id": "sd-concepts-library/canna-lily-flowers102",
"type": "object",
"description": "Canna lily flower style"
},
"samurai-jack": {
"repo_id": "sd-concepts-library/samurai-jack",
"type": "style",
"description": "Samurai Jack animation style"
},
"babies-poster": {
"repo_id": "sd-concepts-library/babies-poster",
"type": "style",
"description": "Babies poster art style"
},
"animal-toy": {
"repo_id": "sd-concepts-library/animal-toy",
"type": "object",
"description": "Animal toy style"
},
"sword-lily-flowers102": {
"repo_id": "sd-concepts-library/sword-lily-flowers102",
"type": "object",
"description": "Sword lily flower style"
}
}
def car_loss(image):
"""Custom loss function that encourages the presence of cars in the image"""
# Convert PIL image to tensor if needed
if isinstance(image, Image.Image):
image = np.array(image)
image = torch.tensor(image, device=device)
# Process image for CLIP
with torch.no_grad():
# Convert to PIL for CLIP processing
pil_image = Image.fromarray(image.cpu().numpy().astype(np.uint8))
# Get CLIP features for the image
inputs = clip_processor(
text=["a photo of a car", "a photo without cars"],
images=pil_image,
return_tensors="pt",
padding=True
).to(device)
# Get similarity scores
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
# Higher score for the first text (with cars) is better
car_score = logits_per_image[0][0]
no_car_score = logits_per_image[0][1]
# We want to maximize car_score and minimize no_car_score
loss = -(car_score - no_car_score)
return loss
def generate_image(pipe, prompt, seed, guidance_scale=7.5, num_inference_steps=30, use_car_guidance=False):
"""Generate an image with optional car guidance"""
generator = torch.Generator(device).manual_seed(seed)
custom_loss = car_loss if use_car_guidance else None
if custom_loss:
try:
# Start with a standard generation
init_images = pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps // 2,
generator=generator
).images
init_image = init_images[0]
# Refine using car guidance
from diffusers import StableDiffusionImg2ImgPipeline
img2img_pipe = StableDiffusionImg2ImgPipeline(
vae=pipe.vae,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
unet=pipe.unet,
scheduler=pipe.scheduler,
safety_checker=None,
feature_extractor=None,
).to(device)
strength = 0.75
current_image = init_image
for i in range(5):
current_loss = custom_loss(current_image)
refined_images = img2img_pipe(
prompt=prompt + ", with beautiful cars",
image=current_image,
strength=strength,
guidance_scale=guidance_scale,
generator=generator,
).images
current_image = refined_images[0]
strength *= 0.8
return current_image
except Exception as e:
print(f"Error in car-guided generation: {e}")
return pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
).images[0]
else:
return pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
).images[0]
# Cache for loaded models and concepts
loaded_models = {}
def get_model_with_concept(concept_name):
"""Get or load a model with the specified concept"""
if concept_name not in loaded_models:
concept_info = CONCEPTS[concept_name]
# Download concept embedding
concept_path = f"concepts/{concept_name}.bin"
os.makedirs("concepts", exist_ok=True)
if not os.path.exists(concept_path):
file = hf_hub_download(
repo_id=concept_info["repo_id"],
filename="learned_embeds.bin",
repo_type="model"
)
import shutil
shutil.copy(file, concept_path)
# Load model and concept
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2",
torch_dtype=torch.float32 if device == "cpu" else torch.float16,
safety_checker=None
).to(device)
pipe.load_textual_inversion(concept_path)
loaded_models[concept_name] = pipe
return loaded_models[concept_name]
def generate_images(concept_name, base_prompt, seed, use_car_guidance):
"""Generate images using the selected concept"""
try:
# Get model with concept
pipe = get_model_with_concept(concept_name)
# Construct prompt based on concept type
if CONCEPTS[concept_name]["type"] == "object":
prompt = f"A {base_prompt} with a <{concept_name}>"
else:
prompt = f"<{concept_name}> {base_prompt}"
# Generate image
image = generate_image(
pipe=pipe,
prompt=prompt,
seed=int(seed),
use_car_guidance=use_car_guidance
)
return image
except Exception as e:
raise gr.Error(f"Error generating image: {str(e)}")
# Create Gradio interface
with gr.Blocks(title="Stable Diffusion Style Explorer") as demo:
gr.Markdown("""
# Stable Diffusion Style Explorer
Generate images using various concepts from the SD Concepts Library, with optional car guidance.
## How to use:
1. Select a concept from the dropdown
2. Enter a base prompt (or use the default)
3. Set a seed for reproducibility
4. Choose whether to use car guidance
5. Click Generate!
Check out the examples below to see different combinations of concepts and prompts!
""")
with gr.Row():
with gr.Column():
concept = gr.Dropdown(
choices=list(CONCEPTS.keys()),
value="samurai-jack",
label="Select Concept"
)
prompt = gr.Textbox(
value="A serene landscape with mountains and a lake at sunset",
label="Base Prompt"
)
seed = gr.Number(
value=42,
label="Seed",
precision=0
)
car_guidance = gr.Checkbox(
value=False,
label="Use Car Guidance"
)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image")
concept.change(
fn=lambda x: gr.Markdown(f"Selected concept: {CONCEPTS[x]['description']} ({CONCEPTS[x]['type']})"),
inputs=[concept],
outputs=[gr.Markdown()]
)
generate_btn.click(
fn=generate_images,
inputs=[concept, prompt, seed, car_guidance],
outputs=[output_image]
)
# Gallery of pre-generated examples
gr.Markdown("### 🖼️ Pre-generated Examples")
with gr.Row():
# Samurai Jack examples
with gr.Column():
gr.Markdown("**Samurai Jack Style**")
gr.Image("Assignment17/Assignment17/outputs/samurai-jack_normal.png",
label="Without Car Guidance")
gr.Image("Assignment17/Assignment17/outputs/samurai-jack_car.png",
label="With Car Guidance")
with gr.Row():
# Canna Lily examples
with gr.Column():
gr.Markdown("**Canna Lily Object**")
gr.Image("Assignment17/Assignment17/outputs/canna-lily-flowers102_normal.png",
label="Without Car Guidance")
gr.Image("Assignment17/Assignment17/outputs/canna-lily-flowers102_car.png",
label="With Car Guidance")
with gr.Row():
# Babies Poster examples
with gr.Column():
gr.Markdown("**Babies Poster Style**")
gr.Image("Assignment17/Assignment17/outputs/babies-poster_normal.png",
label="Without Car Guidance")
gr.Image("Assignment17/Assignment17/outputs/babies-poster_car.png",
label="With Car Guidance")
with gr.Row():
# Animal Toy examples
with gr.Column():
gr.Markdown("**Animal Toy Object**")
gr.Image("Assignment17/Assignment17/outputs/animal-toy_normal.png",
label="Without Car Guidance")
gr.Image("Assignment17/Assignment17/outputs/animal-toy_car.png",
label="With Car Guidance")
with gr.Row():
# Sword Lily examples
with gr.Column():
gr.Markdown("**Sword Lily Object**")
gr.Image("Assignment17/Assignment17/outputs/sword-lily-flowers102_normal.png",
label="Without Car Guidance")
gr.Image("Assignment17/Assignment17/outputs/sword-lily-flowers102_car.png",
label="With Car Guidance")
demo.launch()