|
import numpy as np
|
|
import torch
|
|
import typing as tp
|
|
import math
|
|
from torchaudio import transforms as T
|
|
|
|
from .utils import prepare_audio
|
|
from .sampling import sample, sample_k, sample_rf
|
|
from ..data.utils import PadCrop
|
|
|
|
def generate_diffusion_uncond(
|
|
model,
|
|
steps: int = 250,
|
|
batch_size: int = 1,
|
|
sample_size: int = 2097152,
|
|
seed: int = -1,
|
|
device: str = "cuda",
|
|
init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
|
|
init_noise_level: float = 1.0,
|
|
return_latents = False,
|
|
**sampler_kwargs
|
|
) -> torch.Tensor:
|
|
|
|
|
|
audio_sample_size = sample_size
|
|
|
|
|
|
if model.pretransform is not None:
|
|
sample_size = sample_size // model.pretransform.downsampling_ratio
|
|
|
|
|
|
|
|
seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
|
|
|
|
print(seed)
|
|
torch.manual_seed(seed)
|
|
|
|
noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
|
|
|
|
if init_audio is not None:
|
|
|
|
in_sr, init_audio = init_audio
|
|
|
|
io_channels = model.io_channels
|
|
|
|
|
|
if model.pretransform is not None:
|
|
io_channels = model.pretransform.io_channels
|
|
|
|
|
|
init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
|
|
|
|
|
|
if model.pretransform is not None:
|
|
init_audio = model.pretransform.encode(init_audio)
|
|
|
|
init_audio = init_audio.repeat(batch_size, 1, 1)
|
|
else:
|
|
|
|
init_audio = None
|
|
init_noise_level = None
|
|
|
|
|
|
|
|
if init_audio is not None:
|
|
|
|
sampler_kwargs["sigma_max"] = init_noise_level
|
|
mask = None
|
|
else:
|
|
mask = None
|
|
|
|
|
|
|
|
diff_objective = model.diffusion_objective
|
|
|
|
if diff_objective == "v":
|
|
|
|
sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device)
|
|
elif diff_objective == "rectified_flow":
|
|
sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device)
|
|
|
|
|
|
|
|
if model.pretransform is not None and not return_latents:
|
|
sampled = model.pretransform.decode(sampled)
|
|
|
|
|
|
return sampled
|
|
|
|
|
|
def generate_diffusion_cond(
|
|
model,
|
|
steps: int = 250,
|
|
cfg_scale=6,
|
|
conditioning: dict = None,
|
|
conditioning_tensors: tp.Optional[dict] = None,
|
|
negative_conditioning: dict = None,
|
|
negative_conditioning_tensors: tp.Optional[dict] = None,
|
|
batch_size: int = 1,
|
|
sample_size: int = 2097152,
|
|
sample_rate: int = 48000,
|
|
seed: int = -1,
|
|
device: str = "cuda",
|
|
init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
|
|
init_noise_level: float = 1.0,
|
|
mask_args: dict = None,
|
|
return_latents = False,
|
|
**sampler_kwargs
|
|
) -> torch.Tensor:
|
|
"""
|
|
Generate audio from a prompt using a diffusion model.
|
|
|
|
Args:
|
|
model: The diffusion model to use for generation.
|
|
steps: The number of diffusion steps to use.
|
|
cfg_scale: Classifier-free guidance scale
|
|
conditioning: A dictionary of conditioning parameters to use for generation.
|
|
conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation.
|
|
batch_size: The batch size to use for generation.
|
|
sample_size: The length of the audio to generate, in samples.
|
|
sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly)
|
|
seed: The random seed to use for generation, or -1 to use a random seed.
|
|
device: The device to use for generation.
|
|
init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation.
|
|
init_noise_level: The noise level to use when generating from an initial audio sample.
|
|
return_latents: Whether to return the latents used for generation instead of the decoded audio.
|
|
**sampler_kwargs: Additional keyword arguments to pass to the sampler.
|
|
"""
|
|
|
|
|
|
audio_sample_size = sample_size
|
|
|
|
|
|
if model.pretransform is not None:
|
|
sample_size = sample_size // model.pretransform.downsampling_ratio
|
|
|
|
|
|
|
|
seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
torch.backends.cudnn.allow_tf32 = False
|
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors"
|
|
if conditioning_tensors is None:
|
|
conditioning_tensors = model.conditioner(conditioning, device)
|
|
conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors)
|
|
|
|
if negative_conditioning is not None or negative_conditioning_tensors is not None:
|
|
|
|
if negative_conditioning_tensors is None:
|
|
negative_conditioning_tensors = model.conditioner(negative_conditioning, device)
|
|
|
|
negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True)
|
|
else:
|
|
negative_conditioning_tensors = {}
|
|
|
|
if init_audio is not None:
|
|
|
|
in_sr, init_audio = init_audio
|
|
|
|
io_channels = model.io_channels
|
|
|
|
|
|
if model.pretransform is not None:
|
|
io_channels = model.pretransform.io_channels
|
|
|
|
|
|
init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
|
|
|
|
|
|
if model.pretransform is not None:
|
|
init_audio = model.pretransform.encode(init_audio)
|
|
|
|
init_audio = init_audio.repeat(batch_size, 1, 1)
|
|
else:
|
|
|
|
init_audio = None
|
|
init_noise_level = None
|
|
mask_args = None
|
|
|
|
|
|
if init_audio is not None and mask_args is not None:
|
|
|
|
|
|
cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size)
|
|
pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size)
|
|
pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size)
|
|
assert pastefrom < pasteto, "Paste From should be less than Paste To"
|
|
croplen = pasteto - pastefrom
|
|
if cropfrom + croplen > sample_size:
|
|
croplen = sample_size - cropfrom
|
|
cropto = cropfrom + croplen
|
|
pasteto = pastefrom + croplen
|
|
cutpaste = init_audio.new_zeros(init_audio.shape)
|
|
cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto]
|
|
|
|
init_audio = cutpaste
|
|
|
|
mask = build_mask(sample_size, mask_args)
|
|
mask = mask.to(device)
|
|
elif init_audio is not None and mask_args is None:
|
|
|
|
sampler_kwargs["sigma_max"] = init_noise_level
|
|
mask = None
|
|
else:
|
|
mask = None
|
|
|
|
model_dtype = next(model.model.parameters()).dtype
|
|
noise = noise.type(model_dtype)
|
|
conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()}
|
|
|
|
|
|
|
|
diff_objective = model.diffusion_objective
|
|
|
|
if diff_objective == "v":
|
|
|
|
sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
|
|
|
|
elif diff_objective == "rectified_flow":
|
|
|
|
if "sigma_min" in sampler_kwargs:
|
|
del sampler_kwargs["sigma_min"]
|
|
|
|
if "sampler_type" in sampler_kwargs:
|
|
del sampler_kwargs["sampler_type"]
|
|
|
|
sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
|
|
|
|
|
|
del noise
|
|
del conditioning_tensors
|
|
del conditioning_inputs
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
if model.pretransform is not None and not return_latents:
|
|
|
|
sampled = sampled.to(next(model.pretransform.parameters()).dtype)
|
|
sampled = model.pretransform.decode(sampled)
|
|
|
|
return sampled
|
|
|
|
|
|
|
|
|
|
|
|
def build_mask(sample_size, mask_args):
|
|
maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size)
|
|
maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size)
|
|
softnessL = round(mask_args["softnessL"]/100.0 * sample_size)
|
|
softnessR = round(mask_args["softnessR"]/100.0 * sample_size)
|
|
marination = mask_args["marination"]
|
|
|
|
hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL]
|
|
hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:]
|
|
|
|
mask = torch.zeros((sample_size))
|
|
mask[maskstart:maskend] = 1
|
|
mask[maskstart:maskstart+softnessL] = hannL
|
|
mask[maskend-softnessR:maskend] = hannR
|
|
|
|
if marination > 0:
|
|
mask = mask * (1-marination)
|
|
return mask
|
|
|