Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass | |
| from math import pi | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.utils.checkpoint | |
| from ...configuration_utils import ConfigMixin, register_to_config | |
| from ...models.modeling_utils import ModelMixin | |
| from ...utils import BaseOutput, logging | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class StableAudioPositionalEmbedding(nn.Module): | |
| """Used for continuous time""" | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| assert (dim % 2) == 0 | |
| half_dim = dim // 2 | |
| self.weights = nn.Parameter(torch.randn(half_dim)) | |
| def forward(self, times: torch.Tensor) -> torch.Tensor: | |
| times = times[..., None] | |
| freqs = times * self.weights[None] * 2 * pi | |
| fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) | |
| fouriered = torch.cat((times, fouriered), dim=-1) | |
| return fouriered | |
| class StableAudioProjectionModelOutput(BaseOutput): | |
| """ | |
| Args: | |
| Class for StableAudio projection layer's outputs. | |
| text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): | |
| Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder. | |
| seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): | |
| Sequence of hidden-states obtained by linearly projecting the audio start hidden states. | |
| seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): | |
| Sequence of hidden-states obtained by linearly projecting the audio end hidden states. | |
| """ | |
| text_hidden_states: Optional[torch.Tensor] = None | |
| seconds_start_hidden_states: Optional[torch.Tensor] = None | |
| seconds_end_hidden_states: Optional[torch.Tensor] = None | |
| class StableAudioNumberConditioner(nn.Module): | |
| """ | |
| A simple linear projection model to map numbers to a latent space. | |
| Args: | |
| number_embedding_dim (`int`): | |
| Dimensionality of the number embeddings. | |
| min_value (`int`): | |
| The minimum value of the seconds number conditioning modules. | |
| max_value (`int`): | |
| The maximum value of the seconds number conditioning modules | |
| internal_dim (`int`): | |
| Dimensionality of the intermediate number hidden states. | |
| """ | |
| def __init__( | |
| self, | |
| number_embedding_dim, | |
| min_value, | |
| max_value, | |
| internal_dim: Optional[int] = 256, | |
| ): | |
| super().__init__() | |
| self.time_positional_embedding = nn.Sequential( | |
| StableAudioPositionalEmbedding(internal_dim), | |
| nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), | |
| ) | |
| self.number_embedding_dim = number_embedding_dim | |
| self.min_value = min_value | |
| self.max_value = max_value | |
| def forward( | |
| self, | |
| floats: torch.Tensor, | |
| ): | |
| floats = floats.clamp(self.min_value, self.max_value) | |
| normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) | |
| # Cast floats to same type as embedder | |
| embedder_dtype = next(self.time_positional_embedding.parameters()).dtype | |
| normalized_floats = normalized_floats.to(embedder_dtype) | |
| embedding = self.time_positional_embedding(normalized_floats) | |
| float_embeds = embedding.view(-1, 1, self.number_embedding_dim) | |
| return float_embeds | |
| class StableAudioProjectionModel(ModelMixin, ConfigMixin): | |
| """ | |
| A simple linear projection model to map the conditioning values to a shared latent space. | |
| Args: | |
| text_encoder_dim (`int`): | |
| Dimensionality of the text embeddings from the text encoder (T5). | |
| conditioning_dim (`int`): | |
| Dimensionality of the output conditioning tensors. | |
| min_value (`int`): | |
| The minimum value of the seconds number conditioning modules. | |
| max_value (`int`): | |
| The maximum value of the seconds number conditioning modules | |
| """ | |
| def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value): | |
| super().__init__() | |
| self.text_projection = ( | |
| nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim) | |
| ) | |
| self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) | |
| self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) | |
| def forward( | |
| self, | |
| text_hidden_states: Optional[torch.Tensor] = None, | |
| start_seconds: Optional[torch.Tensor] = None, | |
| end_seconds: Optional[torch.Tensor] = None, | |
| ): | |
| text_hidden_states = ( | |
| text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states) | |
| ) | |
| seconds_start_hidden_states = ( | |
| start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds) | |
| ) | |
| seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds) | |
| return StableAudioProjectionModelOutput( | |
| text_hidden_states=text_hidden_states, | |
| seconds_start_hidden_states=seconds_start_hidden_states, | |
| seconds_end_hidden_states=seconds_end_hidden_states, | |
| ) | |
 
			
