JohanDL commited on
Commit
123829d
1 Parent(s): ee7a6d5
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
 
4
  import os
5
  import torch
6
  import torch.nn.functional as F
@@ -12,10 +13,10 @@ from depth_anything.dpt import DepthAnything
12
  from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
13
 
14
  @torch.no_grad()
15
- @spaces.GPU
16
  def predict_depth(model, image):
17
- return model(image)
18
 
 
19
  def make_video(video_path, outdir='./vis_video_depth',encoder='vitl'):
20
  if encoder not in ["vitl","vitb","vits"]:
21
  encoder = "vits"
@@ -28,7 +29,8 @@ def make_video(video_path, outdir='./vis_video_depth',encoder='vitl'):
28
 
29
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
30
  DEVICE = "cuda"
31
- depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_{}14'.format(encoder)).to(DEVICE).eval()
 
32
 
33
  total_params = sum(param.numel() for param in depth_anything.parameters())
34
  print('Total parameters: {:.2f}M'.format(total_params / 1e6))
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
+ from transformers import pipeline
5
  import os
6
  import torch
7
  import torch.nn.functional as F
 
13
  from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
14
 
15
  @torch.no_grad()
 
16
  def predict_depth(model, image):
17
+ return model(image)["depth"]
18
 
19
+ @spaces.GPU
20
  def make_video(video_path, outdir='./vis_video_depth',encoder='vitl'):
21
  if encoder not in ["vitl","vitb","vits"]:
22
  encoder = "vits"
 
29
 
30
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
31
  DEVICE = "cuda"
32
+ # depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_{}14'.format(encoder)).to(DEVICE).eval()
33
+ depth_anything = pipeline(task = "depth-estimation", model="nielsr/depth-anything-small", device=0)
34
 
35
  total_params = sum(param.numel() for param in depth_anything.parameters())
36
  print('Total parameters: {:.2f}M'.format(total_params / 1e6))