import math import pdb import torch from torch import nn import torch.nn.functional as F from vit_pytorch import ViT from tqdm import tqdm from utils import * class Sampler: def __init__(self, device, mask_ind, emb_f, batch_size, seq_len, channel, fix_mode, timesteps, fixed_frame, **kwargs): self.device = device self.mask_ind = mask_ind self.emb_f = emb_f self.batch_size = batch_size self.seq_len = seq_len self.channel = channel self.fix_mode = fix_mode self.timesteps = timesteps self.fixed_frame = fixed_frame self.get_scheduler() def set_dataset_and_model(self, dataset, model): self.dataset = dataset if dataset.load_scene: self.grid = dataset.create_meshgrid(batch_size=self.batch_size).to(self.device) self.model = model def get_scheduler(self): betas = linear_beta_schedule(timesteps=self.timesteps) # define alphas alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) self.betas = betas def q_sample(self, x_start, t, noise): if noise is None: noise = torch.randn_like(x_start) sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape) sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, x_start.shape ) return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise def p_losses(self, x_start, obj_points, mat, scene_flag, mask, t, action_label, noise=None, loss_type='huber'): if noise is None: noise = torch.randn_like(x_start) noise[mask] = 0. x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) if self.dataset.load_scene: with torch.no_grad(): x_orig = transform_points(self.dataset.denormalize_torch(x_noisy), mat) mat_for_query = mat.clone() target_ind = self.mask_ind if self.mask_ind != -1 else 0 mat_for_query[:, :3, 3] = x_orig[:, self.emb_f, target_ind * 3: target_ind * 3 + 3] mat_for_query[:, 1, 3] = 0 query_points = transform_points(self.grid, mat_for_query) occ = self.dataset.get_occ_for_points(query_points, obj_points, scene_flag) nb_voxels = self.dataset.nb_voxels occ = occ.reshape(-1, nb_voxels, nb_voxels, nb_voxels).float() # import trimesh # print(mat[0]) # grid_np = self.grid[0].detach().cpu().numpy().reshape((-1, 3)) # occ_np = occ[0].detach().cpu().numpy().reshape((-1)) # points = grid_np[occ_np > 0.5] # pcd_trimesh = trimesh.PointCloud(vertices=points) # scene = trimesh.Scene([pcd_trimesh, trimesh.creation.axis(origin_color=[0, 0, 0])]) # scene.show() occ = occ.permute(0, 2, 1, 3) else: occ = None # x_noisy = torch.cat([x_noisy, occ], dim=-1).detach() predicted_noise = self.model(x_noisy, occ, t, action_label, mask) mask_inv = torch.logical_not(mask) if loss_type == 'l1': loss = F.l1_loss(noise[mask_inv], predicted_noise[mask_inv]) elif loss_type == 'l2': loss = F.mse_loss(noise[mask_inv], predicted_noise[mask_inv]) elif loss_type == "huber": loss = F.smooth_l1_loss(noise[mask_inv], predicted_noise[mask_inv]) else: raise NotImplementedError() return loss @torch.no_grad() def p_sample_loop(self, fixed_points, obj_locs, mat, scene, goal, action_label): device = next(self.model.parameters()).device shape = (self.batch_size, self.seq_len, self.channel) points = torch.randn(shape, device=device) # + torch.tensor([0., 0.3, 0.] * 22, device=device) if self.fix_mode: self.set_fixed_points(points, goal, fixed_points, mat, joint_id=self.mask_ind, fix_mode=True, fix_goal=True) imgs = [] occs = [] if self.dataset.load_scene: x_orig = transform_points(self.dataset.denormalize_torch(points), mat) mat_for_query = mat.clone() target_ind = self.mask_ind if self.mask_ind != -1 else 0 mat_for_query[:, :3, 3] = x_orig[:, self.emb_f, target_ind * 3: target_ind * 3 + 3] mat_for_query[:, 1, 3] = 0 query_points = transform_points(self.grid, mat_for_query) occ = self.dataset.get_occ_for_points(query_points, obj_locs, scene) nb_voxels = self.dataset.nb_voxels occ = occ.reshape(-1, nb_voxels, nb_voxels, nb_voxels).float() # import trimesh # print('\n', mat[0]) # grid_np = self.grid[0].detach().cpu().numpy().reshape((-1, 3)) # occ_np = occ[0].detach().cpu().numpy().reshape((-1)) # pointcloud = grid_np[occ_np > 0.5] # pcd_trimesh = trimesh.PointCloud(vertices=pointcloud) # np.save('/home/jiangnan/SyntheticHSI/Paper/Teaser/occ.npy', pointcloud) # scene = trimesh.Scene([pcd_trimesh, trimesh.creation.axis(origin_color=[0, 0, 0])]) # scene.show() occ = occ.permute(0, 2, 1, 3) else: occ = None for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps): model_used = self.model # if s < 3 or (s == 3 and i < 5) or i < 3: # model_used = model_fix # else: # model_used = model points, occ = self.p_sample(model_used, points, fixed_points, goal, None, mat, occ, torch.full((self.batch_size,), i, device=device, dtype=torch.long), i, action_label, self.mask_ind, self.emb_f, self.fix_mode) if self.fix_mode: self.set_fixed_points(points, goal, fixed_points, mat, joint_id=self.mask_ind, fix_mode=True, fix_goal=True) # set_fixed_points(points, goal, mat, joint_id=mask_ind) # # t = torch.ones(b, device=device, dtype=torch.int64) * i # if fixed_points is not None: # points[:, :fixed_points.shape[1], :] = fixed_points # q_sample(fixed_points, t, None, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod) points_orig = transform_points(self.dataset.denormalize_torch(points), mat) imgs.append(points_orig) if occ is not None: occs.append(occ.cpu().numpy()) return imgs, occs @torch.no_grad() def p_sample(self, model, x, fixed_points, goal, obj_points, mat, occ, t, t_index, action_label, mask_ind, emb_f, fix_mode, no_scene=False): betas_t = extract(self.betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t = extract( self.sqrt_one_minus_alphas_cumprod, t, x.shape ) sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape) # Equation 11 in the paper # Use our model (noise predictor) to predict the mean # joints_orig = transform_points(synhsi_dataset.denormalize_torch(x), mat) # occ = synhsi_dataset.get_occ_for_points(joints_orig, obj_points, scene) # x_occ = torch.cat([x, occ], dim=-1).detach() model_mean = sqrt_recip_alphas_t * ( x - betas_t * model(x, occ, t, action_label, mask=None) / sqrt_one_minus_alphas_cumprod_t ) # model_mean_noact = sqrt_recip_alphas_t * ( # x - betas_t * model(x, occ, t, action_label, mask=None, no_action=True) / sqrt_one_minus_alphas_cumprod_t # ) # model_mean = model_mean_noact + (model_mean - model_mean_noact) * 10 if not fix_mode: self.set_fixed_points(model_mean, goal, fixed_points, mat, joint_id=mask_ind, fix_mode=True, fix_goal=False) if t_index == 0: return model_mean, occ else: posterior_variance_t = extract(self.posterior_variance, t, x.shape) noise = torch.randn_like(x) # Algorithm 2 line 4: return model_mean + torch.sqrt(posterior_variance_t) * noise, occ # Algorithm 2 (including returning all images) def set_fixed_points(self, img, goal, fixed_points, mat, joint_id=0, fix_mode=False, fix_goal=True): # if joint_id != 0: # goal_len = 2 goal_len = goal.shape[1] # goal_batch = goal.reshape(1, 1, 3).repeat(img.shape[0], 1, 1) goal = self.dataset.normalize_torch(transform_points(goal, torch.inverse(mat))) # img[:, -1, joint_id * 3: joint_id * 3 + 3] = goal_batch[:, 0] if fix_goal: img[:, -goal_len:, joint_id * 3] = goal[:, :, 0] if joint_id != 0: img[:, -goal_len:, joint_id * 3 + 1] = goal[:, :, 1] img[:, -goal_len:, joint_id * 3 + 2] = goal[:, :, 2] if fixed_points is not None and fix_mode: img[:, :fixed_points.shape[1], :] = fixed_points class Unet(nn.Module): def __init__( self, dim_model, num_heads, num_layers, dropout_p, dim_input, dim_output, nb_voxels=None, free_p=0.1, nb_actions=0, ac_type='', no_scene=False, no_action=False, **kwargs ): super().__init__() # INFO self.model_type = "Transformer" self.dim_model = dim_model self.nb_actions = nb_actions self.ac_type = ac_type self.no_scene = no_scene self.no_action = no_action # LAYERS if not no_scene: self.scene_embedding = ViT( image_size=nb_voxels, patch_size=nb_voxels // 4, channels=nb_voxels, num_classes=dim_model, dim=1024, depth=6, heads=16, mlp_dim=2048, dropout=0.1, emb_dropout=0.1 ) self.free_p = free_p self.positional_encoder = PositionalEncoding( dim_model=dim_model, dropout_p=dropout_p, max_len=5000 ) self.embedding_input = nn.Linear(dim_input, dim_model) self.embedding_output = nn.Linear(dim_output, dim_model) # self.embedding_action = nn.Parameter(torch.randn(16, dim_model)) if not no_action and nb_actions > 0: if self.ac_type in ['last_add_first_token', 'last_new_token']: self.embedding_action = ActionTransformerEncoder(action_number=nb_actions, dim_model=dim_model, nhead=num_heads // 2, num_layers=num_layers, dim_feedforward=dim_model, dropout_p=dropout_p, activation="gelu") elif self.ac_type in ['all_add_token']: self.embedding_action = nn.Sequential( nn.Linear(nb_actions, dim_model), nn.SiLU(inplace=False), nn.Linear(dim_model, dim_model), ) encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model, nhead=num_heads, dim_feedforward=dim_model, dropout=dropout_p, activation="gelu") self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers ) # self.out = nn.Linear(dim_model, dim_output) self.out = nn.Linear(dim_model, dim_output) self.embed_timestep = TimestepEmbedder(self.dim_model, self.positional_encoder) def forward(self, x, cond, timesteps, action, mask, no_action=None): #TODO ActionFlag # action[action[:, 0] != 0., 0] = 1. t_emb = self.embed_timestep(timesteps) # [1, b, d] if self.no_scene: scene_emb = torch.zeros_like(t_emb) else: scene_emb = self.scene_embedding(cond).reshape(-1, 1, self.dim_model) if self.no_action or self.nb_actions == 0: action_emb = torch.zeros_like(t_emb) else: if self.ac_type in ['all_add_token']: action_emb = self.embedding_action(action) elif self.ac_type in ['last_add_first_token', 'last_new_token']: action_emb = self.embedding_action(action) else: raise NotImplementedError t_emb = t_emb.permute(1, 0, 2) free_ind = torch.rand(scene_emb.shape[0]).to(scene_emb.device) < self.free_p scene_emb[free_ind] = 0. # if mask is not None: # x[free_ind][:, mask[0]] = 0. if self.ac_type in ['last_add_first_token', 'last_new_token']: action_emb[free_ind] = 0. scene_emb = scene_emb.permute(1, 0, 2) action_emb = action_emb.permute(1, 0, 2) if self.ac_type in ['all_add_token', 'last_new_token']: emb = t_emb + scene_emb elif self.ac_type in ['last_add_first_token']: emb = t_emb + scene_emb + action_emb x = x.permute(1, 0, 2) x = self.embedding_input(x) * math.sqrt(self.dim_model) if self.ac_type in ['all_add_token', 'last_add_first_token']: x = torch.cat((emb, x), dim=0) elif self.ac_type in ['last_new_token']: x = torch.cat((emb, action_emb, x), dim=0) if self.ac_type in ['all_add_token']: x[1:] = x[1:] + action_emb x = self.positional_encoder(x) x = self.transformer(x) if self.ac_type in ['all_add_token', 'last_add_first_token']: output = self.out(x)[1:] elif self.ac_type in ['last_new_token']: output = self.out(x)[2:] output = output.permute(1, 0, 2) return output class PositionalEncoding(nn.Module): def __init__(self, dim_model, dropout_p, max_len): super().__init__() # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html # max_len determines how far the position can have an effect on a token (window) # Info self.dropout = nn.Dropout(dropout_p) # Encoding - From formula pos_encoding = torch.zeros(max_len, dim_model) positions_list = torch.arange(0, max_len, dtype=torch.float).reshape(-1, 1) # 0, 1, 2, 3, 4, 5 division_term = torch.exp( torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model) # PE(pos, 2i) = sin(pos/1000^(2i/dim_model)) pos_encoding[:, 0::2] = torch.sin(positions_list * division_term) # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model)) pos_encoding[:, 1::2] = torch.cos(positions_list * division_term) # Saving buffer (same as parameter without gradients needed) pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1) self.register_buffer("pos_encoding", pos_encoding) def forward(self, token_embedding: torch.tensor) -> torch.tensor: # Residual connection + pos encoding return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :]) class TimestepEmbedder(nn.Module): def __init__(self, latent_dim, sequence_pos_encoder): super().__init__() self.latent_dim = latent_dim self.sequence_pos_encoder = sequence_pos_encoder time_embed_dim = self.latent_dim self.time_embed = nn.Sequential( nn.Linear(self.latent_dim, time_embed_dim), nn.SiLU(inplace=False), nn.Linear(time_embed_dim, time_embed_dim), ) def forward(self, timesteps): return self.time_embed(self.sequence_pos_encoder.pos_encoding[timesteps])#.permute(1, 0, 2) class ActionTransformerEncoder(nn.Module): def __init__(self, action_number, dim_model, nhead, num_layers, dim_feedforward, dropout_p, activation="gelu") -> None: super().__init__() self.positional_encoder = PositionalEncoding( dim_model=dim_model, dropout_p=dropout_p, max_len=5000 ) self.input_embedder = nn.Linear(action_number, dim_model) encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_p, activation=activation) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers ) def forward(self, x): x = x.permute(1, 0, 2) x = self.input_embedder(x) x = self.positional_encoder(x) x = self.transformer_encoder(x) x = x.permute(1, 0, 2) x = torch.mean(x, dim=1, keepdim=True) return x