Spaces:
Build error
Build error
| import math | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers.models.auto import AutoModel | |
| from transformers.modeling_utils import PreTrainedModel | |
| # from transformers.modeling_layers import GradientCheckpointingLayer | |
| from transformers.activations import ACT2FN | |
| from transformers.utils import logging | |
| from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig | |
| logger = logging.get_logger(__name__) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False): | |
| super().__init__() | |
| self.dim = dim | |
| self.eps = eps | |
| self.elementwise_affine = elementwise_affine | |
| if self.elementwise_affine: | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| else: | |
| self.register_parameter('weight', None) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| output = self._norm(x.float()).type_as(x) | |
| if self.weight is not None: | |
| output = output * self.weight | |
| return output | |
| def extra_repr(self) -> str: | |
| return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' | |
| def modulate(x, shift, scale): | |
| """Apply modulation to input tensor.""" | |
| return x * (1 + scale) + shift | |
| class TimestepEmbedder(nn.Module): | |
| """ | |
| Embeds scalar timesteps into vector representations. | |
| Args: | |
| hidden_size (`int`): Size of the output embedding | |
| frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding | |
| """ | |
| def __init__(self, hidden_size, frequency_embedding_size=256): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(frequency_embedding_size, hidden_size, bias=False), | |
| # nn.SiLU(), | |
| ACT2FN['silu'], | |
| nn.Linear(hidden_size, hidden_size, bias=False), | |
| ) | |
| self.frequency_embedding_size = frequency_embedding_size | |
| def timestep_embedding(t, dim, max_period=10000): | |
| """ | |
| Create sinusoidal timestep embeddings. | |
| Args: | |
| t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element. | |
| These may be fractional. | |
| dim (`int`): The dimension of the output. | |
| max_period (`int`, optional): Controls the minimum frequency of the embeddings. | |
| Returns: | |
| `torch.Tensor`: An [N, D] Tensor of positional embeddings. | |
| """ | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half | |
| ).to(t.device) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding.to(t.dtype) | |
| def forward(self, t): | |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | |
| t_emb = self.mlp(t_freq) | |
| return t_emb | |
| class FeedForwardNetwork(nn.Module): | |
| """ | |
| Standard feed-forward network with SwiGLU activation. | |
| Args: | |
| embed_dim (`int`): Input dimension | |
| ffn_dim (`int`): Hidden dimension | |
| """ | |
| def __init__( | |
| self, | |
| embed_dim, | |
| ffn_dim, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) | |
| self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) | |
| self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False) | |
| self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function | |
| def forward(self, x): | |
| gate = self.gate_proj(x) | |
| up = self.up_proj(x) | |
| # SwiGLU activation | |
| # gate = F.silu(gate) | |
| gate = self.act_fn(gate) | |
| return self.down_proj(gate * up) | |
| class HeadLayer(nn.Module): | |
| """ | |
| A layer in the diffusion head. | |
| Args: | |
| embed_dim (`int`): Input dimension | |
| ffn_dim (`int`): Hidden dimension | |
| cond_dim (`int`): Condition embedding dimension | |
| norm_eps (`float`, optional): Epsilon for normalization | |
| """ | |
| def __init__( | |
| self, | |
| embed_dim, | |
| ffn_dim, | |
| cond_dim, | |
| norm_eps=1e-5, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.cond_dim = cond_dim | |
| self.ffn_dim = ffn_dim | |
| self.ffn = FeedForwardNetwork( | |
| self.embed_dim, | |
| self.ffn_dim, | |
| ) | |
| self.norm = RMSNorm(self.embed_dim, eps=norm_eps) | |
| self.adaLN_modulation = nn.Sequential( | |
| # nn.SiLU(), | |
| ACT2FN['silu'], | |
| nn.Linear(cond_dim, 3 * self.embed_dim, bias=False) | |
| ) | |
| def forward(self, x, c): | |
| shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1) | |
| x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn)) | |
| return x | |
| class FinalLayer(nn.Module): | |
| """ | |
| Final layer in the diffusion head. | |
| Args: | |
| hidden_size (`int`): Input dimension | |
| output_size (`int`): Output dimension | |
| cond_size (`int`): Condition embedding dimension | |
| norm_eps (`float`, optional): Epsilon for normalization | |
| """ | |
| def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5): | |
| super().__init__() | |
| self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False) | |
| self.linear = nn.Linear(hidden_size, output_size, bias=False) | |
| self.adaLN_modulation = nn.Sequential( | |
| # nn.SiLU(), | |
| ACT2FN['silu'], | |
| nn.Linear(cond_size, 2 * hidden_size, bias=False) | |
| ) | |
| def forward(self, x, c): | |
| shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) | |
| x = modulate(self.norm_final(x), shift, scale) | |
| x = self.linear(x) | |
| return x | |
| class VibeVoiceDiffusionHead(PreTrainedModel): | |
| """ | |
| Diffusion head model for vibevoice. | |
| Args: | |
| config (`VibeVoiceDiffusionHeadConfig`): Model configuration | |
| latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`. | |
| """ | |
| config_class = VibeVoiceDiffusionHeadConfig | |
| supports_gradient_checkpointing = True | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| def __init__( | |
| self, | |
| config, | |
| ): | |
| super().__init__(config) | |
| self.config = config | |
| self.cond_dim = config.hidden_size | |
| latent_size = config.latent_size | |
| self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False) | |
| self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False) | |
| self.t_embedder = TimestepEmbedder(self.cond_dim) | |
| ffn_dim = int(config.hidden_size * config.head_ffn_ratio) | |
| # Create the intermediate layers | |
| self.layers = nn.ModuleList([ | |
| HeadLayer( | |
| embed_dim=config.hidden_size, | |
| ffn_dim=ffn_dim, | |
| cond_dim=self.cond_dim, | |
| norm_eps=config.rms_norm_eps | |
| ) | |
| for _ in range(config.head_layers) | |
| ]) | |
| # Final layer for output | |
| self.final_layer = FinalLayer( | |
| hidden_size=config.hidden_size, | |
| output_size=latent_size, | |
| cond_size=self.cond_dim, | |
| norm_eps=config.rms_norm_eps | |
| ) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| """Initialize the weights of the model.""" | |
| # Initialize timestep embedder | |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
| # Zero-out adaLN modulation layers | |
| for layer in self.layers: | |
| nn.init.constant_(layer.adaLN_modulation[-1].weight, 0) | |
| # Zero-out output layers | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(self.final_layer.linear.weight, 0) | |
| def forward( | |
| self, | |
| noisy_images, | |
| timesteps, | |
| condition, | |
| ): | |
| """ | |
| Forward pass of the prediction head. | |
| Args: | |
| noisy_images (`torch.Tensor`): Noisy images/latents to denoise | |
| timesteps (`torch.Tensor`): Timesteps for diffusion | |
| condition (`torch.Tensor`): Conditioning information | |
| Returns: | |
| `torch.Tensor`: The predicted noise/velocity | |
| """ | |
| x = self.noisy_images_proj(noisy_images) | |
| t = self.t_embedder(timesteps) | |
| condition = self.cond_proj(condition) | |
| c = condition + t | |
| for layer in self.layers: | |
| x = layer(x, c) | |
| x = self.final_layer(x, c) | |
| return x | |
| AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead) | |
| __all__ = [ | |
| "VibeVoiceDiffusionHead", | |
| ] |