liuyizhang commited on
Commit
dfba81f
1 Parent(s): c4d99b7

update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -24
app.py CHANGED
@@ -38,6 +38,7 @@ import cv2
38
  import numpy as np
39
  import matplotlib.pyplot as plt
40
 
 
41
  sam_enable = True
42
  inpainting_enable = True
43
  ram_enable = True
@@ -103,16 +104,10 @@ sam_predictor = None
103
  sam_mask_generator = None
104
  sd_model = None
105
  lama_cleaner_model= None
106
- lama_cleaner_model_device = device
107
  ram_model = None
108
  kosmos_model = None
109
  kosmos_processor = None
110
 
111
- def get_sam_vit_h_4b8939():
112
- if not os.path.exists('./sam_vit_h_4b8939.pth'):
113
- logger.info(f"get sam_vit_h_4b8939.pth...")
114
- result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
115
- print(f'wget sam_vit_h_4b8939.pth result = {result}')
116
 
117
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
118
  args = SLConfig.fromfile(model_config_path)
@@ -282,24 +277,31 @@ def set_device():
282
  device = 'cpu'
283
  print(f'device={device}')
284
 
285
- def load_groundingdino_model():
286
  # initialize groundingdino model
287
- global groundingdino_model
288
  logger.info(f"initialize groundingdino model...")
289
- groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device='cpu')
 
 
 
 
 
 
 
290
 
291
- def load_sam_model():
292
  # initialize SAM
293
- global sam_model, sam_predictor, sam_mask_generator, sam_device, device
 
294
  logger.info(f"initialize SAM model...")
295
  sam_device = device
296
  sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
297
  sam_predictor = SamPredictor(sam_model)
298
  sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
299
 
300
- def load_sd_model():
301
  # initialize stable-diffusion-inpainting
302
- global sd_model, device
303
  logger.info(f"initialize stable-diffusion-inpainting...")
304
  sd_model = None
305
  if os.environ.get('IS_MY_DEBUG') is None:
@@ -311,14 +313,14 @@ def load_sd_model():
311
  )
312
  sd_model = sd_model.to(device)
313
 
314
- def load_lama_cleaner_model():
315
  # initialize lama_cleaner
316
- global lama_cleaner_model, device
317
  logger.info(f"initialize lama_cleaner...")
318
 
319
  lama_cleaner_model = ModelManager(
320
  name='lama',
321
- device=lama_cleaner_model_device,
322
  )
323
 
324
  def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
@@ -390,7 +392,7 @@ class Ram_Predictor(RamPredictor):
390
  self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
391
  self.model.train()
392
 
393
- def load_ram_model():
394
  # load ram model
395
  global ram_model
396
  if os.environ.get('IS_MY_DEBUG') is not None:
@@ -830,20 +832,20 @@ if __name__ == "__main__":
830
  if kosmos_enable:
831
  kosmos_model, kosmos_processor = load_kosmos_model(device)
832
 
833
- load_groundingdino_model()
 
834
 
835
  if sam_enable:
836
- get_sam_vit_h_4b8939()
837
- load_sam_model()
838
 
839
  if inpainting_enable:
840
- load_sd_model()
841
 
842
  if lama_cleaner_enable:
843
- load_lama_cleaner_model()
844
 
845
  if ram_enable:
846
- load_ram_model()
847
 
848
  if os.environ.get('IS_MY_DEBUG') is None:
849
  os.system("pip list")
@@ -865,7 +867,7 @@ if __name__ == "__main__":
865
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
866
  value=mask_source_segment, label="Mask from",
867
  visible=False)
868
- text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
869
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
870
  num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
871
 
@@ -946,6 +948,7 @@ if __name__ == "__main__":
946
  <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
947
  gr.Markdown(DESCRIPTION)
948
 
 
949
  computer_info()
950
  block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
951
 
 
38
  import numpy as np
39
  import matplotlib.pyplot as plt
40
 
41
+ groundingdino_enable = True
42
  sam_enable = True
43
  inpainting_enable = True
44
  ram_enable = True
 
104
  sam_mask_generator = None
105
  sd_model = None
106
  lama_cleaner_model= None
 
107
  ram_model = None
108
  kosmos_model = None
109
  kosmos_processor = None
110
 
 
 
 
 
 
111
 
112
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
113
  args = SLConfig.fromfile(model_config_path)
 
277
  device = 'cpu'
278
  print(f'device={device}')
279
 
280
+ def load_groundingdino_model(device):
281
  # initialize groundingdino model
 
282
  logger.info(f"initialize groundingdino model...")
283
+ groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device=device) #'cpu')
284
+ return groundingdino_model
285
+
286
+ def get_sam_vit_h_4b8939():
287
+ if not os.path.exists('./sam_vit_h_4b8939.pth'):
288
+ logger.info(f"get sam_vit_h_4b8939.pth...")
289
+ result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
290
+ print(f'wget sam_vit_h_4b8939.pth result = {result}')
291
 
292
+ def load_sam_model(device):
293
  # initialize SAM
294
+ global sam_model, sam_predictor, sam_mask_generator, sam_device
295
+ get_sam_vit_h_4b8939()
296
  logger.info(f"initialize SAM model...")
297
  sam_device = device
298
  sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
299
  sam_predictor = SamPredictor(sam_model)
300
  sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
301
 
302
+ def load_sd_model(device):
303
  # initialize stable-diffusion-inpainting
304
+ global sd_model
305
  logger.info(f"initialize stable-diffusion-inpainting...")
306
  sd_model = None
307
  if os.environ.get('IS_MY_DEBUG') is None:
 
313
  )
314
  sd_model = sd_model.to(device)
315
 
316
+ def load_lama_cleaner_model(device):
317
  # initialize lama_cleaner
318
+ global lama_cleaner_model
319
  logger.info(f"initialize lama_cleaner...")
320
 
321
  lama_cleaner_model = ModelManager(
322
  name='lama',
323
+ device=device,
324
  )
325
 
326
  def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
 
392
  self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
393
  self.model.train()
394
 
395
+ def load_ram_model(device):
396
  # load ram model
397
  global ram_model
398
  if os.environ.get('IS_MY_DEBUG') is not None:
 
832
  if kosmos_enable:
833
  kosmos_model, kosmos_processor = load_kosmos_model(device)
834
 
835
+ if groundingdino_enable:
836
+ groundingdino_model = load_groundingdino_model('cpu')
837
 
838
  if sam_enable:
839
+ load_sam_model(device)
 
840
 
841
  if inpainting_enable:
842
+ load_sd_model(device)
843
 
844
  if lama_cleaner_enable:
845
+ load_lama_cleaner_model(device)
846
 
847
  if ram_enable:
848
+ load_ram_model(device)
849
 
850
  if os.environ.get('IS_MY_DEBUG') is None:
851
  os.system("pip list")
 
867
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
868
  value=mask_source_segment, label="Mask from",
869
  visible=False)
870
+ text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
871
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
872
  num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
873
 
 
948
  <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
949
  gr.Markdown(DESCRIPTION)
950
 
951
+ print(f'device={device}')
952
  computer_info()
953
  block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
954