Spaces:
Sleeping
Sleeping
import torch | |
def sum_guidance( | |
x: torch.Tensor, | |
t: torch.Tensor, | |
target_sum: torch.Tensor, | |
sigma: float = 1.0, | |
gradient_scale: float = 1.0, | |
segments: tuple = None | |
): | |
"""Enhanced regression guidance with stronger gradients""" | |
x_with_grad = x | |
current_sum = x_with_grad[:, :, 0] | |
current_sum = current_sum / 2 + 0.5 | |
current_sum = current_sum.sum(dim=1) | |
if segments: | |
for i, (start_idx, end_idx) in enumerate(segments): | |
if i==0: | |
current_sum = current_sum[:, start_idx:end_idx,0] | |
else: | |
current_sum += current_sum[:, start_idx:end_idx,0] | |
if sigma == 0: | |
pred_std = torch.ones_like(current_sum) | |
else: | |
pred_std = torch.ones_like(current_sum) * sigma | |
log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \ | |
(target_sum - current_sum)**2 / (2 * pred_std**2) | |
return log_prob.mean() | |
def peak_guidance( | |
x: torch.Tensor, | |
t: torch.Tensor, | |
peak_points: list, | |
window_size: int = 5, | |
alpha_1: float = 1.2, | |
sigma: float = 1.0, | |
gradient_scale: float = 1.0 # 新增梯度缩放参数 | |
): | |
x_with_grad = x | |
log_prob = 0 | |
if x_with_grad.shape[1] < x_with_grad.shape[2]: | |
signal = x_with_grad[:, 0] | |
else: | |
signal = x_with_grad[:, :, 0] | |
signal = signal / 2 + 0.5 | |
for x_coord in peak_points: | |
half_window = window_size // 2 | |
start_idx = max(0, x_coord - half_window) | |
end_idx = min(signal.shape[1], x_coord + half_window + 1) | |
local_mean = (signal[:, start_idx:end_idx].sum(dim=1) - signal[:, x_coord]) / (end_idx - start_idx - 1) | |
local_diff = (local_mean * alpha_1 - signal[:, x_coord]).mean() | |
log_prob += - (local_diff**2) / (2 * sigma**2) | |
return log_prob.mean() | |
def bar_guidance( | |
x: torch.Tensor, | |
t: torch.Tensor, | |
bar_regions: list, | |
sigma: float = 1.0, | |
gradient_scale: float = 1.0 | |
): | |
x_with_grad = x | |
log_prob = 0 | |
if x_with_grad.shape[1] < x_with_grad.shape[2]: | |
signal = x_with_grad[:, 0] | |
else: | |
signal = x_with_grad[:, :, 0] | |
signal = signal / 2 + 0.5 | |
for start_idx, end_idx, target_value in bar_regions: | |
region_mean = signal[:, start_idx:end_idx].mean(dim=1) | |
sigma_t = torch.ones_like(region_mean) * sigma if sigma > 0 else torch.ones_like(region_mean) | |
log_prob += torch.exp(-0.5 * ((region_mean - target_value)**2) / (sigma_t**2)).mean() | |
return log_prob | |
def frequency_guidance( | |
x: torch.Tensor, | |
t: torch.Tensor, | |
target_freq: float, | |
freq_weight: float = 1.0, | |
gradient_scale: float = 1.0 | |
): | |
x_with_grad = x | |
if x_with_grad.shape[1] < x_with_grad.shape[2]: | |
signal = x_with_grad[:, 0] | |
else: | |
signal = x_with_grad[:, :, 0] | |
fft_signal = torch.fft.rfft(signal, dim=1) | |
freqs = torch.fft.rfftfreq(signal.shape[1], d=1.0) | |
freq_window = torch.exp(-((freqs - target_freq)**2) / (2 * (0.1/gradient_scale)**2)) | |
freq_window = freq_window.to(x.device) | |
magnitude = torch.abs(fft_signal) * freq_window[None, :] | |
return torch.exp(freq_weight * magnitude.mean()) | |
def get_time_dependent_weights(t, num_timesteps): | |
""" | |
根据时间步长动态调整控制信号的权重 | |
较早的时间步长使用更大的权重 | |
""" | |
progress = t.float() / num_timesteps | |
# 在早期时间步长使用更大的权重 | |
weight_scale = torch.exp(-5 * progress) | |
return weight_scale | |