PeterYu's picture
update
2875fe6
import numpy as np
import torch
import torch.nn as nn
from .diff_csdi import diff_CSDI
class CSDI_base(nn.Module):
# def __init__(self, target_dim, config, device):
# super().__init__()
# self.device = device
# self.target_dim = target_dim
# self.emb_time_dim = config["model"]["timeemb"]
# self.emb_feature_dim = config["model"]["featureemb"]
# self.is_unconditional = config["model"]["is_unconditional"]
# self.target_strategy = config["model"]["target_strategy"]
# self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim
# if self.is_unconditional == False:
# self.emb_total_dim += 1 # for conditional mask
# self.embed_layer = nn.Embedding(
# num_embeddings=self.target_dim, embedding_dim=self.emb_feature_dim
# )
# config_diff = config["diffusion"]
# config_diff["side_dim"] = self.emb_total_dim
# input_dim = 1 if self.is_unconditional == True else 2
# self.diffmodel = diff_CSDI(config_diff, input_dim)
# # parameters for diffusion models
# self.num_steps = config_diff["num_steps"]
# if config_diff["schedule"] == "quad":
# self.beta = np.linspace(
# config_diff["beta_start"] ** 0.5, config_diff["beta_end"] ** 0.5, self.num_steps
# ) ** 2
# elif config_diff["schedule"] == "linear":
# self.beta = np.linspace(
# config_diff["beta_start"], config_diff["beta_end"], self.num_steps
# )
# self.alpha_hat = 1 - self.beta
# self.alpha = np.cumprod(self.alpha_hat)
# self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1)
def time_embedding(self, pos, d_model=128):
pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(self.device)
position = pos.unsqueeze(2)
div_term = 1 / torch.pow(
10000.0, torch.arange(0, d_model, 2).to(self.device) / d_model
)
pe[:, :, 0::2] = torch.sin(position * div_term)
pe[:, :, 1::2] = torch.cos(position * div_term)
return pe
def get_randmask(self, observed_mask):
rand_for_mask = torch.rand_like(observed_mask) * observed_mask
rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1)
for i in range(len(observed_mask)):
sample_ratio = np.random.rand() # missing ratio
num_observed = observed_mask[i].sum().item()
num_masked = round(num_observed * sample_ratio)
rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices] = -1
cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float()
return cond_mask
def get_hist_mask(self, observed_mask, for_pattern_mask=None):
if for_pattern_mask is None:
for_pattern_mask = observed_mask
if self.target_strategy == "mix":
rand_mask = self.get_randmask(observed_mask)
cond_mask = observed_mask.clone()
for i in range(len(cond_mask)):
mask_choice = np.random.rand()
if self.target_strategy == "mix" and mask_choice > 0.5:
cond_mask[i] = rand_mask[i]
else: # draw another sample for histmask (i-1 corresponds to another sample)
cond_mask[i] = cond_mask[i] * for_pattern_mask[i - 1]
return cond_mask
def get_test_pattern_mask(self, observed_mask, test_pattern_mask):
return observed_mask * test_pattern_mask
def get_side_info(self, observed_tp, cond_mask):
B, K, L = cond_mask.shape
time_embed = self.time_embedding(observed_tp, self.emb_time_dim) # (B,L,emb)
time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1)
feature_embed = self.embed_layer(
torch.arange(self.target_dim).to(self.device)
) # (K,emb)
feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1)
side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,*)
side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L)
if self.is_unconditional == False:
side_mask = cond_mask.unsqueeze(1) # (B,1,K,L)
side_info = torch.cat([side_info, side_mask], dim=1)
return side_info
def calc_loss_valid(
self, observed_data, cond_mask, observed_mask, side_info, is_train
):
loss_sum = 0
for t in range(self.num_steps): # calculate loss for all t
loss = self.calc_loss(
observed_data, cond_mask, observed_mask, side_info, is_train, set_t=t
)
loss_sum += loss.detach()
return loss_sum / self.num_steps
def calc_loss(
self, observed_data, cond_mask, observed_mask, side_info, is_train, set_t=-1
):
B, K, L = observed_data.shape
if is_train != 1: # for validation
t = (torch.ones(B) * set_t).long().to(self.device)
else:
t = torch.randint(0, self.num_steps, [B]).to(self.device)
current_alpha = self.alpha_torch[t] # (B,1,1)
noise = torch.randn_like(observed_data)
noisy_data = (current_alpha ** 0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise
total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask)
predicted = self.diffmodel(total_input, side_info, t) # (B,K,L)
target_mask = observed_mask - cond_mask
residual = (noise - predicted) * target_mask
num_eval = target_mask.sum()
loss = (residual ** 2).sum() / (num_eval if num_eval > 0 else 1)
return loss
def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask):
if self.is_unconditional == True:
total_input = noisy_data.unsqueeze(1) # (B,1,K,L)
else:
cond_obs = (cond_mask * observed_data).unsqueeze(1)
noisy_target = ((1 - cond_mask) * noisy_data).unsqueeze(1)
total_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L)
return total_input
def impute(self, observed_data, cond_mask, side_info, n_samples):
B, K, L = observed_data.shape
imputed_samples = torch.zeros(B, n_samples, K, L).to(self.device)
for i in range(n_samples):
# generate noisy observation for unconditional model
if self.is_unconditional == True:
noisy_obs = observed_data
noisy_cond_history = []
for t in range(self.num_steps):
noise = torch.randn_like(noisy_obs)
noisy_obs = (self.alpha_hat[t] ** 0.5) * noisy_obs + self.beta[t] ** 0.5 * noise
noisy_cond_history.append(noisy_obs * cond_mask)
current_sample = torch.randn_like(observed_data)
for t in range(self.num_steps - 1, -1, -1):
# if self.is_unconditional == True:
diff_input = cond_mask * noisy_cond_history[t] + (1.0 - cond_mask) * current_sample
diff_input = diff_input.unsqueeze(1) # (B,1,K,L)
# else:
# cond_obs = (cond_mask * observed_data).unsqueeze(1)
# noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1)
# diff_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L)
predicted = self.diffmodel(diff_input, side_info, torch.tensor([t]).to(self.device))
coeff1 = 1 / self.alpha_hat[t] ** 0.5
coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5
current_sample = coeff1 * (current_sample - coeff2 * predicted)
if t > 0:
noise = torch.randn_like(current_sample)
sigma = (
(1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t]
) ** 0.5
current_sample += sigma * noise
imputed_samples[:, i] = current_sample.detach()
return imputed_samples
def forward(self, batch, is_train=1):
(
observed_data,
observed_mask,
observed_tp,
gt_mask,
for_pattern_mask,
_,
) = self.process_data(batch)
if is_train == 0:
cond_mask = gt_mask
elif self.target_strategy != "random":
cond_mask = self.get_hist_mask(
observed_mask, for_pattern_mask=for_pattern_mask
)
else:
cond_mask = self.get_randmask(observed_mask)
side_info = self.get_side_info(observed_tp, cond_mask)
loss_func = self.calc_loss if is_train == 1 else self.calc_loss_valid
return loss_func(observed_data, cond_mask, observed_mask, side_info, is_train)
def evaluate(self, batch, n_samples):
(
observed_data,
observed_mask,
observed_tp,
gt_mask,
_,
cut_length,
) = self.process_data(batch)
with torch.no_grad():
cond_mask = gt_mask
target_mask = observed_mask - cond_mask
side_info = self.get_side_info(observed_tp, cond_mask)
samples = self.impute(observed_data, cond_mask, side_info, n_samples)
for i in range(len(cut_length)): # to avoid double evaluation
target_mask[i, ..., 0 : cut_length[i].item()] = 0
return samples, observed_data, target_mask, observed_mask, observed_tp