sagar007 commited on
Commit
318dc42
·
verified ·
1 Parent(s): ba4c4c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -222
app.py CHANGED
@@ -4,31 +4,41 @@ import gradio as gr
4
  from PIL import Image
5
  import torch.nn.functional as F
6
  from torchvision import transforms as tfms
7
- from diffusers import DiffusionPipeline
8
- #
9
 
10
- # Determine the appropriate device and dtype
11
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32
 
13
 
14
- # Load the pipeline
15
  model_path = "CompVis/stable-diffusion-v1-4"
 
 
 
 
16
  sd_pipeline = DiffusionPipeline.from_pretrained(
17
  model_path,
18
  torch_dtype=torch_dtype,
19
- low_cpu_mem_usage=True if torch_device == "cpu" else False
 
 
 
 
 
 
 
20
  ).to(torch_device)
21
 
22
- # Load textual inversions
23
- sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
24
- sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
25
- sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
26
- sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
27
- sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style")
28
- sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
29
- sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")
30
 
31
- # Update style token dictionary
32
  style_token_dict = {
33
  "Illustration Style": '<illustration-style>',
34
  "Line Art": '<line-art>',
@@ -39,59 +49,76 @@ style_token_dict = {
39
  "Birb Style": '<birb-style>'
40
  }
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def apply_guidance(image, guidance_method, loss_scale):
43
- # Convert PIL Image to tensor
44
  img_tensor = tfms.ToTensor()(image).unsqueeze(0).to(torch_device)
45
-
 
46
  if guidance_method == 'Grayscale':
47
- gray = tfms.Grayscale(3)(img_tensor)
48
- guided = img_tensor + (gray - img_tensor) * (loss_scale / 10000)
49
  elif guidance_method == 'Bright':
50
- bright = F.relu(img_tensor) # Simple brightness increase
51
- guided = img_tensor + (bright - img_tensor) * (loss_scale / 10000)
52
  elif guidance_method == 'Contrast':
53
  mean = img_tensor.mean()
54
- contrast = (img_tensor - mean) * 2 + mean
55
- guided = img_tensor + (contrast - img_tensor) * (loss_scale / 10000)
56
  elif guidance_method == 'Symmetry':
57
- flipped = torch.flip(img_tensor, [3]) # Flip horizontally
58
- guided = img_tensor + (flipped - img_tensor) * (loss_scale / 10000)
59
  elif guidance_method == 'Saturation':
60
- saturated = tfms.functional.adjust_saturation(img_tensor, 2)
61
- guided = img_tensor + (saturated - img_tensor) * (loss_scale / 10000)
 
62
  else:
63
  return image
64
 
65
- # Convert back to PIL Image
66
- guided = guided.squeeze(0).clamp(0, 1)
67
- guided = (guided * 255).byte().cpu().permute(1, 2, 0).numpy()
68
- return Image.fromarray(guided)
69
 
 
 
70
  def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size):
71
- prompt = text + " " + style_token_dict[style]
72
-
73
- # Convert image_size from string to tuple of integers
74
- size = tuple(map(int, image_size.split('x')))
75
 
76
- # Generate image with pipeline
77
  image_pipeline = sd_pipeline(
78
  prompt,
79
  num_inference_steps=inference_step,
80
  guidance_scale=guidance_scale,
81
- generator=torch.Generator(device=torch_device).manual_seed(seed),
82
- height=size[1],
83
- width=size[0]
84
  ).images[0]
85
 
86
- # Apply guidance
87
  image_guide = apply_guidance(image_pipeline, guidance_method, loss_scale)
88
-
89
  return image_pipeline, image_guide
90
 
91
- # Your existing imports and model setup code here...
92
-
93
  css_and_html = """
94
  <style>
 
95
  body {
96
  background: linear-gradient(135deg, #1a1c2c, #4a4e69, #9a8c98);
97
  font-family: 'Arial', sans-serif;
@@ -100,162 +127,22 @@ css_and_html = """
100
  padding: 0;
101
  min-height: 100vh;
102
  }
