import numpy as np import torch import torch.nn as nn from .diff_csdi import diff_CSDI class CSDI_base(nn.Module): # def __init__(self, target_dim, config, device): # super().__init__() # self.device = device # self.target_dim = target_dim # self.emb_time_dim = config["model"]["timeemb"] # self.emb_feature_dim = config["model"]["featureemb"] # self.is_unconditional = config["model"]["is_unconditional"] # self.target_strategy = config["model"]["target_strategy"] # self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim # if self.is_unconditional == False: # self.emb_total_dim += 1 # for conditional mask # self.embed_layer = nn.Embedding( # num_embeddings=self.target_dim, embedding_dim=self.emb_feature_dim # ) # config_diff = config["diffusion"] # config_diff["side_dim"] = self.emb_total_dim # input_dim = 1 if self.is_unconditional == True else 2 # self.diffmodel = diff_CSDI(config_diff, input_dim) # # parameters for diffusion models # self.num_steps = config_diff["num_steps"] # if config_diff["schedule"] == "quad": # self.beta = np.linspace( # config_diff["beta_start"] ** 0.5, config_diff["beta_end"] ** 0.5, self.num_steps # ) ** 2 # elif config_diff["schedule"] == "linear": # self.beta = np.linspace( # config_diff["beta_start"], config_diff["beta_end"], self.num_steps # ) # self.alpha_hat = 1 - self.beta # self.alpha = np.cumprod(self.alpha_hat) # self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1) def time_embedding(self, pos, d_model=128): pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(self.device) position = pos.unsqueeze(2) div_term = 1 / torch.pow( 10000.0, torch.arange(0, d_model, 2).to(self.device) / d_model ) pe[:, :, 0::2] = torch.sin(position * div_term) pe[:, :, 1::2] = torch.cos(position * div_term) return pe def get_randmask(self, observed_mask): rand_for_mask = torch.rand_like(observed_mask) * observed_mask rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1) for i in range(len(observed_mask)): sample_ratio = np.random.rand() # missing ratio num_observed = observed_mask[i].sum().item() num_masked = round(num_observed * sample_ratio) rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices] = -1 cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float() return cond_mask def get_hist_mask(self, observed_mask, for_pattern_mask=None): if for_pattern_mask is None: for_pattern_mask = observed_mask if self.target_strategy == "mix": rand_mask = self.get_randmask(observed_mask) cond_mask = observed_mask.clone() for i in range(len(cond_mask)): mask_choice = np.random.rand() if self.target_strategy == "mix" and mask_choice > 0.5: cond_mask[i] = rand_mask[i] else: # draw another sample for histmask (i-1 corresponds to another sample) cond_mask[i] = cond_mask[i] * for_pattern_mask[i - 1] return cond_mask def get_test_pattern_mask(self, observed_mask, test_pattern_mask): return observed_mask * test_pattern_mask def get_side_info(self, observed_tp, cond_mask): B, K, L = cond_mask.shape time_embed = self.time_embedding(observed_tp, self.emb_time_dim) # (B,L,emb) time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1) feature_embed = self.embed_layer( torch.arange(self.target_dim).to(self.device) ) # (K,emb) feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1) side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,*) side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L) if self.is_unconditional == False: side_mask = cond_mask.unsqueeze(1) # (B,1,K,L) side_info = torch.cat([side_info, side_mask], dim=1) return side_info def calc_loss_valid( self, observed_data, cond_mask, observed_mask, side_info, is_train ): loss_sum = 0 for t in range(self.num_steps): # calculate loss for all t loss = self.calc_loss( observed_data, cond_mask, observed_mask, side_info, is_train, set_t=t ) loss_sum += loss.detach() return loss_sum / self.num_steps def calc_loss( self, observed_data, cond_mask, observed_mask, side_info, is_train, set_t=-1 ): B, K, L = observed_data.shape if is_train != 1: # for validation t = (torch.ones(B) * set_t).long().to(self.device) else: t = torch.randint(0, self.num_steps, [B]).to(self.device) current_alpha = self.alpha_torch[t] # (B,1,1) noise = torch.randn_like(observed_data) noisy_data = (current_alpha ** 0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask) predicted = self.diffmodel(total_input, side_info, t) # (B,K,L) target_mask = observed_mask - cond_mask residual = (noise - predicted) * target_mask num_eval = target_mask.sum() loss = (residual ** 2).sum() / (num_eval if num_eval > 0 else 1) return loss def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask): if self.is_unconditional == True: total_input = noisy_data.unsqueeze(1) # (B,1,K,L) else: cond_obs = (cond_mask * observed_data).unsqueeze(1) noisy_target = ((1 - cond_mask) * noisy_data).unsqueeze(1) total_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L) return total_input def impute(self, observed_data, cond_mask, side_info, n_samples): B, K, L = observed_data.shape imputed_samples = torch.zeros(B, n_samples, K, L).to(self.device) for i in range(n_samples): # generate noisy observation for unconditional model if self.is_unconditional == True: noisy_obs = observed_data noisy_cond_history = [] for t in range(self.num_steps): noise = torch.randn_like(noisy_obs) noisy_obs = (self.alpha_hat[t] ** 0.5) * noisy_obs + self.beta[t] ** 0.5 * noise noisy_cond_history.append(noisy_obs * cond_mask) current_sample = torch.randn_like(observed_data) for t in range(self.num_steps - 1, -1, -1): # if self.is_unconditional == True: diff_input = cond_mask * noisy_cond_history[t] + (1.0 - cond_mask) * current_sample diff_input = diff_input.unsqueeze(1) # (B,1,K,L) # else: # cond_obs = (cond_mask * observed_data).unsqueeze(1) # noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1) # diff_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L) predicted = self.diffmodel(diff_input, side_info, torch.tensor([t]).to(self.device)) coeff1 = 1 / self.alpha_hat[t] ** 0.5 coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5 current_sample = coeff1 * (current_sample - coeff2 * predicted) if t > 0: noise = torch.randn_like(current_sample) sigma = ( (1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t] ) ** 0.5 current_sample += sigma * noise imputed_samples[:, i] = current_sample.detach() return imputed_samples def forward(self, batch, is_train=1): ( observed_data, observed_mask, observed_tp, gt_mask, for_pattern_mask, _, ) = self.process_data(batch) if is_train == 0: cond_mask = gt_mask elif self.target_strategy != "random": cond_mask = self.get_hist_mask( observed_mask, for_pattern_mask=for_pattern_mask ) else: cond_mask = self.get_randmask(observed_mask) side_info = self.get_side_info(observed_tp, cond_mask) loss_func = self.calc_loss if is_train == 1 else self.calc_loss_valid return loss_func(observed_data, cond_mask, observed_mask, side_info, is_train) def evaluate(self, batch, n_samples): ( observed_data, observed_mask, observed_tp, gt_mask, _, cut_length, ) = self.process_data(batch) with torch.no_grad(): cond_mask = gt_mask target_mask = observed_mask - cond_mask side_info = self.get_side_info(observed_tp, cond_mask) samples = self.impute(observed_data, cond_mask, side_info, n_samples) for i in range(len(cut_length)): # to avoid double evaluation target_mask[i, ..., 0 : cut_length[i].item()] = 0 return samples, observed_data, target_mask, observed_mask, observed_tp