Surn commited on
Commit
a85cd89
·
1 Parent(s): bb3e4e2

Attempt to integrate negative prompts

Browse files
Files changed (1) hide show
  1. app.py +64 -16
app.py CHANGED
@@ -157,7 +157,6 @@ def retrieve_timesteps(
157
  timesteps = scheduler.timesteps
158
  return timesteps, num_inference_steps
159
 
160
- # FLUX pipeline
161
  @torch.inference_mode()
162
  def flux_pipe_call_that_returns_an_iterable_of_images(
163
  self,
@@ -180,9 +179,11 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
180
  max_sequence_length: int = 512,
181
  good_vae: Optional[Any] = None,
182
  ):
 
183
  height = height or self.default_sample_size * self.vae_scale_factor
184
  width = width or self.default_sample_size * self.vae_scale_factor
185
 
 
186
  self.check_inputs(
187
  prompt,
188
  prompt_2,
@@ -201,7 +202,9 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
201
  device = self._execution_device
202
 
203
  lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
204
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
 
 
205
  prompt=prompt,
206
  prompt_2=prompt_2,
207
  prompt_embeds=prompt_embeds,
@@ -212,18 +215,38 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
212
  lora_scale=lora_scale,
213
  )
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  num_channels_latents = self.transformer.config.in_channels // 4
216
  latents, latent_image_ids = self.prepare_latents(
217
  batch_size * num_images_per_prompt,
218
  num_channels_latents,
219
  height,
220
  width,
221
- prompt_embeds.dtype,
222
  device,
223
  generator,
224
  latents,
225
  )
226
 
 
227
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
228
  image_seq_len = latents.shape[1]
229
  mu = calculate_shift(
@@ -243,41 +266,66 @@ def flux_pipe_call_that_returns_an_iterable_of_images(
243
  )
244
  self._num_timesteps = len(timesteps)
245
 
246
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
 
 
 
 
247
 
 
248
  for i, t in enumerate(timesteps):
249
- if self.interrupt:
250
  continue
251
 
252
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
253
  print(f"Step {i + 1}/{num_inference_steps} - Timestep: {timestep.item()}\n")
254
 
255
- noise_pred = self.transformer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  hidden_states=latents,
257
  timestep=timestep / 1000,
258
  guidance=guidance,
259
- pooled_projections=pooled_prompt_embeds,
260
- encoder_hidden_states=prompt_embeds,
261
- txt_ids=text_ids,
262
  img_ids=latent_image_ids,
263
  joint_attention_kwargs=self.joint_attention_kwargs,
264
  return_dict=False,
265
  )[0]
266
 
 
 
 
 
267
  latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
268
  latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
269
  image = self.vae.decode(latents_for_image, return_dict=False)[0]
270
  yield self.image_processor.postprocess(image, output_type=output_type)[0]
 
 
271
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
272
  torch.cuda.empty_cache()
273
 
 
274
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
275
  latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
276
  image = good_vae.decode(latents, return_dict=False)[0]
277
  self.maybe_free_model_hooks()
278
  torch.cuda.empty_cache()
279
  yield self.image_processor.postprocess(image, output_type=output_type)[0]
280
-
281
  #--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
282
 
283
  dtype = torch.bfloat16
@@ -343,7 +391,7 @@ def update_selection(evt: gr.SelectData, width, height, aspect_ratio):
343
  )
344
 
345
  @spaces.GPU(duration=120,progress=gr.Progress(track_tqdm=True))
346
- def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
347
  pipe.to("cuda")
348
  generator = torch.Generator(device="cuda").manual_seed(seed)
349
  flash_attention_enabled = torch.backends.cuda.flash_sdp_enabled()
@@ -384,7 +432,7 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
384
  ):
385
  yield img
386
 
387
- def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed, progress):
388
  generator = torch.Generator(device="cuda").manual_seed(seed)
389
  pipe_i2i.to("cuda")
390
  flash_attention_enabled = torch.backends.cuda.flash_sdp_enabled()