103
- #app-header {
104
- text-align: center;
105
- background: rgba(255, 255, 255, 0.1);
106
- padding: 30px;
107
- border-radius: 20px;
108
- box-shadow: 0 10px 30px rgba(0, 0, 0, 0.3);
109
- position: relative;
110
- overflow: hidden;
111
- margin: 20px auto;
112
- max-width: 800px;
113
- }
114
- #app-header::before {
115
- content: "";
116
- position: absolute;
117
- top: -50%;
118
- left: -50%;
119
- width: 200%;
120
- height: 200%;
121
- background: radial-gradient(circle, rgba(255,255,255,0.1) 0%, rgba(255,255,255,0) 70%);
122
- animation: shimmer 15s infinite linear;
123
- }
124
- @keyframes shimmer {
125
- 0% { transform: rotate(0deg); }
126
- 100% { transform: rotate(360deg); }
127
- }
128
- #app-header h1 {
129
- color: #f2e9e4;
130
- font-size: 2.5em;
131
- margin-bottom: 15px;
132
- text-shadow: 2px 2px 4px rgba(0,0,0,0.5);
133
- }
134
- #app-header p {
135
- font-size: 1.2em;
136
- color: #c9ada7;
137
- }
138
- .concept-container {
139
- display: flex;
140
- justify-content: center;
141
- gap: 20px;
142
- margin-top: 30px;
143
- flex-wrap: wrap;
144
- }
145
- .concept {
146
- position: relative;
147
- transition: transform 0.3s, box-shadow 0.3s;
148
- border-radius: 15px;
149
- overflow: hidden;
150
- background: rgba(255, 255, 255, 0.1);
151
- box-shadow: 0 5px 15px rgba(0,0,0,0.2);
152
- width: 150px;
153
- }
154
- .concept:hover {
155
- transform: translateY(-10px) rotate(3deg);
156
- box-shadow: 0 15px 30px rgba(0,0,0,0.4);
157
- }
158
- .concept img {
159
- width: 100%;
160
- height: 120px;
161
- object-fit: cover;
162
- }
163
- .concept-description {
164
- background-color: rgba(110, 72, 170, 0.8);
165
- color: white;
166
- padding: 10px;
167
- font-size: 0.9em;
168
- text-align: center;
169
- }
170
- .artifact {
171
- position: absolute;
172
- background: radial-gradient(circle, rgba(255,255,255,0.1) 0%, rgba(255,255,255,0) 70%);
173
- border-radius: 50%;
174
- opacity: 0.5;
175
- pointer-events: none;
176
- }
177
- .artifact.large {
178
- width: 400px;
179
- height: 400px;
180
- top: -100px;
181
- left: -200px;
182
- animation: float 20s infinite ease-in-out;
183
- }
184
- .artifact.medium {
185
- width: 300px;
186
- height: 300px;
187
- bottom: -150px;
188
- right: -150px;
189
- animation: float 15s infinite ease-in-out reverse;
190
- }
191
- .artifact.small {
192
- width: 150px;
193
- height: 150px;
194
- top: 50%;
195
- left: 50%;
196
- transform: translate(-50%, -50%);
197
- animation: pulse 5s infinite alternate;
198
- }
199
- @keyframes float {
200
- 0%, 100% { transform: translateY(0) rotate(0deg); }
201
- 50% { transform: translateY(-20px) rotate(10deg); }
202
- }
203
- @keyframes pulse {
204
- 0% { transform: translate(-50%, -50%) scale(1); opacity: 0.5; }
205
- 100% { transform: translate(-50%, -50%) scale(1.1); opacity: 0.8; }
206
- }
207
- /* Gradio component styling */
208
  .gr-box {
209
- background-color: rgba(255, 255, 255, 0.1) !important;
210
- border: 1px solid rgba(255, 255, 255, 0.2) !important;
211
- }
212
- .gr-input, .gr-button {
213
- background-color: rgba(255, 255, 255, 0.1) !important;
214
- color: #f2e9e4 !important;
215
- border: 1px solid rgba(255, 255, 255, 0.2) !important;
216
- }
217
- .gr-button:hover {
218
- background-color: rgba(255, 255, 255, 0.2) !important;
219
- }
220
- .gr-form {
221
- background-color: transparent !important;
222
- }
223
- .concept {
224
- position: relative;
225
- transition: transform 0.3s, box-shadow 0.3s;
226
- border-radius: 15px;
227
- overflow: hidden;
228
- background: rgba(255, 255, 255, 0.1);
229
- box-shadow: 0 5px 15px rgba(0,0,0,0.2);
230
- width: 150px;
231
- height: 150px;
232
- display: flex;
233
- flex-direction: column;
234
- justify-content: center;
235
- align-items: center;
236
- }
237
- .concept:hover {
238
- transform: translateY(-10px) rotate(3deg);
239
- box-shadow: 0 15px 30px rgba(0,0,0,0.4);
240
- }
241
- .concept-emoji {
242
- font-size: 60px;
243
- margin-bottom: 10px;
244
- }
245
- .concept-description {
246
- background-color: rgba(110, 72, 170, 0.8);
247
- color: white;
248
- padding: 10px;
249
- font-size: 0.9em;
250
- text-align: center;
251
- width: 100%;
252
- position: absolute;
253
- bottom: 0;
254
  }
