import torch def sum_guidance( x: torch.Tensor, t: torch.Tensor, target_sum: torch.Tensor, sigma: float = 1.0, gradient_scale: float = 1.0, segments: tuple = None ): """Enhanced regression guidance with stronger gradients""" x_with_grad = x current_sum = x_with_grad[:, :, 0] current_sum = current_sum / 2 + 0.5 current_sum = current_sum.sum(dim=1) if segments: for i, (start_idx, end_idx) in enumerate(segments): if i==0: current_sum = current_sum[:, start_idx:end_idx,0] else: current_sum += current_sum[:, start_idx:end_idx,0] if sigma == 0: pred_std = torch.ones_like(current_sum) else: pred_std = torch.ones_like(current_sum) * sigma log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \ (target_sum - current_sum)**2 / (2 * pred_std**2) return log_prob.mean() def peak_guidance( x: torch.Tensor, t: torch.Tensor, peak_points: list, window_size: int = 5, alpha_1: float = 1.2, sigma: float = 1.0, gradient_scale: float = 1.0 # 新增梯度缩放参数 ): x_with_grad = x log_prob = 0 if x_with_grad.shape[1] < x_with_grad.shape[2]: signal = x_with_grad[:, 0] else: signal = x_with_grad[:, :, 0] signal = signal / 2 + 0.5 for x_coord in peak_points: half_window = window_size // 2 start_idx = max(0, x_coord - half_window) end_idx = min(signal.shape[1], x_coord + half_window + 1) local_mean = (signal[:, start_idx:end_idx].sum(dim=1) - signal[:, x_coord]) / (end_idx - start_idx - 1) local_diff = (local_mean * alpha_1 - signal[:, x_coord]).mean() log_prob += - (local_diff**2) / (2 * sigma**2) return log_prob.mean() def bar_guidance( x: torch.Tensor, t: torch.Tensor, bar_regions: list, sigma: float = 1.0, gradient_scale: float = 1.0 ): x_with_grad = x log_prob = 0 if x_with_grad.shape[1] < x_with_grad.shape[2]: signal = x_with_grad[:, 0] else: signal = x_with_grad[:, :, 0] signal = signal / 2 + 0.5 for start_idx, end_idx, target_value in bar_regions: region_mean = signal[:, start_idx:end_idx].mean(dim=1) sigma_t = torch.ones_like(region_mean) * sigma if sigma > 0 else torch.ones_like(region_mean) log_prob += torch.exp(-0.5 * ((region_mean - target_value)**2) / (sigma_t**2)).mean() return log_prob def frequency_guidance( x: torch.Tensor, t: torch.Tensor, target_freq: float, freq_weight: float = 1.0, gradient_scale: float = 1.0 ): x_with_grad = x if x_with_grad.shape[1] < x_with_grad.shape[2]: signal = x_with_grad[:, 0] else: signal = x_with_grad[:, :, 0] fft_signal = torch.fft.rfft(signal, dim=1) freqs = torch.fft.rfftfreq(signal.shape[1], d=1.0) freq_window = torch.exp(-((freqs - target_freq)**2) / (2 * (0.1/gradient_scale)**2)) freq_window = freq_window.to(x.device) magnitude = torch.abs(fft_signal) * freq_window[None, :] return torch.exp(freq_weight * magnitude.mean()) def get_time_dependent_weights(t, num_timesteps): """ 根据时间步长动态调整控制信号的权重 较早的时间步长使用更大的权重 """ progress = t.float() / num_timesteps # 在早期时间步长使用更大的权重 weight_scale = torch.exp(-5 * progress) return weight_scale