import os from pathlib import Path import math import torch import numpy as np from audiotools import AudioSignal import tqdm from .modules.transformer import VampNet from .beats import WaveBeat from .mask import * # from dac.model.dac import DAC from lac.model.lac import LAC as DAC def signal_concat( audio_signals: list, ): audio_data = torch.cat([x.audio_data for x in audio_signals], dim=-1) return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate) def _load_model( ckpt: str, lora_ckpt: str = None, device: str = "cpu", chunk_size_s: int = 10, ): # we need to set strict to False if the model has lora weights to add later model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False) # load lora weights if needed if lora_ckpt is not None: if not Path(lora_ckpt).exists(): should_cont = input( f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) " ) if should_cont != "y": raise Exception("aborting") else: model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False) model.to(device) model.eval() model.chunk_size_s = chunk_size_s return model class Interface(torch.nn.Module): def __init__( self, coarse_ckpt: str = None, coarse_lora_ckpt: str = None, coarse2fine_ckpt: str = None, coarse2fine_lora_ckpt: str = None, codec_ckpt: str = None, wavebeat_ckpt: str = None, device: str = "cpu", coarse_chunk_size_s: int = 10, coarse2fine_chunk_size_s: int = 3, ): super().__init__() assert codec_ckpt is not None, "must provide a codec checkpoint" self.codec = DAC.load(Path(codec_ckpt)) self.codec.eval() self.codec.to(device) assert coarse_ckpt is not None, "must provide a coarse checkpoint" self.coarse = _load_model( ckpt=coarse_ckpt, lora_ckpt=coarse_lora_ckpt, device=device, chunk_size_s=coarse_chunk_size_s, ) # check if we have a coarse2fine ckpt if coarse2fine_ckpt is not None: self.c2f = _load_model( ckpt=coarse2fine_ckpt, lora_ckpt=coarse2fine_lora_ckpt, device=device, chunk_size_s=coarse2fine_chunk_size_s, ) else: self.c2f = None if wavebeat_ckpt is not None: print(f"loading wavebeat from {wavebeat_ckpt}") self.beat_tracker = WaveBeat(wavebeat_ckpt) self.beat_tracker.model.to(device) else: self.beat_tracker = None self.device = device def lora_load( self, coarse_ckpt: str = None, c2f_ckpt: str = None, full_ckpts: bool = False, ): if full_ckpts: if coarse_ckpt is not None: self.coarse = _load_model( ckpt=coarse_ckpt, device=self.device, chunk_size_s=self.coarse.chunk_size_s, ) if c2f_ckpt is not None: self.c2f = _load_model( ckpt=c2f_ckpt, device=self.device, chunk_size_s=self.c2f.chunk_size_s, ) else: if coarse_ckpt is not None: self.coarse.to("cpu") state_dict = torch.load(coarse_ckpt, map_location="cpu") print(f"loading coarse from {coarse_ckpt}") self.coarse.load_state_dict(state_dict, strict=False) self.coarse.to(self.device) if c2f_ckpt is not None: self.c2f.to("cpu") state_dict = torch.load(c2f_ckpt, map_location="cpu") print(f"loading c2f from {c2f_ckpt}") self.c2f.load_state_dict(state_dict, strict=False) self.c2f.to(self.device) def s2t(self, seconds: float): """seconds to tokens""" if isinstance(seconds, np.ndarray): return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length) else: return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length) def s2t2s(self, seconds: float): """seconds to tokens to seconds""" return self.t2s(self.s2t(seconds)) def t2s(self, tokens: int): """tokens to seconds""" return tokens * self.codec.hop_length / self.codec.sample_rate def to(self, device): self.device = device self.coarse.to(device) self.codec.to(device) if self.c2f is not None: self.c2f.to(device) if self.beat_tracker is not None: self.beat_tracker.model.to(device) return self def to_signal(self, z: torch.Tensor): return self.coarse.to_signal(z, self.codec) def preprocess(self, signal: AudioSignal): signal = ( signal.clone() .resample(self.codec.sample_rate) .to_mono() .normalize(-24) .ensure_max_of_audio(1.0) ) return signal @torch.inference_mode() def encode(self, signal: AudioSignal): signal = self.preprocess(signal).to(self.device) z = self.codec.encode(signal.samples, signal.sample_rate)["codes"] return z def snap_to_beats( self, signal: AudioSignal ): assert hasattr(self, "beat_tracker"), "No beat tracker loaded" beats, downbeats = self.beat_tracker.extract_beats(signal) # trim the signa around the first beat time samples_begin = int(beats[0] * signal.sample_rate ) samples_end = int(beats[-1] * signal.sample_rate) print(beats[0]) signal = signal.clone().trim(samples_begin, signal.length - samples_end) return signal def make_beat_mask(self, signal: AudioSignal, before_beat_s: float = 0.0, after_beat_s: float = 0.02, mask_downbeats: bool = True, mask_upbeats: bool = True, downbeat_downsample_factor: int = None, beat_downsample_factor: int = None, dropout: float = 0.0, invert: bool = True, ): """make a beat synced mask. that is, make a mask that places 1s at and around the beat, and 0s everywhere else. """ assert self.beat_tracker is not None, "No beat tracker loaded" # get the beat times beats, downbeats = self.beat_tracker.extract_beats(signal) # get the beat indices in z beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats) # remove downbeats from beats beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))] beats_z = beats_z.tolist() downbeats_z = downbeats_z.tolist() # make the mask seq_len = self.s2t(signal.duration) mask = torch.zeros(seq_len, device=self.device) mask_b4 = self.s2t(before_beat_s) mask_after = self.s2t(after_beat_s) if beat_downsample_factor is not None: if beat_downsample_factor < 1: raise ValueError("mask_beat_downsample_factor must be >= 1 or None") else: beat_downsample_factor = 1 if downbeat_downsample_factor is not None: if downbeat_downsample_factor < 1: raise ValueError("mask_beat_downsample_factor must be >= 1 or None") else: downbeat_downsample_factor = 1 beats_z = beats_z[::beat_downsample_factor] downbeats_z = downbeats_z[::downbeat_downsample_factor] print(f"beats_z: {len(beats_z)}") print(f"downbeats_z: {len(downbeats_z)}") if mask_upbeats: for beat_idx in beats_z: _slice = int(beat_idx - mask_b4), int(beat_idx + mask_after) num_steps = mask[_slice[0]:_slice[1]].shape[0] _m = torch.ones(num_steps, device=self.device) _m_mask = torch.bernoulli(_m * (1 - dropout)) _m = _m * _m_mask.long() mask[_slice[0]:_slice[1]] = _m if mask_downbeats: for downbeat_idx in downbeats_z: _slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after) num_steps = mask[_slice[0]:_slice[1]].shape[0] _m = torch.ones(num_steps, device=self.device) _m_mask = torch.bernoulli(_m * (1 - dropout)) _m = _m * _m_mask.long() mask[_slice[0]:_slice[1]] = _m mask = mask.clamp(0, 1) if invert: mask = 1 - mask mask = mask[None, None, :].bool().long() if self.c2f is not None: mask = mask.repeat(1, self.c2f.n_codebooks, 1) else: mask = mask.repeat(1, self.coarse.n_codebooks, 1) return mask def coarse_to_fine( self, z: torch.Tensor, mask: torch.Tensor = None, **kwargs ): assert self.c2f is not None, "No coarse2fine model loaded" length = z.shape[-1] chunk_len = self.s2t(self.c2f.chunk_size_s) n_chunks = math.ceil(z.shape[-1] / chunk_len) # zero pad to chunk_len if length % chunk_len != 0: pad_len = chunk_len - (length % chunk_len) z = torch.nn.functional.pad(z, (0, pad_len)) mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1] if n_codebooks_to_append > 0: z = torch.cat([ z, torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device) ], dim=1) # set the mask to 0 for all conditioning codebooks if mask is not None: mask = mask.clone() mask[:, :self.c2f.n_conditioning_codebooks, :] = 0 fine_z = [] for i in range(n_chunks): chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len] mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None chunk = self.c2f.generate( codec=self.codec, time_steps=chunk_len, start_tokens=chunk, return_signal=False, mask=mask_chunk, **kwargs ) fine_z.append(chunk) fine_z = torch.cat(fine_z, dim=-1) return fine_z[:, :, :length].clone() def coarse_vamp( self, z, mask, return_mask=False, gen_fn=None, **kwargs ): # coarse z cz = z[:, : self.coarse.n_codebooks, :].clone() assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}" mask = mask[:, : self.coarse.n_codebooks, :] cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token) cz_masked = cz_masked[:, : self.coarse.n_codebooks, :] gen_fn = gen_fn or self.coarse.generate c_vamp = gen_fn( codec=self.codec, time_steps=cz.shape[-1], start_tokens=cz, mask=mask, return_signal=False, **kwargs ) # add the fine codes back in c_vamp = torch.cat( [c_vamp, z[:, self.coarse.n_codebooks :, :]], dim=1 ) if return_mask: return c_vamp, cz_masked return c_vamp if __name__ == "__main__": import audiotools as at import logging logger = logging.getLogger() logger.setLevel(logging.INFO) torch.set_printoptions(threshold=10000) at.util.seed(42) interface = Interface( coarse_ckpt="./models/vampnet/coarse.pth", coarse2fine_ckpt="./models/vampnet/c2f.pth", codec_ckpt="./models/vampnet/codec.pth", device="cuda", wavebeat_ckpt="./models/wavebeat.pth" ) sig = at.AudioSignal('assets/example.wav') z = interface.encode(sig) breakpoint() # mask = linear_random(z, 1.0) # mask = mask_and( # mask, periodic_mask( # z, # 32, # 1, # random_roll=True # ) # ) # mask = interface.make_beat_mask( # sig, 0.0, 0.075 # ) # mask = dropout(mask, 0.0) # mask = codebook_unmask(mask, 0) mask = inpaint(z, n_prefix=100, n_suffix=100) zv, mask_z = interface.coarse_vamp( z, mask=mask, sampling_steps=36, temperature=8.0, return_mask=True, gen_fn=interface.coarse.generate ) use_coarse2fine = True if use_coarse2fine: zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask) breakpoint() mask = interface.to_signal(mask_z).cpu() sig = interface.to_signal(zv).cpu() print("done")