TSEditor / models /CSDI /control.py
PeterYu's picture
update
2875fe6
import torch
def sum_guidance(
x: torch.Tensor,
t: torch.Tensor,
target_sum: torch.Tensor,
sigma: float = 1.0,
gradient_scale: float = 1.0,
segments: tuple = None
):
"""Enhanced regression guidance with stronger gradients"""
x_with_grad = x
current_sum = x_with_grad[:, :, 0]
current_sum = current_sum / 2 + 0.5
current_sum = current_sum.sum(dim=1)
if segments:
for i, (start_idx, end_idx) in enumerate(segments):
if i==0:
current_sum = current_sum[:, start_idx:end_idx,0]
else:
current_sum += current_sum[:, start_idx:end_idx,0]
if sigma == 0:
pred_std = torch.ones_like(current_sum)
else:
pred_std = torch.ones_like(current_sum) * sigma
log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \
(target_sum - current_sum)**2 / (2 * pred_std**2)
return log_prob.mean()
def peak_guidance(
x: torch.Tensor,
t: torch.Tensor,
peak_points: list,
window_size: int = 5,
alpha_1: float = 1.2,
sigma: float = 1.0,
gradient_scale: float = 1.0 # 新增梯度缩放参数
):
x_with_grad = x
log_prob = 0
if x_with_grad.shape[1] < x_with_grad.shape[2]:
signal = x_with_grad[:, 0]
else:
signal = x_with_grad[:, :, 0]
signal = signal / 2 + 0.5
for x_coord in peak_points:
half_window = window_size // 2
start_idx = max(0, x_coord - half_window)
end_idx = min(signal.shape[1], x_coord + half_window + 1)
local_mean = (signal[:, start_idx:end_idx].sum(dim=1) - signal[:, x_coord]) / (end_idx - start_idx - 1)
local_diff = (local_mean * alpha_1 - signal[:, x_coord]).mean()
log_prob += - (local_diff**2) / (2 * sigma**2)
return log_prob.mean()
def bar_guidance(
x: torch.Tensor,
t: torch.Tensor,
bar_regions: list,
sigma: float = 1.0,
gradient_scale: float = 1.0
):
x_with_grad = x
log_prob = 0
if x_with_grad.shape[1] < x_with_grad.shape[2]:
signal = x_with_grad[:, 0]
else:
signal = x_with_grad[:, :, 0]
signal = signal / 2 + 0.5
for start_idx, end_idx, target_value in bar_regions:
region_mean = signal[:, start_idx:end_idx].mean(dim=1)
sigma_t = torch.ones_like(region_mean) * sigma if sigma > 0 else torch.ones_like(region_mean)
log_prob += torch.exp(-0.5 * ((region_mean - target_value)**2) / (sigma_t**2)).mean()
return log_prob
def frequency_guidance(
x: torch.Tensor,
t: torch.Tensor,
target_freq: float,
freq_weight: float = 1.0,
gradient_scale: float = 1.0
):
x_with_grad = x
if x_with_grad.shape[1] < x_with_grad.shape[2]:
signal = x_with_grad[:, 0]
else:
signal = x_with_grad[:, :, 0]
fft_signal = torch.fft.rfft(signal, dim=1)
freqs = torch.fft.rfftfreq(signal.shape[1], d=1.0)
freq_window = torch.exp(-((freqs - target_freq)**2) / (2 * (0.1/gradient_scale)**2))
freq_window = freq_window.to(x.device)
magnitude = torch.abs(fft_signal) * freq_window[None, :]
return torch.exp(freq_weight * magnitude.mean())
def get_time_dependent_weights(t, num_timesteps):
"""
根据时间步长动态调整控制信号的权重
较早的时间步长使用更大的权重
"""
progress = t.float() / num_timesteps
# 在早期时间步长使用更大的权重
weight_scale = torch.exp(-5 * progress)
return weight_scale