Spaces:
Running
Running
| import logging | |
| from abc import ABC, abstractmethod | |
| from typing import Dict, List, Literal, Optional, Tuple, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from ...util import append_dims, default | |
| logpy = logging.getLogger(__name__) | |
| class Guider(ABC): | |
| def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: | |
| pass | |
| def prepare_inputs( | |
| self, x: torch.Tensor, s: float, c: Dict, uc: Dict | |
| ) -> Tuple[torch.Tensor, float, Dict]: | |
| pass | |
| class VanillaCFG(Guider): | |
| def __init__( | |
| self, scale: float, low_sigma: float = 0.0, high_sigma: float = float("inf") | |
| ): | |
| self.scale = scale | |
| self.low_sigma = low_sigma | |
| self.high_sigma = high_sigma | |
| def set_scale(self, scale: float): | |
| self.scale = scale | |
| def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
| x_u, x_c = x.chunk(2) | |
| x_pred = x_u + self.scale * (x_c - x_u) | |
| return x_pred | |
| def prepare_inputs(self, x, s, c, uc): | |
| c_out = dict() | |
| for k in c: | |
| if k in [ | |
| "vector", | |
| "crossattn", | |
| "concat", | |
| "audio_emb", | |
| "image_embeds", | |
| "landmarks", | |
| "masks", | |
| "gt", | |
| "valence", | |
| "arousal", | |
| ]: | |
| c_out[k] = torch.cat((uc[k], c[k]), 0) | |
| elif k == "reference": | |
| c_out["reference"] = [] | |
| for i in range(len(c[k])): | |
| c_out["reference"].append(torch.cat((uc[k][i], c[k][i]), 0)) | |
| else: | |
| assert c[k] == uc[k] | |
| c_out[k] = c[k] | |
| return torch.cat([x] * 2), torch.cat([s] * 2), c_out | |
| class VanillaSTG(Guider): | |
| def __init__( | |
| self, | |
| scale_spatial: float, | |
| scale_temporal: float, | |
| low_sigma: float = 0.0, | |
| high_sigma: float = float("inf"), | |
| layer_skip: int = 8, | |
| ): | |
| self.scale_spatial = scale_spatial | |
| self.scale_temporal = scale_temporal | |
| self.low_sigma = low_sigma | |
| self.high_sigma = high_sigma | |
| self.layer_skip = layer_skip | |
| def set_scale(self, scale_spatial: float, scale_temporal: float): | |
| self.scale_spatial = scale_spatial | |
| self.scale_temporal = scale_temporal | |
| def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
| x_c, x_spatial, x_temporal = x.chunk(3) | |
| x_pred = ( | |
| x_c | |
| + self.scale_spatial * (x_c - x_spatial) | |
| + self.scale_temporal * (x_c - x_temporal) | |
| ) | |
| return x_pred | |
| def prepare_inputs(self, x, s, c, uc): | |
| c_out = dict() | |
| for k in c: | |
| if k in [ | |
| "vector", | |
| "crossattn", | |
| "concat", | |
| "audio_emb", | |
| "image_embeds", | |
| "landmarks", | |
| "masks", | |
| "gt", | |
| "valence", | |
| "arousal", | |
| ]: | |
| c_out[k] = torch.cat((c[k], c[k], c[k]), 0) | |
| elif k == "reference": | |
| c_out["reference"] = [] | |
| for i in range(len(c[k])): | |
| c_out["reference"].append(torch.cat((c[k][i], c[k][i], c[k][i]), 0)) | |
| else: | |
| assert c[k] == uc[k] | |
| c_out[k] = c[k] | |
| c_out["skip_spatial_attention_at"] = [None, self.layer_skip, None] | |
| c_out["skip_temporal_attention_at"] = [None, None, self.layer_skip] | |
| return torch.cat([x] * 3), torch.cat([s] * 3), c_out | |
| class MomentumBuffer: | |
| def __init__(self, momentum: float): | |
| self.momentum = momentum | |
| self.running_average = 0 | |
| def update(self, update_value: torch.Tensor): | |
| new_average = self.momentum * self.running_average | |
| self.running_average = update_value + new_average | |
| def project(v0: torch.Tensor, v1: torch.Tensor): | |
| dtype = v0.dtype | |
| v0, v1 = v0.double(), v1.double() | |
| v1 = F.normalize(v1, dim=[-1, -2, -3]) | |
| v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 | |
| v0_orthogonal = v0 - v0_parallel | |
| return v0_parallel.to(dtype), v0_orthogonal.to(dtype) | |
| def normalized_guidance( | |
| pred_cond: torch.Tensor, | |
| pred_uncond: torch.Tensor, | |
| guidance_scale: float, | |
| momentum_buffer: MomentumBuffer = None, | |
| eta: float = 1.0, | |
| norm_threshold: float = 0.0, | |
| ): | |
| diff = pred_cond - pred_uncond | |
| if momentum_buffer is not None: | |
| momentum_buffer.update(diff) | |
| diff = momentum_buffer.running_average | |
| if norm_threshold > 0: | |
| ones = torch.ones_like(diff) | |
| diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True) | |
| scale_factor = torch.minimum(ones, norm_threshold / diff_norm) | |
| diff = diff * scale_factor | |
| diff_parallel, diff_orthogonal = project(diff, pred_cond) | |
| normalized_update = diff_orthogonal + eta * diff_parallel | |
| pred_guided = pred_cond + (guidance_scale - 1) * normalized_update | |
| return pred_guided | |
| class APGGuider(VanillaCFG): | |
| def __init__( | |
| self, | |
| scale: float, | |
| momentum: float = -0.75, | |
| eta: float = 0.0, | |
| norm_threshold: float = 2.5, | |
| ): | |
| super().__init__(scale) | |
| self.momentum_buffer = MomentumBuffer(momentum) | |
| self.eta = eta | |
| self.norm_threshold = norm_threshold | |
| def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
| x_u, x_c = x.chunk(2) | |
| return normalized_guidance( | |
| x_c, x_u, self.scale, self.momentum_buffer, self.eta, self.norm_threshold | |
| ) | |
| class VanillaCFGplusplus(VanillaCFG): | |
| def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
| x_u, x_c = x.chunk(2) | |
| x_pred = x_u + self.scale * (x_c - x_u) | |
| return x_pred, x_u | |
| class KarrasGuider(VanillaCFG): | |
| def prepare_inputs(self, x, s, c, uc): | |
| c_out = dict() | |
| for k in c: | |
| if k in [ | |
| "vector", | |
| "crossattn", | |
| "concat", | |
| "audio_emb", | |
| "image_embeds", | |
| "landmarks", | |
| "valence", | |
| "arousal", | |
| ]: | |
| c_out[k] = torch.cat((c[k], c[k]), 0) | |
| elif k == "reference": | |
| c_out["reference"] = [] | |
| for i in range(len(c[k])): | |
| c_out["reference"].append(torch.cat((c[k][i], c[k][i]), 0)) | |
| else: | |
| assert c[k] == uc[k] | |
| c_out[k] = c[k] | |
| return torch.cat([x] * 2), torch.cat([s] * 2), c_out | |
| class MultipleCondVanilla(Guider): | |
| def __init__(self, scales, condition_names) -> None: | |
| assert len(scales) == len(condition_names) | |
| self.scales = scales | |
| # self.condition_names = condition_names | |
| self.n_conditions = len(scales) | |
| self.map_cond_name = { | |
| "audio_emb": "audio_emb", | |
| "cond_frames_without_noise": "crossattn", | |
| "cond_frames": "concat", | |
| } | |
| self.condition_names = [ | |
| self.map_cond_name.get(cond_name, cond_name) | |
| for cond_name in condition_names | |
| ] | |
| print("Condition names: ", self.condition_names) | |
| def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
| outs = x.chunk(self.n_conditions + 1) | |
| x_full_cond = outs[0] | |
| x_pred = (1 + sum(self.scales)) * x_full_cond | |
| for i, scale in enumerate(self.scales): | |
| x_pred -= scale * outs[i + 1] | |
| return x_pred | |
| def prepare_inputs(self, x, s, c, uc): | |
| c_out = dict() | |
| # The first element is the full condition | |
| for k in c: | |
| if k in [ | |
| "vector", | |
| "crossattn", | |
| "concat", | |
| "audio_emb", | |
| "image_embeds", | |
| "landmarks", | |
| "masks", | |
| "gt", | |
| ]: | |
| c_out[k] = c[k] | |
| else: | |
| assert c[k] == uc[k] | |
| c_out[k] = c[k] | |
| # The rest are the conditions removed from the full condition | |
| for cond_name in self.condition_names: | |
| if not isinstance(cond_name, list): | |
| cond_name = [cond_name] | |
| for k in c: | |
| if k in [ | |
| "vector", | |
| "crossattn", | |
| "concat", | |
| "audio_emb", | |
| "image_embeds", | |
| "landmarks", | |
| "masks", | |
| "gt", | |
| ]: | |
| c_out[k] = torch.cat( | |
| (c_out[k], uc[k] if k in cond_name else c[k]), 0 | |
| ) | |
| return ( | |
| torch.cat([x] * (self.n_conditions + 1)), | |
| torch.cat([s] * (self.n_conditions + 1)), | |
| c_out, | |
| ) | |
| class AudioRefMultiCondGuider(MultipleCondVanilla): | |
| def __init__( | |
| self, | |
| audio_ratio: float = 5.0, | |
| ref_ratio: float = 3.0, | |
| use_normalized: bool = False, | |
| momentum: float = -0.75, | |
| eta: float = 0.0, | |
| norm_threshold: float = 2.5, | |
| ): | |
| super().__init__( | |
| scales=[audio_ratio, ref_ratio], condition_names=["audio_emb", "concat"] | |
| ) | |
| self.audio_ratio = audio_ratio | |
| self.ref_ratio = ref_ratio | |
| self.use_normalized = use_normalized | |
| print(f"Use normalized: {self.use_normalized}") | |
| self.momentum_buffer = MomentumBuffer(momentum) | |
| self.eta = eta | |
| self.norm_threshold = norm_threshold | |
| self.momentum_buffer_audio = MomentumBuffer(momentum) | |
| self.momentum_buffer_ref = MomentumBuffer(momentum) | |
| def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
| e_uc, e_ref, c_audio_ref = x.chunk(3) | |
| if self.use_normalized: | |
| # Normalized guidance version | |
| # Compute diff for audio guidance | |
| diff_audio = c_audio_ref - e_uc | |
| if self.momentum_buffer_audio is not None: | |
| self.momentum_buffer_audio.update(diff_audio) | |
| diff_audio = self.momentum_buffer_audio.running_average | |
| if self.norm_threshold > 0: | |
| ones = torch.ones_like(diff_audio) | |
| diff_norm = diff_audio.norm(p=2, dim=[-1, -2, -3], keepdim=True) | |
| scale_factor = torch.minimum(ones, self.norm_threshold / diff_norm) | |
| diff_audio = diff_audio * scale_factor | |
| diff_audio_parallel, diff_audio_orthogonal = project( | |
| diff_audio, c_audio_ref | |
| ) | |
| normalized_update_audio = ( | |
| diff_audio_orthogonal + self.eta * diff_audio_parallel | |
| ) | |
| guidance_audio = (self.audio_ratio - 1) * normalized_update_audio | |
| # Compute diff for ref guidance | |
| diff_ref = e_ref - e_uc | |
| if self.momentum_buffer_ref is not None: | |
| self.momentum_buffer_ref.update(diff_ref) | |
| diff_ref = self.momentum_buffer_ref.running_average | |
| if self.norm_threshold > 0: | |
| ones = torch.ones_like(diff_ref) | |
| diff_norm = diff_ref.norm(p=2, dim=[-1, -2, -3], keepdim=True) | |
| scale_factor = torch.minimum(ones, self.norm_threshold / diff_norm) | |
| diff_ref = diff_ref * scale_factor | |
| diff_ref_parallel, diff_ref_orthogonal = project(diff_ref, e_ref) | |
| normalized_update_ref = diff_ref_orthogonal + self.eta * diff_ref_parallel | |
| guidance_ref = (self.ref_ratio - 1) * normalized_update_ref | |
| e_final = e_uc + guidance_audio + guidance_ref | |
| else: | |
| # Original version | |
| e_final = ( | |
| self.audio_ratio * (c_audio_ref - e_ref) | |
| + self.ref_ratio * (e_ref - e_uc) | |
| + e_uc | |
| ) | |
| return e_final | |
| def set_scale(self, scale: torch.Tensor): | |
| self.audio_ratio = float(scale[0]) | |
| self.ref_ratio = float(scale[1]) | |
| print(f"Audio ratio: {self.audio_ratio}, Ref ratio: {self.ref_ratio}") | |
| def prepare_inputs(self, x, s, c, uc): | |
| c_out = dict() | |
| # Prepare inputs for e_base (no audio, no ref concat) | |
| c_base = {k: v for k, v in c.items()} | |
| c_base["crossattn"] = uc["crossattn"] | |
| c_base["concat"] = uc["concat"] # Remove ref concat | |
| # Prepare inputs for e_ref (no audio, with ref concat) | |
| c_audio_ref = {k: v for k, v in c.items()} | |
| # c_ref["concat"] = uc["concat"] # Remove ref concat | |
| # Prepare inputs for e_audio (all conditions) | |
| c_ref = {k: v for k, v in c.items()} | |
| c_ref["crossattn"] = uc["crossattn"] | |
| # Combine all conditions | |
| for k in c: | |
| if k in [ | |
| "vector", | |
| "crossattn", | |
| "concat", | |
| "audio_emb", | |
| "image_embeds", | |
| "landmarks", | |
| "masks", | |
| "gt", | |
| ]: | |
| c_out[k] = torch.cat((c_base[k], c_ref[k], c_audio_ref[k]), 0) | |
| else: | |
| c_out[k] = c[k] | |
| return torch.cat([x] * 3), torch.cat([s] * 3), c_out | |
| class IdentityGuider(Guider): | |
| def __init__(self, *args, **kwargs): | |
| # self.num_frames = num_frames | |
| pass | |
| def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: | |
| return x | |
| def prepare_inputs( | |
| self, x: torch.Tensor, s: float, c: Dict, uc: Dict | |
| ) -> Tuple[torch.Tensor, float, Dict]: | |
| c_out = dict() | |
| for k in c: | |
| c_out[k] = c[k] | |
| return x, s, c_out | |
| class LinearPredictionGuider(Guider): | |
| def __init__( | |
| self, | |
| max_scale: float, | |
| num_frames: int, | |
| min_scale: float = 1.0, | |
| additional_cond_keys: Optional[Union[List[str], str]] = None, | |
| only_first=False, | |
| ): | |
| self.min_scale = min_scale | |
| self.max_scale = max_scale | |
| self.num_frames = num_frames | |
| self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) | |
| self.only_first = only_first | |
| if only_first: | |
| self.scale = torch.ones_like(self.scale) * max_scale | |
| self.scale[:, 0] = min_scale | |
| additional_cond_keys = default(additional_cond_keys, []) | |
| if isinstance(additional_cond_keys, str): | |
| additional_cond_keys = [additional_cond_keys] | |
| self.additional_cond_keys = additional_cond_keys | |
| def set_scale(self, scale: torch.Tensor): | |
| self.min_scale = scale | |
| self.scale = torch.linspace( | |
| self.min_scale, self.max_scale, self.num_frames | |
| ).unsqueeze(0) | |
| if self.only_first: | |
| self.scale = torch.ones_like(self.scale) * self.max_scale | |
| self.scale[:, 0] = self.min_scale | |
| print(self.scale) | |
| def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
| x_u, x_c = x.chunk(2) | |
| x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) | |
| x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) | |
| scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) | |
| scale = append_dims(scale, x_u.ndim).to(x_u.device) | |
| return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") | |
| def prepare_inputs( | |
| self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict | |
| ) -> Tuple[torch.Tensor, torch.Tensor, dict]: | |
| c_out = dict() | |
| for k in c: | |
| if ( | |
| k | |
| in ["vector", "crossattn", "concat", "audio_emb", "masks", "gt"] | |
| + self.additional_cond_keys | |
| ): | |
| c_out[k] = torch.cat((uc[k], c[k]), 0) | |
| else: | |
| assert c[k] == uc[k] | |
| c_out[k] = c[k] | |
| return torch.cat([x] * 2), torch.cat([s] * 2), c_out | |
| class LinearPredictionGuiderPlus(LinearPredictionGuider): | |
| def __init__( | |
| self, | |
| max_scale: float, | |
| num_frames: int, | |
| min_scale: float = 1.0, | |
| additional_cond_keys: Optional[Union[List[str], str]] = None, | |
| ): | |
| super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) | |
| def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
| x_u, x_c = x.chunk(2) | |
| x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) | |
| x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) | |
| scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) | |
| scale = append_dims(scale, x_u.ndim).to(x_u.device) | |
| return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ..."), x_u | |
| class TrianglePredictionGuider(LinearPredictionGuider): | |
| def __init__( | |
| self, | |
| max_scale: float, | |
| num_frames: int, | |
| min_scale: float = 1.0, | |
| period: float | List[float] = 1.0, | |
| period_fusing: Literal["mean", "multiply", "max"] = "max", | |
| additional_cond_keys: Optional[Union[List[str], str]] = None, | |
| ): | |
| super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) | |
| values = torch.linspace(0, 1, num_frames) | |
| # Constructs a triangle wave | |
| if isinstance(period, float): | |
| period = [period] | |
| scales = [] | |
| for p in period: | |
| scales.append(self.triangle_wave(values, p)) | |
| if period_fusing == "mean": | |
| scale = sum(scales) / len(period) | |
| elif period_fusing == "multiply": | |
| scale = torch.prod(torch.stack(scales), dim=0) | |
| elif period_fusing == "max": | |
| scale = torch.max(torch.stack(scales), dim=0).values | |
| self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0) | |
| def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor: | |
| return 2 * (values / period - torch.floor(values / period + 0.5)).abs() | |