diff --git a/app.py b/app.py index a46f1a5ba88875d49487891f9b57f287a6237e85..c75ff99e7b04ee683e184070048451d39f7e88b4 100644 --- a/app.py +++ b/app.py @@ -2,12 +2,12 @@ from huggingface_hub import hf_hub_download import torch import os -os.environ["TOKENIZERS_PARALLELISM"] = "true" - import gradio as gr from audioldm2 import text_to_audio, build_model from share_btn import community_icon_html, loading_icon_html, share_js +os.environ["TOKENIZERS_PARALLELISM"] = "true" + model_id = "haoheliu/audioldm2-full" hf_hub_download(repo_id="haoheliu/audioldm2-full", filename="audioldm2-full.pth") diff --git a/audioldm2/__init__.py b/audioldm2/__init__.py deleted file mode 100755 index 91befda907125b4772601b1df2c9a8a52b733735..0000000000000000000000000000000000000000 --- a/audioldm2/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .utils import seed_everything, save_wave, get_time, get_duration, read_list -from .pipeline import * diff --git a/audioldm2/audiomae_gen/__init__.py b/audioldm2/audiomae_gen/__init__.py deleted file mode 100755 index 7202889ac7aeb7e5b344da994206715cbbb3891e..0000000000000000000000000000000000000000 --- a/audioldm2/audiomae_gen/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .sequence_input import Sequence2AudioMAE diff --git a/audioldm2/audiomae_gen/sequence_input.py b/audioldm2/audiomae_gen/sequence_input.py deleted file mode 100755 index 4d961a0dd7157689fab6291bb3c40d9bd656b5f1..0000000000000000000000000000000000000000 --- a/audioldm2/audiomae_gen/sequence_input.py +++ /dev/null @@ -1,429 +0,0 @@ -import torch -import torch.nn as nn -from audioldm2.latent_diffusion.util import ( - instantiate_from_config, -) - -# from latent_diffusion.modules.encoders.modules import CLAPAudioEmbeddingClassifierFreev2 -from transformers import GPT2Config, GPT2Model -import torch.optim.lr_scheduler as lr_scheduler - -class Sequence2AudioMAE(nn.Module): - def __init__( - self, - base_learning_rate, - sequence_gen_length, - sequence_input_key, - sequence_input_embed_dim, - cond_stage_config, - optimizer_type="AdamW", - use_warmup=True, - use_ar_gen_loss=False, - use_audiomae_linear=False, - target_tokens_mask_ratio=0.0, - random_mask_ratio=False, - **kwargs - ): - super().__init__() - assert use_audiomae_linear == False - self.random_mask_ratio = random_mask_ratio - self.learning_rate = base_learning_rate - self.cond_stage_config = cond_stage_config - self.use_audiomae_linear = use_audiomae_linear - self.optimizer_type = optimizer_type - self.use_warmup = use_warmup - self.use_ar_gen_loss = use_ar_gen_loss - # Even though the LDM can be conditioned on mutliple pooling rate - # Our model always predict the higest pooling rate - - # self.time_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["time_pooling_factors"]) - # self.freq_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["freq_pooling_factors"]) - # self.mae_token_num = int(512/(self.time_pool*self.freq_pool)) - - self.mae_token_num = sequence_gen_length - self.sequence_input_key = sequence_input_key - self.sequence_input_embed_dim = sequence_input_embed_dim - self.target_tokens_mask_ratio = target_tokens_mask_ratio - - self.start_of_sequence_tokens = nn.Embedding(32, 768) - self.end_of_sequence_tokens = nn.Embedding(32, 768) - - self.input_sequence_embed_linear = nn.ModuleList([]) - self.initial_learning_rate = None - - for dim in self.sequence_input_embed_dim: - self.input_sequence_embed_linear.append(nn.Linear(dim, 768)) - - self.cond_stage_models = nn.ModuleList([]) - self.instantiate_cond_stage(cond_stage_config) - self.initialize_param_check_toolkit() - - # configuration = GPT2Config(n_layer=1) # TODO - # self.model=GPT2Model(configuration) - ################### - # self.model=nn.Linear(768,768, bias=False) # TODO change the model - # with torch.no_grad(): - # self.model.weight.copy_(torch.eye(768)) - ################### - self.model = GPT2Model(GPT2Config.from_pretrained("gpt2")) - ################### - # self.model = nn.LSTM(input_size=768, hidden_size=768, num_layers=1,bias=False) # TODO - - # self.loss_fn = nn.MSELoss() - self.loss_fn = nn.L1Loss() - - self.logger_save_dir = None - self.logger_exp_name = None - self.logger_exp_group_name = None - self.logger_version = None - - def set_log_dir(self, save_dir, exp_group_name, exp_name): - self.logger_save_dir = save_dir - self.logger_exp_group_name = exp_group_name - self.logger_exp_name = exp_name - - def cfg_uncond(self, batch_size): - unconditional_conditioning = {} - for key in self.cond_stage_model_metadata: - model_idx = self.cond_stage_model_metadata[key]["model_idx"] - unconditional_conditioning[key] = self.cond_stage_models[ - model_idx - ].get_unconditional_condition(batch_size) - assert ( - "crossattn_audiomae_pooled" in unconditional_conditioning.keys() - ), "The module is not initialized with AudioMAE" - unconditional_conditioning[ - "crossattn_clap_to_audiomae_feature" - ] = unconditional_conditioning["crossattn_audiomae_pooled"] - return unconditional_conditioning - - def configure_optimizers(self): - lr = float(self.learning_rate) - # params = list(self.model.parameters()) + list(self.input_sequence_embed_linear.parameters()) - params = list(self.parameters()) - - # opt = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.98), eps=1e-9) - opt = eval(self.optimizer_type)(params, lr=lr) - scheduler = lr_scheduler.StepLR(opt, step_size=10, gamma=0.8) - return [opt], [scheduler] - - def add_sos_eos_tokens(self, _id, sequence, attn_mask): - batchsize = sequence.size(0) - - new_attn_mask_step = torch.ones((batchsize, 1)).to(sequence.device) - key_id = torch.tensor([_id]).to(sequence.device) - - # Add two more steps to attn mask - new_attn_mask = torch.cat( - [new_attn_mask_step, attn_mask, new_attn_mask_step], dim=1 - ) - - # Add two more tokens in the sequence - sos_token = self.start_of_sequence_tokens(key_id).expand(batchsize, 1, -1) - eos_token = self.end_of_sequence_tokens(key_id).expand(batchsize, 1, -1) - new_sequence = torch.cat([sos_token, sequence, eos_token], dim=1) - return new_sequence, new_attn_mask - - def truncate_sequence_and_mask(self, sequence, mask, max_len=512): - if sequence.size(1) > max_len: - print( - "The input sequence length to GPT-2 model is too long:", - sequence.size(1), - ) - return sequence[:, :max_len], mask[:, :max_len] - else: - return sequence, mask - - def get_input_sequence_and_mask(self, cond_dict): - input_embeds = None - input_embeds_attn_mask = None - for _id, sequence_key in enumerate(self.sequence_input_key): - assert sequence_key in cond_dict.keys(), ( - "Invalid sequence key %s" % sequence_key - ) - cond_embed = cond_dict[sequence_key] - if isinstance(cond_embed, list): - assert ( - len(cond_embed) == 2 - ), "The crossattn returned list should have length 2, including embed and attn_mask" - item_input_embeds, item_attn_mask = cond_embed - - item_input_embeds = self.input_sequence_embed_linear[_id]( - item_input_embeds - ) - - item_input_embeds, item_attn_mask = self.add_sos_eos_tokens( - _id, item_input_embeds, item_attn_mask - ) - - if input_embeds is None and input_embeds_attn_mask is None: - input_embeds, input_embeds_attn_mask = ( - item_input_embeds, - item_attn_mask, - ) - else: - input_embeds = torch.cat( - [input_embeds, item_input_embeds], dim=1 - ) # The 1-st dimension is time steps - input_embeds_attn_mask = torch.cat( - [input_embeds_attn_mask, item_attn_mask], dim=1 - ) # The 1-st dimension is time steps - else: - assert isinstance(cond_embed, torch.Tensor) - cond_embed = self.input_sequence_embed_linear[_id](cond_embed) - attn_mask = torch.ones((cond_embed.size(0), cond_embed.size(1))).to( - cond_embed.device - ) - - item_input_embeds, item_attn_mask = self.add_sos_eos_tokens( - _id, cond_embed, attn_mask - ) - - if input_embeds is None and input_embeds_attn_mask is None: - input_embeds, input_embeds_attn_mask = ( - item_input_embeds, - item_attn_mask, - ) - else: - input_embeds, input_embeds_attn_mask = torch.cat( - [input_embeds, item_input_embeds], dim=1 - ), torch.cat([input_embeds_attn_mask, item_attn_mask], dim=1) - - assert input_embeds is not None and input_embeds_attn_mask is not None - - input_embeds, input_embeds_attn_mask = self.truncate_sequence_and_mask( - input_embeds, input_embeds_attn_mask, int(1024 - self.mae_token_num) - ) - cond_sequence_end_time_idx = input_embeds.size( - 1 - ) # The index that we start to collect the output embeds - - return input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx - - def warmup_step(self): - if self.initial_learning_rate is None: - self.initial_learning_rate = float(self.learning_rate) - - # Only the first parameter group - if self.global_step <= 1000: - if self.global_step == 0: - print( - "Warming up learning rate start with %s" - % self.initial_learning_rate - ) - self.trainer.optimizers[0].param_groups[0]["lr"] = ( - self.global_step / 1000 - ) * self.initial_learning_rate - else: - # TODO set learning rate here - self.trainer.optimizers[0].param_groups[0][ - "lr" - ] = self.initial_learning_rate - - def mask_target_sequence(self, target_embeds, target_embeds_attn_mask): - time_seq_mask = None - if self.target_tokens_mask_ratio > 1e-4: - batchsize, time_seq_len, embed_dim = target_embeds.size() - _, time_seq_len = target_embeds_attn_mask.size() - # Generate random mask - if self.random_mask_ratio: - mask_ratio = torch.rand(1).item() * self.target_tokens_mask_ratio - else: - mask_ratio = self.target_tokens_mask_ratio - - time_seq_mask = (torch.rand((batchsize, time_seq_len)) > mask_ratio).to( - target_embeds.device - ) - # Mask the target embedding - target_embeds = target_embeds * time_seq_mask.unsqueeze(-1) - target_embeds_attn_mask = target_embeds_attn_mask * time_seq_mask - return target_embeds, target_embeds_attn_mask, time_seq_mask - - def generate_partial(self, batch, cond_dict=None, no_grad=False): - if cond_dict is None: - cond_dict = self.get_input(batch) - - print("Generate partially prompted audio with in-context learning") - # self.model.train() - # assert self.model.training==True - - target_embeds, target_embeds_attn_mask = ( - cond_dict["crossattn_audiomae_pooled"][0], - cond_dict["crossattn_audiomae_pooled"][1], - ) - - target_time_steps = target_embeds.size(1) - - ( - input_embeds, - input_embeds_attn_mask, - cond_sequence_end_time_idx, - ) = self.get_input_sequence_and_mask(cond_dict) - - model_input = torch.cat( - [input_embeds, target_embeds[:, : target_time_steps // 4, :]], dim=1 - ) - model_input_mask = torch.cat( - [ - input_embeds_attn_mask, - target_embeds_attn_mask[:, : target_time_steps // 4], - ], - dim=1, - ) - - steps = self.mae_token_num - - for _ in range(3 * steps // 4): - output = self.model( - inputs_embeds=model_input, attention_mask=model_input_mask - )["last_hidden_state"] - # Update the model input - model_input = torch.cat([model_input, output[:, -1:, :]], dim=1) - # Update the attention mask - attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to( - model_input.device - ) - model_input_mask = torch.cat( - [model_input_mask, attention_mask_new_step], dim=1 - ) - - output = model_input[:, cond_sequence_end_time_idx:] - - return output, cond_dict - - def generate(self, batch, cond_dict=None, no_grad=False): - if cond_dict is None: - cond_dict = self.get_input(batch) - - # self.model.train() - # print("!!!!!!!!!!!!!train") - - ( - input_embeds, - input_embeds_attn_mask, - cond_sequence_end_time_idx, - ) = self.get_input_sequence_and_mask(cond_dict) - model_input = input_embeds - model_input_mask = input_embeds_attn_mask - - steps = self.mae_token_num - - for _ in range(steps): - output = self.model( - inputs_embeds=model_input, attention_mask=model_input_mask - )["last_hidden_state"] - # Update the model input - model_input = torch.cat([model_input, output[:, -1:, :]], dim=1) - # Update the attention mask - attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to( - model_input.device - ) - model_input_mask = torch.cat( - [model_input_mask, attention_mask_new_step], dim=1 - ) - - return model_input[:, cond_sequence_end_time_idx:], cond_dict - - def get_input_item(self, batch, k): - fname, text, waveform, stft, fbank = ( - batch["fname"], - batch["text"], - batch["waveform"], - batch["stft"], - batch["log_mel_spec"], - ) - ret = {} - - ret["fbank"] = ( - fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() - ) - ret["stft"] = stft.to(memory_format=torch.contiguous_format).float() - # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float() - ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() - ret["text"] = list(text) - ret["fname"] = fname - - for key in batch.keys(): - if key not in ret.keys(): - ret[key] = batch[key] - - return ret[k] - - def get_input(self, batch): - cond_dict = {} - if len(self.cond_stage_model_metadata.keys()) > 0: - unconditional_cfg = False - - for cond_model_key in self.cond_stage_model_metadata.keys(): - cond_stage_key = self.cond_stage_model_metadata[cond_model_key][ - "cond_stage_key" - ] - - # if(not self.training): - # if(isinstance(self.cond_stage_models[self.cond_stage_model_metadata[cond_model_key]["model_idx"]], CLAPAudioEmbeddingClassifierFreev2)): - # assert cond_stage_key == "text" # CLAP model should use text for evaluation - - # The original data for conditioning - xc = self.get_input_item(batch, cond_stage_key) - if type(xc) == torch.Tensor: - xc = xc.to(self.device) - - c = self.get_learned_conditioning( - xc, key=cond_model_key, unconditional_cfg=unconditional_cfg - ) - cond_dict[cond_model_key] = c - - return cond_dict - - def instantiate_cond_stage(self, config): - self.cond_stage_model_metadata = {} - - for i, cond_model_key in enumerate(config.keys()): - model = instantiate_from_config(config[cond_model_key]) - self.cond_stage_models.append(model) - self.cond_stage_model_metadata[cond_model_key] = { - "model_idx": i, - "cond_stage_key": config[cond_model_key]["cond_stage_key"], - "conditioning_key": config[cond_model_key]["conditioning_key"], - } - - def get_learned_conditioning(self, c, key, unconditional_cfg): - assert key in self.cond_stage_model_metadata.keys() - - # Classifier-free guidance - if not unconditional_cfg: - c = self.cond_stage_models[ - self.cond_stage_model_metadata[key]["model_idx"] - ](c) - else: - if isinstance(c, torch.Tensor): - batchsize = c.size(0) - elif isinstance(c, list): - batchsize = len(c) - else: - raise NotImplementedError() - c = self.cond_stage_models[ - self.cond_stage_model_metadata[key]["model_idx"] - ].get_unconditional_condition(batchsize) - - return c - - def initialize_param_check_toolkit(self): - self.tracked_steps = 0 - self.param_dict = {} - - def statistic_require_grad_tensor_number(self, module, name=None): - requires_grad_num = 0 - total_num = 0 - require_grad_tensor = None - for p in module.parameters(): - if p.requires_grad: - requires_grad_num += 1 - if require_grad_tensor is None: - require_grad_tensor = p - total_num += 1 - print( - "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)" - % (name, requires_grad_num, total_num, requires_grad_num / total_num) - ) - return require_grad_tensor diff --git a/audioldm2/audiomae_gen/utils.py b/audioldm2/audiomae_gen/utils.py deleted file mode 100644 index 841d35adf338647bdf8bd1c31e9f33dee1252b6e..0000000000000000000000000000000000000000 --- a/audioldm2/audiomae_gen/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch.nn as nn - - -class Prenet(nn.Module): - def __init__(self, in_dim, sizes=[256, 128], dropout_rate=0.5): - super(Prenet, self).__init__() - in_sizes = [in_dim] + sizes[:-1] - self.layers = nn.ModuleList( - [ - nn.Linear(in_size, out_size) - for (in_size, out_size) in zip(in_sizes, sizes) - ] - ) - self.relu = nn.ReLU() - self.dropout = nn.Dropout(dropout_rate) - - def forward(self, inputs): - for linear in self.layers: - inputs = self.dropout(self.relu(linear(inputs))) - return inputs - - -if __name__ == "__main__": - model = Prenet(in_dim=128, sizes=[256, 256, 128]) - import ipdb - - ipdb.set_trace() diff --git a/audioldm2/clap/__init__.py b/audioldm2/clap/__init__.py deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm2/clap/open_clip/__init__.py b/audioldm2/clap/open_clip/__init__.py deleted file mode 100755 index e9f728f2f273be5d5fdbec6c6cc41d737176a8c0..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from .factory import ( - list_models, - create_model, - create_model_and_transforms, - add_model_config, -) -from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics -from .model import ( - CLAP, - CLAPTextCfg, - CLAPVisionCfg, - CLAPAudioCfp, - convert_weights_to_fp16, - trace_model, -) -from .openai import load_openai_model, list_openai_models -from .pretrained import ( - list_pretrained, - list_pretrained_tag_models, - list_pretrained_model_tags, - get_pretrained_url, - download_pretrained, -) -from .tokenizer import SimpleTokenizer, tokenize -from .transform import image_transform diff --git a/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz b/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz deleted file mode 100755 index 36a15856e00a06a9fbed8cdd34d2393fea4a3113..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a -size 1356917 diff --git a/audioldm2/clap/open_clip/factory.py b/audioldm2/clap/open_clip/factory.py deleted file mode 100755 index df0f4a194c2e7328f7b7d3fe11fa6801c6cc1a7c..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/factory.py +++ /dev/null @@ -1,276 +0,0 @@ -import json -import logging -import os -import re -from copy import deepcopy -from pathlib import Path - -import torch - -from .model import CLAP, convert_weights_to_fp16 -from .openai import load_openai_model -from .pretrained import get_pretrained_url, download_pretrained -from .transform import image_transform - -_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] -_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs - - -def _natural_key(string_): - return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] - - -def _rescan_model_configs(): - global _MODEL_CONFIGS - - config_ext = (".json",) - config_files = [] - for config_path in _MODEL_CONFIG_PATHS: - if config_path.is_file() and config_path.suffix in config_ext: - config_files.append(config_path) - elif config_path.is_dir(): - for ext in config_ext: - config_files.extend(config_path.glob(f"*{ext}")) - - for cf in config_files: - if os.path.basename(cf)[0] == ".": - continue # Ignore hidden files - - with open(cf, "r") as f: - model_cfg = json.load(f) - if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): - _MODEL_CONFIGS[cf.stem] = model_cfg - - _MODEL_CONFIGS = { - k: v - for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) - } - - -_rescan_model_configs() # initial populate of model config registry - - -def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): - checkpoint = torch.load(checkpoint_path, map_location=map_location) - if isinstance(checkpoint, dict) and "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - else: - state_dict = checkpoint - if skip_params: - if next(iter(state_dict.items()))[0].startswith("module"): - state_dict = {k[7:]: v for k, v in state_dict.items()} - # for k in state_dict: - # if k.startswith('transformer'): - # v = state_dict.pop(k) - # state_dict['text_branch.' + k[12:]] = v - return state_dict - - -def create_model( - amodel_name: str, - tmodel_name: str, - pretrained: str = "", - precision: str = "fp32", - device: torch.device = torch.device("cpu"), - jit: bool = False, - force_quick_gelu: bool = False, - openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"), - skip_params=True, - pretrained_audio: str = "", - pretrained_text: str = "", - enable_fusion: bool = False, - fusion_type: str = "None" - # pretrained_image: bool = False, -): - amodel_name = amodel_name.replace( - "/", "-" - ) # for callers using old naming with / in ViT names - pretrained_orig = pretrained - pretrained = pretrained.lower() - if pretrained == "openai": - if amodel_name in _MODEL_CONFIGS: - logging.info(f"Loading {amodel_name} model config.") - model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) - else: - logging.error( - f"Model config for {amodel_name} not found; available models {list_models()}." - ) - raise RuntimeError(f"Model config for {amodel_name} not found.") - - logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") - # Hard Code in model name - model_cfg["text_cfg"]["model_type"] = tmodel_name - model = load_openai_model( - "ViT-B-16", - model_cfg, - device=device, - jit=jit, - cache_dir=openai_model_cache_dir, - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 - if precision == "amp" or precision == "fp32": - model = model.float() - else: - if amodel_name in _MODEL_CONFIGS: - logging.info(f"Loading {amodel_name} model config.") - model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) - else: - logging.error( - f"Model config for {amodel_name} not found; available models {list_models()}." - ) - raise RuntimeError(f"Model config for {amodel_name} not found.") - - if force_quick_gelu: - # override for use of QuickGELU on non-OpenAI transformer models - model_cfg["quick_gelu"] = True - - # if pretrained_image: - # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): - # # pretrained weight loading for timm models set via vision_cfg - # model_cfg['vision_cfg']['timm_model_pretrained'] = True - # else: - # assert False, 'pretrained image towers currently only supported for timm models' - model_cfg["text_cfg"]["model_type"] = tmodel_name - model_cfg["enable_fusion"] = enable_fusion - model_cfg["fusion_type"] = fusion_type - model = CLAP(**model_cfg) - - if pretrained: - checkpoint_path = "" - url = get_pretrained_url(amodel_name, pretrained) - if url: - checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) - elif os.path.exists(pretrained_orig): - checkpoint_path = pretrained_orig - if checkpoint_path: - logging.info( - f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})." - ) - ckpt = load_state_dict(checkpoint_path, skip_params=True) - model.load_state_dict(ckpt) - param_names = [n for n, p in model.named_parameters()] - # for n in param_names: - # print(n, "\t", "Loaded" if n in ckpt else "Unloaded") - else: - logging.warning( - f"Pretrained weights ({pretrained}) not found for model {amodel_name}." - ) - raise RuntimeError( - f"Pretrained weights ({pretrained}) not found for model {amodel_name}." - ) - - if pretrained_audio: - if amodel_name.startswith("PANN"): - if "Cnn14_mAP" in pretrained_audio: # official checkpoint - audio_ckpt = torch.load(pretrained_audio, map_location="cpu") - audio_ckpt = audio_ckpt["model"] - keys = list(audio_ckpt.keys()) - for key in keys: - if ( - "spectrogram_extractor" not in key - and "logmel_extractor" not in key - ): - v = audio_ckpt.pop(key) - audio_ckpt["audio_branch." + key] = v - elif os.path.basename(pretrained_audio).startswith( - "PANN" - ): # checkpoint trained via HTSAT codebase - audio_ckpt = torch.load(pretrained_audio, map_location="cpu") - audio_ckpt = audio_ckpt["state_dict"] - keys = list(audio_ckpt.keys()) - for key in keys: - if key.startswith("sed_model"): - v = audio_ckpt.pop(key) - audio_ckpt["audio_branch." + key[10:]] = v - elif os.path.basename(pretrained_audio).startswith( - "finetuned" - ): # checkpoint trained via linear probe codebase - audio_ckpt = torch.load(pretrained_audio, map_location="cpu") - else: - raise ValueError("Unknown audio checkpoint") - elif amodel_name.startswith("HTSAT"): - if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint - audio_ckpt = torch.load(pretrained_audio, map_location="cpu") - audio_ckpt = audio_ckpt["state_dict"] - keys = list(audio_ckpt.keys()) - for key in keys: - if key.startswith("sed_model") and ( - "spectrogram_extractor" not in key - and "logmel_extractor" not in key - ): - v = audio_ckpt.pop(key) - audio_ckpt["audio_branch." + key[10:]] = v - elif os.path.basename(pretrained_audio).startswith( - "HTSAT" - ): # checkpoint trained via HTSAT codebase - audio_ckpt = torch.load(pretrained_audio, map_location="cpu") - audio_ckpt = audio_ckpt["state_dict"] - keys = list(audio_ckpt.keys()) - for key in keys: - if key.startswith("sed_model"): - v = audio_ckpt.pop(key) - audio_ckpt["audio_branch." + key[10:]] = v - elif os.path.basename(pretrained_audio).startswith( - "finetuned" - ): # checkpoint trained via linear probe codebase - audio_ckpt = torch.load(pretrained_audio, map_location="cpu") - else: - raise ValueError("Unknown audio checkpoint") - else: - raise f"this audio encoder pretrained checkpoint is not support" - - model.load_state_dict(audio_ckpt, strict=False) - logging.info( - f"Loading pretrained {amodel_name} weights ({pretrained_audio})." - ) - param_names = [n for n, p in model.named_parameters()] - for n in param_names: - print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") - - model.to(device=device) - if precision == "fp16": - assert device.type != "cpu" - convert_weights_to_fp16(model) - - if jit: - model = torch.jit.script(model) - - return model, model_cfg - - -def create_model_and_transforms( - model_name: str, - pretrained: str = "", - precision: str = "fp32", - device: torch.device = torch.device("cpu"), - jit: bool = False, - force_quick_gelu: bool = False, - # pretrained_image: bool = False, -): - model = create_model( - model_name, - pretrained, - precision, - device, - jit, - force_quick_gelu=force_quick_gelu, - # pretrained_image=pretrained_image - ) - preprocess_train = image_transform(model.visual.image_size, is_train=True) - preprocess_val = image_transform(model.visual.image_size, is_train=False) - return model, preprocess_train, preprocess_val - - -def list_models(): - """enumerate available model architectures based on config files""" - return list(_MODEL_CONFIGS.keys()) - - -def add_model_config(path): - """add model config path or file and update registry""" - if not isinstance(path, Path): - path = Path(path) - _MODEL_CONFIG_PATHS.append(path) - _rescan_model_configs() diff --git a/audioldm2/clap/open_clip/feature_fusion.py b/audioldm2/clap/open_clip/feature_fusion.py deleted file mode 100755 index dbe4e170e05894c12ebdc36ba1dc1de65e441b89..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/feature_fusion.py +++ /dev/null @@ -1,192 +0,0 @@ -""" -Feature Fusion for Varible-Length Data Processing -AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py -According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 -""" - -import torch -import torch.nn as nn - - -class DAF(nn.Module): - """ - 直接相加 DirectAddFuse - """ - - def __init__(self): - super(DAF, self).__init__() - - def forward(self, x, residual): - return x + residual - - -class iAFF(nn.Module): - """ - 多特征融合 iAFF - """ - - def __init__(self, channels=64, r=4, type="2D"): - super(iAFF, self).__init__() - inter_channels = int(channels // r) - - if type == "1D": - # 本地注意力 - self.local_att = nn.Sequential( - nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(channels), - ) - - # 全局注意力 - self.global_att = nn.Sequential( - nn.AdaptiveAvgPool1d(1), - nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(channels), - ) - - # 第二次本地注意力 - self.local_att2 = nn.Sequential( - nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(channels), - ) - # 第二次全局注意力 - self.global_att2 = nn.Sequential( - nn.AdaptiveAvgPool1d(1), - nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(channels), - ) - elif type == "2D": - # 本地注意力 - self.local_att = nn.Sequential( - nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - - # 全局注意力 - self.global_att = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - - # 第二次本地注意力 - self.local_att2 = nn.Sequential( - nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - # 第二次全局注意力 - self.global_att2 = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - else: - raise f"the type is not supported" - - self.sigmoid = nn.Sigmoid() - - def forward(self, x, residual): - flag = False - xa = x + residual - if xa.size(0) == 1: - xa = torch.cat([xa, xa], dim=0) - flag = True - xl = self.local_att(xa) - xg = self.global_att(xa) - xlg = xl + xg - wei = self.sigmoid(xlg) - xi = x * wei + residual * (1 - wei) - - xl2 = self.local_att2(xi) - xg2 = self.global_att(xi) - xlg2 = xl2 + xg2 - wei2 = self.sigmoid(xlg2) - xo = x * wei2 + residual * (1 - wei2) - if flag: - xo = xo[0].unsqueeze(0) - return xo - - -class AFF(nn.Module): - """ - 多特征融合 AFF - """ - - def __init__(self, channels=64, r=4, type="2D"): - super(AFF, self).__init__() - inter_channels = int(channels // r) - - if type == "1D": - self.local_att = nn.Sequential( - nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(channels), - ) - self.global_att = nn.Sequential( - nn.AdaptiveAvgPool1d(1), - nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(channels), - ) - elif type == "2D": - self.local_att = nn.Sequential( - nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - self.global_att = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - else: - raise f"the type is not supported." - - self.sigmoid = nn.Sigmoid() - - def forward(self, x, residual): - flag = False - xa = x + residual - if xa.size(0) == 1: - xa = torch.cat([xa, xa], dim=0) - flag = True - xl = self.local_att(xa) - xg = self.global_att(xa) - xlg = xl + xg - wei = self.sigmoid(xlg) - xo = 2 * x * wei + 2 * residual * (1 - wei) - if flag: - xo = xo[0].unsqueeze(0) - return xo diff --git a/audioldm2/clap/open_clip/htsat.py b/audioldm2/clap/open_clip/htsat.py deleted file mode 100755 index 8bf4fceea2dfef953522c14a3a39a417658f2257..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/htsat.py +++ /dev/null @@ -1,1304 +0,0 @@ -# Ke Chen -# knutchen@ucsd.edu -# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION -# Some layers designed on the model -# below codes are based and referred from https://github.com/microsoft/Swin-Transformer -# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf - -import torch -import torch.nn as nn -from itertools import repeat -import collections.abc -import math -import warnings - -from torch.nn.init import _calculate_fan_in_and_fan_out -import torch.utils.checkpoint as checkpoint - -import random - -from torchlibrosa.stft import Spectrogram, LogmelFilterBank -from torchlibrosa.augmentation import SpecAugmentation - -from itertools import repeat -from .utils import do_mixup, interpolate - -from .feature_fusion import iAFF, AFF, DAF - - -# from PyTorch internals -def _ntuple(n): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -to_1tuple = _ntuple(1) -to_2tuple = _ntuple(2) -to_3tuple = _ntuple(3) -to_4tuple = _ntuple(4) -to_ntuple = _ntuple - - -def drop_path(x, drop_prob: float = 0.0, training: bool = False): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, - the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for - changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use - 'survival rate' as the argument. - """ - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * ( - x.ndim - 1 - ) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding""" - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - norm_layer=None, - flatten=True, - patch_stride=16, - enable_fusion=False, - fusion_type="None", - ): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patch_stride = to_2tuple(patch_stride) - self.img_size = img_size - self.patch_size = patch_size - self.patch_stride = patch_stride - self.grid_size = ( - img_size[0] // patch_stride[0], - img_size[1] // patch_stride[1], - ) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.enable_fusion = enable_fusion - self.fusion_type = fusion_type - - padding = ( - (patch_size[0] - patch_stride[0]) // 2, - (patch_size[1] - patch_stride[1]) // 2, - ) - - if (self.enable_fusion) and (self.fusion_type == "channel_map"): - self.proj = nn.Conv2d( - in_chans * 4, - embed_dim, - kernel_size=patch_size, - stride=patch_stride, - padding=padding, - ) - else: - self.proj = nn.Conv2d( - in_chans, - embed_dim, - kernel_size=patch_size, - stride=patch_stride, - padding=padding, - ) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - if (self.enable_fusion) and ( - self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] - ): - self.mel_conv2d = nn.Conv2d( - in_chans, - embed_dim, - kernel_size=(patch_size[0], patch_size[1] * 3), - stride=(patch_stride[0], patch_stride[1] * 3), - padding=padding, - ) - if self.fusion_type == "daf_2d": - self.fusion_model = DAF() - elif self.fusion_type == "aff_2d": - self.fusion_model = AFF(channels=embed_dim, type="2D") - elif self.fusion_type == "iaff_2d": - self.fusion_model = iAFF(channels=embed_dim, type="2D") - - def forward(self, x, longer_idx=None): - if (self.enable_fusion) and ( - self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] - ): - global_x = x[:, 0:1, :, :] - - # global processing - B, C, H, W = global_x.shape - assert ( - H == self.img_size[0] and W == self.img_size[1] - ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - global_x = self.proj(global_x) - TW = global_x.size(-1) - if len(longer_idx) > 0: - # local processing - local_x = x[longer_idx, 1:, :, :].contiguous() - B, C, H, W = local_x.shape - local_x = local_x.view(B * C, 1, H, W) - local_x = self.mel_conv2d(local_x) - local_x = local_x.view( - B, C, local_x.size(1), local_x.size(2), local_x.size(3) - ) - local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3) - TB, TC, TH, _ = local_x.size() - if local_x.size(-1) < TW: - local_x = torch.cat( - [ - local_x, - torch.zeros( - (TB, TC, TH, TW - local_x.size(-1)), - device=global_x.device, - ), - ], - dim=-1, - ) - else: - local_x = local_x[:, :, :, :TW] - - global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x) - x = global_x - else: - B, C, H, W = x.shape - assert ( - H == self.img_size[0] and W == self.img_size[1] - ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x) - - if self.flatten: - x = x.flatten(2).transpose(1, 2) # BCHW -> BNC - x = self.norm(x) - return x - - -class Mlp(nn.Module): - """MLP as used in Vision Transformer, MLP-Mixer and related networks""" - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0.0, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def _no_grad_trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): - # type: (Tensor, float, float, float, float) -> Tensor - r"""Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \leq \text{mean} \leq b`. - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - Examples: - >>> w = torch.empty(3, 5) - >>> nn.init.trunc_normal_(w) - """ - return _no_grad_trunc_normal_(tensor, mean, std, a, b) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - if mode == "fan_in": - denom = fan_in - elif mode == "fan_out": - denom = fan_out - elif mode == "fan_avg": - denom = (fan_in + fan_out) / 2 - - variance = scale / denom - - if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) - elif distribution == "normal": - tensor.normal_(std=math.sqrt(variance)) - elif distribution == "uniform": - bound = math.sqrt(3 * variance) - tensor.uniform_(-bound, bound) - else: - raise ValueError(f"invalid distribution {distribution}") - - -def lecun_normal_(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = ( - x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - ) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view( - B, H // window_size, W // window_size, window_size, window_size, -1 - ) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r"""Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__( - self, - dim, - window_size, - num_heads, - qkv_bias=True, - qk_scale=None, - attn_drop=0.0, - proj_drop=0.0, - ): - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) - ) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = ( - coords_flatten[:, :, None] - coords_flatten[:, None, :] - ) # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute( - 1, 2, 0 - ).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=0.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = ( - self.qkv(x) - .reshape(B_, N, 3, self.num_heads, C // self.num_heads) - .permute(2, 0, 3, 1, 4) - ) - q, k, v = ( - qkv[0], - qkv[1], - qkv[2], - ) # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = q @ k.transpose(-2, -1) - - relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1) - ].view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], - -1, - ) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute( - 2, 0, 1 - ).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( - 1 - ).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x, attn - - def extra_repr(self): - return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}" - - -# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model -class SwinTransformerBlock(nn.Module): - r"""Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__( - self, - dim, - input_resolution, - num_heads, - window_size=7, - shift_size=0, - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - norm_before_mlp="ln", - ): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - self.norm_before_mlp = norm_before_mlp - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert ( - 0 <= self.shift_size < self.window_size - ), "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, - window_size=to_2tuple(self.window_size), - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - ) - - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - if self.norm_before_mlp == "ln": - self.norm2 = nn.LayerNorm(dim) - elif self.norm_before_mlp == "bn": - self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose( - 1, 2 - ) - else: - raise NotImplementedError - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop, - ) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None), - ) - w_slices = ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None), - ) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition( - img_mask, self.window_size - ) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill( - attn_mask != 0, float(-100.0) - ).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, x): - # pdb.set_trace() - H, W = self.input_resolution - # print("H: ", H) - # print("W: ", W) - # pdb.set_trace() - B, L, C = x.shape - # assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll( - x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) - ) - else: - shifted_x = x - - # partition windows - x_windows = window_partition( - shifted_x, self.window_size - ) # nW*B, window_size, window_size, C - x_windows = x_windows.view( - -1, self.window_size * self.window_size, C - ) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA - attn_windows, attn = self.attn( - x_windows, mask=self.attn_mask - ) # nW*B, window_size*window_size, C - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll( - shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) - ) - else: - x = shifted_x - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x, attn - - def extra_repr(self): - return ( - f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - ) - - -class PatchMerging(nn.Module): - r"""Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self): - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - -class BasicLayer(nn.Module): - """A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__( - self, - dim, - input_resolution, - depth, - num_heads, - window_size, - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - norm_layer=nn.LayerNorm, - downsample=None, - use_checkpoint=False, - norm_before_mlp="ln", - ): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList( - [ - SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] - if isinstance(drop_path, list) - else drop_path, - norm_layer=norm_layer, - norm_before_mlp=norm_before_mlp, - ) - for i in range(depth) - ] - ) - - # patch merging layer - if downsample is not None: - self.downsample = downsample( - input_resolution, dim=dim, norm_layer=norm_layer - ) - else: - self.downsample = None - - def forward(self, x): - attns = [] - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x, attn = blk(x) - if not self.training: - attns.append(attn.unsqueeze(0)) - if self.downsample is not None: - x = self.downsample(x) - if not self.training: - attn = torch.cat(attns, dim=0) - attn = torch.mean(attn, dim=0) - return x, attn - - def extra_repr(self): - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - -# The Core of HTSAT -class HTSAT_Swin_Transformer(nn.Module): - r"""HTSAT based on the Swin Transformer - Args: - spec_size (int | tuple(int)): Input Spectrogram size. Default 256 - patch_size (int | tuple(int)): Patch size. Default: 4 - path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 - in_chans (int): Number of input image channels. Default: 1 (mono) - num_classes (int): Number of classes for classification head. Default: 527 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 8 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - config (module): The configuration Module from config.py - """ - - def __init__( - self, - spec_size=256, - patch_size=4, - patch_stride=(4, 4), - in_chans=1, - num_classes=527, - embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[4, 8, 16, 32], - window_size=8, - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, - drop_path_rate=0.1, - norm_layer=nn.LayerNorm, - ape=False, - patch_norm=True, - use_checkpoint=False, - norm_before_mlp="ln", - config=None, - enable_fusion=False, - fusion_type="None", - **kwargs, - ): - super(HTSAT_Swin_Transformer, self).__init__() - - self.config = config - self.spec_size = spec_size - self.patch_stride = patch_stride - self.patch_size = patch_size - self.window_size = window_size - self.embed_dim = embed_dim - self.depths = depths - self.ape = ape - self.in_chans = in_chans - self.num_classes = num_classes - self.num_heads = num_heads - self.num_layers = len(self.depths) - self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) - - self.drop_rate = drop_rate - self.attn_drop_rate = attn_drop_rate - self.drop_path_rate = drop_path_rate - - self.qkv_bias = qkv_bias - self.qk_scale = None - - self.patch_norm = patch_norm - self.norm_layer = norm_layer if self.patch_norm else None - self.norm_before_mlp = norm_before_mlp - self.mlp_ratio = mlp_ratio - - self.use_checkpoint = use_checkpoint - - self.enable_fusion = enable_fusion - self.fusion_type = fusion_type - - # process mel-spec ; used only once - self.freq_ratio = self.spec_size // self.config.mel_bins - window = "hann" - center = True - pad_mode = "reflect" - ref = 1.0 - amin = 1e-10 - top_db = None - self.interpolate_ratio = 32 # Downsampled ratio - # Spectrogram extractor - self.spectrogram_extractor = Spectrogram( - n_fft=config.window_size, - hop_length=config.hop_size, - win_length=config.window_size, - window=window, - center=center, - pad_mode=pad_mode, - freeze_parameters=True, - ) - # Logmel feature extractor - self.logmel_extractor = LogmelFilterBank( - sr=config.sample_rate, - n_fft=config.window_size, - n_mels=config.mel_bins, - fmin=config.fmin, - fmax=config.fmax, - ref=ref, - amin=amin, - top_db=top_db, - freeze_parameters=True, - ) - # Spec augmenter - self.spec_augmenter = SpecAugmentation( - time_drop_width=64, - time_stripes_num=2, - freq_drop_width=8, - freq_stripes_num=2, - ) # 2 2 - self.bn0 = nn.BatchNorm2d(self.config.mel_bins) - - # split spctrogram into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=self.spec_size, - patch_size=self.patch_size, - in_chans=self.in_chans, - embed_dim=self.embed_dim, - norm_layer=self.norm_layer, - patch_stride=patch_stride, - enable_fusion=self.enable_fusion, - fusion_type=self.fusion_type, - ) - - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.grid_size - self.patches_resolution = patches_resolution - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter( - torch.zeros(1, num_patches, self.embed_dim) - ) - trunc_normal_(self.absolute_pos_embed, std=0.02) - - self.pos_drop = nn.Dropout(p=self.drop_rate) - - # stochastic depth - dpr = [ - x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths)) - ] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer( - dim=int(self.embed_dim * 2**i_layer), - input_resolution=( - patches_resolution[0] // (2**i_layer), - patches_resolution[1] // (2**i_layer), - ), - depth=self.depths[i_layer], - num_heads=self.num_heads[i_layer], - window_size=self.window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=self.qkv_bias, - qk_scale=self.qk_scale, - drop=self.drop_rate, - attn_drop=self.attn_drop_rate, - drop_path=dpr[ - sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1]) - ], - norm_layer=self.norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - use_checkpoint=use_checkpoint, - norm_before_mlp=self.norm_before_mlp, - ) - self.layers.append(layer) - - self.norm = self.norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - self.maxpool = nn.AdaptiveMaxPool1d(1) - - SF = ( - self.spec_size - // (2 ** (len(self.depths) - 1)) - // self.patch_stride[0] - // self.freq_ratio - ) - self.tscam_conv = nn.Conv2d( - in_channels=self.num_features, - out_channels=self.num_classes, - kernel_size=(SF, 3), - padding=(0, 1), - ) - self.head = nn.Linear(num_classes, num_classes) - - if (self.enable_fusion) and ( - self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] - ): - self.mel_conv1d = nn.Sequential( - nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), - nn.BatchNorm1d(64), - ) - if self.fusion_type == "daf_1d": - self.fusion_model = DAF() - elif self.fusion_type == "aff_1d": - self.fusion_model = AFF(channels=64, type="1D") - elif self.fusion_type == "iaff_1d": - self.fusion_model = iAFF(channels=64, type="1D") - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {"absolute_pos_embed"} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {"relative_position_bias_table"} - - def forward_features(self, x, longer_idx=None): - # A deprecated optimization for using a hierarchical output from different blocks - - frames_num = x.shape[2] - x = self.patch_embed(x, longer_idx=longer_idx) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - for i, layer in enumerate(self.layers): - x, attn = layer(x) - # for x - x = self.norm(x) - B, N, C = x.shape - SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] - ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] - x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST) - B, C, F, T = x.shape - # group 2D CNN - c_freq_bin = F // self.freq_ratio - x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) - x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1) - # get latent_output - fine_grained_latent_output = torch.mean(x, dim=2) - fine_grained_latent_output = interpolate( - fine_grained_latent_output.permute(0, 2, 1).contiguous(), - 8 * self.patch_stride[1], - ) - - latent_output = self.avgpool(torch.flatten(x, 2)) - latent_output = torch.flatten(latent_output, 1) - - # display the attention map, if needed - - x = self.tscam_conv(x) - x = torch.flatten(x, 2) # B, C, T - - fpx = interpolate( - torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1] - ) - - x = self.avgpool(x) - x = torch.flatten(x, 1) - - output_dict = { - "framewise_output": fpx, # already sigmoided - "clipwise_output": torch.sigmoid(x), - "fine_grained_embedding": fine_grained_latent_output, - "embedding": latent_output, - } - - return output_dict - - def crop_wav(self, x, crop_size, spe_pos=None): - time_steps = x.shape[2] - tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) - for i in range(len(x)): - if spe_pos is None: - crop_pos = random.randint(0, time_steps - crop_size - 1) - else: - crop_pos = spe_pos - tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :] - return tx - - # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model - def reshape_wav2img(self, x): - B, C, T, F = x.shape - target_T = int(self.spec_size * self.freq_ratio) - target_F = self.spec_size // self.freq_ratio - assert ( - T <= target_T and F <= target_F - ), "the wav size should less than or equal to the swin input size" - # to avoid bicubic zero error - if T < target_T: - x = nn.functional.interpolate( - x, (target_T, x.shape[3]), mode="bicubic", align_corners=True - ) - if F < target_F: - x = nn.functional.interpolate( - x, (x.shape[2], target_F), mode="bicubic", align_corners=True - ) - x = x.permute(0, 1, 3, 2).contiguous() - x = x.reshape( - x.shape[0], - x.shape[1], - x.shape[2], - self.freq_ratio, - x.shape[3] // self.freq_ratio, - ) - # print(x.shape) - x = x.permute(0, 1, 3, 2, 4).contiguous() - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) - return x - - # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model - def repeat_wat2img(self, x, cur_pos): - B, C, T, F = x.shape - target_T = int(self.spec_size * self.freq_ratio) - target_F = self.spec_size // self.freq_ratio - assert ( - T <= target_T and F <= target_F - ), "the wav size should less than or equal to the swin input size" - # to avoid bicubic zero error - if T < target_T: - x = nn.functional.interpolate( - x, (target_T, x.shape[3]), mode="bicubic", align_corners=True - ) - if F < target_F: - x = nn.functional.interpolate( - x, (x.shape[2], target_F), mode="bicubic", align_corners=True - ) - x = x.permute(0, 1, 3, 2).contiguous() # B C F T - x = x[:, :, :, cur_pos : cur_pos + self.spec_size] - x = x.repeat(repeats=(1, 1, 4, 1)) - return x - - def forward( - self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None - ): # out_feat_keys: List[str] = None): - if self.enable_fusion and x["longer"].sum() == 0: - # if no audio is longer than 10s, then randomly select one audio to be longer - x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True - - if not self.enable_fusion: - x = x["waveform"].to(device=device, non_blocking=True) - x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) - x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) - x = x.transpose(1, 3) - x = self.bn0(x) - x = x.transpose(1, 3) - if self.training: - x = self.spec_augmenter(x) - - if self.training and mixup_lambda is not None: - x = do_mixup(x, mixup_lambda) - - x = self.reshape_wav2img(x) - output_dict = self.forward_features(x) - else: - longer_list = x["longer"].to(device=device, non_blocking=True) - x = x["mel_fusion"].to(device=device, non_blocking=True) - x = x.transpose(1, 3) - x = self.bn0(x) - x = x.transpose(1, 3) - longer_list_idx = torch.where(longer_list)[0] - if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: - new_x = x[:, 0:1, :, :].clone().contiguous() - if len(longer_list_idx) > 0: - # local processing - fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() - FB, FC, FT, FF = fusion_x_local.size() - fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) - fusion_x_local = torch.permute( - fusion_x_local, (0, 2, 1) - ).contiguous() - fusion_x_local = self.mel_conv1d(fusion_x_local) - fusion_x_local = fusion_x_local.view( - FB, FC, FF, fusion_x_local.size(-1) - ) - fusion_x_local = ( - torch.permute(fusion_x_local, (0, 2, 1, 3)) - .contiguous() - .flatten(2) - ) - if fusion_x_local.size(-1) < FT: - fusion_x_local = torch.cat( - [ - fusion_x_local, - torch.zeros( - (FB, FF, FT - fusion_x_local.size(-1)), - device=device, - ), - ], - dim=-1, - ) - else: - fusion_x_local = fusion_x_local[:, :, :FT] - # 1D fusion - new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() - new_x[longer_list_idx] = self.fusion_model( - new_x[longer_list_idx], fusion_x_local - ) - x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] - else: - x = new_x - - elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: - x = x # no change - - if self.training: - x = self.spec_augmenter(x) - if self.training and mixup_lambda is not None: - x = do_mixup(x, mixup_lambda) - - x = self.reshape_wav2img(x) - output_dict = self.forward_features(x, longer_idx=longer_list_idx) - - # if infer_mode: - # # in infer mode. we need to handle different length audio input - # frame_num = x.shape[2] - # target_T = int(self.spec_size * self.freq_ratio) - # repeat_ratio = math.floor(target_T / frame_num) - # x = x.repeat(repeats=(1,1,repeat_ratio,1)) - # x = self.reshape_wav2img(x) - # output_dict = self.forward_features(x) - # else: - # if x.shape[2] > self.freq_ratio * self.spec_size: - # if self.training: - # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) - # x = self.reshape_wav2img(x) - # output_dict = self.forward_features(x) - # else: - # # Change: Hard code here - # overlap_size = (x.shape[2] - 1) // 4 - # output_dicts = [] - # crop_size = (x.shape[2] - 1) // 2 - # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): - # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) - # tx = self.reshape_wav2img(tx) - # output_dicts.append(self.forward_features(tx)) - # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) - # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) - # for d in output_dicts: - # clipwise_output += d["clipwise_output"] - # framewise_output += d["framewise_output"] - # clipwise_output = clipwise_output / len(output_dicts) - # framewise_output = framewise_output / len(output_dicts) - # output_dict = { - # 'framewise_output': framewise_output, - # 'clipwise_output': clipwise_output - # } - # else: # this part is typically used, and most easy one - # x = self.reshape_wav2img(x) - # output_dict = self.forward_features(x) - # x = self.head(x) - - # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T - - return output_dict - - -def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"): - try: - assert audio_cfg.model_name in [ - "tiny", - "base", - "large", - ], "model name for HTS-AT is wrong!" - if audio_cfg.model_name == "tiny": - model = HTSAT_Swin_Transformer( - spec_size=256, - patch_size=4, - patch_stride=(4, 4), - num_classes=audio_cfg.class_num, - embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[4, 8, 16, 32], - window_size=8, - config=audio_cfg, - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - elif audio_cfg.model_name == "base": - model = HTSAT_Swin_Transformer( - spec_size=256, - patch_size=4, - patch_stride=(4, 4), - num_classes=audio_cfg.class_num, - embed_dim=128, - depths=[2, 2, 12, 2], - num_heads=[4, 8, 16, 32], - window_size=8, - config=audio_cfg, - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - elif audio_cfg.model_name == "large": - model = HTSAT_Swin_Transformer( - spec_size=256, - patch_size=4, - patch_stride=(4, 4), - num_classes=audio_cfg.class_num, - embed_dim=256, - depths=[2, 2, 12, 2], - num_heads=[4, 8, 16, 32], - window_size=8, - config=audio_cfg, - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - - return model - except: - raise RuntimeError( - f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." - ) diff --git a/audioldm2/clap/open_clip/loss.py b/audioldm2/clap/open_clip/loss.py deleted file mode 100755 index 37faba58f3693d0659512ab1d6e19614fbda0675..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/loss.py +++ /dev/null @@ -1,397 +0,0 @@ -import torch -import torch.distributed.nn -from torch import distributed as dist, nn as nn -from torch.nn import functional as F -import numpy as np -from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score - -try: - import horovod.torch as hvd -except ImportError: - hvd = None - - -def gather_features( - audio_features, - text_features, - audio_features_mlp=None, - text_features_mlp=None, - local_loss=False, - gather_with_grad=False, - rank=0, - world_size=1, - use_horovod=False, - mlp_loss=False, -): - if use_horovod: - assert hvd is not None, "Please install horovod" - if gather_with_grad: - all_audio_features = hvd.allgather(audio_features) - all_text_features = hvd.allgather(text_features) - if mlp_loss: - all_audio_features_mlp = hvd.allgather(audio_features_mlp) - all_text_features_mlp = hvd.allgather(text_features_mlp) - else: - with torch.no_grad(): - all_audio_features = hvd.allgather(audio_features) - all_text_features = hvd.allgather(text_features) - if mlp_loss: - all_audio_features_mlp = hvd.allgather(audio_features_mlp) - all_text_features_mlp = hvd.allgather(text_features_mlp) - if not local_loss: - # ensure grads for local rank when all_* features don't have a gradient - gathered_audio_features = list( - all_audio_features.chunk(world_size, dim=0) - ) - gathered_text_features = list( - all_text_features.chunk(world_size, dim=0) - ) - gathered_audio_features[rank] = audio_features - gathered_text_features[rank] = text_features - all_audio_features = torch.cat(gathered_audio_features, dim=0) - all_text_features = torch.cat(gathered_text_features, dim=0) - if mlp_loss: - gathered_audio_features_mlp = list( - all_audio_features_mlp.chunk(world_size, dim=0) - ) - gathered_text_features_mlp = list( - all_text_features_mlp.chunk(world_size, dim=0) - ) - gathered_audio_features_mlp[rank] = audio_features_mlp - gathered_text_features_mlp[rank] = text_features_mlp - all_audio_features_mlp = torch.cat( - gathered_audio_features_mlp, dim=0 - ) - all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) - else: - # We gather tensors from all gpus - if gather_with_grad: - all_audio_features = torch.cat( - torch.distributed.nn.all_gather(audio_features), dim=0 - ) - all_text_features = torch.cat( - torch.distributed.nn.all_gather(text_features), dim=0 - ) - if mlp_loss: - all_audio_features_mlp = torch.cat( - torch.distributed.nn.all_gather(audio_features_mlp), dim=0 - ) - all_text_features_mlp = torch.cat( - torch.distributed.nn.all_gather(text_features_mlp), dim=0 - ) - else: - gathered_audio_features = [ - torch.zeros_like(audio_features) for _ in range(world_size) - ] - gathered_text_features = [ - torch.zeros_like(text_features) for _ in range(world_size) - ] - dist.all_gather(gathered_audio_features, audio_features) - dist.all_gather(gathered_text_features, text_features) - if mlp_loss: - gathered_audio_features_mlp = [ - torch.zeros_like(audio_features_mlp) for _ in range(world_size) - ] - gathered_text_features_mlp = [ - torch.zeros_like(text_features_mlp) for _ in range(world_size) - ] - dist.all_gather(gathered_audio_features_mlp, audio_features_mlp) - dist.all_gather(gathered_text_features_mlp, text_features_mlp) - if not local_loss: - # ensure grads for local rank when all_* features don't have a gradient - gathered_audio_features[rank] = audio_features - gathered_text_features[rank] = text_features - if mlp_loss: - gathered_audio_features_mlp[rank] = audio_features_mlp - gathered_text_features_mlp[rank] = text_features_mlp - - all_audio_features = torch.cat(gathered_audio_features, dim=0) - all_text_features = torch.cat(gathered_text_features, dim=0) - if mlp_loss: - all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) - all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) - if mlp_loss: - return ( - all_audio_features, - all_text_features, - all_audio_features_mlp, - all_text_features_mlp, - ) - else: - return all_audio_features, all_text_features - - -class ClipLoss(nn.Module): - def __init__( - self, - local_loss=False, - gather_with_grad=False, - cache_labels=False, - rank=0, - world_size=1, - use_horovod=False, - mlp_loss=False, - weight_loss_kappa=0, - ): - super().__init__() - self.local_loss = local_loss - self.gather_with_grad = gather_with_grad - self.cache_labels = cache_labels - self.rank = rank - self.world_size = world_size - self.use_horovod = use_horovod - self.mlp_loss = mlp_loss - self.weighted_loss = bool(weight_loss_kappa != 0) - self.weight_loss_kappa = weight_loss_kappa - # cache state - self.prev_num_logits = 0 - self.labels = {} - - def forward( - self, - audio_features, - text_features, - logit_scale_a, - logit_scale_t=None, - audio_features_mlp=None, - text_features_mlp=None, - ): - device = audio_features.device - if self.mlp_loss: - if self.world_size > 1: - ( - all_audio_features, - all_text_features, - all_audio_features_mlp, - all_text_features_mlp, - ) = gather_features( - audio_features=audio_features, - text_features=text_features, - audio_features_mlp=audio_features_mlp, - text_features_mlp=text_features_mlp, - local_loss=self.local_loss, - gather_with_grad=self.gather_with_grad, - rank=self.rank, - world_size=self.world_size, - use_horovod=self.use_horovod, - mlp_loss=self.mlp_loss, - ) - if self.local_loss: - a_logits_per_audio = ( - logit_scale_a * audio_features @ all_text_features_mlp.T - ) - a_logits_per_text = ( - logit_scale_a * text_features_mlp @ all_audio_features.T - ) - t_logits_per_audio = ( - logit_scale_t * audio_features_mlp @ all_text_features.T - ) - t_logits_per_text = ( - logit_scale_t * text_features @ all_audio_features_mlp.T - ) - else: - a_logits_per_audio = ( - logit_scale_a * all_audio_features @ all_text_features_mlp.T - ) - a_logits_per_text = a_logits_per_audio.T - t_logits_per_audio = ( - logit_scale_t * all_audio_features_mlp @ all_text_features.T - ) - t_logits_per_text = t_logits_per_audio.T - else: - a_logits_per_audio = ( - logit_scale_a * audio_features @ text_features_mlp.T - ) - a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T - t_logits_per_audio = ( - logit_scale_t * audio_features_mlp @ text_features.T - ) - t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T - - # calculated ground-truth and cache if enabled - num_logits = a_logits_per_audio.shape[0] - if self.prev_num_logits != num_logits or device not in self.labels: - labels = torch.arange(num_logits, device=device, dtype=torch.long) - if self.world_size > 1 and self.local_loss: - labels = labels + num_logits * self.rank - if self.cache_labels: - self.labels[device] = labels - self.prev_num_logits = num_logits - else: - labels = self.labels[device] - - if not self.weighted_loss: - total_loss = ( - F.cross_entropy(a_logits_per_audio, labels) - + F.cross_entropy(a_logits_per_text, labels) - + F.cross_entropy(t_logits_per_audio, labels) - + F.cross_entropy(t_logits_per_text, labels) - ) / 4 - else: - audio_weight = (audio_features @ audio_features.T).detach() - audio_weight = ( - torch.exp( - torch.sum(audio_weight, axis=1) - / (self.weight_loss_kappa * len(audio_weight)) - ) - ).detach() - text_weight = (text_features @ text_features.T).detach() - text_weight = ( - torch.exp( - torch.sum(text_weight, axis=1) - / (self.weight_loss_kappa * len(text_features)) - ) - ).detach() - total_loss = ( - F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) - + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) - + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) - + F.cross_entropy(t_logits_per_text, labels, weight=text_weight) - ) / 4 - else: - if self.world_size > 1: - all_audio_features, all_text_features = gather_features( - audio_features=audio_features, - text_features=text_features, - local_loss=self.local_loss, - gather_with_grad=self.gather_with_grad, - rank=self.rank, - world_size=self.world_size, - use_horovod=self.use_horovod, - mlp_loss=self.mlp_loss, - ) - - if self.local_loss: - logits_per_audio = ( - logit_scale_a * audio_features @ all_text_features.T - ) - logits_per_text = ( - logit_scale_a * text_features @ all_audio_features.T - ) - else: - logits_per_audio = ( - logit_scale_a * all_audio_features @ all_text_features.T - ) - logits_per_text = logits_per_audio.T - else: - logits_per_audio = logit_scale_a * audio_features @ text_features.T - logits_per_text = logit_scale_a * text_features @ audio_features.T - - # calculated ground-truth and cache if enabled - num_logits = logits_per_audio.shape[0] - if self.prev_num_logits != num_logits or device not in self.labels: - labels = torch.arange(num_logits, device=device, dtype=torch.long) - if self.world_size > 1 and self.local_loss: - labels = labels + num_logits * self.rank - if self.cache_labels: - self.labels[device] = labels - self.prev_num_logits = num_logits - else: - labels = self.labels[device] - if not self.weighted_loss: - total_loss = ( - F.cross_entropy(logits_per_audio, labels) - + F.cross_entropy(logits_per_text, labels) - ) / 2 - else: - audio_weight = (all_audio_features @ all_audio_features.T).detach() - audio_weight = ( - torch.exp( - torch.sum(audio_weight, axis=1) - / (self.weight_loss_kappa * len(all_audio_features)) - ) - ).detach() - text_weight = (all_text_features @ all_text_features.T).detach() - text_weight = ( - torch.exp( - torch.sum(text_weight, axis=1) - / (self.weight_loss_kappa * len(all_text_features)) - ) - ).detach() - total_loss = ( - F.cross_entropy(logits_per_audio, labels, weight=text_weight) - + F.cross_entropy(logits_per_text, labels, weight=audio_weight) - ) / 2 - return total_loss - - -def lp_gather_features(pred, target, world_size=1, use_horovod=False): - if use_horovod: - assert hvd is not None, "Please install horovod" - with torch.no_grad(): - all_preds = hvd.allgather(pred) - all_targets = hvd.allgath(target) - else: - gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] - gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] - - dist.all_gather(gathered_preds, pred) - dist.all_gather(gathered_targets, target) - all_preds = torch.cat(gathered_preds, dim=0) - all_targets = torch.cat(gathered_targets, dim=0) - - return all_preds, all_targets - - -def get_map(pred, target): - pred = torch.sigmoid(pred).numpy() - target = target.numpy() - return np.mean(average_precision_score(target, pred, average=None)) - - -def get_acc(pred, target): - pred = torch.argmax(pred, 1).numpy() - target = torch.argmax(target, 1).numpy() - return accuracy_score(target, pred) - - -def get_mauc(pred, target): - pred = torch.sigmoid(pred).numpy() - target = target.numpy() - return np.mean(roc_auc_score(target, pred, average=None)) - - -class LPMetrics(object): - def __init__(self, metric_names=["map", "acc", "mauc"]): - self.metrics = [] - for name in metric_names: - self.metrics.append(self.get_metric(name)) - self.metric_names = metric_names - - def get_metric(self, name): - if name == "map": - return get_map - elif name == "acc": - return get_acc - elif name == "mauc": - return get_mauc - else: - raise ValueError(f"the metric should be at least one of [map, acc, mauc]") - - def evaluate_mertics(self, pred, target): - metric_dict = {} - for i in range(len(self.metric_names)): - metric_dict[self.metric_names[i]] = self.metrics[i](pred, target) - return metric_dict - - -def calc_celoss(pred, target): - target = torch.argmax(target, 1).long() - return nn.CrossEntropyLoss()(pred, target) - - -class LPLoss(nn.Module): - def __init__(self, loss_name): - super().__init__() - if loss_name == "bce": - self.loss_func = nn.BCEWithLogitsLoss() - elif loss_name == "ce": - self.loss_func = calc_celoss - elif loss_name == "mse": - self.loss_func = nn.MSELoss() - else: - raise ValueError(f"the loss func should be at least one of [bce, ce, mse]") - - def forward(self, pred, target): - loss = self.loss_func(pred, target) - return loss diff --git a/audioldm2/clap/open_clip/model.py b/audioldm2/clap/open_clip/model.py deleted file mode 100755 index 130fb582d016868d478e2d10e90d7fc0e7999078..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model.py +++ /dev/null @@ -1,931 +0,0 @@ -""" CLAP Model - -Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. -Adapted to the Audio Task. -""" - -from collections import OrderedDict -from dataclasses import dataclass -from typing import Tuple, Union, Callable, Optional - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - -import logging -from .utils import freeze_batch_norm_2d - -from .pann_model import create_pann_model -from .htsat import create_htsat_model -from transformers import BertModel, RobertaModel, BartModel, RobertaConfig - - -class MLPLayers(nn.Module): - def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1): - super(MLPLayers, self).__init__() - self.nonlin = nonlin - self.dropout = dropout - - sequence = [] - for u0, u1 in zip(units[:-1], units[1:]): - sequence.append(nn.Linear(u0, u1)) - sequence.append(self.nonlin) - sequence.append(nn.Dropout(self.dropout)) - sequence = sequence[:-2] - - self.sequential = nn.Sequential(*sequence) - - def forward(self, X): - X = self.sequential(X) - return X - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1): - super().__init__() - - # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 - self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - - self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() - - self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - - self.relu = nn.ReLU(inplace=True) - self.downsample = None - self.stride = stride - - if stride > 1 or inplanes != planes * Bottleneck.expansion: - # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential( - OrderedDict( - [ - ("-1", nn.AvgPool2d(stride)), - ( - "0", - nn.Conv2d( - inplanes, - planes * self.expansion, - 1, - stride=1, - bias=False, - ), - ), - ("1", nn.BatchNorm2d(planes * self.expansion)), - ] - ) - ) - - def forward(self, x: torch.Tensor): - identity = x - - out = self.relu(self.bn1(self.conv1(x))) - out = self.relu(self.bn2(self.conv2(out))) - out = self.avgpool(out) - out = self.bn3(self.conv3(out)) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - return out - - -class AttentionPool2d(nn.Module): - def __init__( - self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None - ): - super().__init__() - self.positional_embedding = nn.Parameter( - torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 - ) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( - 2, 0, 1 - ) # NCHW -> (HW)NC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC - x, _ = F.multi_head_attention_forward( - query=x, - key=x, - value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat( - [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] - ), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False, - ) - - return x[0] - - -class ModifiedResNet(nn.Module): - """ - A ResNet class that is similar to torchvision's but contains the following changes: - - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - - The final pooling layer is a QKV attention instead of an average pool - """ - - def __init__(self, layers, output_dim, heads, image_size=224, width=64): - super().__init__() - self.output_dim = output_dim - self.image_size = image_size - - # the 3-layer stem - self.conv1 = nn.Conv2d( - 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False - ) - self.bn1 = nn.BatchNorm2d(width // 2) - self.conv2 = nn.Conv2d( - width // 2, width // 2, kernel_size=3, padding=1, bias=False - ) - self.bn2 = nn.BatchNorm2d(width // 2) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) - self.bn3 = nn.BatchNorm2d(width) - self.avgpool = nn.AvgPool2d(2) - self.relu = nn.ReLU(inplace=True) - - # residual layers - self._inplanes = width # this is a *mutable* variable used during construction - self.layer1 = self._make_layer(width, layers[0]) - self.layer2 = self._make_layer(width * 2, layers[1], stride=2) - self.layer3 = self._make_layer(width * 4, layers[2], stride=2) - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) - - embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) - - self.init_parameters() - - def _make_layer(self, planes, blocks, stride=1): - layers = [Bottleneck(self._inplanes, planes, stride)] - - self._inplanes = planes * Bottleneck.expansion - for _ in range(1, blocks): - layers.append(Bottleneck(self._inplanes, planes)) - - return nn.Sequential(*layers) - - def init_parameters(self): - if self.attnpool is not None: - std = self.attnpool.c_proj.in_features**-0.5 - nn.init.normal_(self.attnpool.q_proj.weight, std=std) - nn.init.normal_(self.attnpool.k_proj.weight, std=std) - nn.init.normal_(self.attnpool.v_proj.weight, std=std) - nn.init.normal_(self.attnpool.c_proj.weight, std=std) - - for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: - for name, param in resnet_block.named_parameters(): - if name.endswith("bn3.weight"): - nn.init.zeros_(param) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - assert ( - unlocked_groups == 0 - ), "partial locking not currently supported for this model" - for param in self.parameters(): - param.requires_grad = False - if freeze_bn_stats: - freeze_batch_norm_2d(self) - - def stem(self, x): - for conv, bn in [ - (self.conv1, self.bn1), - (self.conv2, self.bn2), - (self.conv3, self.bn3), - ]: - x = self.relu(bn(conv(x))) - x = self.avgpool(x) - return x - - def forward(self, x): - x = self.stem(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.attnpool(x) - - return x - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - return x.to(orig_type) - - -class QuickGELU(nn.Module): - # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class ResidualAttentionBlock(nn.Module): - def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU): - super().__init__() - - self.attn = nn.MultiheadAttention(d_model, n_head) - self.ln_1 = LayerNorm(d_model) - self.mlp = nn.Sequential( - OrderedDict( - [ - ("c_fc", nn.Linear(d_model, d_model * 4)), - ("gelu", act_layer()), - ("c_proj", nn.Linear(d_model * 4, d_model)), - ] - ) - ) - self.ln_2 = LayerNorm(d_model) - - def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] - - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) - x = x + self.mlp(self.ln_2(x)) - return x - - -class Transformer(nn.Module): - def __init__( - self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU - ): - super().__init__() - self.width = width - self.layers = layers - self.resblocks = nn.ModuleList( - [ - ResidualAttentionBlock(width, heads, act_layer=act_layer) - for _ in range(layers) - ] - ) - - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - for r in self.resblocks: - x = r(x, attn_mask=attn_mask) - return x - - -class VisualTransformer(nn.Module): - def __init__( - self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - output_dim: int, - act_layer: Callable = nn.GELU, - ): - super().__init__() - self.image_size = image_size - self.output_dim = output_dim - self.conv1 = nn.Conv2d( - in_channels=3, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - bias=False, - ) - - scale = width**-0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter( - scale * torch.randn((image_size // patch_size) ** 2 + 1, width) - ) - self.ln_pre = LayerNorm(width) - - self.text_branch = Transformer(width, layers, heads, act_layer=act_layer) - - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - assert ( - unlocked_groups == 0 - ), "partial locking not currently supported for this model" - for param in self.parameters(): - param.requires_grad = False - - def forward(self, x: torch.Tensor): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat( - [ - self.class_embedding.to(x.dtype) - + torch.zeros( - x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device - ), - x, - ], - dim=1, - ) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.text_branch(x) - x = x.permute(1, 0, 2) # LND -> NLD - - x = self.ln_post(x[:, 0, :]) - - if self.proj is not None: - x = x @ self.proj - - return x - - -@dataclass -class CLAPVisionCfg: - layers: Union[Tuple[int, int, int, int], int] = 12 - width: int = 768 - patch_size: int = 16 - image_size: Union[Tuple[int, int], int] = 224 - timm_model_name: str = ( - None # a valid model name overrides layers, width, patch_size - ) - timm_model_pretrained: bool = ( - False # use (imagenet) pretrained weights for named model - ) - timm_pool: str = ( - "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') - ) - timm_proj: str = ( - "linear" # linear projection for timm model output ('linear', 'mlp', '') - ) - - -# Audio Config Class -@dataclass -class CLAPAudioCfp: - model_type: str = "PANN" - model_name: str = "Cnn14" - sample_rate: int = 48000 - # Param - audio_length: int = 1024 - window_size: int = 1024 - hop_size: int = 1024 - fmin: int = 50 - fmax: int = 14000 - class_num: int = 527 - mel_bins: int = 64 - clip_samples: int = 480000 - - -@dataclass -class CLAPTextCfg: - context_length: int - vocab_size: int - width: int - heads: int - layers: int - model_type: str - - -class CLAP(nn.Module): - def __init__( - self, - embed_dim: int, - audio_cfg: CLAPAudioCfp, - text_cfg: CLAPTextCfg, - quick_gelu: bool = False, - enable_fusion: bool = False, - fusion_type: str = "None", - joint_embed_shape: int = 512, - mlp_act: str = "relu", - ): - super().__init__() - if isinstance(audio_cfg, dict): - audio_cfg = CLAPAudioCfp(**audio_cfg) - if isinstance(text_cfg, dict): - text_cfg = CLAPTextCfg(**text_cfg) - - self.audio_cfg = audio_cfg - self.text_cfg = text_cfg - self.enable_fusion = enable_fusion - self.fusion_type = fusion_type - self.joint_embed_shape = joint_embed_shape - self.mlp_act = mlp_act - - self.context_length = text_cfg.context_length - - # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more - # memory efficient in recent PyTorch releases (>= 1.10). - # NOTE: timm models always use native GELU regardless of quick_gelu flag. - act_layer = QuickGELU if quick_gelu else nn.GELU - - if mlp_act == "relu": - mlp_act_layer = nn.ReLU() - elif mlp_act == "gelu": - mlp_act_layer = nn.GELU() - else: - raise NotImplementedError - - # audio branch - # audio branch parameters - if audio_cfg.model_type == "PANN": - self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type) - elif audio_cfg.model_type == "HTSAT": - self.audio_branch = create_htsat_model( - audio_cfg, enable_fusion, fusion_type - ) - else: - logging.error(f"Model config for {audio_cfg.model_type} not found") - raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.") - - # text branch - # text branch parameters - if text_cfg.model_type == "transformer": - self.text_branch = Transformer( - width=text_cfg.width, - layers=text_cfg.layers, - heads=text_cfg.heads, - act_layer=act_layer, - ) - self.vocab_size = text_cfg.vocab_size - self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) - self.positional_embedding = nn.Parameter( - torch.empty(self.context_length, text_cfg.width) - ) - self.ln_final = LayerNorm(text_cfg.width) - self.text_transform = MLPLayers( - units=[ - self.joint_embed_shape, - self.joint_embed_shape, - self.joint_embed_shape, - ], - dropout=0.1, - ) - self.text_projection = nn.Sequential( - nn.Linear(text_cfg.width, self.joint_embed_shape), - mlp_act_layer, - nn.Linear(self.joint_embed_shape, self.joint_embed_shape), - ) - elif text_cfg.model_type == "bert": - self.text_branch = BertModel.from_pretrained("bert-base-uncased") - self.text_transform = MLPLayers( - units=[ - self.joint_embed_shape, - self.joint_embed_shape, - self.joint_embed_shape, - ], - dropout=0.1, - ) - self.text_projection = nn.Sequential( - nn.Linear(768, self.joint_embed_shape), - mlp_act_layer, - nn.Linear(self.joint_embed_shape, self.joint_embed_shape), - ) - elif text_cfg.model_type == "roberta": - self.text_branch = RobertaModel( - RobertaConfig.from_pretrained("roberta-base") - ) - self.text_transform = MLPLayers( - units=[ - self.joint_embed_shape, - self.joint_embed_shape, - self.joint_embed_shape, - ], - dropout=0.1, - ) - self.text_projection = nn.Sequential( - nn.Linear(768, self.joint_embed_shape), - mlp_act_layer, - nn.Linear(self.joint_embed_shape, self.joint_embed_shape), - ) - elif text_cfg.model_type == "bart": - self.text_branch = BartModel.from_pretrained("facebook/bart-base") - self.text_transform = MLPLayers( - units=[ - self.joint_embed_shape, - self.joint_embed_shape, - self.joint_embed_shape, - ], - dropout=0.1, - ) - self.text_projection = nn.Sequential( - nn.Linear(768, self.joint_embed_shape), - mlp_act_layer, - nn.Linear(self.joint_embed_shape, self.joint_embed_shape), - ) - else: - logging.error(f"Model config for {text_cfg.model_type} not found") - raise RuntimeError(f"Model config for {text_cfg.model_type} not found.") - self.text_branch_type = text_cfg.model_type - # text branch parameters - - # audio branch parameters - self.audio_transform = MLPLayers( - units=[ - self.joint_embed_shape, - self.joint_embed_shape, - self.joint_embed_shape, - ], - dropout=0.1, - ) - - # below here is text branch parameters - - # ============================================================================================================ - self.audio_projection = nn.Sequential( - nn.Linear(embed_dim, self.joint_embed_shape), - mlp_act_layer, - nn.Linear(self.joint_embed_shape, self.joint_embed_shape), - ) - - self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) - - self.init_text_branch_parameters() - - def init_text_branch_parameters(self): - if self.text_branch_type == "transformer": - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - proj_std = (self.text_branch.width**-0.5) * ( - (2 * self.text_branch.layers) ** -0.5 - ) - attn_std = self.text_branch.width**-0.5 - fc_std = (2 * self.text_branch.width) ** -0.5 - for block in self.text_branch.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - if self.text_branch_type == "bert" or self.text_branch_type == "roberta": - self.text_branch.embeddings.word_embeddings.weight.shape[-1] - elif self.text_branch_type == "bart": - self.text_branch.shared.weight.shape[-1] - else: - self.text_branch.width - nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07)) - nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07)) - - # deprecated - # if hasattr(self.visual, 'init_parameters'): - # self.visual.init_parameters() - - # if self.text_projection is not None: - # nn.init.normal_(self.text_projection, std=width**-0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - - def encode_audio(self, audio, device): - return self.audio_branch( - audio, mixup_lambda=None, device=device - ) # mix lambda needs to add - - # def list_of_dict_of_tensor2dict_of_tensor(self, x, device): - # tmp = {} - # for k in x[0].keys(): - # tmp[k] = [] - # for i in range(len(x)): - # tmp[k].append(x[i][k][:77]) - # for k in x[0].keys(): - # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True) - # return tmp - - def encode_text(self, text, device): - if self.text_branch_type == "transformer": - text = text.to(device=device, non_blocking=True) - x = self.token_embedding(text) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding - x = x.permute(1, 0, 2) # NLD -> LND - x = self.text_branch(x, attn_mask=self.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) - - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)]) - elif self.text_branch_type == "bert": - # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device) - # text = BatchEncoding(text) - x = self.text_branch( - input_ids=text["input_ids"].to(device=device, non_blocking=True), - attention_mask=text["attention_mask"].to( - device=device, non_blocking=True - ), - token_type_ids=text["token_type_ids"].to( - device=device, non_blocking=True - ), - )["pooler_output"] - x = self.text_projection(x) - elif self.text_branch_type == "roberta": - x = self.text_branch( - input_ids=text["input_ids"].to(device=device, non_blocking=True), - attention_mask=text["attention_mask"].to( - device=device, non_blocking=True - ), - )["pooler_output"] - x = self.text_projection(x) - elif self.text_branch_type == "bart": - x = torch.mean( - self.text_branch( - input_ids=text["input_ids"].to(device=device, non_blocking=True), - attention_mask=text["attention_mask"].to( - device=device, non_blocking=True - ), - )["encoder_last_hidden_state"], - axis=1, - ) - x = self.text_projection(x) - else: - logging.error(f"Model type {self.text_branch_type} not found") - raise RuntimeError(f"Model type {self.text_branch_type} not found.") - return x - - def forward(self, audio, text, device=None): - """Forward audio and text into the CLAP - - Parameters - ---------- - audio: torch.Tensor (batch_size, audio_length) - the time-domain audio input / the batch of mel_spec and longer list. - text: torch.Tensor () // need to add - the text token input - """ - if device is None: - if audio is not None: - device = audio.device - elif text is not None: - device = text.device - if audio is None and text is None: - # a hack to get the logit scale - return self.logit_scale_a.exp(), self.logit_scale_t.exp() - elif audio is None: - return self.encode_text(text, device=device) - elif text is None: - return self.audio_projection( - self.encode_audio(audio, device=device)["embedding"] - ) - audio_features = self.audio_projection( - self.encode_audio(audio, device=device)["embedding"] - ) - audio_features = F.normalize(audio_features, dim=-1) - - text_features = self.encode_text(text, device=device) - # print("text_features", text_features) - # print("text_features.shape", text_features.shape) - # print("text_features.type", type(text_features)) - text_features = F.normalize(text_features, dim=-1) - - audio_features_mlp = self.audio_transform(audio_features) - text_features_mlp = self.text_transform(text_features) - # Four outputs: audio features (basic & MLP), text features (basic & MLP) - return ( - audio_features, - text_features, - audio_features_mlp, - text_features_mlp, - self.logit_scale_a.exp(), - self.logit_scale_t.exp(), - ) - - def get_logit_scale(self): - return self.logit_scale_a.exp(), self.logit_scale_t.exp() - - def get_text_embedding(self, data): - """Get the text embedding from the model - - Parameters - ---------- - data: torch.Tensor - a tensor of text embedding - - Returns - ---------- - text_embed: torch.Tensor - a tensor of text_embeds (N, D) - - """ - device = next(self.parameters()).device - for k in data: - data[k] = data[k].to(device) - text_embeds = self.encode_text(data, device=device) - text_embeds = F.normalize(text_embeds, dim=-1) - - return text_embeds - - def get_audio_embedding(self, data): - """Get the audio embedding from the model - - Parameters - ---------- - data: a list of dict - the audio input dict list from 'get_audio_feature' method - - Returns - ---------- - audio_embed: torch.Tensor - a tensor of audio_embeds (N, D) - - """ - device = next(self.parameters()).device - # input_dict = {} - # keys = data[0].keys() - # for k in keys: - # input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to( - # device - # ) - audio_embeds = self.audio_projection( - self.encode_audio(data, device=device)["embedding"] - ) - audio_embeds = F.normalize(audio_embeds, dim=-1) - - return audio_embeds - - def audio_infer(self, audio, hopsize=None, device=None): - """Forward one audio and produce the audio embedding - - Parameters - ---------- - audio: (audio_length) - the time-domain audio input, notice that it must be only one input - hopsize: int - the overlap hopsize as the sliding window - - Returns - ---------- - output_dict: { - key: [n, (embedding_shape)] if "HTS-AT" - or - key: [(embedding_shape)] if "PANN" - } - the list of key values of the audio branch - - """ - - assert not self.training, "the inference mode must be run at eval stage" - output_dict = {} - # PANN - if self.audio_cfg.model_type == "PANN": - audio_input = audio.unsqueeze(dim=0) - output_dict[key] = self.encode_audio(audio_input, device=device)[ - key - ].squeeze(dim=0) - elif self.audio_cfg.model_type == "HTSAT": - # repeat - audio_len = len(audio) - k = self.audio_cfg.clip_samples // audio_len - if k > 1: - audio = audio.repeat(k) - audio_len = len(audio) - - if hopsize is None: - hopsize = min(hopsize, audio_len) - - if audio_len > self.audio_cfg.clip_samples: - audio_input = [ - audio[pos : pos + self.audio_cfg.clip_samples].clone() - for pos in range( - 0, audio_len - self.audio_cfg.clip_samples, hopsize - ) - ] - audio_input.append(audio[-self.audio_cfg.clip_samples :].clone()) - audio_input = torch.stack(audio_input) - output_dict[key] = self.encode_audio(audio_input, device=device)[key] - else: - audio_input = audio.unsqueeze(dim=0) - output_dict[key] = self.encode_audio(audio_input, device=device)[ - key - ].squeeze(dim=0) - - return output_dict - - -def convert_weights_to_fp16(model: nn.Module): - """Convert applicable model parameters to fp16""" - - def _convert_weights_to_fp16(l): - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() - - if isinstance(l, nn.MultiheadAttention): - for attr in [ - *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], - "in_proj_bias", - "bias_k", - "bias_v", - ]: - tensor = getattr(l, attr) - if tensor is not None: - tensor.data = tensor.data.half() - - for name in ["text_projection", "proj"]: - if hasattr(l, name): - attr = getattr(l, name) - if attr is not None: - attr.data = attr.data.half() - - model.apply(_convert_weights_to_fp16) - - -# Ignore the state dict of the vision part -def build_model_from_openai_state_dict( - state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None" -): - embed_dim = model_cfg["embed_dim"] - audio_cfg = model_cfg["audio_cfg"] - text_cfg = model_cfg["text_cfg"] - state_dict["positional_embedding"].shape[0] - state_dict["token_embedding.weight"].shape[0] - transformer_width = state_dict["ln_final.weight"].shape[0] - transformer_width // 64 - transformer_layers = len( - set( - k.split(".")[2] - for k in state_dict - if k.startswith(f"transformer.resblocks") - ) - ) - - audio_cfg = CLAPAudioCfp(**audio_cfg) - text_cfg = CLAPTextCfg(**text_cfg) - - model = CLAP( - embed_dim, - audio_cfg=audio_cfg, - text_cfg=text_cfg, - quick_gelu=True, # OpenAI models were trained with QuickGELU - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - state_dict["logit_scale_a"] = state_dict["logit_scale"] - state_dict["logit_scale_t"] = state_dict["logit_scale"] - pop_keys = list(state_dict.keys())[::] - # pop the visual branch saved weights - for key in pop_keys: - if key.startswith("visual."): - state_dict.pop(key, None) - - for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]: - state_dict.pop(key, None) - - # not use fp16 - # convert_weights_to_fp16(model) - model.load_state_dict(state_dict, strict=False) - return model.eval() - - -def trace_model(model, batch_size=256, device=torch.device("cpu")): - model.eval() - audio_length = model.audio_cfg.audio_length - example_audio = torch.ones((batch_size, audio_length), device=device) - example_text = torch.zeros( - (batch_size, model.context_length), dtype=torch.int, device=device - ) - model = torch.jit.trace_module( - model, - inputs=dict( - forward=(example_audio, example_text), - encode_text=(example_text,), - encode_image=(example_audio,), - ), - ) - model.audio_cfg.audio_length = audio_length # Question: what does this do? - return model diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-base.json b/audioldm2/clap/open_clip/model_configs/HTSAT-base.json deleted file mode 100755 index 6cef625a89daf4431f1c9f72e10bc9640eef2ba8..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/HTSAT-base.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 1024, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "HTSAT", - "model_name": "base" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-large.json b/audioldm2/clap/open_clip/model_configs/HTSAT-large.json deleted file mode 100755 index 699cdb1b16855582606551e4196b24aba2ffd871..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/HTSAT-large.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 2048, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "HTSAT", - "model_name": "large" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json deleted file mode 100755 index 73e42990fe8361a0df502e7f93d29f19f58c9ecb..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 768, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1536, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "HTSAT", - "model_name": "tiny" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json deleted file mode 100755 index a6e7821163d9afa81c27345a1e472475b92af169..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 768, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "HTSAT", - "model_name": "tiny" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-10.json b/audioldm2/clap/open_clip/model_configs/PANN-10.json deleted file mode 100755 index 954ddf62921aed7dde9c37ffffec98a2e96a4ee7..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/PANN-10.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 1024, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn10" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json deleted file mode 100755 index b7989bc0cd95d0d39049b7524eba508b3e386439..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 2048, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 18000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn14" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json deleted file mode 100755 index 56bdb56bedc304ffa52d8bf5988cea2c1d82d14e..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 2048, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 960000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 360, - "fmin": 50, - "fmax": 8000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn14" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json b/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json deleted file mode 100755 index 5756e3bebc97cc985f512cb081930fee4e49bec1..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 2048, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn14" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 4 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json b/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json deleted file mode 100755 index 5a9e7e208b661619d5e26625e849da1adda8a475..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 2048, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1536, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn14" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14.json b/audioldm2/clap/open_clip/model_configs/PANN-14.json deleted file mode 100755 index 39a5134cde1d8c50f4758377c952ef22f07bab41..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/PANN-14.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 2048, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn14" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/PANN-6.json b/audioldm2/clap/open_clip/model_configs/PANN-6.json deleted file mode 100755 index 21ebc344326de260c386ba77e0ad63cf9b04febf..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/PANN-6.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 512, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn6" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json b/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json deleted file mode 100755 index d0db2c161d13138788c4609d373b023b8454d624..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "embed_dim": 512, - "quick_gelu": true, - "vision_cfg": { - "image_size": 224, - "layers": [ - 3, - 4, - 23, - 3 - ], - "width": 64, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/RN101.json b/audioldm2/clap/open_clip/model_configs/RN101.json deleted file mode 100755 index b88b4d3acbaa701c614ab0ea65fc88fcfe289c32..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/RN101.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": [ - 3, - 4, - 23, - 3 - ], - "width": 64, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json b/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json deleted file mode 100755 index 8c2f91260cdeb043434dc1e893cce81d4ce7f0d1..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "embed_dim": 1024, - "quick_gelu": true, - "vision_cfg": { - "image_size": 224, - "layers": [ - 3, - 4, - 6, - 3 - ], - "width": 64, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} diff --git a/audioldm2/clap/open_clip/model_configs/RN50.json b/audioldm2/clap/open_clip/model_configs/RN50.json deleted file mode 100755 index 33aa884d54fee0076c33676831e49d5e1ffcb8f2..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/RN50.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 224, - "layers": [ - 3, - 4, - 6, - 3 - ], - "width": 64, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/RN50x16.json b/audioldm2/clap/open_clip/model_configs/RN50x16.json deleted file mode 100755 index 3161e1a2c9a839161e652a4d729c2cdc971161db..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/RN50x16.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 384, - "layers": [ - 6, - 8, - 18, - 8 - ], - "width": 96, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/RN50x4.json b/audioldm2/clap/open_clip/model_configs/RN50x4.json deleted file mode 100755 index e155237f8ce1026aaaeecc80751eabe6f329f0bb..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/RN50x4.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "embed_dim": 640, - "vision_cfg": { - "image_size": 288, - "layers": [ - 4, - 6, - 10, - 6 - ], - "width": 80, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 640, - "heads": 10, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/ViT-B-16.json b/audioldm2/clap/open_clip/model_configs/ViT-B-16.json deleted file mode 100755 index 395eea77ec3907c0611531aba63459b193e67b9c..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/ViT-B-16.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json b/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json deleted file mode 100755 index ce6bd923593293ed50dfcfb28b73ca7403bcf3c5..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 512, - "quick_gelu": true, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/ViT-B-32.json b/audioldm2/clap/open_clip/model_configs/ViT-B-32.json deleted file mode 100755 index 07c8e28eb06fa1813ba932fe4eec668262d1c47f..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/ViT-B-32.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/model_configs/ViT-L-14.json b/audioldm2/clap/open_clip/model_configs/ViT-L-14.json deleted file mode 100755 index d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/model_configs/ViT-L-14.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 224, - "layers": 24, - "width": 1024, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm2/clap/open_clip/openai.py b/audioldm2/clap/open_clip/openai.py deleted file mode 100755 index 3f4eb8b55fe960e1792b3da804b60b3d8f70fe26..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/openai.py +++ /dev/null @@ -1,156 +0,0 @@ -""" OpenAI pretrained model functions - -Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. -""" - -import os -import warnings -from typing import Union, List - -import torch - -from .model import build_model_from_openai_state_dict -from .pretrained import ( - get_pretrained_url, - list_pretrained_tag_models, - download_pretrained, -) - -__all__ = ["list_openai_models", "load_openai_model"] - - -def list_openai_models() -> List[str]: - """Returns the names of available CLIP models""" - return list_pretrained_tag_models("openai") - - -def load_openai_model( - name: str, - model_cfg, - device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", - jit=True, - cache_dir=os.path.expanduser("~/.cache/clip"), - enable_fusion: bool = False, - fusion_type: str = "None", -): - """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model - - Parameters - ---------- - name : str - A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict - device : Union[str, torch.device] - The device to put the loaded model - jit : bool - Whether to load the optimized JIT model (default) or more hackable non-JIT model. - - Returns - ------- - model : torch.nn.Module - The CLAP model - preprocess : Callable[[PIL.Image], torch.Tensor] - A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input - """ - if get_pretrained_url(name, "openai"): - model_path = download_pretrained( - get_pretrained_url(name, "openai"), root=cache_dir - ) - elif os.path.isfile(name): - model_path = name - else: - raise RuntimeError( - f"Model {name} not found; available models = {list_openai_models()}" - ) - - try: - # loading JIT archive - model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() - state_dict = None - except RuntimeError: - # loading saved state dict - if jit: - warnings.warn( - f"File {model_path} is not a JIT archive. Loading as a state dict instead" - ) - jit = False - state_dict = torch.load(model_path, map_location="cpu") - - if not jit: - try: - model = build_model_from_openai_state_dict( - state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type - ).to(device) - except KeyError: - sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} - model = build_model_from_openai_state_dict( - sd, model_cfg, enable_fusion, fusion_type - ).to(device) - - if str(device) == "cpu": - model.float() - return model - - # patch the device names - device_holder = torch.jit.trace( - lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] - ) - device_node = [ - n - for n in device_holder.graph.findAllNodes("prim::Constant") - if "Device" in repr(n) - ][-1] - - def patch_device(module): - try: - graphs = [module.graph] if hasattr(module, "graph") else [] - except RuntimeError: - graphs = [] - - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("prim::Constant"): - if "value" in node.attributeNames() and str(node["value"]).startswith( - "cuda" - ): - node.copyAttributes(device_node) - - model.apply(patch_device) - patch_device(model.encode_audio) - patch_device(model.encode_text) - - # patch dtype to float32 on CPU - if str(device) == "cpu": - float_holder = torch.jit.trace( - lambda: torch.ones([]).float(), example_inputs=[] - ) - float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] - float_node = float_input.node() - - def patch_float(module): - try: - graphs = [module.graph] if hasattr(module, "graph") else [] - except RuntimeError: - graphs = [] - - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("aten::to"): - inputs = list(node.inputs()) - for i in [ - 1, - 2, - ]: # dtype can be the second or third argument to aten::to() - if inputs[i].node()["value"] == 5: - inputs[i].node().copyAttributes(float_node) - - model.apply(patch_float) - patch_float(model.encode_audio) - patch_float(model.encode_text) - model.float() - - model.audio_branch.audio_length = model.audio_cfg.audio_length - return model diff --git a/audioldm2/clap/open_clip/pann_model.py b/audioldm2/clap/open_clip/pann_model.py deleted file mode 100755 index e9fab8e03cdca370c141a9e321e98d256e79fb27..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/pann_model.py +++ /dev/null @@ -1,697 +0,0 @@ -# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition -# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn -# Some layers are re-designed for CLAP -import os - -os.environ["NUMBA_CACHE_DIR"] = "/tmp/" - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchlibrosa.stft import Spectrogram, LogmelFilterBank -from torchlibrosa.augmentation import SpecAugmentation - -from .utils import do_mixup, interpolate -from .feature_fusion import iAFF, AFF, DAF - - -def init_layer(layer): - """Initialize a Linear or Convolutional layer.""" - nn.init.xavier_uniform_(layer.weight) - - if hasattr(layer, "bias"): - if layer.bias is not None: - layer.bias.data.fill_(0.0) - - -def init_bn(bn): - """Initialize a Batchnorm layer.""" - bn.bias.data.fill_(0.0) - bn.weight.data.fill_(1.0) - - -class ConvBlock(nn.Module): - def __init__(self, in_channels, out_channels): - super(ConvBlock, self).__init__() - - self.conv1 = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(3, 3), - stride=(1, 1), - padding=(1, 1), - bias=False, - ) - - self.conv2 = nn.Conv2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=(3, 3), - stride=(1, 1), - padding=(1, 1), - bias=False, - ) - - self.bn1 = nn.BatchNorm2d(out_channels) - self.bn2 = nn.BatchNorm2d(out_channels) - - self.init_weight() - - def init_weight(self): - init_layer(self.conv1) - init_layer(self.conv2) - init_bn(self.bn1) - init_bn(self.bn2) - - def forward(self, input, pool_size=(2, 2), pool_type="avg"): - x = input - x = F.relu_(self.bn1(self.conv1(x))) - x = F.relu_(self.bn2(self.conv2(x))) - if pool_type == "max": - x = F.max_pool2d(x, kernel_size=pool_size) - elif pool_type == "avg": - x = F.avg_pool2d(x, kernel_size=pool_size) - elif pool_type == "avg+max": - x1 = F.avg_pool2d(x, kernel_size=pool_size) - x2 = F.max_pool2d(x, kernel_size=pool_size) - x = x1 + x2 - else: - raise Exception("Incorrect argument!") - - return x - - -class ConvBlock5x5(nn.Module): - def __init__(self, in_channels, out_channels): - super(ConvBlock5x5, self).__init__() - - self.conv1 = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(5, 5), - stride=(1, 1), - padding=(2, 2), - bias=False, - ) - - self.bn1 = nn.BatchNorm2d(out_channels) - - self.init_weight() - - def init_weight(self): - init_layer(self.conv1) - init_bn(self.bn1) - - def forward(self, input, pool_size=(2, 2), pool_type="avg"): - x = input - x = F.relu_(self.bn1(self.conv1(x))) - if pool_type == "max": - x = F.max_pool2d(x, kernel_size=pool_size) - elif pool_type == "avg": - x = F.avg_pool2d(x, kernel_size=pool_size) - elif pool_type == "avg+max": - x1 = F.avg_pool2d(x, kernel_size=pool_size) - x2 = F.max_pool2d(x, kernel_size=pool_size) - x = x1 + x2 - else: - raise Exception("Incorrect argument!") - - return x - - -class AttBlock(nn.Module): - def __init__(self, n_in, n_out, activation="linear", temperature=1.0): - super(AttBlock, self).__init__() - - self.activation = activation - self.temperature = temperature - self.att = nn.Conv1d( - in_channels=n_in, - out_channels=n_out, - kernel_size=1, - stride=1, - padding=0, - bias=True, - ) - self.cla = nn.Conv1d( - in_channels=n_in, - out_channels=n_out, - kernel_size=1, - stride=1, - padding=0, - bias=True, - ) - - self.bn_att = nn.BatchNorm1d(n_out) - self.init_weights() - - def init_weights(self): - init_layer(self.att) - init_layer(self.cla) - init_bn(self.bn_att) - - def forward(self, x): - # x: (n_samples, n_in, n_time) - norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) - cla = self.nonlinear_transform(self.cla(x)) - x = torch.sum(norm_att * cla, dim=2) - return x, norm_att, cla - - def nonlinear_transform(self, x): - if self.activation == "linear": - return x - elif self.activation == "sigmoid": - return torch.sigmoid(x) - - -class Cnn14(nn.Module): - def __init__( - self, - sample_rate, - window_size, - hop_size, - mel_bins, - fmin, - fmax, - classes_num, - enable_fusion=False, - fusion_type="None", - ): - super(Cnn14, self).__init__() - - window = "hann" - center = True - pad_mode = "reflect" - ref = 1.0 - amin = 1e-10 - top_db = None - - self.enable_fusion = enable_fusion - self.fusion_type = fusion_type - - # Spectrogram extractor - self.spectrogram_extractor = Spectrogram( - n_fft=window_size, - hop_length=hop_size, - win_length=window_size, - window=window, - center=center, - pad_mode=pad_mode, - freeze_parameters=True, - ) - - # Logmel feature extractor - self.logmel_extractor = LogmelFilterBank( - sr=sample_rate, - n_fft=window_size, - n_mels=mel_bins, - fmin=fmin, - fmax=fmax, - ref=ref, - amin=amin, - top_db=top_db, - freeze_parameters=True, - ) - - # Spec augmenter - self.spec_augmenter = SpecAugmentation( - time_drop_width=64, - time_stripes_num=2, - freq_drop_width=8, - freq_stripes_num=2, - ) - - self.bn0 = nn.BatchNorm2d(64) - - if (self.enable_fusion) and (self.fusion_type == "channel_map"): - self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) - else: - self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) - self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) - self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) - self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) - self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) - self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) - - self.fc1 = nn.Linear(2048, 2048, bias=True) - self.fc_audioset = nn.Linear(2048, classes_num, bias=True) - - if (self.enable_fusion) and ( - self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] - ): - self.mel_conv1d = nn.Sequential( - nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), - nn.BatchNorm1d(64), # No Relu - ) - if self.fusion_type == "daf_1d": - self.fusion_model = DAF() - elif self.fusion_type == "aff_1d": - self.fusion_model = AFF(channels=64, type="1D") - elif self.fusion_type == "iaff_1d": - self.fusion_model = iAFF(channels=64, type="1D") - - if (self.enable_fusion) and ( - self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] - ): - self.mel_conv2d = nn.Sequential( - nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)), - nn.BatchNorm2d(64), - nn.ReLU(inplace=True), - ) - - if self.fusion_type == "daf_2d": - self.fusion_model = DAF() - elif self.fusion_type == "aff_2d": - self.fusion_model = AFF(channels=64, type="2D") - elif self.fusion_type == "iaff_2d": - self.fusion_model = iAFF(channels=64, type="2D") - self.init_weight() - - def init_weight(self): - init_bn(self.bn0) - init_layer(self.fc1) - init_layer(self.fc_audioset) - - def forward(self, input, mixup_lambda=None, device=None): - """ - Input: (batch_size, data_length)""" - - if self.enable_fusion and input["longer"].sum() == 0: - # if no audio is longer than 10s, then randomly select one audio to be longer - input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True - - if not self.enable_fusion: - x = self.spectrogram_extractor( - input["waveform"].to(device=device, non_blocking=True) - ) # (batch_size, 1, time_steps, freq_bins) - x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) - - x = x.transpose(1, 3) - x = self.bn0(x) - x = x.transpose(1, 3) - else: - longer_list = input["longer"].to(device=device, non_blocking=True) - x = input["mel_fusion"].to(device=device, non_blocking=True) - longer_list_idx = torch.where(longer_list)[0] - x = x.transpose(1, 3) - x = self.bn0(x) - x = x.transpose(1, 3) - if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: - new_x = x[:, 0:1, :, :].clone().contiguous() - # local processing - if len(longer_list_idx) > 0: - fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() - FB, FC, FT, FF = fusion_x_local.size() - fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) - fusion_x_local = torch.permute( - fusion_x_local, (0, 2, 1) - ).contiguous() - fusion_x_local = self.mel_conv1d(fusion_x_local) - fusion_x_local = fusion_x_local.view( - FB, FC, FF, fusion_x_local.size(-1) - ) - fusion_x_local = ( - torch.permute(fusion_x_local, (0, 2, 1, 3)) - .contiguous() - .flatten(2) - ) - if fusion_x_local.size(-1) < FT: - fusion_x_local = torch.cat( - [ - fusion_x_local, - torch.zeros( - (FB, FF, FT - fusion_x_local.size(-1)), - device=device, - ), - ], - dim=-1, - ) - else: - fusion_x_local = fusion_x_local[:, :, :FT] - # 1D fusion - new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() - new_x[longer_list_idx] = self.fusion_model( - new_x[longer_list_idx], fusion_x_local - ) - x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] - else: - x = new_x - elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: - x = x # no change - - if self.training: - x = self.spec_augmenter(x) - # Mixup on spectrogram - if self.training and mixup_lambda is not None: - x = do_mixup(x, mixup_lambda) - if (self.enable_fusion) and ( - self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] - ): - global_x = x[:, 0:1, :, :] - - # global processing - B, C, H, W = global_x.shape - global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg") - if len(longer_list_idx) > 0: - local_x = x[longer_list_idx, 1:, :, :].contiguous() - TH = global_x.size(-2) - # local processing - B, C, H, W = local_x.shape - local_x = local_x.view(B * C, 1, H, W) - local_x = self.mel_conv2d(local_x) - local_x = local_x.view( - B, C, local_x.size(1), local_x.size(2), local_x.size(3) - ) - local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3) - TB, TC, _, TW = local_x.size() - if local_x.size(-2) < TH: - local_x = torch.cat( - [ - local_x, - torch.zeros( - (TB, TC, TH - local_x.size(-2), TW), - device=global_x.device, - ), - ], - dim=-2, - ) - else: - local_x = local_x[:, :, :TH, :] - - global_x[longer_list_idx] = self.fusion_model( - global_x[longer_list_idx], local_x - ) - x = global_x - else: - x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") - - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = torch.mean(x, dim=3) - - latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) - latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) - latent_x = latent_x1 + latent_x2 - latent_x = latent_x.transpose(1, 2) - latent_x = F.relu_(self.fc1(latent_x)) - latent_output = interpolate(latent_x, 32) - - (x1, _) = torch.max(x, dim=2) - x2 = torch.mean(x, dim=2) - x = x1 + x2 - x = F.dropout(x, p=0.5, training=self.training) - x = F.relu_(self.fc1(x)) - embedding = F.dropout(x, p=0.5, training=self.training) - clipwise_output = torch.sigmoid(self.fc_audioset(x)) - - output_dict = { - "clipwise_output": clipwise_output, - "embedding": embedding, - "fine_grained_embedding": latent_output, - } - return output_dict - - -class Cnn6(nn.Module): - def __init__( - self, - sample_rate, - window_size, - hop_size, - mel_bins, - fmin, - fmax, - classes_num, - enable_fusion=False, - fusion_type="None", - ): - super(Cnn6, self).__init__() - - window = "hann" - center = True - pad_mode = "reflect" - ref = 1.0 - amin = 1e-10 - top_db = None - - self.enable_fusion = enable_fusion - self.fusion_type = fusion_type - - # Spectrogram extractor - self.spectrogram_extractor = Spectrogram( - n_fft=window_size, - hop_length=hop_size, - win_length=window_size, - window=window, - center=center, - pad_mode=pad_mode, - freeze_parameters=True, - ) - - # Logmel feature extractor - self.logmel_extractor = LogmelFilterBank( - sr=sample_rate, - n_fft=window_size, - n_mels=mel_bins, - fmin=fmin, - fmax=fmax, - ref=ref, - amin=amin, - top_db=top_db, - freeze_parameters=True, - ) - - # Spec augmenter - self.spec_augmenter = SpecAugmentation( - time_drop_width=64, - time_stripes_num=2, - freq_drop_width=8, - freq_stripes_num=2, - ) - - self.bn0 = nn.BatchNorm2d(64) - - self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) - self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) - self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) - self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) - - self.fc1 = nn.Linear(512, 512, bias=True) - self.fc_audioset = nn.Linear(512, classes_num, bias=True) - - self.init_weight() - - def init_weight(self): - init_bn(self.bn0) - init_layer(self.fc1) - init_layer(self.fc_audioset) - - def forward(self, input, mixup_lambda=None, device=None): - """ - Input: (batch_size, data_length)""" - - x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) - x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) - - x = x.transpose(1, 3) - x = self.bn0(x) - x = x.transpose(1, 3) - - if self.training: - x = self.spec_augmenter(x) - - # Mixup on spectrogram - if self.training and mixup_lambda is not None: - x = do_mixup(x, mixup_lambda) - - x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = torch.mean(x, dim=3) - - latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) - latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) - latent_x = latent_x1 + latent_x2 - latent_x = latent_x.transpose(1, 2) - latent_x = F.relu_(self.fc1(latent_x)) - latent_output = interpolate(latent_x, 16) - - (x1, _) = torch.max(x, dim=2) - x2 = torch.mean(x, dim=2) - x = x1 + x2 - x = F.dropout(x, p=0.5, training=self.training) - x = F.relu_(self.fc1(x)) - embedding = F.dropout(x, p=0.5, training=self.training) - clipwise_output = torch.sigmoid(self.fc_audioset(x)) - - output_dict = { - "clipwise_output": clipwise_output, - "embedding": embedding, - "fine_grained_embedding": latent_output, - } - - return output_dict - - -class Cnn10(nn.Module): - def __init__( - self, - sample_rate, - window_size, - hop_size, - mel_bins, - fmin, - fmax, - classes_num, - enable_fusion=False, - fusion_type="None", - ): - super(Cnn10, self).__init__() - - window = "hann" - center = True - pad_mode = "reflect" - ref = 1.0 - amin = 1e-10 - top_db = None - - self.enable_fusion = enable_fusion - self.fusion_type = fusion_type - - # Spectrogram extractor - self.spectrogram_extractor = Spectrogram( - n_fft=window_size, - hop_length=hop_size, - win_length=window_size, - window=window, - center=center, - pad_mode=pad_mode, - freeze_parameters=True, - ) - - # Logmel feature extractor - self.logmel_extractor = LogmelFilterBank( - sr=sample_rate, - n_fft=window_size, - n_mels=mel_bins, - fmin=fmin, - fmax=fmax, - ref=ref, - amin=amin, - top_db=top_db, - freeze_parameters=True, - ) - - # Spec augmenter - self.spec_augmenter = SpecAugmentation( - time_drop_width=64, - time_stripes_num=2, - freq_drop_width=8, - freq_stripes_num=2, - ) - - self.bn0 = nn.BatchNorm2d(64) - - self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) - self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) - self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) - self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) - self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) - - self.fc1 = nn.Linear(1024, 1024, bias=True) - self.fc_audioset = nn.Linear(1024, classes_num, bias=True) - - self.init_weight() - - def init_weight(self): - init_bn(self.bn0) - init_layer(self.fc1) - init_layer(self.fc_audioset) - - def forward(self, input, mixup_lambda=None, device=None): - """ - Input: (batch_size, data_length)""" - - x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) - x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) - - x = x.transpose(1, 3) - x = self.bn0(x) - x = x.transpose(1, 3) - - if self.training: - x = self.spec_augmenter(x) - - # Mixup on spectrogram - if self.training and mixup_lambda is not None: - x = do_mixup(x, mixup_lambda) - - x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = torch.mean(x, dim=3) - - latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) - latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) - latent_x = latent_x1 + latent_x2 - latent_x = latent_x.transpose(1, 2) - latent_x = F.relu_(self.fc1(latent_x)) - latent_output = interpolate(latent_x, 32) - - (x1, _) = torch.max(x, dim=2) - x2 = torch.mean(x, dim=2) - x = x1 + x2 - x = F.dropout(x, p=0.5, training=self.training) - x = F.relu_(self.fc1(x)) - embedding = F.dropout(x, p=0.5, training=self.training) - clipwise_output = torch.sigmoid(self.fc_audioset(x)) - - output_dict = { - "clipwise_output": clipwise_output, - "embedding": embedding, - "fine_grained_embedding": latent_output, - } - - return output_dict - - -def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"): - try: - ModelProto = eval(audio_cfg.model_name) - model = ModelProto( - sample_rate=audio_cfg.sample_rate, - window_size=audio_cfg.window_size, - hop_size=audio_cfg.hop_size, - mel_bins=audio_cfg.mel_bins, - fmin=audio_cfg.fmin, - fmax=audio_cfg.fmax, - classes_num=audio_cfg.class_num, - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - return model - except: - raise RuntimeError( - f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." - ) diff --git a/audioldm2/clap/open_clip/pretrained.py b/audioldm2/clap/open_clip/pretrained.py deleted file mode 100755 index e211d8b5b59320a599e62605f1dee6199f317253..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/pretrained.py +++ /dev/null @@ -1,167 +0,0 @@ -import hashlib -import os -import urllib -import warnings - -from tqdm import tqdm - -_RN50 = dict( - openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", - yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", - cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", -) - -_RN50_quickgelu = dict( - openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", - yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", - cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", -) - -_RN101 = dict( - openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", - yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", -) - -_RN101_quickgelu = dict( - openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", - yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", -) - -_RN50x4 = dict( - openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", -) - -_RN50x16 = dict( - openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", -) - -_RN50x64 = dict( - openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", -) - -_VITB32 = dict( - openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", - laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", - laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", - laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", -) - -_VITB32_quickgelu = dict( - openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", - laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", - laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", - laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", -) - -_VITB16 = dict( - openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", -) - -_VITL14 = dict( - openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", -) - -_PRETRAINED = { - "RN50": _RN50, - "RN50-quickgelu": _RN50_quickgelu, - "RN101": _RN101, - "RN101-quickgelu": _RN101_quickgelu, - "RN50x4": _RN50x4, - "RN50x16": _RN50x16, - "ViT-B-32": _VITB32, - "ViT-B-32-quickgelu": _VITB32_quickgelu, - "ViT-B-16": _VITB16, - "ViT-L-14": _VITL14, -} - - -def list_pretrained(as_str: bool = False): - """returns list of pretrained models - Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True - """ - return [ - ":".join([k, t]) if as_str else (k, t) - for k in _PRETRAINED.keys() - for t in _PRETRAINED[k].keys() - ] - - -def list_pretrained_tag_models(tag: str): - """return all models having the specified pretrain tag""" - models = [] - for k in _PRETRAINED.keys(): - if tag in _PRETRAINED[k]: - models.append(k) - return models - - -def list_pretrained_model_tags(model: str): - """return all pretrain tags for the specified model architecture""" - tags = [] - if model in _PRETRAINED: - tags.extend(_PRETRAINED[model].keys()) - return tags - - -def get_pretrained_url(model: str, tag: str): - if model not in _PRETRAINED: - return "" - model_pretrained = _PRETRAINED[model] - if tag not in model_pretrained: - return "" - return model_pretrained[tag] - - -def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): - os.makedirs(root, exist_ok=True) - filename = os.path.basename(url) - - if "openaipublic" in url: - expected_sha256 = url.split("/")[-2] - else: - expected_sha256 = "" - - download_target = os.path.join(root, filename) - - if os.path.exists(download_target) and not os.path.isfile(download_target): - raise RuntimeError(f"{download_target} exists and is not a regular file") - - if os.path.isfile(download_target): - if expected_sha256: - if ( - hashlib.sha256(open(download_target, "rb").read()).hexdigest() - == expected_sha256 - ): - return download_target - else: - warnings.warn( - f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" - ) - else: - return download_target - - with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm( - total=int(source.info().get("Content-Length")), - ncols=80, - unit="iB", - unit_scale=True, - ) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - if ( - expected_sha256 - and hashlib.sha256(open(download_target, "rb").read()).hexdigest() - != expected_sha256 - ): - raise RuntimeError( - f"Model has been downloaded but the SHA256 checksum does not not match" - ) - - return download_target diff --git a/audioldm2/clap/open_clip/timm_model.py b/audioldm2/clap/open_clip/timm_model.py deleted file mode 100755 index b8486b9e62580bb65f0f50a0a7000890cb7ee42d..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/timm_model.py +++ /dev/null @@ -1,112 +0,0 @@ -""" timm model adapter - -Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. -""" -from collections import OrderedDict - -import torch.nn as nn - -try: - import timm - from timm.models.layers import Mlp, to_2tuple - from timm.models.layers.attention_pool2d import RotAttentionPool2d - from timm.models.layers.attention_pool2d import ( - AttentionPool2d as AbsAttentionPool2d, - ) -except ImportError: - timm = None - -from .utils import freeze_batch_norm_2d - - -class TimmModel(nn.Module): - """timm model adapter - # FIXME this adapter is a work in progress, may change in ways that break weight compat - """ - - def __init__( - self, - model_name, - embed_dim, - image_size=224, - pool="avg", - proj="linear", - drop=0.0, - pretrained=False, - ): - super().__init__() - if timm is None: - raise RuntimeError("Please `pip install timm` to use timm models.") - - self.image_size = to_2tuple(image_size) - self.trunk = timm.create_model(model_name, pretrained=pretrained) - feat_size = self.trunk.default_cfg.get("pool_size", None) - feature_ndim = 1 if not feat_size else 2 - if pool in ("abs_attn", "rot_attn"): - assert feature_ndim == 2 - # if attn pooling used, remove both classifier and default pool - self.trunk.reset_classifier(0, global_pool="") - else: - # reset global pool if pool config set, otherwise leave as network default - reset_kwargs = dict(global_pool=pool) if pool else {} - self.trunk.reset_classifier(0, **reset_kwargs) - prev_chs = self.trunk.num_features - - head_layers = OrderedDict() - if pool == "abs_attn": - head_layers["pool"] = AbsAttentionPool2d( - prev_chs, feat_size=feat_size, out_features=embed_dim - ) - prev_chs = embed_dim - elif pool == "rot_attn": - head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) - prev_chs = embed_dim - else: - assert proj, "projection layer needed if non-attention pooling is used." - - # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used - if proj == "linear": - head_layers["drop"] = nn.Dropout(drop) - head_layers["proj"] = nn.Linear(prev_chs, embed_dim) - elif proj == "mlp": - head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) - - self.head = nn.Sequential(head_layers) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - """lock modules - Args: - unlocked_groups (int): leave last n layer groups unlocked (default: 0) - """ - if not unlocked_groups: - # lock full model - for param in self.trunk.parameters(): - param.requires_grad = False - if freeze_bn_stats: - freeze_batch_norm_2d(self.trunk) - else: - # NOTE: partial freeze requires latest timm (master) branch and is subject to change - try: - # FIXME import here until API stable and in an official release - from timm.models.helpers import group_parameters, group_modules - except ImportError: - raise RuntimeError( - "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" - ) - matcher = self.trunk.group_matcher() - gparams = group_parameters(self.trunk, matcher) - max_layer_id = max(gparams.keys()) - max_layer_id = max_layer_id - unlocked_groups - for group_idx in range(max_layer_id + 1): - group = gparams[group_idx] - for param in group: - self.trunk.get_parameter(param).requires_grad = False - if freeze_bn_stats: - gmodules = group_modules(self.trunk, matcher, reverse=True) - gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} - freeze_batch_norm_2d(self.trunk, gmodules) - - def forward(self, x): - x = self.trunk(x) - x = self.head(x) - return x diff --git a/audioldm2/clap/open_clip/tokenizer.py b/audioldm2/clap/open_clip/tokenizer.py deleted file mode 100755 index ee4d28450ec5dd12a79daf38cf3088e9e73c2cd5..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/tokenizer.py +++ /dev/null @@ -1,197 +0,0 @@ -""" CLIP tokenizer - -Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. -""" -import gzip -import html -import os -from functools import lru_cache -from typing import Union, List - -import ftfy -import regex as re -import torch - - -@lru_cache() -def default_bpe(): - return os.path.join( - os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" - ) - - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -class SimpleTokenizer(object): - def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") - merges = merges[1 : 49152 - 256 - 2 + 1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v + "" for v in vocab] - for merge in merges: - vocab.append("".join(merge)) - if not special_tokens: - special_tokens = ["", ""] - else: - special_tokens = ["", ""] + special_tokens - vocab.extend(special_tokens) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.decoder = {v: k for k, v in self.encoder.items()} - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {t: t for t in special_tokens} - special = "|".join(special_tokens) - self.pat = re.compile( - special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", - re.IGNORECASE, - ) - - self.vocab_size = len(self.encoder) - self.all_special_ids = [self.encoder[t] for t in special_tokens] - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + (token[-1] + "",) - pairs = get_pairs(word) - - if not pairs: - return token + "" - - while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word) - 1 and word[i + 1] == second: - new_word.append(first + second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = " ".join(word) - self.cache[token] = word - return word - - def encode(self, text): - bpe_tokens = [] - text = whitespace_clean(basic_clean(text)).lower() - for token in re.findall(self.pat, text): - token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) - bpe_tokens.extend( - self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") - ) - return bpe_tokens - - def decode(self, tokens): - text = "".join([self.decoder[token] for token in tokens]) - text = ( - bytearray([self.byte_decoder[c] for c in text]) - .decode("utf-8", errors="replace") - .replace("", " ") - ) - return text - - -_tokenizer = SimpleTokenizer() - - -def tokenize( - texts: Union[str, List[str]], context_length: int = 77 -) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] - - sot_token = _tokenizer.encoder[""] - eot_token = _tokenizer.encoder[""] - all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - tokens = tokens[:context_length] # Truncate - result[i, : len(tokens)] = torch.tensor(tokens) - - return result diff --git a/audioldm2/clap/open_clip/transform.py b/audioldm2/clap/open_clip/transform.py deleted file mode 100755 index 77aaa722c4a5544ac50de6df35d3e922f63b111d..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/transform.py +++ /dev/null @@ -1,45 +0,0 @@ -from torchvision.transforms import ( - Normalize, - Compose, - RandomResizedCrop, - InterpolationMode, - ToTensor, - Resize, - CenterCrop, -) - - -def _convert_to_rgb(image): - return image.convert("RGB") - - -def image_transform( - image_size: int, - is_train: bool, - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), -): - normalize = Normalize(mean=mean, std=std) - if is_train: - return Compose( - [ - RandomResizedCrop( - image_size, - scale=(0.9, 1.0), - interpolation=InterpolationMode.BICUBIC, - ), - _convert_to_rgb, - ToTensor(), - normalize, - ] - ) - else: - return Compose( - [ - Resize(image_size, interpolation=InterpolationMode.BICUBIC), - CenterCrop(image_size), - _convert_to_rgb, - ToTensor(), - normalize, - ] - ) diff --git a/audioldm2/clap/open_clip/utils.py b/audioldm2/clap/open_clip/utils.py deleted file mode 100755 index 77875569ff4aff81bf9545ce6ec58e0326d49d0c..0000000000000000000000000000000000000000 --- a/audioldm2/clap/open_clip/utils.py +++ /dev/null @@ -1,356 +0,0 @@ -import numpy as np -import torch -from torch import nn as nn -from torchvision.ops.misc import FrozenBatchNorm2d -import logging -import h5py -from tqdm import tqdm -import random -import json -import os -import pathlib - -# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later. -dataset_split = { - "audiocaps": ["train", "valid", "test"], - "audioset": ["balanced_train", "unbalanced_train", "eval"], - "BBCSoundEffects": ["train", "test"], - "Clotho": ["train", "test", "valid"], - "free_to_use_sounds": ["train", "test"], - "paramount_motion": ["train", "test"], - "sonniss_game_effects": ["train", "test"], - "wesoundeffects": ["train", "test"], - "MACS": ["train", "test"], - "freesound": ["train", "test"], - "FSD50K": ["train", "test", "valid"], - "fsd50k_class_label": ["train", "test", "valid"], - "esc50": ["train", "test"], - "audiostock": ["train", "test"], - "freesound_no_overlap_noesc50": ["train", "test"], - "epidemic_sound_effects": ["train", "test"], - "VGGSound": ["train", "test"], - "urbansound8k_class_label": ["train", "test"], - "audioset_t5": ["balanced_train", "unbalanced_train", "eval"], - "epidemic_sound_effects_t5": ["train", "test"], - "WavText5K": ["train", "test"], - "esc50_no_overlap": ["train", "test"], - "usd8k_no_overlap": ["train", "test"], - "fsd50k_200_class_label": ["train", "test", "valid"], -} - - -def freeze_batch_norm_2d(module, module_match={}, name=""): - """ - Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is - itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and - returned. Otherwise, the module is walked recursively and submodules are converted in place. - - Args: - module (torch.nn.Module): Any PyTorch module. - module_match (dict): Dictionary of full module names to freeze (all if empty) - name (str): Full module name (prefix) - - Returns: - torch.nn.Module: Resulting module - - Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 - """ - res = module - is_match = True - if module_match: - is_match = name in module_match - if is_match and isinstance( - module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) - ): - res = FrozenBatchNorm2d(module.num_features) - res.num_features = module.num_features - res.affine = module.affine - if module.affine: - res.weight.data = module.weight.data.clone().detach() - res.bias.data = module.bias.data.clone().detach() - res.running_mean.data = module.running_mean.data - res.running_var.data = module.running_var.data - res.eps = module.eps - else: - for child_name, child in module.named_children(): - full_child_name = ".".join([name, child_name]) if name else child_name - new_child = freeze_batch_norm_2d(child, module_match, full_child_name) - if new_child is not child: - res.add_module(child_name, new_child) - return res - - -def exist(dataset_name, dataset_type): - """ - Check if dataset exists - """ - if dataset_type in dataset_split[dataset_name]: - return True - else: - return False - - -def get_tar_path_from_dataset_name( - dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None -): - """ - Get tar path from dataset name and type - """ - output = [] - for n in dataset_names: - if full_dataset is not None and n in full_dataset: - current_dataset_types = dataset_split[n] - else: - current_dataset_types = dataset_types - for s in current_dataset_types: - tmp = [] - if islocal: - sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json" - if not os.path.exists(sizefilepath_): - sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" - else: - sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" - if not os.path.exists(sizefilepath_): - continue - sizes = json.load(open(sizefilepath_, "r")) - for k in sizes.keys(): - if islocal: - tmp.append(f"{dataset_path}/{n}/{s}/{k}") - else: - tmp.append( - f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -" - ) - if proportion != 1: - tmp = random.sample(tmp, int(proportion * len(tmp))) - output.append(tmp) - return sum(output, []) - - -def get_tar_path_from_txts(txt_path, islocal, proportion=1): - """ - Get tar path from txt path - """ - if isinstance(txt_path, (list, tuple)): - return sum( - [ - get_tar_path_from_txts( - txt_path[i], islocal=islocal, proportion=proportion - ) - for i in range(len(txt_path)) - ], - [], - ) - if isinstance(txt_path, str): - with open(txt_path) as f: - lines = f.readlines() - if islocal: - lines = [ - lines[i] - .split("\n")[0] - .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/") - for i in range(len(lines)) - ] - else: - lines = [ - lines[i].split("\n")[0].replace(".tar", ".tar -") - for i in range(len(lines)) - ] - if proportion != 1: - print("Sampling tars with proportion of {}".format(proportion)) - lines = random.sample(lines, int(proportion * len(lines))) - return lines - - -def get_mix_lambda(mixup_alpha, batch_size): - mixup_lambdas = [ - np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size) - ] - return np.array(mixup_lambdas).astype(np.float32) - - -def do_mixup(x, mixup_lambda): - """ - Args: - x: (batch_size , ...) - mixup_lambda: (batch_size,) - Returns: - out: (batch_size, ...) - """ - out = ( - x.transpose(0, -1) * mixup_lambda - + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda) - ).transpose(0, -1) - return out - - -def interpolate(x, ratio): - """Interpolate data in time domain. This is used to compensate the - resolution reduction in downsampling of a CNN. - - Args: - x: (batch_size, time_steps, classes_num) - ratio: int, ratio to interpolate - Returns: - upsampled: (batch_size, time_steps * ratio, classes_num) - """ - (batch_size, time_steps, classes_num) = x.shape - upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) - upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) - return upsampled - - -def pad_framewise_output(framewise_output, frames_num): - """Pad framewise_output to the same length as input frames. The pad value - is the same as the value of the last frame. - Args: - framewise_output: (batch_size, frames_num, classes_num) - frames_num: int, number of frames to pad - Outputs: - output: (batch_size, frames_num, classes_num) - """ - pad = framewise_output[:, -1:, :].repeat( - 1, frames_num - framewise_output.shape[1], 1 - ) - """tensor for padding""" - - output = torch.cat((framewise_output, pad), dim=1) - """(batch_size, frames_num, classes_num)""" - - -def process_ipc(index_path, classes_num, filename): - # load data - logging.info("Load Data...............") - ipc = [[] for _ in range(classes_num)] - with h5py.File(index_path, "r") as f: - for i in tqdm(range(len(f["target"]))): - t_class = np.where(f["target"][i])[0] - for t in t_class: - ipc[t].append(i) - print(ipc) - np.save(filename, ipc) - logging.info("Load Data Succeed...............") - - -def save_to_dict(s, o_={}): - sp = s.split(": ") - o_.update({sp[0]: float(sp[1])}) - return o_ - - -def get_data_from_log(txt_path): - """ - Output dictionary from out.txt log file - """ - with open(txt_path) as f: - lines = f.readlines() - val_data = {} - train_data = {} - train_losses = [] - train_losses_epoch = [] - for i in range(len(lines)): - if "| INFO |" in lines[i]: - if "Eval Epoch" in lines[i]: - if "val_loss" in lines[i]: - # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", "")) - line = lines[i].split("Eval Epoch: ")[-1] - num_epoch = int(line.split(" ")[0].split(" ")[0]) - d = { - line.split(" ")[0] - .split(" ")[1] - .replace(":", ""): float(line.split(" ")[0].split(" ")[-1]) - } - for i in range(1, len(line.split(" "))): - d = save_to_dict(line.split(" ")[i], d) - val_data[num_epoch] = d - elif "Train Epoch" in lines[i]: - num_epoch = int(lines[i].split("Train Epoch: ")[1][0]) - loss = float(lines[i].split("Loss: ")[-1].split(" (")[0]) - train_losses.append(loss) - train_losses_epoch.append(num_epoch) - for i in range(len(train_losses)): - train_data[i] = { - "num_epoch": train_losses_epoch[i], - "train_loss": train_losses[i], - } - return train_data, val_data - - -def save_p(obj, filename): - import pickle - - try: - from deepdiff import DeepDiff - except: - os.system("pip install deepdiff") - from deepdiff import DeepDiff - with open(filename, "wb") as file: - pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol - with open(filename, "rb") as file: - z = pickle.load(file) - assert ( - DeepDiff(obj, z, ignore_string_case=True) == {} - ), "there is something wrong with the saving process" - return - - -def load_p(filename): - import pickle - - with open(filename, "rb") as file: - z = pickle.load(file) - return z - - -def save_json(data, name="data.json"): - import json - - with open(name, "w") as fp: - json.dump(data, fp) - return - - -def load_json(name): - import json - - with open(name, "r") as fp: - data = json.load(fp) - return data - - -def load_class_label(path): - # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing - # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array - out = None - if path is not None: - if pathlib.Path(path).suffix in [".pkl", ".pickle"]: - out = load_p(path) - elif pathlib.Path(path).suffix in [".json", ".txt"]: - out = load_json(path) - elif pathlib.Path(path).suffix in [".npy", ".npz"]: - out = np.load(path) - elif pathlib.Path(path).suffix in [".csv"]: - import pandas as pd - - out = pd.read_csv(path) - return out - # if out is None: - # return None - # else: - # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False) - # val = Array('i', out.values(), lock=False) - # return (key, val) - - -from torch import optim - - -def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): - if optimizer_name.lower() == "adamw": - optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps) - elif optimizer_name.lower() == "sgd": - optimizer = optim.SGD(params, lr=lr, momentum=momentum) - elif optimizer_name.lower() == "adam": - optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps) - else: - raise ValueError("optimizer name is not correct") - return optimizer diff --git a/audioldm2/clap/training/__init__.py b/audioldm2/clap/training/__init__.py deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm2/clap/training/audioset_textmap.npy b/audioldm2/clap/training/audioset_textmap.npy deleted file mode 100755 index 3da4c92d3819aaec11e5f576464a9973a6df811b..0000000000000000000000000000000000000000 --- a/audioldm2/clap/training/audioset_textmap.npy +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b -size 84448 diff --git a/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz b/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz deleted file mode 100644 index 36a15856e00a06a9fbed8cdd34d2393fea4a3113..0000000000000000000000000000000000000000 --- a/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a -size 1356917 diff --git a/audioldm2/clap/training/data.py b/audioldm2/clap/training/data.py deleted file mode 100755 index ae01406c63a9b1c678151f67dacd7ea192cb84f2..0000000000000000000000000000000000000000 --- a/audioldm2/clap/training/data.py +++ /dev/null @@ -1,865 +0,0 @@ -import json -import logging -import os -import random -import h5py -from dataclasses import dataclass -import numpy as np -import pandas as pd -import torch -import torchvision.datasets as datasets -from PIL import Image -from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler -from torch.utils.data.distributed import DistributedSampler -import soundfile as sf -import io -from pathlib import Path -# import wget - -from audioldm2.clap.open_clip.utils import get_tar_path_from_dataset_name -from audioldm2.clap.open_clip.utils import load_class_label - -try: - import horovod.torch as hvd -except ImportError: - hvd = None - -try: - import torchaudio -except ImportError: - torchaudio = None - -from audioldm2.clap.open_clip import tokenize - - -def tokenizer(text): - return tokenize(text).squeeze(0) - - -from transformers import RobertaTokenizer - -tokenize = RobertaTokenizer.from_pretrained("roberta-base") - - -def tokenizer(text): - result = tokenize( - text, - padding="max_length", - truncation=True, - max_length=77, - return_tensors="pt", - ) - return {k: v.squeeze(0) for k, v in result.items()} - - -# initizlied the audioset map -_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy") -_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True) - - -def int16_to_float32(x): - return (x / 32767.0).astype(np.float32) - - -def float32_to_int16(x): - x = np.clip(x, a_min=-1.0, a_max=1.0) - return (x * 32767.0).astype(np.int16) - - -# For Toy Dataset -class ToyDataset(Dataset): - def __init__(self, index_path, ipc, config, eval_mode=False): - """Toy Dataset for testing the audioset input with text labels - Parameters - ---------- - index_path: str - the link to the h5 file of each audio - idc: str - the link to the npy file, the number of samples in each class - config: dict - the audio cfg file - eval_model (bool): to indicate if the dataset is a testing dataset - """ - self.audio_cfg = config["audio_cfg"] - self.text_cfg = config["text_cfg"] - self.fp = h5py.File(index_path, "r") - self.ipc = np.load(ipc, allow_pickle=True) - self.total_size = len(self.fp["audio_name"]) - self.classes_num = self.audio_cfg["class_num"] - self.eval_mode = eval_mode - - if not eval_mode: - self.generate_queue() - else: - self.queue = [] - for i in range(self.total_size): - target = self.fp["target"][i] - if np.sum(target) > 0: - self.queue.append(i) - self.total_size = len(self.queue) - logging.info("total dataset size: %d" % (self.total_size)) - logging.info("class num: %d" % (self.classes_num)) - - def time_shifting(self, x): - frame_num = len(x) - shift_len = random.randint(0, frame_num - 1) - new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0) - return new_sample - - def generate_queue(self): - self.queue = [] - while len(self.queue) < self.total_size: - class_set = [*range(self.classes_num)] - random.shuffle(class_set) - self.queue += [ - self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set - ] - self.queue = self.queue[: self.total_size] - - logging.info("queue regenerated:%s" % (self.queue[-5:])) - - def crop_wav(self, x): - crop_size = self.audio_cfg["crop_size"] - crop_pos = random.randint(0, len(x) - crop_size - 1) - return x[crop_pos : crop_pos + crop_size] - - def prompt_text(self, target): - events = _AUDIOSET_MAP[np.where(target > 0)] - event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1] - text = tokenize(event_text)[0] - return text - - def __getitem__(self, index): - """Load waveform, text, and target of an audio clip - - Parameters - ---------- - index: int - the index number - Return - ------ - output: dict { - "hdf5_path": str, - "index_in_hdf5": int, - "audio_name": str, - "waveform": list (audio_length,), - "target": list (class_num, ), - "text": torch.tensor (context_length,) - } - the output dictionary - """ - s_index = self.queue[index] - - audio_name = self.fp["audio_name"][s_index].decode() - # Hardcode here CHANGE - hdf5_path = ( - self.fp["hdf5_path"][s_index] - .decode() - .replace( - "../workspace", - "/home/la/kechen/Research/ke_zsasp/workspace", - ) - ) - r_idx = self.fp["index_in_hdf5"][s_index] - target = self.fp["target"][s_index].astype(np.float32) - text = self.prompt_text(target) - with h5py.File(hdf5_path, "r") as f: - waveform = int16_to_float32(f["waveform"][r_idx])[ - : self.audio_cfg["clip_samples"] - ] - assert ( - len(waveform) == self.audio_cfg["clip_samples"] - ), "The sample length is not match" - # Time shift - # if (self.config.enable_time_shift) and (not self.eval_mode): - # waveform = self.time_shifting(waveform) - # # Label Enhance - # if (self.config.crop_size is not None) and (not self.eval_mode): - # waveform = self.crop_wav(waveform) - # # the label enhance rate is fixed 0.5 - # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5: - # kidx = np.where(target)[0] - # for k in kidx: - # for add_key in self.class_map[k][1]: - # target[add_key] = 1.0 - # if len(self.class_map[k][2]) > 0: - # add_key = random.choice(self.class_map[k][2]) - # target[add_key] = 1.0 - - # missing the text input - mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :] - mel_spec = ( - torch.cat( - [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0 - ) - .cpu() - .numpy() - ) - longer = random.choice([True, False]) - if longer == False: - mel_spec[1:, :, :] = 0.0 - data_dict = { - "hdf5_path": hdf5_path, - "index_in_hdf5": r_idx, - "audio_name": audio_name, - "waveform": waveform, - "class_label": target, - "text": text, - "longer": longer, - "mel_fusion": mel_spec, - } - return data_dict - - def __len__(self): - return self.total_size - - -class CsvDataset(Dataset): - def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"): - logging.debug(f"Loading csv data from {input_filename}.") - df = pd.read_csv(input_filename, sep=sep) - - self.images = df[img_key].tolist() - self.captions = df[caption_key].tolist() - self.transforms = transforms - logging.debug("Done loading data.") - - def __len__(self): - return len(self.captions) - - def __getitem__(self, idx): - images = self.transforms(Image.open(str(self.images[idx]))) - texts = tokenize([str(self.captions[idx])])[0] - return images, texts - - -@dataclass -class DataInfo: - dataloader: DataLoader - sampler: DistributedSampler - - -def preprocess_txt(text): - return tokenize([str(text)])[0] - - -# def get_dataset_size(shards, sizefilepath_=None, is_local=True): -# if isinstance(shards, list): -# size_list = [] -# for s in shards: -# size_list.append( -# get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0] -# ) -# else: -# if not is_local: -# for n in dataset_split.keys(): -# if n in shards.split("/"): -# break -# for s in dataset_split[n]: -# if s in shards.split("/"): -# break -# sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" -# shards_list = list(braceexpand.braceexpand(shards)) -# dir_path = os.path.dirname(shards) -# if sizefilepath_ is not None: -# sizes = json.load(open(sizefilepath_, "r")) -# total_size = sum( -# [ -# int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))]) -# for shard in shards_list -# ] -# ) -# else: -# sizes_filename = os.path.join(dir_path, "sizes.json") -# len_filename = os.path.join(dir_path, "__len__") -# if os.path.exists(sizes_filename): -# sizes = json.load(open(sizes_filename, "r")) -# total_size = sum( -# [int(sizes[os.path.basename(shard)]) for shard in shards_list] -# ) -# elif os.path.exists(len_filename): -# # FIXME this used to be eval(open(...)) but that seemed rather unsafe -# total_size = ast.literal_eval(open(len_filename, "r").read()) -# else: -# raise Exception( -# "Cannot find sizes file for dataset. Please specify the path to the file." -# ) -# # total_size = None # num samples undefined -# # some common dataset sizes (at time of authors last download) -# # cc3m-train: 2905954 -# # cc12m: 10968539 -# # LAION-400m: 407332084 -# num_shards = len(shards_list) -# if isinstance(shards, list): -# return sum(size_list), len(shards) -# else: -# return total_size, num_shards - - -def get_imagenet(args, preprocess_fns, split): - assert split in ["train", "val", "v2"] - is_train = split == "train" - preprocess_train, preprocess_val = preprocess_fns - - if split == "v2": - from imagenetv2_pytorch import ImageNetV2Dataset - - dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) - else: - if is_train: - data_path = args.imagenet_train - preprocess_fn = preprocess_train - else: - data_path = args.imagenet_val - preprocess_fn = preprocess_val - assert data_path - - dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) - - if is_train: - idxs = np.zeros(len(dataset.targets)) - target_array = np.array(dataset.targets) - k = 50 - for c in range(1000): - m = target_array == c - n = len(idxs[m]) - arr = np.zeros(n) - arr[:k] = 1 - np.random.shuffle(arr) - idxs[m] = arr - - idxs = idxs.astype("int") - sampler = SubsetRandomSampler(np.where(idxs)[0]) - else: - sampler = None - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=args.batch_size, - num_workers=args.workers, - sampler=sampler, - ) - - return DataInfo(dataloader, sampler) - - -def count_samples(dataloader): - os.environ["WDS_EPOCH"] = "0" - n_elements, n_batches = 0, 0 - for images, texts in dataloader: - n_batches += 1 - n_elements += len(images) - assert len(images) == len(texts) - return n_elements, n_batches - - -def filter_no_caption(sample): - return "txt" in sample - - -def log_and_continue(exn): - """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" - logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") - return True - - -_SHARD_SHUFFLE_SIZE = 2000 -_SHARD_SHUFFLE_INITIAL = 500 -_SAMPLE_SHUFFLE_SIZE = 5000 -_SAMPLE_SHUFFLE_INITIAL = 1000 - - -# def sample_prop(sizefile, inputs, proportion, is_local=True): -# """ -# Sample a proportion of the data. -# """ -# file_path_dict = { -# os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0] -# for i in range(len(inputs)) -# } -# sampled_filepath_dict = {} -# sampled_size_dict = {} -# if not is_local: -# if os.path.exists("sizes.json"): -# os.remove("sizes.json") -# wget.download(sizefile, "sizes.json") -# sizefile = "sizes.json" -# with open(sizefile, "r", encoding="UTF-8") as f: -# load_dict = json.load(f) -# L = int(len(file_path_dict) * proportion) -# subkeys = random.sample(file_path_dict.keys(), L) -# for k in subkeys: -# sampled_size_dict[k] = load_dict[k] -# sampled_filepath_dict[k] = file_path_dict[k] -# return ( -# sum(sampled_size_dict.values()), -# L, -# [os.path.join(v, k) for k, v in sampled_filepath_dict.items()], -# sampled_size_dict, -# ) - - -def get_mel(audio_data, audio_cfg): - # mel shape: (n_mels, T) - mel = torchaudio.transforms.MelSpectrogram( - sample_rate=audio_cfg["sample_rate"], - n_fft=audio_cfg["window_size"], - win_length=audio_cfg["window_size"], - hop_length=audio_cfg["hop_size"], - center=True, - pad_mode="reflect", - power=2.0, - norm=None, - onesided=True, - n_mels=64, - f_min=audio_cfg["fmin"], - f_max=audio_cfg["fmax"], - ).to(audio_data.device) - mel = mel(audio_data) - # we use log mel spectrogram as input - mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) - return mel.T # (T, n_mels) - - -def get_audio_features( - audio_data, mel, max_len, data_truncating, data_filling, audio_cfg -): - """ - Calculate and add audio features to sample. - Sample: a dict containing all the data of current sample. - audio_data: a tensor of shape (T) containing audio data. - max_len: the maximum length of audio data. - data_truncating: the method of truncating data. - data_filling: the method of filling data. - audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. - """ - sample = {} - - # assert audio_data.size(-1) <= max_len, str(audio_data.size()) - - # split to three parts - chunk_frames = ( - max_len // audio_cfg["hop_size"] + 1 - ) # the +1 related to how the spectrogram is computed - mel = mel[:chunk_frames] - - audio_data = audio_data[..., :max_len] - sample["mel_fusion"] = mel - longer = torch.tensor([True]) - - sample["longer"] = longer - sample["waveform"] = audio_data - - return sample - - -def preprocess( - sample, - audio_ext, - text_ext, - max_len, - audio_cfg, - class_index_dict=None, - data_filling="pad", - data_truncating="rand_trunc", - text_augment_selection=None, -): - """ - Preprocess a single sample for wdsdataloader. - """ - audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) - audio_data = int16_to_float32(float32_to_int16(audio_data)) - audio_data = torch.tensor(audio_data).float() - - # TODO: (yusong) to be include in the future - # # if torchaudio not installed, use soundfile to load audio - # if torchaudio is None: - # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) - # audio_data = torch.tensor(audio_data).float() - # else: - # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py - # with tempfile.TemporaryDirectory() as dirname: - # os.makedirs(dirname, exist_ok=True) - # fname = os.path.join(dirname, f"file.flac") - # with open(fname, "wb") as stream: - # stream.write(sample[audio_ext]) - # audio_data, orig_sr = torchaudio.load(fname) - # audio_data = audio_data[0, :].float() - - sample = get_audio_features( - sample, audio_data, max_len, data_truncating, data_filling, audio_cfg - ) - del sample[audio_ext] - - try: - json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) - except: - print("sample[__url__]:", sample["__url__"]) - - # For selecting augmented text from dataset - if text_augment_selection is None or text_augment_selection == "none": - texts = json_dict_raw["text"] - elif text_augment_selection == "all": - if "text_augment_all" in json_dict_raw.keys(): - texts = json_dict_raw["text_augment_all"] - else: - texts = json_dict_raw["text"] - elif text_augment_selection == "augment_only": - if "text_augment_all" in json_dict_raw.keys(): - if json_dict_raw["text_augment_t5"] is None: - texts = json_dict_raw["text"] - else: - texts = json_dict_raw["text_augment_t5"] - else: - texts = json_dict_raw["text"] - else: - raise NotImplementedError( - f"text_augment_selection {text_augment_selection} not implemented" - ) - sample["full_text"] = texts - - if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: - texts = random.choice(texts) - sample["raw_text"] = texts - sample["text"] = tokenizer(texts) # text shape: [num_token] - if class_index_dict is not None: - # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing - # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array - # key, val = class_index_dict - # key = key[:].split('\n') - # _dict = {k: v for k, v in zip(key, val)} - sample["class_label"] = np.zeros(len(class_index_dict.keys())) - for x in json_dict_raw["tag"]: - sample["class_label"][class_index_dict[x]] = 1 - sample["class_label"] = torch.tensor(sample["class_label"]).float() - del sample[text_ext] - sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext - sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext - sample["audio_orig_sr"] = orig_sr - return sample - - -def collate_fn(batch): - """ - Collate function for wdsdataloader. - batch: a list of dict, each dict is a sample - """ - # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend. - batch_dict = {} - for k in batch[0].keys(): - if isinstance(batch[0][k], dict): # dealwith bert tokenizer output - batch_dict[k] = {} - for kk in batch[0][k].keys(): - tmp = [] - for i in range(len(batch)): - tmp.append(batch[i][k][kk]) - batch_dict[k][kk] = torch.vstack(tmp) - elif isinstance(batch[0][k], torch.Tensor): - batch_dict[k] = torch.stack([sample[k] for sample in batch]) - elif isinstance(batch[0][k], np.ndarray): - batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch])) - else: - batch_dict[k] = [sample[k] for sample in batch] - return batch_dict - - -# def get_wds_dataset( -# args, -# model_cfg, -# is_train, -# audio_ext="flac", -# text_ext="json", -# max_len=480000, -# proportion=1.0, -# sizefilepath_=None, -# is_local=None, -# ): -# """ -# Get a dataset for wdsdataloader. -# """ -# if is_local is None and (not args.remotedata is None): -# is_local = not args.remotedata - -# input_shards = args.train_data if is_train else args.val_data -# assert input_shards is not None - -# if not sizefilepath_ is None: -# sizefilepath = sizefilepath_ -# else: -# sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json") - -# if proportion != 1.0: -# num_samples, num_shards, input_shards, _ = sample_prop( -# sizefilepath, input_shards, proportion, is_local=is_local -# ) -# else: -# num_samples, num_shards = get_dataset_size( -# input_shards, sizefilepath_=sizefilepath_, is_local=is_local -# ) - -# if not num_samples: -# if is_train: -# num_samples = args.train_num_samples -# if not num_samples: -# raise RuntimeError( -# "Currently, number of dataset samples must be specified for training dataset. " -# "Please specify via `--train-num-samples` if no dataset length info present." -# ) -# else: -# num_samples = ( -# args.val_num_samples or 0 -# ) # eval will just exhaust the iterator if not specified - -# pipeline = [wds.SimpleShardList(input_shards)] -# # at this point we have an iterator over all the shards -# # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node -# if is_train or args.parallel_eval: -# pipeline.extend( -# [ -# wds.detshuffle( -# bufsize=_SHARD_SHUFFLE_SIZE, -# initial=_SHARD_SHUFFLE_INITIAL, -# seed=args.seed, -# ), -# wds.split_by_node, -# wds.split_by_worker, -# # at this point, we have an iterator over the shards assigned to each worker at each node -# wds.tarfile_to_samples(handler=log_and_continue), -# wds.shuffle( -# bufsize=_SAMPLE_SHUFFLE_SIZE, -# initial=_SAMPLE_SHUFFLE_INITIAL, -# rng=random.Random(args.seed), -# ), -# # wds.repeatedly, # FIXME determine if this is beneficial -# ] -# ) -# else: -# pipeline.extend( -# [ -# wds.split_by_worker, -# # at this point, we have an iterator over the shards assigned to each worker -# wds.tarfile_to_samples(handler=log_and_continue), -# ] -# ) -# pipeline.append( -# wds.map( -# partial( -# preprocess, -# audio_ext=audio_ext, -# text_ext=text_ext, -# max_len=max_len, -# audio_cfg=model_cfg["audio_cfg"], -# class_index_dict=copy.deepcopy(args.class_index_dict), -# data_filling=args.data_filling, -# data_truncating=args.data_truncating, -# text_augment_selection=args.text_augment_selection, -# ) -# ), -# ) - -# pipeline.append( -# wds.batched( -# args.batch_size, -# partial=not (is_train or args.parallel_eval), -# collation_fn=collate_fn, -# ) -# ) - -# dataset = wds.DataPipeline(*pipeline) -# if is_train or args.parallel_eval: -# # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples. -# # (yusong): See comments below. -# # roll over and repeat a few samples to get same number of full batches on each node -# global_batch_size = args.batch_size * args.world_size -# num_batches = math.ceil(num_samples / global_batch_size) -# num_workers = max(1, args.workers) -# num_worker_batches = math.ceil( -# num_batches / num_workers -# ) # per dataloader worker -# num_batches = num_worker_batches * num_workers -# num_samples = num_batches * global_batch_size -# dataset = dataset.with_epoch( -# num_worker_batches -# ) # each worker is iterating over this -# else: -# # last batches are partial, eval is done on single (master) node -# num_batches = math.ceil(num_samples / args.batch_size) - -# kwargs = {} -# if args.horovod: # multi-node training on summit -# kwargs["multiprocessing_context"] = "forkserver" - -# dataloader = wds.WebLoader( -# dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs -# ) - -# # FIXME not clear which approach is better, with_epoch before vs after dataloader? -# # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 -# # if is_train: -# # # roll over and repeat a few samples to get same number of full batches on each node -# # global_batch_size = args.batch_size * args.world_size -# # num_batches = math.ceil(num_samples / global_batch_size) -# # num_workers = max(1, args.workers) -# # num_batches = math.ceil(num_batches / num_workers) * num_workers -# # num_samples = num_batches * global_batch_size -# # dataloader = dataloader.with_epoch(num_batches) -# # else: -# # # last batches are partial, eval is done on single (master) node -# # num_batches = math.ceil(num_samples / args.batch_size) - -# # add meta-data to dataloader instance for convenience -# dataloader.num_batches = num_batches -# dataloader.num_samples = num_samples - -# return DataInfo(dataloader, None) - - -def wds_batch_list2dict( - batch, - keys=[ - "__url__", - "__key__", - "waveform", - "text", - "raw_text", - "audio_name", - "text_name", - "audio_orig_sr", - ], -): - """ - Return a dictionary of the batch, with keys as the names of the fields. - """ - assert len(keys) == len( - batch - ), "batch must have same number of keys as keys argument" - return {keys[i]: batch[i] for i in range(len(batch))} - - -def get_csv_dataset(args, preprocess_fn, is_train): - input_filename = args.train_data if is_train else args.val_data - assert input_filename - dataset = CsvDataset( - input_filename, - preprocess_fn, - img_key=args.csv_img_key, - caption_key=args.csv_caption_key, - sep=args.csv_separator, - ) - num_samples = len(dataset) - sampler = DistributedSampler(dataset) if args.distributed and is_train else None - shuffle = is_train and sampler is None - - dataloader = DataLoader( - dataset, - batch_size=args.batch_size, - shuffle=shuffle, - num_workers=args.workers, - pin_memory=True, - sampler=sampler, - drop_last=is_train, - ) - dataloader.num_samples = num_samples - dataloader.num_batches = len(dataloader) - - return DataInfo(dataloader, sampler) - - -def get_toy_dataset(args, model_cfg, is_train): - index_path = args.train_data if is_train else args.val_data - ipc_path = args.train_ipc if is_train else args.val_ipc - assert index_path and ipc_path - eval_mode = not is_train - dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode) - - num_samples = len(dataset) - sampler = ( - DistributedSampler(dataset, shuffle=False) - if args.distributed and is_train - else None - ) - - dataloader = DataLoader( - dataset, - batch_size=args.batch_size, - shuffle=False, - num_workers=args.workers, - sampler=sampler, - drop_last=is_train, - ) - dataloader.num_samples = num_samples - dataloader.num_batches = len(dataloader) - - return DataInfo(dataloader, sampler) - - -def get_dataset_fn(data_path, dataset_type): - if dataset_type == "webdataset": - return get_wds_dataset - elif dataset_type == "csv": - return get_csv_dataset - elif dataset_type == "auto": - ext = data_path.split(".")[-1] - if ext in ["csv", "tsv"]: - return get_csv_dataset - elif ext in ["tar"]: - return get_wds_dataset - else: - raise ValueError( - f"Tried to figure out dataset type, but failed for extention {ext}." - ) - elif dataset_type == "toy": - return get_toy_dataset - else: - raise ValueError(f"Unsupported dataset type: {dataset_type}") - - -def get_data(args, model_cfg): - data = {} - - args.class_index_dict = load_class_label(args.class_label_path) - - if args.datasetinfos is None: - args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] - if args.dataset_type == "webdataset": - args.train_data = get_tar_path_from_dataset_name( - args.datasetnames, - args.datasetinfos, - islocal=not args.remotedata, - proportion=args.dataset_proportion, - dataset_path=args.datasetpath, - full_dataset=args.full_train_dataset, - ) - - if args.full_train_dataset is None: - args.full_train_dataset = [] - if args.exclude_eval_dataset is None: - args.exclude_eval_dataset = [] - excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset - - val_dataset_names = ( - [n for n in args.datasetnames if n not in excluded_eval_datasets] - if excluded_eval_datasets - else args.datasetnames - ) - args.val_dataset_names = val_dataset_names - args.val_data = get_tar_path_from_dataset_name( - val_dataset_names, - ["valid", "test", "eval"], - islocal=not args.remotedata, - proportion=1, - dataset_path=args.datasetpath, - full_dataset=None, - ) - - if args.train_data: - data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( - args, model_cfg, is_train=True - ) - - if args.val_data: - data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( - args, model_cfg, is_train=False - ) - - return data diff --git a/audioldm2/clap/training/params.py b/audioldm2/clap/training/params.py deleted file mode 100755 index 0cc1a0e2d982e900988cf5a4b24b2e59b093537b..0000000000000000000000000000000000000000 --- a/audioldm2/clap/training/params.py +++ /dev/null @@ -1,563 +0,0 @@ -import argparse - - -def get_default_params(model_name): - # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) - model_name = model_name.lower() - if "vit" in model_name: - return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} - else: - return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--train-data", - type=str, - default=None, - help="Path to h5 filewith training data", - ) - parser.add_argument( - "--val-data", - type=str, - default=None, - help="Path to h5 file with validation data", - ) - parser.add_argument( - "--freeze-text", - default=False, - action="store_true", - help="if you need to freeze the text encoder, make this True", - ) - parser.add_argument( - "--freeze-text-after", - type=int, - default=-1, - help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it", - ) - parser.add_argument( - "--train-ipc", - type=str, - default=None, - help="Path to npy file of the number of instance per class in training data", - ) - parser.add_argument( - "--val-ipc", - type=str, - default=None, - help="Path to npy file of the number of instance per class in validation data", - ) - parser.add_argument( - "--train-num-samples", - type=int, - default=None, - help="Number of samples in dataset. Required for webdataset if not available in info file.", - ) - parser.add_argument( - "--val-num-samples", - type=int, - default=None, - help="Number of samples in dataset. Useful for webdataset if not available in info file.", - ) - parser.add_argument( - "--dataset-type", - choices=["webdataset", "csv", "auto", "toy"], - default="auto", - help="Which type of dataset to process.", - ) - parser.add_argument( - "--csv-separator", - type=str, - default="\t", - help="For csv-like datasets, which separator to use.", - ) - parser.add_argument( - "--csv-img-key", - type=str, - default="filepath", - help="For csv-like datasets, the name of the key for the image paths.", - ) - parser.add_argument( - "--csv-caption-key", - type=str, - default="title", - help="For csv-like datasets, the name of the key for the captions.", - ) - parser.add_argument( - "--imagenet-val", - type=str, - default=None, - help="Path to imagenet val set for conducting zero shot evaluation.", - ) - parser.add_argument( - "--imagenet-v2", - type=str, - default=None, - help="Path to imagenet v2 for conducting zero shot evaluation.", - ) - parser.add_argument( - "--datasetnames", - nargs="+", - default=None, - help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects", - ) - parser.add_argument( - "--full-train-dataset", - nargs="+", - default=None, - help="Which dataset will be trained with all the subsets. (train+test)", - ) - parser.add_argument( - "--exclude-eval-dataset", - nargs="+", - default=None, - help="Which dataset will be excluded with evaluation", - ) - parser.add_argument( - "--datasetinfos", - nargs="+", - default=None, - help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval", - ) - parser.add_argument( - "--dataset-proportion", - type=float, - default=1.0, - help="How much proportion of dataset we want to train.", - ) - parser.add_argument( - "--remotedata", - default=False, - action="store_true", - help="if the dataset is remote, set this flag", - ) - parser.add_argument( - "--class-label-path", - type=str, - default=None, - help="The path of the class label pickle or csv.", - ) - parser.add_argument( - "--datasetpath", - type=str, - default="/mnt/audio_clip/webdataset_tar", - help="The path to the dataset", - ) - parser.add_argument( - "--logs", - type=str, - default="./logs/", - help="Where to store tensorboard logs. Use None to avoid storing logs.", - ) - parser.add_argument( - "--log-local", - action="store_true", - default=False, - help="log files on local master, otherwise global master only.", - ) - parser.add_argument( - "--name", - type=str, - default=None, - help="Optional identifier for the experiment when storing logs. Otherwise use current time.", - ) - parser.add_argument( - "--workers", type=int, default=1, help="Number of workers per GPU." - ) - parser.add_argument( - "--batch-size", type=int, default=64, help="Batch size per GPU." - ) - parser.add_argument( - "--epochs", type=int, default=32, help="Number of epochs to train for." - ) - parser.add_argument("--lr", type=float, default=None, help="Learning rate.") - parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") - parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") - parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") - parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.") - parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") - - parser.add_argument( - "--split-opt", - action="store_true", - default=False, - help="Use this flag to skip the learning rate decay.", - ) - parser.add_argument( - "--lr-pretrained", type=float, default=None, help="Learning rate for text." - ) - parser.add_argument( - "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text." - ) - parser.add_argument( - "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text." - ) - parser.add_argument( - "--eps-pretrained", type=float, default=None, help="Adam epsilon for text." - ) - parser.add_argument( - "--wd-pretrained", type=float, default=0.2, help="Weight decay for text." - ) - parser.add_argument( - "--momentum-pretrained", type=float, default=0.9, help="Momentum for text." - ) - parser.add_argument( - "--lr-new", type=float, default=None, help="Learning rate for audio." - ) - parser.add_argument( - "--beta1-new", type=float, default=None, help="Adam beta 1 for audio." - ) - parser.add_argument( - "--beta2-new", type=float, default=None, help="Adam beta 2 for audio." - ) - parser.add_argument( - "--eps-new", type=float, default=None, help="Adam epsilon for audio." - ) - parser.add_argument( - "--wd-new", type=float, default=0.2, help="Weight decay for audio." - ) - parser.add_argument( - "--momentum-new", type=float, default=0.9, help="Momentum for audio." - ) - parser.add_argument( - "--warmup", type=int, default=10000, help="Number of steps to warmup for." - ) - parser.add_argument( - "--use-bn-sync", - default=False, - action="store_true", - help="Whether to use batch norm sync.", - ) - parser.add_argument( - "--skip-scheduler", - action="store_true", - default=False, - help="Use this flag to skip the learning rate decay.", - ) - parser.add_argument( - "--save-frequency", type=int, default=1, help="How often to save checkpoints." - ) - parser.add_argument( - "--save-top-performance", - type=int, - default=0, - help="Save the top x performance weights if the value >0", - ) - parser.add_argument( - "--save-most-recent", - action="store_true", - default=False, - help="Always save the most recent model trained to epoch_latest.pt.", - ) - parser.add_argument( - "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." - ) - parser.add_argument( - "--val-frequency", - type=int, - default=1, - help="How often to run evaluation with val data.", - ) - parser.add_argument( - "--resume", - default=None, - type=str, - help="path to latest checkpoint (default: none)", - ) - parser.add_argument( - "--precision", - choices=["amp", "fp16", "fp32"], - default="amp", - help="Floating point precision.", - ) - parser.add_argument( - "--amodel", - type=str, - default="RN50", - help="Name of the audio backbone to use.", - ) - parser.add_argument( - "--tmodel", - type=str, - default="transformer", - help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]", - ) - parser.add_argument( - "--pretrained-audio", - default="", - type=str, - help="Use a pretrained audio model weights for the audio encoder of CLAP", - ) - parser.add_argument( - "--pretrained-text", - default="", - type=str, - help="Use a pretrained text model weights for the text encoder of CLAP", - ) - parser.add_argument( - "--pretrained", - default="", - type=str, - help="Use a pretrained CLIP model weights with the specified tag or file path.", - ) - parser.add_argument( - "--pretrained-image", - default=False, - action="store_true", - help="Load imagenet pretrained weights for image tower backbone if available.", - ) - parser.add_argument( - "--lock-image", - default=False, - action="store_true", - help="Lock full image tower by disabling gradients.", - ) - parser.add_argument( - "--lock-image-unlocked-groups", - type=int, - default=0, - help="Leave last n image tower layer groups unlocked.", - ) - parser.add_argument( - "--lock-image-freeze-bn-stats", - default=False, - action="store_true", - help="Freeze BatchNorm running stats in image tower for any locked layers.", - ) - parser.add_argument( - "--local-loss", - default=False, - action="store_true", - help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)", - ) - parser.add_argument( - "--gather-with-grad", - default=False, - action="store_true", - help="enable full distributed gradient for feature gather", - ) - parser.add_argument( - "--force-quick-gelu", - default=False, - action="store_true", - help="Force use of QuickGELU activation for non-OpenAI transformer models.", - ) - parser.add_argument( - "--torchscript", - default=False, - action="store_true", - help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", - ) - parser.add_argument( - "--trace", - default=False, - action="store_true", - help="torch.jit.trace the model for inference / eval only", - ) - # arguments for distributed training - parser.add_argument( - "--dist-url", - default="env://", - type=str, - help="url used to set up distributed training", - ) - parser.add_argument( - "--dist-backend", default="nccl", type=str, help="distributed backend" - ) - parser.add_argument( - "--report-to", - default="", - type=str, - help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']", - ) - parser.add_argument( - "--wandb-notes", default="", type=str, help="Notes if logging with wandb" - ) - parser.add_argument( - "--C", type=float, default=3.16, help="inverse regularizer for logistic reg." - ) - parser.add_argument( - "--debug", - default=False, - action="store_true", - help="If true, more information is logged.", - ) - parser.add_argument( - "--copy-codebase", - default=False, - action="store_true", - help="If true, we copy the entire base on the log diretory, and execute from there.", - ) - parser.add_argument( - "--horovod", - default=False, - action="store_true", - help="Use horovod for distributed training.", - ) - parser.add_argument( - "--ddp-static-graph", - default=False, - action="store_true", - help="Enable static graph optimization for DDP in PyTorch >= 1.11.", - ) - parser.add_argument( - "--no-set-device-rank", - default=False, - action="store_true", - help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", - ) - parser.add_argument("--seed", type=int, default=4242, help="Default random seed.") - - parser.add_argument( - "--top-k-checkpoint-select-dataset", - type=str, - default="all", - help="The dataset of selecting top-k checkpoint.", - ) - - # @R10, @R@5, @R1, mAP@10 - parser.add_argument( - "--top-k-checkpoint-select-metric", - type=str, - default="_R@10", - help="The metric for selecting top-k checkpoint.", - ) - parser.add_argument( - "--openai-model-cache-dir", - type=str, - default="~/.cache/clip", - help="Directory to download OpenAI models.", - ) - parser.add_argument( - "--optimizer", - type=str, - default="adamw", - help="can be AdamW or SGD", - ) - parser.add_argument( - "--parallel-eval", - default=False, - action="store_true", - help="Eval in parallel (multi-GPU, multi-node).", - ) - - parser.add_argument( - "--no-eval", - default=False, - action="store_true", - help="Training without evaluation.", - ) - - parser.add_argument( - "--lp-mlp", - default=False, - action="store_true", - help="Linear Probe using MLP layer or not.", - ) - - parser.add_argument( - "--lp-freeze", - default=False, - action="store_true", - help="Linear Probe using Freeze CLAP or not", - ) - - parser.add_argument( - "--lp-act", - default="None", - type=str, - help="Options are ['relu','elu','prelu','softmax','sigmoid']", - ) - - parser.add_argument( - "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe." - ) - - parser.add_argument( - "--lp-metrics", - type=str, - default="map,mauc,acc", - help="Metrics of Linear Probe.", - ) - - parser.add_argument( - "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe" - ) - parser.add_argument( - "--kappa", - type=float, - default=0, - help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss", - ) - - parser.add_argument( - "--data-filling", - type=str, - default="pad", - help="type of data filling when the audio length is shorter than the max length." - "Can be one of the following: repeat, repeatpad, pad", - ) - parser.add_argument( - "--data-truncating", - type=str, - default="rand_trunc", - help="type of data truncation when the audio length is longer than the max length." - "Can be one of the following: rand_trunc, fusion", - ) - - parser.add_argument( - "--clap-mlploss", - default=False, - action="store_true", - help="Using MLP loss for CLAP model or not", - ) - - parser.add_argument( - "--wandb-id", - type=str, - default=None, - help="the id of wandb experiment to restore.", - ) - - parser.add_argument( - "--sleep", type=float, default=0, help="sleep n seconds before start training" - ) - - # variable length processing - parser.add_argument( - "--enable-fusion", - default=False, - action="store_true", - help="Enable feature funsion for variable-length data", - ) - - parser.add_argument( - "--fusion-type", - type=str, - default="None", - help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']", - ) - - parser.add_argument( - "--mixup", - default=False, - action="store_true", - help="Enable mixup in finetuning training.", - ) - parser.add_argument( - "--text-augment-selection", - type=str, - default=None, - help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']", - ) - - args = parser.parse_args() - - # If some params are not passed, we use the default values based on model name. - default_params = get_default_params(args.amodel) - for name, val in default_params.items(): - if getattr(args, name) is None: - setattr(args, name, val) - - return args diff --git a/audioldm2/hifigan/LICENSE b/audioldm2/hifigan/LICENSE deleted file mode 100644 index 5afae394d6b37da0e12ba6b290d2512687f421ac..0000000000000000000000000000000000000000 --- a/audioldm2/hifigan/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2020 Jungil Kong - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file diff --git a/audioldm2/hifigan/__init__.py b/audioldm2/hifigan/__init__.py deleted file mode 100755 index 34e055557bf2ecb457376663b67390543c71fb1f..0000000000000000000000000000000000000000 --- a/audioldm2/hifigan/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .models_v2 import Generator -from .models import Generator as Generator_old - - -class AttrDict(dict): - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self diff --git a/audioldm2/hifigan/models.py b/audioldm2/hifigan/models.py deleted file mode 100755 index c4382cc39de0463f9b7c0f33f037dbc233e7cb36..0000000000000000000000000000000000000000 --- a/audioldm2/hifigan/models.py +++ /dev/null @@ -1,174 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Conv1d, ConvTranspose1d -from torch.nn.utils import weight_norm, remove_weight_norm - -LRELU_SLOPE = 0.1 - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -class ResBlock(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): - super(ResBlock, self).__init__() - self.h = h - self.convs1 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - self.convs1.apply(init_weights) - - self.convs2 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - ] - ) - self.convs2.apply(init_weights) - - def forward(self, x): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - xt = c2(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - - -class Generator(torch.nn.Module): - def __init__(self, h): - super(Generator, self).__init__() - self.h = h - self.num_kernels = len(h.resblock_kernel_sizes) - self.num_upsamples = len(h.upsample_rates) - self.conv_pre = weight_norm( - Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) - ) - resblock = ResBlock - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): - self.ups.append( - weight_norm( - ConvTranspose1d( - h.upsample_initial_channel // (2**i), - h.upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = h.upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate( - zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) - ): - self.resblocks.append(resblock(h, ch, k, d)) - - self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) - self.ups.apply(init_weights) - self.conv_post.apply(init_weights) - - def forward(self, x): - x = self.conv_pre(x) - for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - # print("Removing weight norm...") - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) diff --git a/audioldm2/hifigan/models_v2.py b/audioldm2/hifigan/models_v2.py deleted file mode 100755 index 27a2df6b54bdd3a5b259645442624800ac0e8afe..0000000000000000000000000000000000000000 --- a/audioldm2/hifigan/models_v2.py +++ /dev/null @@ -1,395 +0,0 @@ -import torch -import torch.nn.functional as F -import torch.nn as nn -from torch.nn import Conv1d, ConvTranspose1d -from torch.nn.utils import weight_norm, remove_weight_norm - -LRELU_SLOPE = 0.1 - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -class ResBlock1(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): - super(ResBlock1, self).__init__() - self.h = h - self.convs1 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - self.convs1.apply(init_weights) - - self.convs2 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - ] - ) - self.convs2.apply(init_weights) - - def forward(self, x): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - xt = c2(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - - -class ResBlock2(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): - super(ResBlock2, self).__init__() - self.h = h - self.convs = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - ] - ) - self.convs.apply(init_weights) - - def forward(self, x): - for c in self.convs: - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs: - remove_weight_norm(l) - - -class Generator(torch.nn.Module): - def __init__(self, h): - super(Generator, self).__init__() - self.h = h - self.num_kernels = len(h.resblock_kernel_sizes) - self.num_upsamples = len(h.upsample_rates) - self.conv_pre = weight_norm( - Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3) - ) - resblock = ResBlock1 if h.resblock == "1" else ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): - self.ups.append( - weight_norm( - ConvTranspose1d( - h.upsample_initial_channel // (2**i), - h.upsample_initial_channel // (2 ** (i + 1)), - u * 2, - u, - padding=u // 2 + u % 2, - output_padding=u % 2, - ) - ) - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = h.upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate( - zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) - ): - self.resblocks.append(resblock(h, ch, k, d)) - - self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) - self.ups.apply(init_weights) - self.conv_post.apply(init_weights) - - def forward(self, x): - # import ipdb; ipdb.set_trace() - x = self.conv_pre(x) - for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - # print('Removing weight norm...') - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) - - -################################################################################################## - -# import torch -# import torch.nn as nn -# import torch.nn.functional as F -# from torch.nn import Conv1d, ConvTranspose1d -# from torch.nn.utils import weight_norm, remove_weight_norm - -# LRELU_SLOPE = 0.1 - - -# def init_weights(m, mean=0.0, std=0.01): -# classname = m.__class__.__name__ -# if classname.find("Conv") != -1: -# m.weight.data.normal_(mean, std) - - -# def get_padding(kernel_size, dilation=1): -# return int((kernel_size * dilation - dilation) / 2) - - -# class ResBlock(torch.nn.Module): -# def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): -# super(ResBlock, self).__init__() -# self.h = h -# self.convs1 = nn.ModuleList( -# [ -# weight_norm( -# Conv1d( -# channels, -# channels, -# kernel_size, -# 1, -# dilation=dilation[0], -# padding=get_padding(kernel_size, dilation[0]), -# ) -# ), -# weight_norm( -# Conv1d( -# channels, -# channels, -# kernel_size, -# 1, -# dilation=dilation[1], -# padding=get_padding(kernel_size, dilation[1]), -# ) -# ), -# weight_norm( -# Conv1d( -# channels, -# channels, -# kernel_size, -# 1, -# dilation=dilation[2], -# padding=get_padding(kernel_size, dilation[2]), -# ) -# ), -# ] -# ) -# self.convs1.apply(init_weights) - -# self.convs2 = nn.ModuleList( -# [ -# weight_norm( -# Conv1d( -# channels, -# channels, -# kernel_size, -# 1, -# dilation=1, -# padding=get_padding(kernel_size, 1), -# ) -# ), -# weight_norm( -# Conv1d( -# channels, -# channels, -# kernel_size, -# 1, -# dilation=1, -# padding=get_padding(kernel_size, 1), -# ) -# ), -# weight_norm( -# Conv1d( -# channels, -# channels, -# kernel_size, -# 1, -# dilation=1, -# padding=get_padding(kernel_size, 1), -# ) -# ), -# ] -# ) -# self.convs2.apply(init_weights) - -# def forward(self, x): -# for c1, c2 in zip(self.convs1, self.convs2): -# xt = F.leaky_relu(x, LRELU_SLOPE) -# xt = c1(xt) -# xt = F.leaky_relu(xt, LRELU_SLOPE) -# xt = c2(xt) -# x = xt + x -# return x - -# def remove_weight_norm(self): -# for l in self.convs1: -# remove_weight_norm(l) -# for l in self.convs2: -# remove_weight_norm(l) - -# class Generator(torch.nn.Module): -# def __init__(self, h): -# super(Generator, self).__init__() -# self.h = h -# self.num_kernels = len(h.resblock_kernel_sizes) -# self.num_upsamples = len(h.upsample_rates) -# self.conv_pre = weight_norm( -# Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) -# ) -# resblock = ResBlock - -# self.ups = nn.ModuleList() -# for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): -# self.ups.append( -# weight_norm( -# ConvTranspose1d( -# h.upsample_initial_channel // (2**i), -# h.upsample_initial_channel // (2 ** (i + 1)), -# k, -# u, -# padding=(k - u) // 2, -# ) -# ) -# ) - -# self.resblocks = nn.ModuleList() -# for i in range(len(self.ups)): -# ch = h.upsample_initial_channel // (2 ** (i + 1)) -# for j, (k, d) in enumerate( -# zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) -# ): -# self.resblocks.append(resblock(h, ch, k, d)) - -# self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) -# self.ups.apply(init_weights) -# self.conv_post.apply(init_weights) - -# def forward(self, x): -# x = self.conv_pre(x) -# for i in range(self.num_upsamples): -# x = F.leaky_relu(x, LRELU_SLOPE) -# x = self.ups[i](x) -# xs = None -# for j in range(self.num_kernels): -# if xs is None: -# xs = self.resblocks[i * self.num_kernels + j](x) -# else: -# xs += self.resblocks[i * self.num_kernels + j](x) -# x = xs / self.num_kernels -# x = F.leaky_relu(x) -# x = self.conv_post(x) -# x = torch.tanh(x) - -# return x - -# def remove_weight_norm(self): -# print("Removing weight norm...") -# for l in self.ups: -# remove_weight_norm(l) -# for l in self.resblocks: -# l.remove_weight_norm() -# remove_weight_norm(self.conv_pre) -# remove_weight_norm(self.conv_post) diff --git a/audioldm2/latent_diffusion/__init__.py b/audioldm2/latent_diffusion/__init__.py deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm2/latent_diffusion/models/__init__.py b/audioldm2/latent_diffusion/models/__init__.py deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm2/latent_diffusion/models/ddim.py b/audioldm2/latent_diffusion/models/ddim.py deleted file mode 100755 index 0c07207af7959847552805f00831122304b4330e..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/models/ddim.py +++ /dev/null @@ -1,487 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm - -from audioldm2.latent_diffusion.modules.diffusionmodules.util import ( - make_ddim_sampling_parameters, - make_ddim_timesteps, - noise_like, - extract_into_tensor, -) - - -class DDIMSampler(object): - def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - self.device = device - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != self.device: - attr = attr.to(self.device) - setattr(self, name, attr) - - def make_schedule( - self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True - ): - self.ddim_timesteps = make_ddim_timesteps( - ddim_discr_method=ddim_discretize, - num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, - ) - alphas_cumprod = self.model.alphas_cumprod - assert ( - alphas_cumprod.shape[0] == self.ddpm_num_timesteps - ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer("betas", to_torch(self.model.betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer( - "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) - ) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer( - "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", - to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), - ) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( - alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta, - verbose=verbose, - ) - self.register_buffer("ddim_sigmas", ddim_sigmas) - self.register_buffer("ddim_alphas", ddim_alphas) - self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) - self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) - / (1 - self.alphas_cumprod) - * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) - ) - self.register_buffer( - "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps - ) - - @torch.no_grad() - def sample( - self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0.0, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - **kwargs, - ): - # if conditioning is not None: - # if isinstance(conditioning, dict): - # ctmp = conditioning[list(conditioning.keys())[0]] - # while isinstance(ctmp, list): ctmp = ctmp[0] - # cbs = ctmp.shape[0] - # if cbs != batch_size: - # print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - - # elif isinstance(conditioning, list): - # for ctmp in conditioning: - # if ctmp.shape[0] != batch_size: - # print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - - # else: - # if conditioning.shape[0] != batch_size: - # print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - # print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling( - conditioning, - size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, - x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule, - ) - return samples, intermediates - - @torch.no_grad() - def ddim_sampling( - self, - cond, - shape, - x_T=None, - ddim_use_original_steps=False, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - log_every_t=100, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - dynamic_threshold=None, - ucg_schedule=None, - ): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = ( - self.ddpm_num_timesteps - if ddim_use_original_steps - else self.ddim_timesteps - ) - elif timesteps is not None and not ddim_use_original_steps: - subset_end = ( - int( - min(timesteps / self.ddim_timesteps.shape[0], 1) - * self.ddim_timesteps.shape[0] - ) - - 1 - ) - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {"x_inter": [img], "pred_x0": [img]} - time_range = ( - reversed(range(0, timesteps)) - if ddim_use_original_steps - else np.flip(timesteps) - ) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample( - x0, ts - ) # TODO: deterministic forward pass? - img = img_orig * mask + (1.0 - mask) * img - - if ucg_schedule is not None: - assert len(ucg_schedule) == len(time_range) - unconditional_guidance_scale = ucg_schedule[i] - - outs = self.p_sample_ddim( - img, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ) - img, pred_x0 = outs - if callback: - callback(i) - if img_callback: - img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates["x_inter"].append(img) - intermediates["pred_x0"].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_ddim( - self, - x, - c, - t, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - dynamic_threshold=None, - ): - b, *_, device = *x.shape, x.device - - if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: - model_output = self.model.apply_model(x, t, c) - else: - x_in = x - t_in = t - - assert isinstance(c, dict) - assert isinstance(unconditional_conditioning, dict) - - model_uncond = self.model.apply_model( - x_in, t_in, unconditional_conditioning - ) - model_t = self.model.apply_model(x_in, t_in, c) - - model_output = model_uncond + unconditional_guidance_scale * ( - model_t - model_uncond - ) - - if self.model.parameterization == "v": - e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) - else: - e_t = model_output - - if score_corrector is not None: - assert self.model.parameterization == "eps", "not implemented" - e_t = score_corrector.modify_score( - self.model, e_t, x, t, c, **corrector_kwargs - ) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = ( - self.model.alphas_cumprod_prev - if use_original_steps - else self.ddim_alphas_prev - ) - sqrt_one_minus_alphas = ( - self.model.sqrt_one_minus_alphas_cumprod - if use_original_steps - else self.ddim_sqrt_one_minus_alphas - ) - sigmas = ( - self.model.ddim_sigmas_for_original_num_steps - if use_original_steps - else self.ddim_sigmas - ) - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full( - (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device - ) - - # current prediction for x_0 - if self.model.parameterization != "v": - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - else: - pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) - - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - - if dynamic_threshold is not None: - raise NotImplementedError() - - # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - @torch.no_grad() - def encode( - self, - x0, - c, - t_enc, - use_original_steps=False, - return_intermediates=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - callback=None, - ): - num_reference_steps = ( - self.ddpm_num_timesteps - if use_original_steps - else self.ddim_timesteps.shape[0] - ) - - assert t_enc <= num_reference_steps - num_steps = t_enc - - if use_original_steps: - alphas_next = self.alphas_cumprod[:num_steps] - alphas = self.alphas_cumprod_prev[:num_steps] - else: - alphas_next = self.ddim_alphas[:num_steps] - alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) - - x_next = x0 - intermediates = [] - inter_steps = [] - for i in tqdm(range(num_steps), desc="Encoding Image"): - t = torch.full( - (x0.shape[0],), i, device=self.model.device, dtype=torch.long - ) - if unconditional_guidance_scale == 1.0: - noise_pred = self.model.apply_model(x_next, t, c) - else: - assert unconditional_conditioning is not None - e_t_uncond, noise_pred = torch.chunk( - self.model.apply_model( - torch.cat((x_next, x_next)), - torch.cat((t, t)), - torch.cat((unconditional_conditioning, c)), - ), - 2, - ) - noise_pred = e_t_uncond + unconditional_guidance_scale * ( - noise_pred - e_t_uncond - ) - - xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next - weighted_noise_pred = ( - alphas_next[i].sqrt() - * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) - * noise_pred - ) - x_next = xt_weighted + weighted_noise_pred - if ( - return_intermediates - and i % (num_steps // return_intermediates) == 0 - and i < num_steps - 1 - ): - intermediates.append(x_next) - inter_steps.append(i) - elif return_intermediates and i >= num_steps - 2: - intermediates.append(x_next) - inter_steps.append(i) - if callback: - callback(i) - - out = {"x_encoded": x_next, "intermediate_steps": inter_steps} - if return_intermediates: - out.update({"intermediates": intermediates}) - return x_next, out - - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas - - if noise is None: - noise = torch.randn_like(x0) - return ( - extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 - + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise - ) - - @torch.no_grad() - def decode( - self, - x_latent, - cond, - t_start, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_original_steps=False, - callback=None, - ): - timesteps = ( - np.arange(self.ddpm_num_timesteps) - if use_original_steps - else self.ddim_timesteps - ) - timesteps = timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc="Decoding image", total=total_steps) - x_dec = x_latent - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full( - (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long - ) - x_dec, _ = self.p_sample_ddim( - x_dec, - cond, - ts, - index=index, - use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - if callback: - callback(i) - return x_dec diff --git a/audioldm2/latent_diffusion/models/ddpm.py b/audioldm2/latent_diffusion/models/ddpm.py deleted file mode 100755 index df3a6c032ba2ec61250212a31d68184e763dcf0e..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/models/ddpm.py +++ /dev/null @@ -1,1840 +0,0 @@ -from multiprocessing.sharedctypes import Value -import os - -import torch -import torch.nn as nn -import numpy as np -from einops import rearrange, repeat -from contextlib import contextmanager -from functools import partial -from tqdm import tqdm -from torchvision.utils import make_grid -from audioldm2.latent_diffusion.modules.encoders.modules import * - -from audioldm2.latent_diffusion.util import ( - exists, - default, - count_params, - instantiate_from_config, -) -from audioldm2.latent_diffusion.modules.ema import LitEma -from audioldm2.latent_diffusion.modules.distributions.distributions import ( - DiagonalGaussianDistribution, -) - -# from latent_encoder.autoencoder import ( -# VQModelInterface, -# IdentityFirstStage, -# AutoencoderKL, -# ) - -from audioldm2.latent_diffusion.modules.diffusionmodules.util import ( - make_beta_schedule, - extract_into_tensor, - noise_like, -) - -from audioldm2.latent_diffusion.models.ddim import DDIMSampler -from audioldm2.latent_diffusion.models.plms import PLMSSampler -import soundfile as sf -import os - -__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} - -CACHE_DIR = os.getenv( - "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2") -) - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -def uniform_on_device(r1, r2, shape, device): - return (r1 - r2) * torch.rand(*shape, device=device) + r2 - - -class DDPM(nn.Module): - # classic DDPM with Gaussian diffusion, in image space - def __init__( - self, - unet_config, - sampling_rate=None, - timesteps=1000, - beta_schedule="linear", - loss_type="l2", - ckpt_path=None, - ignore_keys=[], - load_only_unet=False, - monitor="val/loss", - use_ema=True, - first_stage_key="image", - latent_t_size=256, - latent_f_size=16, - channels=3, - log_every_t=100, - clip_denoised=True, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - given_betas=None, - original_elbo_weight=0.0, - v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1.0, - conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules - scheduler_config=None, - use_positional_encodings=False, - learn_logvar=False, - logvar_init=0.0, - evaluator=None, - device=None, - ): - super().__init__() - assert parameterization in [ - "eps", - "x0", - "v", - ], 'currently only supporting "eps" and "x0" and "v"' - self.parameterization = parameterization - self.state = None - self.device = device - # print( - # f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" - # ) - assert sampling_rate is not None - self.validation_folder_name = "temp_name" - self.clip_denoised = clip_denoised - self.log_every_t = log_every_t - self.first_stage_key = first_stage_key - self.sampling_rate = sampling_rate - - self.clap = CLAPAudioEmbeddingClassifierFreev2( - pretrained_path="", - sampling_rate=self.sampling_rate, - embed_mode="audio", - amodel="HTSAT-base", - ) - - self.initialize_param_check_toolkit() - - self.latent_t_size = latent_t_size - self.latent_f_size = latent_f_size - - self.channels = channels - self.use_positional_encodings = use_positional_encodings - self.model = DiffusionWrapper(unet_config, conditioning_key) - count_params(self.model, verbose=True) - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self.model) - # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - self.use_scheduler = scheduler_config is not None - if self.use_scheduler: - self.scheduler_config = scheduler_config - - self.v_posterior = v_posterior - self.original_elbo_weight = original_elbo_weight - self.l_simple_weight = l_simple_weight - - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt( - ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet - ) - - self.register_schedule( - given_betas=given_betas, - beta_schedule=beta_schedule, - timesteps=timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s, - ) - - self.loss_type = loss_type - - self.learn_logvar = learn_logvar - self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) - if self.learn_logvar: - self.logvar = nn.Parameter(self.logvar, requires_grad=True) - else: - self.logvar = nn.Parameter(self.logvar, requires_grad=False) - - self.logger_save_dir = None - self.logger_exp_name = None - self.logger_exp_group_name = None - self.logger_version = None - - self.label_indices_total = None - # To avoid the system cannot find metric value for checkpoint - self.metrics_buffer = { - "val/kullback_leibler_divergence_sigmoid": 15.0, - "val/kullback_leibler_divergence_softmax": 10.0, - "val/psnr": 0.0, - "val/ssim": 0.0, - "val/inception_score_mean": 1.0, - "val/inception_score_std": 0.0, - "val/kernel_inception_distance_mean": 0.0, - "val/kernel_inception_distance_std": 0.0, - "val/frechet_inception_distance": 133.0, - "val/frechet_audio_distance": 32.0, - } - self.initial_learning_rate = None - self.test_data_subset_path = None - - def get_log_dir(self): - return os.path.join( - self.logger_save_dir, self.logger_exp_group_name, self.logger_exp_name - ) - - def set_log_dir(self, save_dir, exp_group_name, exp_name): - self.logger_save_dir = save_dir - self.logger_exp_group_name = exp_group_name - self.logger_exp_name = exp_name - - def register_schedule( - self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - ): - if exists(given_betas): - betas = given_betas - else: - betas = make_beta_schedule( - beta_schedule, - timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s, - ) - alphas = 1.0 - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - - (timesteps,) = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - assert ( - alphas_cumprod.shape[0] == self.num_timesteps - ), "alphas have to be defined for each timestep" - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.register_buffer("betas", to_torch(betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) - ) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * ( - 1.0 - alphas_cumprod_prev - ) / (1.0 - alphas_cumprod) + self.v_posterior * betas - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer("posterior_variance", to_torch(posterior_variance)) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer( - "posterior_log_variance_clipped", - to_torch(np.log(np.maximum(posterior_variance, 1e-20))), - ) - self.register_buffer( - "posterior_mean_coef1", - to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), - ) - self.register_buffer( - "posterior_mean_coef2", - to_torch( - (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) - ), - ) - - if self.parameterization == "eps": - lvlb_weights = self.betas**2 / ( - 2 - * self.posterior_variance - * to_torch(alphas) - * (1 - self.alphas_cumprod) - ) - elif self.parameterization == "x0": - lvlb_weights = ( - 0.5 - * np.sqrt(torch.Tensor(alphas_cumprod)) - / (2.0 * 1 - torch.Tensor(alphas_cumprod)) - ) - elif self.parameterization == "v": - lvlb_weights = torch.ones_like( - self.betas**2 - / ( - 2 - * self.posterior_variance - * to_torch(alphas) - * (1 - self.alphas_cumprod) - ) - ) - else: - raise NotImplementedError("mu not supported") - # TODO how to choose this term - lvlb_weights[0] = lvlb_weights[1] - self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) - assert not torch.isnan(self.lvlb_weights).all() - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.model.parameters()) - self.model_ema.copy_to(self.model) - # if context is not None: - # print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.model.parameters()) - # if context is not None: - # print(f"{context}: Restored training weights") - - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): - sd = torch.load(path, map_location="cpu") - if "state_dict" in list(sd.keys()): - sd = sd["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - missing, unexpected = ( - self.load_state_dict(sd, strict=False) - if not only_model - else self.model.load_state_dict(sd, strict=False) - ) - print( - f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" - ) - if len(missing) > 0: - print(f"Missing Keys: {missing}") - if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") - - def q_mean_variance(self, x_start, t): - """ - Get the distribution q(x_t | x_0). - :param x_start: the [N x C x ...] tensor of noiseless inputs. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :return: A tuple (mean, variance, log_variance), all of x_start's shape. - """ - mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = extract_into_tensor( - self.log_one_minus_alphas_cumprod, t, x_start.shape - ) - return mean, variance, log_variance - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - * noise - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start - + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract_into_tensor( - self.posterior_log_variance_clipped, t, x_t.shape - ) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, x, t, clip_denoised: bool): - model_out = self.model(x, t) - if self.parameterization == "eps": - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": - x_recon = model_out - if clip_denoised: - x_recon.clamp_(-1.0, 1.0) - - model_mean, posterior_variance, posterior_log_variance = self.q_posterior( - x_start=x_recon, x_t=x, t=t - ) - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): - b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance( - x=x, t=t, clip_denoised=clip_denoised - ) - noise = noise_like(x.shape, device, repeat_noise) - # no noise when t == 0 - nonzero_mask = ( - (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() - ) - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def p_sample_loop(self, shape, return_intermediates=False): - device = self.betas.device - b = shape[0] - img = torch.randn(shape, device=device) - intermediates = [img] - for i in tqdm( - reversed(range(0, self.num_timesteps)), - desc="Sampling t", - total=self.num_timesteps, - ): - img = self.p_sample( - img, - torch.full((b,), i, device=device, dtype=torch.long), - clip_denoised=self.clip_denoised, - ) - if i % self.log_every_t == 0 or i == self.num_timesteps - 1: - intermediates.append(img) - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample(self, batch_size=16, return_intermediates=False): - shape = (batch_size, channels, self.latent_t_size, self.latent_f_size) - self.channels - return self.p_sample_loop(shape, return_intermediates=return_intermediates) - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) - * noise - ) - - def get_loss(self, pred, target, mean=True): - if self.loss_type == "l1": - loss = (target - pred).abs() - if mean: - loss = loss.mean() - elif self.loss_type == "l2": - if mean: - loss = torch.nn.functional.mse_loss(target, pred) - else: - loss = torch.nn.functional.mse_loss(target, pred, reduction="none") - else: - raise NotImplementedError("unknown loss type '{loss_type}'") - - return loss - - def predict_start_from_z_and_v(self, x_t, t, v): - # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v - ) - - def predict_eps_from_z_and_v(self, x_t, t, v): - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) - * x_t - ) - - def get_v(self, x, noise, t): - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x - ) - - def forward(self, x, *args, **kwargs): - # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size - # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' - t = torch.randint( - 0, self.num_timesteps, (x.shape[0],), device=self.device - ).long() - return self.p_losses(x, t, *args, **kwargs) - - def get_input(self, batch, k): - # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch - # fbank, stft, label_indices, fname, waveform, text = batch - fname, text, waveform, stft, fbank = ( - batch["fname"], - batch["text"], - batch["waveform"], - batch["stft"], - batch["log_mel_spec"], - ) - # for i in range(fbank.size(0)): - # fb = fbank[i].numpy() - # seg_lb = seg_label[i].numpy() - # logits = np.mean(seg_lb, axis=0) - # index = np.argsort(logits)[::-1][:5] - # plt.imshow(seg_lb[:,index], aspect="auto") - # plt.title(index) - # plt.savefig("%s_label.png" % i) - # plt.close() - # plt.imshow(fb, aspect="auto") - # plt.savefig("%s_fb.png" % i) - # plt.close() - ret = {} - - ret["fbank"] = ( - fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() - ) - ret["stft"] = stft.to(memory_format=torch.contiguous_format).float() - # ret["clip_label"] = clip_label.to(memory - # _format=torch.contiguous_format).float() - ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() - ret["text"] = list(text) - ret["fname"] = fname - - for key in batch.keys(): - if key not in ret.keys(): - ret[key] = batch[key] - - return ret[k] - - def _get_rows_from_list(self, samples): - n_imgs_per_row = len(samples) - denoise_grid = rearrange(samples, "n b c h w -> b n c h w") - denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") - denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) - return denoise_grid - - @torch.no_grad() - def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() - x = self.get_input(batch, self.first_stage_key) - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - x = x.to(self.device)[:N] - log["inputs"] = x - - # get diffusion row - diffusion_row = list() - x_start = x[:n_row] - - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), "1 -> b", b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(x_start) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - diffusion_row.append(x_noisy) - - log["diffusion_row"] = self._get_rows_from_list(diffusion_row) - - if sample: - # get denoise row - with self.ema_scope("Plotting"): - samples, denoise_row = self.sample( - batch_size=N, return_intermediates=True - ) - - log["samples"] = samples - log["denoise_row"] = self._get_rows_from_list(denoise_row) - - if return_keys: - if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: - return log - else: - return {key: log[key] for key in return_keys} - return log - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - if self.learn_logvar: - params = params + [self.logvar] - opt = torch.optim.AdamW(params, lr=lr) - return opt - - def initialize_param_check_toolkit(self): - self.tracked_steps = 0 - self.param_dict = {} - - def statistic_require_grad_tensor_number(self, module, name=None): - requires_grad_num = 0 - total_num = 0 - require_grad_tensor = None - for p in module.parameters(): - if p.requires_grad: - requires_grad_num += 1 - if require_grad_tensor is None: - require_grad_tensor = p - total_num += 1 - print( - "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)" - % (name, requires_grad_num, total_num, requires_grad_num / total_num) - ) - return require_grad_tensor - - -class LatentDiffusion(DDPM): - """main class""" - - def __init__( - self, - first_stage_config, - cond_stage_config=None, - num_timesteps_cond=None, - cond_stage_key="image", - optimize_ddpm_parameter=True, - unconditional_prob_cfg=0.1, - warmup_steps=10000, - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - batchsize=None, - evaluation_params={}, - scale_by_std=False, - base_learning_rate=None, - *args, - **kwargs, - ): - self.learning_rate = base_learning_rate - self.num_timesteps_cond = default(num_timesteps_cond, 1) - self.scale_by_std = scale_by_std - self.warmup_steps = warmup_steps - - if optimize_ddpm_parameter: - if unconditional_prob_cfg == 0.0: - "You choose to optimize DDPM. The classifier free guidance scale should be 0.1" - unconditional_prob_cfg = 0.1 - else: - if unconditional_prob_cfg == 0.1: - "You choose not to optimize DDPM. The classifier free guidance scale should be 0.0" - unconditional_prob_cfg = 0.0 - - self.evaluation_params = evaluation_params - assert self.num_timesteps_cond <= kwargs["timesteps"] - - # for backwards compatibility after implementation of DiffusionWrapper - # if conditioning_key is None: - # conditioning_key = "concat" if concat_mode else "crossattn" - # if cond_stage_config == "__is_unconditional__": - # conditioning_key = None - - conditioning_key = list(cond_stage_config.keys()) - - self.conditioning_key = conditioning_key - - ckpt_path = kwargs.pop("ckpt_path", None) - ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, *args, **kwargs) - - self.optimize_ddpm_parameter = optimize_ddpm_parameter - # if(not optimize_ddpm_parameter): - # print("Warning: Close the optimization of the latent diffusion model") - # for p in self.model.parameters(): - # p.requires_grad=False - - self.concat_mode = concat_mode - self.cond_stage_key = cond_stage_key - self.cond_stage_key_orig = cond_stage_key - try: - self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: - self.num_downs = 0 - if not scale_by_std: - self.scale_factor = scale_factor - else: - self.register_buffer("scale_factor", torch.tensor(scale_factor)) - self.instantiate_first_stage(first_stage_config) - self.unconditional_prob_cfg = unconditional_prob_cfg - self.cond_stage_models = nn.ModuleList([]) - self.instantiate_cond_stage(cond_stage_config) - self.cond_stage_forward = cond_stage_forward - self.clip_denoised = False - self.bbox_tokenizer = None - self.conditional_dry_run_finished = False - self.restarted_from_ckpt = False - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys) - self.restarted_from_ckpt = True - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - - for each in self.cond_stage_models: - params = params + list( - each.parameters() - ) # Add the parameter from the conditional stage - - if self.learn_logvar: - print("Diffusion model optimizing logvar") - params.append(self.logvar) - opt = torch.optim.AdamW(params, lr=lr) - # if self.use_scheduler: - # assert "target" in self.scheduler_config - # scheduler = instantiate_from_config(self.scheduler_config) - - # print("Setting up LambdaLR scheduler...") - # scheduler = [ - # { - # "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), - # "interval": "step", - # "frequency": 1, - # } - # ] - # return [opt], scheduler - return opt - - def make_cond_schedule( - self, - ): - self.cond_ids = torch.full( - size=(self.num_timesteps,), - fill_value=self.num_timesteps - 1, - dtype=torch.long, - ) - ids = torch.round( - torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) - ).long() - self.cond_ids[: self.num_timesteps_cond] = ids - - @torch.no_grad() - def on_train_batch_start(self, batch, batch_idx): - # only for very first batch - if ( - self.scale_factor == 1 - and self.scale_by_std - and self.current_epoch == 0 - and self.global_step == 0 - and batch_idx == 0 - and not self.restarted_from_ckpt - ): - # assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' - # set rescale weight to 1./std of encodings - print("### USING STD-RESCALING ###") - x = super().get_input(batch, self.first_stage_key) - x = x.to(self.device) - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - del self.scale_factor - self.register_buffer("scale_factor", 1.0 / z.flatten().std()) - print(f"setting self.scale_factor to {self.scale_factor}") - print("### USING STD-RESCALING ###") - - def register_schedule( - self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - ): - super().register_schedule( - given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s - ) - - self.shorten_cond_schedule = self.num_timesteps_cond > 1 - if self.shorten_cond_schedule: - self.make_cond_schedule() - - def instantiate_first_stage(self, config): - model = instantiate_from_config(config) - self.first_stage_model = model.eval() - self.first_stage_model.train = disabled_train - for param in self.first_stage_model.parameters(): - param.requires_grad = False - - def make_decision(self, probability): - if float(torch.rand(1)) < probability: - return True - else: - return False - - def instantiate_cond_stage(self, config): - self.cond_stage_model_metadata = {} - for i, cond_model_key in enumerate(config.keys()): - model = instantiate_from_config(config[cond_model_key]) - self.cond_stage_models.append(model) - self.cond_stage_model_metadata[cond_model_key] = { - "model_idx": i, - "cond_stage_key": config[cond_model_key]["cond_stage_key"], - "conditioning_key": config[cond_model_key]["conditioning_key"], - } - - def get_first_stage_encoding(self, encoder_posterior): - if isinstance(encoder_posterior, DiagonalGaussianDistribution): - z = encoder_posterior.sample() - elif isinstance(encoder_posterior, torch.Tensor): - z = encoder_posterior - else: - raise NotImplementedError( - f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" - ) - return self.scale_factor * z - - def get_learned_conditioning(self, c, key, unconditional_cfg): - assert key in self.cond_stage_model_metadata.keys() - - # Classifier-free guidance - if not unconditional_cfg: - c = self.cond_stage_models[ - self.cond_stage_model_metadata[key]["model_idx"] - ](c) - else: - # when the cond_stage_key is "all", pick one random element out - if isinstance(c, dict): - c = c[list(c.keys())[0]] - - if isinstance(c, torch.Tensor): - batchsize = c.size(0) - elif isinstance(c, list): - batchsize = len(c) - else: - raise NotImplementedError() - - c = self.cond_stage_models[ - self.cond_stage_model_metadata[key]["model_idx"] - ].get_unconditional_condition(batchsize) - - return c - - def get_input( - self, - batch, - k, - return_first_stage_encode=True, - return_decoding_output=False, - return_encoder_input=False, - return_encoder_output=False, - unconditional_prob_cfg=0.1, - ): - x = super().get_input(batch, k) - - x = x.to(self.device) - - if return_first_stage_encode: - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - else: - z = None - cond_dict = {} - if len(self.cond_stage_model_metadata.keys()) > 0: - unconditional_cfg = False - if self.conditional_dry_run_finished and self.make_decision( - unconditional_prob_cfg - ): - unconditional_cfg = True - for cond_model_key in self.cond_stage_model_metadata.keys(): - cond_stage_key = self.cond_stage_model_metadata[cond_model_key][ - "cond_stage_key" - ] - - if cond_model_key in cond_dict.keys(): - continue - - if not self.training: - if isinstance( - self.cond_stage_models[ - self.cond_stage_model_metadata[cond_model_key]["model_idx"] - ], - CLAPAudioEmbeddingClassifierFreev2, - ): - print( - "Warning: CLAP model normally should use text for evaluation" - ) - - # The original data for conditioning - # If cond_model_key is "all", that means the conditional model need all the information from a batch - - if cond_stage_key != "all": - xc = super().get_input(batch, cond_stage_key) - if type(xc) == torch.Tensor: - xc = xc.to(self.device) - else: - xc = batch - - # if cond_stage_key is "all", xc will be a dictionary containing all keys - # Otherwise xc will be an entry of the dictionary - c = self.get_learned_conditioning( - xc, key=cond_model_key, unconditional_cfg=unconditional_cfg - ) - - # cond_dict will be used to condition the diffusion model - # If one conditional model return multiple conditioning signal - if isinstance(c, dict): - for k in c.keys(): - cond_dict[k] = c[k] - else: - cond_dict[cond_model_key] = c - - # If the key is accidently added to the dictionary and not in the condition list, remove the condition - # for k in list(cond_dict.keys()): - # if(k not in self.cond_stage_model_metadata.keys()): - # del cond_dict[k] - - out = [z, cond_dict] - - if return_decoding_output: - xrec = self.decode_first_stage(z) - out += [xrec] - - if return_encoder_input: - out += [x] - - if return_encoder_output: - out += [encoder_posterior] - - if not self.conditional_dry_run_finished: - self.conditional_dry_run_finished = True - - # Output is a dictionary, where the value could only be tensor or tuple - return out - - def decode_first_stage(self, z): - with torch.no_grad(): - z = 1.0 / self.scale_factor * z - decoding = self.first_stage_model.decode(z) - return decoding - - def mel_spectrogram_to_waveform( - self, mel, savepath=".", bs=None, name="outwav", save=True - ): - # Mel: [bs, 1, t-steps, fbins] - if len(mel.size()) == 4: - mel = mel.squeeze(1) - mel = mel.permute(0, 2, 1) - waveform = self.first_stage_model.vocoder(mel) - waveform = waveform.cpu().detach().numpy() - if save: - self.save_waveform(waveform, savepath, name) - return waveform - - def encode_first_stage(self, x): - with torch.no_grad(): - return self.first_stage_model.encode(x) - - def extract_possible_loss_in_cond_dict(self, cond_dict): - # This function enable the conditional module to return loss function that can optimize them - - assert isinstance(cond_dict, dict) - losses = {} - - for cond_key in cond_dict.keys(): - if "loss" in cond_key and "noncond" in cond_key: - assert cond_key not in losses.keys() - losses[cond_key] = cond_dict[cond_key] - - return losses - - def filter_useful_cond_dict(self, cond_dict): - new_cond_dict = {} - for key in cond_dict.keys(): - if key in self.cond_stage_model_metadata.keys(): - new_cond_dict[key] = cond_dict[key] - - # All the conditional key in the metadata should be used - for key in self.cond_stage_model_metadata.keys(): - assert key in new_cond_dict.keys(), "%s, %s" % ( - key, - str(new_cond_dict.keys()), - ) - - return new_cond_dict - - def shared_step(self, batch, **kwargs): - if self.training: - # Classifier-free guidance - unconditional_prob_cfg = self.unconditional_prob_cfg - else: - unconditional_prob_cfg = 0.0 # TODO possible bug here - - x, c = self.get_input( - batch, self.first_stage_key, unconditional_prob_cfg=unconditional_prob_cfg - ) - - if self.optimize_ddpm_parameter: - loss, loss_dict = self(x, self.filter_useful_cond_dict(c)) - else: - loss_dict = {} - loss = None - - additional_loss_for_cond_modules = self.extract_possible_loss_in_cond_dict(c) - assert isinstance(additional_loss_for_cond_modules, dict) - - loss_dict.update(additional_loss_for_cond_modules) - - if len(additional_loss_for_cond_modules.keys()) > 0: - for k in additional_loss_for_cond_modules.keys(): - if loss is None: - loss = additional_loss_for_cond_modules[k] - else: - loss = loss + additional_loss_for_cond_modules[k] - - # for k,v in additional_loss_for_cond_modules.items(): - # self.log( - # "cond_stage/"+k, - # float(v), - # prog_bar=True, - # logger=True, - # on_step=True, - # on_epoch=True, - # ) - if self.training: - assert loss is not None - - return loss, loss_dict - - def forward(self, x, c, *args, **kwargs): - t = torch.randint( - 0, self.num_timesteps, (x.shape[0],), device=self.device - ).long() - - # assert c is not None - # c = self.get_learned_conditioning(c) - - loss, loss_dict = self.p_losses(x, c, t, *args, **kwargs) - return loss, loss_dict - - def reorder_cond_dict(self, cond_dict): - # To make sure the order is correct - new_cond_dict = {} - for key in self.conditioning_key: - new_cond_dict[key] = cond_dict[key] - return new_cond_dict - - def apply_model(self, x_noisy, t, cond, return_ids=False): - cond = self.reorder_cond_dict(cond) - - x_recon = self.model(x_noisy, t, cond_dict=cond) - - if isinstance(x_recon, tuple) and not return_ids: - return x_recon[0] - else: - return x_recon - - def p_losses(self, x_start, cond, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - model_output = self.apply_model(x_noisy, t, cond) - - loss_dict = {} - prefix = "train" if self.training else "val" - - if self.parameterization == "x0": - target = x_start - elif self.parameterization == "eps": - target = noise - elif self.parameterization == "v": - target = self.get_v(x_start, noise, t) - else: - raise NotImplementedError() - # print(model_output.size(), target.size()) - loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) - loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()}) - - logvar_t = self.logvar[t].to(self.device) - loss = loss_simple / torch.exp(logvar_t) + logvar_t - # loss = loss_simple / torch.exp(self.logvar) + self.logvar - if self.learn_logvar: - loss_dict.update({f"{prefix}/loss_gamma": loss.mean()}) - loss_dict.update({"logvar": self.logvar.data.mean()}) - - loss = self.l_simple_weight * loss.mean() - - loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) - loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() - loss_dict.update({f"{prefix}/loss_vlb": loss_vlb}) - loss += self.original_elbo_weight * loss_vlb - loss_dict.update({f"{prefix}/loss": loss}) - - return loss, loss_dict - - def p_mean_variance( - self, - x, - c, - t, - clip_denoised: bool, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - score_corrector=None, - corrector_kwargs=None, - ): - t_in = t - model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) - - if score_corrector is not None: - assert self.parameterization == "eps" - model_out = score_corrector.modify_score( - self, model_out, x, t, c, **corrector_kwargs - ) - - if return_codebook_ids: - model_out, logits = model_out - - if self.parameterization == "eps": - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": - x_recon = model_out - else: - raise NotImplementedError() - - if clip_denoised: - x_recon.clamp_(-1.0, 1.0) - if quantize_denoised: - x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) - model_mean, posterior_variance, posterior_log_variance = self.q_posterior( - x_start=x_recon, x_t=x, t=t - ) - if return_codebook_ids: - return model_mean, posterior_variance, posterior_log_variance, logits - elif return_x0: - return model_mean, posterior_variance, posterior_log_variance, x_recon - else: - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample( - self, - x, - c, - t, - clip_denoised=False, - repeat_noise=False, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - ): - b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance( - x=x, - c=c, - t=t, - clip_denoised=clip_denoised, - return_codebook_ids=return_codebook_ids, - quantize_denoised=quantize_denoised, - return_x0=return_x0, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - ) - if return_codebook_ids: - raise DeprecationWarning("Support dropped.") - model_mean, _, model_log_variance, logits = outputs - elif return_x0: - model_mean, _, model_log_variance, x0 = outputs - else: - model_mean, _, model_log_variance = outputs - - noise = noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - # no noise when t == 0 - nonzero_mask = ( - (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() - ) - - # if return_codebook_ids: - # return model_mean + nonzero_mask * ( - # 0.5 * model_log_variance - # ).exp() * noise, logits.argmax(dim=1) - if return_x0: - return ( - model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, - x0, - ) - else: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def progressive_denoising( - self, - cond, - shape, - verbose=True, - callback=None, - quantize_denoised=False, - img_callback=None, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - batch_size=None, - x_T=None, - start_T=None, - log_every_t=None, - ): - if not log_every_t: - log_every_t = self.log_every_t - timesteps = self.num_timesteps - if batch_size is not None: - b = batch_size if batch_size is not None else shape[0] - shape = [batch_size] + list(shape) - else: - b = batch_size = shape[0] - if x_T is None: - img = torch.randn(shape, device=self.device) - else: - img = x_T - intermediates = [] - if cond is not None: - if isinstance(cond, dict): - cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) - for key in cond - } - else: - cond = ( - [c[:batch_size] for c in cond] - if isinstance(cond, list) - else cond[:batch_size] - ) - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = ( - tqdm( - reversed(range(0, timesteps)), - desc="Progressive Generation", - total=timesteps, - ) - if verbose - else reversed(range(0, timesteps)) - ) - if type(temperature) == float: - temperature = [temperature] * timesteps - - for i in iterator: - ts = torch.full((b,), i, device=self.device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != "hybrid" - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - - img, x0_partial = self.p_sample( - img, - cond, - ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, - return_x0=True, - temperature=temperature[i], - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - ) - if mask is not None: - assert x0 is not None - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1.0 - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(x0_partial) - if callback: - callback(i) - if img_callback: - img_callback(img, i) - return img, intermediates - - @torch.no_grad() - def p_sample_loop( - self, - cond, - shape, - return_intermediates=False, - x_T=None, - verbose=True, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - start_T=None, - log_every_t=None, - ): - if not log_every_t: - log_every_t = self.log_every_t - device = self.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - intermediates = [img] - if timesteps is None: - timesteps = self.num_timesteps - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = ( - tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) - if verbose - else reversed(range(0, timesteps)) - ) - - if mask is not None: - assert x0 is not None - assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match - - for i in iterator: - ts = torch.full((b,), i, device=device, dtype=torch.long) - - if self.shorten_cond_schedule: - assert self.model.conditioning_key != "hybrid" - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - - img = self.p_sample( - img, - cond, - ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, - ) - - if mask is not None: - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1.0 - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(img) - if callback: - callback(i) - if img_callback: - img_callback(img, i) - - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample( - self, - cond, - batch_size=16, - return_intermediates=False, - x_T=None, - verbose=True, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - shape=None, - **kwargs, - ): - if shape is None: - shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size) - if cond is not None: - if isinstance(cond, dict): - cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) - for key in cond - } - else: - cond = ( - [c[:batch_size] for c in cond] - if isinstance(cond, list) - else cond[:batch_size] - ) - return self.p_sample_loop( - cond, - shape, - return_intermediates=return_intermediates, - x_T=x_T, - verbose=verbose, - timesteps=timesteps, - quantize_denoised=quantize_denoised, - mask=mask, - x0=x0, - **kwargs, - ) - - def save_waveform(self, waveform, savepath, name="outwav"): - for i in range(waveform.shape[0]): - if type(name) is str: - path = os.path.join( - savepath, "%s_%s_%s.wav" % (self.global_step, i, name) - ) - elif type(name) is list: - path = os.path.join( - savepath, - "%s.wav" - % ( - os.path.basename(name[i]) - if (not ".wav" in name[i]) - else os.path.basename(name[i]).split(".")[0] - ), - ) - else: - raise NotImplementedError - todo_waveform = waveform[i, 0] - todo_waveform = ( - todo_waveform / np.max(np.abs(todo_waveform)) - ) * 0.8 # Normalize the energy of the generation output - sf.write(path, todo_waveform, samplerate=self.sampling_rate) - - @torch.no_grad() - def sample_log( - self, - cond, - batch_size, - ddim, - ddim_steps, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_plms=False, - mask=None, - **kwargs, - ): - if mask is not None: - shape = (self.channels, mask.size()[-2], mask.size()[-1]) - else: - shape = (self.channels, self.latent_t_size, self.latent_f_size) - - intermediate = None - if ddim and not use_plms: - ddim_sampler = DDIMSampler(self) - samples, intermediates = ddim_sampler.sample( - ddim_steps, - batch_size, - shape, - cond, - verbose=False, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - mask=mask, - **kwargs, - ) - elif use_plms: - plms_sampler = PLMSSampler(self) - samples, intermediates = plms_sampler.sample( - ddim_steps, - batch_size, - shape, - cond, - verbose=False, - unconditional_guidance_scale=unconditional_guidance_scale, - mask=mask, - unconditional_conditioning=unconditional_conditioning, - **kwargs, - ) - - else: - samples, intermediates = self.sample( - cond=cond, - batch_size=batch_size, - return_intermediates=True, - unconditional_guidance_scale=unconditional_guidance_scale, - mask=mask, - unconditional_conditioning=unconditional_conditioning, - **kwargs, - ) - - return samples, intermediate - - @torch.no_grad() - def generate_batch( - self, - batch, - ddim_steps=200, - ddim_eta=1.0, - x_T=None, - n_gen=1, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_plms=False, - **kwargs, - ): - # Generate n_gen times and select the best - # Batch: audio, text, fnames - assert x_T is None - - if use_plms: - assert ddim_steps is not None - - use_ddim = ddim_steps is not None - - # with self.ema_scope("Plotting"): - for i in range(1): - z, c = self.get_input( - batch, - self.first_stage_key, - unconditional_prob_cfg=0.0, # Do not output unconditional information in the c - ) - - c = self.filter_useful_cond_dict(c) - - text = super().get_input(batch, "text") - - # Generate multiple samples - batch_size = z.shape[0] * n_gen - - # Generate multiple samples at a time and filter out the best - # The condition to the diffusion wrapper can have many format - for cond_key in c.keys(): - if isinstance(c[cond_key], list): - for i in range(len(c[cond_key])): - c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0) - elif isinstance(c[cond_key], dict): - for k in c[cond_key].keys(): - c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0) - else: - c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0) - - text = text * n_gen - - if unconditional_guidance_scale != 1.0: - unconditional_conditioning = {} - for key in self.cond_stage_model_metadata: - model_idx = self.cond_stage_model_metadata[key]["model_idx"] - unconditional_conditioning[key] = self.cond_stage_models[ - model_idx - ].get_unconditional_condition(batch_size) - - fnames = list(super().get_input(batch, "fname")) - samples, _ = self.sample_log( - cond=c, - batch_size=batch_size, - x_T=x_T, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - use_plms=use_plms, - ) - - mel = self.decode_first_stage(samples) - - waveform = self.mel_spectrogram_to_waveform( - mel, savepath="", bs=None, name=fnames, save=False - ) - - if n_gen > 1: - best_index = [] - similarity = self.clap.cos_similarity( - torch.FloatTensor(waveform).squeeze(1), text - ) - for i in range(z.shape[0]): - candidates = similarity[i :: z.shape[0]] - max_index = torch.argmax(candidates).item() - best_index.append(i + max_index * z.shape[0]) - - waveform = waveform[best_index] - - print("Similarity between generated audio and text:") - print(' '.join('{:.2f}'.format(num) for num in similarity.detach().cpu().tolist())) - print("Choose the following indexes as the output:", best_index) - - return waveform - - @torch.no_grad() - def generate_sample( - self, - batchs, - ddim_steps=200, - ddim_eta=1.0, - x_T=None, - n_gen=1, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - name=None, - use_plms=False, - limit_num=None, - **kwargs, - ): - # Generate n_gen times and select the best - # Batch: audio, text, fnames - assert x_T is None - try: - batchs = iter(batchs) - except TypeError: - raise ValueError("The first input argument should be an iterable object") - - if use_plms: - assert ddim_steps is not None - - use_ddim = ddim_steps is not None - if name is None: - name = self.get_validation_folder_name() - - waveform_save_path = os.path.join(self.get_log_dir(), name) - os.makedirs(waveform_save_path, exist_ok=True) - print("Waveform save path: ", waveform_save_path) - - if ( - "audiocaps" in waveform_save_path - and len(os.listdir(waveform_save_path)) >= 964 - ): - print("The evaluation has already been done at %s" % waveform_save_path) - return waveform_save_path - - with self.ema_scope("Plotting"): - for i, batch in enumerate(batchs): - z, c = self.get_input( - batch, - self.first_stage_key, - unconditional_prob_cfg=0.0, # Do not output unconditional information in the c - ) - - if limit_num is not None and i * z.size(0) > limit_num: - break - - c = self.filter_useful_cond_dict(c) - - text = super().get_input(batch, "text") - - # Generate multiple samples - batch_size = z.shape[0] * n_gen - - # Generate multiple samples at a time and filter out the best - # The condition to the diffusion wrapper can have many format - for cond_key in c.keys(): - if isinstance(c[cond_key], list): - for i in range(len(c[cond_key])): - c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0) - elif isinstance(c[cond_key], dict): - for k in c[cond_key].keys(): - c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0) - else: - c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0) - - text = text * n_gen - - if unconditional_guidance_scale != 1.0: - unconditional_conditioning = {} - for key in self.cond_stage_model_metadata: - model_idx = self.cond_stage_model_metadata[key]["model_idx"] - unconditional_conditioning[key] = self.cond_stage_models[ - model_idx - ].get_unconditional_condition(batch_size) - - fnames = list(super().get_input(batch, "fname")) - samples, _ = self.sample_log( - cond=c, - batch_size=batch_size, - x_T=x_T, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - use_plms=use_plms, - ) - - mel = self.decode_first_stage(samples) - - waveform = self.mel_spectrogram_to_waveform( - mel, savepath=waveform_save_path, bs=None, name=fnames, save=False - ) - - if n_gen > 1: - try: - best_index = [] - similarity = self.clap.cos_similarity( - torch.FloatTensor(waveform).squeeze(1), text - ) - for i in range(z.shape[0]): - candidates = similarity[i :: z.shape[0]] - max_index = torch.argmax(candidates).item() - best_index.append(i + max_index * z.shape[0]) - - waveform = waveform[best_index] - - print("Similarity between generated audio and text", similarity) - print("Choose the following indexes:", best_index) - except Exception as e: - print("Warning: while calculating CLAP score (not fatal), ", e) - self.save_waveform(waveform, waveform_save_path, name=fnames) - return waveform_save_path - - -class DiffusionWrapper(nn.Module): - def __init__(self, diff_model_config, conditioning_key): - super().__init__() - self.diffusion_model = instantiate_from_config(diff_model_config) - - self.conditioning_key = conditioning_key - - for key in self.conditioning_key: - if ( - "concat" in key - or "crossattn" in key - or "hybrid" in key - or "film" in key - or "noncond" in key - ): - continue - else: - raise Value("The conditioning key %s is illegal" % key) - - self.being_verbosed_once = False - - def forward(self, x, t, cond_dict: dict = {}): - x = x.contiguous() - t = t.contiguous() - - # x with condition (or maybe not) - xc = x - - y = None - context_list, attn_mask_list = [], [] - - conditional_keys = cond_dict.keys() - - for key in conditional_keys: - if "concat" in key: - xc = torch.cat([x, cond_dict[key].unsqueeze(1)], dim=1) - elif "film" in key: - if y is None: - y = cond_dict[key].squeeze(1) - else: - y = torch.cat([y, cond_dict[key].squeeze(1)], dim=-1) - elif "crossattn" in key: - # assert context is None, "You can only have one context matrix, got %s" % (cond_dict.keys()) - if isinstance(cond_dict[key], dict): - for k in cond_dict[key].keys(): - if "crossattn" in k: - context, attn_mask = cond_dict[key][ - k - ] # crossattn_audiomae_pooled: torch.Size([12, 128, 768]) - else: - assert len(cond_dict[key]) == 2, ( - "The context condition for %s you returned should have two element, one context one mask" - % (key) - ) - context, attn_mask = cond_dict[key] - - # The input to the UNet model is a list of context matrix - context_list.append(context) - attn_mask_list.append(attn_mask) - - elif ( - "noncond" in key - ): # If you use loss function in the conditional module, include the keyword "noncond" in the return dictionary - continue - else: - raise NotImplementedError() - - # if(not self.being_verbosed_once): - # print("The input shape to the diffusion model is as follows:") - # print("xc", xc.size()) - # print("t", t.size()) - # for i in range(len(context_list)): - # print("context_%s" % i, context_list[i].size(), attn_mask_list[i].size()) - # if(y is not None): - # print("y", y.size()) - # self.being_verbosed_once = True - out = self.diffusion_model( - xc, t, context_list=context_list, y=y, context_attn_mask_list=attn_mask_list - ) - return out - self.warmup_step() - - if ( - self.state is None - and len(self.trainer.optimizers[0].state_dict()["state"].keys()) > 0 - ): - self.state = ( - self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"].clone() - ) - elif self.state is not None and batch_idx % 1000 == 0: - assert ( - torch.sum( - torch.abs( - self.state - - self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"] - ) - ) - > 1e-7 - ), "Optimizer is not working" - - if len(self.metrics_buffer.keys()) > 0: - for k in self.metrics_buffer.keys(): - self.log( - k, - self.metrics_buffer[k], - prog_bar=False, - logger=True, - on_step=True, - on_epoch=False, - ) - print(k, self.metrics_buffer[k]) - self.metrics_buffer = {} - - loss, loss_dict = self.shared_step(batch) - - self.log_dict( - {k: float(v) for k, v in loss_dict.items()}, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, - ) - - self.log( - "global_step", - float(self.global_step), - prog_bar=True, - logger=True, - on_step=True, - on_epoch=False, - ) - - lr = self.trainer.optimizers[0].param_groups[0]["lr"] - self.log( - "lr_abs", - float(lr), - prog_bar=True, - logger=True, - on_step=True, - on_epoch=False, - ) - - -if __name__ == "__main__": - import yaml - - model_config = "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/stable-diffusion/models/ldm/text2img256/config.yaml" - model_config = yaml.load(open(model_config, "r"), Loader=yaml.FullLoader) - - latent_diffusion = LatentDiffusion(**model_config["model"]["params"]) - - import ipdb - - ipdb.set_trace() diff --git a/audioldm2/latent_diffusion/models/plms.py b/audioldm2/latent_diffusion/models/plms.py deleted file mode 100755 index 9c80796442bd653ac3dc1970c12f621068a4d821..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/models/plms.py +++ /dev/null @@ -1,360 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm - -from audioldm2.latent_diffusion.modules.diffusionmodules.util import ( - make_ddim_sampling_parameters, - make_ddim_timesteps, - noise_like, -) - - -class PLMSSampler(object): - def __init__(self, model, schedule="linear", **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - def make_schedule( - self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True - ): - if ddim_eta != 0: - ddim_eta = 0 - # raise ValueError('ddim_eta must be 0 for PLMS') - - self.ddim_timesteps = make_ddim_timesteps( - ddim_discr_method=ddim_discretize, - num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, - ) - alphas_cumprod = self.model.alphas_cumprod - assert ( - alphas_cumprod.shape[0] == self.ddpm_num_timesteps - ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer("betas", to_torch(self.model.betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer( - "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) - ) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer( - "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", - to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), - ) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( - alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta, - verbose=verbose, - ) - self.register_buffer("ddim_sigmas", ddim_sigmas) - self.register_buffer("ddim_alphas", ddim_alphas) - self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) - self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) - / (1 - self.alphas_cumprod) - * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) - ) - self.register_buffer( - "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps - ) - - @torch.no_grad() - def sample( - self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0.0, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs, - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print( - f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" - ) - else: - if conditioning.shape[0] != batch_size: - print( - f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" - ) - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f"Data shape for PLMS sampling is {size}") - - samples, intermediates = self.plms_sampling( - conditioning, - size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, - x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - return samples, intermediates - - @torch.no_grad() - def plms_sampling( - self, - cond, - shape, - x_T=None, - ddim_use_original_steps=False, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - log_every_t=100, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - ): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = ( - self.ddpm_num_timesteps - if ddim_use_original_steps - else self.ddim_timesteps - ) - elif timesteps is not None and not ddim_use_original_steps: - subset_end = ( - int( - min(timesteps / self.ddim_timesteps.shape[0], 1) - * self.ddim_timesteps.shape[0] - ) - - 1 - ) - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {"x_inter": [img], "pred_x0": [img]} - time_range = ( - list(reversed(range(0, timesteps))) - if ddim_use_original_steps - else np.flip(timesteps) - ) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running PLMS Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps) - old_eps = [] - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - ts_next = torch.full( - (b,), - time_range[min(i + 1, len(time_range) - 1)], - device=device, - dtype=torch.long, - ) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample( - x0, ts - ) # TODO: deterministic forward pass? - img = img_orig * mask + (1.0 - mask) * img - - outs = self.p_sample_plms( - img, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, - t_next=ts_next, - ) - img, pred_x0, e_t = outs - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) - if callback: - callback(i) - if img_callback: - img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates["x_inter"].append(img) - intermediates["pred_x0"].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_plms( - self, - x, - c, - t, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - old_eps=None, - t_next=None, - ): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if ( - unconditional_conditioning is None - or unconditional_guidance_scale == 1.0 - ): - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score( - self.model, e_t, x, t, c, **corrector_kwargs - ) - - return e_t - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = ( - self.model.alphas_cumprod_prev - if use_original_steps - else self.ddim_alphas_prev - ) - sqrt_one_minus_alphas = ( - self.model.sqrt_one_minus_alphas_cumprod - if use_original_steps - else self.ddim_sqrt_one_minus_alphas - ) - sigmas = ( - self.model.ddim_sigmas_for_original_num_steps - if use_original_steps - else self.ddim_sigmas - ) - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full( - (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device - ) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = ( - 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] - ) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t diff --git a/audioldm2/latent_diffusion/modules/__init__.py b/audioldm2/latent_diffusion/modules/__init__.py deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm2/latent_diffusion/modules/attention.py b/audioldm2/latent_diffusion/modules/attention.py deleted file mode 100755 index 6116342da98249c681ddb5f696b48dc0f5ac69f2..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/attention.py +++ /dev/null @@ -1,467 +0,0 @@ -from inspect import isfunction -import math -import torch -import torch.nn.functional as F -from torch import nn, einsum -from einops import rearrange, repeat - -from audioldm2.latent_diffusion.modules.diffusionmodules.util import checkpoint - - -def exists(val): - return val is not None - - -def uniq(arr): - return {el: True for el in arr}.keys() - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -def init_(tensor): - dim = tensor.shape[-1] - std = 1 / math.sqrt(dim) - tensor.uniform_(-std, std) - return tensor - - -# feedforward -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = ( - nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) - if not glu - else GEGLU(dim, inner_dim) - ) - - self.net = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def Normalize(in_channels): - return torch.nn.GroupNorm( - num_groups=32, num_channels=in_channels, eps=1e-6, affine=True - ) - - -class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super().__init__() - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - q, k, v = rearrange( - qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 - ) - k = k.softmax(dim=-1) - context = torch.einsum("bhdn,bhen->bhde", k, v) - out = torch.einsum("bhde,bhdn->bhen", context, q) - out = rearrange( - out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w - ) - return self.to_out(out) - - -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = rearrange(q, "b c h w -> b (h w) c") - k = rearrange(k, "b c h w -> b c (h w)") - w_ = torch.einsum("bij,bjk->bik", q, k) - - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, "b c h w -> b c (h w)") - w_ = rearrange(w_, "b i j -> b j i") - h_ = torch.einsum("bij,bjk->bik", v, w_) - h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) - h_ = self.proj_out(h_) - - return x + h_ - - -# class CrossAttention(nn.Module): -# """ -# ### Cross Attention Layer -# This falls-back to self-attention when conditional embeddings are not specified. -# """ - -# use_flash_attention: bool = True - -# # use_flash_attention: bool = False -# def __init__( -# self, -# query_dim, -# context_dim=None, -# heads=8, -# dim_head=64, -# dropout=0.0, -# is_inplace: bool = True, -# ): -# # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True): -# """ -# :param d_model: is the input embedding size -# :param n_heads: is the number of attention heads -# :param d_head: is the size of a attention head -# :param d_cond: is the size of the conditional embeddings -# :param is_inplace: specifies whether to perform the attention softmax computation inplace to -# save memory -# """ -# super().__init__() - -# self.is_inplace = is_inplace -# self.n_heads = heads -# self.d_head = dim_head - -# # Attention scaling factor -# self.scale = dim_head**-0.5 - -# # The normal self-attention layer -# if context_dim is None: -# context_dim = query_dim - -# # Query, key and value mappings -# d_attn = dim_head * heads -# self.to_q = nn.Linear(query_dim, d_attn, bias=False) -# self.to_k = nn.Linear(context_dim, d_attn, bias=False) -# self.to_v = nn.Linear(context_dim, d_attn, bias=False) - -# # Final linear layer -# self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout)) - -# # Setup [flash attention](https://github.com/HazyResearch/flash-attention). -# # Flash attention is only used if it's installed -# # and `CrossAttention.use_flash_attention` is set to `True`. -# try: -# # You can install flash attention by cloning their Github repo, -# # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention) -# # and then running `python setup.py install` -# from flash_attn.flash_attention import FlashAttention - -# self.flash = FlashAttention() -# # Set the scale for scaled dot-product attention. -# self.flash.softmax_scale = self.scale -# # Set to `None` if it's not installed -# except ImportError: -# self.flash = None - -# def forward(self, x, context=None, mask=None): -# """ -# :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` -# :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` -# """ - -# # If `cond` is `None` we perform self attention -# has_cond = context is not None -# if not has_cond: -# context = x - -# # Get query, key and value vectors -# q = self.to_q(x) -# k = self.to_k(context) -# v = self.to_v(context) - -# # Use flash attention if it's available and the head size is less than or equal to `128` -# if ( -# CrossAttention.use_flash_attention -# and self.flash is not None -# and not has_cond -# and self.d_head <= 128 -# ): -# return self.flash_attention(q, k, v) -# # Otherwise, fallback to normal attention -# else: -# return self.normal_attention(q, k, v) - -# def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): -# """ -# #### Flash Attention -# :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` -# :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` -# :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` -# """ - -# # Get batch size and number of elements along sequence axis (`width * height`) -# batch_size, seq_len, _ = q.shape - -# # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of -# # shape `[batch_size, seq_len, 3, n_heads * d_head]` -# qkv = torch.stack((q, k, v), dim=2) -# # Split the heads -# qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head) - -# # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to -# # fit this size. -# if self.d_head <= 32: -# pad = 32 - self.d_head -# elif self.d_head <= 64: -# pad = 64 - self.d_head -# elif self.d_head <= 128: -# pad = 128 - self.d_head -# else: -# raise ValueError(f"Head size ${self.d_head} too large for Flash Attention") - -# # Pad the heads -# if pad: -# qkv = torch.cat( -# (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1 -# ) - -# # Compute attention -# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ -# # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]` -# # TODO here I add the dtype changing -# out, _ = self.flash(qkv.type(torch.float16)) -# # Truncate the extra head size -# out = out[:, :, :, : self.d_head].float() -# # Reshape to `[batch_size, seq_len, n_heads * d_head]` -# out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head) - -# # Map to `[batch_size, height * width, d_model]` with a linear layer -# return self.to_out(out) - -# def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): -# """ -# #### Normal Attention - -# :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` -# :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` -# :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` -# """ - -# # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` -# q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32] -# k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32] -# v = v.view(*v.shape[:2], self.n_heads, -1) - -# # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ -# attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale - -# # Compute softmax -# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ -# if self.is_inplace: -# half = attn.shape[0] // 2 -# attn[half:] = attn[half:].softmax(dim=-1) -# attn[:half] = attn[:half].softmax(dim=-1) -# else: -# attn = attn.softmax(dim=-1) - -# # Compute attention output -# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ -# # attn: [bs, 20, 64, 1] -# # v: [bs, 1, 20, 32] -# out = torch.einsum("bhij,bjhd->bihd", attn, v) -# # Reshape to `[batch_size, height * width, n_heads * d_head]` -# out = out.reshape(*out.shape[:2], -1) -# # Map to `[batch_size, height * width, d_model]` with a linear layer -# return self.to_out(out) - - -class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.scale = dim_head**-0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) - - def forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - k = self.to_k(context) - v = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - - sim = einsum("b i d, b j d -> b i j", q, k) * self.scale - - if exists(mask): - mask = rearrange(mask, "b ... -> b (...)") - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, "b j -> (b h) () j", h=h) - sim.masked_fill_(~(mask == 1), max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum("b i j, b j d -> b i d", attn, v) - out = rearrange(out, "(b h) n d -> b n (h d)", h=h) - return self.to_out(out) - - -class BasicTransformerBlock(nn.Module): - def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - ): - super().__init__() - self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention( - query_dim=dim, - context_dim=context_dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def forward(self, x, context=None, mask=None): - if context is None: - return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) - else: - return checkpoint( - self._forward, (x, context, mask), self.parameters(), self.checkpoint - ) - - def _forward(self, x, context=None, mask=None): - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context, mask=mask) + x - x = self.ff(self.norm3(x)) + x - return x - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. - First, project the input (aka embedding) - and reshape to b, t, d. - Then apply standard transformer action. - Finally, reshape to image - """ - - def __init__( - self, - in_channels, - n_heads, - d_head, - depth=1, - dropout=0.0, - context_dim=None, - ): - super().__init__() - - context_dim = context_dim - - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) - - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim - ) - for d in range(depth) - ] - ) - - self.proj_out = zero_module( - nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - ) - - def forward(self, x, context=None, mask=None): - # note: if no context is given, cross-attention defaults to self-attention - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = rearrange(x, "b c h w -> b (h w) c") - for block in self.transformer_blocks: - x = block(x, context=context, mask=mask) - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) - x = self.proj_out(x) - return x + x_in diff --git a/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py b/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py deleted file mode 100755 index f02fa05e163076641b92bbeabceb5f89edb0f18e..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py +++ /dev/null @@ -1,149 +0,0 @@ -""" -Reference Repo: https://github.com/facebookresearch/AudioMAE -""" - -import torch -import torch.nn as nn -from timm.models.layers import to_2tuple -import audioldm2.latent_diffusion.modules.audiomae.models_vit as models_vit -import audioldm2.latent_diffusion.modules.audiomae.models_mae as models_mae - -# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128)) - - -class PatchEmbed_new(nn.Module): - """Flexible Image to Patch Embedding""" - - def __init__( - self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 - ): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - stride = to_2tuple(stride) - - self.img_size = img_size - self.patch_size = patch_size - - self.proj = nn.Conv2d( - in_chans, embed_dim, kernel_size=patch_size, stride=stride - ) # with overlapped patches - # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - - # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) - # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) - _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w - self.patch_hw = (h, w) - self.num_patches = h * w - - def get_output_shape(self, img_size): - # todo: don't be lazy.. - return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - # assert H == self.img_size[0] and W == self.img_size[1], \ - # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x) - x = x.flatten(2).transpose(1, 2) - return x - - -class AudioMAE(nn.Module): - """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)""" - - def __init__( - self, - ): - super().__init__() - model = models_vit.__dict__["vit_base_patch16"]( - num_classes=527, - drop_path_rate=0.1, - global_pool=True, - mask_2d=True, - use_custom_patch=False, - ) - - img_size = (1024, 128) - emb_dim = 768 - - model.patch_embed = PatchEmbed_new( - img_size=img_size, - patch_size=(16, 16), - in_chans=1, - embed_dim=emb_dim, - stride=16, - ) - num_patches = model.patch_embed.num_patches - # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8 - model.pos_embed = nn.Parameter( - torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False - ) # fixed sin-cos embedding - - # checkpoint_path = '/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth' - # checkpoint = torch.load(checkpoint_path, map_location='cpu') - # msg = model.load_state_dict(checkpoint['model'], strict=False) - # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') - - self.model = model - - def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0): - """ - x: mel fbank [Batch, 1, T, F] - mask_t_prob: 'T masking ratio (percentage of removed patches).' - mask_f_prob: 'F masking ratio (percentage of removed patches).' - """ - return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob) - - -class Vanilla_AudioMAE(nn.Module): - """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)""" - - def __init__( - self, - ): - super().__init__() - model = models_mae.__dict__["mae_vit_base_patch16"]( - in_chans=1, audio_exp=True, img_size=(1024, 128) - ) - - # checkpoint_path = '/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth' - # checkpoint = torch.load(checkpoint_path, map_location='cpu') - # msg = model.load_state_dict(checkpoint['model'], strict=False) - - # Skip the missing keys of decoder modules (not required) - # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}') - - self.model = model.eval() - - def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False): - """ - x: mel fbank [Batch, 1, 1024 (T), 128 (F)] - mask_ratio: 'masking ratio (percentage of removed patches).' - """ - with torch.no_grad(): - # embed: [B, 513, 768] for mask_ratio=0.0 - if no_mask: - if no_average: - raise RuntimeError("This function is deprecated") - embed = self.model.forward_encoder_no_random_mask_no_average( - x - ) # mask_ratio - else: - embed = self.model.forward_encoder_no_mask(x) # mask_ratio - else: - raise RuntimeError("This function is deprecated") - embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio) - return embed - - -if __name__ == "__main__": - model = Vanilla_AudioMAE().cuda() - input = torch.randn(4, 1, 1024, 128).cuda() - print("The first run") - embed = model(input, mask_ratio=0.0, no_mask=True) - print(embed) - print("The second run") - embed = model(input, mask_ratio=0.0) - print(embed) diff --git a/audioldm2/latent_diffusion/modules/audiomae/__init__.py b/audioldm2/latent_diffusion/modules/audiomae/__init__.py deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm2/latent_diffusion/modules/audiomae/models_mae.py b/audioldm2/latent_diffusion/modules/audiomae/models_mae.py deleted file mode 100755 index 7ab0076710a08a7451dd4096bd6eb2f8f6e641aa..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/audiomae/models_mae.py +++ /dev/null @@ -1,613 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# References: -# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm -# DeiT: https://github.com/facebookresearch/deit -# -------------------------------------------------------- - -from functools import partial - -import torch -import torch.nn as nn - -from timm.models.vision_transformer import Block -from audioldm2.latent_diffusion.modules.audiomae.util.pos_embed import ( - get_2d_sincos_pos_embed, - get_2d_sincos_pos_embed_flexible, -) -from audioldm2.latent_diffusion.modules.audiomae.util.patch_embed import ( - PatchEmbed_new, - PatchEmbed_org, -) - - -class MaskedAutoencoderViT(nn.Module): - """Masked Autoencoder with VisionTransformer backbone""" - - def __init__( - self, - img_size=224, - patch_size=16, - stride=10, - in_chans=3, - embed_dim=1024, - depth=24, - num_heads=16, - decoder_embed_dim=512, - decoder_depth=8, - decoder_num_heads=16, - mlp_ratio=4.0, - norm_layer=nn.LayerNorm, - norm_pix_loss=False, - audio_exp=False, - alpha=0.0, - temperature=0.2, - mode=0, - contextual_depth=8, - use_custom_patch=False, - split_pos=False, - pos_trainable=False, - use_nce=False, - beta=4.0, - decoder_mode=0, - mask_t_prob=0.6, - mask_f_prob=0.5, - mask_2d=False, - epoch=0, - no_shift=False, - ): - super().__init__() - - self.audio_exp = audio_exp - self.embed_dim = embed_dim - self.decoder_embed_dim = decoder_embed_dim - # -------------------------------------------------------------------------- - # MAE encoder specifics - if use_custom_patch: - print( - f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}" - ) - self.patch_embed = PatchEmbed_new( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - stride=stride, - ) - else: - self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim) - self.use_custom_patch = use_custom_patch - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - - # self.split_pos = split_pos # not useful - self.pos_embed = nn.Parameter( - torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable - ) # fixed sin-cos embedding - - self.encoder_depth = depth - self.contextual_depth = contextual_depth - self.blocks = nn.ModuleList( - [ - Block( - embed_dim, - num_heads, - mlp_ratio, - qkv_bias=True, - norm_layer=norm_layer, - ) # qk_scale=None - for i in range(depth) - ] - ) - self.norm = norm_layer(embed_dim) - - # -------------------------------------------------------------------------- - # MAE decoder specifics - self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) - - self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) - self.decoder_pos_embed = nn.Parameter( - torch.zeros(1, num_patches + 1, decoder_embed_dim), - requires_grad=pos_trainable, - ) # fixed sin-cos embedding - - self.no_shift = no_shift - - self.decoder_mode = decoder_mode - if ( - self.use_custom_patch - ): # overlapped patches as in AST. Similar performance yet compute heavy - window_size = (6, 6) - feat_size = (102, 12) - else: - window_size = (4, 4) - feat_size = (64, 8) - if self.decoder_mode == 1: - decoder_modules = [] - for index in range(16): - if self.no_shift: - shift_size = (0, 0) - else: - if (index % 2) == 0: - shift_size = (0, 0) - else: - shift_size = (2, 0) - # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]) - decoder_modules.append( - SwinTransformerBlock( - dim=decoder_embed_dim, - num_heads=16, - feat_size=feat_size, - window_size=window_size, - shift_size=shift_size, - mlp_ratio=mlp_ratio, - drop=0.0, - drop_attn=0.0, - drop_path=0.0, - extra_norm=False, - sequential_attn=False, - norm_layer=norm_layer, # nn.LayerNorm, - ) - ) - self.decoder_blocks = nn.ModuleList(decoder_modules) - else: - # Transfomer - self.decoder_blocks = nn.ModuleList( - [ - Block( - decoder_embed_dim, - decoder_num_heads, - mlp_ratio, - qkv_bias=True, - norm_layer=norm_layer, - ) # qk_scale=None, - for i in range(decoder_depth) - ] - ) - - self.decoder_norm = norm_layer(decoder_embed_dim) - self.decoder_pred = nn.Linear( - decoder_embed_dim, patch_size**2 * in_chans, bias=True - ) # decoder to patch - - # -------------------------------------------------------------------------- - - self.norm_pix_loss = norm_pix_loss - - self.patch_size = patch_size - self.stride = stride - - # audio exps - self.alpha = alpha - self.T = temperature - self.mode = mode - self.use_nce = use_nce - self.beta = beta - - self.log_softmax = nn.LogSoftmax(dim=-1) - - self.mask_t_prob = mask_t_prob - self.mask_f_prob = mask_f_prob - self.mask_2d = mask_2d - - self.epoch = epoch - - self.initialize_weights() - - def initialize_weights(self): - # initialization - # initialize (and freeze) pos_embed by sin-cos embedding - if self.audio_exp: - pos_embed = get_2d_sincos_pos_embed_flexible( - self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True - ) - else: - pos_embed = get_2d_sincos_pos_embed( - self.pos_embed.shape[-1], - int(self.patch_embed.num_patches**0.5), - cls_token=True, - ) - self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) - - if self.audio_exp: - decoder_pos_embed = get_2d_sincos_pos_embed_flexible( - self.decoder_pos_embed.shape[-1], - self.patch_embed.patch_hw, - cls_token=True, - ) - else: - decoder_pos_embed = get_2d_sincos_pos_embed( - self.decoder_pos_embed.shape[-1], - int(self.patch_embed.num_patches**0.5), - cls_token=True, - ) - self.decoder_pos_embed.data.copy_( - torch.from_numpy(decoder_pos_embed).float().unsqueeze(0) - ) - - # initialize patch_embed like nn.Linear (instead of nn.Conv2d) - w = self.patch_embed.proj.weight.data - torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - - # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) - torch.nn.init.normal_(self.cls_token, std=0.02) - torch.nn.init.normal_(self.mask_token, std=0.02) - - # initialize nn.Linear and nn.LayerNorm - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - # we use xavier_uniform following official JAX ViT: - torch.nn.init.xavier_uniform_(m.weight) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def patchify(self, imgs): - """ - imgs: (N, 3, H, W) - x: (N, L, patch_size**2 *3) - L = (H/p)*(W/p) - """ - p = self.patch_embed.patch_size[0] - # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 - - if self.audio_exp: - if self.use_custom_patch: # overlapped patch - h, w = self.patch_embed.patch_hw - # todo: fixed h/w patch size and stride size. Make hw custom in the future - x = imgs.unfold(2, self.patch_size, self.stride).unfold( - 3, self.patch_size, self.stride - ) # n,1,H,W -> n,1,h,w,p,p - x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) - # x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) - # x = torch.einsum('nchpwq->nhwpqc', x) - # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) - else: - h = imgs.shape[2] // p - w = imgs.shape[3] // p - # h,w = self.patch_embed.patch_hw - x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) - x = torch.einsum("nchpwq->nhwpqc", x) - x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) - else: - h = w = imgs.shape[2] // p - x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) - x = torch.einsum("nchpwq->nhwpqc", x) - x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) - - return x - - def unpatchify(self, x): - """ - x: (N, L, patch_size**2 *3) - specs: (N, 1, H, W) - """ - p = self.patch_embed.patch_size[0] - h = 1024 // p - w = 128 // p - x = x.reshape(shape=(x.shape[0], h, w, p, p, 1)) - x = torch.einsum("nhwpqc->nchpwq", x) - specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p)) - return specs - - def random_masking(self, x, mask_ratio): - """ - Perform per-sample random masking by per-sample shuffling. - Per-sample shuffling is done by argsort random noise. - x: [N, L, D], sequence - """ - N, L, D = x.shape # batch, length, dim - len_keep = int(L * (1 - mask_ratio)) - - noise = torch.rand(N, L, device=x.device) # noise in [0, 1] - - # sort noise for each sample - ids_shuffle = torch.argsort( - noise, dim=1 - ) # ascend: small is keep, large is remove - ids_restore = torch.argsort(ids_shuffle, dim=1) - - # keep the first subset - ids_keep = ids_shuffle[:, :len_keep] - x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) - - # generate the binary mask: 0 is keep, 1 is remove - mask = torch.ones([N, L], device=x.device) - mask[:, :len_keep] = 0 - # unshuffle to get the binary mask - mask = torch.gather(mask, dim=1, index=ids_restore) - - return x_masked, mask, ids_restore - - def random_masking_2d(self, x, mask_t_prob, mask_f_prob): - """ - 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) - Perform per-sample random masking by per-sample shuffling. - Per-sample shuffling is done by argsort random noise. - x: [N, L, D], sequence - """ - N, L, D = x.shape # batch, length, dim - if self.use_custom_patch: # overlapped patch - T = 101 - F = 12 - else: - T = 64 - F = 8 - # x = x.reshape(N, T, F, D) - len_keep_t = int(T * (1 - mask_t_prob)) - len_keep_f = int(F * (1 - mask_f_prob)) - - # noise for mask in time - noise_t = torch.rand(N, T, device=x.device) # noise in [0, 1] - # sort noise for each sample aling time - ids_shuffle_t = torch.argsort( - noise_t, dim=1 - ) # ascend: small is keep, large is remove - ids_restore_t = torch.argsort(ids_shuffle_t, dim=1) - ids_keep_t = ids_shuffle_t[:, :len_keep_t] - # noise mask in freq - noise_f = torch.rand(N, F, device=x.device) # noise in [0, 1] - ids_shuffle_f = torch.argsort( - noise_f, dim=1 - ) # ascend: small is keep, large is remove - ids_restore_f = torch.argsort(ids_shuffle_f, dim=1) - ids_keep_f = ids_shuffle_f[:, :len_keep_f] # - - # generate the binary mask: 0 is keep, 1 is remove - # mask in freq - mask_f = torch.ones(N, F, device=x.device) - mask_f[:, :len_keep_f] = 0 - mask_f = ( - torch.gather(mask_f, dim=1, index=ids_restore_f) - .unsqueeze(1) - .repeat(1, T, 1) - ) # N,T,F - # mask in time - mask_t = torch.ones(N, T, device=x.device) - mask_t[:, :len_keep_t] = 0 - mask_t = ( - torch.gather(mask_t, dim=1, index=ids_restore_t) - .unsqueeze(1) - .repeat(1, F, 1) - .permute(0, 2, 1) - ) # N,T,F - mask = 1 - (1 - mask_t) * (1 - mask_f) # N, T, F - - # get masked x - id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device) - id2res = id2res + 999 * mask # add a large value for masked elements - id2res2 = torch.argsort(id2res.flatten(start_dim=1)) - ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t] - x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) - - ids_restore = torch.argsort(id2res2.flatten(start_dim=1)) - mask = mask.flatten(start_dim=1) - - return x_masked, mask, ids_restore - - def forward_encoder(self, x, mask_ratio, mask_2d=False): - # embed patches - x = self.patch_embed(x) - # add pos embed w/o cls token - x = x + self.pos_embed[:, 1:, :] - - # masking: length -> length * mask_ratio - if mask_2d: - x, mask, ids_restore = self.random_masking_2d( - x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob - ) - else: - x, mask, ids_restore = self.random_masking(x, mask_ratio) - - # append cls token - cls_token = self.cls_token + self.pos_embed[:, :1, :] - cls_tokens = cls_token.expand(x.shape[0], -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - - # apply Transformer blocks - for blk in self.blocks: - x = blk(x) - x = self.norm(x) - - return x, mask, ids_restore, None - - def forward_encoder_no_random_mask_no_average(self, x): - # embed patches - x = self.patch_embed(x) - # add pos embed w/o cls token - x = x + self.pos_embed[:, 1:, :] - - # masking: length -> length * mask_ratio - # if mask_2d: - # x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob) - # else: - # x, mask, ids_restore = self.random_masking(x, mask_ratio) - - # append cls token - cls_token = self.cls_token + self.pos_embed[:, :1, :] - cls_tokens = cls_token.expand(x.shape[0], -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - - # apply Transformer blocks - for blk in self.blocks: - x = blk(x) - x = self.norm(x) - - return x - - def forward_encoder_no_mask(self, x): - # embed patches - x = self.patch_embed(x) - - # add pos embed w/o cls token - x = x + self.pos_embed[:, 1:, :] - - # masking: length -> length * mask_ratio - # x, mask, ids_restore = self.random_masking(x, mask_ratio) - # append cls token - cls_token = self.cls_token + self.pos_embed[:, :1, :] - cls_tokens = cls_token.expand(x.shape[0], -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - - # apply Transformer blocks - contextual_embs = [] - for n, blk in enumerate(self.blocks): - x = blk(x) - if n > self.contextual_depth: - contextual_embs.append(self.norm(x)) - # x = self.norm(x) - contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0) - - return contextual_emb - - def forward_decoder(self, x, ids_restore): - # embed tokens - x = self.decoder_embed(x) - - # append mask tokens to sequence - mask_tokens = self.mask_token.repeat( - x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1 - ) - x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token - x_ = torch.gather( - x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) - ) # unshuffle - x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token - - # add pos embed - x = x + self.decoder_pos_embed - - if self.decoder_mode != 0: - B, L, D = x.shape - x = x[:, 1:, :] - if self.use_custom_patch: - x = x.reshape(B, 101, 12, D) - x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # hack - x = x.reshape(B, 1224, D) - if self.decoder_mode > 3: # mvit - x = self.decoder_blocks(x) - else: - # apply Transformer blocks - for blk in self.decoder_blocks: - x = blk(x) - x = self.decoder_norm(x) - - # predictor projection - pred = self.decoder_pred(x) - - # remove cls token - if self.decoder_mode != 0: - if self.use_custom_patch: - pred = pred.reshape(B, 102, 12, 256) - pred = pred[:, :101, :, :] - pred = pred.reshape(B, 1212, 256) - else: - pred = pred - else: - pred = pred[:, 1:, :] - return pred, None, None # emb, emb_pixel - - def forward_loss(self, imgs, pred, mask, norm_pix_loss=False): - """ - imgs: [N, 3, H, W] - pred: [N, L, p*p*3] - mask: [N, L], 0 is keep, 1 is remove, - """ - target = self.patchify(imgs) - if norm_pix_loss: - mean = target.mean(dim=-1, keepdim=True) - var = target.var(dim=-1, keepdim=True) - target = (target - mean) / (var + 1.0e-6) ** 0.5 - - loss = (pred - target) ** 2 - loss = loss.mean(dim=-1) # [N, L], mean loss per patch - - loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches - return loss - - def forward(self, imgs, mask_ratio=0.8): - emb_enc, mask, ids_restore, _ = self.forward_encoder( - imgs, mask_ratio, mask_2d=self.mask_2d - ) - pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3] - loss_recon = self.forward_loss( - imgs, pred, mask, norm_pix_loss=self.norm_pix_loss - ) - loss_contrastive = torch.FloatTensor([0.0]).cuda() - return loss_recon, pred, mask, loss_contrastive - - -def mae_vit_small_patch16_dec512d8b(**kwargs): - model = MaskedAutoencoderViT( - patch_size=16, - embed_dim=384, - depth=12, - num_heads=6, - decoder_embed_dim=512, - decoder_num_heads=16, - mlp_ratio=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - **kwargs, - ) - return model - - -def mae_vit_base_patch16_dec512d8b(**kwargs): - model = MaskedAutoencoderViT( - patch_size=16, - embed_dim=768, - depth=12, - num_heads=12, - decoder_embed_dim=512, - decoder_num_heads=16, - mlp_ratio=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - **kwargs, - ) - return model - - -def mae_vit_large_patch16_dec512d8b(**kwargs): - model = MaskedAutoencoderViT( - patch_size=16, - embed_dim=1024, - depth=24, - num_heads=16, - decoder_embed_dim=512, - decoder_num_heads=16, - mlp_ratio=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - **kwargs, - ) - return model - - -def mae_vit_huge_patch14_dec512d8b(**kwargs): - model = MaskedAutoencoderViT( - patch_size=14, - embed_dim=1280, - depth=32, - num_heads=16, - decoder_embed_dim=512, - decoder_num_heads=16, - mlp_ratio=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - **kwargs, - ) - return model - - -# set recommended archs -mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks -mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks -mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks -mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks diff --git a/audioldm2/latent_diffusion/modules/audiomae/models_vit.py b/audioldm2/latent_diffusion/modules/audiomae/models_vit.py deleted file mode 100755 index cb37adbc16cfb9a232493c473c9400f199655b6c..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/audiomae/models_vit.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# References: -# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm -# DeiT: https://github.com/facebookresearch/deit -# -------------------------------------------------------- - -from functools import partial - -import torch -import torch.nn as nn -import timm.models.vision_transformer - - -class VisionTransformer(timm.models.vision_transformer.VisionTransformer): - """Vision Transformer with support for global average pooling""" - - def __init__( - self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs - ): - super(VisionTransformer, self).__init__(**kwargs) - - self.global_pool = global_pool - if self.global_pool: - norm_layer = kwargs["norm_layer"] - embed_dim = kwargs["embed_dim"] - self.fc_norm = norm_layer(embed_dim) - del self.norm # remove the original norm - self.mask_2d = mask_2d - self.use_custom_patch = use_custom_patch - - def forward_features(self, x): - B = x.shape[0] - x = self.patch_embed(x) - x = x + self.pos_embed[:, 1:, :] - cls_token = self.cls_token + self.pos_embed[:, :1, :] - cls_tokens = cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks - x = torch.cat((cls_tokens, x), dim=1) - x = self.pos_drop(x) - - for blk in self.blocks: - x = blk(x) - - if self.global_pool: - x = x[:, 1:, :].mean(dim=1) # global pool without cls token - outcome = self.fc_norm(x) - else: - x = self.norm(x) - outcome = x[:, 0] - - return outcome - - def random_masking(self, x, mask_ratio): - """ - Perform per-sample random masking by per-sample shuffling. - Per-sample shuffling is done by argsort random noise. - x: [N, L, D], sequence - """ - N, L, D = x.shape # batch, length, dim - len_keep = int(L * (1 - mask_ratio)) - - noise = torch.rand(N, L, device=x.device) # noise in [0, 1] - - # sort noise for each sample - ids_shuffle = torch.argsort( - noise, dim=1 - ) # ascend: small is keep, large is remove - ids_restore = torch.argsort(ids_shuffle, dim=1) - - # keep the first subset - ids_keep = ids_shuffle[:, :len_keep] - x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) - - # generate the binary mask: 0 is keep, 1 is remove - mask = torch.ones([N, L], device=x.device) - mask[:, :len_keep] = 0 - # unshuffle to get the binary mask - mask = torch.gather(mask, dim=1, index=ids_restore) - - return x_masked, mask, ids_restore - - def random_masking_2d(self, x, mask_t_prob, mask_f_prob): - """ - 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) - Perform per-sample random masking by per-sample shuffling. - Per-sample shuffling is done by argsort random noise. - x: [N, L, D], sequence - """ - - N, L, D = x.shape # batch, length, dim - if self.use_custom_patch: - # # for AS - T = 101 # 64,101 - F = 12 # 8,12 - # # for ESC - # T=50 - # F=12 - # for SPC - # T=12 - # F=12 - else: - # ## for AS - T = 64 - F = 8 - # ## for ESC - # T=32 - # F=8 - ## for SPC - # T=8 - # F=8 - - # mask T - x = x.reshape(N, T, F, D) - len_keep_T = int(T * (1 - mask_t_prob)) - noise = torch.rand(N, T, device=x.device) # noise in [0, 1] - # sort noise for each sample - ids_shuffle = torch.argsort( - noise, dim=1 - ) # ascend: small is keep, large is remove - ids_keep = ids_shuffle[:, :len_keep_T] - index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D) - # x_masked = torch.gather(x, dim=1, index=index) - # x_masked = x_masked.reshape(N,len_keep_T*F,D) - x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D - - # mask F - # x = x.reshape(N, T, F, D) - x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D - len_keep_F = int(F * (1 - mask_f_prob)) - noise = torch.rand(N, F, device=x.device) # noise in [0, 1] - # sort noise for each sample - ids_shuffle = torch.argsort( - noise, dim=1 - ) # ascend: small is keep, large is remove - ids_keep = ids_shuffle[:, :len_keep_F] - # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D) - index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D) - x_masked = torch.gather(x, dim=1, index=index) - x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D - # x_masked = x_masked.reshape(N,len_keep*T,D) - x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D) - - return x_masked, None, None - - def forward_features_mask(self, x, mask_t_prob, mask_f_prob): - B = x.shape[0] # 4,1,1024,128 - x = self.patch_embed(x) # 4, 512, 768 - - x = x + self.pos_embed[:, 1:, :] - if self.random_masking_2d: - x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob) - else: - x, mask, ids_restore = self.random_masking(x, mask_t_prob) - cls_token = self.cls_token + self.pos_embed[:, :1, :] - cls_tokens = cls_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - x = self.pos_drop(x) - - # apply Transformer blocks - for blk in self.blocks: - x = blk(x) - - if self.global_pool: - x = x[:, 1:, :].mean(dim=1) # global pool without cls token - outcome = self.fc_norm(x) - else: - x = self.norm(x) - outcome = x[:, 0] - - return outcome - - # overwrite original timm - def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0): - if mask_t_prob > 0.0 or mask_f_prob > 0.0: - x = self.forward_features_mask( - x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob - ) - else: - x = self.forward_features(x) - x = self.head(x) - return x - - -def vit_small_patch16(**kwargs): - model = VisionTransformer( - patch_size=16, - embed_dim=384, - depth=12, - num_heads=6, - mlp_ratio=4, - qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - **kwargs - ) - return model - - -def vit_base_patch16(**kwargs): - model = VisionTransformer( - patch_size=16, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4, - qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - **kwargs - ) - return model - - -def vit_large_patch16(**kwargs): - model = VisionTransformer( - patch_size=16, - embed_dim=1024, - depth=24, - num_heads=16, - mlp_ratio=4, - qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - **kwargs - ) - return model - - -def vit_huge_patch14(**kwargs): - model = VisionTransformer( - patch_size=14, - embed_dim=1280, - depth=32, - num_heads=16, - mlp_ratio=4, - qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - **kwargs - ) - return model diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/crop.py b/audioldm2/latent_diffusion/modules/audiomae/util/crop.py deleted file mode 100755 index 525e3c783c3d348e593dc89c2b5fb8520918e9ea..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/audiomae/util/crop.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math - -import torch - -from torchvision import transforms -from torchvision.transforms import functional as F - - -class RandomResizedCrop(transforms.RandomResizedCrop): - """ - RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. - This may lead to results different with torchvision's version. - Following BYOL's TF code: - https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 - """ - - @staticmethod - def get_params(img, scale, ratio): - width, height = F._get_image_size(img) - area = height * width - - target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() - log_ratio = torch.log(torch.tensor(ratio)) - aspect_ratio = torch.exp( - torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) - ).item() - - w = int(round(math.sqrt(target_area * aspect_ratio))) - h = int(round(math.sqrt(target_area / aspect_ratio))) - - w = min(w, width) - h = min(h, height) - - i = torch.randint(0, height - h + 1, size=(1,)).item() - j = torch.randint(0, width - w + 1, size=(1,)).item() - - return i, j, h, w diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py b/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py deleted file mode 100755 index b90f89a7d5f78c31bc9113dd88b632b0c234f10a..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# References: -# DeiT: https://github.com/facebookresearch/deit -# -------------------------------------------------------- - -import os -import PIL - -from torchvision import datasets, transforms - -from timm.data import create_transform -from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD - - -def build_dataset(is_train, args): - transform = build_transform(is_train, args) - - root = os.path.join(args.data_path, "train" if is_train else "val") - dataset = datasets.ImageFolder(root, transform=transform) - - print(dataset) - - return dataset - - -def build_transform(is_train, args): - mean = IMAGENET_DEFAULT_MEAN - std = IMAGENET_DEFAULT_STD - # train transform - if is_train: - # this should always dispatch to transforms_imagenet_train - transform = create_transform( - input_size=args.input_size, - is_training=True, - color_jitter=args.color_jitter, - auto_augment=args.aa, - interpolation="bicubic", - re_prob=args.reprob, - re_mode=args.remode, - re_count=args.recount, - mean=mean, - std=std, - ) - return transform - - # eval transform - t = [] - if args.input_size <= 224: - crop_pct = 224 / 256 - else: - crop_pct = 1.0 - size = int(args.input_size / crop_pct) - t.append( - transforms.Resize( - size, interpolation=PIL.Image.BICUBIC - ), # to maintain same ratio w.r.t. 224 images - ) - t.append(transforms.CenterCrop(args.input_size)) - - t.append(transforms.ToTensor()) - t.append(transforms.Normalize(mean, std)) - return transforms.Compose(t) diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/lars.py b/audioldm2/latent_diffusion/modules/audiomae/util/lars.py deleted file mode 100755 index fc43923d22cf2c9af4ae9166612c3f3477faf254..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/audiomae/util/lars.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# LARS optimizer, implementation from MoCo v3: -# https://github.com/facebookresearch/moco-v3 -# -------------------------------------------------------- - -import torch - - -class LARS(torch.optim.Optimizer): - """ - LARS optimizer, no rate scaling or weight decay for parameters <= 1D. - """ - - def __init__( - self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001 - ): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - trust_coefficient=trust_coefficient, - ) - super().__init__(params, defaults) - - @torch.no_grad() - def step(self): - for g in self.param_groups: - for p in g["params"]: - dp = p.grad - - if dp is None: - continue - - if p.ndim > 1: # if not normalization gamma/beta or bias - dp = dp.add(p, alpha=g["weight_decay"]) - param_norm = torch.norm(p) - update_norm = torch.norm(dp) - one = torch.ones_like(param_norm) - q = torch.where( - param_norm > 0.0, - torch.where( - update_norm > 0, - (g["trust_coefficient"] * param_norm / update_norm), - one, - ), - one, - ) - dp = dp.mul(q) - - param_state = self.state[p] - if "mu" not in param_state: - param_state["mu"] = torch.zeros_like(p) - mu = param_state["mu"] - mu.mul_(g["momentum"]).add_(dp) - p.add_(mu, alpha=-g["lr"]) diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py b/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py deleted file mode 100755 index e90ed69d7b8d019dbf5d90571541668e2bd8efe8..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# References: -# ELECTRA https://github.com/google-research/electra -# BEiT: https://github.com/microsoft/unilm/tree/master/beit -# -------------------------------------------------------- - - -def param_groups_lrd( - model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75 -): - """ - Parameter groups for layer-wise lr decay - Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 - """ - param_group_names = {} - param_groups = {} - - num_layers = len(model.blocks) + 1 - - layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - - # no decay: all 1D parameters and model specific ones - if p.ndim == 1 or n in no_weight_decay_list: - g_decay = "no_decay" - this_decay = 0.0 - else: - g_decay = "decay" - this_decay = weight_decay - - layer_id = get_layer_id_for_vit(n, num_layers) - group_name = "layer_%d_%s" % (layer_id, g_decay) - - if group_name not in param_group_names: - this_scale = layer_scales[layer_id] - - param_group_names[group_name] = { - "lr_scale": this_scale, - "weight_decay": this_decay, - "params": [], - } - param_groups[group_name] = { - "lr_scale": this_scale, - "weight_decay": this_decay, - "params": [], - } - - param_group_names[group_name]["params"].append(n) - param_groups[group_name]["params"].append(p) - - # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) - - return list(param_groups.values()) - - -def get_layer_id_for_vit(name, num_layers): - """ - Assign a parameter with its layer id - Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 - """ - if name in ["cls_token", "pos_embed"]: - return 0 - elif name.startswith("patch_embed"): - return 0 - elif name.startswith("blocks"): - return int(name.split(".")[1]) + 1 - else: - return num_layers diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py b/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py deleted file mode 100755 index efe184d8e3fb63ec6b4f83375b6ea719985900de..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math - - -def adjust_learning_rate(optimizer, epoch, args): - """Decay the learning rate with half-cycle cosine after warmup""" - if epoch < args.warmup_epochs: - lr = args.lr * epoch / args.warmup_epochs - else: - lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( - 1.0 - + math.cos( - math.pi - * (epoch - args.warmup_epochs) - / (args.epochs - args.warmup_epochs) - ) - ) - for param_group in optimizer.param_groups: - if "lr_scale" in param_group: - param_group["lr"] = lr * param_group["lr_scale"] - else: - param_group["lr"] = lr - return lr diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/misc.py b/audioldm2/latent_diffusion/modules/audiomae/util/misc.py deleted file mode 100755 index 74184e09e23e0e174350b894b0cff29600c18b71..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/audiomae/util/misc.py +++ /dev/null @@ -1,453 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# References: -# DeiT: https://github.com/facebookresearch/deit -# BEiT: https://github.com/microsoft/unilm/tree/master/beit -# -------------------------------------------------------- - -import builtins -import datetime -import os -import time -from collections import defaultdict, deque -from pathlib import Path - -import torch -import torch.distributed as dist -from torch._six import inf - - -class SmoothedValue(object): - """Track a series of values and provide access to smoothed values over a - window or the global series average. - """ - - def __init__(self, window_size=20, fmt=None): - if fmt is None: - fmt = "{median:.4f} ({global_avg:.4f})" - self.deque = deque(maxlen=window_size) - self.total = 0.0 - self.count = 0 - self.fmt = fmt - - def update(self, value, n=1): - self.deque.append(value) - self.count += n - self.total += value * n - - def synchronize_between_processes(self): - """ - Warning: does not synchronize the deque! - """ - if not is_dist_avail_and_initialized(): - return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") - dist.barrier() - dist.all_reduce(t) - t = t.tolist() - self.count = int(t[0]) - self.total = t[1] - - @property - def median(self): - d = torch.tensor(list(self.deque)) - return d.median().item() - - @property - def avg(self): - d = torch.tensor(list(self.deque), dtype=torch.float32) - return d.mean().item() - - @property - def global_avg(self): - return self.total / self.count - - @property - def max(self): - return max(self.deque) - - @property - def value(self): - return self.deque[-1] - - def __str__(self): - return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value, - ) - - -class MetricLogger(object): - def __init__(self, delimiter="\t"): - self.meters = defaultdict(SmoothedValue) - self.delimiter = delimiter - - def update(self, **kwargs): - for k, v in kwargs.items(): - if v is None: - continue - if isinstance(v, torch.Tensor): - v = v.item() - assert isinstance(v, (float, int)) - self.meters[k].update(v) - - def __getattr__(self, attr): - if attr in self.meters: - return self.meters[attr] - if attr in self.__dict__: - return self.__dict__[attr] - raise AttributeError( - "'{}' object has no attribute '{}'".format(type(self).__name__, attr) - ) - - def __str__(self): - loss_str = [] - for name, meter in self.meters.items(): - loss_str.append("{}: {}".format(name, str(meter))) - return self.delimiter.join(loss_str) - - def synchronize_between_processes(self): - for meter in self.meters.values(): - meter.synchronize_between_processes() - - def add_meter(self, name, meter): - self.meters[name] = meter - - def log_every(self, iterable, print_freq, header=None): - i = 0 - if not header: - header = "" - start_time = time.time() - end = time.time() - iter_time = SmoothedValue(fmt="{avg:.4f}") - data_time = SmoothedValue(fmt="{avg:.4f}") - space_fmt = ":" + str(len(str(len(iterable)))) + "d" - log_msg = [ - header, - "[{0" + space_fmt + "}/{1}]", - "eta: {eta}", - "{meters}", - "time: {time}", - "data: {data}", - ] - if torch.cuda.is_available(): - log_msg.append("max mem: {memory:.0f}") - log_msg = self.delimiter.join(log_msg) - MB = 1024.0 * 1024.0 - for obj in iterable: - data_time.update(time.time() - end) - yield obj - iter_time.update(time.time() - end) - if i % print_freq == 0 or i == len(iterable) - 1: - eta_seconds = iter_time.global_avg * (len(iterable) - i) - eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) - if torch.cuda.is_available(): - print( - log_msg.format( - i, - len(iterable), - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB, - ) - ) - else: - print( - log_msg.format( - i, - len(iterable), - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - ) - ) - i += 1 - end = time.time() - total_time = time.time() - start_time - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print( - "{} Total time: {} ({:.4f} s / it)".format( - header, total_time_str, total_time / len(iterable) - ) - ) - - -def setup_for_distributed(is_master): - """ - This function disables printing when not in master process - """ - builtin_print = builtins.print - - def print(*args, **kwargs): - force = kwargs.pop("force", False) - force = force or (get_world_size() > 8) - if is_master or force: - now = datetime.datetime.now().time() - builtin_print("[{}] ".format(now), end="") # print with time stamp - builtin_print(*args, **kwargs) - - builtins.print = print - - -def is_dist_avail_and_initialized(): - if not dist.is_available(): - return False - if not dist.is_initialized(): - return False - return True - - -def get_world_size(): - if not is_dist_avail_and_initialized(): - return 1 - return dist.get_world_size() - - -def get_rank(): - if not is_dist_avail_and_initialized(): - return 0 - return dist.get_rank() - - -def is_main_process(): - return get_rank() == 0 - - -def save_on_master(*args, **kwargs): - if is_main_process(): - torch.save(*args, **kwargs) - - -def init_distributed_mode(args): - if args.dist_on_itp: - args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) - args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) - args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) - args.dist_url = "tcp://%s:%s" % ( - os.environ["MASTER_ADDR"], - os.environ["MASTER_PORT"], - ) - os.environ["LOCAL_RANK"] = str(args.gpu) - os.environ["RANK"] = str(args.rank) - os.environ["WORLD_SIZE"] = str(args.world_size) - # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] - elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: - args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ["WORLD_SIZE"]) - args.gpu = int(os.environ["LOCAL_RANK"]) - elif "SLURM_PROCID" in os.environ: - args.rank = int(os.environ["SLURM_PROCID"]) - args.gpu = args.rank % torch.cuda.device_count() - else: - print("Not using distributed mode") - setup_for_distributed(is_master=True) # hack - args.distributed = False - return - - args.distributed = True - - torch.cuda.set_device(args.gpu) - args.dist_backend = "nccl" - print( - "| distributed init (rank {}): {}, gpu {}".format( - args.rank, args.dist_url, args.gpu - ), - flush=True, - ) - torch.distributed.init_process_group( - backend=args.dist_backend, - init_method=args.dist_url, - world_size=args.world_size, - rank=args.rank, - ) - torch.distributed.barrier() - setup_for_distributed(args.rank == 0) - - -class NativeScalerWithGradNormCount: - state_dict_key = "amp_scaler" - - def __init__(self): - self._scaler = torch.cuda.amp.GradScaler() - - def __call__( - self, - loss, - optimizer, - clip_grad=None, - parameters=None, - create_graph=False, - update_grad=True, - ): - self._scaler.scale(loss).backward(create_graph=create_graph) - if update_grad: - if clip_grad is not None: - assert parameters is not None - self._scaler.unscale_( - optimizer - ) # unscale the gradients of optimizer's assigned params in-place - norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) - else: - self._scaler.unscale_(optimizer) - norm = get_grad_norm_(parameters) - self._scaler.step(optimizer) - self._scaler.update() - else: - norm = None - return norm - - def state_dict(self): - return self._scaler.state_dict() - - def load_state_dict(self, state_dict): - self._scaler.load_state_dict(state_dict) - - -def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = [p for p in parameters if p.grad is not None] - norm_type = float(norm_type) - if len(parameters) == 0: - return torch.tensor(0.0) - device = parameters[0].grad.device - if norm_type == inf: - total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) - else: - total_norm = torch.norm( - torch.stack( - [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] - ), - norm_type, - ) - return total_norm - - -def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): - output_dir = Path(args.output_dir) - epoch_name = str(epoch) - if loss_scaler is not None: - checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)] - for checkpoint_path in checkpoint_paths: - to_save = { - "model": model_without_ddp.state_dict(), - "optimizer": optimizer.state_dict(), - "epoch": epoch, - "scaler": loss_scaler.state_dict(), - "args": args, - } - - save_on_master(to_save, checkpoint_path) - else: - client_state = {"epoch": epoch} - model.save_checkpoint( - save_dir=args.output_dir, - tag="checkpoint-%s" % epoch_name, - client_state=client_state, - ) - - -def load_model(args, model_without_ddp, optimizer, loss_scaler): - if args.resume: - if args.resume.startswith("https"): - checkpoint = torch.hub.load_state_dict_from_url( - args.resume, map_location="cpu", check_hash=True - ) - else: - checkpoint = torch.load(args.resume, map_location="cpu") - model_without_ddp.load_state_dict(checkpoint["model"]) - print("Resume checkpoint %s" % args.resume) - if ( - "optimizer" in checkpoint - and "epoch" in checkpoint - and not (hasattr(args, "eval") and args.eval) - ): - optimizer.load_state_dict(checkpoint["optimizer"]) - args.start_epoch = checkpoint["epoch"] + 1 - if "scaler" in checkpoint: - loss_scaler.load_state_dict(checkpoint["scaler"]) - print("With optim & sched!") - - -def all_reduce_mean(x): - world_size = get_world_size() - if world_size > 1: - x_reduce = torch.tensor(x).cuda() - dist.all_reduce(x_reduce) - x_reduce /= world_size - return x_reduce.item() - else: - return x - - -# utils -@torch.no_grad() -def concat_all_gather(tensor): - """ - Performs all_gather operation on the provided tensors. - *** Warning ***: torch.distributed.all_gather has no gradient. - """ - tensors_gather = [ - torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) - ] - torch.distributed.all_gather(tensors_gather, tensor, async_op=False) - - output = torch.cat(tensors_gather, dim=0) - return output - - -def merge_vmae_to_avmae(avmae_state_dict, vmae_ckpt): - # keys_to_copy=['pos_embed','patch_embed'] - # replaced=0 - - vmae_ckpt["cls_token"] = vmae_ckpt["cls_token_v"] - vmae_ckpt["mask_token"] = vmae_ckpt["mask_token_v"] - - # pos_emb % not trainable, use default - pos_embed_v = vmae_ckpt["pos_embed_v"] # 1,589,768 - pos_embed = pos_embed_v[:, 1:, :] # 1,588,768 - cls_embed = pos_embed_v[:, 0, :].unsqueeze(1) - pos_embed = pos_embed.reshape(1, 2, 14, 14, 768).sum(dim=1) # 1, 14, 14, 768 - print("Position interpolate from 14,14 to 64,8") - pos_embed = pos_embed.permute(0, 3, 1, 2) # 1, 14,14,768 -> 1,768,14,14 - pos_embed = torch.nn.functional.interpolate( - pos_embed, size=(64, 8), mode="bicubic", align_corners=False - ) - pos_embed = pos_embed.permute(0, 2, 3, 1).flatten( - 1, 2 - ) # 1, 14, 14, 768 => 1, 196,768 - pos_embed = torch.cat((cls_embed, pos_embed), dim=1) - assert vmae_ckpt["pos_embed"].shape == pos_embed.shape - vmae_ckpt["pos_embed"] = pos_embed - # patch_emb - # aggregate 3 channels in video-rgb ckpt to 1 channel for audio - v_weight = vmae_ckpt["patch_embed_v.proj.weight"] # 768,3,2,16,16 - new_proj_weight = torch.nn.Parameter(v_weight.sum(dim=2).sum(dim=1).unsqueeze(1)) - assert new_proj_weight.shape == vmae_ckpt["patch_embed.proj.weight"].shape - vmae_ckpt["patch_embed.proj.weight"] = new_proj_weight - vmae_ckpt["patch_embed.proj.bias"] = vmae_ckpt["patch_embed_v.proj.bias"] - - # hack - vmae_ckpt["norm.weight"] = vmae_ckpt["norm_v.weight"] - vmae_ckpt["norm.bias"] = vmae_ckpt["norm_v.bias"] - - # replace transformer encoder - for k, v in vmae_ckpt.items(): - if k.startswith("blocks."): - kk = k.replace("blocks.", "blocks_v.") - vmae_ckpt[k] = vmae_ckpt[kk] - elif k.startswith("blocks_v."): - pass - else: - print(k) - print(k) diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py b/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py deleted file mode 100755 index ac1e4d436c6f79aef9bf1de32cdac5d4f037c775..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py +++ /dev/null @@ -1,127 +0,0 @@ -import torch -import torch.nn as nn -from timm.models.layers import to_2tuple - - -class PatchEmbed_org(nn.Module): - """Image to Patch Embedding""" - - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) - self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) - self.img_size = img_size - self.patch_size = patch_size - self.num_patches = num_patches - - self.proj = nn.Conv2d( - in_chans, embed_dim, kernel_size=patch_size, stride=patch_size - ) - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - # assert H == self.img_size[0] and W == self.img_size[1], \ - # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x) - y = x.flatten(2).transpose(1, 2) - return y - - -class PatchEmbed_new(nn.Module): - """Flexible Image to Patch Embedding""" - - def __init__( - self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 - ): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - stride = to_2tuple(stride) - - self.img_size = img_size - self.patch_size = patch_size - - self.proj = nn.Conv2d( - in_chans, embed_dim, kernel_size=patch_size, stride=stride - ) # with overlapped patches - # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - - # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) - # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) - _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w - self.patch_hw = (h, w) - self.num_patches = h * w - - def get_output_shape(self, img_size): - # todo: don't be lazy.. - return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - # assert H == self.img_size[0] and W == self.img_size[1], \ - # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - # x = self.proj(x).flatten(2).transpose(1, 2) - x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12 - x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212 - x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768 - return x - - -class PatchEmbed3D_new(nn.Module): - """Flexible Image to Patch Embedding""" - - def __init__( - self, - video_size=(16, 224, 224), - patch_size=(2, 16, 16), - in_chans=3, - embed_dim=768, - stride=(2, 16, 16), - ): - super().__init__() - - self.video_size = video_size - self.patch_size = patch_size - self.in_chans = in_chans - - self.proj = nn.Conv3d( - in_chans, embed_dim, kernel_size=patch_size, stride=stride - ) - _, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w - self.patch_thw = (t, h, w) - self.num_patches = t * h * w - - def get_output_shape(self, video_size): - # todo: don't be lazy.. - return self.proj( - torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2]) - ).shape - - def forward(self, x): - B, C, T, H, W = x.shape - x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14 - x = x.flatten(2) # 32, 768, 1568 - x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768 - return x - - -if __name__ == "__main__": - # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16)) - # input = torch.rand(8,1,1024,128) - # output = patch_emb(input) - # print(output.shape) # (8,512,64) - - patch_emb = PatchEmbed3D_new( - video_size=(6, 224, 224), - patch_size=(2, 16, 16), - in_chans=3, - embed_dim=768, - stride=(2, 16, 16), - ) - input = torch.rand(8, 3, 6, 224, 224) - output = patch_emb(input) - print(output.shape) # (8,64) diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py b/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py deleted file mode 100755 index 2d9177ed98dffcf35264f38aff94e7f00fb50abf..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# Position embedding utils -# -------------------------------------------------------- - -import numpy as np - -import torch - - -# -------------------------------------------------------- -# 2D sine-cosine position embedding -# References: -# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py -# MoCo v3: https://github.com/facebookresearch/moco-v3 -# -------------------------------------------------------- -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size, dtype=np.float32) - grid_w = np.arange(grid_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size[0], dtype=np.float32) - grid_w = np.arange(grid_size[1], dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - out: (M, D) - """ - assert embed_dim % 2 == 0 - # omega = np.arange(embed_dim // 2, dtype=np.float) - omega = np.arange(embed_dim // 2, dtype=float) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -# -------------------------------------------------------- -# Interpolate position embeddings for high-resolution -# References: -# DeiT: https://github.com/facebookresearch/deit -# -------------------------------------------------------- -def interpolate_pos_embed(model, checkpoint_model): - if "pos_embed" in checkpoint_model: - pos_embed_checkpoint = checkpoint_model["pos_embed"] - embedding_size = pos_embed_checkpoint.shape[-1] - num_patches = model.patch_embed.num_patches - num_extra_tokens = model.pos_embed.shape[-2] - num_patches - # height (== width) for the checkpoint position embedding - orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) - # height (== width) for the new position embedding - new_size = int(num_patches**0.5) - # class_token and dist_token are kept unchanged - if orig_size != new_size: - print( - "Position interpolate from %dx%d to %dx%d" - % (orig_size, orig_size, new_size, new_size) - ) - extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] - # only the position tokens are interpolated - pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] - pos_tokens = pos_tokens.reshape( - -1, orig_size, orig_size, embedding_size - ).permute(0, 3, 1, 2) - pos_tokens = torch.nn.functional.interpolate( - pos_tokens, - size=(new_size, new_size), - mode="bicubic", - align_corners=False, - ) - pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) - new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) - checkpoint_model["pos_embed"] = new_pos_embed - - -def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size): - if "pos_embed" in checkpoint_model: - pos_embed_checkpoint = checkpoint_model["pos_embed"] - embedding_size = pos_embed_checkpoint.shape[-1] - num_patches = model.patch_embed.num_patches - num_extra_tokens = model.pos_embed.shape[-2] - num_patches - # height (== width) for the checkpoint position embedding - # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) - # height (== width) for the new position embedding - # new_size = int(num_patches ** 0.5) - # class_token and dist_token are kept unchanged - if orig_size != new_size: - print( - "Position interpolate from %dx%d to %dx%d" - % (orig_size[0], orig_size[1], new_size[0], new_size[1]) - ) - extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] - # only the position tokens are interpolated - pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] - pos_tokens = pos_tokens.reshape( - -1, orig_size[0], orig_size[1], embedding_size - ).permute(0, 3, 1, 2) - pos_tokens = torch.nn.functional.interpolate( - pos_tokens, - size=(new_size[0], new_size[1]), - mode="bicubic", - align_corners=False, - ) - pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) - new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) - checkpoint_model["pos_embed"] = new_pos_embed - - -def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size): - if "pos_embed" in checkpoint_model: - pos_embed_checkpoint = checkpoint_model["pos_embed"] - embedding_size = pos_embed_checkpoint.shape[-1] - num_patches = model.patch_embed.num_patches - model.pos_embed.shape[-2] - num_patches - if orig_size != new_size: - print( - "Position interpolate from %dx%d to %dx%d" - % (orig_size[0], orig_size[1], new_size[0], new_size[1]) - ) - # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] - # only the position tokens are interpolated - cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1) - pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove - pos_tokens = pos_tokens.reshape( - -1, orig_size[0], orig_size[1], embedding_size - ) # .permute(0, 3, 1, 2) - # pos_tokens = torch.nn.functional.interpolate( - # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False) - - # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) - pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff - pos_tokens = pos_tokens.flatten(1, 2) - new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1) - checkpoint_model["pos_embed"] = new_pos_embed - - -def interpolate_patch_embed_audio( - model, - checkpoint_model, - orig_channel, - new_channel=1, - kernel_size=(16, 16), - stride=(16, 16), - padding=(0, 0), -): - if orig_channel != new_channel: - if "patch_embed.proj.weight" in checkpoint_model: - # aggregate 3 channels in rgb ckpt to 1 channel for audio - new_proj_weight = torch.nn.Parameter( - torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze( - 1 - ) - ) - checkpoint_model["patch_embed.proj.weight"] = new_proj_weight diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/stat.py b/audioldm2/latent_diffusion/modules/audiomae/util/stat.py deleted file mode 100755 index 3f8137249503f6eaa25c3170fe5ef6b87f187347..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/audiomae/util/stat.py +++ /dev/null @@ -1,76 +0,0 @@ -import numpy as np -from scipy import stats -from sklearn import metrics -import torch - - -def d_prime(auc): - standard_normal = stats.norm() - d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) - return d_prime - - -@torch.no_grad() -def concat_all_gather(tensor): - """ - Performs all_gather operation on the provided tensors. - *** Warning ***: torch.distributed.all_gather has no gradient. - """ - tensors_gather = [ - torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) - ] - torch.distributed.all_gather(tensors_gather, tensor, async_op=False) - - output = torch.cat(tensors_gather, dim=0) - return output - - -def calculate_stats(output, target): - """Calculate statistics including mAP, AUC, etc. - - Args: - output: 2d array, (samples_num, classes_num) - target: 2d array, (samples_num, classes_num) - - Returns: - stats: list of statistic of each class. - """ - - classes_num = target.shape[-1] - stats = [] - - # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet - acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1)) - - # Class-wise statistics - for k in range(classes_num): - # Average precision - avg_precision = metrics.average_precision_score( - target[:, k], output[:, k], average=None - ) - - # AUC - # auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None) - - # Precisions, recalls - (precisions, recalls, thresholds) = metrics.precision_recall_curve( - target[:, k], output[:, k] - ) - - # FPR, TPR - (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k]) - - save_every_steps = 1000 # Sample statistics to reduce size - dict = { - "precisions": precisions[0::save_every_steps], - "recalls": recalls[0::save_every_steps], - "AP": avg_precision, - "fpr": fpr[0::save_every_steps], - "fnr": 1.0 - tpr[0::save_every_steps], - # 'auc': auc, - # note acc is not class-wise, this is just to keep consistent with other metrics - "acc": acc, - } - stats.append(dict) - - return stats diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/__init__.py b/audioldm2/latent_diffusion/modules/diffusionmodules/__init__.py deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/model.py b/audioldm2/latent_diffusion/modules/diffusionmodules/model.py deleted file mode 100755 index 851f8dd28e80046c5e3c9d95bd37726024f1367c..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/diffusionmodules/model.py +++ /dev/null @@ -1,1069 +0,0 @@ -# pytorch_diffusion + derived encoder decoder -import math -import torch -import torch.nn as nn -import numpy as np -from einops import rearrange - -from audioldm2.latent_diffusion.util import instantiate_from_config -from audioldm2.latent_diffusion.modules.attention import LinearAttention - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm( - num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True - ) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class UpsampleTimeStride4(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=5, stride=1, padding=2 - ) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # Do time downsampling here - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=2, padding=0 - ) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class DownsampleTimeStride4(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # Do time downsampling here - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 - ) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) - return x - - -class ResnetBlock(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - else: - self.nin_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class LinAttnBlock(LinearAttention): - """to match AttnBlock usage""" - - def __init__(self, in_channels): - super().__init__(dim=in_channels, heads=1, dim_head=in_channels) - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w).contiguous() - q = q.permute(0, 2, 1).contiguous() # b,hw,c - k = k.reshape(b, c, h * w).contiguous() # b,c,hw - w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w).contiguous() - w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm( - v, w_ - ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w).contiguous() - - h_ = self.proj_out(h_) - - return x + h_ - - -def make_attn(in_channels, attn_type="vanilla"): - assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" - # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - return AttnBlock(in_channels) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - return LinAttnBlock(in_channels) - - -class Model(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - use_linear_attn=False, - attn_type="vanilla", - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = self.ch * 4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x, t=None, context=None): - # assert x.shape[2] == x.shape[3] == self.resolution - if context is not None: - # assume aligned context, cat along channel axis - x = torch.cat((x, context), dim=1) - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb - ) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - def get_last_layer(self): - return self.conv_out.weight - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type="vanilla", - downsample_time_stride4_levels=[], - **ignore_kwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.downsample_time_stride4_levels = downsample_time_stride4_levels - - if len(self.downsample_time_stride4_levels) > 0: - assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( - "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" - % str(self.num_resolutions) - ) - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - if i_level in self.downsample_time_stride4_levels: - down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) - else: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1, - ) - - def forward(self, x): - # timestep embedding - temb = None - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - downsample_time_stride4_levels=[], - attn_type="vanilla", - **ignorekwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - self.tanh_out = tanh_out - self.downsample_time_stride4_levels = downsample_time_stride4_levels - - if len(self.downsample_time_stride4_levels) > 0: - assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( - "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" - % str(self.num_resolutions) - ) - - # compute in_ch_mult, block_in and curr_res at lowest res - (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - # print( - # "Working with z of shape {} = {} dimensions.".format( - # self.z_shape, np.prod(self.z_shape) - # ) - # ) - - # z to block_in - self.conv_in = torch.nn.Conv2d( - z_channels, block_in, kernel_size=3, stride=1, padding=1 - ) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - if i_level - 1 in self.downsample_time_stride4_levels: - up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) - else: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - if self.tanh_out: - h = torch.tanh(h) - return h - - -class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): - super().__init__() - self.model = nn.ModuleList( - [ - nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock( - in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, - dropout=0.0, - ), - ResnetBlock( - in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, - dropout=0.0, - ), - ResnetBlock( - in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, - dropout=0.0, - ), - nn.Conv2d(2 * in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True), - ] - ) - # end - self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - for i, layer in enumerate(self.model): - if i in [1, 2, 3]: - x = layer(x, None) - else: - x = layer(x) - - h = self.norm_out(x) - h = nonlinearity(h) - x = self.conv_out(h) - return x - - -class UpsampleDecoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - ch, - num_res_blocks, - resolution, - ch_mult=(2, 2), - dropout=0.0, - ): - super().__init__() - # upsampling - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - block_in = in_channels - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.res_blocks = nn.ModuleList() - self.upsample_blocks = nn.ModuleList() - for i_level in range(self.num_resolutions): - res_block = [] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - res_block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - self.res_blocks.append(nn.ModuleList(res_block)) - if i_level != self.num_resolutions - 1: - self.upsample_blocks.append(Upsample(block_in, True)) - curr_res = curr_res * 2 - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - # upsampling - h = x - for k, i_level in enumerate(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.res_blocks[i_level][i_block](h, None) - if i_level != self.num_resolutions - 1: - h = self.upsample_blocks[k](h) - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): - super().__init__() - # residual block, interpolate, residual block - self.factor = factor - self.conv_in = nn.Conv2d( - in_channels, mid_channels, kernel_size=3, stride=1, padding=1 - ) - self.res_block1 = nn.ModuleList( - [ - ResnetBlock( - in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0, - ) - for _ in range(depth) - ] - ) - self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList( - [ - ResnetBlock( - in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0, - ) - for _ in range(depth) - ] - ) - - self.conv_out = nn.Conv2d( - mid_channels, - out_channels, - kernel_size=1, - ) - - def forward(self, x): - x = self.conv_in(x) - for block in self.res_block1: - x = block(x, None) - x = torch.nn.functional.interpolate( - x, - size=( - int(round(x.shape[2] * self.factor)), - int(round(x.shape[3] * self.factor)), - ), - ) - x = self.attn(x).contiguous() - for block in self.res_block2: - x = block(x, None) - x = self.conv_out(x) - return x - - -class MergedRescaleEncoder(nn.Module): - def __init__( - self, - in_channels, - ch, - resolution, - out_ch, - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - ch_mult=(1, 2, 4, 8), - rescale_factor=1.0, - rescale_module_depth=1, - ): - super().__init__() - intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder( - in_channels=in_channels, - num_res_blocks=num_res_blocks, - ch=ch, - ch_mult=ch_mult, - z_channels=intermediate_chn, - double_z=False, - resolution=resolution, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - out_ch=None, - ) - self.rescaler = LatentRescaler( - factor=rescale_factor, - in_channels=intermediate_chn, - mid_channels=intermediate_chn, - out_channels=out_ch, - depth=rescale_module_depth, - ) - - def forward(self, x): - x = self.encoder(x) - x = self.rescaler(x) - return x - - -class MergedRescaleDecoder(nn.Module): - def __init__( - self, - z_channels, - out_ch, - resolution, - num_res_blocks, - attn_resolutions, - ch, - ch_mult=(1, 2, 4, 8), - dropout=0.0, - resamp_with_conv=True, - rescale_factor=1.0, - rescale_module_depth=1, - ): - super().__init__() - tmp_chn = z_channels * ch_mult[-1] - self.decoder = Decoder( - out_ch=out_ch, - z_channels=tmp_chn, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - in_channels=None, - num_res_blocks=num_res_blocks, - ch_mult=ch_mult, - resolution=resolution, - ch=ch, - ) - self.rescaler = LatentRescaler( - factor=rescale_factor, - in_channels=z_channels, - mid_channels=tmp_chn, - out_channels=tmp_chn, - depth=rescale_module_depth, - ) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): - super().__init__() - assert out_size >= in_size - num_blocks = int(np.log2(out_size // in_size)) + 1 - factor_up = 1.0 + (out_size % in_size) - print( - f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" - ) - self.rescaler = LatentRescaler( - factor=factor_up, - in_channels=in_channels, - mid_channels=2 * in_channels, - out_channels=in_channels, - ) - self.decoder = Decoder( - out_ch=out_channels, - resolution=out_size, - z_channels=in_channels, - num_res_blocks=2, - attn_resolutions=[], - in_channels=None, - ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)], - ) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): - super().__init__() - self.with_conv = learned - self.mode = mode - if self.with_conv: - print( - f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" - ) - raise NotImplementedError() - assert in_channels is not None - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=4, stride=2, padding=1 - ) - - def forward(self, x, scale_factor=1.0): - if scale_factor == 1.0: - return x - else: - x = torch.nn.functional.interpolate( - x, mode=self.mode, align_corners=False, scale_factor=scale_factor - ) - return x - - -class FirstStagePostProcessor(nn.Module): - def __init__( - self, - ch_mult: list, - in_channels, - pretrained_model: nn.Module = None, - reshape=False, - n_channels=None, - dropout=0.0, - pretrained_config=None, - ): - super().__init__() - if pretrained_config is None: - assert ( - pretrained_model is not None - ), 'Either "pretrained_model" or "pretrained_config" must not be None' - self.pretrained_model = pretrained_model - else: - assert ( - pretrained_config is not None - ), 'Either "pretrained_model" or "pretrained_config" must not be None' - self.instantiate_pretrained(pretrained_config) - - self.do_reshape = reshape - - if n_channels is None: - n_channels = self.pretrained_model.encoder.ch - - self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) - self.proj = nn.Conv2d( - in_channels, n_channels, kernel_size=3, stride=1, padding=1 - ) - - blocks = [] - downs = [] - ch_in = n_channels - for m in ch_mult: - blocks.append( - ResnetBlock( - in_channels=ch_in, out_channels=m * n_channels, dropout=dropout - ) - ) - ch_in = m * n_channels - downs.append(Downsample(ch_in, with_conv=False)) - - self.model = nn.ModuleList(blocks) - self.downsampler = nn.ModuleList(downs) - - def instantiate_pretrained(self, config): - model = instantiate_from_config(config) - self.pretrained_model = model.eval() - # self.pretrained_model.train = False - for param in self.pretrained_model.parameters(): - param.requires_grad = False - - @torch.no_grad() - def encode_with_pretrained(self, x): - c = self.pretrained_model.encode(x) - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - return c - - def forward(self, x): - z_fs = self.encode_with_pretrained(x) - z = self.proj_norm(z_fs) - z = self.proj(z) - z = nonlinearity(z) - - for submodel, downmodel in zip(self.model, self.downsampler): - z = submodel(z, temb=None) - z = downmodel(z) - - if self.do_reshape: - z = rearrange(z, "b c h w -> b (h w) c") - return z diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/openaimodel.py b/audioldm2/latent_diffusion/modules/diffusionmodules/openaimodel.py deleted file mode 100755 index e006e5a332c3cde5f4e221f003b270d86b34e933..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/diffusionmodules/openaimodel.py +++ /dev/null @@ -1,1103 +0,0 @@ -from abc import abstractmethod -import math - -import numpy as np -import torch as th -import torch.nn as nn -import torch.nn.functional as F - -from audioldm2.latent_diffusion.modules.diffusionmodules.util import ( - checkpoint, - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) -from audioldm2.latent_diffusion.modules.attention import SpatialTransformer - - -# dummy replace -def convert_module_to_f16(x): - pass - - -def convert_module_to_f32(x): - pass - - -## go -class AttentionPool2d(nn.Module): - """ - Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py - """ - - def __init__( - self, - spacial_dim: int, - embed_dim: int, - num_heads_channels: int, - output_dim: int = None, - ): - super().__init__() - self.positional_embedding = nn.Parameter( - th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 - ) - self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) - self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) - self.num_heads = embed_dim // num_heads_channels - self.attention = QKVAttention(self.num_heads) - - def forward(self, x): - b, c, *_spatial = x.shape - x = x.reshape(b, c, -1).contiguous() # NC(HW) - x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) - x = self.qkv_proj(x) - x = self.attention(x) - x = self.c_proj(x) - return x[:, :, 0] - - -class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - - @abstractmethod - def forward(self, x, emb): - """ - Apply the module to `x` given `emb` timestep embeddings. - """ - - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward(self, x, emb, context_list=None, mask_list=None): - # The first spatial transformer block does not have context - spatial_transformer_id = 0 - context_list = [None] + context_list - mask_list = [None] + mask_list - - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer): - if spatial_transformer_id >= len(context_list): - context, mask = None, None - else: - context, mask = ( - context_list[spatial_transformer_id], - mask_list[spatial_transformer_id], - ) - - x = layer(x, context, mask=mask) - spatial_transformer_id += 1 - else: - x = layer(x) - return x - - -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd( - dims, self.channels, self.out_channels, 3, padding=padding - ) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class TransposedUpsample(nn.Module): - "Learned 2x upsampling without padding" - - def __init__(self, channels, out_channels=None, ks=5): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - - self.up = nn.ConvTranspose2d( - self.channels, self.out_channels, kernel_size=ks, stride=2 - ) - - def forward(self, x): - return self.up(x) - - -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd( - dims, - self.channels, - self.out_channels, - 3, - stride=stride, - padding=padding, - ) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: if True, use gradient checkpointing on this module. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. - """ - - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - use_checkpoint=False, - up=False, - down=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) - elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - - def _forward(self, x, emb): - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) - if use_new_attention_order: - # split qkv before split heads - self.attention = QKVAttention(self.num_heads) - else: - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) - - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - - def forward(self, x): - return checkpoint( - self._forward, (x,), self.parameters(), True - ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - # return pt_checkpoint(self._forward, x) # pytorch - - def _forward(self, x): - b, c, *spatial = x.shape - x = x.reshape(b, c, -1).contiguous() - qkv = self.qkv(self.norm(x)).contiguous() - h = self.attention(qkv).contiguous() - h = self.proj_out(h).contiguous() - return (x + h).reshape(b, c, *spatial).contiguous() - - -def count_flops_attn(model, _x, y): - """ - A counter for the `thop` package to count the operations in an - attention operation. - Meant to be used like: - macs, params = thop.profile( - model, - inputs=(inputs, timestamps), - custom_ops={QKVAttention: QKVAttention.count_flops}, - ) - """ - b, c, *spatial = y[0].shape - num_spatial = int(np.prod(spatial)) - # We perform two matmuls with the same number of ops. - # The first computes the weight matrix, the second computes - # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial**2) * c - model.total_ops += th.DoubleTensor([matmul_ops]) - - -class QKVAttentionLegacy(nn.Module): - """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = ( - qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1) - ) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length).contiguous() - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class QKVAttention(nn.Module): - """ - A module which performs QKV attention and splits in a different order. - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.chunk(3, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", - (q * scale).view(bs * self.n_heads, ch, length), - (k * scale).view(bs * self.n_heads, ch, length), - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum( - "bts,bcs->bct", - weight, - v.reshape(bs * self.n_heads, ch, length).contiguous(), - ) - return a.reshape(bs, -1, length).contiguous() - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class UNetModel(nn.Module): - """ - The full UNet model with attention and timestep embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - extra_sa_layer=True, - num_classes=None, - extra_film_condition_dim=None, - use_checkpoint=False, - use_fp16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=True, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - ): - super().__init__() - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert ( - num_head_channels != -1 - ), "Either num_heads or num_head_channels has to be set" - - if num_head_channels == -1: - assert ( - num_heads != -1 - ), "Either num_heads or num_head_channels has to be set" - - self.image_size = image_size - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.num_classes = num_classes - self.extra_film_condition_dim = extra_film_condition_dim - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.predict_codebook_ids = n_embed is not None - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - # assert not ( - # self.num_classes is not None and self.extra_film_condition_dim is not None - # ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim." - - if self.num_classes is not None: - self.label_emb = nn.Embedding(num_classes, time_embed_dim) - - self.use_extra_film_by_concat = self.extra_film_condition_dim is not None - - if self.extra_film_condition_dim is not None: - self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim) - print( - "+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " - % self.extra_film_condition_dim - ) - - if context_dim is not None and not use_spatial_transformer: - assert ( - use_spatial_transformer - ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." - - if context_dim is not None and not isinstance(context_dim, list): - context_dim = [context_dim] - elif context_dim is None: - context_dim = [None] # At least use one spatial transformer - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim - if (not self.use_extra_film_by_concat) - else time_embed_dim * 2, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - dim_head = ( - ch // num_heads - if use_spatial_transformer - else num_head_channels - ) - if extra_sa_layer: - layers.append( - SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=None, - ) - ) - for context_dim_id in range(len(context_dim)): - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim[context_dim_id], - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim - if (not self.use_extra_film_by_concat) - else time_embed_dim * 2, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - middle_layers = [ - ResBlock( - ch, - time_embed_dim - if (not self.use_extra_film_by_concat) - else time_embed_dim * 2, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - if extra_sa_layer: - middle_layers.append( - SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=None - ) - ) - for context_dim_id in range(len(context_dim)): - middle_layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim[context_dim_id], - ) - ) - middle_layers.append( - ResBlock( - ch, - time_embed_dim - if (not self.use_extra_film_by_concat) - else time_embed_dim * 2, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ) - self.middle_block = TimestepEmbedSequential(*middle_layers) - - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim - if (not self.use_extra_film_by_concat) - else time_embed_dim * 2, - dropout, - out_channels=model_channels * mult, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ( - ch // num_heads - if use_spatial_transformer - else num_head_channels - ) - if extra_sa_layer: - layers.append( - SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=None, - ) - ) - for context_dim_id in range(len(context_dim)): - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads_upsample, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim[context_dim_id], - ) - ) - if level and i == num_res_blocks: - out_ch = ch - layers.append( - ResBlock( - ch, - time_embed_dim - if (not self.use_extra_film_by_concat) - else time_embed_dim * 2, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - if self.predict_codebook_ids: - self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) - - self.shape_reported = False - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - self.output_blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - self.output_blocks.apply(convert_module_to_f32) - - def forward( - self, - x, - timesteps=None, - y=None, - context_list=None, - context_attn_mask_list=None, - **kwargs, - ): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param context: conditioning plugged in via crossattn - :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional - :return: an [N x C x ...] Tensor of outputs. - """ - if not self.shape_reported: - # print("The shape of UNet input is", x.size()) - self.shape_reported = True - - assert (y is not None) == ( - self.num_classes is not None or self.extra_film_condition_dim is not None - ), "must specify y if and only if the model is class-conditional or film embedding conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - # if self.num_classes is not None: - # assert y.shape == (x.shape[0],) - # emb = emb + self.label_emb(y) - - if self.use_extra_film_by_concat: - emb = th.cat([emb, self.film_emb(y)], dim=-1) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context_list, context_attn_mask_list) - hs.append(h) - h = self.middle_block(h, emb, context_list, context_attn_mask_list) - for module in self.output_blocks: - concate_tensor = hs.pop() - h = th.cat([h, concate_tensor], dim=1) - h = module(h, emb, context_list, context_attn_mask_list) - h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) - - -class EncoderUNetModel(nn.Module): - """ - The half UNet model with attention and timestep embedding. - For usage, see UNet. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - pool="adaptive", - *args, - **kwargs, - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - self.pool = pool - if pool == "adaptive": - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - nn.AdaptiveAvgPool2d((1, 1)), - zero_module(conv_nd(dims, ch, out_channels, 1)), - nn.Flatten(), - ) - elif pool == "attention": - assert num_head_channels != -1 - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - AttentionPool2d( - (image_size // ds), ch, num_head_channels, out_channels - ), - ) - elif pool == "spatial": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - nn.ReLU(), - nn.Linear(2048, self.out_channels), - ) - elif pool == "spatial_v2": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - normalization(2048), - nn.SiLU(), - nn.Linear(2048, self.out_channels), - ) - else: - raise NotImplementedError(f"Unexpected {pool} pooling") - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - - def forward(self, x, timesteps): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :return: an [N x K] Tensor of outputs. - """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - results = [] - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = self.middle_block(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = th.cat(results, axis=-1) - return self.out(h) - else: - h = h.type(x.dtype) - return self.out(h) diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/util.py b/audioldm2/latent_diffusion/modules/diffusionmodules/util.py deleted file mode 100755 index 0d486f919a7ccc0586bc40225dac0ffb33aed01c..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/diffusionmodules/util.py +++ /dev/null @@ -1,294 +0,0 @@ -# adopted from -# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py -# and -# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -# and -# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py -# -# thanks! - - -import math -import torch -import torch.nn as nn -import numpy as np -from einops import repeat - -from audioldm2.latent_diffusion.util import instantiate_from_config - - -def make_beta_schedule( - schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 -): - if schedule == "linear": - betas = ( - torch.linspace( - linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 - ) - ** 2 - ) - - elif schedule == "cosine": - timesteps = ( - torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s - ) - alphas = timesteps / (1 + cosine_s) * np.pi / 2 - alphas = torch.cos(alphas).pow(2) - alphas = alphas / alphas[0] - betas = 1 - alphas[1:] / alphas[:-1] - betas = np.clip(betas, a_min=0, a_max=0.999) - - elif schedule == "sqrt_linear": - betas = torch.linspace( - linear_start, linear_end, n_timestep, dtype=torch.float64 - ) - elif schedule == "sqrt": - betas = ( - torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) - ** 0.5 - ) - else: - raise ValueError(f"schedule '{schedule}' unknown.") - return betas.numpy() - - -def make_ddim_timesteps( - ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True -): - if ddim_discr_method == "uniform": - c = num_ddpm_timesteps // num_ddim_timesteps - ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) - elif ddim_discr_method == "quad": - ddim_timesteps = ( - (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 - ).astype(int) - else: - raise NotImplementedError( - f'There is no ddim discretization method called "{ddim_discr_method}"' - ) - - # assert ddim_timesteps.shape[0] == num_ddim_timesteps - # add one to get the final alpha values right (the ones from first scale to data during sampling) - steps_out = ddim_timesteps + 1 - if verbose: - print(f"Selected timesteps for ddim sampler: {steps_out}") - return steps_out - - -def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): - # select alphas for computing the variance schedule - alphas = alphacums[ddim_timesteps] - alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) - - # according the the formula provided in https://arxiv.org/abs/2010.02502 - sigmas = eta * np.sqrt( - (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) - ) - if verbose: - print( - f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" - ) - print( - f"For the chosen value of eta, which is {eta}, " - f"this results in the following sigma_t schedule for ddim sampler {sigmas}" - ) - return sigmas, alphas, alphas_prev - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas) - - -def extract_into_tensor(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t).contiguous() - return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() - - -def checkpoint(func, inputs, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. - :param func: the function to evaluate. - :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) - - -class CheckpointFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, run_function, length, *args): - ctx.run_function = run_function - ctx.input_tensors = list(args[:length]) - ctx.input_params = list(args[length:]) - - with torch.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - input_grads = torch.autograd.grad( - output_tensors, - ctx.input_tensors + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return (None, None) + input_grads - - -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) - else: - embedding = repeat(timesteps, "b -> b d", d=dim) - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def scale_module(module, scale): - """ - Scale the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().mul_(scale) - return module - - -def mean_flat(tensor): - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def normalization(channels): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels) - - -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -def avg_pool_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D average pooling module. - """ - if dims == 1: - return nn.AvgPool1d(*args, **kwargs) - elif dims == 2: - return nn.AvgPool2d(*args, **kwargs) - elif dims == 3: - return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -class HybridConditioner(nn.Module): - def __init__(self, c_concat_config, c_crossattn_config): - super().__init__() - self.concat_conditioner = instantiate_from_config(c_concat_config) - self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) - - def forward(self, c_concat, c_crossattn): - c_concat = self.concat_conditioner(c_concat) - c_crossattn = self.crossattn_conditioner(c_crossattn) - return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} - - -def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( - shape[0], *((1,) * (len(shape) - 1)) - ) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() diff --git a/audioldm2/latent_diffusion/modules/ema.py b/audioldm2/latent_diffusion/modules/ema.py deleted file mode 100755 index 880ca3d205d9b4d7450e146930a93f2e63c58b70..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/ema.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -from torch import nn - - -class LitEma(nn.Module): - def __init__(self, model, decay=0.9999, use_num_upates=True): - super().__init__() - if decay < 0.0 or decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.m_name2s_name = {} - self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) - self.register_buffer( - "num_updates", - torch.tensor(0, dtype=torch.int) - if use_num_upates - else torch.tensor(-1, dtype=torch.int), - ) - - for name, p in model.named_parameters(): - if p.requires_grad: - # remove as '.'-character is not allowed in buffers - s_name = name.replace(".", "") - self.m_name2s_name.update({name: s_name}) - self.register_buffer(s_name, p.clone().detach().data) - - self.collected_params = [] - - def forward(self, model): - decay = self.decay - - if self.num_updates >= 0: - self.num_updates += 1 - decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) - - one_minus_decay = 1.0 - decay - - with torch.no_grad(): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - - for key in m_param: - if m_param[key].requires_grad: - sname = self.m_name2s_name[key] - shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_( - one_minus_decay * (shadow_params[sname] - m_param[key]) - ) - else: - assert not key in self.m_name2s_name - - def copy_to(self, model): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - for key in m_param: - if m_param[key].requires_grad: - m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) - else: - assert not key in self.m_name2s_name - - def store(self, parameters): - """ - Save the current parameters for restoring later. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - self.collected_params = [param.clone() for param in parameters] - - def restore(self, parameters): - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. - """ - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) diff --git a/audioldm2/latent_diffusion/modules/encoders/__init__.py b/audioldm2/latent_diffusion/modules/encoders/__init__.py deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm2/latent_diffusion/modules/encoders/modules.py b/audioldm2/latent_diffusion/modules/encoders/modules.py deleted file mode 100755 index 7a72339840c0c3b667e907ea07ee7cb755eb66fd..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/encoders/modules.py +++ /dev/null @@ -1,736 +0,0 @@ -import torch -import logging -import torch.nn as nn -from audioldm2.clap.open_clip import create_model -from audioldm2.clap.training.data import get_audio_features -import torchaudio -from transformers import RobertaTokenizer, AutoTokenizer, T5EncoderModel -import torch.nn.functional as F -from audioldm2.latent_diffusion.modules.audiomae.AudioMAE import Vanilla_AudioMAE -from audioldm2.latent_diffusion.modules.phoneme_encoder.encoder import TextEncoder - -from transformers import AutoTokenizer, T5Config - -from audioldm2.audiomae_gen.sequence_input import Sequence2AudioMAE -import numpy as np - -""" -The model forward function can return three types of data: -1. tensor: used directly as conditioning signal -2. dict: where there is a main key as condition, there are also other key that you can use to pass loss function and itermediate result. etc. -3. list: the length is 2, in which the first element is tensor, the second element is attntion mask. - -The output shape for the cross attention condition should be: -x,x_mask = [bs, seq_len, emb_dim], [bs, seq_len] - -All the returned data, in which will be used as diffusion input, will need to be in float type -""" - - -class PhonemeEncoder(nn.Module): - def __init__(self, vocabs_size=41, pad_length=250, pad_token_id=None): - super().__init__() - """ - encoder = PhonemeEncoder(40) - data = torch.randint(0, 39, (2, 250)) - output = encoder(data) - import ipdb;ipdb.set_trace() - """ - assert pad_token_id is not None - - self.device = None - self.PAD_LENGTH = int(pad_length) - self.pad_token_id = pad_token_id - self.pad_token_sequence = torch.tensor([self.pad_token_id] * self.PAD_LENGTH) - - self.text_encoder = TextEncoder( - n_vocab=vocabs_size, - out_channels=192, - hidden_channels=192, - filter_channels=768, - n_heads=2, - n_layers=6, - kernel_size=3, - p_dropout=0.1, - ) - - self.learnable_positional_embedding = torch.nn.Parameter( - torch.zeros((1, 192, self.PAD_LENGTH)) - ) # [batchsize, seqlen, padlen] - self.learnable_positional_embedding.requires_grad = True - - # Required - def get_unconditional_condition(self, batchsize): - unconditional_tokens = self.pad_token_sequence.expand( - batchsize, self.PAD_LENGTH - ) - return self(unconditional_tokens) # Need to return float type - - # def get_unconditional_condition(self, batchsize): - - # hidden_state = torch.zeros((batchsize, self.PAD_LENGTH, 192)).to(self.device) - # attention_mask = torch.ones((batchsize, self.PAD_LENGTH)).to(self.device) - # return [hidden_state, attention_mask] # Need to return float type - - def _get_src_mask(self, phoneme): - src_mask = phoneme != self.pad_token_id - return src_mask - - def _get_src_length(self, phoneme): - src_mask = self._get_src_mask(phoneme) - length = torch.sum(src_mask, dim=-1) - return length - - # def make_empty_condition_unconditional(self, src_length, text_emb, attention_mask): - # # src_length: [bs] - # # text_emb: [bs, 192, pad_length] - # # attention_mask: [bs, pad_length] - # mask = src_length[..., None, None] > 1 - # text_emb = text_emb * mask - - # attention_mask[src_length < 1] = attention_mask[src_length < 1] * 0.0 + 1.0 - # return text_emb, attention_mask - - def forward(self, phoneme_idx): - if self.device is None: - self.device = self.learnable_positional_embedding.device - self.pad_token_sequence = self.pad_token_sequence.to(self.device) - - src_length = self._get_src_length(phoneme_idx) - text_emb, m, logs, text_emb_mask = self.text_encoder(phoneme_idx, src_length) - text_emb = text_emb + self.learnable_positional_embedding - - # text_emb, text_emb_mask = self.make_empty_condition_unconditional(src_length, text_emb, text_emb_mask) - - return [ - text_emb.permute(0, 2, 1), - text_emb_mask.squeeze(1), - ] # [2, 250, 192], [2, 250] - - -class FlanT5HiddenState(nn.Module): - """ - llama = FlanT5HiddenState() - data = ["","this is not an empty sentence"] - encoder_hidden_states = llama(data) - import ipdb;ipdb.set_trace() - """ - - def __init__( - self, text_encoder_name="google/flan-t5-large", freeze_text_encoder=True - ): - super().__init__() - self.freeze_text_encoder = freeze_text_encoder - self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_name) - self.model = T5EncoderModel(T5Config.from_pretrained(text_encoder_name)) - if freeze_text_encoder: - self.model.eval() - for p in self.model.parameters(): - p.requires_grad = False - else: - print("=> The text encoder is learnable") - - self.empty_hidden_state_cfg = None - self.device = None - - # Required - def get_unconditional_condition(self, batchsize): - param = next(self.model.parameters()) - if self.freeze_text_encoder: - assert param.requires_grad == False - - # device = param.device - if self.empty_hidden_state_cfg is None: - self.empty_hidden_state_cfg, _ = self([""]) - - hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float() - attention_mask = ( - torch.ones((batchsize, hidden_state.size(1))) - .to(hidden_state.device) - .float() - ) - return [hidden_state, attention_mask] # Need to return float type - - def forward(self, batch): - param = next(self.model.parameters()) - if self.freeze_text_encoder: - assert param.requires_grad == False - - if self.device is None: - self.device = param.device - - # print("Manually change text") - # for i in range(len(batch)): - # batch[i] = "dog barking" - try: - return self.encode_text(batch) - except Exception as e: - print(e, batch) - logging.exception("An error occurred: %s", str(e)) - - def encode_text(self, prompt): - device = self.model.device - batch = self.tokenizer( - prompt, - max_length=128, # self.tokenizer.model_max_length - padding=True, - truncation=True, - return_tensors="pt", - ) - input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to( - device - ) - # Get text encoding - if self.freeze_text_encoder: - with torch.no_grad(): - encoder_hidden_states = self.model( - input_ids=input_ids, attention_mask=attention_mask - )[0] - else: - encoder_hidden_states = self.model( - input_ids=input_ids, attention_mask=attention_mask - )[0] - return [ - encoder_hidden_states.detach(), - attention_mask.float(), - ] # Attention mask == 1 means usable token - - -class SequenceGenAudioMAECond(Sequence2AudioMAE): - def __init__( - self, - cond_stage_config, - base_learning_rate, - sequence_gen_length, - sequence_input_key, - sequence_input_embed_dim, - batchsize, - always_output_audiomae_gt=False, - pretrained_path=None, - force_reload_pretrain_avoid_overwrite=False, - learnable=True, - use_warmup=True, - device=None, - use_gt_mae_output=None, # False: does not use AudioMAE GT, True: Use AudioMAE GT - use_gt_mae_prob=None, - ): # The prob of using AudioMAE GT - if use_warmup: - use_warmup = False - - super().__init__( - base_learning_rate=base_learning_rate, - cond_stage_config=cond_stage_config, - sequence_gen_length=sequence_gen_length, - sequence_input_key=sequence_input_key, - use_warmup=use_warmup, - sequence_input_embed_dim=sequence_input_embed_dim, - batchsize=batchsize, - ) - - assert use_gt_mae_output is not None and use_gt_mae_prob is not None - self.always_output_audiomae_gt = always_output_audiomae_gt - self.force_reload_pretrain_avoid_overwrite = ( - force_reload_pretrain_avoid_overwrite - ) - self.pretrained_path = pretrained_path - self.device = device - if self.force_reload_pretrain_avoid_overwrite: - self.is_reload = False - else: - self.is_reload = True - - self.load_pretrain_model() - - self.use_gt_mae_output = use_gt_mae_output - self.use_gt_mae_prob = use_gt_mae_prob - self.learnable = learnable - - if not learnable: - # Only optimize the GPT2 model - for p in self.model.parameters(): - p.requires_grad = False - self.eval() - - def load_pretrain_model(self): - if self.pretrained_path is not None: - print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path) - state_dict = torch.load(self.pretrained_path)["state_dict"] - self.load_state_dict(state_dict) - - # Required - def get_unconditional_condition(self, batchsize): - return_dict = self.cfg_uncond(batchsize) - return_dict["crossattn_audiomae_generated"] = [ - return_dict["crossattn_audiomae_pooled"][0], - torch.ones_like(return_dict["crossattn_audiomae_pooled"][1]).float(), - ] - return return_dict - - def forward(self, batch): - # The conditional module can return both tensor or dictionaries - # The returned tensor will be corresponding to the cond_stage_key - # The returned dict will have keys that correspond to the cond_stage_key - ret_dict = {} - - if self.force_reload_pretrain_avoid_overwrite and not self.is_reload: - self.load_pretrain_model() - self.is_reload = True - - # if(self.always_output_audiomae_gt or (self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob)): - # cond_dict = self.get_input(batch) - # ret_dict["crossattn_audiomae_generated"] = [cond_dict["crossattn_audiomae_pooled"][0], torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float()] # Input sequence and mask - # else: - input_embeds, cond_dict = self.generate(batch) - input_embeds_mask = ( - torch.ones((input_embeds.size(0), input_embeds.size(1))) - .to(input_embeds.device) - .float() - ) - ret_dict["crossattn_audiomae_generated"] = [ - input_embeds, - input_embeds_mask, - ] # Input sequence and mask - - # If the following two keys are not in cond_stage_key, then they will not be used as condition - for key in cond_dict.keys(): - ret_dict[key] = cond_dict[key] - - return ret_dict - - -class AudioMAEConditionCTPoolRandTFSeparated(nn.Module): - """ - audiomae = AudioMAEConditionCTPool2x2() - data = torch.randn((4, 1024, 128)) - output = audiomae(data) - import ipdb;ipdb.set_trace() - exit(0) - """ - - def __init__( - self, - time_pooling_factors=[1, 2, 4, 8], - freq_pooling_factors=[1, 2, 4, 8], - eval_time_pooling=None, - eval_freq_pooling=None, - mask_ratio=0.0, - regularization=False, - no_audiomae_mask=True, - no_audiomae_average=False, - ): - super().__init__() - self.device = None - self.time_pooling_factors = time_pooling_factors - self.freq_pooling_factors = freq_pooling_factors - self.no_audiomae_mask = no_audiomae_mask - self.no_audiomae_average = no_audiomae_average - - self.eval_freq_pooling = eval_freq_pooling - self.eval_time_pooling = eval_time_pooling - self.mask_ratio = mask_ratio - self.use_reg = regularization - - self.audiomae = Vanilla_AudioMAE() - self.audiomae.eval() - for p in self.audiomae.parameters(): - p.requires_grad = False - - # Required - def get_unconditional_condition(self, batchsize): - param = next(self.audiomae.parameters()) - assert param.requires_grad == False - device = param.device - # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) - time_pool, freq_pool = min(self.eval_time_pooling, 64), min( - self.eval_freq_pooling, 8 - ) - # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] - # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] - token_num = int(512 / (time_pool * freq_pool)) - return [ - torch.zeros((batchsize, token_num, 768)).to(device).float(), - torch.ones((batchsize, token_num)).to(device).float(), - ] - - def pool(self, representation, time_pool=None, freq_pool=None): - assert representation.size(-1) == 768 - representation = representation[:, 1:, :].transpose(1, 2) - bs, embedding_dim, token_num = representation.size() - representation = representation.reshape(bs, embedding_dim, 64, 8) - - if self.training: - if time_pool is None and freq_pool is None: - time_pool = min( - 64, - self.time_pooling_factors[ - np.random.choice(list(range(len(self.time_pooling_factors)))) - ], - ) - freq_pool = min( - 8, - self.freq_pooling_factors[ - np.random.choice(list(range(len(self.freq_pooling_factors)))) - ], - ) - # freq_pool = min(8, time_pool) # TODO here I make some modification. - else: - time_pool, freq_pool = min(self.eval_time_pooling, 64), min( - self.eval_freq_pooling, 8 - ) - - self.avgpooling = nn.AvgPool2d( - kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) - ) - self.maxpooling = nn.MaxPool2d( - kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) - ) - - pooled = ( - self.avgpooling(representation) + self.maxpooling(representation) - ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] - pooled = pooled.flatten(2).transpose(1, 2) - return pooled # [bs, token_num, embedding_dim] - - def regularization(self, x): - assert x.size(-1) == 768 - x = F.normalize(x, p=2, dim=-1) - return x - - # Required - def forward(self, batch, time_pool=None, freq_pool=None): - assert batch.size(-2) == 1024 and batch.size(-1) == 128 - - if self.device is None: - self.device = batch.device - - batch = batch.unsqueeze(1) - with torch.no_grad(): - representation = self.audiomae( - batch, - mask_ratio=self.mask_ratio, - no_mask=self.no_audiomae_mask, - no_average=self.no_audiomae_average, - ) - representation = self.pool(representation, time_pool, freq_pool) - if self.use_reg: - representation = self.regularization(representation) - return [ - representation, - torch.ones((representation.size(0), representation.size(1))) - .to(representation.device) - .float(), - ] - - -class AudioMAEConditionCTPoolRand(nn.Module): - """ - audiomae = AudioMAEConditionCTPool2x2() - data = torch.randn((4, 1024, 128)) - output = audiomae(data) - import ipdb;ipdb.set_trace() - exit(0) - """ - - def __init__( - self, - time_pooling_factors=[1, 2, 4, 8], - freq_pooling_factors=[1, 2, 4, 8], - eval_time_pooling=None, - eval_freq_pooling=None, - mask_ratio=0.0, - regularization=False, - no_audiomae_mask=True, - no_audiomae_average=False, - ): - super().__init__() - self.device = None - self.time_pooling_factors = time_pooling_factors - self.freq_pooling_factors = freq_pooling_factors - self.no_audiomae_mask = no_audiomae_mask - self.no_audiomae_average = no_audiomae_average - - self.eval_freq_pooling = eval_freq_pooling - self.eval_time_pooling = eval_time_pooling - self.mask_ratio = mask_ratio - self.use_reg = regularization - - self.audiomae = Vanilla_AudioMAE() - self.audiomae.eval() - for p in self.audiomae.parameters(): - p.requires_grad = False - - # Required - def get_unconditional_condition(self, batchsize): - param = next(self.audiomae.parameters()) - assert param.requires_grad == False - device = param.device - # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) - time_pool, freq_pool = min(self.eval_time_pooling, 64), min( - self.eval_freq_pooling, 8 - ) - # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] - # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] - token_num = int(512 / (time_pool * freq_pool)) - return [ - torch.zeros((batchsize, token_num, 768)).to(device).float(), - torch.ones((batchsize, token_num)).to(device).float(), - ] - - def pool(self, representation, time_pool=None, freq_pool=None): - assert representation.size(-1) == 768 - representation = representation[:, 1:, :].transpose(1, 2) - bs, embedding_dim, token_num = representation.size() - representation = representation.reshape(bs, embedding_dim, 64, 8) - - if self.training: - if time_pool is None and freq_pool is None: - time_pool = min( - 64, - self.time_pooling_factors[ - np.random.choice(list(range(len(self.time_pooling_factors)))) - ], - ) - # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] - freq_pool = min(8, time_pool) # TODO here I make some modification. - else: - time_pool, freq_pool = min(self.eval_time_pooling, 64), min( - self.eval_freq_pooling, 8 - ) - - self.avgpooling = nn.AvgPool2d( - kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) - ) - self.maxpooling = nn.MaxPool2d( - kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) - ) - - pooled = ( - self.avgpooling(representation) + self.maxpooling(representation) - ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] - pooled = pooled.flatten(2).transpose(1, 2) - return pooled # [bs, token_num, embedding_dim] - - def regularization(self, x): - assert x.size(-1) == 768 - x = F.normalize(x, p=2, dim=-1) - return x - - # Required - def forward(self, batch, time_pool=None, freq_pool=None): - assert batch.size(-2) == 1024 and batch.size(-1) == 128 - - if self.device is None: - self.device = batch.device - - batch = batch.unsqueeze(1) - with torch.no_grad(): - representation = self.audiomae( - batch, - mask_ratio=self.mask_ratio, - no_mask=self.no_audiomae_mask, - no_average=self.no_audiomae_average, - ) - representation = self.pool(representation, time_pool, freq_pool) - if self.use_reg: - representation = self.regularization(representation) - return [ - representation, - torch.ones((representation.size(0), representation.size(1))) - .to(representation.device) - .float(), - ] - - -class CLAPAudioEmbeddingClassifierFreev2(nn.Module): - def __init__( - self, - pretrained_path="", - sampling_rate=16000, - embed_mode="audio", - amodel="HTSAT-base", - unconditional_prob=0.1, - random_mute=False, - max_random_mute_portion=0.5, - training_mode=True, - ): - super().__init__() - self.device = "cpu" - self.precision = "fp32" - self.amodel = amodel # or 'PANN-14' - self.tmodel = "roberta" # the best text encoder in our training - self.enable_fusion = False # False if you do not want to use the fusion model - self.fusion_type = "aff_2d" - self.pretrained = pretrained_path - self.embed_mode = embed_mode - self.embed_mode_orig = embed_mode - self.sampling_rate = sampling_rate - self.unconditional_prob = unconditional_prob - self.random_mute = random_mute - self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") - self.max_random_mute_portion = max_random_mute_portion - self.training_mode = training_mode - self.model, self.model_cfg = create_model( - self.amodel, - self.tmodel, - self.pretrained, - precision=self.precision, - device=self.device, - enable_fusion=self.enable_fusion, - fusion_type=self.fusion_type, - ) - audio_cfg = self.model_cfg["audio_cfg"] - self.mel_transform = torchaudio.transforms.MelSpectrogram( - sample_rate=audio_cfg["sample_rate"], - n_fft=audio_cfg["window_size"], - win_length=audio_cfg["window_size"], - hop_length=audio_cfg["hop_size"], - center=True, - pad_mode="reflect", - power=2.0, - norm=None, - onesided=True, - n_mels=64, - f_min=audio_cfg["fmin"], - f_max=audio_cfg["fmax"], - ) - for p in self.model.parameters(): - p.requires_grad = False - self.unconditional_token = None - self.model.eval() - - def get_unconditional_condition(self, batchsize): - self.unconditional_token = self.model.get_text_embedding( - self.tokenizer(["", ""]) - )[0:1] - return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) - - def batch_to_list(self, batch): - ret = [] - for i in range(batch.size(0)): - ret.append(batch[i]) - return ret - - def make_decision(self, probability): - if float(torch.rand(1)) < probability: - return True - else: - return False - - def random_uniform(self, start, end): - val = torch.rand(1).item() - return start + (end - start) * val - - def _random_mute(self, waveform): - # waveform: [bs, t-steps] - t_steps = waveform.size(-1) - for i in range(waveform.size(0)): - mute_size = int( - self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) - ) - mute_start = int(self.random_uniform(0, t_steps - mute_size)) - waveform[i, mute_start : mute_start + mute_size] = 0 - return waveform - - def cos_similarity(self, waveform, text): - # waveform: [bs, t_steps] - original_embed_mode = self.embed_mode - with torch.no_grad(): - self.embed_mode = "audio" - audio_emb = self(waveform.cuda()) - self.embed_mode = "text" - text_emb = self(text) - similarity = F.cosine_similarity(audio_emb, text_emb, dim=2) - self.embed_mode = original_embed_mode - return similarity.squeeze() - - def build_unconditional_emb(self): - self.unconditional_token = self.model.get_text_embedding( - self.tokenizer(["", ""]) - )[0:1] - - def forward(self, batch): - # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 - # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 - if self.model.training == True and not self.training_mode: - print( - "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." - ) - self.model, self.model_cfg = create_model( - self.amodel, - self.tmodel, - self.pretrained, - precision=self.precision, - device="cuda", - enable_fusion=self.enable_fusion, - fusion_type=self.fusion_type, - ) - for p in self.model.parameters(): - p.requires_grad = False - self.model.eval() - - if self.unconditional_token is None: - self.build_unconditional_emb() - - # if(self.training_mode): - # assert self.model.training == True - # else: - # assert self.model.training == False - - # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode - if self.embed_mode == "audio": - if not self.training: - print("INFO: clap model calculate the audio embedding as condition") - with torch.no_grad(): - # assert ( - # self.sampling_rate == 16000 - # ), "We only support 16000 sampling rate" - - # if self.random_mute: - # batch = self._random_mute(batch) - # batch: [bs, 1, t-samples] - if self.sampling_rate != 48000: - batch = torchaudio.functional.resample( - batch, orig_freq=self.sampling_rate, new_freq=48000 - ) - - audio_data = batch.squeeze(1) - mel = self.mel_transform(audio_data) - audio_dict = get_audio_features( - audio_data, - mel, - 480000, - data_truncating="fusion", - data_filling="repeatpad", - audio_cfg=self.model_cfg["audio_cfg"], - ) - # [bs, 512] - embed = self.model.get_audio_embedding(audio_dict) - elif self.embed_mode == "text": - with torch.no_grad(): - # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode - text_data = self.tokenizer(batch) - - if isinstance(batch, str) or ( - isinstance(batch, list) and len(batch) == 1 - ): - for key in text_data.keys(): - text_data[key] = text_data[key].unsqueeze(0) - - embed = self.model.get_text_embedding(text_data) - - embed = embed.unsqueeze(1) - for i in range(embed.size(0)): - if self.make_decision(self.unconditional_prob): - embed[i] = self.unconditional_token - # embed = torch.randn((batch.size(0), 1, 512)).type_as(batch) - return embed.detach() - - def tokenizer(self, text): - result = self.tokenize( - text, - padding="max_length", - truncation=True, - max_length=512, - return_tensors="pt", - ) - return {k: v.squeeze(0) for k, v in result.items()} diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/__init__.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/__init__.py deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/attentions.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/attentions.py deleted file mode 100755 index 3553a688d41b07a45a7ced25f740a55dbc0b6d94..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/phoneme_encoder/attentions.py +++ /dev/null @@ -1,430 +0,0 @@ -import math -import torch -from torch import nn -from torch.nn import functional as F - -import audioldm2.latent_diffusion.modules.phoneme_encoder.commons as commons - -LRELU_SLOPE = 0.1 - - -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps - - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - - -class Encoder(nn.Module): - def __init__( - self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size=1, - p_dropout=0.0, - window_size=4, - **kwargs - ): - super().__init__() - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.window_size = window_size - - self.drop = nn.Dropout(p_dropout) - self.attn_layers = nn.ModuleList() - self.norm_layers_1 = nn.ModuleList() - self.ffn_layers = nn.ModuleList() - self.norm_layers_2 = nn.ModuleList() - for i in range(self.n_layers): - self.attn_layers.append( - MultiHeadAttention( - hidden_channels, - hidden_channels, - n_heads, - p_dropout=p_dropout, - window_size=window_size, - ) - ) - self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append( - FFN( - hidden_channels, - hidden_channels, - filter_channels, - kernel_size, - p_dropout=p_dropout, - ) - ) - self.norm_layers_2.append(LayerNorm(hidden_channels)) - - def forward(self, x, x_mask): - attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) - x = x * x_mask - for i in range(self.n_layers): - y = self.attn_layers[i](x, x, attn_mask) - y = self.drop(y) - x = self.norm_layers_1[i](x + y) - - y = self.ffn_layers[i](x, x_mask) - y = self.drop(y) - x = self.norm_layers_2[i](x + y) - x = x * x_mask - return x - - -class Decoder(nn.Module): - def __init__( - self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size=1, - p_dropout=0.0, - proximal_bias=False, - proximal_init=True, - **kwargs - ): - super().__init__() - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.proximal_bias = proximal_bias - self.proximal_init = proximal_init - - self.drop = nn.Dropout(p_dropout) - self.self_attn_layers = nn.ModuleList() - self.norm_layers_0 = nn.ModuleList() - self.encdec_attn_layers = nn.ModuleList() - self.norm_layers_1 = nn.ModuleList() - self.ffn_layers = nn.ModuleList() - self.norm_layers_2 = nn.ModuleList() - for i in range(self.n_layers): - self.self_attn_layers.append( - MultiHeadAttention( - hidden_channels, - hidden_channels, - n_heads, - p_dropout=p_dropout, - proximal_bias=proximal_bias, - proximal_init=proximal_init, - ) - ) - self.norm_layers_0.append(LayerNorm(hidden_channels)) - self.encdec_attn_layers.append( - MultiHeadAttention( - hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout - ) - ) - self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append( - FFN( - hidden_channels, - hidden_channels, - filter_channels, - kernel_size, - p_dropout=p_dropout, - causal=True, - ) - ) - self.norm_layers_2.append(LayerNorm(hidden_channels)) - - def forward(self, x, x_mask, h, h_mask): - """ - x: decoder input - h: encoder output - """ - self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( - device=x.device, dtype=x.dtype - ) - encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) - x = x * x_mask - for i in range(self.n_layers): - y = self.self_attn_layers[i](x, x, self_attn_mask) - y = self.drop(y) - x = self.norm_layers_0[i](x + y) - - y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) - y = self.drop(y) - x = self.norm_layers_1[i](x + y) - - y = self.ffn_layers[i](x, x_mask) - y = self.drop(y) - x = self.norm_layers_2[i](x + y) - x = x * x_mask - return x - - -class MultiHeadAttention(nn.Module): - def __init__( - self, - channels, - out_channels, - n_heads, - p_dropout=0.0, - window_size=None, - heads_share=True, - block_length=None, - proximal_bias=False, - proximal_init=False, - ): - super().__init__() - assert channels % n_heads == 0 - - self.channels = channels - self.out_channels = out_channels - self.n_heads = n_heads - self.p_dropout = p_dropout - self.window_size = window_size - self.heads_share = heads_share - self.block_length = block_length - self.proximal_bias = proximal_bias - self.proximal_init = proximal_init - self.attn = None - - self.k_channels = channels // n_heads - self.conv_q = nn.Conv1d(channels, channels, 1) - self.conv_k = nn.Conv1d(channels, channels, 1) - self.conv_v = nn.Conv1d(channels, channels, 1) - self.conv_o = nn.Conv1d(channels, out_channels, 1) - self.drop = nn.Dropout(p_dropout) - - if window_size is not None: - n_heads_rel = 1 if heads_share else n_heads - rel_stddev = self.k_channels**-0.5 - self.emb_rel_k = nn.Parameter( - torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) - * rel_stddev - ) - self.emb_rel_v = nn.Parameter( - torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) - * rel_stddev - ) - - nn.init.xavier_uniform_(self.conv_q.weight) - nn.init.xavier_uniform_(self.conv_k.weight) - nn.init.xavier_uniform_(self.conv_v.weight) - if proximal_init: - with torch.no_grad(): - self.conv_k.weight.copy_(self.conv_q.weight) - self.conv_k.bias.copy_(self.conv_q.bias) - - def forward(self, x, c, attn_mask=None): - q = self.conv_q(x) - k = self.conv_k(c) - v = self.conv_v(c) - - x, self.attn = self.attention(q, k, v, mask=attn_mask) - - x = self.conv_o(x) - return x - - def attention(self, query, key, value, mask=None): - # reshape [b, d, t] -> [b, n_h, t, d_k] - b, d, t_s, t_t = (*key.size(), query.size(2)) - query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) - key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - - scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) - if self.window_size is not None: - assert ( - t_s == t_t - ), "Relative attention is only available for self-attention." - key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) - rel_logits = self._matmul_with_relative_keys( - query / math.sqrt(self.k_channels), key_relative_embeddings - ) - scores_local = self._relative_position_to_absolute_position(rel_logits) - scores = scores + scores_local - if self.proximal_bias: - assert t_s == t_t, "Proximal bias is only available for self-attention." - scores = scores + self._attention_bias_proximal(t_s).to( - device=scores.device, dtype=scores.dtype - ) - if mask is not None: - scores = scores.masked_fill(mask == 0, -1e4) - if self.block_length is not None: - assert ( - t_s == t_t - ), "Local attention is only available for self-attention." - block_mask = ( - torch.ones_like(scores) - .triu(-self.block_length) - .tril(self.block_length) - ) - scores = scores.masked_fill(block_mask == 0, -1e4) - p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] - p_attn = self.drop(p_attn) - output = torch.matmul(p_attn, value) - if self.window_size is not None: - relative_weights = self._absolute_position_to_relative_position(p_attn) - value_relative_embeddings = self._get_relative_embeddings( - self.emb_rel_v, t_s - ) - output = output + self._matmul_with_relative_values( - relative_weights, value_relative_embeddings - ) - output = ( - output.transpose(2, 3).contiguous().view(b, d, t_t) - ) # [b, n_h, t_t, d_k] -> [b, d, t_t] - return output, p_attn - - def _matmul_with_relative_values(self, x, y): - """ - x: [b, h, l, m] - y: [h or 1, m, d] - ret: [b, h, l, d] - """ - ret = torch.matmul(x, y.unsqueeze(0)) - return ret - - def _matmul_with_relative_keys(self, x, y): - """ - x: [b, h, l, d] - y: [h or 1, m, d] - ret: [b, h, l, m] - """ - ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) - return ret - - def _get_relative_embeddings(self, relative_embeddings, length): - 2 * self.window_size + 1 - # Pad first before slice to avoid using cond ops. - pad_length = max(length - (self.window_size + 1), 0) - slice_start_position = max((self.window_size + 1) - length, 0) - slice_end_position = slice_start_position + 2 * length - 1 - if pad_length > 0: - padded_relative_embeddings = F.pad( - relative_embeddings, - commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), - ) - else: - padded_relative_embeddings = relative_embeddings - used_relative_embeddings = padded_relative_embeddings[ - :, slice_start_position:slice_end_position - ] - return used_relative_embeddings - - def _relative_position_to_absolute_position(self, x): - """ - x: [b, h, l, 2*l-1] - ret: [b, h, l, l] - """ - batch, heads, length, _ = x.size() - # Concat columns of pad to shift from relative to absolute indexing. - x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) - - # Concat extra elements so to add up to shape (len+1, 2*len-1). - x_flat = x.view([batch, heads, length * 2 * length]) - x_flat = F.pad( - x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) - ) - - # Reshape and slice out the padded elements. - x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ - :, :, :length, length - 1 : - ] - return x_final - - def _absolute_position_to_relative_position(self, x): - """ - x: [b, h, l, l] - ret: [b, h, l, 2*l-1] - """ - batch, heads, length, _ = x.size() - # padd along column - x = F.pad( - x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) - ) - x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) - # add 0's in the beginning that will skew the elements after reshape - x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) - x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] - return x_final - - def _attention_bias_proximal(self, length): - """Bias for self-attention to encourage attention to close positions. - Args: - length: an integer scalar. - Returns: - a Tensor with shape [1, 1, length, length] - """ - r = torch.arange(length, dtype=torch.float32) - diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) - return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) - - -class FFN(nn.Module): - def __init__( - self, - in_channels, - out_channels, - filter_channels, - kernel_size, - p_dropout=0.0, - activation=None, - causal=False, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.activation = activation - self.causal = causal - - if causal: - self.padding = self._causal_padding - else: - self.padding = self._same_padding - - self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) - self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) - self.drop = nn.Dropout(p_dropout) - - def forward(self, x, x_mask): - x = self.conv_1(self.padding(x * x_mask)) - if self.activation == "gelu": - x = x * torch.sigmoid(1.702 * x) - else: - x = torch.relu(x) - x = self.drop(x) - x = self.conv_2(self.padding(x * x_mask)) - return x * x_mask - - def _causal_padding(self, x): - if self.kernel_size == 1: - return x - pad_l = self.kernel_size - 1 - pad_r = 0 - padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad(x, commons.convert_pad_shape(padding)) - return x - - def _same_padding(self, x): - if self.kernel_size == 1: - return x - pad_l = (self.kernel_size - 1) // 2 - pad_r = self.kernel_size // 2 - padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad(x, commons.convert_pad_shape(padding)) - return x diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py deleted file mode 100755 index 9515724c12ab2f856b9a2ec14e38cc63df9b85d6..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py +++ /dev/null @@ -1,161 +0,0 @@ -import math -import torch -from torch.nn import functional as F - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - -def intersperse(lst, item): - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result - - -def kl_divergence(m_p, logs_p, m_q, logs_q): - """KL(P||Q)""" - kl = (logs_q - logs_p) - 0.5 - kl += ( - 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) - ) - return kl - - -def rand_gumbel(shape): - """Sample from the Gumbel distribution, protect from overflows.""" - uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 - return -torch.log(-torch.log(uniform_samples)) - - -def rand_gumbel_like(x): - g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) - return g - - -def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret - - -def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): - position = torch.arange(length, dtype=torch.float) - num_timescales = channels // 2 - log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( - num_timescales - 1 - ) - inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment - ) - scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) - signal = F.pad(signal, [0, 0, 0, channels % 2]) - signal = signal.view(1, channels, length) - return signal - - -def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return x + signal.to(dtype=x.dtype, device=x.device) - - -def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) - - -def subsequent_mask(length): - mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) - return mask - - -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts - - -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - -def shift_1d(x): - x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] - return x - - -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - -def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - duration.device - - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2, 3) * mask - return path - - -def clip_grad_value_(parameters, clip_value, norm_type=2): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - norm_type = float(norm_type) - if clip_value is not None: - clip_value = float(clip_value) - - total_norm = 0 - for p in parameters: - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item() ** norm_type - if clip_value is not None: - p.grad.data.clamp_(min=-clip_value, max=clip_value) - total_norm = total_norm ** (1.0 / norm_type) - return total_norm diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py deleted file mode 100755 index b39bf583b5ea88a4771181e491c8deb92b2d7559..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py +++ /dev/null @@ -1,50 +0,0 @@ -import math -import torch -from torch import nn - -import audioldm2.latent_diffusion.modules.phoneme_encoder.commons as commons -import audioldm2.latent_diffusion.modules.phoneme_encoder.attentions as attentions - - -class TextEncoder(nn.Module): - def __init__( - self, - n_vocab, - out_channels=192, - hidden_channels=192, - filter_channels=768, - n_heads=2, - n_layers=6, - kernel_size=3, - p_dropout=0.1, - ): - super().__init__() - self.n_vocab = n_vocab - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - - self.emb = nn.Embedding(n_vocab, hidden_channels) - nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) - - self.encoder = attentions.Encoder( - hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout - ) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths): - x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] - x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( - x.dtype - ) - - x = self.encoder(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - - m, logs = torch.split(stats, self.out_channels, dim=1) - return x, m, logs, x_mask diff --git a/audioldm2/latent_diffusion/util.py b/audioldm2/latent_diffusion/util.py deleted file mode 100755 index 3dd301b1c0a39a5b905aa23f4b98d224df7d87d9..0000000000000000000000000000000000000000 --- a/audioldm2/latent_diffusion/util.py +++ /dev/null @@ -1,217 +0,0 @@ -import importlib - -import torch -import numpy as np -from collections import abc - -import multiprocessing as mp -from threading import Thread -from queue import Queue - -from inspect import isfunction -from PIL import Image, ImageDraw, ImageFont - - -def log_txt_as_img(wh, xc, size=10): - # wh a tuple of (width, height) - # xc a list of captions to plot - b = len(xc) - txts = list() - for bi in range(b): - txt = Image.new("RGB", wh, color="white") - draw = ImageDraw.Draw(txt) - font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) - nc = int(40 * (wh[0] / 256)) - lines = "\n".join( - xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) - ) - - try: - draw.text((0, 0), lines, fill="black", font=font) - except UnicodeEncodeError: - print("Cant encode string for logging. Skipping.") - - txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 - txts.append(txt) - txts = np.stack(txts) - txts = torch.tensor(txts) - return txts - - -def ismap(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] > 3) - - -def isimage(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) - - -def int16_to_float32(x): - return (x / 32767.0).astype(np.float32) - - -def float32_to_int16(x): - x = np.clip(x, a_min=-1.0, a_max=1.0) - return (x * 32767.0).astype(np.int16) - - -def exists(x): - return x is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def mean_flat(tensor): - """ - https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def count_params(model, verbose=False): - total_params = sum(p.numel() for p in model.parameters()) - if verbose: - print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") - return total_params - - -def instantiate_from_config(config): - if not "target" in config: - if config == "__is_first_stage__": - return None - elif config == "__is_unconditional__": - return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) - - -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - - -def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): - # create dummy dataset instance - - # run prefetching - if idx_to_fn: - res = func(data, worker_id=idx) - else: - res = func(data) - Q.put([idx, res]) - Q.put("Done") - - -def parallel_data_prefetch( - func: callable, - data, - n_proc, - target_data_type="ndarray", - cpu_intensive=True, - use_worker_id=False, -): - # if target_data_type not in ["ndarray", "list"]: - # raise ValueError( - # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." - # ) - if isinstance(data, np.ndarray) and target_data_type == "list": - raise ValueError("list expected but function got ndarray.") - elif isinstance(data, abc.Iterable): - if isinstance(data, dict): - print( - f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' - ) - data = list(data.values()) - if target_data_type == "ndarray": - data = np.asarray(data) - else: - data = list(data) - else: - raise TypeError( - f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." - ) - - if cpu_intensive: - Q = mp.Queue(1000) - proc = mp.Process - else: - Q = Queue(1000) - proc = Thread - # spawn processes - if target_data_type == "ndarray": - arguments = [ - [func, Q, part, i, use_worker_id] - for i, part in enumerate(np.array_split(data, n_proc)) - ] - else: - step = ( - int(len(data) / n_proc + 1) - if len(data) % n_proc != 0 - else int(len(data) / n_proc) - ) - arguments = [ - [func, Q, part, i, use_worker_id] - for i, part in enumerate( - [data[i : i + step] for i in range(0, len(data), step)] - ) - ] - processes = [] - for i in range(n_proc): - p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) - processes += [p] - - # start processes - print(f"Start prefetching...") - import time - - start = time.time() - gather_res = [[] for _ in range(n_proc)] - try: - for p in processes: - p.start() - - k = 0 - while k < n_proc: - # get result - res = Q.get() - if res == "Done": - k += 1 - else: - gather_res[res[0]] = res[1] - - except Exception as e: - print("Exception: ", e) - for p in processes: - p.terminate() - - raise e - finally: - for p in processes: - p.join() - print(f"Prefetching complete. [{time.time() - start} sec.]") - - if target_data_type == "ndarray": - if not isinstance(gather_res[0], np.ndarray): - return np.concatenate([np.asarray(r) for r in gather_res], axis=0) - - # order outputs - return np.concatenate(gather_res, axis=0) - elif target_data_type == "list": - out = [] - for r in gather_res: - out.extend(r) - return out - else: - return gather_res diff --git a/audioldm2/latent_encoder/__init__.py b/audioldm2/latent_encoder/__init__.py deleted file mode 100755 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm2/latent_encoder/autoencoder.py b/audioldm2/latent_encoder/autoencoder.py deleted file mode 100755 index f07075bb76a34edd8568797961752e4957129f92..0000000000000000000000000000000000000000 --- a/audioldm2/latent_encoder/autoencoder.py +++ /dev/null @@ -1,326 +0,0 @@ -import torch -import os - -import torch.nn.functional as F -import numpy as np -from audioldm2.latent_diffusion.modules.ema import * - -from audioldm2.latent_diffusion.modules.diffusionmodules.model import Encoder, Decoder -from audioldm2.latent_diffusion.modules.distributions.distributions import ( - DiagonalGaussianDistribution, -) -import soundfile as sf - -from audioldm2.utilities.model import get_vocoder -from audioldm2.utilities.tools import synth_one_sample - - -class AutoencoderKL(nn.Module): - def __init__( - self, - ddconfig=None, - lossconfig=None, - batchsize=None, - embed_dim=None, - time_shuffle=1, - subband=1, - sampling_rate=16000, - ckpt_path=None, - reload_from_ckpt=None, - ignore_keys=[], - image_key="fbank", - colorize_nlabels=None, - monitor=None, - base_learning_rate=1e-5, - ): - super().__init__() - self.automatic_optimization = False - assert ( - "mel_bins" in ddconfig.keys() - ), "mel_bins is not specified in the Autoencoder config" - num_mel = ddconfig["mel_bins"] - self.image_key = image_key - self.sampling_rate = sampling_rate - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - - self.loss = None - self.subband = int(subband) - - if self.subband > 1: - print("Use subband decomposition %s" % self.subband) - - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - - if self.image_key == "fbank": - self.vocoder = get_vocoder(None, "cpu", num_mel) - self.embed_dim = embed_dim - if colorize_nlabels is not None: - assert type(colorize_nlabels) == int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - self.learning_rate = float(base_learning_rate) - # print("Initial learning rate %s" % self.learning_rate) - - self.time_shuffle = time_shuffle - self.reload_from_ckpt = reload_from_ckpt - self.reloaded = False - self.mean, self.std = None, None - - self.feature_cache = None - self.flag_first_run = True - self.train_step = 0 - - self.logger_save_dir = None - self.logger_exp_name = None - - def get_log_dir(self): - if self.logger_save_dir is None and self.logger_exp_name is None: - return os.path.join(self.logger.save_dir, self.logger._project) - else: - return os.path.join(self.logger_save_dir, self.logger_exp_name) - - def set_log_dir(self, save_dir, exp_name): - self.logger_save_dir = save_dir - self.logger_exp_name = exp_name - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") - - def encode(self, x): - # x = self.time_shuffle_operation(x) - # x = self.freq_split_subband(x) - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - # bs, ch, shuffled_timesteps, fbins = dec.size() - # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins) - # dec = self.freq_merge_subband(dec) - return dec - - def decode_to_waveform(self, dec): - from audioldm2.utilities.model import vocoder_infer - - if self.image_key == "fbank": - dec = dec.squeeze(1).permute(0, 2, 1) - wav_reconstruction = vocoder_infer(dec, self.vocoder) - elif self.image_key == "stft": - dec = dec.squeeze(1).permute(0, 2, 1) - wav_reconstruction = self.wave_decoder(dec) - return wav_reconstruction - - def visualize_latent(self, input): - import matplotlib.pyplot as plt - - # for i in range(10): - # zero_input = torch.zeros_like(input) - 11.59 - # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59 - - # posterior = self.encode(zero_input) - # latent = posterior.sample() - # avg_latent = torch.mean(latent, dim=1)[0] - # plt.imshow(avg_latent.cpu().detach().numpy().T) - # plt.savefig("%s.png" % i) - # plt.close() - - np.save("input.npy", input.cpu().detach().numpy()) - # zero_input = torch.zeros_like(input) - 11.59 - time_input = input.clone() - time_input[:, :, :, :32] *= 0 - time_input[:, :, :, :32] -= 11.59 - - np.save("time_input.npy", time_input.cpu().detach().numpy()) - - posterior = self.encode(time_input) - latent = posterior.sample() - np.save("time_latent.npy", latent.cpu().detach().numpy()) - avg_latent = torch.mean(latent, dim=1) - for i in range(avg_latent.size(0)): - plt.imshow(avg_latent[i].cpu().detach().numpy().T) - plt.savefig("freq_%s.png" % i) - plt.close() - - freq_input = input.clone() - freq_input[:, :, :512, :] *= 0 - freq_input[:, :, :512, :] -= 11.59 - - np.save("freq_input.npy", freq_input.cpu().detach().numpy()) - - posterior = self.encode(freq_input) - latent = posterior.sample() - np.save("freq_latent.npy", latent.cpu().detach().numpy()) - avg_latent = torch.mean(latent, dim=1) - for i in range(avg_latent.size(0)): - plt.imshow(avg_latent[i].cpu().detach().numpy().T) - plt.savefig("time_%s.png" % i) - plt.close() - - def get_input(self, batch): - fname, text, label_indices, waveform, stft, fbank = ( - batch["fname"], - batch["text"], - batch["label_vector"], - batch["waveform"], - batch["stft"], - batch["log_mel_spec"], - ) - # if(self.time_shuffle != 1): - # if(fbank.size(1) % self.time_shuffle != 0): - # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle) - # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len)) - - ret = {} - - ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( - fbank.unsqueeze(1), - stft.unsqueeze(1), - fname, - waveform.unsqueeze(1), - ) - - return ret - - def save_wave(self, batch_wav, fname, save_dir): - os.makedirs(save_dir, exist_ok=True) - - for wav, name in zip(batch_wav, fname): - name = os.path.basename(name) - - sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) - - def get_last_layer(self): - return self.decoder.conv_out.weight - - @torch.no_grad() - def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): - log = dict() - x = batch.to(self.device) - if not only_inputs: - xrec, posterior = self(x) - log["samples"] = self.decode(posterior.sample()) - log["reconstructions"] = xrec - - log["inputs"] = x - wavs = self._log_img(log, train=train, index=0, waveform=waveform) - return wavs - - def _log_img(self, log, train=True, index=0, waveform=None): - images_input = self.tensor2numpy(log["inputs"][index, 0]).T - images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T - images_samples = self.tensor2numpy(log["samples"][index, 0]).T - - if train: - name = "train" - else: - name = "val" - - if self.logger is not None: - self.logger.log_image( - "img_%s" % name, - [images_input, images_reconstruct, images_samples], - caption=["input", "reconstruct", "samples"], - ) - - inputs, reconstructions, samples = ( - log["inputs"], - log["reconstructions"], - log["samples"], - ) - - if self.image_key == "fbank": - wav_original, wav_prediction = synth_one_sample( - inputs[index], - reconstructions[index], - labels="validation", - vocoder=self.vocoder, - ) - wav_original, wav_samples = synth_one_sample( - inputs[index], samples[index], labels="validation", vocoder=self.vocoder - ) - wav_original, wav_samples, wav_prediction = ( - wav_original[0], - wav_samples[0], - wav_prediction[0], - ) - elif self.image_key == "stft": - wav_prediction = ( - self.decode_to_waveform(reconstructions)[index, 0] - .cpu() - .detach() - .numpy() - ) - wav_samples = ( - self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() - ) - wav_original = waveform[index, 0].cpu().detach().numpy() - - if self.logger is not None: - self.logger.experiment.log( - { - "original_%s" - % name: wandb.Audio( - wav_original, caption="original", sample_rate=self.sampling_rate - ), - "reconstruct_%s" - % name: wandb.Audio( - wav_prediction, - caption="reconstruct", - sample_rate=self.sampling_rate, - ), - "samples_%s" - % name: wandb.Audio( - wav_samples, caption="samples", sample_rate=self.sampling_rate - ), - } - ) - - return wav_original, wav_prediction, wav_samples - - def tensor2numpy(self, tensor): - return tensor.cpu().detach().numpy() - - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 - return x - - -class IdentityFirstStage(torch.nn.Module): - def __init__(self, *args, vq_interface=False, **kwargs): - self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff - super().__init__() - - def encode(self, x, *args, **kwargs): - return x - - def decode(self, x, *args, **kwargs): - return x - - def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x - - def forward(self, x, *args, **kwargs): - return x diff --git a/audioldm2/pipeline.py b/audioldm2/pipeline.py deleted file mode 100755 index 1eec55b0198049f8baf263c3b80a7a8a0584ebeb..0000000000000000000000000000000000000000 --- a/audioldm2/pipeline.py +++ /dev/null @@ -1,201 +0,0 @@ -import os - -import yaml -import torch -import torchaudio - -from audioldm2.latent_diffusion.models.ddpm import LatentDiffusion -from audioldm2.utils import default_audioldm_config, get_metadata, download_checkpoint -from audioldm2.utilities.audio import read_wav_file -import os - -CACHE_DIR = os.getenv( - "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2") -) - - -def seed_everything(seed): - import random, os - import numpy as np - import torch - - random.seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = True - - -def text_to_filename(text): - return text.replace(" ", "_").replace("'", "_").replace('"', "_") - - -def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec): - norm_mean = -4.2677393 - norm_std = 4.5689974 - - if sampling_rate != 16000: - waveform_16k = torchaudio.functional.resample( - waveform, orig_freq=sampling_rate, new_freq=16000 - ) - else: - waveform_16k = waveform - - waveform_16k = waveform_16k - waveform_16k.mean() - fbank = torchaudio.compliance.kaldi.fbank( - waveform_16k, - htk_compat=True, - sample_frequency=16000, - use_energy=False, - window_type="hanning", - num_mel_bins=128, - dither=0.0, - frame_shift=10, - ) - - TARGET_LEN = log_mel_spec.size(0) - - # cut and pad - n_frames = fbank.shape[0] - p = TARGET_LEN - n_frames - if p > 0: - m = torch.nn.ZeroPad2d((0, 0, 0, p)) - fbank = m(fbank) - elif p < 0: - fbank = fbank[:TARGET_LEN, :] - - fbank = (fbank - norm_mean) / (norm_std * 2) - - return {"ta_kaldi_fbank": fbank} # [1024, 128] - - -def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1): - text = [text] * batchsize - if batchsize < 1: - print("Warning: Batchsize must be at least 1. Batchsize is set to .") - - if fbank is None: - fbank = torch.zeros( - (batchsize, 1024, 64) - ) # Not used, here to keep the code format - else: - fbank = torch.FloatTensor(fbank) - fbank = fbank.expand(batchsize, 1024, 64) - assert fbank.size(0) == batchsize - - stft = torch.zeros((batchsize, 1024, 512)) # Not used - - if waveform is None: - waveform = torch.zeros((batchsize, 160000)) # Not used - ta_kaldi_fbank = torch.zeros((batchsize, 1024, 128)) - else: - waveform = torch.FloatTensor(waveform) - waveform = waveform.expand(batchsize, -1) - assert waveform.size(0) == batchsize - ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, 16000, fbank) - - batch = { - "text": text, # list - "fname": [text_to_filename(t) for t in text], # list - "waveform": waveform, - "stft": stft, - "log_mel_spec": fbank, - "ta_kaldi_fbank": ta_kaldi_fbank, - } - - return batch - - -def round_up_duration(duration): - return int(round(duration / 2.5) + 1) * 2.5 - - -def split_clap_weight_to_pth(checkpoint): - if os.path.exists(os.path.join(CACHE_DIR, "clap.pth")): - return - print("Constructing the weight for the CLAP model.") - include_keys = "cond_stage_models.0.cond_stage_models.0.model." - new_state_dict = {} - for each in checkpoint["state_dict"].keys(): - if include_keys in each: - new_state_dict[each.replace(include_keys, "module.")] = checkpoint[ - "state_dict" - ][each] - torch.save({"state_dict": new_state_dict}, os.path.join(CACHE_DIR, "clap.pth")) - - -def build_model(ckpt_path=None, config=None, model_name="audioldm2-full"): - print("Loading AudioLDM-2: %s" % model_name) - - if ckpt_path is None: - ckpt_path = get_metadata()[model_name]["path"] - - if not os.path.exists(ckpt_path): - download_checkpoint(model_name) - - if torch.cuda.is_available(): - device = torch.device("cuda:0") - else: - device = torch.device("cpu") - - if config is not None: - assert type(config) is str - config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) - else: - config = default_audioldm_config(model_name) - - # # Use text as condition instead of using waveform during training - config["model"]["params"]["device"] = device - # config["model"]["params"]["cond_stage_key"] = "text" - - # No normalization here - latent_diffusion = LatentDiffusion(**config["model"]["params"]) - - resume_from_checkpoint = ckpt_path - - checkpoint = torch.load(resume_from_checkpoint, map_location=device) - - latent_diffusion.load_state_dict(checkpoint["state_dict"]) - - latent_diffusion.eval() - latent_diffusion = latent_diffusion.to(device) - - return latent_diffusion - -def duration_to_latent_t_size(duration): - return int(duration * 25.6) - -def text_to_audio( - latent_diffusion, - text, - seed=42, - ddim_steps=200, - duration=10, - batchsize=1, - guidance_scale=3.5, - n_candidate_gen_per_text=3, - config=None, -): - assert ( - duration == 10 - ), "Error: Currently we only support 10 seconds of generation. Generating longer files requires some extra coding, which would be a part of the future work." - - seed_everything(int(seed)) - waveform = None - - batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize) - - latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) - - with torch.no_grad(): - waveform = latent_diffusion.generate_batch( - batch, - unconditional_guidance_scale=guidance_scale, - ddim_steps=ddim_steps, - n_gen=n_candidate_gen_per_text, - duration=duration, - ) - - return waveform diff --git a/audioldm2/utilities/__init__.py b/audioldm2/utilities/__init__.py deleted file mode 100755 index 495e8fe675337df0afacd3a31d06d0241b6b0e63..0000000000000000000000000000000000000000 --- a/audioldm2/utilities/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tools import * -from .data import * -from .model import * diff --git a/audioldm2/utilities/audio/__init__.py b/audioldm2/utilities/audio/__init__.py deleted file mode 100755 index c39f9243d2d7b4fc5dea18f56b153b0f5c5bbd4c..0000000000000000000000000000000000000000 --- a/audioldm2/utilities/audio/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .audio_processing import * -from .stft import * -from .tools import * diff --git a/audioldm2/utilities/audio/audio_processing.py b/audioldm2/utilities/audio/audio_processing.py deleted file mode 100755 index 77a4057aa82f226f68474f4c2a19eba84510d663..0000000000000000000000000000000000000000 --- a/audioldm2/utilities/audio/audio_processing.py +++ /dev/null @@ -1,100 +0,0 @@ -import torch -import numpy as np -import librosa.util as librosa_util -from scipy.signal import get_window - - -def window_sumsquare( - window, - n_frames, - hop_length, - win_length, - n_fft, - dtype=np.float32, - norm=None, -): - """ - # from librosa 0.6 - Compute the sum-square envelope of a window function at a given hop length. - - This is used to estimate modulation effects induced by windowing - observations in short-time fourier transforms. - - Parameters - ---------- - window : string, tuple, number, callable, or list-like - Window specification, as in `get_window` - - n_frames : int > 0 - The number of analysis frames - - hop_length : int > 0 - The number of samples to advance between frames - - win_length : [optional] - The length of the window function. By default, this matches `n_fft`. - - n_fft : int > 0 - The length of each analysis frame. - - dtype : np.dtype - The data type of the output - - Returns - ------- - wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` - The sum-squared envelope of the window function - """ - if win_length is None: - win_length = n_fft - - n = n_fft + hop_length * (n_frames - 1) - x = np.zeros(n, dtype=dtype) - - # Compute the squared window at the desired length - win_sq = get_window(window, win_length, fftbins=True) - win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 - win_sq = librosa_util.pad_center(win_sq, n_fft) - - # Fill the envelope - for i in range(n_frames): - sample = i * hop_length - x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] - return x - - -def griffin_lim(magnitudes, stft_fn, n_iters=30): - """ - PARAMS - ------ - magnitudes: spectrogram magnitudes - stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods - """ - - angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) - angles = angles.astype(np.float32) - angles = torch.autograd.Variable(torch.from_numpy(angles)) - signal = stft_fn.inverse(magnitudes, angles).squeeze(1) - - for i in range(n_iters): - _, angles = stft_fn.transform(signal) - signal = stft_fn.inverse(magnitudes, angles).squeeze(1) - return signal - - -def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): - """ - PARAMS - ------ - C: compression factor - """ - return normalize_fun(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression(x, C=1): - """ - PARAMS - ------ - C: compression factor used to compress - """ - return torch.exp(x) / C diff --git a/audioldm2/utilities/audio/stft.py b/audioldm2/utilities/audio/stft.py deleted file mode 100755 index 508f33674e6dd8a5557205c8e77e07955df13a87..0000000000000000000000000000000000000000 --- a/audioldm2/utilities/audio/stft.py +++ /dev/null @@ -1,178 +0,0 @@ -import torch -import torch.nn.functional as F -import numpy as np -from scipy.signal import get_window -from librosa.util import pad_center, tiny -from librosa.filters import mel as librosa_mel_fn - -from audioldm2.utilities.audio.audio_processing import ( - dynamic_range_compression, - dynamic_range_decompression, - window_sumsquare, -) - - -class STFT(torch.nn.Module): - """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" - - def __init__(self, filter_length, hop_length, win_length, window="hann"): - super(STFT, self).__init__() - self.filter_length = filter_length - self.hop_length = hop_length - self.win_length = win_length - self.window = window - self.forward_transform = None - scale = self.filter_length / self.hop_length - fourier_basis = np.fft.fft(np.eye(self.filter_length)) - - cutoff = int((self.filter_length / 2 + 1)) - fourier_basis = np.vstack( - [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] - ) - - forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) - inverse_basis = torch.FloatTensor( - np.linalg.pinv(scale * fourier_basis).T[:, None, :] - ) - - if window is not None: - assert filter_length >= win_length - # get window and zero center pad it to filter_length - fft_window = get_window(window, win_length, fftbins=True) - fft_window = pad_center(fft_window, filter_length) - fft_window = torch.from_numpy(fft_window).float() - - # window the bases - forward_basis *= fft_window - inverse_basis *= fft_window - - self.register_buffer("forward_basis", forward_basis.float()) - self.register_buffer("inverse_basis", inverse_basis.float()) - - def transform(self, input_data): - num_batches = input_data.size(0) - num_samples = input_data.size(1) - - self.num_samples = num_samples - - # similar to librosa, reflect-pad the input - input_data = input_data.view(num_batches, 1, num_samples) - input_data = F.pad( - input_data.unsqueeze(1), - (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), - mode="reflect", - ) - input_data = input_data.squeeze(1) - - forward_transform = F.conv1d( - input_data, - torch.autograd.Variable(self.forward_basis, requires_grad=False), - stride=self.hop_length, - padding=0, - ).cpu() - - cutoff = int((self.filter_length / 2) + 1) - real_part = forward_transform[:, :cutoff, :] - imag_part = forward_transform[:, cutoff:, :] - - magnitude = torch.sqrt(real_part**2 + imag_part**2) - phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) - - return magnitude, phase - - def inverse(self, magnitude, phase): - recombine_magnitude_phase = torch.cat( - [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 - ) - - inverse_transform = F.conv_transpose1d( - recombine_magnitude_phase, - torch.autograd.Variable(self.inverse_basis, requires_grad=False), - stride=self.hop_length, - padding=0, - ) - - if self.window is not None: - window_sum = window_sumsquare( - self.window, - magnitude.size(-1), - hop_length=self.hop_length, - win_length=self.win_length, - n_fft=self.filter_length, - dtype=np.float32, - ) - # remove modulation effects - approx_nonzero_indices = torch.from_numpy( - np.where(window_sum > tiny(window_sum))[0] - ) - window_sum = torch.autograd.Variable( - torch.from_numpy(window_sum), requires_grad=False - ) - window_sum = window_sum - inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ - approx_nonzero_indices - ] - - # scale by hop ratio - inverse_transform *= float(self.filter_length) / self.hop_length - - inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] - inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] - - return inverse_transform - - def forward(self, input_data): - self.magnitude, self.phase = self.transform(input_data) - reconstruction = self.inverse(self.magnitude, self.phase) - return reconstruction - - -class TacotronSTFT(torch.nn.Module): - def __init__( - self, - filter_length, - hop_length, - win_length, - n_mel_channels, - sampling_rate, - mel_fmin, - mel_fmax, - ): - super(TacotronSTFT, self).__init__() - self.n_mel_channels = n_mel_channels - self.sampling_rate = sampling_rate - self.stft_fn = STFT(filter_length, hop_length, win_length) - mel_basis = librosa_mel_fn( - sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax - ) - mel_basis = torch.from_numpy(mel_basis).float() - self.register_buffer("mel_basis", mel_basis) - - def spectral_normalize(self, magnitudes, normalize_fun): - output = dynamic_range_compression(magnitudes, normalize_fun) - return output - - def spectral_de_normalize(self, magnitudes): - output = dynamic_range_decompression(magnitudes) - return output - - def mel_spectrogram(self, y, normalize_fun=torch.log): - """Computes mel-spectrograms from a batch of waves - PARAMS - ------ - y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] - - RETURNS - ------- - mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) - """ - assert torch.min(y.data) >= -1, torch.min(y.data) - assert torch.max(y.data) <= 1, torch.max(y.data) - - magnitudes, phases = self.stft_fn.transform(y) - magnitudes = magnitudes.data - mel_output = torch.matmul(self.mel_basis, magnitudes) - mel_output = self.spectral_normalize(mel_output, normalize_fun) - energy = torch.norm(magnitudes, dim=1) - - return mel_output, magnitudes, phases, energy diff --git a/audioldm2/utilities/audio/tools.py b/audioldm2/utilities/audio/tools.py deleted file mode 100755 index 8c666a7c67e0ae93edbad666520fd2e98fd29d18..0000000000000000000000000000000000000000 --- a/audioldm2/utilities/audio/tools.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import numpy as np -from scipy.io.wavfile import write -import torchaudio - -from audioldm2.utilities.audio.audio_processing import griffin_lim - - -def pad_wav(waveform, segment_length): - waveform_length = waveform.shape[-1] - assert waveform_length > 100, "Waveform is too short, %s" % waveform_length - if segment_length is None or waveform_length == segment_length: - return waveform - elif waveform_length > segment_length: - return waveform[:segment_length] - elif waveform_length < segment_length: - temp_wav = np.zeros((1, segment_length)) - temp_wav[:, :waveform_length] = waveform - return temp_wav - - -def normalize_wav(waveform): - waveform = waveform - np.mean(waveform) - waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) - return waveform * 0.5 - - -def read_wav_file(filename, segment_length): - # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower - waveform, sr = torchaudio.load(filename) # Faster!!! - waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) - waveform = waveform.numpy()[0, ...] - waveform = normalize_wav(waveform) - waveform = waveform[None, ...] - waveform = pad_wav(waveform, segment_length) - - waveform = waveform / np.max(np.abs(waveform)) - waveform = 0.5 * waveform - - return waveform - - -def get_mel_from_wav(audio, _stft): - audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) - audio = torch.autograd.Variable(audio, requires_grad=False) - melspec, magnitudes, phases, energy = _stft.mel_spectrogram(audio) - melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) - magnitudes = torch.squeeze(magnitudes, 0).numpy().astype(np.float32) - energy = torch.squeeze(energy, 0).numpy().astype(np.float32) - return melspec, magnitudes, energy - - -def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): - mel = torch.stack([mel]) - mel_decompress = _stft.spectral_de_normalize(mel) - mel_decompress = mel_decompress.transpose(1, 2).data.cpu() - spec_from_mel_scaling = 1000 - spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) - spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) - spec_from_mel = spec_from_mel * spec_from_mel_scaling - - audio = griffin_lim( - torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters - ) - - audio = audio.squeeze() - audio = audio.cpu().numpy() - audio_path = out_filename - write(audio_path, _stft.sampling_rate, audio) diff --git a/audioldm2/utilities/data/__init__.py b/audioldm2/utilities/data/__init__.py deleted file mode 100755 index 13a9804e72b88e3b9078940aee87db73788c1fb5..0000000000000000000000000000000000000000 --- a/audioldm2/utilities/data/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .dataset import Dataset diff --git a/audioldm2/utilities/data/add_on.py b/audioldm2/utilities/data/add_on.py deleted file mode 100755 index 4cfc6297e2f66759077c1540fc04b19560f3659c..0000000000000000000000000000000000000000 --- a/audioldm2/utilities/data/add_on.py +++ /dev/null @@ -1,508 +0,0 @@ -import os -import torch -import numpy as np -import torchaudio -import matplotlib.pyplot as plt - -CACHE = { - "get_vits_phoneme_ids": { - "PAD_LENGTH": 310, - "_pad": "_", - "_punctuation": ';:,.!?¡¿—…"«»“” ', - "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", - "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ", - "_special": "♪☎☒☝⚠", - } -} - -CACHE["get_vits_phoneme_ids"]["symbols"] = ( - [CACHE["get_vits_phoneme_ids"]["_pad"]] - + list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) - + list(CACHE["get_vits_phoneme_ids"]["_letters"]) - + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) - + list(CACHE["get_vits_phoneme_ids"]["_special"]) -) -CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = { - s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"]) -} - - -def get_vits_phoneme_ids(config, dl_output, metadata): - pad_token_id = 0 - pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] - _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] - - assert ( - "phonemes" in metadata.keys() - ), "You must provide vits phonemes on using addon get_vits_phoneme_ids" - clean_text = metadata["phonemes"] - sequence = [] - - for symbol in clean_text: - symbol_id = _symbol_to_id[symbol] - sequence += [symbol_id] - - inserted_zero_sequence = [0] * (len(sequence) * 2) - inserted_zero_sequence[1::2] = sequence - inserted_zero_sequence = inserted_zero_sequence + [0] - - def _pad_phonemes(phonemes_list): - return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) - - return {"phoneme_idx": torch.LongTensor(_pad_phonemes(inserted_zero_sequence))} - - -def get_vits_phoneme_ids_no_padding(config, dl_output, metadata): - pad_token_id = 0 - pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] - _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] - - assert ( - "phonemes" in metadata.keys() - ), "You must provide vits phonemes on using addon get_vits_phoneme_ids" - clean_text = metadata["phonemes"] + "⚠" - sequence = [] - - for symbol in clean_text: - if symbol not in _symbol_to_id.keys(): - print("%s is not in the vocabulary. %s" % (symbol, clean_text)) - symbol = "_" - symbol_id = _symbol_to_id[symbol] - sequence += [symbol_id] - - def _pad_phonemes(phonemes_list): - return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) - - sequence = sequence[:pad_length] - - return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))} - - -def calculate_relative_bandwidth(config, dl_output, metadata): - assert "stft" in dl_output.keys() - - # The last dimension of the stft feature is the frequency dimension - freq_dimensions = dl_output["stft"].size(-1) - - freq_energy_dist = torch.sum(dl_output["stft"], dim=0) - freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0) - total_energy = freq_energy_dist[-1] - - percentile_5th = total_energy * 0.05 - percentile_95th = total_energy * 0.95 - - lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist)) - higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist)) - - lower_idx = int((lower_idx / freq_dimensions) * 1000) - higher_idx = int((higher_idx / freq_dimensions) * 1000) - - return {"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx])} - - -def calculate_mel_spec_relative_bandwidth_as_extra_channel(config, dl_output, metadata): - assert "stft" in dl_output.keys() - linear_mel_spec = torch.exp(torch.clip(dl_output["log_mel_spec"], max=10)) - - # The last dimension of the stft feature is the frequency dimension - freq_dimensions = linear_mel_spec.size(-1) - freq_energy_dist = torch.sum(linear_mel_spec, dim=0) - freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0) - total_energy = freq_energy_dist[-1] - - percentile_5th = total_energy * 0.05 - percentile_95th = total_energy * 0.95 - - lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist)) - higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist)) - - latent_t_size = config["model"]["params"]["latent_t_size"] - latent_f_size = config["model"]["params"]["latent_f_size"] - - lower_idx = int(latent_f_size * float((lower_idx / freq_dimensions))) - higher_idx = int(latent_f_size * float((higher_idx / freq_dimensions))) - - bandwidth_condition = torch.zeros((latent_t_size, latent_f_size)) - bandwidth_condition[:, lower_idx:higher_idx] += 1.0 - - return { - "mel_spec_bandwidth_cond_extra_channel": bandwidth_condition, - "freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx]), - } - - -def waveform_rs_48k(config, dl_output, metadata): - waveform = dl_output["waveform"] # [1, samples] - sampling_rate = dl_output["sampling_rate"] - - if sampling_rate != 48000: - waveform_48k = torchaudio.functional.resample( - waveform, orig_freq=sampling_rate, new_freq=48000 - ) - else: - waveform_48k = waveform - - return {"waveform_48k": waveform_48k} - - -def extract_vits_phoneme_and_flant5_text(config, dl_output, metadata): - assert ( - "phoneme" not in metadata.keys() - ), "The metadata of speech you use seems belong to fastspeech. Please check dataset_root.json" - - if "phonemes" in metadata.keys(): - new_item = get_vits_phoneme_ids_no_padding(config, dl_output, metadata) - new_item["text"] = "" # We assume TTS data does not have text description - else: - fake_metadata = {"phonemes": ""} # Add empty phoneme sequence - new_item = get_vits_phoneme_ids_no_padding(config, dl_output, fake_metadata) - - return new_item - - -def extract_fs2_phoneme_and_flant5_text(config, dl_output, metadata): - if "phoneme" in metadata.keys(): - new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata) - new_item["text"] = "" - else: - fake_metadata = {"phoneme": []} - new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, fake_metadata) - return new_item - - -def extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata): - PAD_LENGTH = 135 - - phonemes_lookup_dict = { - "K": 0, - "IH2": 1, - "NG": 2, - "OW2": 3, - "AH2": 4, - "F": 5, - "AE0": 6, - "IY0": 7, - "SH": 8, - "G": 9, - "W": 10, - "UW1": 11, - "AO2": 12, - "AW2": 13, - "UW0": 14, - "EY2": 15, - "UW2": 16, - "AE2": 17, - "IH0": 18, - "P": 19, - "D": 20, - "ER1": 21, - "AA1": 22, - "EH0": 23, - "UH1": 24, - "N": 25, - "V": 26, - "AY1": 27, - "EY1": 28, - "UH2": 29, - "EH1": 30, - "L": 31, - "AA2": 32, - "R": 33, - "OY1": 34, - "Y": 35, - "ER2": 36, - "S": 37, - "AE1": 38, - "AH1": 39, - "JH": 40, - "ER0": 41, - "EH2": 42, - "IY2": 43, - "OY2": 44, - "AW1": 45, - "IH1": 46, - "IY1": 47, - "OW0": 48, - "AO0": 49, - "AY0": 50, - "EY0": 51, - "AY2": 52, - "UH0": 53, - "M": 54, - "TH": 55, - "T": 56, - "OY0": 57, - "AW0": 58, - "DH": 59, - "Z": 60, - "spn": 61, - "AH0": 62, - "sp": 63, - "AO1": 64, - "OW1": 65, - "ZH": 66, - "B": 67, - "AA0": 68, - "CH": 69, - "HH": 70, - } - pad_token_id = len(phonemes_lookup_dict.keys()) - - assert ( - "phoneme" in metadata.keys() - ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset" - - phonemes = [ - phonemes_lookup_dict[x] - for x in metadata["phoneme"] - if (x in phonemes_lookup_dict.keys()) - ] - - if (len(phonemes) / PAD_LENGTH) > 5: - print( - "Warning: Phonemes length is too long and is truncated too much! %s" - % metadata - ) - - phonemes = phonemes[:PAD_LENGTH] - - def _pad_phonemes(phonemes_list): - return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list)) - - return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))} - - -def extract_phoneme_g2p_en_feature(config, dl_output, metadata): - PAD_LENGTH = 250 - - phonemes_lookup_dict = { - " ": 0, - "AA": 1, - "AE": 2, - "AH": 3, - "AO": 4, - "AW": 5, - "AY": 6, - "B": 7, - "CH": 8, - "D": 9, - "DH": 10, - "EH": 11, - "ER": 12, - "EY": 13, - "F": 14, - "G": 15, - "HH": 16, - "IH": 17, - "IY": 18, - "JH": 19, - "K": 20, - "L": 21, - "M": 22, - "N": 23, - "NG": 24, - "OW": 25, - "OY": 26, - "P": 27, - "R": 28, - "S": 29, - "SH": 30, - "T": 31, - "TH": 32, - "UH": 33, - "UW": 34, - "V": 35, - "W": 36, - "Y": 37, - "Z": 38, - "ZH": 39, - } - pad_token_id = len(phonemes_lookup_dict.keys()) - - assert ( - "phoneme" in metadata.keys() - ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset" - phonemes = [ - phonemes_lookup_dict[x] - for x in metadata["phoneme"] - if (x in phonemes_lookup_dict.keys()) - ] - - if (len(phonemes) / PAD_LENGTH) > 5: - print( - "Warning: Phonemes length is too long and is truncated too much! %s" - % metadata - ) - - phonemes = phonemes[:PAD_LENGTH] - - def _pad_phonemes(phonemes_list): - return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list)) - - return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))} - - -def extract_kaldi_fbank_feature(config, dl_output, metadata): - norm_mean = -4.2677393 - norm_std = 4.5689974 - - waveform = dl_output["waveform"] # [1, samples] - sampling_rate = dl_output["sampling_rate"] - log_mel_spec_hifigan = dl_output["log_mel_spec"] - - if sampling_rate != 16000: - waveform_16k = torchaudio.functional.resample( - waveform, orig_freq=sampling_rate, new_freq=16000 - ) - else: - waveform_16k = waveform - - waveform_16k = waveform_16k - waveform_16k.mean() - fbank = torchaudio.compliance.kaldi.fbank( - waveform_16k, - htk_compat=True, - sample_frequency=16000, - use_energy=False, - window_type="hanning", - num_mel_bins=128, - dither=0.0, - frame_shift=10, - ) - - TARGET_LEN = log_mel_spec_hifigan.size(0) - - # cut and pad - n_frames = fbank.shape[0] - p = TARGET_LEN - n_frames - if p > 0: - m = torch.nn.ZeroPad2d((0, 0, 0, p)) - fbank = m(fbank) - elif p < 0: - fbank = fbank[:TARGET_LEN, :] - - fbank = (fbank - norm_mean) / (norm_std * 2) - - return {"ta_kaldi_fbank": fbank} # [1024, 128] - - -def extract_kaldi_fbank_feature_32k(config, dl_output, metadata): - norm_mean = -4.2677393 - norm_std = 4.5689974 - - waveform = dl_output["waveform"] # [1, samples] - sampling_rate = dl_output["sampling_rate"] - log_mel_spec_hifigan = dl_output["log_mel_spec"] - - if sampling_rate != 32000: - waveform_32k = torchaudio.functional.resample( - waveform, orig_freq=sampling_rate, new_freq=32000 - ) - else: - waveform_32k = waveform - - waveform_32k = waveform_32k - waveform_32k.mean() - fbank = torchaudio.compliance.kaldi.fbank( - waveform_32k, - htk_compat=True, - sample_frequency=32000, - use_energy=False, - window_type="hanning", - num_mel_bins=128, - dither=0.0, - frame_shift=10, - ) - - TARGET_LEN = log_mel_spec_hifigan.size(0) - - # cut and pad - n_frames = fbank.shape[0] - p = TARGET_LEN - n_frames - if p > 0: - m = torch.nn.ZeroPad2d((0, 0, 0, p)) - fbank = m(fbank) - elif p < 0: - fbank = fbank[:TARGET_LEN, :] - - fbank = (fbank - norm_mean) / (norm_std * 2) - - return {"ta_kaldi_fbank": fbank} # [1024, 128] - - -# Use the beat and downbeat information as music conditions -def extract_drum_beat(config, dl_output, metadata): - def visualization(conditional_signal, mel_spectrogram, filename): - import soundfile as sf - - sf.write( - os.path.basename(dl_output["fname"]), - np.array(dl_output["waveform"])[0], - dl_output["sampling_rate"], - ) - plt.figure(figsize=(10, 10)) - - plt.subplot(211) - plt.imshow(np.array(conditional_signal).T, aspect="auto") - plt.title("Conditional Signal") - - plt.subplot(212) - plt.imshow(np.array(mel_spectrogram).T, aspect="auto") - plt.title("Mel Spectrogram") - - plt.savefig(filename) - plt.close() - - assert "sample_rate" in metadata and "beat" in metadata and "downbeat" in metadata - - sampling_rate = metadata["sample_rate"] - duration = dl_output["duration"] - # The dataloader segment length before performing torch resampling - original_segment_length_before_resample = int(sampling_rate * duration) - - random_start_sample = int(dl_output["random_start_sample_in_original_audio_file"]) - - # The sample idx for beat and downbeat, relatively to the segmented audio - beat = [ - x - random_start_sample - for x in metadata["beat"] - if ( - x - random_start_sample >= 0 - and x - random_start_sample <= original_segment_length_before_resample - ) - ] - downbeat = [ - x - random_start_sample - for x in metadata["downbeat"] - if ( - x - random_start_sample >= 0 - and x - random_start_sample <= original_segment_length_before_resample - ) - ] - - latent_shape = ( - config["model"]["params"]["latent_t_size"], - config["model"]["params"]["latent_f_size"], - ) - conditional_signal = torch.zeros(latent_shape) - - # beat: -0.5 - # downbeat: +1.0 - # 0: none; -0.5: beat; 1.0: downbeat; 0.5: downbeat+beat - for each in beat: - beat_index = int( - (each / original_segment_length_before_resample) * latent_shape[0] - ) - beat_index = min(beat_index, conditional_signal.size(0) - 1) - - conditional_signal[beat_index, :] -= 0.5 - - for each in downbeat: - beat_index = int( - (each / original_segment_length_before_resample) * latent_shape[0] - ) - beat_index = min(beat_index, conditional_signal.size(0) - 1) - - conditional_signal[beat_index, :] += 1.0 - - # visualization(conditional_signal, dl_output["log_mel_spec"], filename = os.path.basename(dl_output["fname"])+".png") - - return {"cond_beat_downbeat": conditional_signal} diff --git a/audioldm2/utilities/data/dataset.py b/audioldm2/utilities/data/dataset.py deleted file mode 100755 index f0bfbb7388ca6473beb4574ac4e29dcf0b7c0571..0000000000000000000000000000000000000000 --- a/audioldm2/utilities/data/dataset.py +++ /dev/null @@ -1,518 +0,0 @@ -import os -import pandas as pd - -import audioldm2.utilities.audio as Audio -from audioldm2.utilities.tools import load_json - -import random -from torch.utils.data import Dataset -import torch.nn.functional -import torch -import numpy as np -import torchaudio - - -class AudioDataset(Dataset): - def __init__( - self, - config=None, - split="train", - waveform_only=False, - add_ons=[], - dataset_json_path=None, # - ): - """ - Dataset that manages audio recordings - :param audio_conf: Dictionary containing the audio loading and preprocessing settings - :param dataset_json_file - """ - self.config = config - self.split = split - self.pad_wav_start_sample = 0 # If none, random choose - self.trim_wav = False - self.waveform_only = waveform_only - self.add_ons = [eval(x) for x in add_ons] - print("Add-ons:", self.add_ons) - - self.build_setting_parameters() - - # For an external dataset - if dataset_json_path is not None: - assert type(dataset_json_path) == str - print("Load metadata from %s" % dataset_json_path) - self.data = load_json(dataset_json_path)["data"] - self.id2label, self.index_dict, self.num2label = {}, {}, {} - else: - self.metadata_root = load_json(self.config["metadata_root"]) - self.dataset_name = self.config["data"][self.split] - assert split in self.config["data"].keys(), ( - "The dataset split %s you specified is not present in the config. You can choose from %s" - % (split, self.config["data"].keys()) - ) - self.build_dataset() - self.build_id_to_label() - - self.build_dsp() - self.label_num = len(self.index_dict) - print("Dataset initialize finished") - - def __getitem__(self, index): - ( - fname, - waveform, - stft, - log_mel_spec, - label_vector, # the one-hot representation of the audio class - # the metadata of the sampled audio file and the mixup audio file (if exist) - (datum, mix_datum), - random_start, - ) = self.feature_extraction(index) - text = self.get_sample_text_caption(datum, mix_datum, label_vector) - - data = { - "text": text, # list - "fname": self.text_to_filename(text) - if (len(fname) == 0) - else fname, # list - # tensor, [batchsize, class_num] - "label_vector": "" if (label_vector is None) else label_vector.float(), - # tensor, [batchsize, 1, samples_num] - "waveform": "" if (waveform is None) else waveform.float(), - # tensor, [batchsize, t-steps, f-bins] - "stft": "" if (stft is None) else stft.float(), - # tensor, [batchsize, t-steps, mel-bins] - "log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(), - "duration": self.duration, - "sampling_rate": self.sampling_rate, - "random_start_sample_in_original_audio_file": random_start, - } - - for add_on in self.add_ons: - data.update(add_on(self.config, data, self.data[index])) - - if data["text"] is None: - print("Warning: The model return None on key text", fname) - data["text"] = "" - - return data - - def text_to_filename(self, text): - return text.replace(" ", "_").replace("'", "_").replace('"', "_") - - def get_dataset_root_path(self, dataset): - assert dataset in self.metadata_root.keys() - return self.metadata_root[dataset] - - def get_dataset_metadata_path(self, dataset, key): - # key: train, test, val, class_label_indices - try: - if dataset in self.metadata_root["metadata"]["path"].keys(): - return self.metadata_root["metadata"]["path"][dataset][key] - except: - raise ValueError( - 'Dataset %s does not metadata "%s" specified' % (dataset, key) - ) - # return None - - def __len__(self): - return len(self.data) - - def feature_extraction(self, index): - if index > len(self.data) - 1: - print( - "The index of the dataloader is out of range: %s/%s" - % (index, len(self.data)) - ) - index = random.randint(0, len(self.data) - 1) - - # Read wave file and extract feature - while True: - try: - label_indices = np.zeros(self.label_num, dtype=np.float32) - datum = self.data[index] - ( - log_mel_spec, - stft, - mix_lambda, - waveform, - random_start, - ) = self.read_audio_file(datum["wav"]) - mix_datum = None - if self.label_num > 0 and "labels" in datum.keys(): - for label_str in datum["labels"].split(","): - label_indices[int(self.index_dict[label_str])] = 1.0 - - # If the key "label" is not in the metadata, return all zero vector - label_indices = torch.FloatTensor(label_indices) - break - except Exception as e: - index = (index + 1) % len(self.data) - print( - "Error encounter during audio feature extraction: ", e, datum["wav"] - ) - continue - - # The filename of the wav file - fname = datum["wav"] - # t_step = log_mel_spec.size(0) - # waveform = torch.FloatTensor(waveform[..., : int(self.hopsize * t_step)]) - waveform = torch.FloatTensor(waveform) - - return ( - fname, - waveform, - stft, - log_mel_spec, - label_indices, - (datum, mix_datum), - random_start, - ) - - # def augmentation(self, log_mel_spec): - # assert torch.min(log_mel_spec) < 0 - # log_mel_spec = log_mel_spec.exp() - - # log_mel_spec = torch.transpose(log_mel_spec, 0, 1) - # # this is just to satisfy new torchaudio version. - # log_mel_spec = log_mel_spec.unsqueeze(0) - # if self.freqm != 0: - # log_mel_spec = self.frequency_masking(log_mel_spec, self.freqm) - # if self.timem != 0: - # log_mel_spec = self.time_masking( - # log_mel_spec, self.timem) # self.timem=0 - - # log_mel_spec = (log_mel_spec + 1e-7).log() - # # squeeze back - # log_mel_spec = log_mel_spec.squeeze(0) - # log_mel_spec = torch.transpose(log_mel_spec, 0, 1) - # return log_mel_spec - - def build_setting_parameters(self): - # Read from the json config - self.melbins = self.config["preprocessing"]["mel"]["n_mel_channels"] - # self.freqm = self.config["preprocessing"]["mel"]["freqm"] - # self.timem = self.config["preprocessing"]["mel"]["timem"] - self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"] - self.hopsize = self.config["preprocessing"]["stft"]["hop_length"] - self.duration = self.config["preprocessing"]["audio"]["duration"] - self.target_length = int(self.duration * self.sampling_rate / self.hopsize) - - self.mixup = self.config["augmentation"]["mixup"] - - # Calculate parameter derivations - # self.waveform_sample_length = int(self.target_length * self.hopsize) - - # if (self.config["balance_sampling_weight"]): - # self.samples_weight = np.loadtxt( - # self.config["balance_sampling_weight"], delimiter="," - # ) - - if "train" not in self.split: - self.mixup = 0.0 - # self.freqm = 0 - # self.timem = 0 - - def _relative_path_to_absolute_path(self, metadata, dataset_name): - root_path = self.get_dataset_root_path(dataset_name) - for i in range(len(metadata["data"])): - assert "wav" in metadata["data"][i].keys(), metadata["data"][i] - assert metadata["data"][i]["wav"][0] != "/", ( - "The dataset metadata should only contain relative path to the audio file: " - + str(metadata["data"][i]["wav"]) - ) - metadata["data"][i]["wav"] = os.path.join( - root_path, metadata["data"][i]["wav"] - ) - return metadata - - def build_dataset(self): - self.data = [] - print("Build dataset split %s from %s" % (self.split, self.dataset_name)) - if type(self.dataset_name) is str: - data_json = load_json( - self.get_dataset_metadata_path(self.dataset_name, key=self.split) - ) - data_json = self._relative_path_to_absolute_path( - data_json, self.dataset_name - ) - self.data = data_json["data"] - elif type(self.dataset_name) is list: - for dataset_name in self.dataset_name: - data_json = load_json( - self.get_dataset_metadata_path(dataset_name, key=self.split) - ) - data_json = self._relative_path_to_absolute_path( - data_json, dataset_name - ) - self.data += data_json["data"] - else: - raise Exception("Invalid data format") - print("Data size: {}".format(len(self.data))) - - def build_dsp(self): - self.STFT = Audio.stft.TacotronSTFT( - self.config["preprocessing"]["stft"]["filter_length"], - self.config["preprocessing"]["stft"]["hop_length"], - self.config["preprocessing"]["stft"]["win_length"], - self.config["preprocessing"]["mel"]["n_mel_channels"], - self.config["preprocessing"]["audio"]["sampling_rate"], - self.config["preprocessing"]["mel"]["mel_fmin"], - self.config["preprocessing"]["mel"]["mel_fmax"], - ) - # self.stft_transform = torchaudio.transforms.Spectrogram( - # n_fft=1024, hop_length=160 - # ) - # self.melscale_transform = torchaudio.transforms.MelScale( - # sample_rate=16000, n_stft=1024 // 2 + 1, n_mels=64 - # ) - - def build_id_to_label(self): - id2label = {} - id2num = {} - num2label = {} - class_label_indices_path = self.get_dataset_metadata_path( - dataset=self.config["data"]["class_label_indices"], - key="class_label_indices", - ) - if class_label_indices_path is not None: - df = pd.read_csv(class_label_indices_path) - for _, row in df.iterrows(): - index, mid, display_name = row["index"], row["mid"], row["display_name"] - id2label[mid] = display_name - id2num[mid] = index - num2label[index] = display_name - self.id2label, self.index_dict, self.num2label = id2label, id2num, num2label - else: - self.id2label, self.index_dict, self.num2label = {}, {}, {} - - def resample(self, waveform, sr): - waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate) - # waveform = librosa.resample(waveform, sr, self.sampling_rate) - return waveform - - # if sr == 16000: - # return waveform - # if sr == 32000 and self.sampling_rate == 16000: - # waveform = waveform[::2] - # return waveform - # if sr == 48000 and self.sampling_rate == 16000: - # waveform = waveform[::3] - # return waveform - # else: - # raise ValueError( - # "We currently only support 16k audio generation. You need to resample you audio file to 16k, 32k, or 48k: %s, %s" - # % (sr, self.sampling_rate) - # ) - - def normalize_wav(self, waveform): - waveform = waveform - np.mean(waveform) - waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) - return waveform * 0.5 # Manually limit the maximum amplitude into 0.5 - - def random_segment_wav(self, waveform, target_length): - waveform_length = waveform.shape[-1] - assert waveform_length > 100, "Waveform is too short, %s" % waveform_length - - # Too short - if (waveform_length - target_length) <= 0: - return waveform, 0 - - random_start = int(self.random_uniform(0, waveform_length - target_length)) - return waveform[:, random_start : random_start + target_length], random_start - - def pad_wav(self, waveform, target_length): - waveform_length = waveform.shape[-1] - assert waveform_length > 100, "Waveform is too short, %s" % waveform_length - - if waveform_length == target_length: - return waveform - - # Pad - temp_wav = np.zeros((1, target_length), dtype=np.float32) - if self.pad_wav_start_sample is None: - rand_start = int(self.random_uniform(0, target_length - waveform_length)) - else: - rand_start = 0 - - temp_wav[:, rand_start : rand_start + waveform_length] = waveform - return temp_wav - - def trim_wav(self, waveform): - if np.max(np.abs(waveform)) < 0.0001: - return waveform - - def detect_leading_silence(waveform, threshold=0.0001): - chunk_size = 1000 - waveform_length = waveform.shape[0] - start = 0 - while start + chunk_size < waveform_length: - if np.max(np.abs(waveform[start : start + chunk_size])) < threshold: - start += chunk_size - else: - break - return start - - def detect_ending_silence(waveform, threshold=0.0001): - chunk_size = 1000 - waveform_length = waveform.shape[0] - start = waveform_length - while start - chunk_size > 0: - if np.max(np.abs(waveform[start - chunk_size : start])) < threshold: - start -= chunk_size - else: - break - if start == waveform_length: - return start - else: - return start + chunk_size - - start = detect_leading_silence(waveform) - end = detect_ending_silence(waveform) - - return waveform[start:end] - - def read_wav_file(self, filename): - # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower - waveform, sr = torchaudio.load(filename) - - waveform, random_start = self.random_segment_wav( - waveform, target_length=int(sr * self.duration) - ) - - waveform = self.resample(waveform, sr) - # random_start = int(random_start * (self.sampling_rate / sr)) - - waveform = waveform.numpy()[0, ...] - - waveform = self.normalize_wav(waveform) - - if self.trim_wav: - waveform = self.trim_wav(waveform) - - waveform = waveform[None, ...] - waveform = self.pad_wav( - waveform, target_length=int(self.sampling_rate * self.duration) - ) - return waveform, random_start - - def mix_two_waveforms(self, waveform1, waveform2): - mix_lambda = np.random.beta(5, 5) - mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2 - return self.normalize_wav(mix_waveform), mix_lambda - - def read_audio_file(self, filename, filename2=None): - if os.path.exists(filename): - waveform, random_start = self.read_wav_file(filename) - else: - print( - 'Warning [dataset.py]: The wav path "', - filename, - '" is not find in the metadata. Use empty waveform instead.', - ) - target_length = int(self.sampling_rate * self.duration) - waveform = torch.zeros((1, target_length)) - random_start = 0 - - mix_lambda = 0.0 - # log_mel_spec, stft = self.wav_feature_extraction_torchaudio(waveform) # this line is faster, but this implementation is not aligned with HiFi-GAN - if not self.waveform_only: - log_mel_spec, stft = self.wav_feature_extraction(waveform) - else: - # Load waveform data only - # Use zero array to keep the format unified - log_mel_spec, stft = None, None - - return log_mel_spec, stft, mix_lambda, waveform, random_start - - def get_sample_text_caption(self, datum, mix_datum, label_indices): - text = self.label_indices_to_text(datum, label_indices) - if mix_datum is not None: - text += " " + self.label_indices_to_text(mix_datum, label_indices) - return text - - # This one is significantly slower than "wav_feature_extraction_torchaudio" if num_worker > 1 - def wav_feature_extraction(self, waveform): - waveform = waveform[0, ...] - waveform = torch.FloatTensor(waveform) - - log_mel_spec, stft, energy = Audio.tools.get_mel_from_wav(waveform, self.STFT) - - log_mel_spec = torch.FloatTensor(log_mel_spec.T) - stft = torch.FloatTensor(stft.T) - - log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft) - return log_mel_spec, stft - - # @profile - # def wav_feature_extraction_torchaudio(self, waveform): - # waveform = waveform[0, ...] - # waveform = torch.FloatTensor(waveform) - - # stft = self.stft_transform(waveform) - # mel_spec = self.melscale_transform(stft) - # log_mel_spec = torch.log(mel_spec + 1e-7) - - # log_mel_spec = torch.FloatTensor(log_mel_spec.T) - # stft = torch.FloatTensor(stft.T) - - # log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft) - # return log_mel_spec, stft - - def pad_spec(self, log_mel_spec): - n_frames = log_mel_spec.shape[0] - p = self.target_length - n_frames - # cut and pad - if p > 0: - m = torch.nn.ZeroPad2d((0, 0, 0, p)) - log_mel_spec = m(log_mel_spec) - elif p < 0: - log_mel_spec = log_mel_spec[0 : self.target_length, :] - - if log_mel_spec.size(-1) % 2 != 0: - log_mel_spec = log_mel_spec[..., :-1] - - return log_mel_spec - - def _read_datum_caption(self, datum): - caption_keys = [x for x in datum.keys() if ("caption" in x)] - random_index = torch.randint(0, len(caption_keys), (1,))[0].item() - return datum[caption_keys[random_index]] - - def _is_contain_caption(self, datum): - caption_keys = [x for x in datum.keys() if ("caption" in x)] - return len(caption_keys) > 0 - - def label_indices_to_text(self, datum, label_indices): - if self._is_contain_caption(datum): - return self._read_datum_caption(datum) - elif "label" in datum.keys(): - name_indices = torch.where(label_indices > 0.1)[0] - # description_header = "This audio contains the sound of " - description_header = "" - labels = "" - for id, each in enumerate(name_indices): - if id == len(name_indices) - 1: - labels += "%s." % self.num2label[int(each)] - else: - labels += "%s, " % self.num2label[int(each)] - return description_header + labels - else: - return "" # TODO, if both label and caption are not provided, return empty string - - def random_uniform(self, start, end): - val = torch.rand(1).item() - return start + (end - start) * val - - def frequency_masking(self, log_mel_spec, freqm): - bs, freq, tsteps = log_mel_spec.size() - mask_len = int(self.random_uniform(freqm // 8, freqm)) - mask_start = int(self.random_uniform(start=0, end=freq - mask_len)) - log_mel_spec[:, mask_start : mask_start + mask_len, :] *= 0.0 - return log_mel_spec - - def time_masking(self, log_mel_spec, timem): - bs, freq, tsteps = log_mel_spec.size() - mask_len = int(self.random_uniform(timem // 8, timem)) - mask_start = int(self.random_uniform(start=0, end=tsteps - mask_len)) - log_mel_spec[:, :, mask_start : mask_start + mask_len] *= 0.0 - return log_mel_spec diff --git a/audioldm2/utilities/model.py b/audioldm2/utilities/model.py deleted file mode 100755 index ffefac1212b85bfb8c4992371dbdf6d500a969e3..0000000000000000000000000000000000000000 --- a/audioldm2/utilities/model.py +++ /dev/null @@ -1,121 +0,0 @@ -import torch - -import audioldm2.hifigan as hifigan - - -def get_vocoder_config(): - return { - "resblock": "1", - "num_gpus": 6, - "batch_size": 16, - "learning_rate": 0.0002, - "adam_b1": 0.8, - "adam_b2": 0.99, - "lr_decay": 0.999, - "seed": 1234, - "upsample_rates": [5, 4, 2, 2, 2], - "upsample_kernel_sizes": [16, 16, 8, 4, 4], - "upsample_initial_channel": 1024, - "resblock_kernel_sizes": [3, 7, 11], - "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - "segment_size": 8192, - "num_mels": 64, - "num_freq": 1025, - "n_fft": 1024, - "hop_size": 160, - "win_size": 1024, - "sampling_rate": 16000, - "fmin": 0, - "fmax": 8000, - "fmax_for_loss": None, - "num_workers": 4, - "dist_config": { - "dist_backend": "nccl", - "dist_url": "tcp://localhost:54321", - "world_size": 1, - }, - } - - -def get_available_checkpoint_keys(model, ckpt): - state_dict = torch.load(ckpt)["state_dict"] - current_state_dict = model.state_dict() - new_state_dict = {} - for k in state_dict.keys(): - if ( - k in current_state_dict.keys() - and current_state_dict[k].size() == state_dict[k].size() - ): - new_state_dict[k] = state_dict[k] - else: - print("==> WARNING: Skipping %s" % k) - print( - "%s out of %s keys are matched" - % (len(new_state_dict.keys()), len(state_dict.keys())) - ) - return new_state_dict - - -def get_param_num(model): - num_param = sum(param.numel() for param in model.parameters()) - return num_param - - -def torch_version_orig_mod_remove(state_dict): - new_state_dict = {} - new_state_dict["generator"] = {} - for key in state_dict["generator"].keys(): - if "_orig_mod." in key: - new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[ - "generator" - ][key] - else: - new_state_dict["generator"][key] = state_dict["generator"][key] - return new_state_dict - - -def get_vocoder(config, device, mel_bins): - name = "HiFi-GAN" - speaker = "" - if name == "MelGAN": - if speaker == "LJSpeech": - vocoder = torch.hub.load( - "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" - ) - elif speaker == "universal": - vocoder = torch.hub.load( - "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" - ) - vocoder.mel2wav.eval() - vocoder.mel2wav.to(device) - elif name == "HiFi-GAN": - config = get_vocoder_config() - config = hifigan.AttrDict(config) - vocoder = hifigan.Generator_old(config) - # print("Load hifigan/g_01080000") - # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) - # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) - # ckpt = torch_version_orig_mod_remove(ckpt) - # vocoder.load_state_dict(ckpt["generator"]) - vocoder.eval() - vocoder.remove_weight_norm() - vocoder.to(device) - return vocoder - - -def vocoder_infer(mels, vocoder, lengths=None): - with torch.no_grad(): - wavs = vocoder(mels).squeeze(1) - - wavs = (wavs.cpu().numpy() * 32768).astype("int16") - - if lengths is not None: - wavs = wavs[:, :lengths] - - # wavs = [wav for wav in wavs] - - # for i in range(len(mels)): - # if lengths is not None: - # wavs[i] = wavs[i][: lengths[i]] - - return wavs diff --git a/audioldm2/utilities/sampler.py b/audioldm2/utilities/sampler.py deleted file mode 100755 index cdaf4882715f53f39ead8bf71fb3dccc29cd8b94..0000000000000000000000000000000000000000 --- a/audioldm2/utilities/sampler.py +++ /dev/null @@ -1,588 +0,0 @@ -from typing import Iterator, List, Optional, Union -from collections import Counter -import logging -from operator import itemgetter -import random - -import numpy as np - -from torch.utils.data import DistributedSampler -from torch.utils.data.sampler import Sampler - -LOGGER = logging.getLogger(__name__) - -from torch.utils.data import Dataset, Sampler - - -class DatasetFromSampler(Dataset): - """Dataset to create indexes from `Sampler`. - Args: - sampler: PyTorch sampler - """ - - def __init__(self, sampler: Sampler): - """Initialisation for DatasetFromSampler.""" - self.sampler = sampler - self.sampler_list = None - - def __getitem__(self, index: int): - """Gets element of the dataset. - Args: - index: index of the element in the dataset - Returns: - Single element by index - """ - if self.sampler_list is None: - self.sampler_list = list(self.sampler) - return self.sampler_list[index] - - def __len__(self) -> int: - """ - Returns: - int: length of the dataset - """ - return len(self.sampler) - - -class BalanceClassSampler(Sampler): - """Allows you to create stratified sample on unbalanced classes. - - Args: - labels: list of class label for each elem in the dataset - mode: Strategy to balance classes. - Must be one of [downsampling, upsampling] - - Python API examples: - - .. code-block:: python - - import os - from torch import nn, optim - from torch.utils.data import DataLoader - from catalyst import dl - from catalyst.data import ToTensor, BalanceClassSampler - from catalyst.contrib.datasets import MNIST - - train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) - train_labels = train_data.targets.cpu().numpy().tolist() - train_sampler = BalanceClassSampler(train_labels, mode=5000) - valid_data = MNIST(os.getcwd(), train=False) - - loaders = { - "train": DataLoader(train_data, sampler=train_sampler, batch_size=32), - "valid": DataLoader(valid_data, batch_size=32), - } - - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) - criterion = nn.CrossEntropyLoss() - optimizer = optim.Adam(model.parameters(), lr=0.02) - - runner = dl.SupervisedRunner() - # model training - runner.train( - model=model, - criterion=criterion, - optimizer=optimizer, - loaders=loaders, - num_epochs=1, - logdir="./logs", - valid_loader="valid", - valid_metric="loss", - minimize_valid_metric=True, - verbose=True, - ) - """ - - def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"): - """Sampler initialisation.""" - super().__init__(labels) - - labels = np.array(labels) - samples_per_class = {label: (labels == label).sum() for label in set(labels)} - - self.lbl2idx = { - label: np.arange(len(labels))[labels == label].tolist() - for label in set(labels) - } - - if isinstance(mode, str): - assert mode in ["downsampling", "upsampling"] - - if isinstance(mode, int) or mode == "upsampling": - samples_per_class = ( - mode if isinstance(mode, int) else max(samples_per_class.values()) - ) - else: - samples_per_class = min(samples_per_class.values()) - - self.labels = labels - self.samples_per_class = samples_per_class - self.length = self.samples_per_class * len(set(labels)) - - def __iter__(self) -> Iterator[int]: - """ - Returns: - iterator of indices of stratified sample - """ - indices = [] - for key in sorted(self.lbl2idx): - replace_flag = self.samples_per_class > len(self.lbl2idx[key]) - indices += np.random.choice( - self.lbl2idx[key], self.samples_per_class, replace=replace_flag - ).tolist() - assert len(indices) == self.length - np.random.shuffle(indices) - - return iter(indices) - - def __len__(self) -> int: - """ - Returns: - length of result sample - """ - return self.length - - -class BatchBalanceClassSampler(Sampler): - """ - This kind of sampler can be used for both metric learning and classification task. - - BatchSampler with the given strategy for the C unique classes dataset: - - Selection `num_classes` of C classes for each batch - - Selection `num_samples` instances for each class in the batch - The epoch ends after `num_batches`. - So, the batch sise is `num_classes` * `num_samples`. - - One of the purposes of this sampler is to be used for - forming triplets and pos/neg pairs inside the batch. - To guarante existance of these pairs in the batch, - `num_classes` and `num_samples` should be > 1. (1) - - This type of sampling can be found in the classical paper of Person Re-Id, - where P (`num_classes`) equals 32 and K (`num_samples`) equals 4: - `In Defense of the Triplet Loss for Person Re-Identification`_. - - Args: - labels: list of classes labeles for each elem in the dataset - num_classes: number of classes in a batch, should be > 1 - num_samples: number of instances of each class in a batch, should be > 1 - num_batches: number of batches in epoch - (default = len(labels) // (num_classes * num_samples)) - - .. _In Defense of the Triplet Loss for Person Re-Identification: - https://arxiv.org/abs/1703.07737 - - Python API examples: - - .. code-block:: python - - import os - from torch import nn, optim - from torch.utils.data import DataLoader - from catalyst import dl - from catalyst.data import ToTensor, BatchBalanceClassSampler - from catalyst.contrib.datasets import MNIST - - train_data = MNIST(os.getcwd(), train=True, download=True) - train_labels = train_data.targets.cpu().numpy().tolist() - train_sampler = BatchBalanceClassSampler( - train_labels, num_classes=10, num_samples=4) - valid_data = MNIST(os.getcwd(), train=False) - - loaders = { - "train": DataLoader(train_data, batch_sampler=train_sampler), - "valid": DataLoader(valid_data, batch_size=32), - } - - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) - criterion = nn.CrossEntropyLoss() - optimizer = optim.Adam(model.parameters(), lr=0.02) - - runner = dl.SupervisedRunner() - # model training - runner.train( - model=model, - criterion=criterion, - optimizer=optimizer, - loaders=loaders, - num_epochs=1, - logdir="./logs", - valid_loader="valid", - valid_metric="loss", - minimize_valid_metric=True, - verbose=True, - ) - """ - - def __init__( - self, - labels: Union[List[int], np.ndarray], - num_classes: int, - num_samples: int, - num_batches: int = None, - ): - """Sampler initialisation.""" - super().__init__(labels) - classes = set(labels) - - assert isinstance(num_classes, int) and isinstance(num_samples, int) - assert (1 < num_classes <= len(classes)) and (1 < num_samples) - assert all( - n > 1 for n in Counter(labels).values() - ), "Each class shoud contain at least 2 instances to fit (1)" - - labels = np.array(labels) - self._labels = list(set(labels.tolist())) - self._num_classes = num_classes - self._num_samples = num_samples - self._batch_size = self._num_classes * self._num_samples - self._num_batches = num_batches or len(labels) // self._batch_size - self.lbl2idx = { - label: np.arange(len(labels))[labels == label].tolist() - for label in set(labels) - } - - @property - def batch_size(self) -> int: - """ - Returns: - this value should be used in DataLoader as batch size - """ - return self._batch_size - - @property - def batches_in_epoch(self) -> int: - """ - Returns: - number of batches in an epoch - """ - return self._num_batches - - def __len__(self) -> int: - """ - Returns: - number of samples in an epoch - """ - return self._num_batches # * self._batch_size - - def __iter__(self) -> Iterator[int]: - """ - Returns: - indeces for sampling dataset elems during an epoch - """ - indices = [] - for _ in range(self._num_batches): - batch_indices = [] - classes_for_batch = random.sample(self._labels, self._num_classes) - while self._num_classes != len(set(classes_for_batch)): - classes_for_batch = random.sample(self._labels, self._num_classes) - for cls_id in classes_for_batch: - replace_flag = self._num_samples > len(self.lbl2idx[cls_id]) - batch_indices += np.random.choice( - self.lbl2idx[cls_id], self._num_samples, replace=replace_flag - ).tolist() - indices.append(batch_indices) - return iter(indices) - - -class DynamicBalanceClassSampler(Sampler): - """ - This kind of sampler can be used for classification tasks with significant - class imbalance. - - The idea of this sampler that we start with the original class distribution - and gradually move to uniform class distribution like with downsampling. - - Let's define :math: D_i = #C_i/ #C_min where :math: #C_i is a size of class - i and :math: #C_min is a size of the rarest class, so :math: D_i define - class distribution. Also define :math: g(n_epoch) is a exponential - scheduler. On each epoch current :math: D_i calculated as - :math: current D_i = D_i ^ g(n_epoch), - after this data samples according this distribution. - - Notes: - In the end of the training, epochs will contain only - min_size_class * n_classes examples. So, possible it will not - necessary to do validation on each epoch. For this reason use - ControlFlowCallback. - - Examples: - - >>> import torch - >>> import numpy as np - - >>> from catalyst.data import DynamicBalanceClassSampler - >>> from torch.utils import data - - >>> features = torch.Tensor(np.random.random((200, 100))) - >>> labels = np.random.randint(0, 4, size=(200,)) - >>> sampler = DynamicBalanceClassSampler(labels) - >>> labels = torch.LongTensor(labels) - >>> dataset = data.TensorDataset(features, labels) - >>> loader = data.dataloader.DataLoader(dataset, batch_size=8) - - >>> for batch in loader: - >>> b_features, b_labels = batch - - Sampler was inspired by https://arxiv.org/abs/1901.06783 - """ - - def __init__( - self, - labels: List[Union[int, str]], - exp_lambda: float = 0.9, - start_epoch: int = 0, - max_d: Optional[int] = None, - mode: Union[str, int] = "downsampling", - ignore_warning: bool = False, - ): - """ - Args: - labels: list of labels for each elem in the dataset - exp_lambda: exponent figure for schedule - start_epoch: start epoch number, can be useful for multi-stage - experiments - max_d: if not None, limit on the difference between the most - frequent and the rarest classes, heuristic - mode: number of samples per class in the end of training. Must be - "downsampling" or number. Before change it, make sure that you - understand how does it work - ignore_warning: ignore warning about min class size - """ - assert isinstance(start_epoch, int) - assert 0 < exp_lambda < 1, "exp_lambda must be in (0, 1)" - super().__init__(labels) - self.exp_lambda = exp_lambda - if max_d is None: - max_d = np.inf - self.max_d = max_d - self.epoch = start_epoch - labels = np.array(labels) - samples_per_class = Counter(labels) - self.min_class_size = min(samples_per_class.values()) - - if self.min_class_size < 100 and not ignore_warning: - LOGGER.warning( - f"the smallest class contains only" - f" {self.min_class_size} examples. At the end of" - f" training, epochs will contain only" - f" {self.min_class_size * len(samples_per_class)}" - f" examples" - ) - - self.original_d = { - key: value / self.min_class_size for key, value in samples_per_class.items() - } - self.label2idxes = { - label: np.arange(len(labels))[labels == label].tolist() - for label in set(labels) - } - - if isinstance(mode, int): - self.min_class_size = mode - else: - assert mode == "downsampling" - - self.labels = labels - self._update() - - def _update(self) -> None: - """Update d coefficients.""" - current_d = { - key: min(value ** self._exp_scheduler(), self.max_d) - for key, value in self.original_d.items() - } - samples_per_classes = { - key: int(value * self.min_class_size) for key, value in current_d.items() - } - self.samples_per_classes = samples_per_classes - self.length = np.sum(list(samples_per_classes.values())) - self.epoch += 1 - - def _exp_scheduler(self) -> float: - return self.exp_lambda**self.epoch - - def __iter__(self) -> Iterator[int]: - """ - Returns: - iterator of indices of stratified sample - """ - indices = [] - for key in sorted(self.label2idxes): - samples_per_class = self.samples_per_classes[key] - replace_flag = samples_per_class > len(self.label2idxes[key]) - indices += np.random.choice( - self.label2idxes[key], samples_per_class, replace=replace_flag - ).tolist() - assert len(indices) == self.length - np.random.shuffle(indices) - self._update() - return iter(indices) - - def __len__(self) -> int: - """ - Returns: - length of result sample - """ - return self.length - - -class MiniEpochSampler(Sampler): - """ - Sampler iterates mini epochs from the dataset used by ``mini_epoch_len``. - - Args: - data_len: Size of the dataset - mini_epoch_len: Num samples from the dataset used in one - mini epoch. - drop_last: If ``True``, sampler will drop the last batches - if its size would be less than ``batches_per_epoch`` - shuffle: one of ``"always"``, ``"real_epoch"``, or `None``. - The sampler will shuffle indices - > "per_mini_epoch" - every mini epoch (every ``__iter__`` call) - > "per_epoch" -- every real epoch - > None -- don't shuffle - - Example: - >>> MiniEpochSampler(len(dataset), mini_epoch_len=100) - >>> MiniEpochSampler(len(dataset), mini_epoch_len=100, drop_last=True) - >>> MiniEpochSampler(len(dataset), mini_epoch_len=100, - >>> shuffle="per_epoch") - """ - - def __init__( - self, - data_len: int, - mini_epoch_len: int, - drop_last: bool = False, - shuffle: str = None, - ): - """Sampler initialisation.""" - super().__init__(None) - - self.data_len = int(data_len) - self.mini_epoch_len = int(mini_epoch_len) - - self.steps = int(data_len / self.mini_epoch_len) - self.state_i = 0 - - has_reminder = data_len - self.steps * mini_epoch_len > 0 - if self.steps == 0: - self.divider = 1 - elif has_reminder and not drop_last: - self.divider = self.steps + 1 - else: - self.divider = self.steps - - self._indices = np.arange(self.data_len) - self.indices = self._indices - self.end_pointer = max(self.data_len, self.mini_epoch_len) - - if not (shuffle is None or shuffle in ["per_mini_epoch", "per_epoch"]): - raise ValueError( - "Shuffle must be one of ['per_mini_epoch', 'per_epoch']. " - + f"Got {shuffle}" - ) - self.shuffle_type = shuffle - - def shuffle(self) -> None: - """Shuffle sampler indices.""" - if self.shuffle_type == "per_mini_epoch" or ( - self.shuffle_type == "per_epoch" and self.state_i == 0 - ): - if self.data_len >= self.mini_epoch_len: - self.indices = self._indices - np.random.shuffle(self.indices) - else: - self.indices = np.random.choice( - self._indices, self.mini_epoch_len, replace=True - ) - - def __iter__(self) -> Iterator[int]: - """Iterate over sampler. - - Returns: - python iterator - """ - self.state_i = self.state_i % self.divider - self.shuffle() - - start = self.state_i * self.mini_epoch_len - stop = ( - self.end_pointer - if (self.state_i == self.steps) - else (self.state_i + 1) * self.mini_epoch_len - ) - indices = self.indices[start:stop].tolist() - - self.state_i += 1 - return iter(indices) - - def __len__(self) -> int: - """ - Returns: - int: length of the mini-epoch - """ - return self.mini_epoch_len - - -class DistributedSamplerWrapper(DistributedSampler): - """ - Wrapper over `Sampler` for distributed training. - Allows you to use any sampler in distributed mode. - - It is especially useful in conjunction with - `torch.nn.parallel.DistributedDataParallel`. In such case, each - process can pass a DistributedSamplerWrapper instance as a DataLoader - sampler, and load a subset of subsampled data of the original dataset - that is exclusive to it. - - .. note:: - Sampler is assumed to be of constant size. - """ - - def __init__( - self, - sampler, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - shuffle: bool = True, - ): - """ - - Args: - sampler: Sampler used for subsampling - num_replicas (int, optional): Number of processes participating in - distributed training - rank (int, optional): Rank of the current process - within ``num_replicas`` - shuffle (bool, optional): If true (default), - sampler will shuffle the indices - """ - super(DistributedSamplerWrapper, self).__init__( - DatasetFromSampler(sampler), - num_replicas=num_replicas, - rank=rank, - shuffle=shuffle, - ) - self.sampler = sampler - - def __iter__(self) -> Iterator[int]: - """Iterate over sampler. - - Returns: - python iterator - """ - self.dataset = DatasetFromSampler(self.sampler) - indexes_of_indexes = super().__iter__() - subsampler_indexes = self.dataset - return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) - - -__all__ = [ - "BalanceClassSampler", - "BatchBalanceClassSampler", - "DistributedSamplerWrapper", - "DynamicBalanceClassSampler", - "MiniEpochSampler", -] diff --git a/audioldm2/utilities/tools.py b/audioldm2/utilities/tools.py deleted file mode 100755 index a647a272cdf076b2ae9785bc83724ebd7a897642..0000000000000000000000000000000000000000 --- a/audioldm2/utilities/tools.py +++ /dev/null @@ -1,541 +0,0 @@ -# Author: Haohe Liu -# Email: haoheliu@gmail.com -# Date: 11 Feb 2023 - -import os -import json - -import torch -import torch.nn.functional as F -import numpy as np -import matplotlib -from scipy.io import wavfile -from matplotlib import pyplot as plt - - -matplotlib.use("Agg") - -import hashlib -import os - -import requests -from tqdm import tqdm - -URL_MAP = { - "vggishish_lpaps": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt", - "vggishish_mean_std_melspec_10s_22050hz": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt", - "melception": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt", -} - -CKPT_MAP = { - "vggishish_lpaps": "vggishish16.pt", - "vggishish_mean_std_melspec_10s_22050hz": "train_means_stds_melspec_10s_22050hz.txt", - "melception": "melception-21-05-10T09-28-40.pt", -} - -MD5_MAP = { - "vggishish_lpaps": "197040c524a07ccacf7715d7080a80bd", - "vggishish_mean_std_melspec_10s_22050hz": "f449c6fd0e248936c16f6d22492bb625", - "melception": "a71a41041e945b457c7d3d814bbcf72d", -} - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -def load_json(fname): - with open(fname, "r") as f: - data = json.load(f) - return data - - -def read_json(dataset_json_file): - with open(dataset_json_file, "r") as fp: - data_json = json.load(fp) - return data_json["data"] - - -def copy_test_subset_data(metadata, testset_copy_target_path): - # metadata = read_json(testset_metadata) - os.makedirs(testset_copy_target_path, exist_ok=True) - if len(os.listdir(testset_copy_target_path)) == len(metadata): - return - else: - # delete files in folder testset_copy_target_path - for file in os.listdir(testset_copy_target_path): - try: - os.remove(os.path.join(testset_copy_target_path, file)) - except Exception as e: - print(e) - - print("Copying test subset data to {}".format(testset_copy_target_path)) - for each in tqdm(metadata): - cmd = "cp {} {}".format(each["wav"], os.path.join(testset_copy_target_path)) - os.system(cmd) - - -def listdir_nohidden(path): - for f in os.listdir(path): - if not f.startswith("."): - yield f - - -def get_restore_step(path): - checkpoints = os.listdir(path) - if os.path.exists(os.path.join(path, "final.ckpt")): - return "final.ckpt", 0 - elif not os.path.exists(os.path.join(path, "last.ckpt")): - steps = [int(x.split(".ckpt")[0].split("step=")[1]) for x in checkpoints] - return checkpoints[np.argmax(steps)], np.max(steps) - else: - steps = [] - for x in checkpoints: - if "last" in x: - if "-v" not in x: - fname = "last.ckpt" - else: - this_version = int(x.split(".ckpt")[0].split("-v")[1]) - steps.append(this_version) - if len(steps) == 0 or this_version > np.max(steps): - fname = "last-v%s.ckpt" % this_version - return fname, 0 - - -def download(url, local_path, chunk_size=1024): - os.makedirs(os.path.split(local_path)[0], exist_ok=True) - with requests.get(url, stream=True) as r: - total_size = int(r.headers.get("content-length", 0)) - with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: - with open(local_path, "wb") as f: - for data in r.iter_content(chunk_size=chunk_size): - if data: - f.write(data) - pbar.update(chunk_size) - - -def md5_hash(path): - with open(path, "rb") as f: - content = f.read() - return hashlib.md5(content).hexdigest() - - -def get_ckpt_path(name, root, check=False): - assert name in URL_MAP - path = os.path.join(root, CKPT_MAP[name]) - if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): - print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) - download(URL_MAP[name], path) - md5 = md5_hash(path) - assert md5 == MD5_MAP[name], md5 - return path - - -class KeyNotFoundError(Exception): - def __init__(self, cause, keys=None, visited=None): - self.cause = cause - self.keys = keys - self.visited = visited - messages = list() - if keys is not None: - messages.append("Key not found: {}".format(keys)) - if visited is not None: - messages.append("Visited: {}".format(visited)) - messages.append("Cause:\n{}".format(cause)) - message = "\n".join(messages) - super().__init__(message) - - -def retrieve( - list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False -): - """Given a nested list or dict return the desired value at key expanding - callable nodes if necessary and :attr:`expand` is ``True``. The expansion - is done in-place. - - Parameters - ---------- - list_or_dict : list or dict - Possibly nested list or dictionary. - key : str - key/to/value, path like string describing all keys necessary to - consider to get to the desired value. List indices can also be - passed here. - splitval : str - String that defines the delimiter between keys of the - different depth levels in `key`. - default : obj - Value returned if :attr:`key` is not found. - expand : bool - Whether to expand callable nodes on the path or not. - - Returns - ------- - The desired value or if :attr:`default` is not ``None`` and the - :attr:`key` is not found returns ``default``. - - Raises - ------ - Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is - ``None``. - """ - - keys = key.split(splitval) - - success = True - try: - visited = [] - parent = None - last_key = None - for key in keys: - if callable(list_or_dict): - if not expand: - raise KeyNotFoundError( - ValueError( - "Trying to get past callable node with expand=False." - ), - keys=keys, - visited=visited, - ) - list_or_dict = list_or_dict() - parent[last_key] = list_or_dict - - last_key = key - parent = list_or_dict - - try: - if isinstance(list_or_dict, dict): - list_or_dict = list_or_dict[key] - else: - list_or_dict = list_or_dict[int(key)] - except (KeyError, IndexError, ValueError) as e: - raise KeyNotFoundError(e, keys=keys, visited=visited) - - visited += [key] - # final expansion of retrieved value - if expand and callable(list_or_dict): - list_or_dict = list_or_dict() - parent[last_key] = list_or_dict - except KeyNotFoundError as e: - if default is None: - raise e - else: - list_or_dict = default - success = False - - if not pass_success: - return list_or_dict - else: - return list_or_dict, success - - -def to_device(data, device): - if len(data) == 12: - ( - ids, - raw_texts, - speakers, - texts, - src_lens, - max_src_len, - mels, - mel_lens, - max_mel_len, - pitches, - energies, - durations, - ) = data - - speakers = torch.from_numpy(speakers).long().to(device) - texts = torch.from_numpy(texts).long().to(device) - src_lens = torch.from_numpy(src_lens).to(device) - mels = torch.from_numpy(mels).float().to(device) - mel_lens = torch.from_numpy(mel_lens).to(device) - pitches = torch.from_numpy(pitches).float().to(device) - energies = torch.from_numpy(energies).to(device) - durations = torch.from_numpy(durations).long().to(device) - - return ( - ids, - raw_texts, - speakers, - texts, - src_lens, - max_src_len, - mels, - mel_lens, - max_mel_len, - pitches, - energies, - durations, - ) - - if len(data) == 6: - (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data - - speakers = torch.from_numpy(speakers).long().to(device) - texts = torch.from_numpy(texts).long().to(device) - src_lens = torch.from_numpy(src_lens).to(device) - - return (ids, raw_texts, speakers, texts, src_lens, max_src_len) - - -def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=""): - # if losses is not None: - # logger.add_scalar("Loss/total_loss", losses[0], step) - # logger.add_scalar("Loss/mel_loss", losses[1], step) - # logger.add_scalar("Loss/mel_postnet_loss", losses[2], step) - # logger.add_scalar("Loss/pitch_loss", losses[3], step) - # logger.add_scalar("Loss/energy_loss", losses[4], step) - # logger.add_scalar("Loss/duration_loss", losses[5], step) - # if(len(losses) > 6): - # logger.add_scalar("Loss/disc_loss", losses[6], step) - # logger.add_scalar("Loss/fmap_loss", losses[7], step) - # logger.add_scalar("Loss/r_loss", losses[8], step) - # logger.add_scalar("Loss/g_loss", losses[9], step) - # logger.add_scalar("Loss/gen_loss", losses[10], step) - # logger.add_scalar("Loss/diff_loss", losses[11], step) - - if fig is not None: - logger.add_figure(tag, fig) - - if audio is not None: - audio = audio / (max(abs(audio)) * 1.1) - logger.add_audio( - tag, - audio, - sample_rate=sampling_rate, - ) - - -def get_mask_from_lengths(lengths, max_len=None): - batch_size = lengths.shape[0] - if max_len is None: - max_len = torch.max(lengths).item() - - ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) - mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) - - return mask - - -def expand(values, durations): - out = list() - for value, d in zip(values, durations): - out += [value] * max(0, int(d)) - return np.array(out) - - -def synth_one_sample_val( - targets, predictions, vocoder, model_config, preprocess_config -): - index = np.random.choice(list(np.arange(targets[6].size(0)))) - - basename = targets[0][index] - src_len = predictions[8][index].item() - mel_len = predictions[9][index].item() - mel_target = targets[6][index, :mel_len].detach().transpose(0, 1) - - mel_prediction = predictions[0][index, :mel_len].detach().transpose(0, 1) - postnet_mel_prediction = predictions[1][index, :mel_len].detach().transpose(0, 1) - duration = targets[11][index, :src_len].detach().cpu().numpy() - - if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": - pitch = predictions[2][index, :src_len].detach().cpu().numpy() - pitch = expand(pitch, duration) - else: - pitch = predictions[2][index, :mel_len].detach().cpu().numpy() - - if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": - energy = predictions[3][index, :src_len].detach().cpu().numpy() - energy = expand(energy, duration) - else: - energy = predictions[3][index, :mel_len].detach().cpu().numpy() - - with open( - os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") - ) as f: - stats = json.load(f) - stats = stats["pitch"] + stats["energy"][:2] - - # from datetime import datetime - # now = datetime.now() - # current_time = now.strftime("%D:%H:%M:%S") - # np.save(("mel_pred_%s.npy" % current_time).replace("/","-"), mel_prediction.cpu().numpy()) - # np.save(("postnet_mel_prediction_%s.npy" % current_time).replace("/","-"), postnet_mel_prediction.cpu().numpy()) - # np.save(("mel_target_%s.npy" % current_time).replace("/","-"), mel_target.cpu().numpy()) - - fig = plot_mel( - [ - (mel_prediction.cpu().numpy(), pitch, energy), - (postnet_mel_prediction.cpu().numpy(), pitch, energy), - (mel_target.cpu().numpy(), pitch, energy), - ], - stats, - [ - "Raw mel spectrogram prediction", - "Postnet mel prediction", - "Ground-Truth Spectrogram", - ], - ) - - if vocoder is not None: - from .model import vocoder_infer - - wav_reconstruction = vocoder_infer( - mel_target.unsqueeze(0), - vocoder, - model_config, - preprocess_config, - )[0] - wav_prediction = vocoder_infer( - postnet_mel_prediction.unsqueeze(0), - vocoder, - model_config, - preprocess_config, - )[0] - else: - wav_reconstruction = wav_prediction = None - - return fig, wav_reconstruction, wav_prediction, basename - - -def synth_one_sample(mel_input, mel_prediction, labels, vocoder): - if vocoder is not None: - from .model import vocoder_infer - - wav_reconstruction = vocoder_infer( - mel_input.permute(0, 2, 1), - vocoder, - ) - wav_prediction = vocoder_infer( - mel_prediction.permute(0, 2, 1), - vocoder, - ) - else: - wav_reconstruction = wav_prediction = None - - return wav_reconstruction, wav_prediction - - -def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path): - # (diff_output, diff_loss, latent_loss) = diffusion - - basenames = targets[0] - - for i in range(len(predictions[1])): - basename = basenames[i] - src_len = predictions[8][i].item() - mel_len = predictions[9][i].item() - mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1) - # diff_output = diff_output[i, :mel_len].detach().transpose(0, 1) - # duration = predictions[5][i, :src_len].detach().cpu().numpy() - if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": - pitch = predictions[2][i, :src_len].detach().cpu().numpy() - # pitch = expand(pitch, duration) - else: - pitch = predictions[2][i, :mel_len].detach().cpu().numpy() - if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": - energy = predictions[3][i, :src_len].detach().cpu().numpy() - # energy = expand(energy, duration) - else: - energy = predictions[3][i, :mel_len].detach().cpu().numpy() - # import ipdb; ipdb.set_trace() - with open( - os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") - ) as f: - stats = json.load(f) - stats = stats["pitch"] + stats["energy"][:2] - - fig = plot_mel( - [ - (mel_prediction.cpu().numpy(), pitch, energy), - ], - stats, - ["Synthetized Spectrogram by PostNet"], - ) - # np.save("{}_postnet.npy".format(basename), mel_prediction.cpu().numpy()) - plt.savefig(os.path.join(path, "{}_postnet_2.png".format(basename))) - plt.close() - - from .model import vocoder_infer - - mel_predictions = predictions[1].transpose(1, 2) - lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"] - wav_predictions = vocoder_infer( - mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths - ) - - sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] - for wav, basename in zip(wav_predictions, basenames): - wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav) - - -def plot_mel(data, titles=None): - fig, axes = plt.subplots(len(data), 1, squeeze=False) - if titles is None: - titles = [None for i in range(len(data))] - - for i in range(len(data)): - mel = data[i] - axes[i][0].imshow(mel, origin="lower", aspect="auto") - axes[i][0].set_aspect(2.5, adjustable="box") - axes[i][0].set_ylim(0, mel.shape[0]) - axes[i][0].set_title(titles[i], fontsize="medium") - axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) - axes[i][0].set_anchor("W") - - return fig - - -def pad_1D(inputs, PAD=0): - def pad_data(x, length, PAD): - x_padded = np.pad( - x, (0, length - x.shape[0]), mode="constant", constant_values=PAD - ) - return x_padded - - max_len = max((len(x) for x in inputs)) - padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) - - return padded - - -def pad_2D(inputs, maxlen=None): - def pad(x, max_len): - PAD = 0 - if np.shape(x)[0] > max_len: - raise ValueError("not max_len") - - s = np.shape(x)[1] - x_padded = np.pad( - x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD - ) - return x_padded[:, :s] - - if maxlen: - output = np.stack([pad(x, maxlen) for x in inputs]) - else: - max_len = max(np.shape(x)[0] for x in inputs) - output = np.stack([pad(x, max_len) for x in inputs]) - - return output - - -def pad(input_ele, mel_max_length=None): - if mel_max_length: - max_len = mel_max_length - else: - max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) - - out_list = list() - for i, batch in enumerate(input_ele): - if len(batch.shape) == 1: - one_batch_padded = F.pad( - batch, (0, max_len - batch.size(0)), "constant", 0.0 - ) - elif len(batch.shape) == 2: - one_batch_padded = F.pad( - batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 - ) - out_list.append(one_batch_padded) - out_padded = torch.stack(out_list) - return out_padded diff --git a/audioldm2/utils.py b/audioldm2/utils.py deleted file mode 100755 index c098e25b99cf78b0c0befc71fb7ba7688e79c899..0000000000000000000000000000000000000000 --- a/audioldm2/utils.py +++ /dev/null @@ -1,352 +0,0 @@ -import contextlib -import importlib -from huggingface_hub import hf_hub_download - -from inspect import isfunction -import os -import soundfile as sf -import time -import wave - -import progressbar - -CACHE_DIR = os.getenv( - "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2") -) - -def read_list(fname): - result = [] - with open(fname, "r", encoding="utf-8") as f: - for each in f.readlines(): - each = each.strip('\n') - result.append(each) - return result - -def get_duration(fname): - with contextlib.closing(wave.open(fname, "r")) as f: - frames = f.getnframes() - rate = f.getframerate() - return frames / float(rate) - - -def get_bit_depth(fname): - with contextlib.closing(wave.open(fname, "r")) as f: - bit_depth = f.getsampwidth() * 8 - return bit_depth - - -def get_time(): - t = time.localtime() - return time.strftime("%d_%m_%Y_%H_%M_%S", t) - - -def seed_everything(seed): - import random, os - import numpy as np - import torch - - random.seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = True - - -def save_wave(waveform, savepath, name="outwav"): - if type(name) is not list: - name = [name] * waveform.shape[0] - - for i in range(waveform.shape[0]): - path = os.path.join( - savepath, - "%s_%s.wav" - % ( - os.path.basename(name[i]) - if (not ".wav" in name[i]) - else os.path.basename(name[i]).split(".")[0], - i, - ), - ) - print("Save audio to %s" % path) - sf.write(path, waveform[i, 0], samplerate=16000) - - -def exists(x): - return x is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def count_params(model, verbose=False): - total_params = sum(p.numel() for p in model.parameters()) - if verbose: - print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") - return total_params - - -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - - -def instantiate_from_config(config): - if not "target" in config: - if config == "__is_first_stage__": - return None - elif config == "__is_unconditional__": - return None - raise KeyError("Expected key `target` to instantiate.") - try: - return get_obj_from_str(config["target"])(**config.get("params", dict())) - except: - import ipdb - - ipdb.set_trace() - - -def default_audioldm_config(model_name="audioldm2-full"): - basic_config = { - "metadata_root": "/mnt/bn/lqhaoheliu/metadata/processed/dataset_root.json", - "log_directory": "./log/audiomae_pred", - "precision": "high", - "data": { - "train": [ - "audiocaps", - "audioset", - "wavcaps", - "audiostock_music_250k", - "free_to_use_sounds", - "epidemic_sound_effects", - "vggsound", - "million_song_dataset", - ], - "val": "audiocaps", - "test": "audiocaps", - "class_label_indices": "audioset", - "dataloader_add_ons": [ - "extract_kaldi_fbank_feature", - "extract_vits_phoneme_and_flant5_text", - "waveform_rs_48k", - ], - }, - "variables": { - "sampling_rate": 16000, - "mel_bins": 64, - "latent_embed_dim": 8, - "latent_t_size": 256, - "latent_f_size": 16, - "in_channels": 8, - "optimize_ddpm_parameter": True, - "warmup_steps": 5000, - }, - "step": { - "validation_every_n_epochs": 1, - "save_checkpoint_every_n_steps": 5000, - "limit_val_batches": 10, - "max_steps": 1500000, - "save_top_k": 2, - }, - "preprocessing": { - "audio": { - "sampling_rate": 16000, - "max_wav_value": 32768, - "duration": 10.24, - }, - "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024}, - "mel": {"n_mel_channels": 64, "mel_fmin": 0, "mel_fmax": 8000}, - }, - "augmentation": {"mixup": 0}, - "model": { - "target": "audioldm2.latent_diffusion.models.ddpm.LatentDiffusion", - "params": { - "first_stage_config": { - "base_learning_rate": 0.000008, - "target": "audioldm2.latent_encoder.autoencoder.AutoencoderKL", - "params": { - "sampling_rate": 16000, - "batchsize": 4, - "monitor": "val/rec_loss", - "image_key": "fbank", - "subband": 1, - "embed_dim": 8, - "time_shuffle": 1, - "lossconfig": { - "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator", - "params": { - "disc_start": 50001, - "kl_weight": 1000, - "disc_weight": 0.5, - "disc_in_channels": 1, - }, - }, - "ddconfig": { - "double_z": True, - "mel_bins": 64, - "z_channels": 8, - "resolution": 256, - "downsample_time": False, - "in_channels": 1, - "out_ch": 1, - "ch": 128, - "ch_mult": [1, 2, 4], - "num_res_blocks": 2, - "attn_resolutions": [], - "dropout": 0, - }, - }, - }, - "base_learning_rate": 0.0001, - "warmup_steps": 5000, - "optimize_ddpm_parameter": True, - "sampling_rate": 16000, - "batchsize": 16, - "linear_start": 0.0015, - "linear_end": 0.0195, - "num_timesteps_cond": 1, - "log_every_t": 200, - "timesteps": 1000, - "unconditional_prob_cfg": 0.1, - "parameterization": "eps", - "first_stage_key": "fbank", - "latent_t_size": 256, - "latent_f_size": 16, - "channels": 8, - "monitor": "val/loss_simple_ema", - "scale_by_std": True, - "unet_config": { - "target": "audioldm2.latent_diffusion.modules.diffusionmodules.openaimodel.UNetModel", - "params": { - "image_size": 64, - "context_dim": [768, 1024], - "in_channels": 8, - "out_channels": 8, - "model_channels": 128, - "attention_resolutions": [8, 4, 2], - "num_res_blocks": 2, - "channel_mult": [1, 2, 3, 5], - "num_head_channels": 32, - "use_spatial_transformer": True, - "transformer_depth": 1, - }, - }, - "evaluation_params": { - "unconditional_guidance_scale": 3.5, - "ddim_sampling_steps": 200, - "n_candidates_per_samples": 3, - }, - "cond_stage_config": { - "crossattn_audiomae_generated": { - "cond_stage_key": "all", - "conditioning_key": "crossattn", - "target": "audioldm2.latent_diffusion.modules.encoders.modules.SequenceGenAudioMAECond", - "params": { - "always_output_audiomae_gt": False, - "learnable": True, - "device": "cuda", - "use_gt_mae_output": True, - "use_gt_mae_prob": 0.25, - "base_learning_rate": 0.0002, - "sequence_gen_length": 8, - "use_warmup": True, - "sequence_input_key": [ - "film_clap_cond1", - "crossattn_flan_t5", - ], - "sequence_input_embed_dim": [512, 1024], - "batchsize": 16, - "cond_stage_config": { - "film_clap_cond1": { - "cond_stage_key": "text", - "conditioning_key": "film", - "target": "audioldm2.latent_diffusion.modules.encoders.modules.CLAPAudioEmbeddingClassifierFreev2", - "params": { - "sampling_rate": 48000, - "embed_mode": "text", - "amodel": "HTSAT-base", - }, - }, - "crossattn_flan_t5": { - "cond_stage_key": "text", - "conditioning_key": "crossattn", - "target": "audioldm2.latent_diffusion.modules.encoders.modules.FlanT5HiddenState", - }, - "crossattn_audiomae_pooled": { - "cond_stage_key": "ta_kaldi_fbank", - "conditioning_key": "crossattn", - "target": "audioldm2.latent_diffusion.modules.encoders.modules.AudioMAEConditionCTPoolRand", - "params": { - "regularization": False, - "no_audiomae_mask": True, - "time_pooling_factors": [8], - "freq_pooling_factors": [8], - "eval_time_pooling": 8, - "eval_freq_pooling": 8, - "mask_ratio": 0, - }, - }, - }, - }, - }, - "crossattn_flan_t5": { - "cond_stage_key": "text", - "conditioning_key": "crossattn", - "target": "audioldm2.latent_diffusion.modules.encoders.modules.FlanT5HiddenState", - }, - }, - }, - }, - } - return basic_config - - -def get_metadata(): - return { - "audioldm2-full": { - "path": os.path.join( - CACHE_DIR, - "audioldm2-full.pth", - ), - "url": "https://huggingface.co/haoheliu/audioldm2-full/resolve/main/audioldm2-full.pth", - }, - } - - -class MyProgressBar: - def __init__(self): - self.pbar = None - - def __call__(self, block_num, block_size, total_size): - if not self.pbar: - self.pbar = progressbar.ProgressBar(maxval=total_size) - self.pbar.start() - - downloaded = block_num * block_size - if downloaded < total_size: - self.pbar.update(downloaded) - else: - self.pbar.finish() - - -def download_checkpoint(checkpoint_name="audioldm2-full"): - meta = get_metadata() - if checkpoint_name not in meta.keys(): - print( - "The model name you provided is not supported. Please use one of the following: ", - meta.keys(), - ) - - model_id = "haoheliu/%s" % checkpoint_name - hf_hub_download( - repo_id=model_id, - filename=os.path.basename(meta[checkpoint_name]["path"]), - local_dir=os.path.dirname(meta[checkpoint_name]["path"]), - ) diff --git a/batch.lst b/batch.lst deleted file mode 100644 index c52c7a52523775ad729bfbb350f9cd70ddfbf3e4..0000000000000000000000000000000000000000 --- a/batch.lst +++ /dev/null @@ -1,4 +0,0 @@ -A forest of wind chimes singing a soothing melody in the breeze. -A violin playing a heartfelt melody. -A saxophone playing a soulful melody. -Musical constellations twinkling in the night sky, forming a cosmic melody. \ No newline at end of file diff --git a/bin/audioldm2 b/bin/audioldm2 deleted file mode 100755 index 2ff95674ad63326a4b0d7f7b633c2f87949502f3..0000000000000000000000000000000000000000 --- a/bin/audioldm2 +++ /dev/null @@ -1,131 +0,0 @@ -#!/usr/bin/python3 -import os -import torch -import logging -from audioldm2 import text_to_audio, build_model, save_wave, get_time, read_list -import argparse - -os.environ["TOKENIZERS_PARALLELISM"] = "true" -matplotlib_logger = logging.getLogger('matplotlib') -matplotlib_logger.setLevel(logging.WARNING) - - -CACHE_DIR = os.getenv( - "AUDIOLDM_CACHE_DIR", - os.path.join(os.path.expanduser("~"), ".cache/audioldm2")) - -parser = argparse.ArgumentParser() - -parser.add_argument( - "-t", - "--text", - type=str, - required=False, - default="", - help="Text prompt to the model for audio generation", -) - -parser.add_argument( - "-tl", - "--text_list", - type=str, - required=False, - default="", - help="A file that contains text prompt to the model for audio generation", -) - -parser.add_argument( - "-s", - "--save_path", - type=str, - required=False, - help="The path to save model output", - default="./output", -) - -parser.add_argument( - "--model_name", - type=str, - required=False, - help="The checkpoint you gonna use", - default="audioldm2-full", - choices=["audioldm2-full"] -) - -parser.add_argument( - "-b", - "--batchsize", - type=int, - required=False, - default=1, - help="Generate how many samples at the same time", -) - -parser.add_argument( - "--ddim_steps", - type=int, - required=False, - default=200, - help="The sampling step for DDIM", -) - -parser.add_argument( - "-gs", - "--guidance_scale", - type=float, - required=False, - default=3.5, - help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", -) - -parser.add_argument( - "-n", - "--n_candidate_gen_per_text", - type=int, - required=False, - default=3, - help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", -) - -parser.add_argument( - "--seed", - type=int, - required=False, - default=0, - help="Change this value (any integer number) will lead to a different generation result.", -) - -args = parser.parse_args() - -torch.set_float32_matmul_precision("high") - -save_path = os.path.join(args.save_path, get_time()) - -text = args.text -random_seed = args.seed -duration = 10 -guidance_scale = args.guidance_scale -n_candidate_gen_per_text = args.n_candidate_gen_per_text - -os.makedirs(save_path, exist_ok=True) -audioldm2 = build_model(model_name=args.model_name) - -if(args.text_list): - print("Generate audio based on the text prompts in %s" % args.text_list) - prompt_todo = read_list(args.text_list) -else: - prompt_todo = [text] - -for text in prompt_todo: - waveform = text_to_audio( - audioldm2, - text, - seed=random_seed, - duration=duration, - guidance_scale=guidance_scale, - ddim_steps=args.ddim_steps, - n_candidate_gen_per_text=n_candidate_gen_per_text, - batchsize=args.batchsize, - ) - - save_wave(waveform, save_path, name=text) diff --git a/bin/audioldm2.cmd b/bin/audioldm2.cmd deleted file mode 100755 index c164fbfb6a194858b6d9019c8e29df3e57b3172a..0000000000000000000000000000000000000000 --- a/bin/audioldm2.cmd +++ /dev/null @@ -1,2 +0,0 @@ -@echo OFF -python -m audioldm2 %* \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 7611c9394ca30415015459fd4dfc8924bd112ea9..0b008319e8d6cc95822b92cf7f03be4da3f40048 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,4 @@ transformers==4.30.2 --extra-index-url https://download.pytorch.org/whl/cu113 torch >= 2.0 huggingface_hub -soundfile audioldm2 diff --git a/setup.py b/setup.py deleted file mode 100644 index 69f32a48bc29a8e293efbbf3fb05080e940de108..0000000000000000000000000000000000000000 --- a/setup.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- -# python3 setup.py sdist bdist_wheel -""" -@File : setup.py.py -@Contact : haoheliu@gmail.com -@License : (C)Copyright 2020-2100 - -@Modify Time @Author @Version @Desciption ------------- ------- -------- ----------- -9/6/21 5:16 PM Haohe Liu 1.0 None -""" - -# !/usr/bin/env python -# -*- coding: utf-8 -*- - -# Note: To use the 'upload' functionality of this file, you must: -# $ pipenv install twine --dev - -import io -import os -import sys -from shutil import rmtree - -from setuptools import find_packages, setup, Command - -# Package meta-data. -NAME = "audioldm2" -DESCRIPTION = "This package is written for text-to-audio/music generation." -URL = "https://github.com/haoheliu/audioldm2" -EMAIL = "haoheliu@gmail.com" -AUTHOR = "Haohe Liu" -REQUIRES_PYTHON = ">=3.7.0" -VERSION = "0.0.2" - -# What packages are required for this module to be executed? -REQUIRED = [ - "torch>=1.13.0", - "torchaudio>=0.13.0", - "torchvision>=0.14.0", - "tqdm", - "gradio", - "pyyaml", - "einops", - "chardet", - "numpy<=1.23.5", - "soundfile", - "librosa==0.9.2", - "scipy", - "pandas", - "torchlibrosa==0.0.9", - "transformers", - "progressbar", - "ftfy", -] - -# What packages are optional? -EXTRAS = {} - -# The rest you shouldn't have to touch too much :) -# ------------------------------------------------ -# Except, perhaps the License and Trove Classifiers! -# If you do change the License, remember to change the Trove Classifier for that! - -here = os.path.abspath(os.path.dirname(__file__)) - -# Import the README and use it as the long-description. -# Note: this will only work if 'README.md' is present in your MANIFEST.in file! -try: - with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: - long_description = "\n" + f.read() -except FileNotFoundError: - long_description = DESCRIPTION - -# Load the package's __version__.py module as a dictionary. -about = {} -if not VERSION: - project_slug = NAME.lower().replace("-", "_").replace(" ", "_") - with open(os.path.join(here, project_slug, "__version__.py")) as f: - exec(f.read(), about) -else: - about["__version__"] = VERSION - - -class UploadCommand(Command): - """Support setup.py upload.""" - - description = "Build and publish the package." - user_options = [] - - @staticmethod - def status(s): - """Prints things in bold.""" - print("\033[1m{0}\033[0m".format(s)) - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - try: - self.status("Removing previous builds…") - rmtree(os.path.join(here, "dist")) - except OSError: - pass - - self.status("Building Source and Wheel (universal) distribution…") - os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) - - self.status("Uploading the package to PyPI via Twine…") - os.system("twine upload dist/*") - - self.status("Pushing git tags…") - os.system("git tag v{0}".format(about["__version__"])) - os.system("git push --tags") - - sys.exit() - - -# Where the magic happens: -setup( - name=NAME, - version=about["__version__"], - description=DESCRIPTION, - long_description=long_description, - long_description_content_type="text/markdown", - author=AUTHOR, - author_email=EMAIL, - python_requires=REQUIRES_PYTHON, - url=URL, - # packages=find_packages(exclude=[]), - # If your package is a single module, use this instead of 'packages': - # entry_points={ - # 'console_scripts': ['mycli=mymodule:cli'], - # }, - install_requires=REQUIRED, - extras_require=EXTRAS, - packages=find_packages(), - include_package_data=True, - license="MIT", - classifiers=[ - # Trove classifiers - # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers - "License :: OSI Approved :: MIT License", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - ], - # $ setup.py publish support. - cmdclass={ - "upload": UploadCommand, - }, - scripts=["bin/audioldm2.cmd", "bin/audioldm2"], -) diff --git a/tests/code_coverage.py b/tests/code_coverage.py deleted file mode 100644 index deb035e9fedffacd8bf3a9c37d5566fa8fd4e819..0000000000000000000000000000000000000000 --- a/tests/code_coverage.py +++ /dev/null @@ -1,3 +0,0 @@ -import os - -os.system('python3 bin/audioldm2 -t "A toilet flushing and water trickling"') diff --git a/tests/code_coverage.sh b/tests/code_coverage.sh deleted file mode 100644 index 0a5920c645c262c80436ff586e7ad4825e9e5622..0000000000000000000000000000000000000000 --- a/tests/code_coverage.sh +++ /dev/null @@ -1 +0,0 @@ -pytest --cov=src tests/* \ No newline at end of file