fffiloni commited on
Commit
f3011b8
·
verified ·
1 Parent(s): cbf01ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -23
app.py CHANGED
@@ -272,48 +272,68 @@ def infer(ref_style_file, style_description, caption):
272
  # Reset the state after inference, regardless of success or failure
273
  reset_inference_state()
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
276
- global models_rbm, models_b
277
  try:
278
  caption = f"{caption} in {style_description}"
279
  sam_prompt = f"{caption}"
280
  use_sam_mask = False
281
 
282
- if low_vram:
283
- # Revert the devices of the modules back to their original state
284
- models_to(models_rbm, device)
285
 
286
  batch_size = 1
287
  height, width = 1024, 1024
288
  stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
289
 
290
- extras.sampling_configs['cfg'] = 4
291
- extras.sampling_configs['shift'] = 2
292
- extras.sampling_configs['timesteps'] = 20
293
- extras.sampling_configs['t_start'] = 1.0
294
- extras_b.sampling_configs['cfg'] = 1.1
295
- extras_b.sampling_configs['shift'] = 1
296
- extras_b.sampling_configs['timesteps'] = 10
297
- extras_b.sampling_configs['t_start'] = 1.0
298
-
299
  ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
300
  ref_images = resize_image(PIL.Image.open(ref_sub_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
301
 
302
- batch = {'captions': [caption] * batch_size}
303
- batch['style'] = ref_style
304
- batch['images'] = ref_images
305
 
306
- x0_forward = models_rbm.effnet(extras.effnet_preprocess(ref_images.to(device)))
307
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
308
 
309
  ## SAM Mask for sub
310
  use_sam_mask = False
311
  x0_preview = models_rbm.previewer(x0_forward)
312
  sam_model = LangSAM()
 
313
 
314
- # Convert tensor to PIL Image before passing to sam_model.predict
315
- x0_preview_pil = T.ToPILImage()(x0_preview[0])
316
- x0_preview_tensor = T.ToTensor()(x0_preview_pil) # Convert PIL Image back to tensor
317
  sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
318
  sam_mask = sam_mask.detach().unsqueeze(dim=0).to(device)
319
 
@@ -323,7 +343,6 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
323
  unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
324
 
325
  if low_vram:
326
- # The sampling process uses more vram, so we offload everything except two modules to the cpu.
327
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
328
  models_to(sam_model, device="cpu")
329
  models_to(sam_model.sam, device="cpu")
@@ -381,7 +400,7 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
381
 
382
  finally:
383
  # Reset the state after inference, regardless of success or failure
384
- reset_inference_state()
385
 
386
  import gradio as gr
387
 
 
272
  # Reset the state after inference, regardless of success or failure
273
  reset_inference_state()
274
 
275
+ def reset_compo_inference_state():
276
+ global models_rbm, models_b, extras, extras_b, device, core, core_b
277
+
278
+ # Reset sampling configurations
279
+ extras.sampling_configs['cfg'] = 4
280
+ extras.sampling_configs['shift'] = 2
281
+ extras.sampling_configs['timesteps'] = 20
282
+ extras.sampling_configs['t_start'] = 1.0
283
+
284
+ extras_b.sampling_configs['cfg'] = 1.1
285
+ extras_b.sampling_configs['shift'] = 1
286
+ extras_b.sampling_configs['timesteps'] = 10
287
+ extras_b.sampling_configs['t_start'] = 1.0
288
+
289
+ # Move models to CPU to free up GPU memory
290
+ models_to(models_rbm, device="cpu")
291
+ models_b.generator.to("cpu")
292
+
293
+ # Clear CUDA cache
294
+ torch.cuda.empty_cache()
295
+ gc.collect()
296
+
297
+ # Ensure all models are in eval mode and don't require gradients
298
+ for model in [models_rbm.generator, models_b.generator]:
299
+ model.eval()
300
+ for param in model.parameters():
301
+ param.requires_grad = False
302
+
303
+ # Clear CUDA cache again
304
+ torch.cuda.empty_cache()
305
+ gc.collect()
306
+
307
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file):
308
+ global models_rbm, models_b, device
309
  try:
310
  caption = f"{caption} in {style_description}"
311
  sam_prompt = f"{caption}"
312
  use_sam_mask = False
313
 
314
+ # Ensure all models are on the correct device
315
+ models_to(models_rbm, device)
316
+ models_b.generator.to(device)
317
 
318
  batch_size = 1
319
  height, width = 1024, 1024
320
  stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
321
 
 
 
 
 
 
 
 
 
 
322
  ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
323
  ref_images = resize_image(PIL.Image.open(ref_sub_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
324
 
325
+ batch = {'captions': [caption] * batch_size, 'style': ref_style, 'images': ref_images}
 
 
326
 
327
+ x0_forward = models_rbm.effnet(extras.effnet_preprocess(ref_images))
328
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
329
 
330
  ## SAM Mask for sub
331
  use_sam_mask = False
332
  x0_preview = models_rbm.previewer(x0_forward)
333
  sam_model = LangSAM()
334
+ sam_model.to(device)
335
 
336
+ x0_preview_pil = T.ToPILImage()(x0_preview[0].cpu())
 
 
337
  sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
338
  sam_mask = sam_mask.detach().unsqueeze(dim=0).to(device)
339
 
 
343
  unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
344
 
345
  if low_vram:
 
346
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
347
  models_to(sam_model, device="cpu")
348
  models_to(sam_model.sam, device="cpu")
 
400
 
401
  finally:
402
  # Reset the state after inference, regardless of success or failure
403
+ reset_compo_inference_state()
404
 
405
  import gradio as gr
406