TSEditor / models /CSDI /tiffusion.py
PeterYu's picture
update
a785d5a
import math
import torch
import torch.nn.functional as F
from torch import nn
from einops import reduce
from tqdm.auto import tqdm
from functools import partial
from ..model_utils import default, identity, extract
from .control import *
from .diff_csdi import diff_CSDI
from .csdi import CSDI_base
import numpy as np
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
class Tiffusion(nn.Module):
def __init__(
self,
seq_length,
feature_size,
n_layer_enc=3,
n_layer_dec=6,
d_model=None,
timesteps=1000,
sampling_timesteps=None,
loss_type="l1",
beta_schedule="cosine",
n_heads=4,
mlp_hidden_times=4,
eta=0.0,
attn_pd=0.0,
resid_pd=0.0,
kernel_size=None,
padding_size=None,
use_ff=True,
reg_weight=None,
control_signal={},
moving_average=False,
is_unconditional=False,
target_strategy="mix",
**kwargs,
):
super(Tiffusion, self).__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.eta, self.use_ff = eta, use_ff
self.seq_length = seq_length
self.feature_size = feature_size
self.ff_weight = default(reg_weight, math.sqrt(self.seq_length) / 5)
self.sum_weight = default(reg_weight, math.sqrt(self.seq_length // 10) / 50)
self.training_control_signal = control_signal # training control signal
self.moving_average = moving_average
self.is_unconditional = is_unconditional
self.target_strategy = target_strategy
self.target_strategy = "random"
config = {
"model": {
"timeemb": 128,
"featureemb": 16,
"is_unconditional": False,
"target_strategy": "mix",
},
"diffusion": {
"layers": 3,
"channels": 64,
"nheads": 8,
"diffusion_embedding_dim": 128,
"is_linear": False,
"beta_start": 0.0001,
"beta_end": 0.5,
"schedule": "quad",
"num_steps": 50,
}
}
self.emb_time_dim = config["model"]["timeemb"]
self.emb_feature_dim = config["model"]["featureemb"]
self.is_unconditional = config["model"]["is_unconditional"]
self.target_strategy = config["model"]["target_strategy"]
# parameters for diffusion models
config_diff = config["diffusion"]
self.num_steps = config_diff["num_steps"]
if config_diff["schedule"] == "quad":
self.beta = np.linspace(
config_diff["beta_start"] ** 0.5, config_diff["beta_end"] ** 0.5, self.num_steps
) ** 2
elif config_diff["schedule"] == "linear":
self.beta = np.linspace(
config_diff["beta_start"], config_diff["beta_end"], self.num_steps
)
self.alpha_hat = 1 - self.beta
self.alpha = np.cumprod(self.alpha_hat)
self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1)
self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim
if self.is_unconditional == False:
self.emb_total_dim += 1 # for conditional mask
self.target_dim = feature_size
print(feature_size)
self.embed_layer = nn.Embedding(
num_embeddings=self.target_dim
, embedding_dim=self.emb_feature_dim
)
self.diffmodel = diff_CSDI(
{
"layers": 3,
"channels": 64,
"nheads": 8,
"diffusion_embedding_dim": 128,
"is_linear": False,
"beta_start": 0.0001,
"beta_end": 0.5,
"schedule": "quad",
"num_steps": 50,
"side_dim": self.emb_total_dim
},
(1 if self.is_unconditional == True else 2)
)
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
else:
raise ValueError(f"unknown beta schedule {beta_schedule}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
# sampling related parameters
self.sampling_timesteps = default(
sampling_timesteps, timesteps
) # default num sampling timesteps to number of timesteps at training
assert self.sampling_timesteps <= timesteps
self.fast_sampling = self.sampling_timesteps < timesteps
# helper function to register buffer from float64 to float32
register_buffer = lambda name, val: self.register_buffer(
name, val.to(torch.float32)
)
register_buffer("betas", betas)
register_buffer("alphas_cumprod", alphas_cumprod)
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
register_buffer(
"sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
)
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
register_buffer(
"sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer("posterior_variance", posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer(
"posterior_log_variance_clipped",
torch.log(posterior_variance.clamp(min=1e-20)),
)
register_buffer(
"posterior_mean_coef1",
betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
)
register_buffer(
"posterior_mean_coef2",
(1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod),
)
# calculate reweighting
register_buffer(
"loss_weight",
torch.sqrt(alphas) * torch.sqrt(1.0 - alphas_cumprod) / betas / 100,
)
def predict_noise_from_start(self, x_t, t, x0):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0
) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(
self.posterior_log_variance_clipped, t, x_t.shape
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def output(self, x, t, padding_masks=None, control_signal=None):
"""Modified output function to work with CSDI"""
if isinstance(t, int):
t = torch.tensor([t]).to(x.device)
# Prepare side info
observed_tp = torch.arange(x.shape[1], device=x.device).float()
observed_tp = observed_tp.unsqueeze(0).expand(x.shape[0], -1)
side_info = self.get_side_info(observed_tp, padding_masks)
# Get model prediction
predicted, _ = self.diffmodel(x, side_info, t)
return predicted
def generate_mts(self, batch_size=16):
feature_size, seq_length = self.feature_size, self.seq_length
sample_fn = self.fast_sample if self.fast_sampling else self.sample
return sample_fn((batch_size, seq_length, feature_size))
def generate_mts_infill(self, target, partial_mask=None, clip_denoised=True, model_kwargs=None):
"""Improved method for conditional generation"""
with torch.no_grad():
# Setup inputs
observed_tp = torch.arange(target.shape[1], device=target.device).float()
observed_tp = observed_tp.unsqueeze(0).expand(target.shape[0], -1)
# Generate side info
side_info = self.get_side_info(observed_tp, partial_mask)
# Sample using CSDI imputation
samples = self.impute(
observed_data=target,
cond_mask=partial_mask,
side_info=side_info,
n_samples=1
)
return samples.squeeze(1)
# def fast_sample_infill_float_mask(
# self,
# shape,
# target: torch.Tensor, # target time series # [B, L, C]
# sampling_timesteps,
# partial_mask: torch.Tensor = None, # float mask between 0 and 1 # [B, L, C]
# clip_denoised=True,
# model_kwargs=None,
# ):
# batch, device, total_timesteps, eta = (
# shape[0],
# self.betas.device,
# self.num_timesteps,
# self.eta,
# )
# # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
# times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)
# times = list(reversed(times.int().tolist()))
# time_pairs = list(
# zip(times[:-1], times[1:])
# ) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
# # Initialize with noise
# img = torch.randn(shape, device=device) # [B, L, C]
# for time, time_next in tqdm(
# time_pairs, desc="conditional sampling loop time step"
# ):
# time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
# # pred_noise, x_start, *_ = self.model_predictions(
# # img,
# # time_cond,
# # clip_x_start=clip_denoised,
# # control_signal=model_kwargs.get("model_control_signal", {}),
# # )
# # x, t, clip_x_start=False, padding_masks=None, control_signal=None
# # if padding_masks is None:
# padding_masks = torch.ones(
# img.shape[0], self.seq_length, dtype=bool, device=img.device
# )
# maybe_clip = (
# partial(torch.clamp, min=-1.0, max=1.0) if clip_denoised else identity
# )
# # def output(self, x, t, padding_masks=None, control_signal=None):
# # """Modified output function to work with CSDI"""
# # if isinstance(t, int):
# # t = torch.tensor([t]).to(x.device)
# # # Prepare side info
# # observed_tp = torch.arange(x.shape[1], device=x.device).float()
# # observed_tp = observed_tp.unsqueeze(0).expand(x.shape[0], -1)
# # side_info = self.get_side_info(observed_tp, padding_masks)
# # # Get model prediction
# # predicted, _ = self.diffmodel(x, side_info, t)
# # return predicted
# predicted, _ = self.diffmodel(img, time_cond)
# coeff1 = 1 / self.alpha_hat[time] ** 0.5
# coeff2 = (1 - self.alpha_hat[time]) / (1 - self.alpha[time]) ** 0.5
# x_start = coeff1 * (img - coeff2 * predicted)
# # x_start = self.output(img, time_cond, padding_masks)
# x_start = maybe_clip(x_start)
# pred_noise = self.predict_noise_from_start(img, time_cond, x_start)
# # return pred_noise, x_start
# if time_next < 0:
# img = x_start
# continue
# # Compute the predicted mean
# alpha = self.alphas_cumprod[time]
# alpha_next = self.alphas_cumprod[time_next]
# sigma = (
# eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
# )
# c = (1 - alpha_next - sigma**2).sqrt()
# noise = torch.randn_like(img)
# pred_mean = x_start * alpha_next.sqrt() + c * pred_noise
# img = pred_mean + sigma * noise
# # # Langevin Dynamics part for additional gradient updates
# # img = self.langevin_fn(
# # sample=img,
# # mean=pred_mean,
# # sigma=sigma,
# # t=time_cond,
# # tgt_embs=target,
# # partial_mask=partial_mask,
# # enable_float_mask=True,
# # **model_kwargs,
# # )
# img = img * (1 - partial_mask) + target * partial_mask
# img = img * (1 - partial_mask) + target * partial_mask
# return img
def langevin_fn(
self,
coef,
partial_mask,
tgt_embs,
learning_rate,
sample,
mean,
sigma,
t,
coef_=0.0,
gradient_control_signal={},
model_control_signal={},
side_info=None,
**kwargs,
):
# we thus run more gradient updates at large diffusion step t to guide the generation then
# reduce the number of gradient steps in stages to accelerate sampling.
if t[0].item() < self.num_timesteps * 0.02 :
K = 0
elif t[0].item() > self.num_timesteps * 0.9:
K = 3
elif t[0].item() > self.num_timesteps * 0.75:
K = 2
learning_rate = learning_rate * 0.5
else:
K = 1
learning_rate = learning_rate * 0.25
input_embs_param = torch.nn.Parameter(sample)
# 获取时间相关的权重调整因子
time_weight = get_time_dependent_weights(t[0], self.num_timesteps)
with torch.enable_grad():
for iteration in range(K):
# x_i+1 = x_i + noise * grad(logp(x_i)) + sqrt(2*noise) * z_i
optimizer = torch.optim.Adagrad([input_embs_param], lr=learning_rate)
optimizer.zero_grad()
# x_start = self.output(
# x=input_embs_param,
# t=t,
# control_signal=model_control_signal,
# )
# Prepare model input
# if self.is_unconditional:
# diff_input = cond_mask * observed_data + (1.0 - cond_mask) * current_sample
# diff_input = diff_input.unsqueeze(1)
# else:
# cond_obs = (cond_mask * observed_data).unsqueeze(1)
# noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1)
# diff_input = torch.cat([cond_obs, noisy_target], dim=1)
if self.is_unconditional:
diff_input = input_embs_param.unsqueeze(1)
else:
cond_obs = (partial_mask * tgt_embs).unsqueeze(1)
noisy_target = ((1 - partial_mask) * input_embs_param).unsqueeze(1)
diff_input = torch.cat([cond_obs, noisy_target], dim=1)
x_start, _ = self.diffmodel(diff_input, side_info, t)
if sigma.mean() == 0:
logp_term = (
coef * ((mean - input_embs_param) ** 2 / 1.0).mean(dim=0).sum()
)
# determine the partical_mask is float
if kwargs.get("enable_float_mask", False):
infill_loss = (x_start * (partial_mask) - tgt_embs * (partial_mask)) ** 2
else:
infill_loss = (x_start[partial_mask] - tgt_embs[partial_mask]) ** 2
infill_loss = infill_loss.mean(dim=0).sum()
else:
logp_term = (
coef
* ((mean - input_embs_param) ** 2 / sigma).mean(dim=0).sum()
)
if kwargs.get("enable_float_mask", False):
infill_loss = (x_start * (partial_mask) - tgt_embs * (partial_mask)) ** 2
else:
infill_loss = (x_start[partial_mask] - tgt_embs[partial_mask]) ** 2
infill_loss = (infill_loss / sigma.mean()).mean(dim=0).sum()
gradient_scale = gradient_control_signal.get("gradient_scale", 1.0) # 全局梯度缩放因子
control_loss = 0
auc_sum, peak_points, bar_regions, target_freq = \
gradient_control_signal.get("auc"), gradient_control_signal.get("peak_points"), gradient_control_signal.get("bar_regions"), gradient_control_signal.get("target_freq")
# 1. 原有的sum控制
if auc_sum is not None:
sum_weight = gradient_control_signal.get("auc_weight", 1.0) * time_weight
auc_loss = - sum_weight * sum_guidance(
x=input_embs_param,
t=t,
target_sum=auc_sum,
gradient_scale=gradient_scale,
segments=gradient_control_signal.get("segments", ())
)
control_loss += auc_loss
# 峰值引导
if peak_points is not None:
peak_weight = gradient_control_signal.get("peak_weight", 1.0) * time_weight
peak_loss = - peak_weight * peak_guidance(
x=input_embs_param,
t=t,
peak_points=peak_points,
window_size=gradient_control_signal.get("peak_window_size", 5),
alpha_1=gradient_control_signal.get("peak_alpha_1", 1.2),
gradient_scale=gradient_scale
)
control_loss += peak_loss
# 区间引导
if bar_regions is not None:
bar_weight = gradient_control_signal.get("bar_weight", 1.0) * time_weight
bar_loss = -bar_weight * bar_guidance(
x=input_embs_param,
t=t,
bar_regions=bar_regions,
gradient_scale=gradient_scale
)
control_loss += bar_loss
# 频率引导
if target_freq is not None:
freq_weight = gradient_control_signal.get("freq_weight", 1.0) * time_weight
freq_loss = -freq_weight * frequency_guidance(
x=input_embs_param,
t=t,
target_freq=target_freq,
freq_weight=freq_weight,
gradient_scale=gradient_scale
)
control_loss += freq_loss
loss = logp_term + infill_loss + control_loss
loss.backward()
optimizer.step()
torch.nn.utils.clip_grad_norm_([input_embs_param], gradient_control_signal.get("max_grad_norm", 1.0))
epsilon = torch.randn_like(input_embs_param.data)
noise_scale = coef_ * sigma.mean().item()
input_embs_param = torch.nn.Parameter(
(
input_embs_param.data + noise_scale * epsilon
).detach()
)
if kwargs.get("enable_float_mask", False):
sample = sample * partial_mask + input_embs_param.data * (1 - partial_mask)
else:
sample[~partial_mask] = input_embs_param.data[~partial_mask]
return sample
def predict_weighted_points(
self,
observed_points: torch.Tensor,
observed_mask: torch.Tensor,
coef=1e-1,
stepsize=1e-1,
sampling_steps=50,
**kargs,
):
model_kwargs = {}
model_kwargs["coef"] = coef
model_kwargs["learning_rate"] = stepsize
model_kwargs = {**model_kwargs, **kargs}
assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1"
x = observed_points.unsqueeze(0)
float_mask = observed_mask.unsqueeze(0) # x != 0, 1 for observed, 0 for missing, bool tensor
binary_mask = float_mask.clone()
binary_mask[binary_mask > 0] = 1
x = x * 2 - 1 # normalize
self.device = x.device
x, float_mask, binary_mask = x.to(self.device), float_mask.to(self.device), binary_mask.to(self.device)
if sampling_steps == self.num_timesteps:
print("normal sampling")
raise NotImplementedError
sample = self.ema.ema_model.sample_infill_float_mask(
shape=x.shape,
target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing
partial_mask=float_mask,
model_kwargs=model_kwargs,
)
# x: partially noise : (batch_size, seq_length, feature_dim)
else:
print("fast sampling")
sample = self.fast_sample_infill_float_mask(
shape=x.shape,
target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing
partial_mask=float_mask,
model_kwargs=model_kwargs,
sampling_timesteps=sampling_steps,
)
# unnormalize
sample = (sample + 1) / 2
return sample.squeeze(0).detach().cpu().numpy()
def forward(self, x, **kwargs):
"""Modified forward pass for CSDI training"""
# Convert input from [B, C, L] to [B, L, C]
observed_data = x.permute(0, 2, 1)
observed_mask = kwargs.get("observed_mask", torch.ones_like(observed_data))
observed_tp = torch.arange(observed_data.shape[1], device=x.device).float()
observed_tp = observed_tp.unsqueeze(0).expand(x.shape[0], -1)
# Generate masks
is_train = kwargs.get("is_train", 1)
if is_train:
cond_mask = self.get_randmask(observed_mask)
else:
gt_mask = kwargs.get("gt_mask", observed_mask.clone())
if "pred_length" in kwargs:
gt_mask[:,:,-kwargs["pred_length"]:] = 0
cond_mask = gt_mask
# Get side info and calculate loss
side_info = self.get_side_info(observed_tp, cond_mask)
loss_func = self.calc_loss if is_train else self.calc_loss_valid
return loss_func(observed_data, cond_mask, observed_mask, side_info, is_train)
def time_embedding(self, pos, d_model=128):
pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(pos.device)
position = pos.unsqueeze(2)
div_term = 1 / torch.pow(
10000.0, torch.arange(0, d_model, 2).to(pos.device) / d_model
)
pe[:, :, 0::2] = torch.sin(position * div_term)
pe[:, :, 1::2] = torch.cos(position * div_term)
return pe
def get_randmask(self, observed_mask):
rand_for_mask = torch.rand_like(observed_mask) * observed_mask
rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1)
for i in range(len(observed_mask)):
sample_ratio = np.random.rand() # missing ratio
num_observed = observed_mask[i].sum().item()
num_masked = round(num_observed * sample_ratio)
rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices] = -1
cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float()
return cond_mask
def get_hist_mask(self, observed_mask, for_pattern_mask=None):
if for_pattern_mask is None:
for_pattern_mask = observed_mask
if self.target_strategy == "mix":
rand_mask = self.get_randmask(observed_mask)
cond_mask = observed_mask.clone()
for i in range(len(cond_mask)):
mask_choice = np.random.rand()
if self.target_strategy == "mix" and mask_choice > 0.5:
cond_mask[i] = rand_mask[i]
else: # draw another sample for histmask (i-1 corresponds to another sample)
cond_mask[i] = cond_mask[i] * for_pattern_mask[i - 1]
return cond_mask
def get_test_pattern_mask(self, observed_mask, test_pattern_mask):
return observed_mask * test_pattern_mask
def get_side_info(self, observed_tp, cond_mask):
B, K, L = cond_mask.shape
time_embed = self.time_embedding(observed_tp, self.emb_time_dim) # (B,L,emb) torch.Size([64, 24, 128])
# print(time_embed.shape)
time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1)
feature_embed = self.embed_layer(
torch.arange(self.target_dim).to(observed_tp.device)
) # (K, emb)
# print("feature_embed",feature_embed.shape)
feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1)
# torch.Size([64, 24, 24, 128])[64, 28, 28, 16])
# print(time_embed.shape, feature_embed.shape)
side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,*)
side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L)
if self.is_unconditional == False:
side_mask = cond_mask.unsqueeze(1) # (B,1,K,L)
side_info = torch.cat([side_info, side_mask], dim=1)
return side_info
def calc_loss_valid(
self, observed_data, cond_mask, observed_mask, side_info, is_train
):
loss_sum = 0
for t in range(self.num_steps): # calculate loss for all t
loss = self.calc_loss(
observed_data, cond_mask, observed_mask, side_info, is_train, set_t=t
)
loss_sum += loss.detach()
return loss_sum / self.num_steps
def calc_loss(
self, observed_data, cond_mask, observed_mask, side_info, is_train, set_t=-1
):
B, K, L = observed_data.shape
if is_train != 1: # for validation
t = (torch.ones(B) * set_t).long().to(self.device)
else:
t = torch.randint(0, self.num_steps, [B]).to(self.device)
current_alpha = self.alpha_torch[t] # (B,1,1)
noise = torch.randn_like(observed_data)
noisy_data = (current_alpha ** 0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise
total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask)
predicted, _ = self.diffmodel(total_input, side_info, t) # (B,K,L)
target_mask = observed_mask - cond_mask
residual = (noise - predicted) * target_mask
num_eval = target_mask.sum()
loss = (residual ** 2).sum() / (num_eval if num_eval > 0 else 1)
return loss
def evaluate(self, batch, n_samples):
(
observed_data, # [B, L, K]
observed_mask, # 1 for observed, 0 for missing
observed_tp, # [0, 1, 2, ..., L-1]
gt_mask,
_,
cut_length,
) = self.process_data(batch)
with torch.no_grad():
cond_mask = gt_mask
target_mask = observed_mask - cond_mask # 1 for missing, 0 for observed
side_info = self.get_side_info(observed_tp, cond_mask)
samples = self.impute(observed_data, cond_mask, side_info, n_samples)
for i in range(len(cut_length)): # to avoid double evaluation
target_mask[i, ..., 0 : cut_length[i].item()] = 0
return samples, observed_data, target_mask, observed_mask, observed_tp
def impute(self, observed_data, cond_mask, side_info, n_samples):
"""Modified impute function with Langevin dynamics and control signals"""
B, K, L = observed_data.shape
imputed_samples = torch.zeros(B, n_samples, K, L).to(self.device)
# Setup sampling parameters
# times = torch.linspace(-1, self.num_steps - 1, steps=self.sampling_timesteps + 1)
# times = list(reversed(times.int().tolist()))
# time_pairs = list(zip(times[:-1], times[1:]))
for i in range(n_samples):
# Initialize with noise
current_sample = torch.randn_like(observed_data)
# for t, time_next in tqdm(time_pairs, desc="Imputation sampling"):
for t in range(self.num_steps - 1, -1, -1):
# Prepare time condition
# time_cond = torch.full((B,), time, device=self.device, dtype=torch.long)
time_cond = torch.tensor([t]).to(self.device)
# Prepare model input
if self.is_unconditional:
diff_input = cond_mask * observed_data + (1.0 - cond_mask) * current_sample
diff_input = diff_input.unsqueeze(1)
else:
cond_obs = (cond_mask * observed_data).unsqueeze(1)
noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1)
diff_input = torch.cat([cond_obs, noisy_target], dim=1)
predicted, _ = self.diffmodel(diff_input, side_info, torch.tensor([t]).to(self.device))
coeff1 = 1 / self.alpha_hat[t] ** 0.5
coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5
current_sample = coeff1 * (current_sample - coeff2 * predicted)
if t > 0:
noise = torch.randn_like(current_sample)
sigma = (
(1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t]
) ** 0.5
current_sample += sigma * noise
# # Get prediction
# predicted = self.diffmodel(diff_input, side_info, time_cond)[0]
# if time_next < 0:
# current_sample = predicted
# continue
# # Update sample with noise
# alpha = self.alpha[time]
# alpha_next = self.alpha[time_next]
# # Compute transition parameters
# sigma = self.eta * ((1 - alpha_next) / (1 - alpha) * (1 - alpha / alpha_next)).sqrt()
# c = (1 - alpha_next - sigma**2).sqrt()
# # Update sample
# noise = torch.randn_like(current_sample)
# pred_mean = predicted * alpha_next.sqrt() + c * current_sample
# current_sample = pred_mean + sigma * noise
# # # Apply Langevin dynamics and control signals
# # if model_kwargs is not None:
# # current_sample = self.langevin_fn(
# # sample=current_sample,
# # mean=pred_mean,
# # sigma=sigma,
# # t=time_cond,
# # tgt_embs=observed_data,
# # partial_mask=cond_mask,
# # enable_float_mask=True,
# # side_info=side_info,
# # **model_kwargs
# # )
# # Apply conditioning
# current_sample = current_sample * (1 - cond_mask) + observed_data * cond_mask
imputed_samples[:, i] = current_sample
return imputed_samples
def fast_sample_infill_float_mask(
self,
shape,
target: torch.Tensor,
sampling_timesteps,
partial_mask: torch.Tensor = None,
clip_denoised=True,
model_kwargs=None,
):
"""Simplified fast sampling that uses improved impute function"""
batch = shape[0]
device = self.device
target = target.permute(0, 2, 1)
partial_mask = partial_mask.permute(0, 2, 1)
# Generate timepoints
observed_tp = torch.arange(shape[1], device=device).float()
observed_tp = observed_tp.unsqueeze(0).expand(batch, -1)
# Get side info
side_info = self.get_side_info(observed_tp, partial_mask)
# Use modified impute function with control signals
samples = self.impute(
observed_data=target,
cond_mask=partial_mask,
side_info=side_info,
n_samples=1,
)
return samples.squeeze(1).permute(0, 2, 1)