jeffaudi commited on
Commit
9f2a44b
1 Parent(s): 1377fc6

Testing device=cuda

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -43,7 +43,7 @@ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
43
  args.device = device
44
 
45
  cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
46
- checkpoint = torch.load(cache_file, map_location='cpu')
47
  log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
48
  print("Model loaded from {} \n => {}".format(cache_file, log))
49
  _ = model.eval()
@@ -65,7 +65,7 @@ def image_transform_grounding_for_vis(init_image):
65
  image, _ = transform(init_image, None) # 3, h, w
66
  return image
67
 
68
- model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
69
 
70
  def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
71
  init_image = input_image.convert("RGB")
 
43
  args.device = device
44
 
45
  cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
46
+ checkpoint = torch.load(cache_file, map_location=device)
47
  log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
48
  print("Model loaded from {} \n => {}".format(cache_file, log))
49
  _ = model.eval()
 
65
  image, _ = transform(init_image, None) # 3, h, w
66
  return image
67
 
68
+ model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device=device)
69
 
70
  def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
71
  init_image = input_image.convert("RGB")