jatingocodeo commited on
Commit
6fca3f3
·
verified ·
1 Parent(s): e1378aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +311 -0
app.py CHANGED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionPipeline
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os
7
+ from huggingface_hub import hf_hub_download
8
+ import warnings
9
+ from transformers import CLIPProcessor, CLIPModel
10
+ warnings.filterwarnings("ignore")
11
+
12
+ # Check if CUDA is available
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"Using device: {device}")
15
+
16
+ # Load CLIP model for semantic guidance
17
+ print("Loading CLIP model for semantic guidance...")
18
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
19
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
20
+
21
+ # Dictionary of available concepts
22
+ CONCEPTS = {
23
+ "canna-lily-flowers102": {
24
+ "repo_id": "sd-concepts-library/canna-lily-flowers102",
25
+ "type": "object",
26
+ "description": "Canna lily flower style"
27
+ },
28
+ "samurai-jack": {
29
+ "repo_id": "sd-concepts-library/samurai-jack",
30
+ "type": "style",
31
+ "description": "Samurai Jack animation style"
32
+ },
33
+ "babies-poster": {
34
+ "repo_id": "sd-concepts-library/babies-poster",
35
+ "type": "style",
36
+ "description": "Babies poster art style"
37
+ },
38
+ "animal-toy": {
39
+ "repo_id": "sd-concepts-library/animal-toy",
40
+ "type": "object",
41
+ "description": "Animal toy style"
42
+ },
43
+ "sword-lily-flowers102": {
44
+ "repo_id": "sd-concepts-library/sword-lily-flowers102",
45
+ "type": "object",
46
+ "description": "Sword lily flower style"
47
+ }
48
+ }
49
+
50
+ def car_loss(image):
51
+ """Custom loss function that encourages the presence of cars in the image"""
52
+ # Convert PIL image to tensor if needed
53
+ if isinstance(image, Image.Image):
54
+ image = np.array(image)
55
+ image = torch.tensor(image, device=device)
56
+
57
+ # Process image for CLIP
58
+ with torch.no_grad():
59
+ # Convert to PIL for CLIP processing
60
+ pil_image = Image.fromarray(image.cpu().numpy().astype(np.uint8))
61
+
62
+ # Get CLIP features for the image
63
+ inputs = clip_processor(
64
+ text=["a photo of a car", "a photo without cars"],
65
+ images=pil_image,
66
+ return_tensors="pt",
67
+ padding=True
68
+ ).to(device)
69
+
70
+ # Get similarity scores
71
+ outputs = clip_model(**inputs)
72
+ logits_per_image = outputs.logits_per_image
73
+
74
+ # Higher score for the first text (with cars) is better
75
+ car_score = logits_per_image[0][0]
76
+ no_car_score = logits_per_image[0][1]
77
+
78
+ # We want to maximize car_score and minimize no_car_score
79
+ loss = -(car_score - no_car_score)
80
+
81
+ return loss
82
+
83
+ def generate_image(pipe, prompt, seed, guidance_scale=7.5, num_inference_steps=30, use_car_guidance=False):
84
+ """Generate an image with optional car guidance"""
85
+ generator = torch.Generator(device).manual_seed(seed)
86
+ custom_loss = car_loss if use_car_guidance else None
87
+
88
+ if custom_loss:
89
+ try:
90
+ # Start with a standard generation
91
+ init_images = pipe(
92
+ prompt,
93
+ guidance_scale=guidance_scale,
94
+ num_inference_steps=num_inference_steps // 2,
95
+ generator=generator
96
+ ).images
97
+
98
+ init_image = init_images[0]
99
+
100
+ # Refine using car guidance
101
+ from diffusers import StableDiffusionImg2ImgPipeline
102
+
103
+ img2img_pipe = StableDiffusionImg2ImgPipeline(
104
+ vae=pipe.vae,
105
+ text_encoder=pipe.text_encoder,
106
+ tokenizer=pipe.tokenizer,
107
+ unet=pipe.unet,
108
+ scheduler=pipe.scheduler,
109
+ safety_checker=None,
110
+ feature_extractor=None,
111
+ ).to(device)
112
+
113
+ strength = 0.75
114
+ current_image = init_image
115
+
116
+ for i in range(5):
117
+ current_loss = custom_loss(current_image)
118
+
119
+ refined_images = img2img_pipe(
120
+ prompt=prompt + ", with beautiful cars",
121
+ image=current_image,
122
+ strength=strength,
123
+ guidance_scale=guidance_scale,
124
+ generator=generator,
125
+ ).images
126
+
127
+ current_image = refined_images[0]
128
+ strength *= 0.8
129
+
130
+ return current_image
131
+
132
+ except Exception as e:
133
+ print(f"Error in car-guided generation: {e}")
134
+ return pipe(
135
+ prompt,
136
+ guidance_scale=guidance_scale,
137
+ num_inference_steps=num_inference_steps,
138
+ generator=generator
139
+ ).images[0]
140
+ else:
141
+ return pipe(
142
+ prompt,
143
+ guidance_scale=guidance_scale,
144
+ num_inference_steps=num_inference_steps,
145
+ generator=generator
146
+ ).images[0]
147
+
148
+ # Cache for loaded models and concepts
149
+ loaded_models = {}
150
+
151
+ def get_model_with_concept(concept_name):
152
+ """Get or load a model with the specified concept"""
153
+ if concept_name not in loaded_models:
154
+ concept_info = CONCEPTS[concept_name]
155
+
156
+ # Download concept embedding
157
+ concept_path = f"concepts/{concept_name}.bin"
158
+ os.makedirs("concepts", exist_ok=True)
159
+
160
+ if not os.path.exists(concept_path):
161
+ file = hf_hub_download(
162
+ repo_id=concept_info["repo_id"],
163
+ filename="learned_embeds.bin",
164
+ repo_type="model"
165
+ )
166
+ import shutil
167
+ shutil.copy(file, concept_path)
168
+
169
+ # Load model and concept
170
+ pipe = StableDiffusionPipeline.from_pretrained(
171
+ "stabilityai/stable-diffusion-2",
172
+ torch_dtype=torch.float32 if device == "cpu" else torch.float16,
173
+ safety_checker=None
174
+ ).to(device)
175
+
176
+ pipe.load_textual_inversion(concept_path)
177
+ loaded_models[concept_name] = pipe
178
+
179
+ return loaded_models[concept_name]
180
+
181
+ def generate_images(concept_name, base_prompt, seed, use_car_guidance):
182
+ """Generate images using the selected concept"""
183
+ try:
184
+ # Get model with concept
185
+ pipe = get_model_with_concept(concept_name)
186
+
187
+ # Construct prompt based on concept type
188
+ if CONCEPTS[concept_name]["type"] == "object":
189
+ prompt = f"A {base_prompt} with a <{concept_name}>"
190
+ else:
191
+ prompt = f"<{concept_name}> {base_prompt}"
192
+
193
+ # Generate image
194
+ image = generate_image(
195
+ pipe=pipe,
196
+ prompt=prompt,
197
+ seed=int(seed),
198
+ use_car_guidance=use_car_guidance
199
+ )
200
+
201
+ return image
202
+ except Exception as e:
203
+ raise gr.Error(f"Error generating image: {str(e)}")
204
+
205
+ # Create Gradio interface
206
+ with gr.Blocks(title="Stable Diffusion Style Explorer") as demo:
207
+ gr.Markdown("""
208
+ # Stable Diffusion Style Explorer
209
+
210
+ Generate images using various concepts from the SD Concepts Library, with optional car guidance.
211
+
212
+ ## How to use:
213
+ 1. Select a concept from the dropdown
214
+ 2. Enter a base prompt (or use the default)
215
+ 3. Set a seed for reproducibility
216
+ 4. Choose whether to use car guidance
217
+ 5. Click Generate!
218
+
219
+ Check out the examples below to see different combinations of concepts and prompts!
220
+ """)
221
+
222
+ with gr.Row():
223
+ with gr.Column():
224
+ concept = gr.Dropdown(
225
+ choices=list(CONCEPTS.keys()),
226
+ value="samurai-jack",
227
+ label="Select Concept"
228
+ )
229
+
230
+ prompt = gr.Textbox(
231
+ value="A serene landscape with mountains and a lake at sunset",
232
+ label="Base Prompt"
233
+ )
234
+
235
+ seed = gr.Number(
236
+ value=42,
237
+ label="Seed",
238
+ precision=0
239
+ )
240
+
241
+ car_guidance = gr.Checkbox(
242
+ value=False,
243
+ label="Use Car Guidance"
244
+ )
245
+
246
+ generate_btn = gr.Button("Generate Image")
247
+
248
+ with gr.Column():
249
+ output_image = gr.Image(label="Generated Image")
250
+
251
+ concept.change(
252
+ fn=lambda x: gr.Markdown(f"Selected concept: {CONCEPTS[x]['description']} ({CONCEPTS[x]['type']})"),
253
+ inputs=[concept],
254
+ outputs=[gr.Markdown()]
255
+ )
256
+
257
+ generate_btn.click(
258
+ fn=generate_images,
259
+ inputs=[concept, prompt, seed, car_guidance],
260
+ outputs=[output_image]
261
+ )
262
+
263
+ # Gallery of pre-generated examples
264
+ gr.Markdown("### 🖼️ Pre-generated Examples")
265
+
266
+ with gr.Row():
267
+ # Samurai Jack examples
268
+ with gr.Column():
269
+ gr.Markdown("**Samurai Jack Style**")
270
+ gr.Image("Assignment17/Assignment17/outputs/samurai-jack_normal.png",
271
+ label="Without Car Guidance")
272
+ gr.Image("Assignment17/Assignment17/outputs/samurai-jack_car.png",
273
+ label="With Car Guidance")
274
+
275
+ with gr.Row():
276
+ # Canna Lily examples
277
+ with gr.Column():
278
+ gr.Markdown("**Canna Lily Object**")
279
+ gr.Image("Assignment17/Assignment17/outputs/canna-lily-flowers102_normal.png",
280
+ label="Without Car Guidance")
281
+ gr.Image("Assignment17/Assignment17/outputs/canna-lily-flowers102_car.png",
282
+ label="With Car Guidance")
283
+
284
+ with gr.Row():
285
+ # Babies Poster examples
286
+ with gr.Column():
287
+ gr.Markdown("**Babies Poster Style**")
288
+ gr.Image("Assignment17/Assignment17/outputs/babies-poster_normal.png",
289
+ label="Without Car Guidance")
290
+ gr.Image("Assignment17/Assignment17/outputs/babies-poster_car.png",
291
+ label="With Car Guidance")
292
+
293
+ with gr.Row():
294
+ # Animal Toy examples
295
+ with gr.Column():
296
+ gr.Markdown("**Animal Toy Object**")
297
+ gr.Image("Assignment17/Assignment17/outputs/animal-toy_normal.png",
298
+ label="Without Car Guidance")
299
+ gr.Image("Assignment17/Assignment17/outputs/animal-toy_car.png",
300
+ label="With Car Guidance")
301
+
302
+ with gr.Row():
303
+ # Sword Lily examples
304
+ with gr.Column():
305
+ gr.Markdown("**Sword Lily Object**")
306
+ gr.Image("Assignment17/Assignment17/outputs/sword-lily-flowers102_normal.png",
307
+ label="Without Car Guidance")
308
+ gr.Image("Assignment17/Assignment17/outputs/sword-lily-flowers102_car.png",
309
+ label="With Car Guidance")
310
+
311
+ demo.launch()