Spaces:
Paused
Paused
File size: 2,202 Bytes
6497501 ed25868 6497501 ed25868 6497501 ed25868 6497501 ed25868 6497501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from functools import partial
import torch
from ...util import default, instantiate_from_config
class VanillaCFG:
"""
implements parallelized CFG
"""
def __init__(self, scale, dyn_thresh_config=None):
scale_schedule = lambda scale, sigma: scale # independent of step
self.scale_schedule = partial(scale_schedule, scale)
self.dyn_thresh = instantiate_from_config(
default(
dyn_thresh_config,
{
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
},
)
)
def __call__(self, x, sigma):
x_u, x_c = x.chunk(2)
scale_value = self.scale_schedule(sigma)
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "add_crossattn", "concat"]:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class DualCFG:
def __init__(self, scale):
self.scale = scale
self.dyn_thresh = instantiate_from_config(
{
"target": "sgm.modules.diffusionmodules.sampling_utils.DualThresholding"
},
)
def __call__(self, x, sigma):
x_u_1, x_u_2, x_c = x.chunk(3)
x_pred = self.dyn_thresh(x_u_1, x_u_2, x_c, self.scale)
return x_pred
def prepare_inputs(self, x, s, c, uc_1, uc_2):
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat", "add_crossattn"]:
c_out[k] = torch.cat((uc_1[k], uc_2[k], c[k]), 0)
else:
assert c[k] == uc_1[k]
c_out[k] = c[k]
return torch.cat([x] * 3), torch.cat([s] * 3), c_out
class IdentityGuider:
def __call__(self, x, sigma):
return x
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
c_out[k] = c[k]
return x, s, c_out
|