Spaces:
Running
Running
File size: 10,720 Bytes
6fca3f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 |
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import numpy as np
import os
from huggingface_hub import hf_hub_download
import warnings
from transformers import CLIPProcessor, CLIPModel
warnings.filterwarnings("ignore")
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load CLIP model for semantic guidance
print("Loading CLIP model for semantic guidance...")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Dictionary of available concepts
CONCEPTS = {
"canna-lily-flowers102": {
"repo_id": "sd-concepts-library/canna-lily-flowers102",
"type": "object",
"description": "Canna lily flower style"
},
"samurai-jack": {
"repo_id": "sd-concepts-library/samurai-jack",
"type": "style",
"description": "Samurai Jack animation style"
},
"babies-poster": {
"repo_id": "sd-concepts-library/babies-poster",
"type": "style",
"description": "Babies poster art style"
},
"animal-toy": {
"repo_id": "sd-concepts-library/animal-toy",
"type": "object",
"description": "Animal toy style"
},
"sword-lily-flowers102": {
"repo_id": "sd-concepts-library/sword-lily-flowers102",
"type": "object",
"description": "Sword lily flower style"
}
}
def car_loss(image):
"""Custom loss function that encourages the presence of cars in the image"""
# Convert PIL image to tensor if needed
if isinstance(image, Image.Image):
image = np.array(image)
image = torch.tensor(image, device=device)
# Process image for CLIP
with torch.no_grad():
# Convert to PIL for CLIP processing
pil_image = Image.fromarray(image.cpu().numpy().astype(np.uint8))
# Get CLIP features for the image
inputs = clip_processor(
text=["a photo of a car", "a photo without cars"],
images=pil_image,
return_tensors="pt",
padding=True
).to(device)
# Get similarity scores
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
# Higher score for the first text (with cars) is better
car_score = logits_per_image[0][0]
no_car_score = logits_per_image[0][1]
# We want to maximize car_score and minimize no_car_score
loss = -(car_score - no_car_score)
return loss
def generate_image(pipe, prompt, seed, guidance_scale=7.5, num_inference_steps=30, use_car_guidance=False):
"""Generate an image with optional car guidance"""
generator = torch.Generator(device).manual_seed(seed)
custom_loss = car_loss if use_car_guidance else None
if custom_loss:
try:
# Start with a standard generation
init_images = pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps // 2,
generator=generator
).images
init_image = init_images[0]
# Refine using car guidance
from diffusers import StableDiffusionImg2ImgPipeline
img2img_pipe = StableDiffusionImg2ImgPipeline(
vae=pipe.vae,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
unet=pipe.unet,
scheduler=pipe.scheduler,
safety_checker=None,
feature_extractor=None,
).to(device)
strength = 0.75
current_image = init_image
for i in range(5):
current_loss = custom_loss(current_image)
refined_images = img2img_pipe(
prompt=prompt + ", with beautiful cars",
image=current_image,
strength=strength,
guidance_scale=guidance_scale,
generator=generator,
).images
current_image = refined_images[0]
strength *= 0.8
return current_image
except Exception as e:
print(f"Error in car-guided generation: {e}")
return pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
).images[0]
else:
return pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
).images[0]
# Cache for loaded models and concepts
loaded_models = {}
def get_model_with_concept(concept_name):
"""Get or load a model with the specified concept"""
if concept_name not in loaded_models:
concept_info = CONCEPTS[concept_name]
# Download concept embedding
concept_path = f"concepts/{concept_name}.bin"
os.makedirs("concepts", exist_ok=True)
if not os.path.exists(concept_path):
file = hf_hub_download(
repo_id=concept_info["repo_id"],
filename="learned_embeds.bin",
repo_type="model"
)
import shutil
shutil.copy(file, concept_path)
# Load model and concept
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2",
torch_dtype=torch.float32 if device == "cpu" else torch.float16,
safety_checker=None
).to(device)
pipe.load_textual_inversion(concept_path)
loaded_models[concept_name] = pipe
return loaded_models[concept_name]
def generate_images(concept_name, base_prompt, seed, use_car_guidance):
"""Generate images using the selected concept"""
try:
# Get model with concept
pipe = get_model_with_concept(concept_name)
# Construct prompt based on concept type
if CONCEPTS[concept_name]["type"] == "object":
prompt = f"A {base_prompt} with a <{concept_name}>"
else:
prompt = f"<{concept_name}> {base_prompt}"
# Generate image
image = generate_image(
pipe=pipe,
prompt=prompt,
seed=int(seed),
use_car_guidance=use_car_guidance
)
return image
except Exception as e:
raise gr.Error(f"Error generating image: {str(e)}")
# Create Gradio interface
with gr.Blocks(title="Stable Diffusion Style Explorer") as demo:
gr.Markdown("""
# Stable Diffusion Style Explorer
Generate images using various concepts from the SD Concepts Library, with optional car guidance.
## How to use:
1. Select a concept from the dropdown
2. Enter a base prompt (or use the default)
3. Set a seed for reproducibility
4. Choose whether to use car guidance
5. Click Generate!
Check out the examples below to see different combinations of concepts and prompts!
""")
with gr.Row():
with gr.Column():
concept = gr.Dropdown(
choices=list(CONCEPTS.keys()),
value="samurai-jack",
label="Select Concept"
)
prompt = gr.Textbox(
value="A serene landscape with mountains and a lake at sunset",
label="Base Prompt"
)
seed = gr.Number(
value=42,
label="Seed",
precision=0
)
car_guidance = gr.Checkbox(
value=False,
label="Use Car Guidance"
)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image")
concept.change(
fn=lambda x: gr.Markdown(f"Selected concept: {CONCEPTS[x]['description']} ({CONCEPTS[x]['type']})"),
inputs=[concept],
outputs=[gr.Markdown()]
)
generate_btn.click(
fn=generate_images,
inputs=[concept, prompt, seed, car_guidance],
outputs=[output_image]
)
# Gallery of pre-generated examples
gr.Markdown("### 🖼️ Pre-generated Examples")
with gr.Row():
# Samurai Jack examples
with gr.Column():
gr.Markdown("**Samurai Jack Style**")
gr.Image("Assignment17/Assignment17/outputs/samurai-jack_normal.png",
label="Without Car Guidance")
gr.Image("Assignment17/Assignment17/outputs/samurai-jack_car.png",
label="With Car Guidance")
with gr.Row():
# Canna Lily examples
with gr.Column():
gr.Markdown("**Canna Lily Object**")
gr.Image("Assignment17/Assignment17/outputs/canna-lily-flowers102_normal.png",
label="Without Car Guidance")
gr.Image("Assignment17/Assignment17/outputs/canna-lily-flowers102_car.png",
label="With Car Guidance")
with gr.Row():
# Babies Poster examples
with gr.Column():
gr.Markdown("**Babies Poster Style**")
gr.Image("Assignment17/Assignment17/outputs/babies-poster_normal.png",
label="Without Car Guidance")
gr.Image("Assignment17/Assignment17/outputs/babies-poster_car.png",
label="With Car Guidance")
with gr.Row():
# Animal Toy examples
with gr.Column():
gr.Markdown("**Animal Toy Object**")
gr.Image("Assignment17/Assignment17/outputs/animal-toy_normal.png",
label="Without Car Guidance")
gr.Image("Assignment17/Assignment17/outputs/animal-toy_car.png",
label="With Car Guidance")
with gr.Row():
# Sword Lily examples
with gr.Column():
gr.Markdown("**Sword Lily Object**")
gr.Image("Assignment17/Assignment17/outputs/sword-lily-flowers102_normal.png",
label="Without Car Guidance")
gr.Image("Assignment17/Assignment17/outputs/sword-lily-flowers102_car.png",
label="With Car Guidance")
demo.launch() |