Drexubery commited on
Commit
600dfac
1 Parent(s): ec0c93a
Files changed (2) hide show
  1. utils/pvd_utils.py +28 -0
  2. viewcrafter.py +1 -5
utils/pvd_utils.py CHANGED
@@ -32,6 +32,7 @@ sys.path.append('./extern/dust3r')
32
  from dust3r.utils.device import to_numpy
33
  import matplotlib.pyplot as plt
34
  import matplotlib.colors as mcolors
 
35
 
36
  def save_video(data,images_path,folder=None):
37
  if isinstance(data, np.ndarray):
@@ -521,3 +522,30 @@ def visualizer_frame(camera_poses, highlight_index):
521
  return img
522
 
523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  from dust3r.utils.device import to_numpy
33
  import matplotlib.pyplot as plt
34
  import matplotlib.colors as mcolors
35
+ from torchvision.transforms import CenterCrop, Compose, Resize
36
 
37
  def save_video(data,images_path,folder=None):
38
  if isinstance(data, np.ndarray):
 
522
  return img
523
 
524
 
525
+ def center_crop_image(input_image):
526
+
527
+ height = 576
528
+ width = 1024
529
+ _,_,h,w = input_image.shape
530
+ h_ratio = h / height
531
+ w_ratio = w / width
532
+
533
+ if h_ratio > w_ratio:
534
+ h = int(h / w_ratio)
535
+ if h < height:
536
+ h = height
537
+ input_image = Resize((h, width))(input_image)
538
+
539
+ else:
540
+ w = int(w / h_ratio)
541
+ if w < width:
542
+ w = width
543
+ input_image = Resize((height, w))(input_image)
544
+
545
+ transformer = Compose([
546
+ # Resize(width),
547
+ CenterCrop((height, width)),
548
+ ])
549
+
550
+ input_image = transformer(input_image)
551
+ return input_image
viewcrafter.py CHANGED
@@ -401,14 +401,10 @@ class ViewCrafter:
401
  self.opts.center_scale = float(i2v_center_scale)
402
  i2v_d_phi,i2v_d_theta,i2v_d_r = [i for i in i2v_pose.split(';')]
403
  self.gradio_traj = [float(i) for i in i2v_d_phi.split()],[float(i) for i in i2v_d_theta.split()],[float(i) for i in i2v_d_r.split()]
404
- transform = transforms.Compose([
405
- transforms.Resize((576,1024)),
406
- # transforms.CenterCrop((576,1024)),
407
- ])
408
  torch.cuda.empty_cache()
409
  img_tensor = torch.from_numpy(i2v_input_image).permute(2, 0, 1).unsqueeze(0).float().to(self.device)
410
  img_tensor = (img_tensor / 255. - 0.5) * 2
411
- image_tensor_resized = transform(img_tensor) #1,3,h,w
412
  images = get_input_dict(image_tensor_resized,idx = 0,dtype = torch.float32)
413
  images = [images, copy.deepcopy(images)]
414
  images[1]['idx'] = 1
 
401
  self.opts.center_scale = float(i2v_center_scale)
402
  i2v_d_phi,i2v_d_theta,i2v_d_r = [i for i in i2v_pose.split(';')]
403
  self.gradio_traj = [float(i) for i in i2v_d_phi.split()],[float(i) for i in i2v_d_theta.split()],[float(i) for i in i2v_d_r.split()]
 
 
 
 
404
  torch.cuda.empty_cache()
405
  img_tensor = torch.from_numpy(i2v_input_image).permute(2, 0, 1).unsqueeze(0).float().to(self.device)
406
  img_tensor = (img_tensor / 255. - 0.5) * 2
407
+ image_tensor_resized = center_crop_image(img_tensor) #1,3,h,w
408
  images = get_input_dict(image_tensor_resized,idx = 0,dtype = torch.float32)
409
  images = [images, copy.deepcopy(images)]
410
  images[1]['idx'] = 1