vampnet-opera / vampnet /interface.py
hugo flores garcia
recovering from a gittastrophe
41b9d24
raw
history blame
20.5 kB
import os
from pathlib import Path
import math
import logging
import torch
import numpy as np
from audiotools import AudioSignal
import tqdm
from .modules.transformer import VampNet
from .beats import WaveBeat
from .mask import *
# from dac.model.dac import DAC
from lac.model.lac import LAC as DAC
def signal_concat(
audio_signals: list,
):
audio_data = torch.cat([x.audio_data for x in audio_signals], dim=-1)
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
def _load_model(
ckpt: str,
lora_ckpt: str = None,
device: str = "cpu",
chunk_size_s: int = 10,
):
# we need to set strict to False if the model has lora weights to add later
model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False)
# load lora weights if needed
if lora_ckpt is not None:
if not Path(lora_ckpt).exists():
should_cont = input(
f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) "
)
if should_cont != "y":
raise Exception("aborting")
else:
model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False)
model.to(device)
model.eval()
model.chunk_size_s = chunk_size_s
return model
class Interface(torch.nn.Module):
def __init__(
self,
coarse_ckpt: str = None,
coarse_lora_ckpt: str = None,
coarse2fine_ckpt: str = None,
coarse2fine_lora_ckpt: str = None,
codec_ckpt: str = None,
wavebeat_ckpt: str = None,
device: str = "cpu",
coarse_chunk_size_s: int = 10,
coarse2fine_chunk_size_s: int = 3,
compile=True,
):
super().__init__()
assert codec_ckpt is not None, "must provide a codec checkpoint"
self.codec = DAC.load(Path(codec_ckpt))
self.codec.eval()
self.codec.to(device)
self.codec_path = Path(codec_ckpt)
assert coarse_ckpt is not None, "must provide a coarse checkpoint"
self.coarse = _load_model(
ckpt=coarse_ckpt,
lora_ckpt=coarse_lora_ckpt,
device=device,
chunk_size_s=coarse_chunk_size_s,
)
self.coarse_path = Path(coarse_ckpt)
# check if we have a coarse2fine ckpt
if coarse2fine_ckpt is not None:
self.c2f_path = Path(coarse2fine_ckpt)
self.c2f = _load_model(
ckpt=coarse2fine_ckpt,
lora_ckpt=coarse2fine_lora_ckpt,
device=device,
chunk_size_s=coarse2fine_chunk_size_s,
)
else:
self.c2f_path = None
self.c2f = None
if wavebeat_ckpt is not None:
logging.debug(f"loading wavebeat from {wavebeat_ckpt}")
self.beat_tracker = WaveBeat(wavebeat_ckpt)
self.beat_tracker.model.to(device)
else:
self.beat_tracker = None
self.device = device
self.loudness = -24.0
if compile:
logging.debug(f"compiling models")
self.coarse = torch.compile(self.coarse)
if self.c2f is not None:
self.c2f = torch.compile(self.c2f)
self.codec = torch.compile(self.codec)
@classmethod
def default(cls):
from . import download_codec, download_default
print(f"loading default vampnet")
codec_path = download_codec()
coarse_path, c2f_path = download_default()
return Interface(
coarse_ckpt=coarse_path,
coarse2fine_ckpt=c2f_path,
codec_ckpt=codec_path,
)
@classmethod
def available_models(cls):
from . import list_finetuned
return list_finetuned()
def load_finetuned(self, name: str):
assert name in self.available_models(), f"{name} is not a valid model name"
from . import download_finetuned
coarse_path, c2f_path = download_finetuned(name)
self.reload(
coarse_ckpt=coarse_path,
c2f_ckpt=c2f_path,
)
def reload(
self,
coarse_ckpt: str = None,
c2f_ckpt: str = None,
):
if coarse_ckpt is not None:
# check if we already loaded, if so, don't reload
if self.coarse_path == Path(coarse_ckpt):
logging.debug(f"already loaded {coarse_ckpt}")
else:
self.coarse = _load_model(
ckpt=coarse_ckpt,
device=self.device,
chunk_size_s=self.coarse.chunk_size_s,
)
self.coarse_path = Path(coarse_ckpt)
logging.debug(f"loaded {coarse_ckpt}")
if c2f_ckpt is not None:
if self.c2f_path == Path(c2f_ckpt):
logging.debug(f"already loaded {c2f_ckpt}")
else:
self.c2f = _load_model(
ckpt=c2f_ckpt,
device=self.device,
chunk_size_s=self.c2f.chunk_size_s,
)
self.c2f_path = Path(c2f_ckpt)
logging.debug(f"loaded {c2f_ckpt}")
def s2t(self, seconds: float):
"""seconds to tokens"""
if isinstance(seconds, np.ndarray):
return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
else:
return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
def s2t2s(self, seconds: float):
"""seconds to tokens to seconds"""
return self.t2s(self.s2t(seconds))
def t2s(self, tokens: int):
"""tokens to seconds"""
return tokens * self.codec.hop_length / self.codec.sample_rate
def to(self, device):
self.device = device
self.coarse.to(device)
self.codec.to(device)
if self.c2f is not None:
self.c2f.to(device)
if self.beat_tracker is not None:
self.beat_tracker.model.to(device)
return self
def decode(self, z: torch.Tensor):
return self.coarse.decode(z, self.codec)
def _preprocess(self, signal: AudioSignal):
signal = (
signal.clone()
.resample(self.codec.sample_rate)
.to_mono()
.normalize(self.loudness)
.ensure_max_of_audio(1.0)
)
logging.debug(f"length before codec preproc: {signal.samples.shape}")
signal.samples, length = self.codec.preprocess(signal.samples, signal.sample_rate)
logging.debug(f"length after codec preproc: {signal.samples.shape}")
return signal
@torch.inference_mode()
def encode(self, signal: AudioSignal):
signal = signal.to(self.device)
signal = self._preprocess(signal)
z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
return z
def snap_to_beats(
self,
signal: AudioSignal
):
assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
beats, downbeats = self.beat_tracker.extract_beats(signal)
# trim the signa around the first beat time
samples_begin = int(beats[0] * signal.sample_rate )
samples_end = int(beats[-1] * signal.sample_rate)
logging.debug(beats[0])
signal = signal.clone().trim(samples_begin, signal.length - samples_end)
return signal
def make_beat_mask(self,
signal: AudioSignal,
before_beat_s: float = 0.0,
after_beat_s: float = 0.02,
mask_downbeats: bool = True,
mask_upbeats: bool = True,
downbeat_downsample_factor: int = None,
beat_downsample_factor: int = None,
dropout: float = 0.0,
invert: bool = True,
):
"""make a beat synced mask. that is, make a mask that
places 1s at and around the beat, and 0s everywhere else.
"""
assert self.beat_tracker is not None, "No beat tracker loaded"
# get the beat times
beats, downbeats = self.beat_tracker.extract_beats(signal)
# get the beat indices in z
beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
# remove downbeats from beats
beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))]
beats_z = beats_z.tolist()
downbeats_z = downbeats_z.tolist()
# make the mask
seq_len = self.s2t(signal.duration)
mask = torch.zeros(seq_len, device=self.device)
mask_b4 = self.s2t(before_beat_s)
mask_after = self.s2t(after_beat_s)
if beat_downsample_factor is not None:
if beat_downsample_factor < 1:
raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
else:
beat_downsample_factor = 1
if downbeat_downsample_factor is not None:
if downbeat_downsample_factor < 1:
raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
else:
downbeat_downsample_factor = 1
beats_z = beats_z[::beat_downsample_factor]
downbeats_z = downbeats_z[::downbeat_downsample_factor]
logging.debug(f"beats_z: {len(beats_z)}")
logging.debug(f"downbeats_z: {len(downbeats_z)}")
if mask_upbeats:
for beat_idx in beats_z:
_slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
num_steps = mask[_slice[0]:_slice[1]].shape[0]
_m = torch.ones(num_steps, device=self.device)
_m_mask = torch.bernoulli(_m * (1 - dropout))
_m = _m * _m_mask.long()
mask[_slice[0]:_slice[1]] = _m
if mask_downbeats:
for downbeat_idx in downbeats_z:
_slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
num_steps = mask[_slice[0]:_slice[1]].shape[0]
_m = torch.ones(num_steps, device=self.device)
_m_mask = torch.bernoulli(_m * (1 - dropout))
_m = _m * _m_mask.long()
mask[_slice[0]:_slice[1]] = _m
mask = mask.clamp(0, 1)
if invert:
mask = 1 - mask
mask = mask[None, None, :].bool().long()
if self.c2f is not None:
mask = mask.repeat(1, self.c2f.n_codebooks, 1)
else:
mask = mask.repeat(1, self.coarse.n_codebooks, 1)
return mask
def set_chunk_size(self, chunk_size_s: float):
self.coarse.chunk_size_s = chunk_size_s
@torch.inference_mode()
def coarse_to_fine(
self,
z: torch.Tensor,
mask: torch.Tensor = None,
return_mask: bool = False,
**kwargs
):
assert self.c2f is not None, "No coarse2fine model loaded"
length = z.shape[-1]
chunk_len = self.s2t(self.c2f.chunk_size_s)
n_chunks = math.ceil(z.shape[-1] / chunk_len)
# zero pad to chunk_len
if length % chunk_len != 0:
pad_len = chunk_len - (length % chunk_len)
z = torch.nn.functional.pad(z, (0, pad_len))
mask = torch.nn.functional.pad(mask, (0, pad_len), value=1) if mask is not None else None
n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
if n_codebooks_to_append > 0:
z = torch.cat([
z,
torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
], dim=1)
logging.debug(f"appended {n_codebooks_to_append} codebooks to z")
# set the mask to 0 for all conditioning codebooks
if mask is not None:
mask = mask.clone()
mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
fine_z = []
for i in range(n_chunks):
chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
with torch.autocast("cuda", dtype=torch.bfloat16):
chunk = self.c2f.generate(
codec=self.codec,
time_steps=chunk_len,
start_tokens=chunk,
return_signal=False,
mask=mask_chunk,
cfg_guidance=None,
**kwargs
)
fine_z.append(chunk)
fine_z = torch.cat(fine_z, dim=-1)
if return_mask:
return fine_z[:, :, :length].clone(), apply_mask(fine_z, mask, self.c2f.mask_token)[0][:, :, :length].clone()
return fine_z[:, :, :length].clone()
@torch.inference_mode()
def coarse_vamp(
self,
z,
mask,
return_mask=False,
gen_fn=None,
**kwargs
):
# coarse z
cz = z[:, : self.coarse.n_codebooks, :].clone()
mask = mask[:, : self.coarse.n_codebooks, :]
# assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
# cut into chunks, keep the last chunk separate if it's too small
chunk_len = self.s2t(self.coarse.chunk_size_s)
n_chunks = math.ceil(cz.shape[-1] / chunk_len)
last_chunk_len = cz.shape[-1] % chunk_len
cz_chunks = []
mask_chunks = []
for i in range(n_chunks):
chunk = cz[:, :, i * chunk_len : (i + 1) * chunk_len]
mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len]
# make sure that the very first and last timestep of each chunk is 0 so that we don't get a weird
# discontinuity when we stitch the chunks back together
# only if there's already a 0 somewhere in the chunk
if torch.any(mask_chunk == 0):
mask_chunk[:, :, 0] = 0
mask_chunk[:, :, -1] = 0
cz_chunks.append(chunk)
mask_chunks.append(mask_chunk)
# now vamp each chunk
cz_masked_chunks = []
cz_vamped_chunks = []
for chunk, mask_chunk in zip(cz_chunks, mask_chunks):
cz_masked_chunk, mask_chunk = apply_mask(chunk, mask_chunk, self.coarse.mask_token)
cz_masked_chunk = cz_masked_chunk[:, : self.coarse.n_codebooks, :]
cz_masked_chunks.append(cz_masked_chunk)
gen_fn = gen_fn or self.coarse.generate
with torch.autocast("cuda", dtype=torch.bfloat16):
c_vamp_chunk = gen_fn(
codec=self.codec,
time_steps=chunk_len,
start_tokens=cz_masked_chunk,
return_signal=False,
mask=mask_chunk,
**kwargs
)
cz_vamped_chunks.append(c_vamp_chunk)
# stitch the chunks back together
cz_masked = torch.cat(cz_masked_chunks, dim=-1)
c_vamp = torch.cat(cz_vamped_chunks, dim=-1)
# add the fine codes back in
c_vamp = torch.cat(
[c_vamp, z[:, self.coarse.n_codebooks :, :]],
dim=1
)
if return_mask:
return c_vamp, cz_masked
return c_vamp
def build_mask(self,
z: torch.Tensor,
sig: AudioSignal = None,
rand_mask_intensity: float = 1.0,
prefix_s: float = 0.0,
suffix_s: float = 0.0,
periodic_prompt: int = 7,
periodic_prompt_width: int = 1,
onset_mask_width: int = 0,
_dropout: float = 0.0,
upper_codebook_mask: int = 3,
ncc: int = 0,
):
mask = linear_random(z, rand_mask_intensity)
mask = mask_and(
mask,
inpaint(z, self.s2t(prefix_s), self.s2t(suffix_s)),
)
pmask = periodic_mask(z, periodic_prompt, periodic_prompt_width, random_roll=True)
mask = mask_and(mask, pmask)
if onset_mask_width > 0:
assert sig is not None, f"must provide a signal to use onset mask"
mask = mask_and(
mask, onset_mask(
sig, z, self,
width=onset_mask_width
)
)
mask = dropout(mask, _dropout)
mask = codebook_unmask(mask, ncc)
mask = codebook_mask(mask, int(upper_codebook_mask), None)
return mask
def vamp(
self,
codes: torch.Tensor,
mask: torch.Tensor,
batch_size: int = 1,
feedback_steps: int = 1,
time_stretch_factor: int = 1,
return_mask: bool = False,
**kwargs,
):
z = codes
# expand z to batch size
z = z.expand(batch_size, -1, -1)
mask = mask.expand(batch_size, -1, -1)
# stretch mask and z to match the time stretch factor
# we'll add (stretch_factor - 1) mask tokens in between each timestep of z
# and we'll make the mask 1 in all the new slots we added
if time_stretch_factor > 1:
z = z.repeat_interleave(time_stretch_factor, dim=-1)
mask = mask.repeat_interleave(time_stretch_factor, dim=-1)
added_mask = torch.ones_like(mask)
added_mask[:, :, ::time_stretch_factor] = 0
mask = mask.bool() | added_mask.bool()
mask = mask.long()
# the forward pass
logging.debug(z.shape)
logging.debug("coarse!")
zv, mask_z = self.coarse_vamp(
z,
mask=mask,
return_mask=True,
**kwargs
)
# add the top codebooks back in
if zv.shape[1] < z.shape[1]:
logging.debug(f"adding {z.shape[1] - zv.shape[1]} codebooks back in")
zv = torch.cat(
[zv, z[:, self.coarse.n_codebooks :, :]],
dim=1
)
# now, coarse2fine
logging.debug(f"coarse2fine!")
zv, fine_zv_mask = self.coarse_to_fine(
zv,
mask=mask,
typical_filtering=True,
_sampling_steps=[2],
return_mask=True
)
mask_z = torch.cat(
[mask_z[:, :self.coarse.n_codebooks, :], fine_zv_mask[:, self.coarse.n_codebooks:, :]],
dim=1
)
z = zv
if return_mask:
return z, mask_z.cpu(),
else:
return z
def visualize_codes(self, z: torch.Tensor):
import matplotlib.pyplot as plt
# make sure the figsize is square when imshow is called
fig = plt.figure(figsize=(10, 7))
# in subplots, plot z[0] and the mask
# set title to "codes" and "mask"
fig.add_subplot(2, 1, 1)
plt.imshow(z[0].cpu().numpy(), aspect='auto', origin='lower', cmap="tab20")
plt.title("codes")
plt.ylabel("codebook index")
# set the xticks to seconds
if __name__ == "__main__":
import audiotools as at
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
torch.set_logging.debugoptions(threshold=10000)
at.util.seed(42)
interface = Interface(
coarse_ckpt="./models/vampnet/coarse.pth",
coarse2fine_ckpt="./models/vampnet/c2f.pth",
codec_ckpt="./models/vampnet/codec.pth",
device="cuda",
wavebeat_ckpt="./models/wavebeat.pth"
)
sig = at.AudioSignal('assets/example.wav')
z = interface.encode(sig)
mask = interface.build_mask(
z=z,
sig=sig,
rand_mask_intensity=1.0,
prefix_s=0.0,
suffix_s=0.0,
periodic_prompt=7,
periodic_prompt2=7,
periodic_prompt_width=1,
onset_mask_width=5,
_dropout=0.0,
upper_codebook_mask=3,
upper_codebook_mask_2=None,
ncc=0,
)
zv, mask_z = interface.coarse_vamp(
z,
mask=mask,
return_mask=True,
gen_fn=interface.coarse.generate
)
use_coarse2fine = True
if use_coarse2fine:
zv = interface.coarse_to_fine(zv, mask=mask)
breakpoint()
mask = interface.decode(mask_z).cpu()
sig = interface.decode(zv).cpu()
logging.debug("done")