Stable-X commited on
Commit
fd89d5f
1 Parent(s): e4bf056

fix: Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -3,6 +3,7 @@ import time
3
  import torch
4
  import numpy as np
5
  import gradio as gr
 
6
  import tempfile
7
  import subprocess
8
  from dust3r.losses import L21
@@ -11,6 +12,7 @@ from spann3r.datasets import Demo
11
  from torch.utils.data import DataLoader
12
  import trimesh
13
  from scipy.spatial.transform import Rotation
 
14
 
15
  # Default values
16
  DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
@@ -45,10 +47,22 @@ def cat_meshes(meshes):
45
  faces = np.concatenate(faces)
46
  return dict(vertices=vertices, face_colors=colors, faces=faces)
47
 
 
 
 
 
 
 
 
 
 
 
 
48
  def load_model(ckpt_path, device):
49
  model = Spann3R(dus3r_name=DEFAULT_DUST3R_PATH,
50
  use_feat=False).to(device)
51
- model.load_state_dict(torch.load(ckpt_path)['model'])
 
52
  model.eval()
53
  return model
54
 
@@ -91,14 +105,14 @@ def pts3d_to_trimesh(img, pts3d, valid=None):
91
  assert len(faces) == len(face_colors)
92
  return dict(vertices=vertices, face_colors=face_colors, faces=faces)
93
 
 
 
 
94
  @torch.no_grad()
95
- def reconstruct(video_path, conf_thresh, kf_every, voxel_size=0.05, as_pointcloud=False):
96
  # Extract frames from video
97
  demo_path = extract_frames(video_path)
98
 
99
- # Load model
100
- model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
101
-
102
  # Load dataset
103
  dataset = Demo(ROOT=demo_path, resolution=224, full_video=True, kf_every=kf_every)
104
  dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
@@ -168,16 +182,15 @@ iface = gr.Interface(
168
  inputs=[
169
  gr.Video(label="Input Video"),
170
  gr.Slider(0, 1, value=1e-3, label="Confidence Threshold"),
171
- gr.Slider(1, 30, step=1, value=10, label="Keyframe Interval"),
172
- gr.Slider(0.001, 0.01, value=0.005, step=0.001, label="Voxel Size for Downsampling"),
173
  gr.Checkbox(label="As Pointcloud", value=False)
174
  ],
175
  outputs=[
176
  gr.Model3D(label="3D Model (GLB)", display_mode="solid"),
177
  gr.Textbox(label="Status")
178
  ],
179
- title="3D Reconstruction from Video",
180
  )
181
 
182
  if __name__ == "__main__":
183
- iface.launch(share=True)
 
3
  import torch
4
  import numpy as np
5
  import gradio as gr
6
+ import urllib.parse
7
  import tempfile
8
  import subprocess
9
  from dust3r.losses import L21
 
12
  from torch.utils.data import DataLoader
13
  import trimesh
14
  from scipy.spatial.transform import Rotation
15
+ import spaces
16
 
17
  # Default values
18
  DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
 
47
  faces = np.concatenate(faces)
48
  return dict(vertices=vertices, face_colors=colors, faces=faces)
49
 
50
+ def load_ckpt(model_path_or_url, verbose=True):
51
+ if verbose:
52
+ print('... loading model from', model_path_or_url)
53
+ is_url = urllib.parse.urlparse(model_path_or_url).scheme in ('http', 'https')
54
+
55
+ if is_url:
56
+ ckpt = torch.hub.load_state_dict_from_url(model_path_or_url, map_location='cpu', progress=verbose)
57
+ else:
58
+ ckpt = torch.load(model_path_or_url, map_location='cpu')
59
+ return ckpt
60
+
61
  def load_model(ckpt_path, device):
62
  model = Spann3R(dus3r_name=DEFAULT_DUST3R_PATH,
63
  use_feat=False).to(device)
64
+
65
+ model.load_state_dict(load_ckpt(ckpt_path)['model'])
66
  model.eval()
67
  return model
68
 
 
105
  assert len(faces) == len(face_colors)
106
  return dict(vertices=vertices, face_colors=face_colors, faces=faces)
107
 
108
+ model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
109
+
110
+ @spaces.GPU
111
  @torch.no_grad()
112
+ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False):
113
  # Extract frames from video
114
  demo_path = extract_frames(video_path)
115
 
 
 
 
116
  # Load dataset
117
  dataset = Demo(ROOT=demo_path, resolution=224, full_video=True, kf_every=kf_every)
118
  dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
 
182
  inputs=[
183
  gr.Video(label="Input Video"),
184
  gr.Slider(0, 1, value=1e-3, label="Confidence Threshold"),
185
+ gr.Slider(1, 30, step=1, value=5, label="Keyframe Interval"),
 
186
  gr.Checkbox(label="As Pointcloud", value=False)
187
  ],
188
  outputs=[
189
  gr.Model3D(label="3D Model (GLB)", display_mode="solid"),
190
  gr.Textbox(label="Status")
191
  ],
192
+ title="3D Reconstruction with Spatial Memory",
193
  )
194
 
195
  if __name__ == "__main__":
196
+ iface.launch()