liuyizhang commited on
Commit
9003ca5
1 Parent(s): f0d812d

update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -250,11 +250,13 @@ def set_device():
250
 
251
  def load_groundingdino_model():
252
  # initialize groundingdino model
 
253
  logger.info(f"initialize groundingdino model...")
254
  groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
255
 
256
  def load_sam_model():
257
  # initialize SAM
 
258
  logger.info(f"initialize SAM model...")
259
  sam_device = device
260
  sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
@@ -263,6 +265,7 @@ def load_sam_model():
263
 
264
  def load_sd_model():
265
  # initialize stable-diffusion-inpainting
 
266
  logger.info(f"initialize stable-diffusion-inpainting...")
267
  sd_pipe = None
268
  if os.environ.get('IS_MY_DEBUG') is None:
@@ -276,6 +279,7 @@ def load_sd_model():
276
 
277
  def load_lama_cleaner_model():
278
  # initialize lama_cleaner
 
279
  logger.info(f"initialize lama_cleaner...")
280
  from lama_cleaner.helper import (
281
  load_img,
@@ -359,6 +363,7 @@ class Ram_Predictor(RamPredictor):
359
 
360
  def load_ram_model():
361
  # load ram model
 
362
  model_path = "./checkpoints/ram_epoch12.pth"
363
  ram_config = dict(
364
  model=dict(
@@ -674,23 +679,22 @@ def change_radio_display(task_type, mask_source_radio):
674
  num_relation_visible = True
675
  return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
676
 
677
- set_device()
678
- load_groundingdino_model()
679
- load_sam_model()
680
- load_sd_model()
681
- load_lama_cleaner_model()
682
- load_ram_model()
683
-
684
- os.system("pip list")
685
-
686
  if __name__ == "__main__":
687
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
688
  parser.add_argument("--debug", action="store_true", help="using debug mode")
689
  parser.add_argument("--share", action="store_true", help="share the app")
690
  args = parser.parse_args()
691
-
692
  print(f'args = {args}')
693
 
 
 
 
 
 
 
 
 
 
694
  block = gr.Blocks().queue()
695
  with block:
696
  with gr.Row():
 
250
 
251
  def load_groundingdino_model():
252
  # initialize groundingdino model
253
+ global groundingdino_model
254
  logger.info(f"initialize groundingdino model...")
255
  groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
256
 
257
  def load_sam_model():
258
  # initialize SAM
259
+ global sam_model, sam_predictor, sam_mask_generator
260
  logger.info(f"initialize SAM model...")
261
  sam_device = device
262
  sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
 
265
 
266
  def load_sd_model():
267
  # initialize stable-diffusion-inpainting
268
+ global sd_pipe
269
  logger.info(f"initialize stable-diffusion-inpainting...")
270
  sd_pipe = None
271
  if os.environ.get('IS_MY_DEBUG') is None:
 
279
 
280
  def load_lama_cleaner_model():
281
  # initialize lama_cleaner
282
+ global lama_cleaner_model
283
  logger.info(f"initialize lama_cleaner...")
284
  from lama_cleaner.helper import (
285
  load_img,
 
363
 
364
  def load_ram_model():
365
  # load ram model
366
+ global ram_model
367
  model_path = "./checkpoints/ram_epoch12.pth"
368
  ram_config = dict(
369
  model=dict(
 
679
  num_relation_visible = True
680
  return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
681
 
 
 
 
 
 
 
 
 
 
682
  if __name__ == "__main__":
683
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
684
  parser.add_argument("--debug", action="store_true", help="using debug mode")
685
  parser.add_argument("--share", action="store_true", help="share the app")
686
  args = parser.parse_args()
 
687
  print(f'args = {args}')
688
 
689
+ set_device()
690
+ load_groundingdino_model()
691
+ load_sam_model()
692
+ load_sd_model()
693
+ load_lama_cleaner_model()
694
+ load_ram_model()
695
+
696
+ os.system("pip list")
697
+
698
  block = gr.Blocks().queue()
699
  with block:
700
  with gr.Row():