zixinz commited on
Commit
5458ff3
·
1 Parent(s): 69b2678

depth estimatro

Browse files
Files changed (1) hide show
  1. code_depth/depth_infer.py +58 -72
code_depth/depth_infer.py CHANGED
@@ -1,87 +1,73 @@
1
- # app.py
2
  import os
3
- import pathlib
4
- import subprocess
5
- import gradio as gr
6
- import spaces
7
  import torch
 
 
8
  from PIL import Image
9
 
10
- BASE_DIR = pathlib.Path(__file__).resolve().parent
11
- SCRIPT_DIR = BASE_DIR / "code_depth"
12
- GET_WEIGHTS_SH = SCRIPT_DIR / "get_weights.sh"
13
-
14
- # 让我们能 import 到 code_depth/depth_infer.py
15
  import sys
16
- if str(SCRIPT_DIR) not in sys.path:
17
- sys.path.append(str(SCRIPT_DIR))
18
 
19
- from depth_infer import DepthModel # noqa
20
 
21
- def _ensure_executable(p: pathlib.Path):
22
- if not p.exists():
23
- raise FileNotFoundError(f"Not found: {p}")
24
- os.chmod(p, os.stat(p).st_mode | 0o111)
25
 
26
- def ensure_weights():
27
- """在 code_depth 目录下运行你的 get_weights.sh。"""
28
- _ensure_executable(GET_WEIGHTS_SH)
29
- subprocess.run(
30
- ["bash", str(GET_WEIGHTS_SH)],
31
- check=True,
32
- cwd=str(SCRIPT_DIR),
33
- env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
34
- )
35
- ckpt_dir = SCRIPT_DIR / "checkpoints"
36
- if not ckpt_dir.exists():
37
- raise RuntimeError("weights download script ran but checkpoints/ not found")
38
- return str(ckpt_dir)
39
 
40
- # 启动时下载权重(不开持久化时,若环境重建会再次下载)
41
- try:
42
- CKPT_DIR = ensure_weights()
43
- print(f"✅ Weights ready in: {CKPT_DIR}")
44
- except Exception as e:
45
- print(f"⚠️ Failed to prepare weights: {e}")
46
 
47
- # 模型缓存(按 encoder 复用)
48
- _MODELS: dict[str, DepthModel] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- def get_model(encoder: str) -> DepthModel:
51
- if encoder not in _MODELS:
52
- _MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder)
53
- return _MODELS[encoder]
54
 
55
- @spaces.GPU
56
- def infer_depth(
57
- image: Image.Image,
58
- encoder: str = "vitl",
59
- max_res: int = 1280,
60
- input_size: int = 518,
61
- fp32: bool = False,
62
- grayscale: bool = False,
63
- ) -> Image.Image:
64
- # 这里才真正触发 CUDA 设备占用
65
- device = "cuda" if torch.cuda.is_available() else "cpu"
66
- print(f"[infer] device={device}, encoder={encoder}, max_res={max_res}, input_size={input_size}, fp32={fp32}, gray={grayscale}")
67
- model = get_model(encoder)
68
- return model.infer(image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=grayscale)
69
 
70
- with gr.Blocks() as demo:
71
- gr.Markdown("## GeoRemover · Depth Preview (Video-Depth-Anything)")
72
- with gr.Row():
73
- with gr.Column():
74
- inp = gr.Image(label="Upload image", type="pil")
75
- encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Encoder")
76
- max_res = gr.Slider(512, 2048, value=1280, step=64, label="Max resolution")
77
- input_size = gr.Slider(256, 1024, value=518, step=2, label="Model input_size")
78
- fp32 = gr.Checkbox(False, label="Use FP32 (default FP16)")
79
- gray = gr.Checkbox(False, label="Grayscale depth")
80
- btn = gr.Button("Run")
81
- with gr.Column():
82
- out = gr.Image(label="Depth visualization")
83
 
84
- btn.click(fn=infer_depth, inputs=[inp, encoder, max_res, input_size, fp32, gray], outputs=[out])
 
85
 
86
- if __name__ == "__main__":
87
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
+ # code_depth/depth_infer.py
2
  import os
3
+ from pathlib import Path
4
+ import numpy as np
 
 
5
  import torch
6
+ import cv2
7
+ import matplotlib.cm as cm
8
  from PIL import Image
9
 
10
+ # `from video_depth_anything.video_depth import VideoDepthAnything` 能被找到
11
+ HERE = Path(__file__).resolve().parent
 
 
 
12
  import sys
13
+ if str(HERE) not in sys.path:
14
+ sys.path.append(str(HERE))
15
 
16
+ from video_depth_anything.video_depth import VideoDepthAnything # noqa
17
 
18
+ _MODEL_CFGS = {
19
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
20
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
21
+ }
22
 
23
+ class DepthModel:
24
+ def __init__(self, repo_root: Path, encoder: str = "vitl", device: str | None = None):
25
+ self.encoder = encoder
26
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
27
+ self.model = VideoDepthAnything(**_MODEL_CFGS[encoder]).to(self.device).eval()
 
 
 
 
 
 
 
 
28
 
29
+ ckpt = repo_root / "code_depth" / "checkpoints" / f"video_depth_anything_{encoder}.pth"
30
+ if not ckpt.is_file():
31
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
32
+ state = torch.load(str(ckpt), map_location="cpu")
33
+ self.model.load_state_dict(state, strict=True)
 
34
 
35
+ @torch.inference_mode()
36
+ def infer(
37
+ self,
38
+ image: Image.Image | np.ndarray,
39
+ max_res: int = 1280,
40
+ input_size: int = 518,
41
+ fp32: bool = False,
42
+ grayscale: bool = False,
43
+ ) -> Image.Image:
44
+ """返回一张深度可视化图(PIL.Image)。"""
45
+ if isinstance(image, Image.Image):
46
+ rgb = np.array(image.convert("RGB"))
47
+ else:
48
+ # 假设是 numpy 的 RGB/HWC
49
+ assert image.ndim == 3 and image.shape[2] in (3, 4), "Expect HxWxC image"
50
+ rgb = image[..., :3].copy()
51
 
52
+ h, w = rgb.shape[:2]
53
+ if max(h, w) > max_res:
54
+ scale = max_res / max(h, w)
55
+ rgb = cv2.resize(rgb, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
56
 
57
+ # 模型接口是“视频深度”,单帧就堆一维
58
+ frame_tensor = np.stack([rgb], axis=0)
59
+ depths, _ = self.model.infer_video_depth(
60
+ frame_tensor, 32, input_size=input_size, device=self.device, fp32=fp32
61
+ )
62
+ depth = depths[0]
 
 
 
 
 
 
 
 
63
 
64
+ # 可视化
65
+ d_min, d_max = depth.min(), depth.max()
66
+ depth_norm = ((depth - d_min) / (d_max - d_min + 1e-6) * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
67
 
68
+ if grayscale:
69
+ return Image.fromarray(depth_norm, mode="L")
70
 
71
+ cmap = np.array(cm.get_cmap("inferno").colors)
72
+ depth_vis = (cmap[depth_norm] * 255).astype(np.uint8)
73
+ return Image.fromarray(depth_vis)