yizhangliu commited on
Commit
0cc37e5
β€’
1 Parent(s): f47bc1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -30
app.py CHANGED
@@ -8,9 +8,6 @@ import gradio as gr
8
 
9
  from loguru import logger
10
 
11
- # os.system("pip install diffuser==0.6.0")
12
- # os.system("pip install transformers==4.29.1")
13
-
14
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
15
 
16
  if os.environ.get('IS_MY_DEBUG') is None:
@@ -69,7 +66,10 @@ ckpt_repo_id = "ShilongLiu/GroundingDINO"
69
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
70
  sam_checkpoint = './sam_vit_h_4b8939.pth'
71
  output_dir = "outputs"
72
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
73
 
74
  os.makedirs(output_dir, exist_ok=True)
75
  groundingdino_model = None
@@ -77,8 +77,9 @@ sam_device = None
77
  sam_model = None
78
  sam_predictor = None
79
  sam_mask_generator = None
80
- sd_pipe = None
81
  lama_cleaner_model= None
 
82
  ram_model = None
83
 
84
  def get_sam_vit_h_4b8939():
@@ -165,16 +166,6 @@ def load_image(image_path):
165
  image, _ = transform(image_pil, None) # 3, h, w
166
  return image_pil, image
167
 
168
- def load_model(model_config_path, model_checkpoint_path, device):
169
- args = SLConfig.fromfile(model_config_path)
170
- args.device = device
171
- model = build_model(args)
172
- checkpoint = torch.load(model_checkpoint_path, map_location=device) #"cpu")
173
- load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
174
- print(load_res)
175
- _ = model.eval()
176
- return model
177
-
178
  def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
179
  caption = caption.lower()
180
  caption = caption.strip()
@@ -258,18 +249,21 @@ def mix_masks(imgs):
258
  return Image.fromarray(np.uint8(255*re_img))
259
 
260
  def set_device():
261
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
262
- print(f'device={device}')
 
 
 
263
 
264
  def load_groundingdino_model():
265
  # initialize groundingdino model
266
  global groundingdino_model
267
  logger.info(f"initialize groundingdino model...")
268
- groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
269
 
270
  def load_sam_model():
271
  # initialize SAM
272
- global sam_model, sam_predictor, sam_mask_generator, sam_device
273
  logger.info(f"initialize SAM model...")
274
  sam_device = device
275
  sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
@@ -278,26 +272,26 @@ def load_sam_model():
278
 
279
  def load_sd_model():
280
  # initialize stable-diffusion-inpainting
281
- global sd_pipe
282
  logger.info(f"initialize stable-diffusion-inpainting...")
283
- sd_pipe = None
284
  if os.environ.get('IS_MY_DEBUG') is None:
285
- sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
286
  "runwayml/stable-diffusion-inpainting",
287
  revision="fp16",
288
  # "stabilityai/stable-diffusion-2-inpainting",
289
  torch_dtype=torch.float16,
290
  )
291
- sd_pipe = sd_pipe.to(device)
292
 
293
  def load_lama_cleaner_model():
294
  # initialize lama_cleaner
295
- global lama_cleaner_model
296
  logger.info(f"initialize lama_cleaner...")
297
 
298
  lama_cleaner_model = ModelManager(
299
  name='lama',
300
- device='cpu', # device,
301
  )
302
 
303
  def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
@@ -517,6 +511,7 @@ mask_source_segment = "type what to detect below"
517
 
518
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
519
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, cleaner_size_limit=1080):
 
520
  if (task_type == 'relate anything'):
521
  output_images = relate_anything(input_image['image'], num_relation)
522
  return output_images, gr.Gallery.update(label='relate images')
@@ -566,7 +561,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
566
  groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
567
  )
568
  if boxes_filt.size(0) == 0:
569
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_')
570
  return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
571
  boxes_filt_ori = copy.deepcopy(boxes_filt)
572
 
@@ -640,7 +635,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
640
  # inpainting pipeline
641
  image_source_for_inpaint = image_pil.resize((512, 512))
642
  image_mask_for_inpaint = mask_pil.resize((512, 512))
643
- image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
644
  else:
645
  # remove from mask
646
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_5_')
@@ -707,6 +702,8 @@ def change_radio_display(task_type, mask_source_radio):
707
 
708
  def get_model_device(module):
709
  try:
 
 
710
  if isinstance(module, torch.nn.DataParallel):
711
  module = module.module
712
  for submodule in module.children():
@@ -714,8 +711,9 @@ def get_model_device(module):
714
  parameters = submodule._parameters
715
  if "weight" in parameters:
716
  return parameters["weight"].device
 
717
  except Exception as e:
718
- return 'ohoh'
719
 
720
  if __name__ == "__main__":
721
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
@@ -732,10 +730,12 @@ if __name__ == "__main__":
732
  load_lama_cleaner_model()
733
  load_ram_model()
734
 
735
- os.system("pip list")
 
 
736
  print(f'groundingdino_model__{get_model_device(groundingdino_model)}')
