|
import torch |
|
from tqdm import tqdm |
|
import pytorch3d.transforms as tra3d |
|
|
|
from StructDiffusion.diffusion.noise_schedule import extract |
|
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses |
|
from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs, move_pc_and_create_scene_new |
|
|
|
class Sampler: |
|
|
|
def __init__(self, model_class, checkpoint_path, device, debug=False): |
|
|
|
self.debug = debug |
|
self.device = device |
|
|
|
self.model = model_class.load_from_checkpoint(checkpoint_path) |
|
self.backbone = self.model.model |
|
self.backbone.to(device) |
|
self.backbone.eval() |
|
|
|
def sample(self, batch, num_poses): |
|
|
|
noise_schedule = self.model.noise_schedule |
|
|
|
B = batch["pcs"].shape[0] |
|
|
|
x_noisy = torch.randn((B, num_poses, 9), device=self.device) |
|
|
|
xs = [] |
|
for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)), |
|
desc='sampling loop time step', total=noise_schedule.timesteps): |
|
|
|
t = torch.full((B,), t_index, device=self.device, dtype=torch.long) |
|
|
|
|
|
betas_t = extract(noise_schedule.betas, t, x_noisy.shape) |
|
sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape) |
|
sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape) |
|
|
|
|
|
pcs = batch["pcs"] |
|
sentence = batch["sentence"] |
|
type_index = batch["type_index"] |
|
position_index = batch["position_index"] |
|
pad_mask = batch["pad_mask"] |
|
|
|
with torch.no_grad(): |
|
predicted_noise = self.backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask) |
|
|
|
|
|
model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t) |
|
if t_index == 0: |
|
x_noisy = model_mean |
|
else: |
|
posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape) |
|
noise = torch.randn_like(x_noisy) |
|
x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise |
|
|
|
xs.append(x_noisy) |
|
|
|
xs = list(reversed(xs)) |
|
return xs |
|
|
|
class SamplerV2: |
|
|
|
def __init__(self, diffusion_model_class, diffusion_checkpoint_path, |
|
collision_model_class, collision_checkpoint_path, |
|
device, debug=False): |
|
|
|
self.debug = debug |
|
self.device = device |
|
|
|
self.diffusion_model = diffusion_model_class.load_from_checkpoint(diffusion_checkpoint_path) |
|
self.diffusion_backbone = self.diffusion_model.model |
|
self.diffusion_backbone.to(device) |
|
self.diffusion_backbone.eval() |
|
|
|
self.collision_model = collision_model_class.load_from_checkpoint(collision_checkpoint_path) |
|
self.collision_backbone = self.collision_model.model |
|
self.collision_backbone.to(device) |
|
self.collision_backbone.eval() |
|
|
|
def sample(self, batch, num_poses): |
|
|
|
noise_schedule = self.diffusion_model.noise_schedule |
|
|
|
B = batch["pcs"].shape[0] |
|
|
|
x_noisy = torch.randn((B, num_poses, 9), device=self.device) |
|
|
|
xs = [] |
|
for t_index in tqdm(reversed(range(0, noise_schedule.timesteps)), |
|
desc='sampling loop time step', total=noise_schedule.timesteps): |
|
|
|
t = torch.full((B,), t_index, device=self.device, dtype=torch.long) |
|
|
|
|
|
betas_t = extract(noise_schedule.betas, t, x_noisy.shape) |
|
sqrt_one_minus_alphas_cumprod_t = extract(noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape) |
|
sqrt_recip_alphas_t = extract(noise_schedule.sqrt_recip_alphas, t, x_noisy.shape) |
|
|
|
|
|
pcs = batch["pcs"] |
|
sentence = batch["sentence"] |
|
type_index = batch["type_index"] |
|
position_index = batch["position_index"] |
|
pad_mask = batch["pad_mask"] |
|
|
|
with torch.no_grad(): |
|
predicted_noise = self.diffusion_backbone.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask) |
|
|
|
|
|
model_mean = sqrt_recip_alphas_t * (x_noisy - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t) |
|
if t_index == 0: |
|
x_noisy = model_mean |
|
else: |
|
posterior_variance_t = extract(noise_schedule.posterior_variance, t, x_noisy.shape) |
|
noise = torch.randn_like(x_noisy) |
|
x_noisy = model_mean + torch.sqrt(posterior_variance_t) * noise |
|
|
|
xs.append(x_noisy) |
|
|
|
xs = list(reversed(xs)) |
|
|
|
visualize = True |
|
|
|
struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0]) |
|
|
|
|
|
|
|
S = B |
|
num_elite = 10 |
|
|
|
|
|
|
|
|
|
obj_xyzs = batch["pcs"][0][:, :, :3] |
|
print("obj_xyzs shape", obj_xyzs.shape) |
|
|
|
|
|
|
|
num_target_objs = num_poses |
|
if self.diffusion_backbone.use_virtual_structure_frame: |
|
num_target_objs -= 1 |
|
object_pad_mask = batch["pad_mask"][0][-num_target_objs:].unsqueeze(0) |
|
target_object_inds = 1 - object_pad_mask |
|
print("target_object_inds shape", target_object_inds.shape) |
|
print("target_object_inds", target_object_inds) |
|
|
|
N, P, _ = obj_xyzs.shape |
|
print("S, N, P: {}, {}, {}".format(S, N, P)) |
|
|
|
|
|
|
|
|
|
struct_pose = struct_pose.repeat(1, N, 1, 1) |
|
struct_pose = struct_pose.reshape(S * N, 4, 4) |
|
|
|
new_obj_xyzs = obj_xyzs.repeat(S, 1, 1, 1) |
|
current_pc_pose = torch.eye(4).repeat(S, N, 1, 1).to(self.device) |
|
current_pc_pose[:, :, :3, 3] = torch.mean(new_obj_xyzs, dim=2) |
|
current_pc_pose = current_pc_pose.reshape(S * N, 4, 4) |
|
|
|
|
|
obj_params = torch.zeros((S, N, 6)).to(self.device) |
|
obj_params[:, :, :3] = pc_poses_in_struct[:, :, :3, 3] |
|
obj_params[:, :, 3:] = tra3d.matrix_to_euler_angles(pc_poses_in_struct[:, :, :3, :3], "XYZ") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scores = torch.zeros(S).to(self.device) |
|
no_intersection_scores = torch.zeros(S).to(self.device) |
|
num_batches = int(S / B) |
|
if S % B != 0: |
|
num_batches += 1 |
|
for b in range(num_batches): |
|
if b + 1 == num_batches: |
|
cur_batch_idxs_start = b * B |
|
cur_batch_idxs_end = S |
|
else: |
|
cur_batch_idxs_start = b * B |
|
cur_batch_idxs_end = (b + 1) * B |
|
cur_batch_size = cur_batch_idxs_end - cur_batch_idxs_start |
|
|
|
|
|
|
|
|
|
|
|
batch_obj_params = obj_params[cur_batch_idxs_start: cur_batch_idxs_end] |
|
batch_struct_pose = struct_pose[cur_batch_idxs_start * N: cur_batch_idxs_end * N] |
|
batch_current_pc_pose = current_pc_pose[cur_batch_idxs_start * N:cur_batch_idxs_end * N] |
|
|
|
new_obj_xyzs, _, subsampled_scene_xyz, _, obj_pair_xyzs = \ |
|
move_pc_and_create_scene_new(obj_xyzs, batch_obj_params, batch_struct_pose, batch_current_pc_pose, |
|
target_object_inds, self.device, |
|
return_scene_pts=False, |
|
return_scene_pts_and_pc_idxs=False, |
|
num_scene_pts=False, |
|
normalize_pc=False, |
|
return_pair_pc=True, |
|
num_pair_pc_pts=self.collision_model.data_cfg.num_scene_pts, |
|
normalize_pair_pc=self.collision_model.data_cfg.normalize_pc) |
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
_, num_comb, num_pair_pc_pts, _ = obj_pair_xyzs.shape |
|
|
|
collision_logits = self.collision_backbone.forward(obj_pair_xyzs.reshape(cur_batch_size * num_comb, num_pair_pc_pts, -1)) |
|
collision_scores = self.collision_backbone.convert_logits(collision_logits).reshape(cur_batch_size, num_comb) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
no_intersection_scores[cur_batch_idxs_start:cur_batch_idxs_end] = 1 - torch.mean(collision_scores, dim=1) |
|
if visualize: |
|
print("no intersection scores", no_intersection_scores) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scores = no_intersection_scores |
|
sort_idx = torch.argsort(scores).flip(dims=[0])[:num_elite] |
|
elite_obj_params = obj_params[sort_idx] |
|
elite_struct_poses = struct_pose.reshape(S, N, 4, 4)[sort_idx] |
|
elite_struct_poses = elite_struct_poses.reshape(num_elite * N, 4, 4) |
|
elite_scores = scores[sort_idx] |
|
print("elite scores:", elite_scores) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elite_obj_params = elite_obj_params.reshape(num_elite * N, -1) |
|
pc_poses_in_struct = torch.eye(4).repeat(num_elite * N, 1, 1).to(self.device) |
|
pc_poses_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(elite_obj_params[:, 3:], "XYZ") |
|
pc_poses_in_struct[:, :3, 3] = elite_obj_params[:, :3] |
|
pc_poses_in_struct = pc_poses_in_struct.reshape(num_elite, N, 4, 4) |
|
|
|
struct_pose = elite_struct_poses.reshape(num_elite, N, 4, 4)[:, 0,].unsqueeze(1) |
|
|
|
return struct_pose, pc_poses_in_struct |