import logging import os from argparse import ArgumentParser from pathlib import Path import pandas as pd import tensordict as td import torch import torch.distributed as distributed import torch.nn.functional as F from open_clip import create_model_from_pretrained from torch.utils.data import DataLoader from tqdm import tqdm from mmaudio.data.data_setup import error_avoidance_collate from mmaudio.data.extraction.wav_dataset import WavTextClipsDataset from mmaudio.ext.autoencoder import AutoEncoderModule from mmaudio.ext.mel_converter import get_mel_converter log = logging.getLogger() torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True local_rank = int(os.environ['LOCAL_RANK']) world_size = int(os.environ['WORLD_SIZE']) # 16k SAMPLE_RATE = 16_000 NUM_SAMPLES = 16_000 * 8 tod_vae_ckpt = './ext_weights/v1-16.pth' bigvgan_vocoder_ckpt = './ext_weights/best_netG.pt' mode = '16k' # 44k """ NOTE: 352800 (8*44100) is not divisible by (STFT hop size * VAE downsampling ratio) which is 1024. 353280 is the next integer divisible by 1024. """ # SAMPLE_RATE = 44100 # NUM_SAMPLES = 353280 # tod_vae_ckpt = './ext_weights/v1-44.pth' # bigvgan_vocoder_ckpt = None # mode = '44k' def distributed_setup(): distributed.init_process_group(backend="nccl") local_rank = distributed.get_rank() world_size = distributed.get_world_size() print(f'Initialized: local_rank={local_rank}, world_size={world_size}') return local_rank, world_size @torch.inference_mode() def main(): distributed_setup() parser = ArgumentParser() parser.add_argument('--data_dir', type=Path, default='./training/example_audios/') parser.add_argument('--captions_tsv', type=Path, default='./training/example_audio.tsv') parser.add_argument('--clips_tsv', type=Path, default='./training/example_output/clips.tsv') parser.add_argument('--latent_dir', type=Path, default='./training/example_output/audio-latents') parser.add_argument('--output_dir', type=Path, default='./training/example_output/memmap/audio-example') parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--num_workers', type=int, default=8) args = parser.parse_args() data_dir = args.data_dir captions_tsv = args.captions_tsv clips_tsv = args.clips_tsv latent_dir = args.latent_dir output_dir = args.output_dir batch_size = args.batch_size num_workers = args.num_workers clip_model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', return_transform=False).eval().cuda() # a hack to make it output last hidden states def new_encode_text(self, text, normalize: bool = False): cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x = self.transformer(x, attn_mask=self.attn_mask) x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] return F.normalize(x, dim=-1) if normalize else x clip_model.encode_text = new_encode_text.__get__(clip_model) tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt, vocoder_ckpt_path=bigvgan_vocoder_ckpt, mode=mode).eval().cuda() mel_converter = get_mel_converter(mode).eval().cuda() dataset = WavTextClipsDataset(data_dir, captions_tsv=captions_tsv, clips_tsv=clips_tsv, sample_rate=SAMPLE_RATE, num_samples=NUM_SAMPLES, normalize_audio=True, reject_silent=True) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=error_avoidance_collate) latent_dir.mkdir(exist_ok=True, parents=True) # extraction for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)): ids = batch['id'] waveforms = batch['waveform'].cuda() tokens = batch['tokens'].cuda() text_features = clip_model.encode_text(tokens, normalize=True) mel = mel_converter(waveforms) dist = tod.encode(mel) a_mean = dist.mean.detach().cpu().transpose(1, 2) a_std = dist.std.detach().cpu().transpose(1, 2) text_features = text_features.detach().cpu() ids = [id for id in ids] captions = [caption for caption in batch['caption']] data = { 'id': ids, 'caption': captions, 'mean': a_mean, 'std': a_std, 'text_features': text_features, } torch.save(data, latent_dir / f'r{local_rank}_{i:05d}.pth') distributed.barrier() # combine the results if local_rank == 0: print('Extraction done. Combining the results.') list_of_ids_and_labels = [] output_data = { 'mean': [], 'std': [], 'text_features': [], } latents = sorted(os.listdir(latent_dir)) latents = [l for l in latents if l.endswith('.pth')] for t in tqdm(latents): data = torch.load(latent_dir / t, weights_only=True) bs = len(data['id']) for bi in range(bs): this_id = data['id'][bi] this_caption = data['caption'][bi] list_of_ids_and_labels.append({'id': this_id, 'caption': this_caption}) output_data['mean'].append(data['mean'][bi]) output_data['std'].append(data['std'][bi]) output_data['text_features'].append(data['text_features'][bi]) output_df = pd.DataFrame(list_of_ids_and_labels) output_dir.mkdir(exist_ok=True, parents=True) output_name = output_dir.stem output_df.to_csv(output_dir.parent / f'{output_name}.tsv', sep='\t', index=False) print(f'Output: {len(output_df)}') output_data = {k: torch.stack(v) for k, v in output_data.items()} td.TensorDict(output_data).memmap_(output_dir) if __name__ == '__main__': main() distributed.destroy_process_group()