737
  print(f'sam_model__{get_model_device(sam_model)}')
738
- print(f'sd_model__{get_model_device(sd_pipe)}')
739
  print(f'lama_cleaner_model__{get_model_device(lama_cleaner_model)}')
740
  print(f'ram_model__{get_model_device(ram_model)}')
741
 
@@ -790,3 +790,4 @@ if __name__ == "__main__":
790
 
791
  computer_info()
792
  block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
 
 
8
 
9
  from loguru import logger
10
 
 
 
 
11
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
12
 
13
  if os.environ.get('IS_MY_DEBUG') is None:
 
66
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
67
  sam_checkpoint = './sam_vit_h_4b8939.pth'
68
  output_dir = "outputs"
69
+ if os.environ.get('IS_MY_DEBUG') is None:
70
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
71
+ else:
72
+ device = 'cpu'
73
 
74
  os.makedirs(output_dir, exist_ok=True)
75
  groundingdino_model = None
 
77
  sam_model = None
78
  sam_predictor = None
79
  sam_mask_generator = None
80
+ sd_model = None
81
  lama_cleaner_model= None
82
+ lama_cleaner_model_device = device
83
  ram_model = None
84
 
85
  def get_sam_vit_h_4b8939():
 
166
  image, _ = transform(image_pil, None) # 3, h, w
167
  return image_pil, image
168
 
 
 
 
 
 
 
 
 
 
 
169
  def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
170
  caption = caption.lower()
171
  caption = caption.strip()
 
249
  return Image.fromarray(np.uint8(255*re_img))
250
 
251
  def set_device():
252
+ global device
253
+ if os.environ.get('IS_MY_DEBUG') is None:
254
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
255
+ else:
256
+ device = 'cpu'
257
 
258
  def load_groundingdino_model():
259
  # initialize groundingdino model
260
  global groundingdino_model
261
  logger.info(f"initialize groundingdino model...")
262
+ groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device='cpu')
263
 
264
  def load_sam_model():
265
  # initialize SAM
266
+ global sam_model, sam_predictor, sam_mask_generator, sam_device, device
267
  logger.info(f"initialize SAM model...")
268
  sam_device = device
269
  sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
 
272
 
273
  def load_sd_model():
274
  # initialize stable-diffusion-inpainting
275
+ global sd_model, device
276
  logger.info(f"initialize stable-diffusion-inpainting...")
277
+ sd_model = None
278
  if os.environ.get('IS_MY_DEBUG') is None:
279
+ sd_model = StableDiffusionInpaintPipeline.from_pretrained(
280
  "runwayml/stable-diffusion-inpainting",
281
  revision="fp16",
282
  # "stabilityai/stable-diffusion-2-inpainting",
283
  torch_dtype=torch.float16,
284
  )
285
+ sd_model = sd_model.to(device)
286
 
287
  def load_lama_cleaner_model():
288
  # initialize lama_cleaner
289
+ global lama_cleaner_model, device
290
  logger.info(f"initialize lama_cleaner...")
291
 
292
  lama_cleaner_model = ModelManager(
293
  name='lama',
294
+ device=lama_cleaner_model_device,
295
  )
296
 
297
  def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
 
511
 
512
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
513
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, cleaner_size_limit=1080):
514
+
515
  if (task_type == 'relate anything'):
516
  output_images = relate_anything(input_image['image'], num_relation)
517
  return output_images, gr.Gallery.update(label='relate images')
 
561
  groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
562
  )
563
  if boxes_filt.size(0) == 0:
564
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
565
  return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
566
  boxes_filt_ori = copy.deepcopy(boxes_filt)
567
 
 
635
  # inpainting pipeline
636
  image_source_for_inpaint = image_pil.resize((512, 512))
637
  image_mask_for_inpaint = mask_pil.resize((512, 512))
638
+ image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
639
  else:
640
  # remove from mask
641
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_5_')
 
702
 
703
  def get_model_device(module):
704
  try:
705
+ if module is None:
706
+ return 'None'
707
  if isinstance(module, torch.nn.DataParallel):
708
  module = module.module
709
  for submodule in module.children():
 
711
  parameters = submodule._parameters
712
  if "weight" in parameters:
713
  return parameters["weight"].device
714
+ return 'UnKnown'
715
  except Exception as e:
716
+ return 'Error'
717
 
718
  if __name__ == "__main__":
719
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
 
730
  load_lama_cleaner_model()
731
  load_ram_model()
732
 
733
+ if os.environ.get('IS_MY_DEBUG') is None:
734
+ os.system("pip list")
735
+
736
  print(f'groundingdino_model__{get_model_device(groundingdino_model)}')
737
  print(f'sam_model__{get_model_device(sam_model)}')
738
+ print(f'sd_model__{get_model_device(sd_model)}')
739
  print(f'lama_cleaner_model__{get_model_device(lama_cleaner_model)}')
740
  print(f'ram_model__{get_model_device(ram_model)}')
741
 
 
790
 
791
  computer_info()
792
  block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
793
+