haodongli commited on
Commit
7fc5c17
·
1 Parent(s): dd7650e

add device

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -65,11 +65,12 @@ def ply2glb(ply_path, glb_path):
65
  cloud.export(glb_path)
66
  os.remove(ply_path)
67
 
68
- @spaces.GPU
69
  def fn(image_path, mask_path):
 
70
  name_base, _ = os.path.splitext(os.path.basename(image_path))
71
  config, accelerator = prepare_to_run_demo()
72
  model = load_model(config, accelerator)
 
73
  image, cv2_image, mask = load_infer_data_demo(image_path, mask_path,
74
  model_dtype=config['spherevit']['dtype'], device=accelerator.device)
75
  if torch.backends.mps.is_available():
 
65
  cloud.export(glb_path)
66
  os.remove(ply_path)
67
 
 
68
  def fn(image_path, mask_path):
69
+ device = "cuda" if torch.cuda.is_available() else "cpu"
70
  name_base, _ = os.path.splitext(os.path.basename(image_path))
71
  config, accelerator = prepare_to_run_demo()
72
  model = load_model(config, accelerator)
73
+ model = model.to(device)
74
  image, cv2_image, mask = load_infer_data_demo(image_path, mask_path,
75
  model_dtype=config['spherevit']['dtype'], device=accelerator.device)
76
  if torch.backends.mps.is_available():