@@ -447,7 +495,7 @@ def run_lora(prompt, map_option, image_input, image_strength, cfg_scale, steps,
447
  print(f"Conditioned Image: {image_input.size}.. converted to RGB and resized\n")
448
  if map_option != "Prompt":
449
  prompt = PROMPTS[map_option]
450
- # negative_prompt = NEGATIVE_PROMPTS.get(map_option, "")
451
 
452
  selected_lora = loras[selected_index]
453
  lora_path = selected_lora["repo"]
@@ -484,7 +532,7 @@ def run_lora(prompt, map_option, image_input, image_strength, cfg_scale, steps,
484
 
485
  if(image_input is not None):
486
  print(f"\nGenerating image to image with seed: {seed}\n")
487
- generated_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed, progress)
488
 
489
  if enlarge:
490
  upscaled_image = upscale_image(generated_image, max(1.0,min((TARGET_SIZE[0]/width),(TARGET_SIZE[1]/height))))
@@ -498,7 +546,7 @@ def run_lora(prompt, map_option, image_input, image_strength, cfg_scale, steps,
498
  final_image = tmp_upscaled.name
499
  yield final_image, seed, gr.update(visible=False)
500
  else:
501
- image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
502
 
503
  final_image = None
504
  step_counter = 0
@@ -816,7 +864,7 @@ with gr.Blocks(css_paths="style_20250314.css", title=title, theme='Surn/beeuty',
816
  label="Prompt",
817
  visible=False,
818
  elem_classes="solid",
819
- value="top-down, (rectangular tabletop_map) alien planet map, Battletech_boardgame scifi world with forests, lakes, oceans, continents and snow at the top and bottom, (middle is dark, no_reflections, no_shadows), from directly above. From 100,000 feet looking straight down. 10000 foot-view",
820
  lines=4
821
  )
822
  negative_prompt_textbox = gr.Textbox(
 
157
  timesteps = scheduler.timesteps
158
  return timesteps, num_inference_steps
159
 
 
160
  @torch.inference_mode()
161
  def flux_pipe_call_that_returns_an_iterable_of_images(
162
  self,
 
179
  max_sequence_length: int = 512,
180
  good_vae: Optional[Any] = None,
181
  ):
182
+ # Set default height and width
183
  height = height or self.default_sample_size * self.vae_scale_factor
184
  width = width or self.default_sample_size * self.vae_scale_factor
185
 
186
+ # Validate inputs
187
  self.check_inputs(
188
  prompt,
189
  prompt_2,
 
202
  device = self._execution_device
203
 
204
  lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
205
+
206
+ # Encode the positive prompt
207
+ prompt_embeds_pos, pooled_prompt_embeds_pos, text_ids_pos = self.encode_prompt(
208
  prompt=prompt,
209
  prompt_2=prompt_2,
210
  prompt_embeds=prompt_embeds,
 
215
  lora_scale=lora_scale,
216
  )
217
 
218
+ # Encode the negative prompt if provided
219
+ if negative_prompt is not None:
220
+ prompt_embeds_neg, pooled_prompt_embeds_neg, text_ids_neg = self.encode_prompt(
221
+ prompt=negative_prompt,
222
+ prompt_2=None, # Assuming no secondary prompt for negative
223
+ prompt_embeds=None,
224
+ pooled_prompt_embeds=None,
225
+ device=device,
226
+ num_images_per_prompt=num_images_per_prompt,
227
+ max_sequence_length=max_sequence_length,
228
+ lora_scale=lora_scale,
229
+ )
230
+ else:
231
+ # Fallback to positive embeddings if no negative prompt is provided
232
+ prompt_embeds_neg = prompt_embeds_pos
233
+ pooled_prompt_embeds_neg = pooled_prompt_embeds_pos
234
+ text_ids_neg = text_ids_pos
235
+
236
+ # Prepare latents
237
  num_channels_latents = self.transformer.config.in_channels // 4
238
  latents, latent_image_ids = self.prepare_latents(
239
  batch_size * num_images_per_prompt,
240
  num_channels_latents,
241
  height,
242
  width,
243
+ prompt_embeds_pos.dtype,
244
  device,
245
  generator,
246
  latents,
247
  )
248
 
249
+ # Set up timesteps
250
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
251
  image_seq_len = latents.shape[1]
252
  mu = calculate_shift(
 
266
  )
267
  self._num_timesteps = len(timesteps)
268
 
269
+ guidance = (
270
+ torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0])
271
+ if self.transformer.config.guidance_embeds
272
+ else None
273
+ )
274
 
275
+ # Denoising loop
276
  for i, t in enumerate(timesteps):
277
+ if self._interrupt:
278
  continue
279
 
280
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
281
  print(f"Step {i + 1}/{num_inference_steps} - Timestep: {timestep.item()}\n")
282
 
283
+ # Compute noise prediction for positive prompt
284
+ noise_pred_pos = self.transformer(
285
+ hidden_states=latents,
286
+ timestep=timestep / 1000,
287
+ guidance=guidance,
288
+ pooled_projections=pooled_prompt_embeds_pos,
289
+ encoder_hidden_states=prompt_embeds_pos,
290
+ txt_ids=text_ids_pos,
291
+ img_ids=latent_image_ids,
292
+ joint_attention_kwargs=self.joint_attention_kwargs,
293
+ return_dict=False,
294
+ )[0]
295
+
296
+ # Compute noise prediction for negative prompt
297
+ noise_pred_neg = self.transformer(
298
  hidden_states=latents,
299
  timestep=timestep / 1000,
300
  guidance=guidance,
301
+ pooled_projections=pooled_prompt_embeds_neg,
302
+ encoder_hidden_states=prompt_embeds_neg,
303
+ txt_ids=text_ids_neg,
304
  img_ids=latent_image_ids,
305
  joint_attention_kwargs=self.joint_attention_kwargs,
306
  return_dict=False,
307
  )[0]
308
 
309
+ # Combine noise predictions using guidance scale
310
+ noise_pred = noise_pred_neg + guidance_scale * (noise_pred_pos - noise_pred_neg)
311
+
312
+ # Generate intermediate image
313
  latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
314
  latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
315
  image = self.vae.decode(latents_for_image, return_dict=False)[0]
316
  yield self.image_processor.postprocess(image, output_type=output_type)[0]
317
+
318
+ # Update latents with combined noise prediction
319
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
320
  torch.cuda.empty_cache()
321
 
322
+ # Final image generation
323
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
324
  latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
325
  image = good_vae.decode(latents, return_dict=False)[0]
326
  self.maybe_free_model_hooks()
327
  torch.cuda.empty_cache()
328
  yield self.image_processor.postprocess(image, output_type=output_type)[0]
 
329
  #--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
330
 
331
  dtype = torch.bfloat16
 
391
  )
392
 
393
  @spaces.GPU(duration=120,progress=gr.Progress(track_tqdm=True))
394
+ def generate_image(prompt_mash, negative_prompt, steps, seed, cfg_scale, width, height, lora_scale, progress):
395
  pipe.to("cuda")
396
  generator = torch.Generator(device="cuda").manual_seed(seed)
397
  flash_attention_enabled = torch.backends.cuda.flash_sdp_enabled()
 
432
  ):
