Spaces:
Running
on
Zero
Running
on
Zero
wzhouxiff
commited on
Commit
•
c305f12
1
Parent(s):
979cf8b
merge all the run function to app.py
Browse files- app.py +344 -7
- objctrl_2_5d/utils/ui_utils.py +0 -26
app.py
CHANGED
@@ -8,14 +8,23 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
8 |
from omegaconf import OmegaConf
|
9 |
from PIL import Image
|
10 |
import numpy as np
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
13 |
|
|
|
|
|
14 |
|
15 |
from cameractrl.inference import get_pipeline
|
16 |
-
from objctrl_2_5d.objctrl_2_5d import run
|
17 |
from objctrl_2_5d.utils.examples import examples, sync_points
|
18 |
|
|
|
|
|
|
|
19 |
|
20 |
### Title and Description ###
|
21 |
#### Description ####
|
@@ -118,7 +127,6 @@ pipeline = get_pipeline(model_id, "unet", model_config['down_block_types'], mode
|
|
118 |
|
119 |
### run the demo ##
|
120 |
@spaces.GPU(duration=50)
|
121 |
-
# def run_segment(segmentor):
|
122 |
def segment(canvas, image, logits):
|
123 |
if logits is not None:
|
124 |
logits *= 32.0
|
@@ -159,8 +167,338 @@ def segment(canvas, image, logits):
|
|
159 |
|
160 |
return mask[0], masked_img, masked_img, logits / 32.0
|
161 |
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
|
|
|
|
|
|
|
|
164 |
|
165 |
# -------------- UI definition --------------
|
166 |
with gr.Blocks() as demo:
|
@@ -317,14 +655,13 @@ with gr.Blocks() as demo:
|
|
317 |
)
|
318 |
|
319 |
select_button.click(
|
320 |
-
# run_segment(segmentor),
|
321 |
segment,
|
322 |
[canvas, original_image, mask_logits],
|
323 |
[mask, mask_output, masked_original_image, mask_logits]
|
324 |
)
|
325 |
|
326 |
depth_button.click(
|
327 |
-
|
328 |
[original_image, selected_points],
|
329 |
[depth, depth_image, org_depth_image]
|
330 |
)
|
@@ -347,7 +684,7 @@ with gr.Blocks() as demo:
|
|
347 |
)
|
348 |
|
349 |
generated_button.click(
|
350 |
-
|
351 |
[
|
352 |
original_image,
|
353 |
mask,
|
|
|
8 |
from omegaconf import OmegaConf
|
9 |
from PIL import Image
|
10 |
import numpy as np
|
11 |
+
from copy import deepcopy
|
12 |
+
import cv2
|
13 |
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torchvision
|
16 |
+
from einops import rearrange
|
17 |
+
import tempfile
|
18 |
|
19 |
+
from objctrl_2_5d.utils.ui_utils import process_image, get_camera_pose, get_subject_points, get_points, undo_points, mask_image
|
20 |
+
from ZoeDepth.zoedepth.utils.misc import colorize
|
21 |
|
22 |
from cameractrl.inference import get_pipeline
|
|
|
23 |
from objctrl_2_5d.utils.examples import examples, sync_points
|
24 |
|
25 |
+
from objctrl_2_5d.utils.objmask_util import RT2Plucker, Unprojected, roll_with_ignore_multidim, dilate_mask_pytorch
|
26 |
+
from objctrl_2_5d.utils.filter_utils import get_freq_filter, freq_mix_3d
|
27 |
+
|
28 |
|
29 |
### Title and Description ###
|
30 |
#### Description ####
|
|
|
127 |
|
128 |
### run the demo ##
|
129 |
@spaces.GPU(duration=50)
|
|
|
130 |
def segment(canvas, image, logits):
|
131 |
if logits is not None:
|
132 |
logits *= 32.0
|
|
|
167 |
|
168 |
return mask[0], masked_img, masked_img, logits / 32.0
|
169 |
|
170 |
+
@spaces.GPU(duration=50)
|
171 |
+
def get_depth(image, points):
|
172 |
+
|
173 |
+
depth = d_model_NK.infer_pil(image)
|
174 |
+
colored_depth = colorize(depth, cmap='gray_r') # [h, w, 4] 0-255
|
175 |
+
|
176 |
+
depth_img = deepcopy(colored_depth[:, :, :3])
|
177 |
+
if len(points) > 0:
|
178 |
+
for idx, point in enumerate(points):
|
179 |
+
if idx % 2 == 0:
|
180 |
+
cv2.circle(depth_img, tuple(point), 10, (255, 0, 0), -1)
|
181 |
+
else:
|
182 |
+
cv2.circle(depth_img, tuple(point), 10, (0, 0, 255), -1)
|
183 |
+
if idx > 0:
|
184 |
+
cv2.arrowedLine(depth_img, points[idx-1], points[idx], (255, 255, 255), 4, tipLength=0.5)
|
185 |
+
|
186 |
+
return depth, depth_img, colored_depth[:, :, :3]
|
187 |
+
|
188 |
+
|
189 |
+
@spaces.GPU(duration=50)
|
190 |
+
def run_objctrl_2_5d(condition_image,
|
191 |
+
mask,
|
192 |
+
depth,
|
193 |
+
RTs,
|
194 |
+
bg_mode,
|
195 |
+
shared_wapring_latents,
|
196 |
+
scale_wise_masks,
|
197 |
+
rescale,
|
198 |
+
seed,
|
199 |
+
ds, dt,
|
200 |
+
num_inference_steps=25):
|
201 |
+
|
202 |
+
DEBUG = False
|
203 |
+
|
204 |
+
if DEBUG:
|
205 |
+
cur_OUTPUT_PATH = 'outputs/tmp'
|
206 |
+
os.makedirs(cur_OUTPUT_PATH, exist_ok=True)
|
207 |
+
|
208 |
+
# num_inference_steps=25
|
209 |
+
min_guidance_scale = 1.0
|
210 |
+
max_guidance_scale = 3.0
|
211 |
+
|
212 |
+
area_ratio = 0.3
|
213 |
+
depth_scale_ = 5.2
|
214 |
+
center_margin = 10
|
215 |
+
|
216 |
+
height, width = 320, 576
|
217 |
+
num_frames = 14
|
218 |
+
|
219 |
+
intrinsics = np.array([[float(width), float(width), float(width) / 2, float(height) / 2]])
|
220 |
+
intrinsics = np.repeat(intrinsics, num_frames, axis=0) # [n_frame, 4]
|
221 |
+
fx = intrinsics[0, 0] / width
|
222 |
+
fy = intrinsics[0, 1] / height
|
223 |
+
cx = intrinsics[0, 2] / width
|
224 |
+
cy = intrinsics[0, 3] / height
|
225 |
+
|
226 |
+
down_scale = 8
|
227 |
+
H, W = height // down_scale, width // down_scale
|
228 |
+
K = np.array([[width / down_scale, 0, W / 2], [0, width / down_scale, H / 2], [0, 0, 1]])
|
229 |
+
|
230 |
+
seed = int(seed)
|
231 |
+
|
232 |
+
center_h_margin, center_w_margin = center_margin, center_margin
|
233 |
+
depth_center = np.mean(depth[height//2-center_h_margin:height//2+center_h_margin, width//2-center_w_margin:width//2+center_w_margin])
|
234 |
+
|
235 |
+
if rescale > 0:
|
236 |
+
depth_rescale = round(depth_scale_ * rescale / depth_center, 2)
|
237 |
+
else:
|
238 |
+
depth_rescale = 1.0
|
239 |
+
|
240 |
+
depth = depth * depth_rescale
|
241 |
+
|
242 |
+
depth_down = F.interpolate(torch.tensor(depth).unsqueeze(0).unsqueeze(0),
|
243 |
+
(H, W), mode='bilinear', align_corners=False).squeeze().numpy() # [H, W]
|
244 |
+
|
245 |
+
## latent
|
246 |
+
generator = torch.Generator()
|
247 |
+
generator.manual_seed(seed)
|
248 |
+
|
249 |
+
latents_org = pipeline.prepare_latents(
|
250 |
+
1,
|
251 |
+
14,
|
252 |
+
8,
|
253 |
+
height,
|
254 |
+
width,
|
255 |
+
pipeline.dtype,
|
256 |
+
device,
|
257 |
+
generator,
|
258 |
+
None,
|
259 |
+
)
|
260 |
+
latents_org = latents_org / pipeline.scheduler.init_noise_sigma
|
261 |
+
|
262 |
+
cur_plucker_embedding, _, _ = RT2Plucker(RTs, RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W
|
263 |
+
cur_plucker_embedding = cur_plucker_embedding.to(device)
|
264 |
+
cur_plucker_embedding = cur_plucker_embedding[None, ...] # b 6 f h w
|
265 |
+
cur_plucker_embedding = cur_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
|
266 |
+
cur_plucker_embedding = cur_plucker_embedding[:, :num_frames, ...]
|
267 |
+
cur_pose_features = pipeline.pose_encoder(cur_plucker_embedding)
|
268 |
+
|
269 |
+
# bg_mode = ["Fixed", "Reverse", "Free"]
|
270 |
+
if bg_mode == "Fixed":
|
271 |
+
fix_RTs = np.repeat(RTs[0][None, ...], num_frames, axis=0) # [n_frame, 4, 3]
|
272 |
+
fix_plucker_embedding, _, _ = RT2Plucker(fix_RTs, num_frames, (height, width), fx, fy, cx, cy) # 6, V, H, W
|
273 |
+
fix_plucker_embedding = fix_plucker_embedding.to(device)
|
274 |
+
fix_plucker_embedding = fix_plucker_embedding[None, ...] # b 6 f h w
|
275 |
+
fix_plucker_embedding = fix_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
|
276 |
+
fix_plucker_embedding = fix_plucker_embedding[:, :num_frames, ...]
|
277 |
+
fix_pose_features = pipeline.pose_encoder(fix_plucker_embedding)
|
278 |
+
|
279 |
+
elif bg_mode == "Reverse":
|
280 |
+
bg_plucker_embedding, _, _ = RT2Plucker(RTs[::-1], RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W
|
281 |
+
bg_plucker_embedding = bg_plucker_embedding.to(device)
|
282 |
+
bg_plucker_embedding = bg_plucker_embedding[None, ...] # b 6 f h w
|
283 |
+
bg_plucker_embedding = bg_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
|
284 |
+
bg_plucker_embedding = bg_plucker_embedding[:, :num_frames, ...]
|
285 |
+
fix_pose_features = pipeline.pose_encoder(bg_plucker_embedding)
|
286 |
+
|
287 |
+
else:
|
288 |
+
fix_pose_features = None
|
289 |
+
|
290 |
+
#### preparing mask
|
291 |
+
|
292 |
+
mask = Image.fromarray(mask)
|
293 |
+
mask = mask.resize((W, H))
|
294 |
+
mask = np.array(mask).astype(np.float32)
|
295 |
+
mask = np.expand_dims(mask, axis=-1)
|
296 |
+
|
297 |
+
# visulize mask
|
298 |
+
if DEBUG:
|
299 |
+
mask_sum_vis = mask[..., 0]
|
300 |
+
mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8)
|
301 |
+
mask_sum_vis = Image.fromarray(mask_sum_vis)
|
302 |
+
|
303 |
+
mask_sum_vis.save(f'{cur_OUTPUT_PATH}/org_mask.png')
|
304 |
+
|
305 |
+
try:
|
306 |
+
warped_masks = Unprojected(mask, depth_down, RTs, H=H, W=W, K=K)
|
307 |
+
|
308 |
+
warped_masks.insert(0, mask)
|
309 |
+
|
310 |
+
except:
|
311 |
+
# mask to bbox
|
312 |
+
print(f'!!! Mask is too small to warp; mask to bbox')
|
313 |
+
mask = mask[:, :, 0]
|
314 |
+
coords = cv2.findNonZero(mask)
|
315 |
+
x, y, w, h = cv2.boundingRect(coords)
|
316 |
+
# mask[y:y+h, x:x+w] = 1.0
|
317 |
+
|
318 |
+
center_x, center_y = x + w // 2, y + h // 2
|
319 |
+
center_z = depth_down[center_y, center_x]
|
320 |
+
|
321 |
+
# RTs [n_frame, 3, 4] to [n_frame, 4, 4] , add [0, 0, 0, 1]
|
322 |
+
RTs = np.concatenate([RTs, np.array([[[0, 0, 0, 1]]] * num_frames)], axis=1)
|
323 |
+
|
324 |
+
# RTs: world to camera
|
325 |
+
P0 = np.array([center_x, center_y, 1])
|
326 |
+
Pc0 = np.linalg.inv(K) @ P0 * center_z
|
327 |
+
pw = np.linalg.inv(RTs[0]) @ np.array([Pc0[0], Pc0[1], center_z, 1]) # [4]
|
328 |
+
|
329 |
+
P = [np.array([center_x, center_y])]
|
330 |
+
for i in range(1, num_frames):
|
331 |
+
Pci = RTs[i] @ pw
|
332 |
+
Pi = K @ Pci[:3] / Pci[2]
|
333 |
+
P.append(Pi[:2])
|
334 |
+
|
335 |
+
warped_masks = [mask]
|
336 |
+
for i in range(1, num_frames):
|
337 |
+
shift_x = int(round(P[i][0] - P[0][0]))
|
338 |
+
shift_y = int(round(P[i][1] - P[0][1]))
|
339 |
+
|
340 |
+
cur_mask = roll_with_ignore_multidim(mask, [shift_y, shift_x])
|
341 |
+
warped_masks.append(cur_mask)
|
342 |
+
|
343 |
+
|
344 |
+
warped_masks = [v[..., None] for v in warped_masks]
|
345 |
+
|
346 |
+
warped_masks = np.stack(warped_masks, axis=0) # [f, h, w]
|
347 |
+
warped_masks = np.repeat(warped_masks, 3, axis=-1) # [f, h, w, 3]
|
348 |
+
|
349 |
+
mask_sum = np.sum(warped_masks, axis=0, keepdims=True) # [1, H, W, 3]
|
350 |
+
mask_sum[mask_sum > 1.0] = 1.0
|
351 |
+
mask_sum = mask_sum[0,:,:, 0]
|
352 |
+
|
353 |
+
if DEBUG:
|
354 |
+
## visulize warp mask
|
355 |
+
warp_masks_vis = torch.tensor(warped_masks)
|
356 |
+
warp_masks_vis = (warp_masks_vis * 255.0).to(torch.uint8)
|
357 |
+
torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warped_masks.mp4', warp_masks_vis, fps=10, video_codec='h264', options={'crf': '10'})
|
358 |
+
|
359 |
+
# visulize mask
|
360 |
+
mask_sum_vis = mask_sum
|
361 |
+
mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8)
|
362 |
+
mask_sum_vis = Image.fromarray(mask_sum_vis)
|
363 |
+
|
364 |
+
mask_sum_vis.save(f'{cur_OUTPUT_PATH}/merged_mask.png')
|
365 |
+
|
366 |
+
if scale_wise_masks:
|
367 |
+
min_area = H * W * area_ratio # cal in downscale
|
368 |
+
non_zero_len = mask_sum.sum()
|
369 |
+
|
370 |
+
print(f'non_zero_len: {non_zero_len}, min_area: {min_area}')
|
371 |
+
|
372 |
+
if non_zero_len > min_area:
|
373 |
+
kernel_sizes = [1, 1, 1, 3]
|
374 |
+
elif non_zero_len > min_area * 0.5:
|
375 |
+
kernel_sizes = [3, 1, 1, 5]
|
376 |
+
else:
|
377 |
+
kernel_sizes = [5, 3, 3, 7]
|
378 |
+
else:
|
379 |
+
kernel_sizes = [1, 1, 1, 1]
|
380 |
+
|
381 |
+
mask = torch.from_numpy(mask_sum) # [h, w]
|
382 |
+
mask = mask[None, None, ...] # [1, 1, h, w]
|
383 |
+
mask = F.interpolate(mask, (height, width), mode='bilinear', align_corners=False) # [1, 1, H, W]
|
384 |
+
# mask = mask.repeat(1, num_frames, 1, 1) # [1, f, H, W]
|
385 |
+
mask = mask.to(pipeline.dtype).to(device)
|
386 |
+
|
387 |
+
##### Mask End ######
|
388 |
+
|
389 |
+
### Got blending pose features Start ###
|
390 |
+
|
391 |
+
pose_features = []
|
392 |
+
for i in range(0, len(cur_pose_features)):
|
393 |
+
kernel_size = kernel_sizes[i]
|
394 |
+
h, w = cur_pose_features[i].shape[-2:]
|
395 |
+
|
396 |
+
if fix_pose_features is None:
|
397 |
+
pose_features.append(torch.zeros_like(cur_pose_features[i]))
|
398 |
+
else:
|
399 |
+
pose_features.append(fix_pose_features[i])
|
400 |
+
|
401 |
+
cur_mask = F.interpolate(mask, (h, w), mode='bilinear', align_corners=False)
|
402 |
+
cur_mask = dilate_mask_pytorch(cur_mask, kernel_size=kernel_size) # [1, 1, H, W]
|
403 |
+
cur_mask = cur_mask.repeat(1, num_frames, 1, 1) # [1, f, H, W]
|
404 |
+
|
405 |
+
if DEBUG:
|
406 |
+
# visulize mask
|
407 |
+
mask_vis = cur_mask[0, 0].cpu().numpy() * 255.0
|
408 |
+
mask_vis = Image.fromarray(mask_vis.astype(np.uint8))
|
409 |
+
mask_vis.save(f'{cur_OUTPUT_PATH}/mask_k{kernel_size}_scale{i}.png')
|
410 |
+
|
411 |
+
cur_mask = cur_mask[None, ...] # [1, 1, f, H, W]
|
412 |
+
pose_features[-1] = cur_pose_features[i] * cur_mask + pose_features[-1] * (1 - cur_mask)
|
413 |
+
|
414 |
+
### Got blending pose features End ###
|
415 |
+
|
416 |
+
##### Warp Noise Start ######
|
417 |
+
|
418 |
+
if shared_wapring_latents:
|
419 |
+
noise = latents_org[0, 0].data.cpu().numpy().copy() #[14, 4, 40, 72]
|
420 |
+
noise = np.transpose(noise, (1, 2, 0)) # [40, 72, 4]
|
421 |
+
|
422 |
+
try:
|
423 |
+
warp_noise = Unprojected(noise, depth_down, RTs, H=H, W=W, K=K)
|
424 |
+
warp_noise.insert(0, noise)
|
425 |
+
except:
|
426 |
+
print(f'!!! Noise is too small to warp; mask to bbox')
|
427 |
+
|
428 |
+
warp_noise = [noise]
|
429 |
+
for i in range(1, num_frames):
|
430 |
+
shift_x = int(round(P[i][0] - P[0][0]))
|
431 |
+
shift_y = int(round(P[i][1] - P[0][1]))
|
432 |
+
|
433 |
+
cur_noise= roll_with_ignore_multidim(noise, [shift_y, shift_x])
|
434 |
+
warp_noise.append(cur_noise)
|
435 |
+
|
436 |
+
warp_noise = np.stack(warp_noise, axis=0) # [f, h, w, 4]
|
437 |
+
|
438 |
+
if DEBUG:
|
439 |
+
## visulize warp noise
|
440 |
+
warp_noise_vis = torch.tensor(warp_noise)[..., :3] * torch.tensor(warped_masks)
|
441 |
+
warp_noise_vis = (warp_noise_vis - warp_noise_vis.min()) / (warp_noise_vis.max() - warp_noise_vis.min())
|
442 |
+
warp_noise_vis = (warp_noise_vis * 255.0).to(torch.uint8)
|
443 |
+
|
444 |
+
torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warp_noise.mp4', warp_noise_vis, fps=10, video_codec='h264', options={'crf': '10'})
|
445 |
+
|
446 |
+
|
447 |
+
warp_latents = torch.tensor(warp_noise).permute(0, 3, 1, 2).to(latents_org.device).to(latents_org.dtype) # [frame, 4, H, W]
|
448 |
+
warp_latents = warp_latents.unsqueeze(0) # [1, frame, 4, H, W]
|
449 |
+
|
450 |
+
warped_masks = torch.tensor(warped_masks).permute(0, 3, 1, 2).unsqueeze(0) # [1, frame, 3, H, W]
|
451 |
+
mask_extend = torch.concat([warped_masks, warped_masks[:,:,0:1]], dim=2) # [1, frame, 4, H, W]
|
452 |
+
mask_extend = mask_extend.to(latents_org.device).to(latents_org.dtype)
|
453 |
+
|
454 |
+
warp_latents = warp_latents * mask_extend + latents_org * (1 - mask_extend)
|
455 |
+
warp_latents = warp_latents.permute(0, 2, 1, 3, 4)
|
456 |
+
random_noise = latents_org.clone().permute(0, 2, 1, 3, 4)
|
457 |
+
|
458 |
+
filter_shape = warp_latents.shape
|
459 |
+
|
460 |
+
freq_filter = get_freq_filter(
|
461 |
+
filter_shape,
|
462 |
+
device = device,
|
463 |
+
filter_type='butterworth',
|
464 |
+
n=4,
|
465 |
+
d_s=ds,
|
466 |
+
d_t=dt
|
467 |
+
)
|
468 |
+
|
469 |
+
warp_latents = freq_mix_3d(warp_latents, random_noise, freq_filter)
|
470 |
+
warp_latents = warp_latents.permute(0, 2, 1, 3, 4)
|
471 |
+
|
472 |
+
else:
|
473 |
+
warp_latents = latents_org.clone()
|
474 |
+
|
475 |
+
generator.manual_seed(42)
|
476 |
+
|
477 |
+
with torch.no_grad():
|
478 |
+
result = pipeline(
|
479 |
+
image=condition_image,
|
480 |
+
pose_embedding=cur_plucker_embedding,
|
481 |
+
height=height,
|
482 |
+
width=width,
|
483 |
+
num_frames=num_frames,
|
484 |
+
num_inference_steps=num_inference_steps,
|
485 |
+
min_guidance_scale=min_guidance_scale,
|
486 |
+
max_guidance_scale=max_guidance_scale,
|
487 |
+
do_image_process=True,
|
488 |
+
generator=generator,
|
489 |
+
output_type='pt',
|
490 |
+
pose_features= pose_features,
|
491 |
+
latents = warp_latents
|
492 |
+
).frames[0].cpu() #[f, c, h, w]
|
493 |
+
|
494 |
+
|
495 |
+
result = rearrange(result, 'f c h w -> f h w c')
|
496 |
+
result = (result * 255.0).to(torch.uint8)
|
497 |
|
498 |
+
video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
499 |
+
torchvision.io.write_video(video_path, result, fps=10, video_codec='h264', options={'crf': '8'})
|
500 |
+
|
501 |
+
return video_path
|
502 |
|
503 |
# -------------- UI definition --------------
|
504 |
with gr.Blocks() as demo:
|
|
|
655 |
)
|
656 |
|
657 |
select_button.click(
|
|
|
658 |
segment,
|
659 |
[canvas, original_image, mask_logits],
|
660 |
[mask, mask_output, masked_original_image, mask_logits]
|
661 |
)
|
662 |
|
663 |
depth_button.click(
|
664 |
+
get_depth,
|
665 |
[original_image, selected_points],
|
666 |
[depth, depth_image, org_depth_image]
|
667 |
)
|
|
|
684 |
)
|
685 |
|
686 |
generated_button.click(
|
687 |
+
run_objctrl_2_5d,
|
688 |
[
|
689 |
original_image,
|
690 |
mask,
|
objctrl_2_5d/utils/ui_utils.py
CHANGED
@@ -1,14 +1,9 @@
|
|
1 |
-
import spaces
|
2 |
-
|
3 |
import gradio as gr
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
|
7 |
from copy import deepcopy
|
8 |
import cv2
|
9 |
-
import torch
|
10 |
-
|
11 |
-
from ZoeDepth.zoedepth.utils.misc import colorize
|
12 |
|
13 |
from objctrl_2_5d.utils.vis_camera import vis_camera_rescale
|
14 |
from objctrl_2_5d.utils.objmask_util import trajectory_to_camera_poses_v1
|
@@ -102,27 +97,6 @@ def get_points(img,
|
|
102 |
def undo_points(original_image):
|
103 |
return original_image, []
|
104 |
|
105 |
-
@spaces.GPU(duration=50)
|
106 |
-
def run_depth(d_model_NK):
|
107 |
-
def get_depth(image, points):
|
108 |
-
|
109 |
-
depth = d_model_NK.infer_pil(image)
|
110 |
-
colored_depth = colorize(depth, cmap='gray_r') # [h, w, 4] 0-255
|
111 |
-
|
112 |
-
depth_img = deepcopy(colored_depth[:, :, :3])
|
113 |
-
if len(points) > 0:
|
114 |
-
for idx, point in enumerate(points):
|
115 |
-
if idx % 2 == 0:
|
116 |
-
cv2.circle(depth_img, tuple(point), 10, (255, 0, 0), -1)
|
117 |
-
else:
|
118 |
-
cv2.circle(depth_img, tuple(point), 10, (0, 0, 255), -1)
|
119 |
-
if idx > 0:
|
120 |
-
cv2.arrowedLine(depth_img, points[idx-1], points[idx], (255, 255, 255), 4, tipLength=0.5)
|
121 |
-
|
122 |
-
return depth, depth_img, colored_depth[:, :, :3]
|
123 |
-
|
124 |
-
return get_depth
|
125 |
-
|
126 |
|
127 |
def interpolate_points(points, num_points):
|
128 |
x = points[:, 0]
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
import numpy as np
|
4 |
|
5 |
from copy import deepcopy
|
6 |
import cv2
|
|
|
|
|
|
|
7 |
|
8 |
from objctrl_2_5d.utils.vis_camera import vis_camera_rescale
|
9 |
from objctrl_2_5d.utils.objmask_util import trajectory_to_camera_poses_v1
|
|
|
97 |
def undo_points(original_image):
|
98 |
return original_image, []
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
def interpolate_points(points, num_points):
|
102 |
x = points[:, 0]
|