NCJ commited on
Commit
849873b
1 Parent(s): e8a673b

workaround for torch.jit

Browse files
Files changed (1) hide show
  1. demo/mesh_recon.py +6 -1
demo/mesh_recon.py CHANGED
@@ -4,10 +4,14 @@ import numpy as np
4
  import torch
5
  import trimesh
6
 
 
7
 
8
- device = torch.device('cpu')
 
9
 
10
  # use torch hub
 
 
11
  model = torch.hub.load("isl-org/ZoeDepth", "ZoeD_NK", pretrained=True).to(device).eval()
12
 
13
 
@@ -97,6 +101,7 @@ def depth_edges_mask(depth):
97
  return mask
98
 
99
 
 
100
  def mesh_reconstruction(
101
  masked_image: np.ndarray,
102
  mask: np.ndarray,
 
4
  import torch
5
  import trimesh
6
 
7
+ import spaces
8
 
9
+
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
  # use torch hub
13
+ # zeroGPU hack from https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/9
14
+ torch.jit.script = lambda f: f
15
  model = torch.hub.load("isl-org/ZoeDepth", "ZoeD_NK", pretrained=True).to(device).eval()
16
 
17
 
 
101
  return mask
102
 
103
 
104
+ @spaces.GPU
105
  def mesh_reconstruction(
106
  masked_image: np.ndarray,
107
  mask: np.ndarray,