Spaces:
Running on Zero
Running on Zero
| """Audio reference conditioning item for IC-LoRA voice cloning.""" | |
| import torch | |
| from ltx_core.components.patchifiers import AudioPatchifier | |
| from ltx_core.conditioning.item import ConditioningItem | |
| from ltx_core.tools import AudioLatentTools | |
| from ltx_core.types import AudioLatentShape, LatentState | |
| class AudioConditionByReferenceLatent(ConditioningItem): | |
| """Conditions audio generation on a reference audio latent for voice cloning. | |
| Mirrors VideoConditionByReferenceLatent but for audio: | |
| - Patchifies reference latent [B, C, T, F] -> [B, ref_T, 128] | |
| - Computes 1D temporal positions via AudioPatchifier | |
| - Sets denoise_mask = 1.0 - strength (strength=1.0 -> mask=0 -> frozen) | |
| - Builds ASYMMETRIC attention mask: target->ref=1 (attend), ref->target=0 (read-only) | |
| - APPENDS ref tokens to END of latent sequence (IC-LoRA pattern) | |
| - Uses OVERLAPPING positions (same coordinate space) so RoPE doesn't | |
| decay target->ref attention. The asymmetric mask provides the structural | |
| signal that ref tokens are conditioning, not reconstruction targets. | |
| Args: | |
| latent: Reference audio latent [B, C, T, F] (pre-VAE-encoded). | |
| strength: Conditioning strength. 1.0 = full (ref kept clean), | |
| 0.0 = none (ref fully denoised). Default 1.0. | |
| """ | |
| def __init__(self, latent: torch.Tensor, strength: float = 1.0): | |
| self.latent = latent | |
| self.strength = strength | |
| def apply_to( | |
| self, | |
| latent_state: LatentState, | |
| latent_tools: AudioLatentTools, | |
| ) -> LatentState: | |
| """Append reference audio tokens with positions and attention mask.""" | |
| tokens = latent_tools.patchifier.patchify(self.latent) | |
| # Compute positions for the reference audio β small offset (0.5s) from | |
| # target start to avoid exact t=0 overlap (which causes ref content to | |
| # bleed into target start), while keeping RoPE decay minimal. | |
| # 0.5s / max_pos(20s) = 0.025 fractional β negligible RoPE decay. | |
| ref_shape = AudioLatentShape( | |
| batch=self.latent.shape[0], | |
| channels=self.latent.shape[1], | |
| frames=self.latent.shape[2], | |
| mel_bins=self.latent.shape[3], | |
| ) | |
| positions = latent_tools.patchifier.get_patch_grid_bounds( | |
| output_shape=ref_shape, | |
| device=self.latent.device, | |
| ) | |
| # Small offset to prevent t=0 position collision between target and ref | |
| positions = positions + 0.5 | |
| # Denoise mask: 0 for frozen (strength=1.0), 1 for fully denoised (strength=0.0) | |
| denoise_mask = torch.full( | |
| size=(*tokens.shape[:2], 1), | |
| fill_value=1.0 - self.strength, | |
| device=self.latent.device, | |
| dtype=torch.float32, | |
| ) | |
| # Build ASYMMETRIC attention mask manually. | |
| # Structure: | |
| # target (N) ref (M) | |
| # ββββββββββββββ¬βββββββββββ | |
| # target β 1.0 β 1.0 β target attends to everything | |
| # (N) β β β | |
| # ββββββββββββββΌβββββββββββ€ | |
| # ref β 0.0 β 1.0 β ref only attends to itself | |
| # (M) β β β | |
| # ββββββββββββββ΄βββββββββββ | |
| # | |
| # This makes reference tokens "read-only conditioning": | |
| # - Target tokens freely attend to ref (voice cloning signal) | |
| # - Ref tokens don't attend to noisy target (stays clean/stable) | |
| batch_size = tokens.shape[0] | |
| num_target = latent_state.latent.shape[1] | |
| num_ref = tokens.shape[1] | |
| total = num_target + num_ref | |
| # Use float32 for the [0,1] mask β _prepare_self_attention_mask converts | |
| # to log-space bias in the model's compute dtype before it reaches attention. | |
| mask = torch.zeros( | |
| (batch_size, total, total), | |
| device=self.latent.device, | |
| dtype=torch.float32, | |
| ) | |
| # Incorporate existing mask if present, otherwise full attention for target | |
| if latent_state.attention_mask is not None: | |
| mask[:, :num_target, :num_target] = latent_state.attention_mask | |
| else: | |
| mask[:, :num_target, :num_target] = 1.0 | |
| # Target -> ref: FULL attention (target can read reference voice) | |
| mask[:, :num_target, num_target:] = 1.0 | |
| # Ref -> target: BLOCKED (ref is read-only, doesn't see noisy target) | |
| # mask[:, num_target:, :num_target] remains 0.0 | |
| # Ref -> ref: full self-attention within reference | |
| mask[:, num_target:, num_target:] = 1.0 | |
| return LatentState( | |
| latent=torch.cat([latent_state.latent, tokens], dim=1), | |
| denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1), | |
| positions=torch.cat([latent_state.positions, positions], dim=2), | |
| clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1), | |
| attention_mask=mask, | |
| ) | |