Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
from .diff_csdi import diff_CSDI | |
class CSDI_base(nn.Module): | |
# def __init__(self, target_dim, config, device): | |
# super().__init__() | |
# self.device = device | |
# self.target_dim = target_dim | |
# self.emb_time_dim = config["model"]["timeemb"] | |
# self.emb_feature_dim = config["model"]["featureemb"] | |
# self.is_unconditional = config["model"]["is_unconditional"] | |
# self.target_strategy = config["model"]["target_strategy"] | |
# self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim | |
# if self.is_unconditional == False: | |
# self.emb_total_dim += 1 # for conditional mask | |
# self.embed_layer = nn.Embedding( | |
# num_embeddings=self.target_dim, embedding_dim=self.emb_feature_dim | |
# ) | |
# config_diff = config["diffusion"] | |
# config_diff["side_dim"] = self.emb_total_dim | |
# input_dim = 1 if self.is_unconditional == True else 2 | |
# self.diffmodel = diff_CSDI(config_diff, input_dim) | |
# # parameters for diffusion models | |
# self.num_steps = config_diff["num_steps"] | |
# if config_diff["schedule"] == "quad": | |
# self.beta = np.linspace( | |
# config_diff["beta_start"] ** 0.5, config_diff["beta_end"] ** 0.5, self.num_steps | |
# ) ** 2 | |
# elif config_diff["schedule"] == "linear": | |
# self.beta = np.linspace( | |
# config_diff["beta_start"], config_diff["beta_end"], self.num_steps | |
# ) | |
# self.alpha_hat = 1 - self.beta | |
# self.alpha = np.cumprod(self.alpha_hat) | |
# self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1) | |
def time_embedding(self, pos, d_model=128): | |
pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(self.device) | |
position = pos.unsqueeze(2) | |
div_term = 1 / torch.pow( | |
10000.0, torch.arange(0, d_model, 2).to(self.device) / d_model | |
) | |
pe[:, :, 0::2] = torch.sin(position * div_term) | |
pe[:, :, 1::2] = torch.cos(position * div_term) | |
return pe | |
def get_randmask(self, observed_mask): | |
rand_for_mask = torch.rand_like(observed_mask) * observed_mask | |
rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1) | |
for i in range(len(observed_mask)): | |
sample_ratio = np.random.rand() # missing ratio | |
num_observed = observed_mask[i].sum().item() | |
num_masked = round(num_observed * sample_ratio) | |
rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices] = -1 | |
cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float() | |
return cond_mask | |
def get_hist_mask(self, observed_mask, for_pattern_mask=None): | |
if for_pattern_mask is None: | |
for_pattern_mask = observed_mask | |
if self.target_strategy == "mix": | |
rand_mask = self.get_randmask(observed_mask) | |
cond_mask = observed_mask.clone() | |
for i in range(len(cond_mask)): | |
mask_choice = np.random.rand() | |
if self.target_strategy == "mix" and mask_choice > 0.5: | |
cond_mask[i] = rand_mask[i] | |
else: # draw another sample for histmask (i-1 corresponds to another sample) | |
cond_mask[i] = cond_mask[i] * for_pattern_mask[i - 1] | |
return cond_mask | |
def get_test_pattern_mask(self, observed_mask, test_pattern_mask): | |
return observed_mask * test_pattern_mask | |
def get_side_info(self, observed_tp, cond_mask): | |
B, K, L = cond_mask.shape | |
time_embed = self.time_embedding(observed_tp, self.emb_time_dim) # (B,L,emb) | |
time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1) | |
feature_embed = self.embed_layer( | |
torch.arange(self.target_dim).to(self.device) | |
) # (K,emb) | |
feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1) | |
side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,*) | |
side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L) | |
if self.is_unconditional == False: | |
side_mask = cond_mask.unsqueeze(1) # (B,1,K,L) | |
side_info = torch.cat([side_info, side_mask], dim=1) | |
return side_info | |
def calc_loss_valid( | |
self, observed_data, cond_mask, observed_mask, side_info, is_train | |
): | |
loss_sum = 0 | |
for t in range(self.num_steps): # calculate loss for all t | |
loss = self.calc_loss( | |
observed_data, cond_mask, observed_mask, side_info, is_train, set_t=t | |
) | |
loss_sum += loss.detach() | |
return loss_sum / self.num_steps | |
def calc_loss( | |
self, observed_data, cond_mask, observed_mask, side_info, is_train, set_t=-1 | |
): | |
B, K, L = observed_data.shape | |
if is_train != 1: # for validation | |
t = (torch.ones(B) * set_t).long().to(self.device) | |
else: | |
t = torch.randint(0, self.num_steps, [B]).to(self.device) | |
current_alpha = self.alpha_torch[t] # (B,1,1) | |
noise = torch.randn_like(observed_data) | |
noisy_data = (current_alpha ** 0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise | |
total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask) | |
predicted = self.diffmodel(total_input, side_info, t) # (B,K,L) | |
target_mask = observed_mask - cond_mask | |
residual = (noise - predicted) * target_mask | |
num_eval = target_mask.sum() | |
loss = (residual ** 2).sum() / (num_eval if num_eval > 0 else 1) | |
return loss | |
def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask): | |
if self.is_unconditional == True: | |
total_input = noisy_data.unsqueeze(1) # (B,1,K,L) | |
else: | |
cond_obs = (cond_mask * observed_data).unsqueeze(1) | |
noisy_target = ((1 - cond_mask) * noisy_data).unsqueeze(1) | |
total_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L) | |
return total_input | |
def impute(self, observed_data, cond_mask, side_info, n_samples): | |
B, K, L = observed_data.shape | |
imputed_samples = torch.zeros(B, n_samples, K, L).to(self.device) | |
for i in range(n_samples): | |
# generate noisy observation for unconditional model | |
if self.is_unconditional == True: | |
noisy_obs = observed_data | |
noisy_cond_history = [] | |
for t in range(self.num_steps): | |
noise = torch.randn_like(noisy_obs) | |
noisy_obs = (self.alpha_hat[t] ** 0.5) * noisy_obs + self.beta[t] ** 0.5 * noise | |
noisy_cond_history.append(noisy_obs * cond_mask) | |
current_sample = torch.randn_like(observed_data) | |
for t in range(self.num_steps - 1, -1, -1): | |
# if self.is_unconditional == True: | |
diff_input = cond_mask * noisy_cond_history[t] + (1.0 - cond_mask) * current_sample | |
diff_input = diff_input.unsqueeze(1) # (B,1,K,L) | |
# else: | |
# cond_obs = (cond_mask * observed_data).unsqueeze(1) | |
# noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1) | |
# diff_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L) | |
predicted = self.diffmodel(diff_input, side_info, torch.tensor([t]).to(self.device)) | |
coeff1 = 1 / self.alpha_hat[t] ** 0.5 | |
coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5 | |
current_sample = coeff1 * (current_sample - coeff2 * predicted) | |
if t > 0: | |
noise = torch.randn_like(current_sample) | |
sigma = ( | |
(1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t] | |
) ** 0.5 | |
current_sample += sigma * noise | |
imputed_samples[:, i] = current_sample.detach() | |
return imputed_samples | |
def forward(self, batch, is_train=1): | |
( | |
observed_data, | |
observed_mask, | |
observed_tp, | |
gt_mask, | |
for_pattern_mask, | |
_, | |
) = self.process_data(batch) | |
if is_train == 0: | |
cond_mask = gt_mask | |
elif self.target_strategy != "random": | |
cond_mask = self.get_hist_mask( | |
observed_mask, for_pattern_mask=for_pattern_mask | |
) | |
else: | |
cond_mask = self.get_randmask(observed_mask) | |
side_info = self.get_side_info(observed_tp, cond_mask) | |
loss_func = self.calc_loss if is_train == 1 else self.calc_loss_valid | |
return loss_func(observed_data, cond_mask, observed_mask, side_info, is_train) | |
def evaluate(self, batch, n_samples): | |
( | |
observed_data, | |
observed_mask, | |
observed_tp, | |
gt_mask, | |
_, | |
cut_length, | |
) = self.process_data(batch) | |
with torch.no_grad(): | |
cond_mask = gt_mask | |
target_mask = observed_mask - cond_mask | |
side_info = self.get_side_info(observed_tp, cond_mask) | |
samples = self.impute(observed_data, cond_mask, side_info, n_samples) | |
for i in range(len(cut_length)): # to avoid double evaluation | |
target_mask[i, ..., 0 : cut_length[i].item()] = 0 | |
return samples, observed_data, target_mask, observed_mask, observed_tp | |