| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
| from math import pi |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.utils.checkpoint |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...models.modeling_utils import ModelMixin |
| from ...utils import BaseOutput, logging |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class StableAudioPositionalEmbedding(nn.Module): |
| """Used for continuous time""" |
|
|
| def __init__(self, dim: int): |
| super().__init__() |
| assert (dim % 2) == 0 |
| half_dim = dim // 2 |
| self.weights = nn.Parameter(torch.randn(half_dim)) |
|
|
| def forward(self, times: torch.Tensor) -> torch.Tensor: |
| times = times[..., None] |
| freqs = times * self.weights[None] * 2 * pi |
| fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) |
| fouriered = torch.cat((times, fouriered), dim=-1) |
| return fouriered |
|
|
|
|
| @dataclass |
| class StableAudioProjectionModelOutput(BaseOutput): |
| """ |
| Args: |
| Class for StableAudio projection layer's outputs. |
| text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder. |
| seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): |
| Sequence of hidden-states obtained by linearly projecting the audio start hidden states. |
| seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): |
| Sequence of hidden-states obtained by linearly projecting the audio end hidden states. |
| """ |
|
|
| text_hidden_states: Optional[torch.Tensor] = None |
| seconds_start_hidden_states: Optional[torch.Tensor] = None |
| seconds_end_hidden_states: Optional[torch.Tensor] = None |
|
|
|
|
| class StableAudioNumberConditioner(nn.Module): |
| """ |
| A simple linear projection model to map numbers to a latent space. |
| |
| Args: |
| number_embedding_dim (`int`): |
| Dimensionality of the number embeddings. |
| min_value (`int`): |
| The minimum value of the seconds number conditioning modules. |
| max_value (`int`): |
| The maximum value of the seconds number conditioning modules |
| internal_dim (`int`): |
| Dimensionality of the intermediate number hidden states. |
| """ |
|
|
| def __init__( |
| self, |
| number_embedding_dim, |
| min_value, |
| max_value, |
| internal_dim: Optional[int] = 256, |
| ): |
| super().__init__() |
| self.time_positional_embedding = nn.Sequential( |
| StableAudioPositionalEmbedding(internal_dim), |
| nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), |
| ) |
|
|
| self.number_embedding_dim = number_embedding_dim |
| self.min_value = min_value |
| self.max_value = max_value |
|
|
| def forward( |
| self, |
| floats: torch.Tensor, |
| ): |
| floats = floats.clamp(self.min_value, self.max_value) |
|
|
| normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) |
|
|
| |
| embedder_dtype = next(self.time_positional_embedding.parameters()).dtype |
| normalized_floats = normalized_floats.to(embedder_dtype) |
|
|
| embedding = self.time_positional_embedding(normalized_floats) |
| float_embeds = embedding.view(-1, 1, self.number_embedding_dim) |
|
|
| return float_embeds |
|
|
|
|
| class StableAudioProjectionModel(ModelMixin, ConfigMixin): |
| """ |
| A simple linear projection model to map the conditioning values to a shared latent space. |
| |
| Args: |
| text_encoder_dim (`int`): |
| Dimensionality of the text embeddings from the text encoder (T5). |
| conditioning_dim (`int`): |
| Dimensionality of the output conditioning tensors. |
| min_value (`int`): |
| The minimum value of the seconds number conditioning modules. |
| max_value (`int`): |
| The maximum value of the seconds number conditioning modules |
| """ |
|
|
| @register_to_config |
| def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value): |
| super().__init__() |
| self.text_projection = ( |
| nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim) |
| ) |
| self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) |
| self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) |
|
|
| def forward( |
| self, |
| text_hidden_states: Optional[torch.Tensor] = None, |
| start_seconds: Optional[torch.Tensor] = None, |
| end_seconds: Optional[torch.Tensor] = None, |
| ): |
| text_hidden_states = ( |
| text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states) |
| ) |
| seconds_start_hidden_states = ( |
| start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds) |
| ) |
| seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds) |
|
|
| return StableAudioProjectionModelOutput( |
| text_hidden_states=text_hidden_states, |
| seconds_start_hidden_states=seconds_start_hidden_states, |
| seconds_end_hidden_states=seconds_end_hidden_states, |
| ) |
|
|