Spaces:
Running
on
Zero
Running
on
Zero
pablovela5620
commited on
Upload gradio_app.py with huggingface_hub
Browse files- gradio_app.py +331 -0
gradio_app.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
import PIL.Image
|
3 |
+
from PIL.Image import Image
|
4 |
+
|
5 |
+
from src.rr_logging_utils import (
|
6 |
+
log_camera,
|
7 |
+
create_svd_blueprint,
|
8 |
+
)
|
9 |
+
|
10 |
+
from src.pose_utils import generate_camera_parameters
|
11 |
+
from src.camera_parameters import PinholeParameters
|
12 |
+
from src.depth_utils import image_to_depth
|
13 |
+
from src.image_warping import image_depth_warping
|
14 |
+
from src.sigma_utils import load_lambda_ts
|
15 |
+
from src.nerfstudio_data import frames_to_nerfstudio
|
16 |
+
|
17 |
+
|
18 |
+
import gradio as gr
|
19 |
+
from gradio_rerun import Rerun
|
20 |
+
|
21 |
+
import rerun as rr
|
22 |
+
import rerun.blueprint as rrb
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import PIL
|
26 |
+
import torch
|
27 |
+
from pathlib import Path
|
28 |
+
import threading
|
29 |
+
from queue import SimpleQueue
|
30 |
+
import trimesh
|
31 |
+
import subprocess
|
32 |
+
|
33 |
+
import mmcv
|
34 |
+
from uuid import uuid4
|
35 |
+
|
36 |
+
from typing import Final, Literal
|
37 |
+
|
38 |
+
from jaxtyping import Float64, Float32, UInt8
|
39 |
+
|
40 |
+
from monopriors.relative_depth_models import (
|
41 |
+
get_relative_predictor,
|
42 |
+
)
|
43 |
+
|
44 |
+
from src.custom_diffusers_pipeline.svd import StableVideoDiffusionPipeline
|
45 |
+
from src.custom_diffusers_pipeline.scheduler import EulerDiscreteScheduler
|
46 |
+
|
47 |
+
try:
|
48 |
+
import spaces # type: ignore
|
49 |
+
|
50 |
+
IN_SPACES = True
|
51 |
+
except ImportError:
|
52 |
+
print("Not running on Zero")
|
53 |
+
IN_SPACES = False
|
54 |
+
|
55 |
+
|
56 |
+
SVD_HEIGHT: Final[int] = 576
|
57 |
+
SVD_WIDTH: Final[int] = 1024
|
58 |
+
NEAR: Final[float] = 0.0001
|
59 |
+
FAR: Final[float] = 500.0
|
60 |
+
|
61 |
+
if gr.NO_RELOAD:
|
62 |
+
DepthAnythingV2Predictor = get_relative_predictor("DepthAnythingV2Predictor")(
|
63 |
+
device="cuda"
|
64 |
+
)
|
65 |
+
SVD_PIPE = StableVideoDiffusionPipeline.from_pretrained(
|
66 |
+
"stabilityai/stable-video-diffusion-img2vid-xt",
|
67 |
+
torch_dtype=torch.float16,
|
68 |
+
variant="fp16",
|
69 |
+
)
|
70 |
+
SVD_PIPE.to("cuda")
|
71 |
+
scheduler = EulerDiscreteScheduler.from_config(SVD_PIPE.scheduler.config)
|
72 |
+
SVD_PIPE.scheduler = scheduler
|
73 |
+
|
74 |
+
|
75 |
+
def svd_render_threaded(
|
76 |
+
image_o: PIL.Image.Image,
|
77 |
+
masks: Float64[torch.Tensor, "b 72 128"],
|
78 |
+
cond_image: PIL.Image.Image,
|
79 |
+
lambda_ts: Float64[torch.Tensor, "n b"],
|
80 |
+
num_denoise_iters: Literal[2, 25, 50, 100],
|
81 |
+
weight_clamp: float,
|
82 |
+
log_queue: SimpleQueue | None = None,
|
83 |
+
):
|
84 |
+
frames: list[PIL.Image.Image] = SVD_PIPE(
|
85 |
+
[image_o],
|
86 |
+
log_queue=log_queue,
|
87 |
+
temp_cond=cond_image,
|
88 |
+
mask=masks,
|
89 |
+
lambda_ts=lambda_ts,
|
90 |
+
weight_clamp=weight_clamp,
|
91 |
+
num_frames=25,
|
92 |
+
decode_chunk_size=8,
|
93 |
+
num_inference_steps=num_denoise_iters,
|
94 |
+
).frames[0]
|
95 |
+
|
96 |
+
log_queue.put(frames)
|
97 |
+
|
98 |
+
|
99 |
+
if IN_SPACES:
|
100 |
+
svd_render_threaded = spaces.GPU(svd_render_threaded)
|
101 |
+
|
102 |
+
|
103 |
+
@rr.thread_local_stream("warped_image")
|
104 |
+
def gradio_warped_image(
|
105 |
+
image_path: str,
|
106 |
+
num_denoise_iters: Literal[2, 25, 50, 100],
|
107 |
+
direction: Literal["left", "right"],
|
108 |
+
degrees_per_frame: int | float,
|
109 |
+
major_radius: float = 60.0,
|
110 |
+
minor_radius: float = 70.0,
|
111 |
+
num_frames: int = 25, # StableDiffusion Video generates 25 frames
|
112 |
+
progress=gr.Progress(track_tqdm=True),
|
113 |
+
):
|
114 |
+
# ensure that the degrees per frame is a float
|
115 |
+
degrees_per_frame = float(degrees_per_frame)
|
116 |
+
|
117 |
+
image_path: Path = Path(image_path) if isinstance(image_path, str) else image_path
|
118 |
+
assert image_path.exists(), f"Image file not found: {image_path}"
|
119 |
+
save_path: Path = image_path.parent / f"{image_path.stem}_{uuid4()}"
|
120 |
+
|
121 |
+
# setup rerun logging
|
122 |
+
stream = rr.binary_stream()
|
123 |
+
parent_log_path = Path("world")
|
124 |
+
rr.log(f"{parent_log_path}", rr.ViewCoordinates.LDB, static=True)
|
125 |
+
blueprint: rrb.Blueprint = create_svd_blueprint(parent_log_path)
|
126 |
+
rr.send_blueprint(blueprint)
|
127 |
+
|
128 |
+
# Load image and resize to SVD dimensions
|
129 |
+
rgb_original: Image = PIL.Image.open(image_path)
|
130 |
+
rgb_resized: Image = rgb_original.resize(
|
131 |
+
(SVD_WIDTH, SVD_HEIGHT), PIL.Image.Resampling.NEAREST
|
132 |
+
)
|
133 |
+
rgb_np_original: UInt8[np.ndarray, "h w 3"] = np.array(rgb_original)
|
134 |
+
rgb_np_hw3: UInt8[np.ndarray, "h w 3"] = np.array(rgb_resized)
|
135 |
+
|
136 |
+
# generate initial camera parameters for video trajectory
|
137 |
+
camera_list: list[PinholeParameters] = generate_camera_parameters(
|
138 |
+
num_frames=num_frames,
|
139 |
+
image_width=SVD_WIDTH,
|
140 |
+
image_height=SVD_HEIGHT,
|
141 |
+
degrees_per_frame=degrees_per_frame,
|
142 |
+
major_radius=major_radius,
|
143 |
+
minor_radius=minor_radius,
|
144 |
+
direction=direction,
|
145 |
+
)
|
146 |
+
|
147 |
+
assert len(camera_list) == num_frames, "Number of camera parameters mismatch"
|
148 |
+
|
149 |
+
# Estimate depth map and pointcloud for the input image
|
150 |
+
depth: Float32[np.ndarray, "h w"]
|
151 |
+
trimesh_pc: trimesh.PointCloud
|
152 |
+
depth_original: Float32[np.ndarray, "original_h original_w"]
|
153 |
+
trimesh_pc_original: trimesh.PointCloud
|
154 |
+
|
155 |
+
depth, trimesh_pc, depth_original, trimesh_pc_original = image_to_depth(
|
156 |
+
rgb_np_original=rgb_np_original,
|
157 |
+
rgb_np_hw3=rgb_np_hw3,
|
158 |
+
cam_params=camera_list[0],
|
159 |
+
near=NEAR,
|
160 |
+
far=FAR,
|
161 |
+
depth_predictor=DepthAnythingV2Predictor,
|
162 |
+
)
|
163 |
+
|
164 |
+
rr.log(
|
165 |
+
f"{parent_log_path}/point_cloud",
|
166 |
+
rr.Points3D(
|
167 |
+
positions=trimesh_pc.vertices,
|
168 |
+
colors=trimesh_pc.colors,
|
169 |
+
),
|
170 |
+
static=True,
|
171 |
+
)
|
172 |
+
|
173 |
+
start_cam: PinholeParameters = camera_list[0]
|
174 |
+
cond_image: list[PIL.Image.Image] = []
|
175 |
+
masks: list[Float64[torch.Tensor, "1 72 128"]] = []
|
176 |
+
|
177 |
+
# Perform image depth warping to generated camera parameters
|
178 |
+
current_cam: PinholeParameters
|
179 |
+
for frame_id, current_cam in enumerate(camera_list):
|
180 |
+
rr.set_time_sequence("frame_id", frame_id)
|
181 |
+
if frame_id == 0:
|
182 |
+
cam_log_path: Path = parent_log_path / "warped_camera"
|
183 |
+
log_camera(cam_log_path, current_cam, rgb_np_hw3, depth)
|
184 |
+
else:
|
185 |
+
# clear logged depth from the previous frame
|
186 |
+
rr.log(f"{cam_log_path}/pinhole/depth", rr.Clear(recursive=False))
|
187 |
+
cam_log_path: Path = parent_log_path / "warped_camera"
|
188 |
+
# do image warping
|
189 |
+
warped_frame2, mask_erosion_tensor = image_depth_warping(
|
190 |
+
image=rgb_np_hw3,
|
191 |
+
depth=depth,
|
192 |
+
cam_T_world_44_s=start_cam.extrinsics.cam_T_world,
|
193 |
+
cam_T_world_44_t=current_cam.extrinsics.cam_T_world,
|
194 |
+
K=current_cam.intrinsics.k_matrix,
|
195 |
+
)
|
196 |
+
cond_image.append(warped_frame2)
|
197 |
+
masks.append(mask_erosion_tensor)
|
198 |
+
|
199 |
+
log_camera(cam_log_path, current_cam, np.asarray(warped_frame2))
|
200 |
+
yield stream.read(), None, [], ""
|
201 |
+
|
202 |
+
masks: Float64[torch.Tensor, "b 72 128"] = torch.cat(masks)
|
203 |
+
# load sigmas to optimize for timestep
|
204 |
+
progress(0.1, desc="Optimizing timesteps for diffusion")
|
205 |
+
lambda_ts: Float64[torch.Tensor, "n b"] = load_lambda_ts(num_denoise_iters)
|
206 |
+
progress(0.15, desc="Starting diffusion")
|
207 |
+
|
208 |
+
# to allow logging from a separate thread
|
209 |
+
log_queue: SimpleQueue = SimpleQueue()
|
210 |
+
handle = threading.Thread(
|
211 |
+
target=svd_render_threaded,
|
212 |
+
kwargs={
|
213 |
+
"image_o": rgb_resized,
|
214 |
+
"masks": masks,
|
215 |
+
"cond_image": cond_image,
|
216 |
+
"lambda_ts": lambda_ts,
|
217 |
+
"num_denoise_iters": num_denoise_iters,
|
218 |
+
"weight_clamp": 0.2,
|
219 |
+
"log_queue": log_queue,
|
220 |
+
},
|
221 |
+
)
|
222 |
+
|
223 |
+
handle.start()
|
224 |
+
i = 0
|
225 |
+
while True:
|
226 |
+
msg = log_queue.get()
|
227 |
+
match msg:
|
228 |
+
case frames if all(isinstance(frame, PIL.Image.Image) for frame in frames):
|
229 |
+
break
|
230 |
+
case entity_path, entity, times:
|
231 |
+
i += 1
|
232 |
+
rr.reset_time()
|
233 |
+
for timeline, time in times:
|
234 |
+
if isinstance(time, int):
|
235 |
+
rr.set_time_sequence(timeline, time)
|
236 |
+
else:
|
237 |
+
rr.set_time_seconds(timeline, time)
|
238 |
+
static = False
|
239 |
+
if entity_path == "diffusion_step":
|
240 |
+
static = True
|
241 |
+
rr.log(entity_path, entity, static=static)
|
242 |
+
yield stream.read(), None, [], f"{i} out of {num_denoise_iters}"
|
243 |
+
case _:
|
244 |
+
assert False
|
245 |
+
handle.join()
|
246 |
+
|
247 |
+
# all frames but the first one
|
248 |
+
frame: np.ndarray
|
249 |
+
for frame_id, (frame, cam_pararms) in enumerate(zip(frames, camera_list)):
|
250 |
+
# add one since the first frame is the original image
|
251 |
+
rr.set_time_sequence("frame_id", frame_id)
|
252 |
+
cam_log_path = parent_log_path / "generated_camera"
|
253 |
+
generated_rgb_np: UInt8[np.ndarray, "h w 3"] = np.array(frame)
|
254 |
+
log_camera(cam_log_path, cam_pararms, generated_rgb_np, depth=None)
|
255 |
+
yield stream.read(), None, [], "finished"
|
256 |
+
|
257 |
+
frames_to_nerfstudio(
|
258 |
+
rgb_np_original, frames, trimesh_pc_original, camera_list, save_path
|
259 |
+
)
|
260 |
+
# zip up nerfstudio data
|
261 |
+
zip_file_path = save_path / "nerfstudio.zip"
|
262 |
+
progress(0.95, desc="Zipping up camera data in nerfstudio format")
|
263 |
+
# Run the zip command
|
264 |
+
subprocess.run(["zip", "-r", str(zip_file_path), str(save_path)], check=True)
|
265 |
+
video_file_path = save_path / "output.mp4"
|
266 |
+
mmcv.frames2video(str(save_path), str(video_file_path), fps=7)
|
267 |
+
print(f"Video saved to {video_file_path}")
|
268 |
+
yield stream.read(), video_file_path, [str(zip_file_path)], "finished"
|
269 |
+
|
270 |
+
|
271 |
+
with gr.Blocks() as demo:
|
272 |
+
with gr.Tab("Streaming"):
|
273 |
+
with gr.Row():
|
274 |
+
img = gr.Image(interactive=True, label="Image", type="filepath")
|
275 |
+
with gr.Tab(label="Settings"):
|
276 |
+
with gr.Column():
|
277 |
+
warp_img_btn = gr.Button("Warp Images")
|
278 |
+
num_iters = gr.Radio(
|
279 |
+
choices=[2, 25, 50, 100],
|
280 |
+
value=2,
|
281 |
+
label="Number of iterations",
|
282 |
+
type="value",
|
283 |
+
)
|
284 |
+
cam_direction = gr.Radio(
|
285 |
+
choices=["left", "right"],
|
286 |
+
value="left",
|
287 |
+
label="Camera direction",
|
288 |
+
type="value",
|
289 |
+
)
|
290 |
+
degrees_per_frame = gr.Slider(
|
291 |
+
minimum=0.25,
|
292 |
+
maximum=1.0,
|
293 |
+
step=0.05,
|
294 |
+
value=0.3,
|
295 |
+
label="Degrees per frame",
|
296 |
+
)
|
297 |
+
iteration_num = gr.Textbox(
|
298 |
+
value="",
|
299 |
+
label="Current Diffusion Step",
|
300 |
+
)
|
301 |
+
with gr.Tab(label="Outputs"):
|
302 |
+
video_output = gr.Video(interactive=False)
|
303 |
+
image_files_output = gr.File(interactive=False, file_count="multiple")
|
304 |
+
|
305 |
+
# Rerun 0.16 has issues when embedded in a Gradio tab, so we share a viewer between all the tabs.
|
306 |
+
# In 0.17 we can instead scope each viewer to its own tab to clean up these examples further.
|
307 |
+
with gr.Row():
|
308 |
+
viewer = Rerun(
|
309 |
+
streaming=True,
|
310 |
+
)
|
311 |
+
|
312 |
+
warp_img_btn.click(
|
313 |
+
gradio_warped_image,
|
314 |
+
inputs=[img, num_iters, cam_direction, degrees_per_frame],
|
315 |
+
outputs=[viewer, video_output, image_files_output, iteration_num],
|
316 |
+
)
|
317 |
+
|
318 |
+
gr.Examples(
|
319 |
+
[
|
320 |
+
[
|
321 |
+
"/home/pablo/0Dev/docker/.per/repos/NVS_Solver/example_imgs/single/000001.jpg",
|
322 |
+
],
|
323 |
+
],
|
324 |
+
fn=warp_img_btn,
|
325 |
+
inputs=[img, num_iters, cam_direction, degrees_per_frame],
|
326 |
+
outputs=[viewer, video_output, image_files_output],
|
327 |
+
)
|
328 |
+
|
329 |
+
|
330 |
+
if __name__ == "__main__":
|
331 |
+
demo.queue().launch()
|