freealise commited on
Commit
722a74e
1 Parent(s): 047eeea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -3
app.py CHANGED
@@ -8,6 +8,8 @@ import torch
8
  import torch.nn.functional as F
9
  from torchvision import transforms
10
  from torchvision.transforms import Compose
 
 
11
  import tempfile
12
  import spaces
13
  from zipfile import ZipFile
@@ -159,7 +161,73 @@ def make_video(video_path, outdir='./vis_video_depth', encoder='vits'):
159
  # out.release()
160
  cv2.destroyAllWindows()
161
 
162
- return final_vid, final_zip #output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  def loadurl(url):
165
  return url
@@ -212,10 +280,16 @@ with gr.Blocks(css=css) as demo:
212
  input_video = gr.Video(label="Input Video", format="mp4")
213
  input_url.change(fn=loadurl, inputs=[input_url], outputs=[input_video])
214
  submit = gr.Button("Submit")
 
215
  with gr.Column():
216
  model_type = gr.Dropdown([("small", "vits"), ("base", "vitb"), ("large", "vitl")], type="value", value="vits", label='Model Type')
 
217
  processed_video = gr.Video(label="Output Video", format="mp4")
218
  processed_zip = gr.File(label="Output Archive")
 
 
 
 
219
 
220
  def on_submit(uploaded_video,model_type):
221
 
@@ -224,12 +298,13 @@ with gr.Blocks(css=css) as demo:
224
 
225
  return output_video_path
226
 
227
- submit.click(on_submit, inputs=[input_video, model_type], outputs=[processed_video, processed_zip])
 
228
 
229
  example_files = os.listdir('examples')
230
  example_files.sort()
231
  example_files = [os.path.join('examples', filename) for filename in example_files]
232
- examples = gr.Examples(examples=example_files, inputs=[input_video], outputs=[processed_video, processed_zip], fn=on_submit, cache_examples=True)
233
 
234
 
235
  if __name__ == '__main__':
 
8
  import torch.nn.functional as F
9
  from torchvision import transforms
10
  from torchvision.transforms import Compose
11
+ import trimesh
12
+ from geometry import create_triangles
13
  import tempfile
14
  import spaces
15
  from zipfile import ZipFile
 
161
  # out.release()
162
  cv2.destroyAllWindows()
163
 
164
+ return final_vid, final_zip, orig_frames[0], depth_frames[0] #output_path
165
+
166
+ def depth_edges_mask(depth):
167
+ """Returns a mask of edges in the depth map.
168
+ Args:
169
+ depth: 2D numpy array of shape (H, W) with dtype float32.
170
+ Returns:
171
+ mask: 2D numpy array of shape (H, W) with dtype bool.
172
+ """
173
+ # Compute the x and y gradients of the depth map.
174
+ depth_dx, depth_dy = np.gradient(depth)
175
+ # Compute the gradient magnitude.
176
+ depth_grad = np.sqrt(depth_dx ** 2 + depth_dy ** 2)
177
+ # Compute the edge mask.
178
+ mask = depth_grad > 0.05
179
+ return mask
180
+
181
+ def pano_depth_to_world_points(depth):
182
+ """
183
+ 360 depth to world points
184
+ given 2D depth is an equirectangular projection of a spherical image
185
+ Treat depth as radius
186
+ longitude : -pi to pi
187
+ latitude : -pi/2 to pi/2
188
+ """
189
+
190
+ # Convert depth to radius
191
+ radius = depth.flatten()
192
+
193
+ lon = np.linspace(-np.pi, np.pi, depth.shape[1])
194
+ lat = np.linspace(-np.pi/2, np.pi/2, depth.shape[0])
195
+
196
+ lon, lat = np.meshgrid(lon, lat)
197
+ lon = lon.flatten()
198
+ lat = lat.flatten()
199
+
200
+ # Convert to cartesian coordinates
201
+ x = radius * np.cos(lat) * np.cos(lon)
202
+ y = radius * np.cos(lat) * np.sin(lon)
203
+ z = radius * np.sin(lat)
204
+
205
+ pts3d = np.stack([x, y, z], axis=1)
206
+
207
+ return pts3d
208
+
209
+ def get_mesh(image, depth, keep_edges=False):
210
+ image.thumbnail((1024,1024)) # limit the size of the image
211
+ pts3d = pano_depth_to_world_points(depth)
212
+
213
+ # Create a trimesh mesh from the points
214
+ # Each pixel is connected to its 4 neighbors
215
+ # colors are the RGB values of the image
216
+
217
+ verts = pts3d.reshape(-1, 3)
218
+ image = np.array(image)
219
+ if keep_edges:
220
+ triangles = create_triangles(image.shape[0], image.shape[1])
221
+ else:
222
+ triangles = create_triangles(image.shape[0], image.shape[1], mask=~depth_edges_mask(depth))
223
+ colors = image.reshape(-1, 3)
224
+ mesh = trimesh.Trimesh(vertices=verts, faces=triangles, vertex_colors=colors)
225
+
226
+ # Save as glb
227
+ glb_file = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
228
+ glb_path = glb_file.name
229
+ mesh.export(glb_path)
230
+ return glb_path
231
 
232
  def loadurl(url):
233
  return url
 
280
  input_video = gr.Video(label="Input Video", format="mp4")
281
  input_url.change(fn=loadurl, inputs=[input_url], outputs=[input_video])
282
  submit = gr.Button("Submit")
283
+ render = gr.Button("Render")
284
  with gr.Column():
285
  model_type = gr.Dropdown([("small", "vits"), ("base", "vitb"), ("large", "vitl")], type="value", value="vits", label='Model Type')
286
+ checkbox = gr.Checkbox(label="Keep occlusion edges", value=True)
287
  processed_video = gr.Video(label="Output Video", format="mp4")
288
  processed_zip = gr.File(label="Output Archive")
289
+ output_frame = gr.Image(label="Frame", type='pil')
290
+ output_depth = gr.Image(label="Depth", type='pil')
291
+ result = gr.Model3D(label="3D Mesh", clear_color=[
292
+ 1.0, 1.0, 1.0, 1.0])
293
 
294
  def on_submit(uploaded_video,model_type):
295
 
 
298
 
299
  return output_video_path
300
 
301
+ submit.click(on_submit, inputs=[input_video, model_type], outputs=[processed_video, processed_zip, output_frame, output_depth])
302
+ render.click(get_mesh, inputs=[output_frame, output_depth, checkbox], outputs=[result])
303
 
304
  example_files = os.listdir('examples')
305
  example_files.sort()
306
  example_files = [os.path.join('examples', filename) for filename in example_files]
307
+ examples = gr.Examples(examples=example_files, inputs=[input_video], outputs=[processed_video, processed_zip, output_frame, output_depth], fn=on_submit, cache_examples=True)
308
 
309
 
310
  if __name__ == '__main__':