File size: 3,514 Bytes
2875fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch

def sum_guidance(
    x: torch.Tensor,
    t: torch.Tensor,
    target_sum: torch.Tensor,
    sigma: float = 1.0,
    gradient_scale: float = 1.0,
    segments: tuple = None
):
    """Enhanced regression guidance with stronger gradients"""
    x_with_grad = x
    current_sum = x_with_grad[:, :, 0]
    current_sum = current_sum / 2 + 0.5
    current_sum = current_sum.sum(dim=1)
    if segments:
        for i, (start_idx, end_idx) in enumerate(segments):
            if i==0:
                current_sum = current_sum[:, start_idx:end_idx,0]
            else:
                current_sum += current_sum[:, start_idx:end_idx,0]

    if sigma == 0:
        pred_std = torch.ones_like(current_sum)
    else:
        pred_std = torch.ones_like(current_sum) * sigma

    log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \
        (target_sum - current_sum)**2 / (2 * pred_std**2)
    return log_prob.mean()


def peak_guidance(
    x: torch.Tensor,
    t: torch.Tensor,
    peak_points: list,
    window_size: int = 5,
    alpha_1: float = 1.2,
    sigma: float = 1.0,
    gradient_scale: float = 1.0  # 新增梯度缩放参数
):
    x_with_grad = x
    log_prob = 0
    
    if x_with_grad.shape[1] < x_with_grad.shape[2]:
        signal = x_with_grad[:, 0]
    else:
        signal = x_with_grad[:, :, 0]

    signal = signal / 2 + 0.5
    
    for x_coord in peak_points:
        half_window = window_size // 2
        start_idx = max(0, x_coord - half_window)
        end_idx = min(signal.shape[1], x_coord + half_window + 1)
        local_mean = (signal[:, start_idx:end_idx].sum(dim=1) - signal[:, x_coord]) / (end_idx - start_idx - 1)
        local_diff = (local_mean * alpha_1 - signal[:, x_coord]).mean()
        log_prob += - (local_diff**2) / (2 * sigma**2)
    return log_prob.mean()


def bar_guidance(
    x: torch.Tensor,
    t: torch.Tensor,
    bar_regions: list,
    sigma: float = 1.0,
    gradient_scale: float = 1.0
):
    x_with_grad = x
    log_prob = 0
    
    if x_with_grad.shape[1] < x_with_grad.shape[2]:
        signal = x_with_grad[:, 0]
    else:
        signal = x_with_grad[:, :, 0]
    
    signal = signal / 2 + 0.5
    
    for start_idx, end_idx, target_value in bar_regions:
        region_mean = signal[:, start_idx:end_idx].mean(dim=1)
        sigma_t = torch.ones_like(region_mean) * sigma if sigma > 0 else torch.ones_like(region_mean)
        log_prob += torch.exp(-0.5 * ((region_mean - target_value)**2) / (sigma_t**2)).mean()
    
    return log_prob


def frequency_guidance(
    x: torch.Tensor,
    t: torch.Tensor,
    target_freq: float,
    freq_weight: float = 1.0,
    gradient_scale: float = 1.0
):
    x_with_grad = x
    if x_with_grad.shape[1] < x_with_grad.shape[2]:
        signal = x_with_grad[:, 0]
    else:
        signal = x_with_grad[:, :, 0]
    
    fft_signal = torch.fft.rfft(signal, dim=1)
    freqs = torch.fft.rfftfreq(signal.shape[1], d=1.0)
    freq_window = torch.exp(-((freqs - target_freq)**2) / (2 * (0.1/gradient_scale)**2))
    freq_window = freq_window.to(x.device)
    magnitude = torch.abs(fft_signal) * freq_window[None, :]
    return torch.exp(freq_weight * magnitude.mean())


def get_time_dependent_weights(t, num_timesteps):
    """
    根据时间步长动态调整控制信号的权重
    较早的时间步长使用更大的权重
    """
    progress = t.float() / num_timesteps
    # 在早期时间步长使用更大的权重
    weight_scale = torch.exp(-5 * progress)
    return weight_scale