File size: 3,097 Bytes
5488167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71922e7
 
 
 
5488167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71922e7
5488167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch


class MomentumBuffer:
    def __init__(self, momentum: float = -0.75):
        self.momentum = momentum
        self.running_average = 0

    def update(self, update_value: torch.Tensor):
        new_average = self.momentum * self.running_average
        self.running_average = update_value + new_average


def project(
    v0: torch.Tensor,  # [B, C, H, W]
    v1: torch.Tensor,  # [B, C, H, W]
    dims=[-1, -2],
):
    dtype = v0.dtype
    if v0.device.type == "mps":
        v0, v1 = v0.float(), v1.float()
    else:
        v0, v1 = v0.double(), v1.double()
    v1 = torch.nn.functional.normalize(v1, dim=dims)
    v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
    v0_orthogonal = v0 - v0_parallel
    return v0_parallel.to(dtype), v0_orthogonal.to(dtype)


def apg_forward(
    pred_cond: torch.Tensor,  # [B, C, H, W]
    pred_uncond: torch.Tensor,  # [B, C, H, W]
    guidance_scale: float,
    momentum_buffer: MomentumBuffer = None,
    eta: float = 0.0,
    norm_threshold: float = 2.5,
    dims=[-1, -2],
):
    diff = pred_cond - pred_uncond
    if momentum_buffer is not None:
        momentum_buffer.update(diff)
        diff = momentum_buffer.running_average

    if norm_threshold > 0:
        ones = torch.ones_like(diff)
        diff_norm = diff.norm(p=2, dim=dims, keepdim=True)
        scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
        diff = diff * scale_factor

    diff_parallel, diff_orthogonal = project(diff, pred_cond, dims)
    normalized_update = diff_orthogonal + eta * diff_parallel
    pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
    return pred_guided


def cfg_forward(cond_output, uncond_output, cfg_strength):
    return uncond_output + cfg_strength * (cond_output - uncond_output)


def cfg_double_condition_forward(
    cond_output,
    uncond_output,
    only_text_cond_output,
    guidance_scale_text,
    guidance_scale_lyric,
):
    return (1 - guidance_scale_text) * uncond_output + (guidance_scale_text - guidance_scale_lyric) * only_text_cond_output + guidance_scale_lyric * cond_output 


def optimized_scale(positive_flat, negative_flat):

    # Calculate dot production
    dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)

    # Squared norm of uncondition
    squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8

    # st_star = v_cond^T * v_uncond / ||v_uncond||^2
    st_star = dot_product / squared_norm
    
    return st_star


def cfg_zero_star(noise_pred_with_cond, noise_pred_uncond, guidance_scale, i, zero_steps=1, use_zero_init=True):
    bsz = noise_pred_with_cond.shape[0]
    positive_flat = noise_pred_with_cond.view(bsz, -1)
    negative_flat = noise_pred_uncond.view(bsz, -1)
    alpha = optimized_scale(positive_flat, negative_flat)
    alpha = alpha.view(bsz, 1, 1, 1)
    if (i <= zero_steps) and use_zero_init:
        noise_pred = noise_pred_with_cond * 0.
    else:
        noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_with_cond - noise_pred_uncond * alpha)
    return noise_pred