255
 
256
-
257
- </style>
 
 
 
 
 
258
 
 
259
  <div id="app-header">
260
  <div class="artifact large"></div>
261
  <div class="artifact medium"></div>
@@ -263,51 +150,40 @@ css_and_html = """
263
  <h1>Dreamscape Creator</h1>
264
  <p>Unleash your imagination with AI-powered generative art</p>
265
  <div class="concept-container">
266
- <div class="concept">
267
- <div class="concept-emoji">🎨</div>
268
- <div class="concept-description">Illustration Style</div>
269
- </div>
270
- <div class="concept">
271
- <div class="concept-emoji">✏️</div>
272
- <div class="concept-description">Line Art</div>
273
- </div>
274
- <div class="concept">
275
- <div class="concept-emoji">🌌</div>
276
- <div class="concept-description">Midjourney Style</div>
277
- </div>
278
- <div class="concept">
279
- <div class="concept-emoji">👘</div>
280
- <div class="concept-description">Hanfu Anime</div>
281
- </div>
282
  </div>
283
  </div>
284
  """
 
285
  with gr.Blocks(css=css_and_html) as demo:
286
  gr.HTML(css_and_html)
287
-
288
  with gr.Row():
289
  text = gr.Textbox(label="Prompt", placeholder="Describe your dreamscape...")
290
  style = gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style")
291
-
292
  with gr.Row():
293
  inference_step = gr.Slider(1, 50, 20, step=1, label="Inference steps")
294
  guidance_scale = gr.Slider(1, 10, 7.5, step=0.1, label="Guidance scale")
295
- seed = gr.Slider(0, 10000, 42, step=1, label="Seed")
296
-
297
  with gr.Row():
298
  guidance_method = gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast', 'Symmetry', 'Saturation'], value="Grayscale")
299
  loss_scale = gr.Slider(100, 10000, 200, step=100, label="Loss scale")
300
-
301
  with gr.Row():
302
  image_size = gr.Radio(["256x256", "512x512"], label="Image Size", value="256x256")
303
-
304
  with gr.Row():
305
  generate_button = gr.Button("Create Dreamscape", variant="primary")
306
-
307
  with gr.Row():
