File size: 2,202 Bytes
6497501
 
 
 
 
 
 
 
 
 
 
 
 
ed25868
 
6497501
 
 
 
 
 
 
 
 
 
 
ed25868
 
6497501
 
 
 
 
 
ed25868
6497501
 
 
 
 
 
 
ed25868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6497501
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from functools import partial

import torch

from ...util import default, instantiate_from_config


class VanillaCFG:
    """
    implements parallelized CFG
    """

    def __init__(self, scale, dyn_thresh_config=None):
        scale_schedule = lambda scale, sigma: scale  # independent of step
        self.scale_schedule = partial(scale_schedule, scale)
        self.dyn_thresh = instantiate_from_config(
            default(
                dyn_thresh_config,
                {
                    "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
                },
            )
        )

    def __call__(self, x, sigma):
        x_u, x_c = x.chunk(2)
        scale_value = self.scale_schedule(sigma)
        x_pred = self.dyn_thresh(x_u, x_c, scale_value)
        return x_pred

    def prepare_inputs(self, x, s, c, uc):
        c_out = dict()

        for k in c:
            if k in ["vector", "crossattn", "add_crossattn", "concat"]:
                c_out[k] = torch.cat((uc[k], c[k]), 0)
            else:
                assert c[k] == uc[k]
                c_out[k] = c[k]
        return torch.cat([x] * 2), torch.cat([s] * 2), c_out
    

class DualCFG:

    def __init__(self, scale):
        self.scale = scale
        self.dyn_thresh = instantiate_from_config(
            {
                "target": "sgm.modules.diffusionmodules.sampling_utils.DualThresholding"
            },
        )

    def __call__(self, x, sigma):
        x_u_1, x_u_2, x_c = x.chunk(3)
        x_pred = self.dyn_thresh(x_u_1, x_u_2, x_c, self.scale)
        return x_pred

    def prepare_inputs(self, x, s, c, uc_1, uc_2):
        c_out = dict()

        for k in c:
            if k in ["vector", "crossattn", "concat", "add_crossattn"]:
                c_out[k] = torch.cat((uc_1[k], uc_2[k], c[k]), 0)
            else:
                assert c[k] == uc_1[k]
                c_out[k] = c[k]
        return torch.cat([x] * 3), torch.cat([s] * 3), c_out



class IdentityGuider:
    def __call__(self, x, sigma):
        return x

    def prepare_inputs(self, x, s, c, uc):
        c_out = dict()

        for k in c:
            c_out[k] = c[k]

        return x, s, c_out