433
  yield img
434
 
435
+ def generate_image_to_image(prompt_mash, negative_prompt, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed, progress):
436
  generator = torch.Generator(device="cuda").manual_seed(seed)
437
  pipe_i2i.to("cuda")
438
  flash_attention_enabled = torch.backends.cuda.flash_sdp_enabled()
 
495
  print(f"Conditioned Image: {image_input.size}.. converted to RGB and resized\n")
496
  if map_option != "Prompt":
497
  prompt = PROMPTS[map_option]
498
+ negative_prompt = NEGATIVE_PROMPTS.get(map_option, "")
499
 
500
  selected_lora = loras[selected_index]
501
  lora_path = selected_lora["repo"]
 
532
 
533
  if(image_input is not None):
534
  print(f"\nGenerating image to image with seed: {seed}\n")
535
+ generated_image = generate_image_to_image(prompt_mash, negative_prompt, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed, progress)
536
 
537
  if enlarge:
538
  upscaled_image = upscale_image(generated_image, max(1.0,min((TARGET_SIZE[0]/width),(TARGET_SIZE[1]/height))))
 
546
  final_image = tmp_upscaled.name
547
  yield final_image, seed, gr.update(visible=False)
548
  else:
549
+ image_generator = generate_image(prompt_mash, negative_prompt, steps, seed, cfg_scale, width, height, lora_scale, progress)
550
 
551
  final_image = None
552
  step_counter = 0
 
864
  label="Prompt",
865
  visible=False,
866
  elem_classes="solid",
867
+ value="Planetary overhead view, directly from above, centered on the planet’s surface, (rectangular tabletop_map) alien planet map, Battletech_boardgame scifi world with forests, lakes, oceans, continents and snow at the top and bottom, (middle is dark, no_reflections, no_shadows), looking straight down.",
868
  lines=4
869
  )
870
  negative_prompt_textbox = gr.Textbox(