import sys sys.path.append('./extern/dust3r') from dust3r.inference import inference, load_model from dust3r.utils.image import load_images from dust3r.image_pairs import make_pairs from dust3r.cloud_opt import global_aligner, GlobalAlignerMode from dust3r.utils.device import to_numpy import trimesh import torch import numpy as np import torchvision import os import copy import cv2 from PIL import Image import pytorch3d print(pytorch3d.__version__) from pytorch3d.structures import Pointclouds from torchvision.utils import save_image import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image from utils.pvd_utils import * from omegaconf import OmegaConf from pytorch_lightning import seed_everything from utils.diffusion_utils import instantiate_from_config,load_model_checkpoint,image_guided_synthesis from pathlib import Path from torchvision.utils import save_image class ViewCrafter: def __init__(self, opts, gradio = False): self.opts = opts self.device = opts.device self.setup_dust3r() self.setup_diffusion() #fixme # initialize ref images, pcd if not gradio: self.images, self.img_ori = self.load_initial_images(image_dir=self.opts.image_dir) self.run_dust3r(input_images=self.images) def run_dust3r(self, input_images,clean_pc = False): pairs = make_pairs(input_images, scene_graph='complete', prefilter=None, symmetrize=True) output = inference(pairs, self.dust3r, self.device, batch_size=self.opts.batch_size) mode = GlobalAlignerMode.PointCloudOptimizer #if len(self.images) > 2 else GlobalAlignerMode.PairViewer scene = global_aligner(output, device=self.device, mode=mode) if mode == GlobalAlignerMode.PointCloudOptimizer: loss = scene.compute_global_alignment(init='mst', niter=self.opts.niter, schedule=self.opts.schedule, lr=self.opts.lr) if clean_pc: self.scene = scene.clean_pointcloud() else: self.scene = scene def render_pcd(self,pts3d,imgs,masks,views,renderer,device): imgs = to_numpy(imgs) pts3d = to_numpy(pts3d) if masks == None: pts = torch.from_numpy(np.concatenate([p for p in pts3d])).view(-1, 3).to(device) col = torch.from_numpy(np.concatenate([p for p in imgs])).view(-1, 3).to(device) else: # masks = to_numpy(masks) pts = torch.from_numpy(np.concatenate([p[m] for p, m in zip(pts3d, masks)])).to(device) col = torch.from_numpy(np.concatenate([p[m] for p, m in zip(imgs, masks)])).to(device) color_mask = torch.ones(col.shape).to(device) # point_cloud_mask = Pointclouds(points=[pts],features=[color_mask]).extend(views) point_cloud = Pointclouds(points=[pts], features=[col]).extend(views) images = renderer(point_cloud) # view_masks = renderer(point_cloud_mask) return images, None def run_render(self, pcd, imgs,masks, H, W, camera_traj,num_views,use_cpu=False): if use_cpu: device = torch.device("cpu") else: device = self.device render_setup = setup_renderer(camera_traj, image_size=(H,W)) renderer = render_setup['renderer'] render_results, viewmask = self.render_pcd(pcd, imgs, masks, num_views,renderer,device) return render_results, viewmask def run_diffusion(self, renderings): prompts = [self.opts.prompt] videos = (renderings * 2. - 1.).permute(3,0,1,2).unsqueeze(0).to(self.device) condition_index = [0] with torch.no_grad(), torch.cuda.amp.autocast(): # [1,1,c,t,h,w] batch_samples = image_guided_synthesis(self.diffusion, prompts, videos, self.noise_shape, self.opts.n_samples, self.opts.ddim_steps, self.opts.ddim_eta, \ self.opts.unconditional_guidance_scale, self.opts.cfg_img, self.opts.frame_stride, self.opts.text_input, self.opts.multiple_cond_cfg, self.opts.timestep_spacing, self.opts.guidance_rescale, condition_index) # save_results_seperate(batch_samples[0], self.opts.save_dir, fps=8) # torch.Size([1, 3, 25, 576, 1024]) [-1,1] return torch.clamp(batch_samples[0][0].permute(1,2,3,0), -1., 1.) def nvs_single_view(self, gradio=False): # 最后一个view为 0 pose c2ws = self.scene.get_im_poses().detach()[1:] principal_points = self.scene.get_principal_points().detach()[1:] #cx cy focals = self.scene.get_focals().detach()[1:] shape = self.images[0]['true_shape'] H, W = int(shape[0][0]), int(shape[0][1]) pcd = [i.detach() for i in self.scene.get_pts3d(clip_thred=self.opts.dpt_trd)] # a list of points of size whc depth = [i.detach() for i in self.scene.get_depthmaps()] depth_avg = depth[-1][H//2,W//2] #以图像中心处的depth(z)为球心旋转 radius = depth_avg*self.opts.center_scale #缩放调整 ## change coordinate c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=-1, r=radius, elevation=self.opts.elevation, device=self.device) imgs = np.array(self.scene.imgs) masks = None if self.opts.mode == 'single_view_nbv': ## 输入candidate->渲染mask->最大mask对应的pose作为nbv ## nbv模式下self.opts.d_theta[0], self.opts.d_phi[0]代表search space中的网格theta, phi之间的间距; self.opts.d_phi[0]的符号代表方向,分为左右两个方向 ## FIXME hard coded candidate view数量, 以left为例,第一次迭代从[左,左上]中选取, 从第二次开始可以从[左,左上,左下]中选取 num_candidates = 2 candidate_poses,thetas,phis = generate_candidate_poses(c2ws, H, W, focals, principal_points, self.opts.d_theta[0], self.opts.d_phi[0],num_candidates, self.device) _, viewmask = self.run_render([pcd[-1]], [imgs[-1]],masks, H, W, candidate_poses,num_candidates,use_cpu=False) nbv_id = torch.argmin(viewmask.sum(dim=[1,2,3])).item() save_image( viewmask.permute(0,3,1,2), os.path.join(self.opts.save_dir,f"candidate_mask0_nbv{nbv_id}.png"), normalize=True, value_range=(0, 1)) theta_nbv = thetas[nbv_id] phi_nbv = phis[nbv_id] # generate camera trajectory from T_curr to T_nbv camera_traj,num_views = generate_traj_specified(c2ws, H, W, focals, principal_points, theta_nbv, phi_nbv, self.opts.d_r[0],self.opts.video_length, self.device) # 重置elevation self.opts.elevation -= theta_nbv elif self.opts.mode == 'single_view_target': camera_traj,num_views = generate_traj_specified(c2ws, H, W, focals, principal_points, self.opts.d_theta[0], self.opts.d_phi[0], self.opts.d_r[0],self.opts.video_length, self.device) elif self.opts.mode == 'single_view_txt': if not gradio: with open(self.opts.traj_txt, 'r') as file: lines = file.readlines() phi = [float(i) for i in lines[0].split()] theta = [float(i) for i in lines[1].split()] r = [float(i) for i in lines[2].split()] else: phi, theta, r = self.gradio_traj # device = torch.device("cpu") device = self.device camera_traj,num_views = generate_traj_txt(c2ws, H, W, focals, principal_points, phi, theta, r,self.opts.video_length, device,viz_traj=True, save_dir = self.opts.save_dir) # camera_traj,num_views = generate_traj_txt(c2ws, H, W, focals, principal_points, phi, theta, r,self.opts.video_length, self.device,viz_traj=True, save_dir = self.opts.save_dir) else: raise KeyError(f"Invalid Mode: {self.opts.mode}") render_results, viewmask = self.run_render([pcd[-1]], [imgs[-1]],masks, H, W, camera_traj,num_views,use_cpu=False) render_results = render_results.to(self.device) render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1) render_results[0] = self.img_ori if self.opts.mode == 'single_view_txt': if phi[-1]==0. and theta[-1]==0. and r[-1]==0.: render_results[-1] = self.img_ori # torch.Size([25, 576, 1024, 3]), [0,1] # save_pointcloud_with_normals([imgs[-1]], [pcd[-1]], msk=None, save_path=os.path.join(self.opts.save_dir,'pcd0.ply') , mask_pc=False, reduce_pc=False) diffusion_results = self.run_diffusion(render_results) save_video((diffusion_results + 1.0) / 2.0, os.path.join(self.opts.save_dir, 'diffusion0.mp4')) return render_results def nvs_sparse_view(self,iter): c2ws = self.scene.get_im_poses().detach() principal_points = self.scene.get_principal_points().detach() focals = self.scene.get_focals().detach() shape = self.images[0]['true_shape'] H, W = int(shape[0][0]), int(shape[0][1]) pcd = [i.detach() for i in self.scene.get_pts3d(clip_thred=self.opts.dpt_trd)] # a list of points of size whc depth = [i.detach() for i in self.scene.get_depthmaps()] depth_avg = depth[0][H//2,W//2] #以ref图像中心处的depth(z)为球心旋转 radius = depth_avg*self.opts.center_scale #缩放调整 ## masks for cleaner point cloud self.scene.min_conf_thr = float(self.scene.conf_trf(torch.tensor(self.opts.min_conf_thr))) masks = self.scene.get_masks() depth = self.scene.get_depthmaps() bgs_mask = [dpt > self.opts.bg_trd*(torch.max(dpt[40:-40,:])+torch.min(dpt[40:-40,:])) for dpt in depth] masks_new = [m+mb for m, mb in zip(masks,bgs_mask)] masks = to_numpy(masks_new) ## render, 从c2ws[0]即ref image对应的相机开始 imgs = np.array(self.scene.imgs) if self.opts.mode == 'single_view_ref_iterative': c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=0, r=radius, elevation=self.opts.elevation, device=self.device) camera_traj,num_views = generate_traj_specified(c2ws[0:1], H, W, focals[0:1], principal_points[0:1], self.opts.d_theta[iter], self.opts.d_phi[iter], self.opts.d_r[iter],self.opts.video_length, self.device) render_results, viewmask = self.run_render(pcd, imgs,masks, H, W, camera_traj,num_views) render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1) render_results[0] = self.img_ori elif self.opts.mode == 'single_view_1drc_iterative': self.opts.elevation -= self.opts.d_theta[iter-1] c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=-1, r=radius, elevation=self.opts.elevation, device=self.device) camera_traj,num_views = generate_traj_specified(c2ws[-1:], H, W, focals[-1:], principal_points[-1:], self.opts.d_theta[iter], self.opts.d_phi[iter], self.opts.d_r[iter],self.opts.video_length, self.device) render_results, viewmask = self.run_render(pcd, imgs,masks, H, W, camera_traj,num_views) render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1) render_results[0] = (self.images[-1]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. elif self.opts.mode == 'single_view_nbv': c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=-1, r=radius, elevation=self.opts.elevation, device=self.device) ## 输入candidate->渲染mask->最大mask对应的pose作为nbv ## nbv模式下self.opts.d_theta[0], self.opts.d_phi[0]代表search space中的网格theta, phi之间的间距; self.opts.d_phi[0]的符号代表方向,分为左右两个方向 ## FIXME hard coded candidate view数量, 以left为例,第一次迭代从[左,左上]中选取, 从第二次开始可以从[左,左上,左下]中选取 num_candidates = 3 candidate_poses,thetas,phis = generate_candidate_poses(c2ws[-1:], H, W, focals[-1:], principal_points[-1:], self.opts.d_theta[0], self.opts.d_phi[0], num_candidates, self.device) _, viewmask = self.run_render(pcd, imgs,masks, H, W, candidate_poses,num_candidates) nbv_id = torch.argmin(viewmask.sum(dim=[1,2,3])).item() save_image(viewmask.permute(0,3,1,2), os.path.join(self.opts.save_dir,f"candidate_mask{iter}_nbv{nbv_id}.png"), normalize=True, value_range=(0, 1)) theta_nbv = thetas[nbv_id] phi_nbv = phis[nbv_id] # generate camera trajectory from T_curr to T_nbv camera_traj,num_views = generate_traj_specified(c2ws[-1:], H, W, focals[-1:], principal_points[-1:], theta_nbv, phi_nbv, self.opts.d_r[0],self.opts.video_length, self.device) # 重置elevation self.opts.elevation -= theta_nbv render_results, viewmask = self.run_render(pcd, imgs,masks, H, W, camera_traj,num_views) render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1) render_results[0] = (self.images[-1]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. else: raise KeyError(f"Invalid Mode: {self.opts.mode}") save_video(render_results, os.path.join(self.opts.save_dir, f'render{iter}.mp4')) save_pointcloud_with_normals(imgs, pcd, msk=masks, save_path=os.path.join(self.opts.save_dir, f'pcd{iter}.ply') , mask_pc=True, reduce_pc=False) diffusion_results = self.run_diffusion(render_results) save_video((diffusion_results + 1.0) / 2.0, os.path.join(self.opts.save_dir, f'diffusion{iter}.mp4')) # torch.Size([25, 576, 1024, 3]) return diffusion_results def nvs_single_view_ref_iterative(self): all_results = [] sample_rate = 6 idx = 1 #初始包含1张ref image for itr in range(0, len(self.opts.d_phi)): if itr == 0: self.images = [self.images[0]] #去掉后一份copy diffusion_results_itr = self.nvs_single_view() # diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device) diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2) all_results.append(diffusion_results_itr) else: for i in range(0+sample_rate, diffusion_results_itr.shape[0], sample_rate): self.images.append(get_input_dict(diffusion_results_itr[i:i+1,...],idx,dtype = torch.float32)) idx += 1 self.run_dust3r(input_images=self.images, clean_pc=True) diffusion_results_itr = self.nvs_sparse_view(itr) # diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device) diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2) all_results.append(diffusion_results_itr) return all_results def nvs_single_view_1drc_iterative(self): all_results = [] sample_rate = 6 idx = 1 #初始包含1张ref image for itr in range(0, len(self.opts.d_phi)): if itr == 0: self.images = [self.images[0]] #去掉后一份copy diffusion_results_itr = self.nvs_single_view() # diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device) diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2) all_results.append(diffusion_results_itr) else: for i in range(0+sample_rate, diffusion_results_itr.shape[0], sample_rate): self.images.append(get_input_dict(diffusion_results_itr[i:i+1,...],idx,dtype = torch.float32)) idx += 1 self.run_dust3r(input_images=self.images, clean_pc=True) diffusion_results_itr = self.nvs_sparse_view(itr) # diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device) diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2) all_results.append(diffusion_results_itr) return all_results def nvs_single_view_nbv(self): # lef and right # d_theta and a_phi 是搜索空间的顶点间隔 all_results = [] ## FIXME: hard coded sample_rate = 6 max_itr = 3 idx = 1 #初始包含1张ref image for itr in range(0, max_itr): if itr == 0: self.images = [self.images[0]] #去掉后一份copy diffusion_results_itr = self.nvs_single_view() # diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device) diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2) all_results.append(diffusion_results_itr) else: for i in range(0+sample_rate, diffusion_results_itr.shape[0], sample_rate): self.images.append(get_input_dict(diffusion_results_itr[i:i+1,...],idx,dtype = torch.float32)) idx += 1 self.run_dust3r(input_images=self.images, clean_pc=True) diffusion_results_itr = self.nvs_sparse_view(itr) # diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device) diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2) all_results.append(diffusion_results_itr) return all_results def setup_diffusion(self): seed_everything(self.opts.seed) config = OmegaConf.load(self.opts.config) model_config = config.pop("model", OmegaConf.create()) ## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set" model_config['params']['unet_config']['params']['use_checkpoint'] = False model = instantiate_from_config(model_config) model = model.to(self.device) model.cond_stage_model.device = self.device model.perframe_ae = self.opts.perframe_ae assert os.path.exists(self.opts.ckpt_path), "Error: checkpoint Not Found!" model = load_model_checkpoint(model, self.opts.ckpt_path) model.eval() self.diffusion = model h, w = self.opts.height // 8, self.opts.width // 8 channels = model.model.diffusion_model.out_channels n_frames = self.opts.video_length self.noise_shape = [self.opts.bs, channels, n_frames, h, w] def setup_dust3r(self): self.dust3r = load_model(self.opts.model_path, self.device) def load_initial_images(self, image_dir): ## load images ## dict_keys(['img', 'true_shape', 'idx', 'instance', 'img_ori']),张量形式 images = load_images([image_dir], size=512,force_1024 = True) img_ori = (images[0]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. # [576,1024,3] [0,1] # img_ori = Image.open(image_dir).convert('RGB') # transform = transforms.Compose([ # transforms.Resize((576, 1024)), # transforms.ToTensor(), # transforms.Normalize((0., 0., 0.), (1., 1., 1.)) # 归一化到[-1,1],如果要归一化到[0,1],请使用transforms.Normalize((0., 0., 0.), (1., 1., 1.)) # ]) # img_ori = transform(img_ori).permute(1,2,0).to(self.device) if len(images) == 1: images = [images[0], copy.deepcopy(images[0])] images[1]['idx'] = 1 return images, img_ori def run_traj(self,i2v_input_image, i2v_elevation, i2v_center_scale, i2v_pose): self.opts.elevation = float(i2v_elevation) self.opts.center_scale = float(i2v_center_scale) i2v_d_phi,i2v_d_theta,i2v_d_r = [i for i in i2v_pose.split(';')] self.gradio_traj = [float(i) for i in i2v_d_phi.split()],[float(i) for i in i2v_d_theta.split()],[float(i) for i in i2v_d_r.split()] transform = transforms.Compose([ transforms.Resize((576,1024)), # transforms.CenterCrop((576,1024)), ]) torch.cuda.empty_cache() img_tensor = torch.from_numpy(i2v_input_image).permute(2, 0, 1).unsqueeze(0).float().to(self.device) img_tensor = (img_tensor / 255. - 0.5) * 2 image_tensor_resized = transform(img_tensor) #1,3,h,w images = get_input_dict(image_tensor_resized,idx = 0,dtype = torch.float32) images = [images, copy.deepcopy(images)] images[1]['idx'] = 1 self.images = images self.img_ori = (image_tensor_resized.squeeze(0).permute(1,2,0) + 1.)/2. # self.images: torch.Size([1, 3, 288, 512]), [-1,1] # self.img_ori: torch.Size([576, 1024, 3]), [0,1] # self.images, self.img_ori = self.load_initial_images(image_dir=i2v_input_image) self.run_dust3r(input_images=self.images) render_results = self.nvs_single_view(gradio=True) save_video(render_results, os.path.join(self.opts.save_dir, 'render0.mp4')) traj_dir = os.path.join(self.opts.save_dir, "viz_traj.mp4") return traj_dir def run_gen(self,i2v_steps, i2v_seed): self.opts.ddim_steps = i2v_steps seed_everything(i2v_seed) render_dir = os.path.join(self.opts.save_dir, 'render0.mp4') video = imageio.get_reader(render_dir, 'ffmpeg') frames = [] for frame in video: frame = frame / 255.0 frames.append(frame) frames = np.array(frames) ##torch.Size([25, 576, 1024, 3]) render_results = torch.from_numpy(frames).to(self.device).half() gen_dir = os.path.join(self.opts.save_dir, "diffusion0.mp4") diffusion_results = self.run_diffusion(render_results) save_video((diffusion_results + 1.0) / 2.0, os.path.join(self.opts.save_dir, 'diffusion0.mp4')) return gen_dir def run_both(self,i2v_input_image, i2v_elevation, i2v_center_scale, i2v_pose,i2v_steps, i2v_seed): self.opts.ddim_steps = i2v_steps seed_everything(i2v_seed) self.opts.elevation = float(i2v_elevation) self.opts.center_scale = float(i2v_center_scale) i2v_d_phi,i2v_d_theta,i2v_d_r = [i for i in i2v_pose.split(';')] self.gradio_traj = [float(i) for i in i2v_d_phi.split()],[float(i) for i in i2v_d_theta.split()],[float(i) for i in i2v_d_r.split()] transform = transforms.Compose([ transforms.Resize((576,1024)), # transforms.CenterCrop((576,1024)), ]) torch.cuda.empty_cache() img_tensor = torch.from_numpy(i2v_input_image).permute(2, 0, 1).unsqueeze(0).float().to(self.device) img_tensor = (img_tensor / 255. - 0.5) * 2 image_tensor_resized = transform(img_tensor) #1,3,h,w images = get_input_dict(image_tensor_resized,idx = 0,dtype = torch.float32) images = [images, copy.deepcopy(images)] images[1]['idx'] = 1 self.images = images self.img_ori = (image_tensor_resized.squeeze(0).permute(1,2,0) + 1.)/2. # self.images: torch.Size([1, 3, 288, 512]), [-1,1] # self.img_ori: torch.Size([576, 1024, 3]), [0,1] # self.images, self.img_ori = self.load_initial_images(image_dir=i2v_input_image) self.run_dust3r(input_images=self.images) self.nvs_single_view(gradio=True) traj_dir = os.path.join(self.opts.save_dir, "viz_traj.mp4") gen_dir = os.path.join(self.opts.save_dir, "diffusion0.mp4") return gen_dir,traj_dir,