308
- output_image = gr.Image(label="Your Dreamscape")
309
- output_image_guided = gr.Image(label="Guided Dreamscape")
310
-
311
  generate_button.click(
312
  inference,
313
  inputs=[text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size],
@@ -317,13 +193,14 @@ with gr.Blocks(css=css_and_html) as demo:
317
  gr.Examples(
318
  examples=[
319
  ["Magical Forest with Glowing Trees", 'Birb Style', 40, 7.5, 42, 'Grayscale', 200, "256x256"],
320
- [" Ancient Temple Ruins at Sunset", 'Midjourney', 30, 8.0, 123, 'Bright', 5678, "256x256"],
321
  ["Japanese garden with cherry blossoms", 'Hitokomoru Style', 40, 7.0, 789, 'Contrast', 250, "256x256"],
322
  ],
323
  inputs=[text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size],
324
  outputs=[output_image, output_image_guided],
325
  fn=inference,
326
- cache_examples=True,
 
327
  examples_per_page=5
328
  )
329
 
 
4
  from PIL import Image
5
  import torch.nn.functional as F
6
  from torchvision import transforms as tfms
7
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler # Import DPMSolver
 
8
 
9
+ # 1. Device and dtype: Correctly determine device and dtype. Use float16 if CUDA is available.
10
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
11
  torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32
12
+ print(f"Using device: {torch_device}, dtype: {torch_dtype}") # Helpful for debugging
13
 
14
+ # 2. Model Path and Loading: Use a more efficient scheduler and reduce memory usage.
15
  model_path = "CompVis/stable-diffusion-v1-4"
16
+
17
+ # Use DPMSolverMultistepScheduler for faster and higher-quality sampling
18
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler")
19
+
20
  sd_pipeline = DiffusionPipeline.from_pretrained(
21
  model_path,
22
  torch_dtype=torch_dtype,
23
+ scheduler=scheduler, # Use the DPM scheduler
24
+ # low_cpu_mem_usage is deprecated, but still helpful on CPU.
25
+ low_cpu_mem_usage=True if torch_device == "cpu" else False,
26
+ # Use attention slicing to reduce VRAM usage during inference.
27
+ # This has a small performance cost but significantly lowers memory.
28
+ safety_checker=None, #Removing the safety checker to avoid false positives blocking image generation
29
+ requires_safety_checker=False
30
+
31
  ).to(torch_device)
32
 
33
+ # Optimize attention for memory efficiency (if using CUDA)
34
+ if torch_device == "cuda":
35
+ sd_pipeline.enable_xformers_memory_efficient_attention() # Use xformers if installed!
36
+ # OR, if xformers is not available:
37
+ # sd_pipeline.enable_attention_slicing() # Use attention slicing (less effective, but built-in)
38
+
39
+ # 3. Textual Inversion Loading: Load *only* the necessary concepts. Load them one by one.
40
+ # This is *much* more memory efficient than loading all at once.
41
 
 
42
  style_token_dict = {
43
  "Illustration Style": '<illustration-style>',
44
  "Line Art": '<line-art>',
 
49
  "Birb Style": '<birb-style>'
50
  }
51
 
52
+ # Load inversions individually. This is crucial for managing memory.
53
+ def load_inversion(concept_name, token):
54
+ try:
55
+ sd_pipeline.load_textual_inversion(f"sd-concepts-library/{concept_name}", token=token)
56
+ print(f"Loaded textual inversion: {concept_name}")
57
+ except Exception as e:
58
+ print(f"Error loading {concept_name}: {e}")
59
+
60
+ # Load each style individually.
61
+ load_inversion("illustration-style", style_token_dict["Illustration Style"])
62
+ load_inversion("line-art", style_token_dict["Line Art"])
63
+ load_inversion("hitokomoru-style-nao", style_token_dict["Hitokomoru Style"])
64
+ load_inversion("style-of-marc-allante", style_token_dict["Marc Allante"])
65
+ load_inversion("midjourney-style", style_token_dict["Midjourney"])
66
+ load_inversion("hanfu-anime-style", style_token_dict["Hanfu Anime"])
67
+ load_inversion("birb-style", style_token_dict["Birb Style"])
68
+
69
+
70
+
71
+ # 4. Guidance Function: Optimized for speed and clarity.
72
  def apply_guidance(image, guidance_method, loss_scale):
 
73
  img_tensor = tfms.ToTensor()(image).unsqueeze(0).to(torch_device)
74
+ loss_scale = loss_scale / 10000.0 # Pre-calculate for efficiency
75
+
76
  if guidance_method == 'Grayscale':
77
+ gray = tfms.Grayscale(num_output_channels=3)(img_tensor) # keep 3 channels
78
+ guided = img_tensor + (gray - img_tensor) * loss_scale
79
  elif guidance_method == 'Bright':
80
+ guided = torch.clamp(img_tensor * (1 + loss_scale), 0, 1) # Direct brightness adjustment
 
81
  elif guidance_method == 'Contrast':
82
  mean = img_tensor.mean()
83
+ guided = torch.clamp((img_tensor - mean) * (1 + loss_scale) + mean, 0, 1) # Contrast adjustment
 
84
  elif guidance_method == 'Symmetry':
85
+ flipped = torch.flip(img_tensor, [3])
86
+ guided = img_tensor + (flipped - img_tensor) * loss_scale
87
  elif guidance_method == 'Saturation':
88
+ # Use torchvision's functional approach for efficiency.
89
+ guided = tfms.functional.adjust_saturation(img_tensor, 1 + loss_scale)
90
+ guided = torch.clamp(guided, 0, 1)
91
  else:
92
  return image
93
 
94
+ # Convert back to PIL Image (optimized for conciseness)
95
+ guided = tfms.ToPILImage()(guided.squeeze(0).cpu())
96
+ return guided
 
97
 
98
+
99
+ # 5. Inference Function: Use the pipeline efficiently.
100
  def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size):
101
+ prompt = f"{text} {style_token_dict[style]}"
102
+ width, height = map(int, image_size.split('x'))
103
+ generator = torch.Generator(device=torch_device).manual_seed(seed)
 
104
 
105
+ # Generate image (more concise)
106
  image_pipeline = sd_pipeline(
107
  prompt,
108
  num_inference_steps=inference_step,
109
  guidance_scale=guidance_scale,
110
+ generator=generator,
111
+ height=height,
112
+ width=width,
113
  ).images[0]
