audioEditing / utils.py
hilamanor's picture
Stable Audio Open + progbars + mp3 + batched forward + cleanup
7c56def
raw
history blame
4.4 kB
import numpy as np
import torch
from typing import Optional, List, Tuple, NamedTuple, Union
from models import PipelineWrapper
import torchaudio
from audioldm.utils import get_duration
MAX_DURATION = 30
class PromptEmbeddings(NamedTuple):
embedding_hidden_states: torch.Tensor
embedding_class_lables: torch.Tensor
boolean_prompt_mask: torch.Tensor
def load_audio(audio_path: Union[str, np.array], fn_STFT, left: int = 0, right: int = 0,
device: Optional[torch.device] = None,
return_wav: bool = False, stft: bool = False, model_sr: Optional[int] = None) -> torch.Tensor:
if stft: # AudioLDM/tango loading to spectrogram
if type(audio_path) is str:
import audioldm
import audioldm.audio
duration = get_duration(audio_path)
if MAX_DURATION is not None:
duration = min(duration, MAX_DURATION)
mel, _, wav = audioldm.audio.wav_to_fbank(audio_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT)
mel = mel.unsqueeze(0)
else:
mel = audio_path
c, h, w = mel.shape
left = min(left, w-1)
right = min(right, w - left - 1)
mel = mel[:, :, left:w-right]
mel = mel.unsqueeze(0).to(device)
if return_wav:
return mel, 16000, duration, wav
return mel, model_sr, duration
else:
waveform, sr = torchaudio.load(audio_path)
if sr != model_sr:
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=model_sr)
# waveform = waveform.numpy()[0, ...]
def normalize_wav(waveform):
waveform = waveform - torch.mean(waveform)
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
return waveform * 0.5
waveform = normalize_wav(waveform)
# waveform = waveform[None, ...]
# waveform = pad_wav(waveform, segment_length)
# waveform = waveform[0, ...]
waveform = torch.FloatTensor(waveform)
if MAX_DURATION is not None:
duration = min(waveform.shape[-1] / model_sr, MAX_DURATION)
waveform = waveform[:, :int(duration * model_sr)]
# cut waveform
duration = waveform.shape[-1] / model_sr
return waveform, model_sr, duration
def get_height_of_spectrogram(length: int, ldm_stable: PipelineWrapper) -> int:
vocoder_upsample_factor = np.prod(ldm_stable.model.vocoder.config.upsample_rates) / \
ldm_stable.model.vocoder.config.sampling_rate
if length is None:
length = ldm_stable.model.unet.config.sample_size * ldm_stable.model.vae_scale_factor * \
vocoder_upsample_factor
height = int(length / vocoder_upsample_factor)
# original_waveform_length = int(length * ldm_stable.model.vocoder.config.sampling_rate)
if height % ldm_stable.model.vae_scale_factor != 0:
height = int(np.ceil(height / ldm_stable.model.vae_scale_factor)) * ldm_stable.model.vae_scale_factor
print(
f"Audio length in seconds {length} is increased to {height * vocoder_upsample_factor} "
f"so that it can be handled by the model. It will be cut to {length} after the "
f"denoising process."
)
return height
def get_text_embeddings(target_prompt: List[str], target_neg_prompt: List[str], ldm_stable: PipelineWrapper
) -> Tuple[torch.Tensor, PromptEmbeddings, PromptEmbeddings]:
text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = \
ldm_stable.encode_text(target_prompt)
uncond_embedding_hidden_states, uncond_embedding_class_lables, uncond_boolean_prompt_mask = \
ldm_stable.encode_text(target_neg_prompt)
text_emb = PromptEmbeddings(embedding_hidden_states=text_embeddings_hidden_states,
boolean_prompt_mask=text_embeddings_boolean_prompt_mask,
embedding_class_lables=text_embeddings_class_labels)
uncond_emb = PromptEmbeddings(embedding_hidden_states=uncond_embedding_hidden_states,
boolean_prompt_mask=uncond_boolean_prompt_mask,
embedding_class_lables=uncond_embedding_class_lables)
return text_embeddings_class_labels, text_emb, uncond_emb