Spaces:
Sleeping
Sleeping
File size: 3,514 Bytes
2875fe6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import torch
def sum_guidance(
x: torch.Tensor,
t: torch.Tensor,
target_sum: torch.Tensor,
sigma: float = 1.0,
gradient_scale: float = 1.0,
segments: tuple = None
):
"""Enhanced regression guidance with stronger gradients"""
x_with_grad = x
current_sum = x_with_grad[:, :, 0]
current_sum = current_sum / 2 + 0.5
current_sum = current_sum.sum(dim=1)
if segments:
for i, (start_idx, end_idx) in enumerate(segments):
if i==0:
current_sum = current_sum[:, start_idx:end_idx,0]
else:
current_sum += current_sum[:, start_idx:end_idx,0]
if sigma == 0:
pred_std = torch.ones_like(current_sum)
else:
pred_std = torch.ones_like(current_sum) * sigma
log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \
(target_sum - current_sum)**2 / (2 * pred_std**2)
return log_prob.mean()
def peak_guidance(
x: torch.Tensor,
t: torch.Tensor,
peak_points: list,
window_size: int = 5,
alpha_1: float = 1.2,
sigma: float = 1.0,
gradient_scale: float = 1.0 # 新增梯度缩放参数
):
x_with_grad = x
log_prob = 0
if x_with_grad.shape[1] < x_with_grad.shape[2]:
signal = x_with_grad[:, 0]
else:
signal = x_with_grad[:, :, 0]
signal = signal / 2 + 0.5
for x_coord in peak_points:
half_window = window_size // 2
start_idx = max(0, x_coord - half_window)
end_idx = min(signal.shape[1], x_coord + half_window + 1)
local_mean = (signal[:, start_idx:end_idx].sum(dim=1) - signal[:, x_coord]) / (end_idx - start_idx - 1)
local_diff = (local_mean * alpha_1 - signal[:, x_coord]).mean()
log_prob += - (local_diff**2) / (2 * sigma**2)
return log_prob.mean()
def bar_guidance(
x: torch.Tensor,
t: torch.Tensor,
bar_regions: list,
sigma: float = 1.0,
gradient_scale: float = 1.0
):
x_with_grad = x
log_prob = 0
if x_with_grad.shape[1] < x_with_grad.shape[2]:
signal = x_with_grad[:, 0]
else:
signal = x_with_grad[:, :, 0]
signal = signal / 2 + 0.5
for start_idx, end_idx, target_value in bar_regions:
region_mean = signal[:, start_idx:end_idx].mean(dim=1)
sigma_t = torch.ones_like(region_mean) * sigma if sigma > 0 else torch.ones_like(region_mean)
log_prob += torch.exp(-0.5 * ((region_mean - target_value)**2) / (sigma_t**2)).mean()
return log_prob
def frequency_guidance(
x: torch.Tensor,
t: torch.Tensor,
target_freq: float,
freq_weight: float = 1.0,
gradient_scale: float = 1.0
):
x_with_grad = x
if x_with_grad.shape[1] < x_with_grad.shape[2]:
signal = x_with_grad[:, 0]
else:
signal = x_with_grad[:, :, 0]
fft_signal = torch.fft.rfft(signal, dim=1)
freqs = torch.fft.rfftfreq(signal.shape[1], d=1.0)
freq_window = torch.exp(-((freqs - target_freq)**2) / (2 * (0.1/gradient_scale)**2))
freq_window = freq_window.to(x.device)
magnitude = torch.abs(fft_signal) * freq_window[None, :]
return torch.exp(freq_weight * magnitude.mean())
def get_time_dependent_weights(t, num_timesteps):
"""
根据时间步长动态调整控制信号的权重
较早的时间步长使用更大的权重
"""
progress = t.float() / num_timesteps
# 在早期时间步长使用更大的权重
weight_scale = torch.exp(-5 * progress)
return weight_scale
|