TSEditor / models /Tiffusion /tiffusion_backup.py
PeterYu's picture
update
2875fe6
import math
import torch
import torch.nn.functional as F
from torch import nn
from einops import reduce
from tqdm.auto import tqdm
from functools import partial
from .transformer import Transformer
from ..model_utils import default, identity, extract
from .control import *
import mlflow.pyfunc
import mlflow
from mlflow.models import infer_signature
# import matplotlib.pyplot as plt
# images_cache = []
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
class Tiffusion(nn.Module):
def __init__(
self,
seq_length,
feature_size,
n_layer_enc=3,
n_layer_dec=6,
d_model=None,
timesteps=1000,
sampling_timesteps=None,
loss_type="l1",
beta_schedule="cosine",
n_heads=4,
mlp_hidden_times=4,
eta=0.0,
attn_pd=0.0,
resid_pd=0.0,
kernel_size=None,
padding_size=None,
use_ff=True,
reg_weight=None,
control_signal={},
moving_average=False,
**kwargs,
):
super(Tiffusion, self).__init__()
self.eta, self.use_ff = eta, use_ff
self.seq_length = seq_length
self.feature_size = feature_size
self.ff_weight = default(reg_weight, math.sqrt(self.seq_length) / 5)
self.sum_weight = default(reg_weight, math.sqrt(self.seq_length // 10) / 50)
self.training_control_signal = control_signal # training control signal
self.moving_average = moving_average
self.model: Transformer = Transformer(
n_feat=feature_size,
n_channel=seq_length,
n_layer_enc=n_layer_enc,
n_layer_dec=n_layer_dec,
n_heads=n_heads,
attn_pdrop=attn_pd,
resid_pdrop=resid_pd,
mlp_hidden_times=mlp_hidden_times,
max_len=seq_length,
n_embd=d_model,
conv_params=[kernel_size, padding_size],
**kwargs,
)
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
else:
raise ValueError(f"unknown beta schedule {beta_schedule}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
# sampling related parameters
self.sampling_timesteps = default(
sampling_timesteps, timesteps
) # default num sampling timesteps to number of timesteps at training
assert self.sampling_timesteps <= timesteps
self.fast_sampling = self.sampling_timesteps < timesteps
# helper function to register buffer from float64 to float32
register_buffer = lambda name, val: self.register_buffer(
name, val.to(torch.float32)
)
register_buffer("betas", betas)
register_buffer("alphas_cumprod", alphas_cumprod)
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
register_buffer(
"sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
)
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
register_buffer(
"sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer("posterior_variance", posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer(
"posterior_log_variance_clipped",
torch.log(posterior_variance.clamp(min=1e-20)),
)
register_buffer(
"posterior_mean_coef1",
betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
)
register_buffer(
"posterior_mean_coef2",
(1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod),
)
# calculate reweighting
register_buffer(
"loss_weight",
torch.sqrt(alphas) * torch.sqrt(1.0 - alphas_cumprod) / betas / 100,
)
def predict_noise_from_start(self, x_t, t, x0):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0
) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(
self.posterior_log_variance_clipped, t, x_t.shape
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def output(self, x, t, padding_masks=None, control_signal=None):
# if ss:=control_signal.get("sum") is not None and len(ss.shape) == 1:
# bs = x.shape[0]
# control_signal["sum"] = ss.unsqueeze(0).repeat(bs, 1)
# print("control_signal", control_signal)
trend, season = self.model(
x, t, padding_masks=padding_masks, control_signal=control_signal
)
model_output = trend + season
return model_output
def model_predictions(
self, x, t, clip_x_start=False, padding_masks=None, control_signal=None
):
if padding_masks is None:
padding_masks = torch.ones(
x.shape[0], self.seq_length, dtype=bool, device=x.device
)
maybe_clip = (
partial(torch.clamp, min=-1.0, max=1.0) if clip_x_start else identity
)
x_start = self.output(x, t, padding_masks, control_signal=control_signal)
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
return pred_noise, x_start
def p_mean_variance(self, x, t, clip_denoised=True, control_signal=None):
_, x_start = self.model_predictions(x, t, control_signal=control_signal)
if clip_denoised:
x_start.clamp_(-1.0, 1.0)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_start, x_t=x, t=t
)
return model_mean, posterior_variance, posterior_log_variance, x_start
def p_sample(self, x, t: int, clip_denoised=True, control_signal=None):
batched_times = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(
x=x, t=batched_times, clip_denoised=clip_denoised, control_signal=control_signal
)
noise = torch.randn_like(x) if t > 0 else 0.0 # no noise if t == 0
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
return pred_img, x_start
@torch.no_grad()
def sample(self, shape, control_signal=None):
device = self.betas.device
img = torch.randn(shape, device=device)
for t in tqdm(
reversed(range(0, self.num_timesteps)),
desc="sampling loop time step",
total=self.num_timesteps,
):
img, _ = self.p_sample(img, t, control_signal=control_signal)
return img
@torch.no_grad()
def fast_sample(self, shape, clip_denoised=True, model_kwargs=None,
):
batch, device, total_timesteps, sampling_timesteps, eta = (
shape[0],
self.betas.device,
self.num_timesteps,
self.sampling_timesteps,
self.eta,
)
# [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)
times = list(reversed(times.int().tolist()))
time_pairs = list(
zip(times[:-1], times[1:])
) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img = torch.randn(shape, device=device)
for time, time_next in tqdm(time_pairs, desc="sampling loop time step"):
time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
pred_noise, x_start, *_ = self.model_predictions(
img, time_cond, clip_x_start=clip_denoised,
control_signal=model_kwargs.get("model_control_signal", {}) if model_kwargs else {}
)
if time_next < 0:
img = x_start
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = (
eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
)
c = (1 - alpha_next - sigma**2).sqrt()
noise = torch.randn_like(img)
img = x_start * alpha_next.sqrt() + c * pred_noise + sigma * noise
return img
def generate_mts(self, batch_size=16):
feature_size, seq_length = self.feature_size, self.seq_length
sample_fn = self.fast_sample if self.fast_sampling else self.sample
return sample_fn((batch_size, seq_length, feature_size))
def generate_mts_infill(self, target, partial_mask=None, clip_denoised=True, model_kwargs=None):
sample_fn = self.fast_sample_infill_float_mask # if self.fast_sampling else self.sample_infill
print("model_kwargs", model_kwargs)
print("partial_mask", partial_mask.shape)
print("target", target.shape)
return sample_fn(
shape=target.shape,
target=target,
sampling_timesteps=self.sampling_timesteps,
partial_mask=partial_mask,
clip_denoised=clip_denoised,
model_kwargs=model_kwargs
)
@property
def loss_fn(self):
if self.loss_type == "l1":
return F.l1_loss
elif self.loss_type == "l2":
return F.mse_loss
else:
raise ValueError(f"invalid loss type {self.loss_type}")
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
@torch.no_grad()
def calculate_dynamic_window(self, t: torch.Tensor) -> torch.Tensor:
# Batch-wise time point normalization
t_min = 0 # t.min()
t_max = 500 # t.max()
# t_normalized = (t - t_min) / (t_max - t_min)
# Compute window sizes
# windows = ((t_normalized.exp2() - 1) * 15 // 1 + 1).long()
# plt.scatter(t, ( (5 ** ((t - 0) / 1000))) * 15 // 7 + 1)
windows = ((5 ** ( t / 500)) * 15 // 5 - 2).long()
return windows
@torch.no_grad()
def torch_moving_average(self, bs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Compute moving average for a time series tensor with dynamically calculated
window sizes for each sample.
Parameters:
-----------
bs : torch.Tensor
Input time series tensor of shape (batch_size, sequence_length, features)
t : torch.Tensor
Time points tensor of shape (batch_size, sequence_length)
Returns:
--------
torch.Tensor
Moving average tensor with the same shape as input
"""
# Get tensor dimensions
batch_size, total_seq_length, num_features = bs.shape
# Calculate dynamic window sizes for each sample
windows = self.calculate_dynamic_window(t)
# Create output tensor initialized with the original values
moving_avg = bs.clone()
# Compute moving average for each sample and time point
for b in range(batch_size):
for i in range(total_seq_length):
# Get the window size for this sample and time point
current_window = windows[b].item()
# Determine the start and end of the window
start = max(0, i - current_window + 1)
window = bs[b:b+1, start:i+1, :]
# Compute average along the time dimension
window_avg = window.mean(dim=1)
# Replace values where we have enough previous steps
if i >= current_window - 1:
moving_avg[b, i, :] = window_avg
return moving_avg
def _train_loss(
self,
x_start,
t,
target=None,
noise=None,
padding_masks=None,
control_signal=None,
):
noise = default(noise, lambda: torch.randn_like(x_start))
if target is None:
target = x_start
x = self.q_sample(x_start=x_start, t=t, noise=noise) # noise sample
# with torch.no_grad():
# if control_signal is None:
# control_signal = {
# "sum": target.mean(1),
# "top-peak-position": target.topk(self.seq_length // 20, dim=1)[1],
# } # .unsqueeze(-1)
# # elif self.control_sum:
# # ss = control_signal.get("sum")
# # if len(ss.shape) == 1:
# # bs = x.shape[0]
# # control_signal["sum"] = ss.unsqueeze(0).repeat(bs, 1)
# # control_signal = control_signal
# else:
# control_signal = {}
model_out = self.output(x, t, padding_masks, control_signal=control_signal)
# moving average according to the timestamp t, t larger means more stable, less noise
if self.moving_average:
target = self.torch_moving_average(target.cpu(), t.cpu()).to(model_out.device)
train_loss = self.loss_fn(model_out, target, reduction="none")
fourier_loss = torch.tensor([0.0])
if self.use_ff:
fft1 = torch.fft.fft(model_out.transpose(1, 2), norm="forward")
fft2 = torch.fft.fft(target.transpose(1, 2), norm="forward")
fft1, fft2 = fft1.transpose(1, 2), fft2.transpose(1, 2)
fourier_loss = self.loss_fn(
torch.real(fft1), torch.real(fft2), reduction="none"
) + self.loss_fn(torch.imag(fft1), torch.imag(fft2), reduction="none")
train_loss += self.ff_weight * fourier_loss
# if self.control_sum:
# train_loss += (
# self.loss_fn(model_out[..., 0].sum(1), target[..., 0].sum(1))
# / self.seq_length
# )
# * self.sum_weight
train_loss = reduce(train_loss, "b ... -> b (...)", "mean")
train_loss = train_loss * extract(self.loss_weight, t, train_loss.shape)
return train_loss.mean()
# fmt: off
def forward(self, x, **kwargs):
b, c, n, device, feature_size, = *x.shape, x.device, self.feature_size
assert n == feature_size, f'number of variable must be {feature_size}'
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
return self._train_loss(x_start=x, t=t, **kwargs)
def return_components(self, x, t: int):
b, c, n, device, feature_size, = *x.shape, x.device, self.feature_size
assert n == feature_size, f'number of variable must be {feature_size}'
t = torch.tensor([t])
t = t.repeat(b).to(device)
x = self.q_sample(x, t)
trend, season, residual = self.model(x, t, return_res=True)
return trend, season, residual, x
# fmt: on
def fast_sample_infill_float_mask(
self,
shape,
target: torch.Tensor, # target time series # [B, L, C]
sampling_timesteps,
partial_mask: torch.Tensor = None, # float mask between 0 and 1 # [B, L, C]
clip_denoised=True,
model_kwargs=None,
):
batch, device, total_timesteps, eta = (
shape[0],
self.betas.device,
self.num_timesteps,
self.eta,
)
# [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)
times = list(reversed(times.int().tolist()))
time_pairs = list(
zip(times[:-1], times[1:])
) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
# Initialize with noise
img = torch.randn(shape, device=device) # [B, L, C]
for time, time_next in tqdm(
time_pairs, desc="conditional sampling loop time step"
):
time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
pred_noise, x_start, *_ = self.model_predictions(
img,
time_cond,
clip_x_start=clip_denoised,
control_signal=model_kwargs.get("model_control_signal", {}),
)
if time_next < 0:
img = x_start
continue
# Compute the predicted mean
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = (
eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
)
c = (1 - alpha_next - sigma**2).sqrt()
noise = torch.randn_like(img)
pred_mean = x_start * alpha_next.sqrt() + c * pred_noise
img = pred_mean + sigma * noise
# # Apply partial mask to the current sample
# if partial_mask is not None:
# target_t = self.q_sample(target, t=time_cond)
# img = img * (1.0 - partial_mask) + target_t * partial_mask
# Langevin Dynamics part for additional gradient updates
img = self.langevin_fn(
sample=img,
mean=pred_mean,
sigma=sigma,
t=time_cond,
tgt_embs=target,
partial_mask=partial_mask,
enable_float_mask=True,
**model_kwargs,
)
img = img * (1 - partial_mask) + target * partial_mask
img = img * (1 - partial_mask) + target * partial_mask
return img
def fast_sample_infill(
self,
shape,
target,
sampling_timesteps,
partial_mask=None,
clip_denoised=True,
model_kwargs=None,
):
batch, device, total_timesteps, eta = (
shape[0],
self.betas.device,
self.num_timesteps,
self.eta,
)
# [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)
times = list(reversed(times.int().tolist()))
time_pairs = list(
zip(times[:-1], times[1:])
) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img = torch.randn(shape, device=device)
for time, time_next in tqdm(
time_pairs, desc="conditional sampling loop time step"
):
time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
pred_noise, x_start, *_ = self.model_predictions(
img,
time_cond,
clip_x_start=clip_denoised,
control_signal=model_kwargs.get("model_control_signal", {}),
)
if time_next < 0:
img = x_start
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = (
eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
)
c = (1 - alpha_next - sigma**2).sqrt()
pred_mean = x_start * alpha_next.sqrt() + c * pred_noise
noise = torch.randn_like(img)
img = pred_mean + sigma * noise
img = self.langevin_fn(
sample=img,
mean=pred_mean,
sigma=sigma,
t=time_cond,
tgt_embs=target,
partial_mask=partial_mask,
# gradient_control_signal=model_kwargs.get("gradient_control_signal", {}),
# model_control_signal=model_kwargs.get("model_control_signal", {}),
**model_kwargs,
)
target_t = self.q_sample(target, t=time_cond)
img[partial_mask] = target_t[partial_mask]
img[partial_mask] = target[partial_mask]
return img
def sample_infill(
self,
shape,
target,
partial_mask=None,
clip_denoised=True,
model_kwargs=None,
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
"""
batch, device = shape[0], self.betas.device
img = torch.randn(shape, device=device)
for t in tqdm(
reversed(range(0, self.num_timesteps)),
desc="conditional sampling loop time step",
total=self.num_timesteps,
):
img = self.p_sample_infill(
x=img,
t=t,
clip_denoised=clip_denoised,
target=target,
partial_mask=partial_mask,
model_kwargs=model_kwargs,
)
img[partial_mask] = target[partial_mask]
return img
def p_sample_infill(
self,
x,
target,
t: int,
partial_mask=None,
clip_denoised=True,
model_kwargs=None,
):
b, *_, device = *x.shape, self.betas.device
batched_times = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long)
model_mean, _, model_log_variance, _ = self.p_mean_variance(
x=x, t=batched_times, clip_denoised=clip_denoised, control_signal=model_kwargs.get("model_control_signal", {})
# don't pass parameters to control signal, for model itself
# Otherwise pass: control_signal=model_kwargs.get("control_signal", {})
)
noise = torch.randn_like(x) if t > 0 else 0.0 # no noise if t == 0
sigma = (0.5 * model_log_variance).exp()
pred_img = model_mean + sigma * noise
pred_img = self.langevin_fn(
sample=pred_img,
mean=model_mean,
sigma=sigma,
t=batched_times,
tgt_embs=target,
partial_mask=partial_mask,
# control_signal=model_kwargs.get("gradient_control_signal", {}),
**model_kwargs,
)
# fix point (must passed points)
target_t = self.q_sample(target, t=batched_times)
pred_img[partial_mask] = target_t[partial_mask]
return pred_img
@staticmethod
def classifier_guidance(
x: torch.Tensor,
t: torch.Tensor,
y: torch.Tensor,
classifier: torch.nn.Module
):
with torch.enable_grad():
# 激活梯度计算
x_with_grad = x.detach().requires_grad_(True)
# 获取 log 形式的概率分布
logits = classifier(x_with_grad, t)
log_prob = F.log_softmax(logits, dim=-1)
# 选取出 y 对应的项
selected = log_prob[range(len(logits)), y.view(-1)]
# 计算梯度
return torch.autograd.grad(selected.sum(), x_with_grad)[0]
@staticmethod
def regression_guidance(
x: torch.Tensor,
t: torch.Tensor,
target_sum: torch.Tensor, # Target sum value
sigma: float = 1.0
):
"""
Compute gradient for guiding the sum of first channel to match target value
Args:
x: Input tensor [batch_size, channels, length] or [batch_size, length, channels]
t: Time steps
target_sum: Target sum value [batch_size]
sigma: Standard deviation for Gaussian likelihood
"""
# with torch.enable_grad():
# x_with_grad = x.detach().requires_grad_(True)
# normalize to 0, 1
# x_with_grad = (x + x.min()) / (x.max() - x.min())
# x_with_grad = x / 2 + 0.5 # [-1,1 to 0,1]
x_with_grad = x
# Calculate sum of first channel/feature
# Assuming x shape is [batch_size, channels, length] or [batch_size, length, channels]
if x_with_grad.dim() == 3:
if x_with_grad.shape[1] < x_with_grad.shape[2]: # [B, C, L]
current_sum = x_with_grad[:1, 0]
current_sum = current_sum / 2 + 0.5 # [-1, 1 to 0, 1]
print("Current Sum: ", current_sum.max().item(), current_sum.min().item())
current_sum = current_sum.sum(dim=1) # Sum over length
else: # [B, L, C]
current_sum = x_with_grad[:1, :, 0]
current_sum = current_sum / 2 + 0.5 # [-1, 1 to 0, 1]
print("Current Sum: ", current_sum.max().item(), current_sum.min().item())
current_sum = current_sum.sum(dim=1) # Sum over length
# Compute log probability under Gaussian distribution
sigma = torch.log(t) / 5
print("sigma", sigma)
if sigma.mean() == 0:
pred_std = torch.ones_like(current_sum)
else:
pred_std = torch.ones_like(current_sum) * sigma
log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \
(target_sum - current_sum)**2 / (2 * pred_std**2)
# print(target_sum, current_sum)
# print("Current Sum: ", current_sum.mean().item())
# print("Current Diff: ", (target_sum - current_sum).mean().item())
return log_prob.mean()
# return torch.autograd.grad(log_prob.sum(), x_with_grad)[0]
def langevin_fn(
self,
coef,
partial_mask,
tgt_embs,
learning_rate,
sample,
mean,
sigma,
t,
coef_=0.0,
gradient_control_signal={},
model_control_signal={},
**kwargs,
):
# we thus run more gradient updates at large diffusion step t to guide the generation then
# reduce the number of gradient steps in stages to accelerate sampling.
if t[0].item() < self.num_timesteps * 0.02 :
K = 0
elif t[0].item() > self.num_timesteps * 0.9:
K = 3
elif t[0].item() > self.num_timesteps * 0.75:
K = 2
learning_rate = learning_rate * 0.5
else:
K = 1
learning_rate = learning_rate * 0.25
input_embs_param = torch.nn.Parameter(sample)
# 获取时间相关的权重调整因子
time_weight = get_time_dependent_weights(t[0], self.num_timesteps)
with torch.enable_grad():
for iteration in range(K):
# x_i+1 = x_i + noise * grad(logp(x_i)) + sqrt(2*noise) * z_i
optimizer = torch.optim.Adagrad([input_embs_param], lr=learning_rate)
optimizer.zero_grad()
x_start = self.output(
x=input_embs_param,
t=t,
control_signal=model_control_signal,
)
if sigma.mean() == 0:
logp_term = (
coef * ((mean - input_embs_param) ** 2 / 1.0).mean(dim=0).sum()
)
# determine the partical_mask is float
if kwargs.get("enable_float_mask", False):
infill_loss = (x_start * (partial_mask) - tgt_embs * (partial_mask)) ** 2
else:
infill_loss = (x_start[partial_mask] - tgt_embs[partial_mask]) ** 2
infill_loss = infill_loss.mean(dim=0).sum()
else:
logp_term = (
coef
* ((mean - input_embs_param) ** 2 / sigma).mean(dim=0).sum()
)
if kwargs.get("enable_float_mask", False):
infill_loss = (x_start * (partial_mask) - tgt_embs * (partial_mask)) ** 2
else:
infill_loss = (x_start[partial_mask] - tgt_embs[partial_mask]) ** 2
infill_loss = (infill_loss / sigma.mean()).mean(dim=0).sum()
# 第二个等号后面最后一项消失了,因为当我们要求模型生成“狗”的图像时,扩散过程始终
# 不变,对应的梯度也是0,可以抹掉。
# https://lichtung612.github.io/posts/3-diffusion-models/
# 第三个等号后面两项中,第一项是扩散模型本身的梯度引导,新增的只能是第二项,即classifier guidance只需要额外添加一个classifier的梯度来引导。
# 控制信号损失
gradient_scale = gradient_control_signal.get("gradient_scale", 1.0) # 全局梯度缩放因子
# Add regression guidance for sum constraint
control_loss = 0
# target_sum =
# normalize the sum to -1, 1
# seq_length = input_embs_param.shape[1]
# target_sum = ((target_sum / seq_length ) * 2 - 1) * seq_length
# if target_sum:=gradient_control_signal.get("sum") is not None:
# # print("sigma", sigma.shape, sigma, end=" ")
# reg_nll = self.regression_guidance(
# x=input_embs_param,
# t=t,
# target_sum=target_sum,
# sigma=sigma
# )
# control_loss += - gradient_control_signal.get("reg_weight", 1.0) * reg_nll * (5 - K) # (reg_gradient * ).sum()
# init control signal loss
auc_sum, peak_points, bar_regions, target_freq = \
gradient_control_signal.get("auc"), gradient_control_signal.get("peak_points"), gradient_control_signal.get("bar_regions"), gradient_control_signal.get("target_freq")
# 1. 原有的sum控制
if auc_sum is not None:
sum_weight = gradient_control_signal.get("auc_weight", 1.0) * time_weight
auc_loss = - sum_weight * sum_guidance(
x=input_embs_param,
t=t,
target_sum=auc_sum,
gradient_scale=gradient_scale,
segments=gradient_control_signal.get("segments", ())
)
control_loss += auc_loss
# 峰值引导
if peak_points is not None:
peak_weight = gradient_control_signal.get("peak_weight", 1.0) * time_weight
peak_loss = - peak_weight * peak_guidance(
x=input_embs_param,
t=t,
peak_points=peak_points,
window_size=gradient_control_signal.get("peak_window_size", 5),
alpha_1=gradient_control_signal.get("peak_alpha_1", 1.2),
gradient_scale=gradient_scale
)
control_loss += peak_loss
# 区间引导
if bar_regions is not None:
bar_weight = gradient_control_signal.get("bar_weight", 1.0) * time_weight
bar_loss = -bar_weight * bar_guidance(
x=input_embs_param,
t=t,
bar_regions=bar_regions,
gradient_scale=gradient_scale
)
control_loss += bar_loss
# 频率引导
if target_freq is not None:
freq_weight = gradient_control_signal.get("freq_weight", 1.0) * time_weight
freq_loss = -freq_weight * frequency_guidance(
x=input_embs_param,
t=t,
target_freq=target_freq,
freq_weight=freq_weight,
gradient_scale=gradient_scale
)
control_loss += freq_loss
loss = logp_term + infill_loss + control_loss
if iteration == 0: # Only print first iteration to avoid spam
# print(f"Losses - Diffusion: {logp_term:.4f}, Infill: {infill_loss:.4f}, Control: {control_loss:.4f}")
# if target_sum is not None:
# # Print current sum vs target for monitoring
# if x_start.shape[1] < x_start.shape[2]: # [B, C, L]
# current_sum = input_embs_param[:, 0].sum(dim=1)
# else: # [B, L, C]
# current_sum = input_embs_param[:, :, 0].sum(dim=1)
# print(f"Current sum: {current_sum.data}, Target sum: {target_sum}")
# print(f"Losses - Diffusion: {logp_term:.4f}\tInfill: {infill_loss:.4f}", end="\t")
# if auc_sum is not None:
# print(f"Sum Control: {auc_loss.item():.4f}", end="\t")
# if peak_points is not None:
# print(f"Peak Control: {peak_loss.item():.4f}", end="\t")
# if bar_regions is not None:
# print(f"Bar Control: {bar_loss.item():.4f}", end="\t")
# if target_freq is not None:
# print(f"Freq Control: {freq_loss.item():.4f}", end="\t")
# print()
pass
# loss = logp_term + infill_loss + auc_loss
# print(logp_term, infill_loss, auc_loss)
loss.backward()
optimizer.step()
torch.nn.utils.clip_grad_norm_([input_embs_param], gradient_control_signal.get("max_grad_norm", 1.0))
# add more noise
epsilon = torch.randn_like(input_embs_param.data)
noise_scale = coef_ * sigma.mean().item() # * 2
# noise_scale = noise_scale * time_weight # (1 - time_weight) # 随时间减少噪声
input_embs_param = torch.nn.Parameter(
(
input_embs_param.data + noise_scale * epsilon
).detach()
)
if kwargs.get("enable_float_mask", False):
sample = sample * partial_mask + input_embs_param.data * (1 - partial_mask)
else:
sample[~partial_mask] = input_embs_param.data[~partial_mask]
# if t[0].item() % 10 == 9:
# print("Sampled Image")
# images_cache.append(plt.plot(sample[0,:,0].detach().cpu().numpy())[0])
# if t[0].item() == 9:
# plt.show()
# images_cache.clear()
# plt.show()
# plt.savefig(f"sampled_{t[0].item()}.png")
# plt.plot(sample[0,:,0].detach().cpu().numpy())
# plt.show()
return sample
# def load_weights(self, model_path):
# data = torch.load(model_path, map_location="cuda:0", weights_only=True)
# self.load_state_dict(data["model"])
# print("Model weights loaded successfully")
def predict_weighted_points(
self,
observed_points: torch.Tensor,
observed_mask: torch.Tensor,
coef=1e-1,
stepsize=1e-1,
sampling_steps=50,
**kargs,
):
model_kwargs = {}
model_kwargs["coef"] = coef
model_kwargs["learning_rate"] = stepsize
model_kwargs = {**model_kwargs, **kargs}
assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1"
x = observed_points.unsqueeze(0)
float_mask = observed_mask.unsqueeze(0) # x != 0, 1 for observed, 0 for missing, bool tensor
binary_mask = float_mask.clone()
binary_mask[binary_mask > 0] = 1
x = x * 2 - 1 # normalize
self.device = x.device
x, float_mask, binary_mask = x.to(self.device), float_mask.to(self.device), binary_mask.to(self.device)
if sampling_steps == self.num_timesteps:
print("normal sampling")
raise NotImplementedError
sample = self.ema.ema_model.sample_infill_float_mask(
shape=x.shape,
target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing
partial_mask=float_mask,
model_kwargs=model_kwargs,
)
# x: partially noise : (batch_size, seq_length, feature_dim)
else:
print("fast sampling")
sample = self.fast_sample_infill_float_mask(
shape=x.shape,
target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing
partial_mask=float_mask,
model_kwargs=model_kwargs,
sampling_timesteps=sampling_steps,
)
# unnormalize
sample = (sample + 1) / 2
return sample.squeeze(0).detach().cpu().numpy()
def register_model(self, registered_model_name, model_path="tiffusion_model", conda_env=None):
"""Register the model with MLflow model registry.
Args:
registered_model_name: Name to register the model under
model_path: Local path to save model artifacts
conda_env: Custom conda environment for the model
"""
# Create basic conda env if not provided
if conda_env is None:
conda_env = {
'channels': ['defaults', 'conda-forge'],
'dependencies': [
'python>=3.8',
'pytorch',
'einops',
'tqdm'
],
'name': 'tiffusion_env'
}
# Start an MLflow run
with mlflow.start_run() as run:
# Log model parameters
mlflow.log_params({
"seq_length": self.seq_length,
"feature_size": self.feature_size,
"n_layer_enc": self.model.n_layer_enc,
"n_layer_dec": self.model.n_layer_dec,
"n_heads": self.model.n_heads,
"timesteps": self.num_timesteps,
"loss_type": self.loss_type
})
# Create a custom Python model class for MLflow
class TiffusionWrapper(mlflow.pyfunc.PythonModel):
def __init__(self, model):
self.model = model
def predict(self, context, model_input):
# Generate predictions using the model
with torch.no_grad():
result = self.model.generate_mts(batch_size=len(model_input))
return result.numpy()
# Create wrapper instance
wrapped_model = TiffusionWrapper(self)
# Log and register the model
mlflow.pyfunc.log_model(
artifact_path=model_path,
python_model=wrapped_model,
conda_env=conda_env,
registered_model_name=registered_model_name
)
print(f"Model registered as: {registered_model_name}")
print(f"Run ID: {run.info.run_id}")
def load_registered_model(self, registered_model_name, version=None, stage=None):
"""Load a registered model from MLflow model registry.
Args:
registered_model_name: Name of registered model
version: Optional specific version to load
stage: Optional stage to load (e.g. 'Production', 'Staging')
"""
if version:
model_uri = f"models:/{registered_model_name}/{version}"
elif stage:
model_uri = f"models:/{registered_model_name}/{stage}"
else:
model_uri = f"models:/{registered_model_name}/latest"
return mlflow.pyfunc.load_model(model_uri)