LiheYoung commited on
Commit
19c4b4d
1 Parent(s): 2b429ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -1,11 +1,12 @@
 
1
  import gradio as gr
 
 
 
2
  import torch
 
3
  from torchvision.transforms import Compose
4
  import tempfile
5
- from PIL import Image
6
- import numpy as np
7
- import cv2
8
- import torch.nn.functional as F
9
 
10
  from depth_anything.dpt import DPT_DINOv2
11
  from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
@@ -45,6 +46,11 @@ transform = Compose([
45
  PrepareForNet(),
46
  ])
47
 
 
 
 
 
 
48
  with gr.Blocks(css=css) as demo:
49
  gr.Markdown(title)
50
  gr.Markdown(description)
@@ -63,9 +69,8 @@ with gr.Blocks(css=css) as demo:
63
  image = transform({'image': image})['image']
64
  image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
65
 
66
- with torch.no_grad():
67
- depth = model(image)
68
- depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
69
 
70
  raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
71
  tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
 
1
+ import spaces
2
  import gradio as gr
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
  import torch
7
+ import torch.nn.functional as F
8
  from torchvision.transforms import Compose
9
  import tempfile
 
 
 
 
10
 
11
  from depth_anything.dpt import DPT_DINOv2
12
  from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
 
46
  PrepareForNet(),
47
  ])
48
 
49
+ @spaces.GPU
50
+ def predict_depth(model, image):
51
+ with torch.no_grad():
52
+ return model(image)
53
+
54
  with gr.Blocks(css=css) as demo:
55
  gr.Markdown(title)
56
  gr.Markdown(description)
 
69
  image = transform({'image': image})['image']
70
  image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
71
 
72
+ depth = predict_depth(model, image)
73
+ depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
 
74
 
75
  raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
76
  tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)