fffiloni commited on
Commit
627fc63
1 Parent(s): eb756c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -126,10 +126,14 @@ models_rbm = core.Models(
126
  )
127
  models_rbm.generator.eval().requires_grad_(False)
128
 
129
- sam_model = LangSAM()
130
 
131
  def infer(ref_style_file, style_description, caption, progress):
132
  global models_rbm, models_b, device
 
 
 
 
133
 
134
  if low_vram:
135
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
@@ -237,8 +241,8 @@ def infer(ref_style_file, style_description, caption, progress):
237
  gc.collect()
238
 
239
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
240
- global models_rbm, models_b, device, sam_model
241
-
242
  if low_vram:
243
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
244
  models_to(sam_model, device=device)
@@ -276,7 +280,7 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progre
276
  ## SAM Mask for sub
277
  use_sam_mask = False
278
  x0_preview = models_rbm.previewer(x0_forward)
279
-
280
  x0_preview_pil = T.ToPILImage()(x0_preview[0].cpu())
281
  sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
282
  # sam_mask, boxes, phrases, logits = sam_model.predict(transform(x0_preview[0]), sam_prompt)
 
126
  )
127
  models_rbm.generator.eval().requires_grad_(False)
128
 
129
+
130
 
131
  def infer(ref_style_file, style_description, caption, progress):
132
  global models_rbm, models_b, device
133
+
134
+ if sam_model:
135
+ models_to(sam_model, device="cpu")
136
+ models_to(sam_model.sam, device="cpu")
137
 
138
  if low_vram:
139
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
 
241
  gc.collect()
242
 
243
  def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
244
+ global models_rbm, models_b, device
245
+ sam_model = LangSAM()
246
  if low_vram:
247
  models_to(models_rbm, device=device, excepts=["generator", "previewer"])
248
  models_to(sam_model, device=device)
 
280
  ## SAM Mask for sub
281
  use_sam_mask = False
282
  x0_preview = models_rbm.previewer(x0_forward)
283
+
284
  x0_preview_pil = T.ToPILImage()(x0_preview[0].cpu())
285
  sam_mask, boxes, phrases, logits = sam_model.predict(x0_preview_pil, sam_prompt)
286
  # sam_mask, boxes, phrases, logits = sam_model.predict(transform(x0_preview[0]), sam_prompt)