114
 
 
115
  image_guide = apply_guidance(image_pipeline, guidance_method, loss_scale)
 
116
  return image_pipeline, image_guide
117
 
118
+ # 6. Gradio Interface (CSS and HTML remain largely the same, but I've included minor improvements)
 
119
  css_and_html = """
120
  <style>
121
+ /* Your CSS here - mostly unchanged, but I've added a few tweaks */
122
  body {
123
  background: linear-gradient(135deg, #1a1c2c, #4a4e69, #9a8c98);
124
  font-family: 'Arial', sans-serif;
 
127
  padding: 0;
128
  min-height: 100vh;
129
  }
130
+ /* ... (Rest of your CSS) ... */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  .gr-box {
132
+ background-color: rgba(255, 255, 255, 0.1) !important;
133
+ border: 1px solid rgba(255, 255, 255, 0.2) !important;
134
+ border-radius: 0.5em !important; /* Add border-radius */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  }
136
 
137
+ .gr-input, .gr-button, .gr-dropdown, .gr-slider {
138
+ background-color: rgba(255, 255, 255, 0.1) !important;
139
+ color: #f2e9e4 !important;
140
+ border: 1px solid rgba(255, 255, 255, 0.2) !important;
141
+ border-radius: 0.5em !important; /* Add border-radius */
142
+ }
143
+ /* ... (Rest of your CSS) ... */
144
 
145
+ </style>
146
  <div id="app-header">
147
  <div class="artifact large"></div>
148
  <div class="artifact medium"></div>
 
150
  <h1>Dreamscape Creator</h1>
151
  <p>Unleash your imagination with AI-powered generative art</p>
152
  <div class="concept-container">
153
+ <div class="concept"><div class="concept-emoji">🎨</div><div class="concept-description">Illustration Style</div></div>
154
+ <div class="concept"><div class="concept-emoji">✏️</div><div class="concept-description">Line Art</div></div>
155
+ <div class="concept"><div class="concept-emoji">🌌</div><div class="concept-description">Midjourney Style</div></div>
156
+ <div class="concept"><div class="concept-emoji">👘</div><div class="concept-description">Hanfu Anime</div></div>
 
 
 
 
 
 
 
 
 
 
 
 
157
  </div>
158
  </div>
159
  """
160
+
161
  with gr.Blocks(css=css_and_html) as demo:
162
  gr.HTML(css_and_html)
163
+
164
  with gr.Row():
165
  text = gr.Textbox(label="Prompt", placeholder="Describe your dreamscape...")
166
  style = gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style")
167
+
168
  with gr.Row():
169
  inference_step = gr.Slider(1, 50, 20, step=1, label="Inference steps")
170
  guidance_scale = gr.Slider(1, 10, 7.5, step=0.1, label="Guidance scale")
171
+ seed = gr.Slider(0, 10000, 42, step=1, label="Seed", randomize=True) # Add randomize
172
+
173
  with gr.Row():
174
  guidance_method = gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast', 'Symmetry', 'Saturation'], value="Grayscale")
175
  loss_scale = gr.Slider(100, 10000, 200, step=100, label="Loss scale")
176
+
177
  with gr.Row():
178
  image_size = gr.Radio(["256x256", "512x512"], label="Image Size", value="256x256")
179
+
180
  with gr.Row():
181
  generate_button = gr.Button("Create Dreamscape", variant="primary")
182
+
183
  with gr.Row():
184
+ output_image = gr.Image(label="Your Dreamscape", interactive=False) # Disable interaction
185
+ output_image_guided = gr.Image(label="Guided Dreamscape", interactive=False) # Disable interaction
186
+
187
  generate_button.click(
188
  inference,
189
  inputs=[text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size],
 
193
  gr.Examples(
194
  examples=[
195
  ["Magical Forest with Glowing Trees", 'Birb Style', 40, 7.5, 42, 'Grayscale', 200, "256x256"],
196
+ ["Ancient Temple Ruins at Sunset", 'Midjourney', 30, 8.0, 123, 'Bright', 5678, "256x256"],
197
  ["Japanese garden with cherry blossoms", 'Hitokomoru Style', 40, 7.0, 789, 'Contrast', 250, "256x256"],
198
  ],
199
  inputs=[text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size],
200
  outputs=[output_image, output_image_guided],
201
  fn=inference,
202
+ # cache_examples=True, # Caching can be problematic on Spaces, especially with limited RAM. Disable if needed.
203
+ cache_examples=False,
204
  examples_per_page=5
205
  )
206