diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..68bc17f9ff2104a9d7b6777058bb4c343ca72609 --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..8060e2bf1747e386fb9ef65837631e0a7adcaaab --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Xubo Liu + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE \ No newline at end of file diff --git a/README.md b/README.md index 5635c5cbfd32406e91a587a15f9a9c8e48832d01..9a7c85e03be2a4e982cb29e6f4f3b6cd6f0f5a73 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,4 @@ sdk_version: 3.47.1 app_file: app.py pinned: false license: mit ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +--- \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2797d006ac6d62dbe8f4ed41d166acbee807a46f --- /dev/null +++ b/app.py @@ -0,0 +1,115 @@ +from pathlib import Path +from threading import Thread + +import gdown +import gradio as gr +import librosa +import numpy as np +import torch + +from pipeline import build_audiosep + +CHECKPOINTS_DIR = Path("checkpoint") + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# The model will be loaded in the future +MODEL_NAME = CHECKPOINTS_DIR / "audiosep_base_4M_steps.ckpt" +MODEL = None + + +description = """ +# AudioSep: Separate Anything You Describe +[[Project Page]](https://audio-agi.github.io/Separate-Anything-You-Describe) [[Paper]](https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf) [[Code]](https://github.com/Audio-AGI/AudioSep) + +We introduce AudioSep, a foundation model for open-domain sound separation with natural language queries. +AudioSep demonstrates strong separation performance and impressivezero-shot generalization ability on +numerous tasks such as audio event separation, musical instrument separation, and speech enhancement. +""" + + +def get_model(): + model = build_audiosep( + config_yaml="config/audiosep_base.yaml", + checkpoint_path=MODEL_NAME, + device=DEVICE, + ) + return model + + +def inference(audio_file_path: str, text: str): + print(f"Separate audio from [{audio_file_path}] with textual query [{text}]") + mixture, _ = librosa.load(audio_file_path, sr=32000, mono=True) + + with torch.no_grad(): + text = [text] + + conditions = MODEL.query_encoder.get_query_embed( + modality="text", text=text, device=DEVICE + ) + + input_dict = { + "mixture": torch.Tensor(mixture)[None, None, :].to(DEVICE), + "condition": conditions, + } + + sep_segment = MODEL.ss_model(input_dict)["waveform"] + + sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() + + return 32000, np.round(sep_segment * 32767).astype(np.int16) + + +def download_models(): + CHECKPOINTS_DIR.mkdir(exist_ok=True) + success_file = CHECKPOINTS_DIR / "_SUCCESS" + + models = ( + ( + "https://drive.google.com/file/d/1wQuXThdATXrkmkPM2sRGaNapJ4mTqmlY/view?usp=sharing", + MODEL_NAME, + ), + ( + "https://drive.google.com/file/d/11oj8_tPG6SXgw5fIEsZ5HiWZnJOrvdhw/view?usp=sharing", + CHECKPOINTS_DIR / "music_speech_audioset_epoch_15_esc_89.98.pt", + ), + ) + + def download(models): + for model_url, model_path in models: + gdown.download(model_url, str(model_path), quiet=False, fuzzy=True) + + success_file.touch() + + global MODEL + MODEL = get_model() + button.update(value="Separate", interactive=True) + + if not success_file.exists(): + thread = Thread(target=download, args=[models]) + thread.start() + + +with gr.Blocks(title="AudioSep") as demo: + gr.Markdown(description) + with gr.Row(): + with gr.Column(): + input_audio = gr.Audio() + text = gr.Textbox() + with gr.Column(): + with gr.Column(): + output_audio = gr.Audio(scale=10) + button = gr.Button( + "Downloading the models...", + variant="primary", + scale=2, + size="lg", + interactive=False, + ) + button.click( + fn=inference, inputs=[input_audio, text], outputs=[output_audio] + ) + +download_models() + +demo.queue().launch(share=True) diff --git a/assets/results.png b/assets/results.png new file mode 100644 index 0000000000000000000000000000000000000000..5656fca373ed7f049baad5d72da9c9c3017a3b10 Binary files /dev/null and b/assets/results.png differ diff --git a/callbacks/base.py b/callbacks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..9caad4bacba02b46c6668563f58c95b45d370c06 --- /dev/null +++ b/callbacks/base.py @@ -0,0 +1,35 @@ +import os +import lightning.pytorch as pl +from lightning.pytorch.utilities import rank_zero_only + + +class CheckpointEveryNSteps(pl.Callback): + def __init__( + self, + checkpoints_dir, + save_step_frequency, + ) -> None: + r"""Save a checkpoint every N steps. + + Args: + checkpoints_dir (str): directory to save checkpoints + save_step_frequency (int): save checkpoint every N step + """ + + self.checkpoints_dir = checkpoints_dir + self.save_step_frequency = save_step_frequency + + @rank_zero_only + def on_train_batch_end(self, *args, **kwargs) -> None: + r"""Save a checkpoint every N steps.""" + + trainer = args[0] + global_step = trainer.global_step + + if global_step == 1 or global_step % self.save_step_frequency == 0: + + ckpt_path = os.path.join( + self.checkpoints_dir, + "step={}.ckpt".format(global_step)) + trainer.save_checkpoint(ckpt_path) + print("Save checkpoint to {}".format(ckpt_path)) diff --git a/config/audiosep_base.yaml b/config/audiosep_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cb54d16397e523cb67a602a631276aa0770e7ae0 --- /dev/null +++ b/config/audiosep_base.yaml @@ -0,0 +1,41 @@ +--- +task_name: AudioSep + +data: + datafiles: + - 'datafiles/template.json' + + sampling_rate: 32000 + segment_seconds: 5 + loudness_norm: + lower_db: -10 + higher_db: 10 + max_mix_num: 2 + +model: + query_net: CLAP + condition_size: 512 + model_type: ResUNet30 + input_channels: 1 + output_channels: 1 + resume_checkpoint: "" + use_text_ratio: 1.0 + +train: + optimizer: + optimizer_type: AdamW + learning_rate: 1e-3 + warm_up_steps: 10000 + reduce_lr_steps: 1000000 + lr_lambda_type: constant_warm_up + num_nodes: 1 + num_workers: 6 + loss_type: l1_wav + sync_batchnorm: True + batch_size_per_device: 12 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 20000 # Save every #save_step_frequency steps. + early_stop_steps: 10000001 + random_seed: 1234 + diff --git a/data/audiotext_dataset.py b/data/audiotext_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1e9cc2c037516f7768ffca8d8083d137a4879dba --- /dev/null +++ b/data/audiotext_dataset.py @@ -0,0 +1,91 @@ +import json +import random +import torch +import torchaudio +from torch.utils.data import Dataset + + +class AudioTextDataset(Dataset): + """Can sample data from audio-text databases + Params: + sampling_rate: audio sampling rate + max_clip_len: max length (seconds) of audio clip to be sampled + """ + def __init__( + self, + datafiles=[''], + sampling_rate=32000, + max_clip_len=5, + ): + all_data_json = [] + for datafile in datafiles: + with open(datafile, 'r') as fp: + data_json = json.load(fp)['data'] + all_data_json.extend(data_json) + self.all_data_json = all_data_json + + self.sampling_rate = sampling_rate + self.max_length = max_clip_len * sampling_rate + + def __len__(self): + return len(self.all_data_json) + + def _cut_or_randomcrop(self, waveform): + # waveform: [1, samples] + # random crop + if waveform.size(1) > self.max_length: + random_idx = random.randint(0, waveform.size(1)-self.max_length) + waveform = waveform[:, random_idx:random_idx+self.max_length] + else: + temp_wav = torch.zeros(1, self.max_length) + temp_wav[:, 0:waveform.size(1)] = waveform + waveform = temp_wav + + assert waveform.size(1) == self.max_length, \ + f"number of audio samples is {waveform.size(1)}" + + return waveform + + def _read_audio(self, index): + try: + audio_path = self.all_data_json[index]['wav'] + audio_data, audio_rate = torchaudio.load(audio_path, channels_first=True) + text = self.all_data_json[index]['caption'] + + # drop short utterance + if audio_data.size(1) < self.sampling_rate * 1: + raise Exception(f'{audio_path} is too short, drop it ...') + + return text, audio_data, audio_rate + + except Exception as e: + print(f'error: {e} occurs, when loading {audio_path}') + random_index = random.randint(0, len(self.all_data_json)-1) + return self._read_audio(index=random_index) + + def __getitem__(self, index): + # create a audio tensor + text, audio_data, audio_rate = self._read_audio(index) + audio_len = audio_data.shape[1] / audio_rate + # convert stero to single channel + if audio_data.shape[0] > 1: + # audio_data: [samples] + audio_data = (audio_data[0] + audio_data[1]) / 2 + else: + audio_data = audio_data.squeeze(0) + + # resample audio clip + if audio_rate != self.sampling_rate: + audio_data = torchaudio.functional.resample(audio_data, orig_freq=audio_rate, new_freq=self.sampling_rate) + + audio_data = audio_data.unsqueeze(0) + + audio_data = self._cut_or_randomcrop(audio_data) + + data_dict = { + 'text': text, + 'waveform': audio_data, + 'modality': 'audio_text' + } + + return data_dict diff --git a/data/datamodules.py b/data/datamodules.py new file mode 100644 index 0000000000000000000000000000000000000000..73136ada66a08103d6bf318ca1154525404faa41 --- /dev/null +++ b/data/datamodules.py @@ -0,0 +1,122 @@ +from typing import Dict, List, Optional, NoReturn +import torch +import lightning.pytorch as pl +from torch.utils.data import DataLoader +from data.audiotext_dataset import AudioTextDataset + + +class DataModule(pl.LightningDataModule): + def __init__( + self, + train_dataset: object, + batch_size: int, + num_workers: int + ): + r"""Data module. To get one batch of data: + + code-block:: python + + data_module.setup() + + for batch_data_dict in data_module.train_dataloader(): + print(batch_data_dict.keys()) + break + + Args: + train_sampler: Sampler object + train_dataset: Dataset object + num_workers: int + distributed: bool + """ + super().__init__() + self._train_dataset = train_dataset + self.num_workers = num_workers + self.batch_size = batch_size + self.collate_fn = collate_fn + + + def prepare_data(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + pass + + def setup(self, stage: Optional[str] = None) -> NoReturn: + r"""called on every device.""" + + # make assignments here (val/train/test split) + # called on every process in DDP + + # SegmentSampler is used for selecting segments for training. + # On multiple devices, each SegmentSampler samples a part of mini-batch + # data. + self.train_dataset = self._train_dataset + + + def train_dataloader(self) -> torch.utils.data.DataLoader: + r"""Get train loader.""" + train_loader = DataLoader( + dataset=self.train_dataset, + batch_size=self.batch_size, + collate_fn=self.collate_fn, + num_workers=self.num_workers, + pin_memory=True, + persistent_workers=False, + shuffle=True + ) + + return train_loader + + def val_dataloader(self): + # val_split = Dataset(...) + # return DataLoader(val_split) + pass + + def test_dataloader(self): + # test_split = Dataset(...) + # return DataLoader(test_split) + pass + + def teardown(self): + # clean up after fit or test + # called on every process in DDP + pass + + +def collate_fn(list_data_dict): + r"""Collate mini-batch data to inputs and targets for training. + + Args: + list_data_dict: e.g., [ + { + 'text': 'a sound of dog', + 'waveform': (1, samples), + 'modality': 'audio_text' + } + ... + ] + Returns: + data_dict: e.g. + 'audio_text': { + 'text': ['a sound of dog', ...] + 'waveform': (batch_size, 1, samples) + } + """ + + at_list_data_dict = [data_dict for data_dict in list_data_dict if data_dict['modality']=='audio_text'] + + at_data_dict = {} + + if len(at_list_data_dict) > 0: + for key in at_list_data_dict[0].keys(): + at_data_dict[key] = [at_data_dict[key] for at_data_dict in at_list_data_dict] + if key == 'waveform': + at_data_dict[key] = torch.stack(at_data_dict[key]) + elif key == 'text': + at_data_dict[key] = [text for text in at_data_dict[key]] + + + data_dict = { + 'audio_text': at_data_dict + } + + return data_dict \ No newline at end of file diff --git a/data/waveform_mixers.py b/data/waveform_mixers.py new file mode 100644 index 0000000000000000000000000000000000000000..3b3f4df61c4a97d6450c524315b6ac359d82215f --- /dev/null +++ b/data/waveform_mixers.py @@ -0,0 +1,127 @@ +import random +import sre_compile +import numpy as np +import torch +import torch.nn as nn +import pyloudnorm as pyln + + +class SegmentMixer(nn.Module): + def __init__(self, max_mix_num, lower_db, higher_db): + super(SegmentMixer, self).__init__() + + self.max_mix_num = max_mix_num + self.loudness_param = { + 'lower_db': lower_db, + 'higher_db': higher_db, + } + + def __call__(self, waveforms): + + batch_size = waveforms.shape[0] + + data_dict = { + 'segment': [], + 'mixture': [], + } + + for n in range(0, batch_size): + + segment = waveforms[n].clone() + + # create zero tensors as the background template + noise = torch.zeros_like(segment) + + mix_num = random.randint(2, self.max_mix_num) + assert mix_num >= 2 + + for i in range(1, mix_num): + next_segment = waveforms[(n + i) % batch_size] + rescaled_next_segment = dynamic_loudnorm(audio=next_segment, reference=segment, **self.loudness_param) + noise += rescaled_next_segment + + # randomly normalize background noise + noise = dynamic_loudnorm(audio=noise, reference=segment, **self.loudness_param) + + # create audio mixyure + mixture = segment + noise + + # declipping if need be + max_value = torch.max(torch.abs(mixture)) + if max_value > 1: + segment *= 0.9 / max_value + mixture *= 0.9 / max_value + + data_dict['segment'].append(segment) + data_dict['mixture'].append(mixture) + + for key in data_dict.keys(): + data_dict[key] = torch.stack(data_dict[key], dim=0) + + # return data_dict + return data_dict['mixture'], data_dict['segment'] + + +def rescale_to_match_energy(segment1, segment2): + + ratio = get_energy_ratio(segment1, segment2) + rescaled_segment1 = segment1 / ratio + return rescaled_segment1 + + +def get_energy(x): + return torch.mean(x ** 2) + + +def get_energy_ratio(segment1, segment2): + + energy1 = get_energy(segment1) + energy2 = max(get_energy(segment2), 1e-10) + ratio = (energy1 / energy2) ** 0.5 + ratio = torch.clamp(ratio, 0.02, 50) + return ratio + + +def dynamic_loudnorm(audio, reference, lower_db=-10, higher_db=10): + rescaled_audio = rescale_to_match_energy(audio, reference) + + delta_loudness = random.randint(lower_db, higher_db) + + gain = np.power(10.0, delta_loudness / 20.0) + + return gain * rescaled_audio + + +def torch_to_numpy(tensor): + """Convert a PyTorch tensor to a NumPy array.""" + if isinstance(tensor, torch.Tensor): + return tensor.detach().cpu().numpy() + else: + raise ValueError("Input must be a PyTorch tensor.") + + +def numpy_to_torch(array): + """Convert a NumPy array to a PyTorch tensor.""" + if isinstance(array, np.ndarray): + return torch.from_numpy(array) + else: + raise ValueError("Input must be a NumPy array.") + + +# decayed +def random_loudness_norm(audio, lower_db=-35, higher_db=-15, sr=32000): + device = audio.device + audio = torch_to_numpy(audio.squeeze(0)) + # randomly select a norm volume + norm_vol = random.randint(lower_db, higher_db) + + # measure the loudness first + meter = pyln.Meter(sr) # create BS.1770 meter + loudness = meter.integrated_loudness(audio) + # loudness normalize audio + normalized_audio = pyln.normalize.loudness(audio, loudness, norm_vol) + + normalized_audio = numpy_to_torch(normalized_audio).unsqueeze(0) + + return normalized_audio.to(device) + diff --git a/datafiles/template.json b/datafiles/template.json new file mode 100644 index 0000000000000000000000000000000000000000..f0895e0b9060b64465637894bf982fe00e52b954 --- /dev/null +++ b/datafiles/template.json @@ -0,0 +1,8 @@ +{ + "data": [ + { + "wav": "path_to_audio_file", + "caption": "textual_desciptions" + } + ] +} \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..d124e0b8929ff8a809bb736e29d442781c459390 --- /dev/null +++ b/environment.yml @@ -0,0 +1,326 @@ +name: AudioSep +channels: + - pytorch + - nvidia + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - backcall=0.2.0=pyhd3eb1b0_0 + - blas=1.0=mkl + - boltons=23.0.0=py310h06a4308_0 + - brotlipy=0.7.0=py310h7f8727e_1002 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2023.01.10=h06a4308_0 + - certifi=2022.12.7=py310h06a4308_0 + - cffi=1.15.1=py310h5eee18b_3 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - comm=0.1.2=py310h06a4308_0 + - conda=23.3.1=py310h06a4308_0 + - conda-content-trust=0.1.3=py310h06a4308_0 + - conda-package-handling=2.0.2=py310h06a4308_0 + - conda-package-streaming=0.7.0=py310h06a4308_0 + - cryptography=38.0.4=py310h9ce1e76_0 + - cuda=11.6.1=0 + - cuda-cccl=11.6.55=hf6102b2_0 + - cuda-command-line-tools=11.6.2=0 + - cuda-compiler=11.6.2=0 + - cuda-cudart=11.6.55=he381448_0 + - cuda-cudart-dev=11.6.55=h42ad0f4_0 + - cuda-cuobjdump=11.6.124=h2eeebcb_0 + - cuda-cupti=11.6.124=h86345e5_0 + - cuda-cuxxfilt=11.6.124=hecbf4f6_0 + - cuda-driver-dev=11.6.55=0 + - cuda-gdb=12.1.55=0 + - cuda-libraries=11.6.1=0 + - cuda-libraries-dev=11.6.1=0 + - cuda-memcheck=11.8.86=0 + - cuda-nsight=12.1.55=0 + - cuda-nsight-compute=12.1.0=0 + - cuda-nvcc=11.6.124=hbba6d2d_0 + - cuda-nvdisasm=12.1.55=0 + - cuda-nvml-dev=11.6.55=haa9ef22_0 + - cuda-nvprof=12.1.55=0 + - cuda-nvprune=11.6.124=he22ec0a_0 + - cuda-nvrtc=11.6.124=h020bade_0 + - cuda-nvrtc-dev=11.6.124=h249d397_0 + - cuda-nvtx=11.6.124=h0630a44_0 + - cuda-nvvp=12.1.55=0 + - cuda-runtime=11.6.1=0 + - cuda-samples=11.6.101=h8efea70_0 + - cuda-sanitizer-api=12.1.55=0 + - cuda-toolkit=11.6.1=0 + - cuda-tools=11.6.1=0 + - cuda-visual-tools=11.6.1=0 + - debugpy=1.5.1=py310h295c915_0 + - decorator=5.1.1=pyhd3eb1b0_0 + - flit-core=3.8.0=py310h06a4308_0 + - freetype=2.12.1=h4a9f257_0 + - gds-tools=1.6.0.25=0 + - giflib=5.2.1=h5eee18b_3 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - idna=3.4=py310h06a4308_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - ipykernel=6.19.2=py310h2f386ee_0 + - ipython=8.12.0=py310h06a4308_0 + - jpeg=9e=h5eee18b_1 + - jsonpatch=1.32=pyhd3eb1b0_0 + - jsonpointer=2.1=pyhd3eb1b0_0 + - jupyter_client=8.1.0=py310h06a4308_0 + - jupyter_core=5.3.0=py310h06a4308_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libcublas=11.9.2.110=h5e84587_0 + - libcublas-dev=11.9.2.110=h5c901ab_0 + - libcufft=10.7.1.112=hf425ae0_0 + - libcufft-dev=10.7.1.112=ha5ce4c0_0 + - libcufile=1.6.0.25=0 + - libcufile-dev=1.6.0.25=0 + - libcurand=10.3.2.56=0 + - libcurand-dev=10.3.2.56=0 + - libcusolver=11.3.4.124=h33c3c4e_0 + - libcusparse=11.7.2.124=h7538f96_0 + - libcusparse-dev=11.7.2.124=hbbe9722_0 + - libdeflate=1.17=h5eee18b_0 + - libffi=3.4.2=h6a678d5_6 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.2=h7f8727e_0 + - libnpp=11.6.3.124=hd2722f0_0 + - libnpp-dev=11.6.3.124=h3c42840_0 + - libnvjpeg=11.6.2.124=hd473ad6_0 + - libnvjpeg-dev=11.6.2.124=hb5906b9_0 + - libpng=1.6.39=h5eee18b_0 + - libsodium=1.0.18=h7b6447c_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.0=h6a678d5_2 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.41.5=h5eee18b_0 + - libwebp=1.2.4=h11a3e52_1 + - libwebp-base=1.2.4=h5eee18b_1 + - lz4-c=1.9.4=h6a678d5_0 + - matplotlib-inline=0.1.6=py310h06a4308_0 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py310h7f8727e_0 + - mkl_fft=1.3.1=py310hd6ae3a3_0 + - mkl_random=1.2.2=py310h00e6091_0 + - ncurses=6.4=h6a678d5_0 + - nest-asyncio=1.5.6=py310h06a4308_0 + - nettle=3.7.3=hbbd107a_1 + - nsight-compute=2023.1.0.15=0 + - numpy=1.23.5=py310hd5efca6_0 + - numpy-base=1.23.5=py310h8e6c178_0 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1t=h7f8727e_0 + - packaging=23.0=py310h06a4308_0 + - parso=0.8.3=pyhd3eb1b0_0 + - pexpect=4.8.0=pyhd3eb1b0_3 + - pickleshare=0.7.5=pyhd3eb1b0_1003 + - pip=22.3.1=py310h06a4308_0 + - platformdirs=2.5.2=py310h06a4308_0 + - pluggy=1.0.0=py310h06a4308_1 + - psutil=5.9.0=py310h5eee18b_0 + - ptyprocess=0.7.0=pyhd3eb1b0_2 + - pure_eval=0.2.2=pyhd3eb1b0_0 + - pycosat=0.6.4=py310h5eee18b_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pyopenssl=22.0.0=pyhd3eb1b0_0 + - pysocks=1.7.1=py310h06a4308_0 + - python=3.10.9=h7a1cb2a_0 + - python-dateutil=2.8.2=pyhd3eb1b0_0 + - pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0 + - pytorch-cuda=11.6=h867d48c_1 + - pytorch-mutex=1.0=cuda + - pyzmq=23.2.0=py310h6a678d5_0 + - readline=8.2=h5eee18b_0 + - requests=2.28.1=py310h06a4308_0 + - ruamel.yaml=0.17.21=py310h5eee18b_0 + - ruamel.yaml.clib=0.2.6=py310h5eee18b_1 + - setuptools=65.6.3=py310h06a4308_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.40.1=h5082296_0 + - stack_data=0.2.0=pyhd3eb1b0_0 + - tk=8.6.12=h1ccaba5_0 + - toolz=0.12.0=py310h06a4308_0 + - torchaudio=0.13.1=py310_cu116 + - torchvision=0.14.1=py310_cu116 + - tornado=6.2=py310h5eee18b_0 + - tqdm=4.64.1=py310h06a4308_0 + - typing_extensions=4.4.0=py310h06a4308_0 + - tzdata=2022g=h04d1e81_0 + - urllib3=1.26.14=py310h06a4308_0 + - wheel=0.37.1=pyhd3eb1b0_0 + - xz=5.2.10=h5eee18b_1 + - zeromq=4.3.4=h2531618_0 + - zlib=1.2.13=h5eee18b_0 + - zstandard=0.18.0=py310h5eee18b_0 + - zstd=1.5.4=hc292b87_0 + - pip: + - absl-py==1.4.0 + - aiohttp==3.8.4 + - aiosignal==1.3.1 + - anyio==3.6.2 + - appdirs==1.4.4 + - arrow==1.2.3 + - asttokens==2.2.1 + - async-generator==1.10 + - async-timeout==4.0.2 + - attrs==22.2.0 + - audioread==3.0.0 + - av==10.0.0 + - beartype==0.12.0 + - beautifulsoup4==4.12.2 + - blessed==1.20.0 + - braceexpand==0.1.7 + - cachetools==5.3.0 + - click==8.1.3 + - contourpy==1.0.7 + - croniter==1.3.10 + - cycler==0.11.0 + - dataclasses-json==0.5.8 + - dateutils==0.6.12 + - decord==0.6.0 + - deepdiff==6.3.0 + - dtk==0.2 + - exceptiongroup==1.1.1 + - executing==1.2.0 + - fastapi==0.88.0 + - ffmpeg==1.4 + - ffmpeg-python==0.2.0 + - filelock==3.12.0 + - fonttools==4.39.3 + - frozenlist==1.3.3 + - fsspec==2023.4.0 + - ftfy==6.1.1 + - future==0.18.3 + - gammatone==1.0 + - google-auth==2.17.3 + - google-auth-oauthlib==1.0.0 + - greenlet==2.0.2 + - grpcio==1.54.0 + - h11==0.14.0 + - h5py==3.8.0 + - hickle==5.0.2 + - huggingface-hub==0.14.1 + - humanize==4.6.0 + - imageio==2.27.0 + - inquirer==3.1.3 + - ipdb==0.13.13 + - itsdangerous==2.1.2 + - jedi==0.18.2 + - jinja2==3.1.2 + - joblib==1.2.0 + - kiwisolver==1.4.4 + - langchain==0.0.216 + - langchainplus-sdk==0.0.17 + - lazy-loader==0.2 + - librosa==0.10.0.post2 + - lightning==2.0.0 + - lightning-cloud==0.5.33 + - lightning-utilities==0.8.0 + - llvmlite==0.39.1 + - markdown==3.4.3 + - markdown-it-py==2.2.0 + - markupsafe==2.1.2 + - marshmallow==3.19.0 + - marshmallow-enum==1.5.1 + - matplotlib==3.7.1 + - mdurl==0.1.2 + - mergedeep==1.3.4 + - mock==5.0.2 + - msgpack==1.0.5 + - msgpack-numpy==0.4.8 + - multidict==6.0.4 + - musdb==0.4.0 + - mypy-extensions==1.0.0 + - networkx==3.1 + - nose==1.3.7 + - numba==0.56.4 + - numexpr==2.8.4 + - oauthlib==3.2.2 + - openai==0.27.8 + - openapi-schema-pydantic==1.2.4 + - opencv-python==4.7.0.72 + - ordered-set==4.1.0 + - outcome==1.2.0 + - pandas==1.5.3 + - panns-inference==0.1.0 + - pesq==0.0.4 + - pillow==9.5.0 + - pooch==1.6.0 + - prompt-toolkit==3.0.38 + - protobuf==4.22.3 + - pyaml==23.5.9 + - pyasn1==0.5.0 + - pyasn1-modules==0.3.0 + - pydantic==1.10.7 + - pygments==2.14.0 + - pyjwt==2.6.0 + - pyloudnorm==0.1.1 + - pyparsing==3.0.9 + - pystoi==0.3.3 + - python-editor==1.0.4 + - python-multipart==0.0.6 + - pytorch-ignite==0.3.0 + - pytorch-lightning==2.0.1.post0 + - pytz==2023.3 + - pywavelets==1.4.1 + - pyyaml==6.0 + - readchar==4.0.5 + - regex==2023.3.23 + - requests-oauthlib==1.3.1 + - resampy==0.4.2 + - rich==13.3.3 + - rsa==4.9 + - scikit-image==0.20.0 + - scikit-learn==1.2.2 + - scipy==1.10.1 + - selenium==4.8.3 + - simplejpeg==1.6.6 + - sniffio==1.3.0 + - sortedcontainers==2.4.0 + - soundfile==0.12.1 + - soupsieve==2.4 + - soxr==0.3.5 + - sqlalchemy==2.0.17 + - stack-data==0.6.2 + - starlette==0.22.0 + - starsessions==1.3.0 + - stempeg==0.2.3 + - tenacity==8.2.2 + - tensorboard==2.12.2 + - tensorboard-data-server==0.7.0 + - tensorboard-plugin-wit==1.8.1 + - termcolor==1.1.0 + - threadpoolctl==3.1.0 + - tifffile==2023.3.21 + - timm==0.3.2 + - tokenizers==0.13.3 + - tomli==2.0.1 + - torchfile==0.1.0 + - torchlibrosa==0.1.0 + - torchmetrics==0.11.4 + - traitlets==5.9.0 + - transformers==4.28.1 + - trio==0.22.0 + - trio-websocket==0.10.2 + - typeguard==3.0.2 + - typing-extensions==4.5.0 + - typing-inspect==0.9.0 + - uvicorn==0.21.1 + - visdom==0.1.8.9 + - wcwidth==0.2.6 + - webdataset==0.2.48 + - websocket-client==1.5.1 + - websockets==11.0.1 + - werkzeug==2.2.3 + - wget==3.2 + - wsproto==1.2.0 + - yarl==1.8.2 + - zenodo-get==1.3.4 + - zsvision==0.7.8 \ No newline at end of file diff --git a/losses.py b/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf599fa6ecb91c086394b06c81ce3dee927a012 --- /dev/null +++ b/losses.py @@ -0,0 +1,17 @@ +import torch + + +def l1(output, target): + return torch.mean(torch.abs(output - target)) + + +def l1_wav(output_dict, target_dict): + return l1(output_dict['segment'], target_dict['segment']) + + +def get_loss_function(loss_type): + if loss_type == "l1_wav": + return l1_wav + + else: + raise NotImplementedError("Error!") diff --git a/models/CLAP/__init__.py b/models/CLAP/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/CLAP/open_clip/__init__.py b/models/CLAP/open_clip/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e9f728f2f273be5d5fdbec6c6cc41d737176a8c0 --- /dev/null +++ b/models/CLAP/open_clip/__init__.py @@ -0,0 +1,25 @@ +from .factory import ( + list_models, + create_model, + create_model_and_transforms, + add_model_config, +) +from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics +from .model import ( + CLAP, + CLAPTextCfg, + CLAPVisionCfg, + CLAPAudioCfp, + convert_weights_to_fp16, + trace_model, +) +from .openai import load_openai_model, list_openai_models +from .pretrained import ( + list_pretrained, + list_pretrained_tag_models, + list_pretrained_model_tags, + get_pretrained_url, + download_pretrained, +) +from .tokenizer import SimpleTokenizer, tokenize +from .transform import image_transform diff --git a/models/CLAP/open_clip/bert.py b/models/CLAP/open_clip/bert.py new file mode 100755 index 0000000000000000000000000000000000000000..a83d96d2a77ed05198efc05837522bc88d2499cc --- /dev/null +++ b/models/CLAP/open_clip/bert.py @@ -0,0 +1,40 @@ +from transformers import BertTokenizer, BertModel + +tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") +model = BertModel.from_pretrained("bert-base-uncased") +text = "Replace me by any text you'd like." + + +def bert_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output = model(**encoded_input) + return output + + +from transformers import RobertaTokenizer, RobertaModel + +tokenizer = RobertaTokenizer.from_pretrained("roberta-base") +model = RobertaModel.from_pretrained("roberta-base") +text = "Replace me by any text you'd like." + + +def Roberta_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output = model(**encoded_input) + return output + + +from transformers import BartTokenizer, BartModel + +tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") +model = BartModel.from_pretrained("facebook/bart-base") +text = "Replace me by any text you'd like." + + +def bart_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output = model(**encoded_input) + return output diff --git a/models/CLAP/open_clip/factory.py b/models/CLAP/open_clip/factory.py new file mode 100755 index 0000000000000000000000000000000000000000..844f9ca0e12a0ff43ba3e042a3e43530ebe91b8c --- /dev/null +++ b/models/CLAP/open_clip/factory.py @@ -0,0 +1,277 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path + +import torch + +from .model import CLAP, convert_weights_to_fp16 +from .openai import load_openai_model +from .pretrained import get_pretrained_url, download_pretrained +from .transform import image_transform + +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = (".json",) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f"*{ext}")) + + for cf in config_files: + if os.path.basename(cf)[0] == ".": + continue # Ignore hidden files + + with open(cf, "r") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = { + k: v + for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) + } + + +_rescan_model_configs() # initial populate of model config registry + + +def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + if skip_params: + if next(iter(state_dict.items()))[0].startswith("module"): + state_dict = {k[7:]: v for k, v in state_dict.items()} + # for k in state_dict: + # if k.startswith('transformer'): + # v = state_dict.pop(k) + # state_dict['text_branch.' + k[12:]] = v + return state_dict + + +def create_model( + amodel_name: str, + tmodel_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"), + skip_params=True, + pretrained_audio: str = "", + pretrained_text: str = "", + enable_fusion: bool = False, + fusion_type: str = "None" + # pretrained_image: bool = False, +): + amodel_name = amodel_name.replace( + "/", "-" + ) # for callers using old naming with / in ViT names + pretrained_orig = pretrained + pretrained = pretrained.lower() + if pretrained == "openai": + if amodel_name in _MODEL_CONFIGS: + logging.info(f"Loading {amodel_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) + else: + logging.error( + f"Model config for {amodel_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {amodel_name} not found.") + + logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") + # Hard Code in model name + model_cfg["text_cfg"]["model_type"] = tmodel_name + model = load_openai_model( + "ViT-B-16", + model_cfg, + device=device, + jit=jit, + cache_dir=openai_model_cache_dir, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 + if precision == "amp" or precision == "fp32": + model = model.float() + else: + if amodel_name in _MODEL_CONFIGS: + logging.info(f"Loading {amodel_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) + else: + logging.error( + f"Model config for {amodel_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {amodel_name} not found.") + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + # if pretrained_image: + # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): + # # pretrained weight loading for timm models set via vision_cfg + # model_cfg['vision_cfg']['timm_model_pretrained'] = True + # else: + # assert False, 'pretrained image towers currently only supported for timm models' + model_cfg["text_cfg"]["model_type"] = tmodel_name + model_cfg["enable_fusion"] = enable_fusion + model_cfg["fusion_type"] = fusion_type + model = CLAP(**model_cfg) + + if pretrained: + checkpoint_path = "" + url = get_pretrained_url(amodel_name, pretrained) + if url: + checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) + elif os.path.exists(pretrained_orig): + checkpoint_path = pretrained_orig + if checkpoint_path: + logging.info( + f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})." + ) + ckpt = load_state_dict(checkpoint_path, skip_params=True) + model.load_state_dict(ckpt) + param_names = [n for n, p in model.named_parameters()] + # for n in param_names: + # print(n, "\t", "Loaded" if n in ckpt else "Unloaded") + else: + logging.warning( + f"Pretrained weights ({pretrained}) not found for model {amodel_name}." + ) + raise RuntimeError( + f"Pretrained weights ({pretrained}) not found for model {amodel_name}." + ) + + if pretrained_audio: + if amodel_name.startswith("PANN"): + if "Cnn14_mAP" in pretrained_audio: # official checkpoint + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["model"] + keys = list(audio_ckpt.keys()) + for key in keys: + if ( + "spectrogram_extractor" not in key + and "logmel_extractor" not in key + ): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key] = v + elif os.path.basename(pretrained_audio).startswith( + "PANN" + ): # checkpoint trained via HTSAT codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model"): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "finetuned" + ): # checkpoint trained via linear probe codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + else: + raise ValueError("Unknown audio checkpoint") + elif amodel_name.startswith("HTSAT"): + if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model") and ( + "spectrogram_extractor" not in key + and "logmel_extractor" not in key + ): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "HTSAT" + ): # checkpoint trained via HTSAT codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model"): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "finetuned" + ): # checkpoint trained via linear probe codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + else: + raise ValueError("Unknown audio checkpoint") + else: + raise f"this audio encoder pretrained checkpoint is not support" + + model.load_state_dict(audio_ckpt, strict=False) + logging.info( + f"Loading pretrained {amodel_name} weights ({pretrained_audio})." + ) + param_names = [n for n, p in model.named_parameters()] + for n in param_names: + print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") + + model.to(device=device) + if precision == "fp16": + assert device.type != "cpu" + convert_weights_to_fp16(model) + + if jit: + model = torch.jit.script(model) + + return model, model_cfg + + +def create_model_and_transforms( + model_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + # pretrained_image: bool = False, +): + model = create_model( + model_name, + pretrained, + precision, + device, + jit, + force_quick_gelu=force_quick_gelu, + # pretrained_image=pretrained_image + ) + preprocess_train = image_transform(model.visual.image_size, is_train=True) + preprocess_val = image_transform(model.visual.image_size, is_train=False) + return model, preprocess_train, preprocess_val + + +def list_models(): + """enumerate available model architectures based on config files""" + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """add model config path or file and update registry""" + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() diff --git a/models/CLAP/open_clip/feature_fusion.py b/models/CLAP/open_clip/feature_fusion.py new file mode 100755 index 0000000000000000000000000000000000000000..dbe4e170e05894c12ebdc36ba1dc1de65e441b89 --- /dev/null +++ b/models/CLAP/open_clip/feature_fusion.py @@ -0,0 +1,192 @@ +""" +Feature Fusion for Varible-Length Data Processing +AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py +According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 +""" + +import torch +import torch.nn as nn + + +class DAF(nn.Module): + """ + 直接相加 DirectAddFuse + """ + + def __init__(self): + super(DAF, self).__init__() + + def forward(self, x, residual): + return x + residual + + +class iAFF(nn.Module): + """ + 多特征融合 iAFF + """ + + def __init__(self, channels=64, r=4, type="2D"): + super(iAFF, self).__init__() + inter_channels = int(channels // r) + + if type == "1D": + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + elif type == "2D": + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + else: + raise f"the type is not supported" + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = torch.cat([xa, xa], dim=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xi = x * wei + residual * (1 - wei) + + xl2 = self.local_att2(xi) + xg2 = self.global_att(xi) + xlg2 = xl2 + xg2 + wei2 = self.sigmoid(xlg2) + xo = x * wei2 + residual * (1 - wei2) + if flag: + xo = xo[0].unsqueeze(0) + return xo + + +class AFF(nn.Module): + """ + 多特征融合 AFF + """ + + def __init__(self, channels=64, r=4, type="2D"): + super(AFF, self).__init__() + inter_channels = int(channels // r) + + if type == "1D": + self.local_att = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + elif type == "2D": + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + else: + raise f"the type is not supported." + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = torch.cat([xa, xa], dim=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xo = 2 * x * wei + 2 * residual * (1 - wei) + if flag: + xo = xo[0].unsqueeze(0) + return xo diff --git a/models/CLAP/open_clip/htsat.py b/models/CLAP/open_clip/htsat.py new file mode 100755 index 0000000000000000000000000000000000000000..3b856c6a43df162116a941f1b5c76e93713b276a --- /dev/null +++ b/models/CLAP/open_clip/htsat.py @@ -0,0 +1,1308 @@ +# Ke Chen +# knutchen@ucsd.edu +# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION +# Some layers designed on the model +# below codes are based and referred from https://github.com/microsoft/Swin-Transformer +# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf + +import torch +import torch.nn as nn +import torch.nn.functional as F +from itertools import repeat +import collections.abc +import math +import warnings + +from torch.nn.init import _calculate_fan_in_and_fan_out +import torch.utils.checkpoint as checkpoint + +import random + +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from itertools import repeat +from .utils import do_mixup, interpolate + +from .feature_fusion import iAFF, AFF, DAF + +# from PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + patch_stride=16, + enable_fusion=False, + fusion_type="None", + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patch_stride = to_2tuple(patch_stride) + self.img_size = img_size + self.patch_size = patch_size + self.patch_stride = patch_stride + self.grid_size = ( + img_size[0] // patch_stride[0], + img_size[1] // patch_stride[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + padding = ( + (patch_size[0] - patch_stride[0]) // 2, + (patch_size[1] - patch_stride[1]) // 2, + ) + + if (self.enable_fusion) and (self.fusion_type == "channel_map"): + self.proj = nn.Conv2d( + in_chans * 4, + embed_dim, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + else: + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + self.mel_conv2d = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=(patch_size[0], patch_size[1] * 3), + stride=(patch_stride[0], patch_stride[1] * 3), + padding=padding, + ) + if self.fusion_type == "daf_2d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_2d": + self.fusion_model = AFF(channels=embed_dim, type="2D") + elif self.fusion_type == "iaff_2d": + self.fusion_model = iAFF(channels=embed_dim, type="2D") + + def forward(self, x, longer_idx=None): + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + global_x = x[:, 0:1, :, :] + + # global processing + B, C, H, W = global_x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + global_x = self.proj(global_x) + TW = global_x.size(-1) + if len(longer_idx) > 0: + # local processing + local_x = x[longer_idx, 1:, :, :].contiguous() + B, C, H, W = local_x.shape + local_x = local_x.view(B * C, 1, H, W) + local_x = self.mel_conv2d(local_x) + local_x = local_x.view( + B, C, local_x.size(1), local_x.size(2), local_x.size(3) + ) + local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3) + TB, TC, TH, _ = local_x.size() + if local_x.size(-1) < TW: + local_x = torch.cat( + [ + local_x, + torch.zeros( + (TB, TC, TH, TW - local_x.size(-1)), + device=global_x.device, + ), + ], + dim=-1, + ) + else: + local_x = local_x[:, :, :, :TW] + + global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x) + x = global_x + else: + B, C, H, W = x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}" + + +# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model +class SwinTransformerBlock(nn.Module): + r"""Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + norm_before_mlp="ln", + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.norm_before_mlp = norm_before_mlp + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + if self.norm_before_mlp == "ln": + self.norm2 = nn.LayerNorm(dim) + elif self.norm_before_mlp == "bn": + self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose( + 1, 2 + ) + else: + raise NotImplementedError + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + # pdb.set_trace() + H, W = self.input_resolution + # print("H: ", H) + # print("W: ", W) + # pdb.set_trace() + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows, attn = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x, attn + + def extra_repr(self): + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) + + +class PatchMerging(nn.Module): + r"""Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self): + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + norm_before_mlp="ln", + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + norm_before_mlp=norm_before_mlp, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer + ) + else: + self.downsample = None + + def forward(self, x): + attns = [] + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x, attn = blk(x) + if not self.training: + attns.append(attn.unsqueeze(0)) + if self.downsample is not None: + x = self.downsample(x) + if not self.training: + attn = torch.cat(attns, dim=0) + attn = torch.mean(attn, dim=0) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +# The Core of HTSAT +class HTSAT_Swin_Transformer(nn.Module): + r"""HTSAT based on the Swin Transformer + Args: + spec_size (int | tuple(int)): Input Spectrogram size. Default 256 + patch_size (int | tuple(int)): Patch size. Default: 4 + path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 + in_chans (int): Number of input image channels. Default: 1 (mono) + num_classes (int): Number of classes for classification head. Default: 527 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 8 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + config (module): The configuration Module from config.py + """ + + def __init__( + self, + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + in_chans=1, + num_classes=527, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + norm_before_mlp="ln", + config=None, + enable_fusion=False, + fusion_type="None", + **kwargs, + ): + super(HTSAT_Swin_Transformer, self).__init__() + + self.config = config + self.spec_size = spec_size + self.patch_stride = patch_stride + self.patch_size = patch_size + self.window_size = window_size + self.embed_dim = embed_dim + self.depths = depths + self.ape = ape + self.in_chans = in_chans + self.num_classes = num_classes + self.num_heads = num_heads + self.num_layers = len(self.depths) + self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) + + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + + self.qkv_bias = qkv_bias + self.qk_scale = None + + self.patch_norm = patch_norm + self.norm_layer = norm_layer if self.patch_norm else None + self.norm_before_mlp = norm_before_mlp + self.mlp_ratio = mlp_ratio + + self.use_checkpoint = use_checkpoint + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # process mel-spec ; used only once + self.freq_ratio = self.spec_size // self.config.mel_bins + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + self.interpolate_ratio = 32 # Downsampled ratio + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=config.window_size, + hop_length=config.hop_size, + win_length=config.window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=config.sample_rate, + n_fft=config.window_size, + n_mels=config.mel_bins, + fmin=config.fmin, + fmax=config.fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) # 2 2 + self.bn0 = nn.BatchNorm2d(self.config.mel_bins) + + # split spctrogram into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=self.spec_size, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + norm_layer=self.norm_layer, + patch_stride=patch_stride, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.grid_size + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dim) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=self.drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(self.embed_dim * 2**i_layer), + input_resolution=( + patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer), + ), + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_scale=self.qk_scale, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[ + sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1]) + ], + norm_layer=self.norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + norm_before_mlp=self.norm_before_mlp, + ) + self.layers.append(layer) + + self.norm = self.norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.maxpool = nn.AdaptiveMaxPool1d(1) + + SF = ( + self.spec_size + // (2 ** (len(self.depths) - 1)) + // self.patch_stride[0] + // self.freq_ratio + ) + self.tscam_conv = nn.Conv2d( + in_channels=self.num_features, + out_channels=self.num_classes, + kernel_size=(SF, 3), + padding=(0, 1), + ) + self.head = nn.Linear(num_classes, num_classes) + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] + ): + self.mel_conv1d = nn.Sequential( + nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1d(64), + ) + if self.fusion_type == "daf_1d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_1d": + self.fusion_model = AFF(channels=64, type="1D") + elif self.fusion_type == "iaff_1d": + self.fusion_model = iAFF(channels=64, type="1D") + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"absolute_pos_embed"} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"relative_position_bias_table"} + + def forward_features(self, x, longer_idx=None): + # A deprecated optimization for using a hierarchical output from different blocks + + frames_num = x.shape[2] + x = self.patch_embed(x, longer_idx=longer_idx) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + for i, layer in enumerate(self.layers): + x, attn = layer(x) + # for x + x = self.norm(x) + B, N, C = x.shape + SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] + ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] + x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST) + B, C, F, T = x.shape + # group 2D CNN + c_freq_bin = F // self.freq_ratio + x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) + x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1) + # get latent_output + fine_grained_latent_output = torch.mean(x, dim=2) + fine_grained_latent_output = interpolate( + fine_grained_latent_output.permute(0, 2, 1).contiguous(), + 8 * self.patch_stride[1], + ) + + latent_output = self.avgpool(torch.flatten(x, 2)) + latent_output = torch.flatten(latent_output, 1) + + # display the attention map, if needed + + x = self.tscam_conv(x) + x = torch.flatten(x, 2) # B, C, T + + fpx = interpolate( + torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1] + ) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + output_dict = { + "framewise_output": fpx, # already sigmoided + "clipwise_output": torch.sigmoid(x), + "fine_grained_embedding": fine_grained_latent_output, + "embedding": latent_output, + } + + return output_dict + + def crop_wav(self, x, crop_size, spe_pos=None): + time_steps = x.shape[2] + tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) + for i in range(len(x)): + if spe_pos is None: + crop_pos = random.randint(0, time_steps - crop_size - 1) + else: + crop_pos = spe_pos + tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :] + return tx + + # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model + def reshape_wav2img(self, x): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert ( + T <= target_T and F <= target_F + ), "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate( + x, (target_T, x.shape[3]), mode="bicubic", align_corners=True + ) + if F < target_F: + x = nn.functional.interpolate( + x, (x.shape[2], target_F), mode="bicubic", align_corners=True + ) + x = x.permute(0, 1, 3, 2).contiguous() + x = x.reshape( + x.shape[0], + x.shape[1], + x.shape[2], + self.freq_ratio, + x.shape[3] // self.freq_ratio, + ) + # print(x.shape) + x = x.permute(0, 1, 3, 2, 4).contiguous() + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) + return x + + # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model + def repeat_wat2img(self, x, cur_pos): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert ( + T <= target_T and F <= target_F + ), "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate( + x, (target_T, x.shape[3]), mode="bicubic", align_corners=True + ) + if F < target_F: + x = nn.functional.interpolate( + x, (x.shape[2], target_F), mode="bicubic", align_corners=True + ) + x = x.permute(0, 1, 3, 2).contiguous() # B C F T + x = x[:, :, :, cur_pos : cur_pos + self.spec_size] + x = x.repeat(repeats=(1, 1, 4, 1)) + return x + + def forward( + self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None + ): # out_feat_keys: List[str] = None): + + if self.enable_fusion and x["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True + + if not self.enable_fusion: + x = x["waveform"].to(device=device, non_blocking=True) + x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.training: + x = self.spec_augmenter(x) + + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x) + else: + longer_list = x["longer"].to(device=device, non_blocking=True) + x = x["mel_fusion"].to(device=device, non_blocking=True) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + longer_list_idx = torch.where(longer_list)[0] + if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: + new_x = x[:, 0:1, :, :].clone().contiguous() + if len(longer_list_idx) > 0: + # local processing + fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() + FB, FC, FT, FF = fusion_x_local.size() + fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) + fusion_x_local = torch.permute( + fusion_x_local, (0, 2, 1) + ).contiguous() + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.view( + FB, FC, FF, fusion_x_local.size(-1) + ) + fusion_x_local = ( + torch.permute(fusion_x_local, (0, 2, 1, 3)) + .contiguous() + .flatten(2) + ) + if fusion_x_local.size(-1) < FT: + fusion_x_local = torch.cat( + [ + fusion_x_local, + torch.zeros( + (FB, FF, FT - fusion_x_local.size(-1)), + device=device, + ), + ], + dim=-1, + ) + else: + fusion_x_local = fusion_x_local[:, :, :FT] + # 1D fusion + new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() + new_x[longer_list_idx] = self.fusion_model( + new_x[longer_list_idx], fusion_x_local + ) + x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] + else: + x = new_x + + elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x, longer_idx=longer_list_idx) + + # if infer_mode: + # # in infer mode. we need to handle different length audio input + # frame_num = x.shape[2] + # target_T = int(self.spec_size * self.freq_ratio) + # repeat_ratio = math.floor(target_T / frame_num) + # x = x.repeat(repeats=(1,1,repeat_ratio,1)) + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # else: + # if x.shape[2] > self.freq_ratio * self.spec_size: + # if self.training: + # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # else: + # # Change: Hard code here + # overlap_size = (x.shape[2] - 1) // 4 + # output_dicts = [] + # crop_size = (x.shape[2] - 1) // 2 + # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): + # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) + # tx = self.reshape_wav2img(tx) + # output_dicts.append(self.forward_features(tx)) + # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) + # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) + # for d in output_dicts: + # clipwise_output += d["clipwise_output"] + # framewise_output += d["framewise_output"] + # clipwise_output = clipwise_output / len(output_dicts) + # framewise_output = framewise_output / len(output_dicts) + # output_dict = { + # 'framewise_output': framewise_output, + # 'clipwise_output': clipwise_output + # } + # else: # this part is typically used, and most easy one + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # x = self.head(x) + + # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T + + return output_dict + + +def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"): + try: + + assert audio_cfg.model_name in [ + "tiny", + "base", + "large", + ], "model name for HTS-AT is wrong!" + if audio_cfg.model_name == "tiny": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + elif audio_cfg.model_name == "base": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=128, + depths=[2, 2, 12, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + elif audio_cfg.model_name == "large": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=256, + depths=[2, 2, 12, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + + return model + except: + raise RuntimeError( + f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." + ) diff --git a/models/CLAP/open_clip/linear_probe.py b/models/CLAP/open_clip/linear_probe.py new file mode 100755 index 0000000000000000000000000000000000000000..9d7e23b6b67a53e16d050d675a99d01d7d04d581 --- /dev/null +++ b/models/CLAP/open_clip/linear_probe.py @@ -0,0 +1,66 @@ +import numpy as np +import torch.nn.functional as F +from torch import nn +from .model import MLPLayers + + +class LinearProbe(nn.Module): + def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): + """ + Args: + model: nn.Module + mlp: bool, if True, then use the MLP layer as the linear probe module + freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe + in_ch: int, the output channel from CLAP model + out_ch: int, the output channel from linear probe (class_num) + act: torch.nn.functional, the activation function before the loss function + """ + super().__init__() + in_ch = 512 + self.clap_model = model + self.clap_model.text_branch = None # to save memory + self.freeze = freeze + if mlp: + self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) + else: + self.lp_layer = nn.Linear(in_ch, out_ch) + + if self.freeze: + for param in self.clap_model.parameters(): + param.requires_grad = False + + if act == "None": + self.act = None + elif act == "relu": + self.act = nn.ReLU() + elif act == "elu": + self.act = nn.ELU() + elif act == "prelu": + self.act = nn.PReLU(num_parameters=in_ch) + elif act == "softmax": + self.act = nn.Softmax(dim=-1) + elif act == "sigmoid": + self.act = nn.Sigmoid() + + def forward(self, x, mix_lambda=None, device=None): + """ + Args: + x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list + mix_lambda: torch.tensor [batch], the mixup lambda + Returns: + class_prob: torch.tensor [batch, class_num] + + """ + # batchnorm cancel grandient + if self.freeze: + self.clap_model.eval() + + x = self.clap_model.audio_projection( + self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ + "embedding" + ] + ) + out = self.lp_layer(x) + if self.act is not None: + out = self.act(out) + return out diff --git a/models/CLAP/open_clip/loss.py b/models/CLAP/open_clip/loss.py new file mode 100755 index 0000000000000000000000000000000000000000..cc66298a14997da4aa2efc71e37c0a6bcda53fd1 --- /dev/null +++ b/models/CLAP/open_clip/loss.py @@ -0,0 +1,398 @@ +from multiprocessing.sharedctypes import Value +import torch +import torch.distributed.nn +from torch import distributed as dist, nn as nn +from torch.nn import functional as F +import numpy as np +from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def gather_features( + audio_features, + text_features, + audio_features_mlp=None, + text_features_mlp=None, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False, + mlp_loss=False, +): + if use_horovod: + assert hvd is not None, "Please install horovod" + if gather_with_grad: + all_audio_features = hvd.allgather(audio_features) + all_text_features = hvd.allgather(text_features) + if mlp_loss: + all_audio_features_mlp = hvd.allgather(audio_features_mlp) + all_text_features_mlp = hvd.allgather(text_features_mlp) + else: + with torch.no_grad(): + all_audio_features = hvd.allgather(audio_features) + all_text_features = hvd.allgather(text_features) + if mlp_loss: + all_audio_features_mlp = hvd.allgather(audio_features_mlp) + all_text_features_mlp = hvd.allgather(text_features_mlp) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features = list( + all_audio_features.chunk(world_size, dim=0) + ) + gathered_text_features = list( + all_text_features.chunk(world_size, dim=0) + ) + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + if mlp_loss: + gathered_audio_features_mlp = list( + all_audio_features_mlp.chunk(world_size, dim=0) + ) + gathered_text_features_mlp = list( + all_text_features_mlp.chunk(world_size, dim=0) + ) + gathered_audio_features_mlp[rank] = audio_features_mlp + gathered_text_features_mlp[rank] = text_features_mlp + all_audio_features_mlp = torch.cat( + gathered_audio_features_mlp, dim=0 + ) + all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_audio_features = torch.cat( + torch.distributed.nn.all_gather(audio_features), dim=0 + ) + all_text_features = torch.cat( + torch.distributed.nn.all_gather(text_features), dim=0 + ) + if mlp_loss: + all_audio_features_mlp = torch.cat( + torch.distributed.nn.all_gather(audio_features_mlp), dim=0 + ) + all_text_features_mlp = torch.cat( + torch.distributed.nn.all_gather(text_features_mlp), dim=0 + ) + else: + gathered_audio_features = [ + torch.zeros_like(audio_features) for _ in range(world_size) + ] + gathered_text_features = [ + torch.zeros_like(text_features) for _ in range(world_size) + ] + dist.all_gather(gathered_audio_features, audio_features) + dist.all_gather(gathered_text_features, text_features) + if mlp_loss: + gathered_audio_features_mlp = [ + torch.zeros_like(audio_features_mlp) for _ in range(world_size) + ] + gathered_text_features_mlp = [ + torch.zeros_like(text_features_mlp) for _ in range(world_size) + ] + dist.all_gather(gathered_audio_features_mlp, audio_features_mlp) + dist.all_gather(gathered_text_features_mlp, text_features_mlp) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + if mlp_loss: + gathered_audio_features_mlp[rank] = audio_features_mlp + gathered_text_features_mlp[rank] = text_features_mlp + + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + if mlp_loss: + all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) + all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) + if mlp_loss: + return ( + all_audio_features, + all_text_features, + all_audio_features_mlp, + all_text_features_mlp, + ) + else: + return all_audio_features, all_text_features + + +class ClipLoss(nn.Module): + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + mlp_loss=False, + weight_loss_kappa=0, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + self.mlp_loss = mlp_loss + self.weighted_loss = bool(weight_loss_kappa != 0) + self.weight_loss_kappa = weight_loss_kappa + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward( + self, + audio_features, + text_features, + logit_scale_a, + logit_scale_t=None, + audio_features_mlp=None, + text_features_mlp=None, + ): + device = audio_features.device + if self.mlp_loss: + if self.world_size > 1: + ( + all_audio_features, + all_text_features, + all_audio_features_mlp, + all_text_features_mlp, + ) = gather_features( + audio_features=audio_features, + text_features=text_features, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + use_horovod=self.use_horovod, + mlp_loss=self.mlp_loss, + ) + if self.local_loss: + a_logits_per_audio = ( + logit_scale_a * audio_features @ all_text_features_mlp.T + ) + a_logits_per_text = ( + logit_scale_a * text_features_mlp @ all_audio_features.T + ) + t_logits_per_audio = ( + logit_scale_t * audio_features_mlp @ all_text_features.T + ) + t_logits_per_text = ( + logit_scale_t * text_features @ all_audio_features_mlp.T + ) + else: + a_logits_per_audio = ( + logit_scale_a * all_audio_features @ all_text_features_mlp.T + ) + a_logits_per_text = a_logits_per_audio.T + t_logits_per_audio = ( + logit_scale_t * all_audio_features_mlp @ all_text_features.T + ) + t_logits_per_text = t_logits_per_audio.T + else: + a_logits_per_audio = ( + logit_scale_a * audio_features @ text_features_mlp.T + ) + a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T + t_logits_per_audio = ( + logit_scale_t * audio_features_mlp @ text_features.T + ) + t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T + + # calculated ground-truth and cache if enabled + num_logits = a_logits_per_audio.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + if not self.weighted_loss: + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels) + + F.cross_entropy(a_logits_per_text, labels) + + F.cross_entropy(t_logits_per_audio, labels) + + F.cross_entropy(t_logits_per_text, labels) + ) / 4 + else: + audio_weight = (audio_features @ audio_features.T).detach() + audio_weight = ( + torch.exp( + torch.sum(audio_weight, axis=1) + / (self.weight_loss_kappa * len(audio_weight)) + ) + ).detach() + text_weight = (text_features @ text_features.T).detach() + text_weight = ( + torch.exp( + torch.sum(text_weight, axis=1) + / (self.weight_loss_kappa * len(text_features)) + ) + ).detach() + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) + + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) + + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) + + F.cross_entropy(t_logits_per_text, labels, weight=text_weight) + ) / 4 + else: + if self.world_size > 1: + all_audio_features, all_text_features = gather_features( + audio_features=audio_features, + text_features=text_features, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + use_horovod=self.use_horovod, + mlp_loss=self.mlp_loss, + ) + + if self.local_loss: + logits_per_audio = ( + logit_scale_a * audio_features @ all_text_features.T + ) + logits_per_text = ( + logit_scale_a * text_features @ all_audio_features.T + ) + else: + logits_per_audio = ( + logit_scale_a * all_audio_features @ all_text_features.T + ) + logits_per_text = logits_per_audio.T + else: + logits_per_audio = logit_scale_a * audio_features @ text_features.T + logits_per_text = logit_scale_a * text_features @ audio_features.T + + # calculated ground-truth and cache if enabled + num_logits = logits_per_audio.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + if not self.weighted_loss: + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + else: + audio_weight = (all_audio_features @ all_audio_features.T).detach() + audio_weight = ( + torch.exp( + torch.sum(audio_weight, axis=1) + / (self.weight_loss_kappa * len(all_audio_features)) + ) + ).detach() + text_weight = (all_text_features @ all_text_features.T).detach() + text_weight = ( + torch.exp( + torch.sum(text_weight, axis=1) + / (self.weight_loss_kappa * len(all_text_features)) + ) + ).detach() + total_loss = ( + F.cross_entropy(logits_per_audio, labels, weight=text_weight) + + F.cross_entropy(logits_per_text, labels, weight=audio_weight) + ) / 2 + return total_loss + + +def lp_gather_features(pred, target, world_size=1, use_horovod=False): + if use_horovod: + assert hvd is not None, "Please install horovod" + with torch.no_grad(): + all_preds = hvd.allgather(pred) + all_targets = hvd.allgath(target) + else: + gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] + gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] + + dist.all_gather(gathered_preds, pred) + dist.all_gather(gathered_targets, target) + all_preds = torch.cat(gathered_preds, dim=0) + all_targets = torch.cat(gathered_targets, dim=0) + + return all_preds, all_targets + + +def get_map(pred, target): + pred = torch.sigmoid(pred).numpy() + target = target.numpy() + return np.mean(average_precision_score(target, pred, average=None)) + + +def get_acc(pred, target): + pred = torch.argmax(pred, 1).numpy() + target = torch.argmax(target, 1).numpy() + return accuracy_score(target, pred) + + +def get_mauc(pred, target): + pred = torch.sigmoid(pred).numpy() + target = target.numpy() + return np.mean(roc_auc_score(target, pred, average=None)) + + +class LPMetrics(object): + def __init__(self, metric_names=["map", "acc", "mauc"]): + self.metrics = [] + for name in metric_names: + self.metrics.append(self.get_metric(name)) + self.metric_names = metric_names + + def get_metric(self, name): + if name == "map": + return get_map + elif name == "acc": + return get_acc + elif name == "mauc": + return get_mauc + else: + raise ValueError(f"the metric should be at least one of [map, acc, mauc]") + + def evaluate_mertics(self, pred, target): + metric_dict = {} + for i in range(len(self.metric_names)): + metric_dict[self.metric_names[i]] = self.metrics[i](pred, target) + return metric_dict + + +def calc_celoss(pred, target): + target = torch.argmax(target, 1).long() + return nn.CrossEntropyLoss()(pred, target) + + +class LPLoss(nn.Module): + def __init__(self, loss_name): + super().__init__() + if loss_name == "bce": + self.loss_func = nn.BCEWithLogitsLoss() + elif loss_name == "ce": + self.loss_func = calc_celoss + elif loss_name == "mse": + self.loss_func = nn.MSELoss() + else: + raise ValueError(f"the loss func should be at least one of [bce, ce, mse]") + + def forward(self, pred, target): + loss = self.loss_func(pred, target) + return loss diff --git a/models/CLAP/open_clip/model.py b/models/CLAP/open_clip/model.py new file mode 100755 index 0000000000000000000000000000000000000000..5677da7ec2cebaa44c9328ece4873359f459426a --- /dev/null +++ b/models/CLAP/open_clip/model.py @@ -0,0 +1,935 @@ +""" CLAP Model + +Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +Adapted to the Audio Task. +""" + +from collections import OrderedDict +from dataclasses import dataclass +from email.mime import audio +from typing import Tuple, Union, Callable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from .timm_model import TimmModel +import logging +from .utils import freeze_batch_norm_2d + +from .pann_model import create_pann_model +from .htsat import create_htsat_model +from transformers import BertModel, RobertaModel, BartModel, RobertaConfig +from transformers.tokenization_utils_base import BatchEncoding + + +class MLPLayers(nn.Module): + def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1): + super(MLPLayers, self).__init__() + self.nonlin = nonlin + self.dropout = dropout + + sequence = [] + for u0, u1 in zip(units[:-1], units[1:]): + sequence.append(nn.Linear(u0, u1)) + sequence.append(self.nonlin) + sequence.append(nn.Dropout(self.dropout)) + sequence = sequence[:-2] + + self.sequential = nn.Sequential(*sequence) + + def forward(self, X): + X = self.sequential(X) + return X + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( + 2, 0, 1 + ) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + def stem(self, x): + for conv, bn in [ + (self.conv1, self.bn1), + (self.conv2, self.bn2), + (self.conv3, self.bn3), + ]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock(width, heads, act_layer=act_layer) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisualTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + act_layer: Callable = nn.GELU, + ): + super().__init__() + self.image_size = image_size + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn((image_size // patch_size) ** 2 + 1, width) + ) + self.ln_pre = LayerNorm(width) + + self.text_branch = Transformer(width, layers, heads, act_layer=act_layer) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_branch(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +@dataclass +class CLAPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + timm_model_name: str = ( + None # a valid model name overrides layers, width, patch_size + ) + timm_model_pretrained: bool = ( + False # use (imagenet) pretrained weights for named model + ) + timm_pool: str = ( + "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + ) + timm_proj: str = ( + "linear" # linear projection for timm model output ('linear', 'mlp', '') + ) + + +# Audio Config Class +@dataclass +class CLAPAudioCfp: + model_type: str = "PANN" + model_name: str = "Cnn14" + sample_rate: int = 48000 + # Param + audio_length: int = 1024 + window_size: int = 1024 + hop_size: int = 1024 + fmin: int = 50 + fmax: int = 14000 + class_num: int = 527 + mel_bins: int = 64 + clip_samples: int = 480000 + + +@dataclass +class CLAPTextCfg: + context_length: int + vocab_size: int + width: int + heads: int + layers: int + model_type: str + + +class CLAP(nn.Module): + def __init__( + self, + embed_dim: int, + audio_cfg: CLAPAudioCfp, + text_cfg: CLAPTextCfg, + quick_gelu: bool = False, + enable_fusion: bool = False, + fusion_type: str = "None", + joint_embed_shape: int = 512, + mlp_act: str = "relu", + ): + super().__init__() + if isinstance(audio_cfg, dict): + audio_cfg = CLAPAudioCfp(**audio_cfg) + if isinstance(text_cfg, dict): + text_cfg = CLAPTextCfg(**text_cfg) + + self.audio_cfg = audio_cfg + self.text_cfg = text_cfg + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + self.joint_embed_shape = joint_embed_shape + self.mlp_act = mlp_act + + self.context_length = text_cfg.context_length + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if mlp_act == "relu": + mlp_act_layer = nn.ReLU() + elif mlp_act == "gelu": + mlp_act_layer = nn.GELU() + else: + raise NotImplementedError + + # audio branch + # audio branch parameters + if audio_cfg.model_type == "PANN": + self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type) + elif audio_cfg.model_type == "HTSAT": + self.audio_branch = create_htsat_model( + audio_cfg, enable_fusion, fusion_type + ) + else: + logging.error(f"Model config for {audio_cfg.model_type} not found") + raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.") + + # text branch + # text branch parameters + if text_cfg.model_type == "transformer": + self.text_branch = Transformer( + width=text_cfg.width, + layers=text_cfg.layers, + heads=text_cfg.heads, + act_layer=act_layer, + ) + self.vocab_size = text_cfg.vocab_size + self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, text_cfg.width) + ) + self.ln_final = LayerNorm(text_cfg.width) + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(text_cfg.width, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "bert": + self.text_branch = BertModel.from_pretrained("bert-base-uncased") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "roberta": + self.text_branch = RobertaModel.from_pretrained("roberta-base") + + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "bart": + self.text_branch = BartModel.from_pretrained("facebook/bart-base") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + else: + logging.error(f"Model config for {text_cfg.model_type} not found") + raise RuntimeError(f"Model config for {text_cfg.model_type} not found.") + self.text_branch_type = text_cfg.model_type + # text branch parameters + + # audio branch parameters + self.audio_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + + # below here is text branch parameters + + # ============================================================================================================ + self.audio_projection = nn.Sequential( + nn.Linear(embed_dim, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + + self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) + + self.init_text_branch_parameters() + + def init_text_branch_parameters(self): + if self.text_branch_type == "transformer": + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + proj_std = (self.text_branch.width**-0.5) * ( + (2 * self.text_branch.layers) ** -0.5 + ) + attn_std = self.text_branch.width**-0.5 + fc_std = (2 * self.text_branch.width) ** -0.5 + for block in self.text_branch.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + if self.text_branch_type == "bert" or self.text_branch_type == "roberta": + width = self.text_branch.embeddings.word_embeddings.weight.shape[-1] + elif self.text_branch_type == "bart": + width = self.text_branch.shared.weight.shape[-1] + else: + width = self.text_branch.width + nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07)) + nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07)) + + # deprecated + # if hasattr(self.visual, 'init_parameters'): + # self.visual.init_parameters() + + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_audio(self, audio, device): + return self.audio_branch( + audio, mixup_lambda=None, device=device + ) # mix lambda needs to add + + # def list_of_dict_of_tensor2dict_of_tensor(self, x, device): + # tmp = {} + # for k in x[0].keys(): + # tmp[k] = [] + # for i in range(len(x)): + # tmp[k].append(x[i][k][:77]) + # for k in x[0].keys(): + # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True) + # return tmp + + def encode_text(self, text, device): + if self.text_branch_type == "transformer": + text = text.to(device=device, non_blocking=True) + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_branch(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)]) + elif self.text_branch_type == "bert": + # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device) + # text = BatchEncoding(text) + x = self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + token_type_ids=text["token_type_ids"].to( + device=device, non_blocking=True + ), + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "roberta": + x = self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "bart": + x = torch.mean( + self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + )["encoder_last_hidden_state"], + axis=1, + ) + x = self.text_projection(x) + else: + logging.error(f"Model type {self.text_branch_type} not found") + raise RuntimeError(f"Model type {self.text_branch_type} not found.") + return x + + def forward(self, audio, text, device=None): + """Forward audio and text into the CLAP + + Parameters + ---------- + audio: torch.Tensor (batch_size, audio_length) + the time-domain audio input / the batch of mel_spec and longer list. + text: torch.Tensor () // need to add + the text token input + """ + if device is None: + if audio is not None: + device = audio.device + elif text is not None: + device = text.device + if audio is None and text is None: + # a hack to get the logit scale + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + elif audio is None: + return self.encode_text(text, device=device) + elif text is None: + return self.audio_projection( + self.encode_audio(audio, device=device)["embedding"] + ) + audio_features = self.audio_projection( + self.encode_audio(audio, device=device)["embedding"] + ) + audio_features = F.normalize(audio_features, dim=-1) + + text_features = self.encode_text(text, device=device) + # print("text_features", text_features) + # print("text_features.shape", text_features.shape) + # print("text_features.type", type(text_features)) + text_features = F.normalize(text_features, dim=-1) + + audio_features_mlp = self.audio_transform(audio_features) + text_features_mlp = self.text_transform(text_features) + # Four outputs: audio features (basic & MLP), text features (basic & MLP) + return ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + self.logit_scale_a.exp(), + self.logit_scale_t.exp(), + ) + + def get_logit_scale(self): + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + + def get_text_embedding(self, data): + """Get the text embedding from the model + + Parameters + ---------- + data: torch.Tensor + a tensor of text embedding + + Returns + ---------- + text_embed: torch.Tensor + a tensor of text_embeds (N, D) + + """ + device = next(self.parameters()).device + for k in data: + data[k] = data[k].to(device) + text_embeds = self.encode_text(data, device=device) + text_embeds = F.normalize(text_embeds, dim=-1) + + return text_embeds + + def get_audio_embedding(self, data): + """Get the audio embedding from the model + + Parameters + ---------- + data: a list of dict + the audio input dict list from 'get_audio_feature' method + + Returns + ---------- + audio_embed: torch.Tensor + a tensor of audio_embeds (N, D) + + """ + device = next(self.parameters()).device + input_dict = {} + keys = data[0].keys() + for k in keys: + input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to( + device + ) + + audio_embeds = self.audio_projection( + self.encode_audio(input_dict, device=device)["embedding"] + ) + audio_embeds = F.normalize(audio_embeds, dim=-1) + + return audio_embeds + + def audio_infer(self, audio, hopsize=None, device=None): + """Forward one audio and produce the audio embedding + + Parameters + ---------- + audio: (audio_length) + the time-domain audio input, notice that it must be only one input + hopsize: int + the overlap hopsize as the sliding window + + Returns + ---------- + output_dict: { + key: [n, (embedding_shape)] if "HTS-AT" + or + key: [(embedding_shape)] if "PANN" + } + the list of key values of the audio branch + + """ + + assert not self.training, "the inference mode must be run at eval stage" + output_dict = {} + # PANN + if self.audio_cfg.model_type == "PANN": + audio_input = audio.unsqueeze(dim=0) + output_dict[key] = self.encode_audio(audio_input, device=device)[ + key + ].squeeze(dim=0) + elif self.audio_cfg.model_type == "HTSAT": + # repeat + audio_len = len(audio) + k = self.audio_cfg.clip_samples // audio_len + if k > 1: + audio = audio.repeat(k) + audio_len = len(audio) + + if hopsize is None: + hopsize = min(hopsize, audio_len) + + if audio_len > self.audio_cfg.clip_samples: + audio_input = [ + audio[pos : pos + self.audio_cfg.clip_samples].clone() + for pos in range( + 0, audio_len - self.audio_cfg.clip_samples, hopsize + ) + ] + audio_input.append(audio[-self.audio_cfg.clip_samples :].clone()) + audio_input = torch.stack(audio_input) + output_dict[key] = self.encode_audio(audio_input, device=device)[key] + else: + audio_input = audio.unsqueeze(dim=0) + output_dict[key] = self.encode_audio(audio_input, device=device)[ + key + ].squeeze(dim=0) + + return output_dict + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [ + *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", + "bias_k", + "bias_v", + ]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +# Ignore the state dict of the vision part +def build_model_from_openai_state_dict( + state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None" +): + + embed_dim = model_cfg["embed_dim"] + audio_cfg = model_cfg["audio_cfg"] + text_cfg = model_cfg["text_cfg"] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"transformer.resblocks") + ) + ) + + audio_cfg = CLAPAudioCfp(**audio_cfg) + text_cfg = CLAPTextCfg(**text_cfg) + + model = CLAP( + embed_dim, + audio_cfg=audio_cfg, + text_cfg=text_cfg, + quick_gelu=True, # OpenAI models were trained with QuickGELU + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + state_dict["logit_scale_a"] = state_dict["logit_scale"] + state_dict["logit_scale_t"] = state_dict["logit_scale"] + pop_keys = list(state_dict.keys())[::] + # pop the visual branch saved weights + for key in pop_keys: + if key.startswith("visual."): + state_dict.pop(key, None) + + for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + # not use fp16 + # convert_weights_to_fp16(model) + model.load_state_dict(state_dict, strict=False) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device("cpu")): + model.eval() + audio_length = model.audio_cfg.audio_length + example_audio = torch.ones((batch_size, audio_length), device=device) + example_text = torch.zeros( + (batch_size, model.context_length), dtype=torch.int, device=device + ) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_audio, example_text), + encode_text=(example_text,), + encode_image=(example_audio,), + ), + ) + model.audio_cfg.audio_length = audio_length # Question: what does this do? + return model diff --git a/models/CLAP/open_clip/model_configs/HTSAT-base.json b/models/CLAP/open_clip/model_configs/HTSAT-base.json new file mode 100755 index 0000000000000000000000000000000000000000..6cef625a89daf4431f1c9f72e10bc9640eef2ba8 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/HTSAT-base.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 1024, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "base" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/HTSAT-large.json b/models/CLAP/open_clip/model_configs/HTSAT-large.json new file mode 100755 index 0000000000000000000000000000000000000000..699cdb1b16855582606551e4196b24aba2ffd871 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/HTSAT-large.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "large" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/HTSAT-tiny-win-1536.json b/models/CLAP/open_clip/model_configs/HTSAT-tiny-win-1536.json new file mode 100755 index 0000000000000000000000000000000000000000..73e42990fe8361a0df502e7f93d29f19f58c9ecb --- /dev/null +++ b/models/CLAP/open_clip/model_configs/HTSAT-tiny-win-1536.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1536, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "tiny" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/HTSAT-tiny.json b/models/CLAP/open_clip/model_configs/HTSAT-tiny.json new file mode 100755 index 0000000000000000000000000000000000000000..a6e7821163d9afa81c27345a1e472475b92af169 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/HTSAT-tiny.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "tiny" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/PANN-10.json b/models/CLAP/open_clip/model_configs/PANN-10.json new file mode 100755 index 0000000000000000000000000000000000000000..954ddf62921aed7dde9c37ffffec98a2e96a4ee7 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/PANN-10.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 1024, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn10" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/PANN-14-fmax-18k.json b/models/CLAP/open_clip/model_configs/PANN-14-fmax-18k.json new file mode 100755 index 0000000000000000000000000000000000000000..b7989bc0cd95d0d39049b7524eba508b3e386439 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/PANN-14-fmax-18k.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 18000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/PANN-14-fmax-8k-20s.json b/models/CLAP/open_clip/model_configs/PANN-14-fmax-8k-20s.json new file mode 100755 index 0000000000000000000000000000000000000000..56bdb56bedc304ffa52d8bf5988cea2c1d82d14e --- /dev/null +++ b/models/CLAP/open_clip/model_configs/PANN-14-fmax-8k-20s.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 960000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 360, + "fmin": 50, + "fmax": 8000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/PANN-14-tiny-transformer.json b/models/CLAP/open_clip/model_configs/PANN-14-tiny-transformer.json new file mode 100755 index 0000000000000000000000000000000000000000..5756e3bebc97cc985f512cb081930fee4e49bec1 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/PANN-14-tiny-transformer.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 4 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/PANN-14-win-1536.json b/models/CLAP/open_clip/model_configs/PANN-14-win-1536.json new file mode 100755 index 0000000000000000000000000000000000000000..5a9e7e208b661619d5e26625e849da1adda8a475 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/PANN-14-win-1536.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1536, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/PANN-14.json b/models/CLAP/open_clip/model_configs/PANN-14.json new file mode 100755 index 0000000000000000000000000000000000000000..39a5134cde1d8c50f4758377c952ef22f07bab41 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/PANN-14.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/PANN-6.json b/models/CLAP/open_clip/model_configs/PANN-6.json new file mode 100755 index 0000000000000000000000000000000000000000..21ebc344326de260c386ba77e0ad63cf9b04febf --- /dev/null +++ b/models/CLAP/open_clip/model_configs/PANN-6.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 512, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn6" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/RN101-quickgelu.json b/models/CLAP/open_clip/model_configs/RN101-quickgelu.json new file mode 100755 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/RN101.json b/models/CLAP/open_clip/model_configs/RN101.json new file mode 100755 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/RN50-quickgelu.json b/models/CLAP/open_clip/model_configs/RN50-quickgelu.json new file mode 100755 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/models/CLAP/open_clip/model_configs/RN50.json b/models/CLAP/open_clip/model_configs/RN50.json new file mode 100755 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/RN50x16.json b/models/CLAP/open_clip/model_configs/RN50x16.json new file mode 100755 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/models/CLAP/open_clip/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/RN50x4.json b/models/CLAP/open_clip/model_configs/RN50x4.json new file mode 100755 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/models/CLAP/open_clip/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/ViT-B-16.json b/models/CLAP/open_clip/model_configs/ViT-B-16.json new file mode 100755 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/models/CLAP/open_clip/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/ViT-B-32-quickgelu.json b/models/CLAP/open_clip/model_configs/ViT-B-32-quickgelu.json new file mode 100755 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/ViT-B-32.json b/models/CLAP/open_clip/model_configs/ViT-B-32.json new file mode 100755 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/models/CLAP/open_clip/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/model_configs/ViT-L-14.json b/models/CLAP/open_clip/model_configs/ViT-L-14.json new file mode 100755 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/models/CLAP/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/models/CLAP/open_clip/openai.py b/models/CLAP/open_clip/openai.py new file mode 100755 index 0000000000000000000000000000000000000000..3f4eb8b55fe960e1792b3da804b60b3d8f70fe26 --- /dev/null +++ b/models/CLAP/open_clip/openai.py @@ -0,0 +1,156 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import Union, List + +import torch + +from .model import build_model_from_openai_state_dict +from .pretrained import ( + get_pretrained_url, + list_pretrained_tag_models, + download_pretrained, +) + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_tag_models("openai") + + +def load_openai_model( + name: str, + model_cfg, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit=True, + cache_dir=os.path.expanduser("~/.cache/clip"), + enable_fusion: bool = False, + fusion_type: str = "None", +): + """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + + Returns + ------- + model : torch.nn.Module + The CLAP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if get_pretrained_url(name, "openai"): + model_path = download_pretrained( + get_pretrained_url(name, "openai"), root=cache_dir + ) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError( + f"Model {name} not found; available models = {list_openai_models()}" + ) + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn( + f"File {model_path} is not a JIT archive. Loading as a state dict instead" + ) + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + try: + model = build_model_from_openai_state_dict( + state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type + ).to(device) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict( + sd, model_cfg, enable_fusion, fusion_type + ).to(device) + + if str(device) == "cpu": + model.float() + return model + + # patch the device names + device_holder = torch.jit.trace( + lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] + ) + device_node = [ + n + for n in device_holder.graph.findAllNodes("prim::Constant") + if "Device" in repr(n) + ][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith( + "cuda" + ): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_audio) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace( + lambda: torch.ones([]).float(), example_inputs=[] + ) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [ + 1, + 2, + ]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_audio) + patch_float(model.encode_text) + model.float() + + model.audio_branch.audio_length = model.audio_cfg.audio_length + return model diff --git a/models/CLAP/open_clip/pann_model.py b/models/CLAP/open_clip/pann_model.py new file mode 100755 index 0000000000000000000000000000000000000000..0d9a8eb0bf897ad6ec04923361b01e5de433b2ef --- /dev/null +++ b/models/CLAP/open_clip/pann_model.py @@ -0,0 +1,704 @@ +# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition +# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn +# Some layers are re-designed for CLAP +import os + +os.environ["NUMBA_CACHE_DIR"] = "/tmp/" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from .utils import do_mixup, interpolate, pad_framewise_output +from .feature_fusion import iAFF, AFF, DAF + + +def init_layer(layer): + """Initialize a Linear or Convolutional layer.""" + nn.init.xavier_uniform_(layer.weight) + + if hasattr(layer, "bias"): + if layer.bias is not None: + layer.bias.data.fill_(0.0) + + +def init_bn(bn): + """Initialize a Batchnorm layer.""" + bn.bias.data.fill_(0.0) + bn.weight.data.fill_(1.0) + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + + def forward(self, input, pool_size=(2, 2), pool_type="avg"): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + x = F.relu_(self.bn2(self.conv2(x))) + if pool_type == "max": + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg": + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg+max": + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception("Incorrect argument!") + + return x + + +class ConvBlock5x5(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock5x5, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(5, 5), + stride=(1, 1), + padding=(2, 2), + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_bn(self.bn1) + + def forward(self, input, pool_size=(2, 2), pool_type="avg"): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + if pool_type == "max": + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg": + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg+max": + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception("Incorrect argument!") + + return x + + +class AttBlock(nn.Module): + def __init__(self, n_in, n_out, activation="linear", temperature=1.0): + super(AttBlock, self).__init__() + + self.activation = activation + self.temperature = temperature + self.att = nn.Conv1d( + in_channels=n_in, + out_channels=n_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + self.cla = nn.Conv1d( + in_channels=n_in, + out_channels=n_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + + self.bn_att = nn.BatchNorm1d(n_out) + self.init_weights() + + def init_weights(self): + init_layer(self.att) + init_layer(self.cla) + init_bn(self.bn_att) + + def forward(self, x): + # x: (n_samples, n_in, n_time) + norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) + cla = self.nonlinear_transform(self.cla(x)) + x = torch.sum(norm_att * cla, dim=2) + return x, norm_att, cla + + def nonlinear_transform(self, x): + if self.activation == "linear": + return x + elif self.activation == "sigmoid": + return torch.sigmoid(x) + + +class Cnn14(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + + super(Cnn14, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + if (self.enable_fusion) and (self.fusion_type == "channel_map"): + self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) + else: + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] + ): + self.mel_conv1d = nn.Sequential( + nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1d(64), # No Relu + ) + if self.fusion_type == "daf_1d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_1d": + self.fusion_model = AFF(channels=64, type="1D") + elif self.fusion_type == "iaff_1d": + self.fusion_model = iAFF(channels=64, type="1D") + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + self.mel_conv2d = nn.Sequential( + nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + ) + + if self.fusion_type == "daf_2d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_2d": + self.fusion_model = AFF(channels=64, type="2D") + elif self.fusion_type == "iaff_2d": + self.fusion_model = iAFF(channels=64, type="2D") + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + if self.enable_fusion and input["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True + + if not self.enable_fusion: + x = self.spectrogram_extractor( + input["waveform"].to(device=device, non_blocking=True) + ) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + else: + longer_list = input["longer"].to(device=device, non_blocking=True) + x = input["mel_fusion"].to(device=device, non_blocking=True) + longer_list_idx = torch.where(longer_list)[0] + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: + new_x = x[:, 0:1, :, :].clone().contiguous() + # local processing + if len(longer_list_idx) > 0: + fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() + FB, FC, FT, FF = fusion_x_local.size() + fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) + fusion_x_local = torch.permute( + fusion_x_local, (0, 2, 1) + ).contiguous() + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.view( + FB, FC, FF, fusion_x_local.size(-1) + ) + fusion_x_local = ( + torch.permute(fusion_x_local, (0, 2, 1, 3)) + .contiguous() + .flatten(2) + ) + if fusion_x_local.size(-1) < FT: + fusion_x_local = torch.cat( + [ + fusion_x_local, + torch.zeros( + (FB, FF, FT - fusion_x_local.size(-1)), + device=device, + ), + ], + dim=-1, + ) + else: + fusion_x_local = fusion_x_local[:, :, :FT] + # 1D fusion + new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() + new_x[longer_list_idx] = self.fusion_model( + new_x[longer_list_idx], fusion_x_local + ) + x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] + else: + x = new_x + elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + global_x = x[:, 0:1, :, :] + + # global processing + B, C, H, W = global_x.shape + global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg") + if len(longer_list_idx) > 0: + local_x = x[longer_list_idx, 1:, :, :].contiguous() + TH = global_x.size(-2) + # local processing + B, C, H, W = local_x.shape + local_x = local_x.view(B * C, 1, H, W) + local_x = self.mel_conv2d(local_x) + local_x = local_x.view( + B, C, local_x.size(1), local_x.size(2), local_x.size(3) + ) + local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3) + TB, TC, _, TW = local_x.size() + if local_x.size(-2) < TH: + local_x = torch.cat( + [ + local_x, + torch.zeros( + (TB, TC, TH - local_x.size(-2), TW), + device=global_x.device, + ), + ], + dim=-2, + ) + else: + local_x = local_x[:, :, :TH, :] + + global_x[longer_list_idx] = self.fusion_model( + global_x[longer_list_idx], local_x + ) + x = global_x + else: + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 32) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + return output_dict + + +class Cnn6(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + + super(Cnn6, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) + + self.fc1 = nn.Linear(512, 512, bias=True) + self.fc_audioset = nn.Linear(512, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 16) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + + return output_dict + + +class Cnn10(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + + super(Cnn10, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + + self.fc1 = nn.Linear(1024, 1024, bias=True) + self.fc_audioset = nn.Linear(1024, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 32) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + + return output_dict + + +def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"): + try: + ModelProto = eval(audio_cfg.model_name) + model = ModelProto( + sample_rate=audio_cfg.sample_rate, + window_size=audio_cfg.window_size, + hop_size=audio_cfg.hop_size, + mel_bins=audio_cfg.mel_bins, + fmin=audio_cfg.fmin, + fmax=audio_cfg.fmax, + classes_num=audio_cfg.class_num, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + return model + except: + raise RuntimeError( + f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." + ) diff --git a/models/CLAP/open_clip/pretrained.py b/models/CLAP/open_clip/pretrained.py new file mode 100755 index 0000000000000000000000000000000000000000..e211d8b5b59320a599e62605f1dee6199f317253 --- /dev/null +++ b/models/CLAP/open_clip/pretrained.py @@ -0,0 +1,167 @@ +import hashlib +import os +import urllib +import warnings + +from tqdm import tqdm + +_RN50 = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", +) + +_RN50_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", +) + +_RN101 = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", +) + +_RN101_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", +) + +_RN50x4 = dict( + openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", +) + +_RN50x16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", +) + +_RN50x64 = dict( + openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", +) + +_VITB32 = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB32_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", +) + +_VITL14 = dict( + openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", +) + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-L-14": _VITL14, +} + + +def list_pretrained(as_str: bool = False): + """returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [ + ":".join([k, t]) if as_str else (k, t) + for k in _PRETRAINED.keys() + for t in _PRETRAINED[k].keys() + ] + + +def list_pretrained_tag_models(tag: str): + """return all models having the specified pretrain tag""" + models = [] + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_model_tags(model: str): + """return all pretrain tags for the specified model architecture""" + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def get_pretrained_url(model: str, tag: str): + if model not in _PRETRAINED: + return "" + model_pretrained = _PRETRAINED[model] + if tag not in model_pretrained: + return "" + return model_pretrained[tag] + + +def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + if "openaipublic" in url: + expected_sha256 = url.split("/")[-2] + else: + expected_sha256 = "" + + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if ( + hashlib.sha256(open(download_target, "rb").read()).hexdigest() + == expected_sha256 + ): + return download_target + else: + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if ( + expected_sha256 + and hashlib.sha256(open(download_target, "rb").read()).hexdigest() + != expected_sha256 + ): + raise RuntimeError( + f"Model has been downloaded but the SHA256 checksum does not not match" + ) + + return download_target diff --git a/models/CLAP/open_clip/timm_model.py b/models/CLAP/open_clip/timm_model.py new file mode 100755 index 0000000000000000000000000000000000000000..c9d1ab4666b5bab5038d44b90c9ddca5087de460 --- /dev/null +++ b/models/CLAP/open_clip/timm_model.py @@ -0,0 +1,112 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +from collections import OrderedDict + +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import ( + AttentionPool2d as AbsAttentionPool2d, + ) +except ImportError as e: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool="avg", + proj="linear", + drop=0.0, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + self.trunk = timm.create_model(model_name, pretrained=pretrained) + feat_size = self.trunk.default_cfg.get("pool_size", None) + feature_ndim = 1 if not feat_size else 2 + if pool in ("abs_attn", "rot_attn"): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool="") + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == "abs_attn": + head_layers["pool"] = AbsAttentionPool2d( + prev_chs, feat_size=feat_size, out_features=embed_dim + ) + prev_chs = embed_dim + elif pool == "rot_attn": + head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, "projection layer needed if non-attention pooling is used." + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == "linear": + head_layers["drop"] = nn.Dropout(drop) + head_layers["proj"] = nn.Linear(prev_chs, embed_dim) + elif proj == "mlp": + head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" + ) + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/models/CLAP/open_clip/tokenizer.py b/models/CLAP/open_clip/tokenizer.py new file mode 100755 index 0000000000000000000000000000000000000000..ee4d28450ec5dd12a79daf38cf3088e9e73c2cd5 --- /dev/null +++ b/models/CLAP/open_clip/tokenizer.py @@ -0,0 +1,197 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" + ) + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + if not special_tokens: + special_tokens = ["", ""] + else: + special_tokens = ["", ""] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + +_tokenizer = SimpleTokenizer() + + +def tokenize( + texts: Union[str, List[str]], context_length: int = 77 +) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + result[i, : len(tokens)] = torch.tensor(tokens) + + return result diff --git a/models/CLAP/open_clip/transform.py b/models/CLAP/open_clip/transform.py new file mode 100755 index 0000000000000000000000000000000000000000..77aaa722c4a5544ac50de6df35d3e922f63b111d --- /dev/null +++ b/models/CLAP/open_clip/transform.py @@ -0,0 +1,45 @@ +from torchvision.transforms import ( + Normalize, + Compose, + RandomResizedCrop, + InterpolationMode, + ToTensor, + Resize, + CenterCrop, +) + + +def _convert_to_rgb(image): + return image.convert("RGB") + + +def image_transform( + image_size: int, + is_train: bool, + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), +): + normalize = Normalize(mean=mean, std=std) + if is_train: + return Compose( + [ + RandomResizedCrop( + image_size, + scale=(0.9, 1.0), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) + else: + return Compose( + [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) diff --git a/models/CLAP/open_clip/utils.py b/models/CLAP/open_clip/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..8d6a6b7ea29d9edfc0a69debbfcd11cc88c98a28 --- /dev/null +++ b/models/CLAP/open_clip/utils.py @@ -0,0 +1,361 @@ +import numpy as np +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d +import logging +import h5py +from tqdm import tqdm +import random +import json +import os +import pathlib + +# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later. +dataset_split = { + "audiocaps": ["train", "valid", "test"], + "audioset": ["balanced_train", "unbalanced_train", "eval"], + "BBCSoundEffects": ["train", "test"], + "Clotho": ["train", "test", "valid"], + "free_to_use_sounds": ["train", "test"], + "paramount_motion": ["train", "test"], + "sonniss_game_effects": ["train", "test"], + "wesoundeffects": ["train", "test"], + "MACS": ["train", "test"], + "freesound": ["train", "test"], + "FSD50K": ["train", "test", "valid"], + "fsd50k_class_label": ["train", "test", "valid"], + "esc50": ["train", "test"], + "audiostock": ["train", "test"], + "freesound_no_overlap_noesc50": ["train", "test"], + "epidemic_sound_effects": ["train", "test"], + "VGGSound": ["train", "test"], + "urbansound8k_class_label": ["train", "test"], + "audioset_t5": ["balanced_train", "unbalanced_train", "eval"], + "epidemic_sound_effects_t5": ["train", "test"], + "WavText5K": ["train", "test"], + "esc50_no_overlap": ["train", "test"], + "usd8k_no_overlap": ["train", "test"], + "fsd50k_200_class_label": ["train", "test", "valid"], +} + + +def freeze_batch_norm_2d(module, module_match={}, name=""): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance( + module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) + ): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = ".".join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +def exist(dataset_name, dataset_type): + """ + Check if dataset exists + """ + if dataset_type in dataset_split[dataset_name]: + return True + else: + return False + + +def get_tar_path_from_dataset_name( + dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None +): + """ + Get tar path from dataset name and type + """ + output = [] + for n in dataset_names: + if full_dataset is not None and n in full_dataset: + current_dataset_types = dataset_split[n] + else: + current_dataset_types = dataset_types + for s in current_dataset_types: + tmp = [] + if islocal: + sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json" + if not os.path.exists(sizefilepath_): + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + else: + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + if not os.path.exists(sizefilepath_): + continue + sizes = json.load(open(sizefilepath_, "r")) + for k in sizes.keys(): + if islocal: + tmp.append(f"{dataset_path}/{n}/{s}/{k}") + else: + tmp.append( + f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -" + ) + if proportion != 1: + tmp = random.sample(tmp, int(proportion * len(tmp))) + output.append(tmp) + return sum(output, []) + + +def get_tar_path_from_txts(txt_path, islocal, proportion=1): + """ + Get tar path from txt path + """ + if isinstance(txt_path, (list, tuple)): + return sum( + [ + get_tar_path_from_txts( + txt_path[i], islocal=islocal, proportion=proportion + ) + for i in range(len(txt_path)) + ], + [], + ) + if isinstance(txt_path, str): + with open(txt_path) as f: + lines = f.readlines() + if islocal: + lines = [ + lines[i] + .split("\n")[0] + .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/") + for i in range(len(lines)) + ] + else: + lines = [ + lines[i].split("\n")[0].replace(".tar", ".tar -") + for i in range(len(lines)) + ] + if proportion != 1: + print("Sampling tars with proportion of {}".format(proportion)) + lines = random.sample(lines, int(proportion * len(lines))) + return lines + + +def get_mix_lambda(mixup_alpha, batch_size): + mixup_lambdas = [ + np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size) + ] + return np.array(mixup_lambdas).astype(np.float32) + + +def do_mixup(x, mixup_lambda): + """ + Args: + x: (batch_size , ...) + mixup_lambda: (batch_size,) + Returns: + out: (batch_size, ...) + """ + out = ( + x.transpose(0, -1) * mixup_lambda + + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda) + ).transpose(0, -1) + return out + + +def interpolate(x, ratio): + """Interpolate data in time domain. This is used to compensate the + resolution reduction in downsampling of a CNN. + + Args: + x: (batch_size, time_steps, classes_num) + ratio: int, ratio to interpolate + Returns: + upsampled: (batch_size, time_steps * ratio, classes_num) + """ + (batch_size, time_steps, classes_num) = x.shape + upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) + upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) + return upsampled + + +def pad_framewise_output(framewise_output, frames_num): + """Pad framewise_output to the same length as input frames. The pad value + is the same as the value of the last frame. + Args: + framewise_output: (batch_size, frames_num, classes_num) + frames_num: int, number of frames to pad + Outputs: + output: (batch_size, frames_num, classes_num) + """ + pad = framewise_output[:, -1:, :].repeat( + 1, frames_num - framewise_output.shape[1], 1 + ) + """tensor for padding""" + + output = torch.cat((framewise_output, pad), dim=1) + """(batch_size, frames_num, classes_num)""" + + +def process_ipc(index_path, classes_num, filename): + # load data + logging.info("Load Data...............") + ipc = [[] for _ in range(classes_num)] + with h5py.File(index_path, "r") as f: + for i in tqdm(range(len(f["target"]))): + t_class = np.where(f["target"][i])[0] + for t in t_class: + ipc[t].append(i) + print(ipc) + np.save(filename, ipc) + logging.info("Load Data Succeed...............") + + +def save_to_dict(s, o_={}): + sp = s.split(": ") + o_.update({sp[0]: float(sp[1])}) + return o_ + + +def get_data_from_log(txt_path): + """ + Output dictionary from out.txt log file + """ + with open(txt_path) as f: + lines = f.readlines() + val_data = {} + train_data = {} + train_losses = [] + train_losses_epoch = [] + for i in range(len(lines)): + if "| INFO |" in lines[i]: + if "Eval Epoch" in lines[i]: + if "val_loss" in lines[i]: + # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", "")) + line = lines[i].split("Eval Epoch: ")[-1] + num_epoch = int(line.split(" ")[0].split(" ")[0]) + d = { + line.split(" ")[0] + .split(" ")[1] + .replace(":", ""): float(line.split(" ")[0].split(" ")[-1]) + } + for i in range(1, len(line.split(" "))): + d = save_to_dict(line.split(" ")[i], d) + val_data[num_epoch] = d + elif "Train Epoch" in lines[i]: + num_epoch = int(lines[i].split("Train Epoch: ")[1][0]) + loss = float(lines[i].split("Loss: ")[-1].split(" (")[0]) + train_losses.append(loss) + train_losses_epoch.append(num_epoch) + for i in range(len(train_losses)): + train_data[i] = { + "num_epoch": train_losses_epoch[i], + "train_loss": train_losses[i], + } + return train_data, val_data + + +def save_p(obj, filename): + import pickle + + try: + from deepdiff import DeepDiff + except: + os.system("pip install deepdiff") + from deepdiff import DeepDiff + with open(filename, "wb") as file: + pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol + with open(filename, "rb") as file: + z = pickle.load(file) + assert ( + DeepDiff(obj, z, ignore_string_case=True) == {} + ), "there is something wrong with the saving process" + return + + +def load_p(filename): + import pickle + + with open(filename, "rb") as file: + z = pickle.load(file) + return z + + +def save_json(data, name="data.json"): + import json + + with open(name, "w") as fp: + json.dump(data, fp) + return + + +def load_json(name): + import json + + with open(name, "r") as fp: + data = json.load(fp) + return data + + +from multiprocessing import Process, Manager +from multiprocessing import Process, Value, Array +from ctypes import c_wchar + + +def load_class_label(path): + # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing + # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array + out = None + if path is not None: + if pathlib.Path(path).suffix in [".pkl", ".pickle"]: + out = load_p(path) + elif pathlib.Path(path).suffix in [".json", ".txt"]: + out = load_json(path) + elif pathlib.Path(path).suffix in [".npy", ".npz"]: + out = np.load(path) + elif pathlib.Path(path).suffix in [".csv"]: + import pandas as pd + + out = pd.read_csv(path) + return out + # if out is None: + # return None + # else: + # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False) + # val = Array('i', out.values(), lock=False) + # return (key, val) + + +from torch import optim + + +def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): + if optimizer_name.lower() == "adamw": + optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps) + elif optimizer_name.lower() == "sgd": + optimizer = optim.SGD(params, lr=lr, momentum=momentum) + elif optimizer_name.lower() == "adam": + optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps) + else: + raise ValueError("optimizer name is not correct") + return optimizer diff --git a/models/CLAP/open_clip/version.py b/models/CLAP/open_clip/version.py new file mode 100755 index 0000000000000000000000000000000000000000..3ced3581bb601ae91b1e1da4b8f4f520855a065e --- /dev/null +++ b/models/CLAP/open_clip/version.py @@ -0,0 +1 @@ +__version__ = "0.2.1" diff --git a/models/CLAP/training/__init__.py b/models/CLAP/training/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/CLAP/training/data.py b/models/CLAP/training/data.py new file mode 100755 index 0000000000000000000000000000000000000000..c1f1b50166afcaa698690860f6d1b51b6f267b13 --- /dev/null +++ b/models/CLAP/training/data.py @@ -0,0 +1,975 @@ +import ast +import json +import logging +import math +import os +import random +import h5py +from dataclasses import dataclass +from models.CLAP.training.params import parse_args +import braceexpand +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.datasets as datasets +import torchvision.transforms +import webdataset as wds +from PIL import Image +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler +from torch.utils.data.distributed import DistributedSampler +from functools import partial +import soundfile as sf +import io +from pathlib import Path +import wget + +from models.CLAP.open_clip.utils import get_tar_path_from_dataset_name, dataset_split +from models.CLAP.open_clip.utils import load_p, load_class_label +import tempfile +import copy + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +try: + import torchaudio +except ImportError: + torchaudio = None + +from models.CLAP.open_clip import tokenize + + +def tokenizer(text): + return tokenize(text).squeeze(0) + + +from transformers import RobertaTokenizer + +tokenize = RobertaTokenizer.from_pretrained("roberta-base") + + +def tokenizer(text): + result = tokenize( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + + +# initizlied the audioset map +_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy") +_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1.0, a_max=1.0) + return (x * 32767.0).astype(np.int16) + + +# For Toy Dataset +class ToyDataset(Dataset): + def __init__(self, index_path, ipc, config, eval_mode=False): + """Toy Dataset for testing the audioset input with text labels + Parameters + ---------- + index_path: str + the link to the h5 file of each audio + idc: str + the link to the npy file, the number of samples in each class + config: dict + the audio cfg file + eval_model (bool): to indicate if the dataset is a testing dataset + """ + self.audio_cfg = config["audio_cfg"] + self.text_cfg = config["text_cfg"] + self.fp = h5py.File(index_path, "r") + self.ipc = np.load(ipc, allow_pickle=True) + self.total_size = len(self.fp["audio_name"]) + self.classes_num = self.audio_cfg["class_num"] + self.eval_mode = eval_mode + + if not eval_mode: + self.generate_queue() + else: + self.queue = [] + for i in range(self.total_size): + target = self.fp["target"][i] + if np.sum(target) > 0: + self.queue.append(i) + self.total_size = len(self.queue) + logging.info("total dataset size: %d" % (self.total_size)) + logging.info("class num: %d" % (self.classes_num)) + + def time_shifting(self, x): + frame_num = len(x) + shift_len = random.randint(0, frame_num - 1) + new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0) + return new_sample + + def generate_queue(self): + self.queue = [] + while len(self.queue) < self.total_size: + class_set = [*range(self.classes_num)] + random.shuffle(class_set) + self.queue += [ + self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set + ] + self.queue = self.queue[: self.total_size] + + logging.info("queue regenerated:%s" % (self.queue[-5:])) + + def crop_wav(self, x): + crop_size = self.audio_cfg["crop_size"] + crop_pos = random.randint(0, len(x) - crop_size - 1) + return x[crop_pos : crop_pos + crop_size] + + def prompt_text(self, target): + events = _AUDIOSET_MAP[np.where(target > 0)] + event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1] + text = tokenize(event_text)[0] + return text + + def __getitem__(self, index): + """Load waveform, text, and target of an audio clip + + Parameters + ---------- + index: int + the index number + Return + ------ + output: dict { + "hdf5_path": str, + "index_in_hdf5": int, + "audio_name": str, + "waveform": list (audio_length,), + "target": list (class_num, ), + "text": torch.tensor (context_length,) + } + the output dictionary + """ + s_index = self.queue[index] + + audio_name = self.fp["audio_name"][s_index].decode() + # Hardcode here CHANGE + hdf5_path = ( + self.fp["hdf5_path"][s_index] + .decode() + .replace( + "../workspace", + "/home/la/kechen/Research/ke_zsasp/workspace", + ) + ) + r_idx = self.fp["index_in_hdf5"][s_index] + target = self.fp["target"][s_index].astype(np.float32) + text = self.prompt_text(target) + with h5py.File(hdf5_path, "r") as f: + waveform = int16_to_float32(f["waveform"][r_idx])[ + : self.audio_cfg["clip_samples"] + ] + assert ( + len(waveform) == self.audio_cfg["clip_samples"] + ), "The sample length is not match" + # Time shift + # if (self.config.enable_time_shift) and (not self.eval_mode): + # waveform = self.time_shifting(waveform) + # # Label Enhance + # if (self.config.crop_size is not None) and (not self.eval_mode): + # waveform = self.crop_wav(waveform) + # # the label enhance rate is fixed 0.5 + # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5: + # kidx = np.where(target)[0] + # for k in kidx: + # for add_key in self.class_map[k][1]: + # target[add_key] = 1.0 + # if len(self.class_map[k][2]) > 0: + # add_key = random.choice(self.class_map[k][2]) + # target[add_key] = 1.0 + + # missing the text input + mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :] + mel_spec = ( + torch.cat( + [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0 + ) + .cpu() + .numpy() + ) + longer = random.choice([True, False]) + if longer == False: + mel_spec[1:, :, :] = 0.0 + data_dict = { + "hdf5_path": hdf5_path, + "index_in_hdf5": r_idx, + "audio_name": audio_name, + "waveform": waveform, + "class_label": target, + "text": text, + "longer": longer, + "mel_fusion": mel_spec, + } + return data_dict + + def __len__(self): + return self.total_size + + +class CsvDataset(Dataset): + def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"): + logging.debug(f"Loading csv data from {input_filename}.") + df = pd.read_csv(input_filename, sep=sep) + + self.images = df[img_key].tolist() + self.captions = df[caption_key].tolist() + self.transforms = transforms + logging.debug("Done loading data.") + + def __len__(self): + return len(self.captions) + + def __getitem__(self, idx): + images = self.transforms(Image.open(str(self.images[idx]))) + texts = tokenize([str(self.captions[idx])])[0] + return images, texts + + +@dataclass +class DataInfo: + dataloader: DataLoader + sampler: DistributedSampler + + +def preprocess_txt(text): + return tokenize([str(text)])[0] + + +def get_dataset_size(shards, sizefilepath_=None, is_local=True): + if isinstance(shards, list): + size_list = [] + for s in shards: + size_list.append( + get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0] + ) + else: + if not is_local: + for n in dataset_split.keys(): + if n in shards.split("/"): + break + for s in dataset_split[n]: + if s in shards.split("/"): + break + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + shards_list = list(braceexpand.braceexpand(shards)) + dir_path = os.path.dirname(shards) + if sizefilepath_ is not None: + sizes = json.load(open(sizefilepath_, "r")) + total_size = sum( + [ + int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))]) + for shard in shards_list + ] + ) + else: + sizes_filename = os.path.join(dir_path, "sizes.json") + len_filename = os.path.join(dir_path, "__len__") + if os.path.exists(sizes_filename): + sizes = json.load(open(sizes_filename, "r")) + total_size = sum( + [int(sizes[os.path.basename(shard)]) for shard in shards_list] + ) + elif os.path.exists(len_filename): + # FIXME this used to be eval(open(...)) but that seemed rather unsafe + total_size = ast.literal_eval(open(len_filename, "r").read()) + else: + raise Exception( + "Cannot find sizes file for dataset. Please specify the path to the file." + ) + # total_size = None # num samples undefined + # some common dataset sizes (at time of authors last download) + # cc3m-train: 2905954 + # cc12m: 10968539 + # LAION-400m: 407332084 + num_shards = len(shards_list) + if isinstance(shards, list): + return sum(size_list), len(shards) + else: + return total_size, num_shards + + +def get_imagenet(args, preprocess_fns, split): + assert split in ["train", "val", "v2"] + is_train = split == "train" + preprocess_train, preprocess_val = preprocess_fns + + if split == "v2": + from imagenetv2_pytorch import ImageNetV2Dataset + + dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) + else: + if is_train: + data_path = args.imagenet_train + preprocess_fn = preprocess_train + else: + data_path = args.imagenet_val + preprocess_fn = preprocess_val + assert data_path + + dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) + + if is_train: + idxs = np.zeros(len(dataset.targets)) + target_array = np.array(dataset.targets) + k = 50 + for c in range(1000): + m = target_array == c + n = len(idxs[m]) + arr = np.zeros(n) + arr[:k] = 1 + np.random.shuffle(arr) + idxs[m] = arr + + idxs = idxs.astype("int") + sampler = SubsetRandomSampler(np.where(idxs)[0]) + else: + sampler = None + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=sampler, + ) + + return DataInfo(dataloader, sampler) + + +def count_samples(dataloader): + os.environ["WDS_EPOCH"] = "0" + n_elements, n_batches = 0, 0 + for images, texts in dataloader: + n_batches += 1 + n_elements += len(images) + assert len(images) == len(texts) + return n_elements, n_batches + + +def filter_no_caption(sample): + return "txt" in sample + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") + return True + + +_SHARD_SHUFFLE_SIZE = 2000 +_SHARD_SHUFFLE_INITIAL = 500 +_SAMPLE_SHUFFLE_SIZE = 5000 +_SAMPLE_SHUFFLE_INITIAL = 1000 + + +def sample_prop(sizefile, inputs, proportion, is_local=True): + """ + Sample a proportion of the data. + """ + file_path_dict = { + os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0] + for i in range(len(inputs)) + } + sampled_filepath_dict = {} + sampled_size_dict = {} + if not is_local: + if os.path.exists("sizes.json"): + os.remove("sizes.json") + wget.download(sizefile, "sizes.json") + sizefile = "sizes.json" + with open(sizefile, "r", encoding="UTF-8") as f: + load_dict = json.load(f) + L = int(len(file_path_dict) * proportion) + subkeys = random.sample(file_path_dict.keys(), L) + for k in subkeys: + sampled_size_dict[k] = load_dict[k] + sampled_filepath_dict[k] = file_path_dict[k] + return ( + sum(sampled_size_dict.values()), + L, + [os.path.join(v, k) for k, v in sampled_filepath_dict.items()], + sampled_size_dict, + ) + + +def get_mel(audio_data, audio_cfg): + # mel shape: (n_mels, T) + mel = torchaudio.transforms.MelSpectrogram( + sample_rate=audio_cfg["sample_rate"], + n_fft=audio_cfg["window_size"], + win_length=audio_cfg["window_size"], + hop_length=audio_cfg["hop_size"], + center=True, + pad_mode="reflect", + power=2.0, + norm=None, + onesided=True, + n_mels=64, + f_min=audio_cfg["fmin"], + f_max=audio_cfg["fmax"], + ).to(audio_data.device) + mel = mel(audio_data) + # Align to librosa: + # librosa_melspec = librosa.feature.melspectrogram( + # waveform, + # sr=audio_cfg['sample_rate'], + # n_fft=audio_cfg['window_size'], + # hop_length=audio_cfg['hop_size'], + # win_length=audio_cfg['window_size'], + # center=True, + # pad_mode="reflect", + # power=2.0, + # n_mels=64, + # norm=None, + # htk=True, + # f_min=audio_cfg['fmin'], + # f_max=audio_cfg['fmax'] + # ) + # we use log mel spectrogram as input + mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) + return mel.T # (T, n_mels) + + +def get_audio_features( + sample, audio_data, max_len, data_truncating, data_filling, audio_cfg +): + """ + Calculate and add audio features to sample. + Sample: a dict containing all the data of current sample. + audio_data: a tensor of shape (T) containing audio data. + max_len: the maximum length of audio data. + data_truncating: the method of truncating data. + data_filling: the method of filling data. + audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. + """ + with torch.no_grad(): + if len(audio_data) > max_len: + if data_truncating == "rand_trunc": + longer = torch.tensor([True]) + elif data_truncating == "fusion": + # fusion + mel = get_mel(audio_data, audio_cfg) + # split to three parts + chunk_frames = ( + max_len // audio_cfg["hop_size"] + 1 + ) # the +1 related to how the spectrogram is computed + total_frames = mel.shape[0] + if chunk_frames == total_frames: + # there is a corner case where the audio length is + # larger than max_len but smaller than max_len+hop_size. + # In this case, we just use the whole audio. + mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([False]) + else: + ranges = np.array_split( + list(range(0, total_frames - chunk_frames + 1)), 3 + ) + # print('total_frames-chunk_frames:', total_frames-chunk_frames, + # 'len(audio_data):', len(audio_data), + # 'chunk_frames:', chunk_frames, + # 'total_frames:', total_frames) + if len(ranges[1]) == 0: + # if the audio is too short, we just use the first chunk + ranges[1] = [0] + if len(ranges[2]) == 0: + # if the audio is too short, we just use the first chunk + ranges[2] = [0] + # randomly choose index for each part + idx_front = np.random.choice(ranges[0]) + idx_middle = np.random.choice(ranges[1]) + idx_back = np.random.choice(ranges[2]) + # select mel + mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :] + mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :] + mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :] + + # shrink the mel + mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])( + mel[None] + )[0] + # logging.info(f"mel_shrink.shape: {mel_shrink.shape}") + + # stack + mel_fusion = torch.stack( + [mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink], + dim=0, + ) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([True]) + else: + raise NotImplementedError( + f"data_truncating {data_truncating} not implemented" + ) + # random crop to max_len (for compatibility) + overflow = len(audio_data) - max_len + idx = np.random.randint(0, overflow + 1) + audio_data = audio_data[idx : idx + max_len] + + else: # padding if too short + if len(audio_data) < max_len: # do nothing if equal + if data_filling == "repeatpad": + n_repeat = int(max_len / len(audio_data)) + audio_data = audio_data.repeat(n_repeat) + # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0) + # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0] + audio_data = F.pad( + audio_data, + (0, max_len - len(audio_data)), + mode="constant", + value=0, + ) + elif data_filling == "pad": + audio_data = F.pad( + audio_data, + (0, max_len - len(audio_data)), + mode="constant", + value=0, + ) + elif data_filling == "repeat": + n_repeat = int(max_len / len(audio_data)) + audio_data = audio_data.repeat(n_repeat + 1)[:max_len] + else: + raise NotImplementedError( + f"data_filling {data_filling} not implemented" + ) + if data_truncating == "fusion": + mel = get_mel(audio_data, audio_cfg) + mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([False]) + + sample["longer"] = longer + sample["waveform"] = audio_data + + return sample + + +def preprocess( + sample, + audio_ext, + text_ext, + max_len, + audio_cfg, + class_index_dict=None, + data_filling="pad", + data_truncating="rand_trunc", + text_augment_selection=None, +): + """ + Preprocess a single sample for wdsdataloader. + """ + audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) + audio_data = int16_to_float32(float32_to_int16(audio_data)) + audio_data = torch.tensor(audio_data).float() + + # TODO: (yusong) to be include in the future + # # if torchaudio not installed, use soundfile to load audio + # if torchaudio is None: + # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) + # audio_data = torch.tensor(audio_data).float() + # else: + # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py + # with tempfile.TemporaryDirectory() as dirname: + # os.makedirs(dirname, exist_ok=True) + # fname = os.path.join(dirname, f"file.flac") + # with open(fname, "wb") as stream: + # stream.write(sample[audio_ext]) + # audio_data, orig_sr = torchaudio.load(fname) + # audio_data = audio_data[0, :].float() + + sample = get_audio_features( + sample, audio_data, max_len, data_truncating, data_filling, audio_cfg + ) + del sample[audio_ext] + + try: + json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) + except: + print("sample[__url__]:", sample["__url__"]) + + # For selecting augmented text from dataset + if text_augment_selection is None or text_augment_selection == "none": + texts = json_dict_raw["text"] + elif text_augment_selection == "all": + if "text_augment_all" in json_dict_raw.keys(): + texts = json_dict_raw["text_augment_all"] + else: + texts = json_dict_raw["text"] + elif text_augment_selection == "augment_only": + if "text_augment_all" in json_dict_raw.keys(): + if json_dict_raw["text_augment_t5"] is None: + texts = json_dict_raw["text"] + else: + texts = json_dict_raw["text_augment_t5"] + else: + texts = json_dict_raw["text"] + else: + raise NotImplementedError( + f"text_augment_selection {text_augment_selection} not implemented" + ) + sample["full_text"] = texts + + if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: + texts = random.choice(texts) + sample["raw_text"] = texts + sample["text"] = tokenizer(texts) # text shape: [num_token] + if class_index_dict is not None: + # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing + # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array + # key, val = class_index_dict + # key = key[:].split('\n') + # _dict = {k: v for k, v in zip(key, val)} + sample["class_label"] = np.zeros(len(class_index_dict.keys())) + for x in json_dict_raw["tag"]: + sample["class_label"][class_index_dict[x]] = 1 + sample["class_label"] = torch.tensor(sample["class_label"]).float() + del sample[text_ext] + sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext + sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext + sample["audio_orig_sr"] = orig_sr + return sample + + +def collate_fn(batch): + """ + Collate function for wdsdataloader. + batch: a list of dict, each dict is a sample + """ + # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend. + batch_dict = {} + for k in batch[0].keys(): + if isinstance(batch[0][k], dict): # dealwith bert tokenizer output + batch_dict[k] = {} + for kk in batch[0][k].keys(): + tmp = [] + for i in range(len(batch)): + tmp.append(batch[i][k][kk]) + batch_dict[k][kk] = torch.vstack(tmp) + elif isinstance(batch[0][k], torch.Tensor): + batch_dict[k] = torch.stack([sample[k] for sample in batch]) + elif isinstance(batch[0][k], np.ndarray): + batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch])) + else: + batch_dict[k] = [sample[k] for sample in batch] + return batch_dict + + +def get_wds_dataset( + args, + model_cfg, + is_train, + audio_ext="flac", + text_ext="json", + max_len=480000, + proportion=1.0, + sizefilepath_=None, + is_local=None, +): + """ + Get a dataset for wdsdataloader. + """ + if is_local is None and (not args.remotedata is None): + is_local = not args.remotedata + + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + + if not sizefilepath_ is None: + sizefilepath = sizefilepath_ + else: + sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json") + + if proportion != 1.0: + num_samples, num_shards, input_shards, _ = sample_prop( + sizefilepath, input_shards, proportion, is_local=is_local + ) + else: + num_samples, num_shards = get_dataset_size( + input_shards, sizefilepath_=sizefilepath_, is_local=is_local + ) + + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + "Currently, number of dataset samples must be specified for training dataset. " + "Please specify via `--train-num-samples` if no dataset length info present." + ) + else: + num_samples = ( + args.val_num_samples or 0 + ) # eval will just exhaust the iterator if not specified + + pipeline = [wds.SimpleShardList(input_shards)] + # at this point we have an iterator over all the shards + # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node + if is_train or args.parallel_eval: + pipeline.extend( + [ + wds.detshuffle( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + ), + wds.split_by_node, + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker at each node + wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + rng=random.Random(args.seed), + ), + # wds.repeatedly, # FIXME determine if this is beneficial + ] + ) + else: + pipeline.extend( + [ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ] + ) + pipeline.append( + wds.map( + partial( + preprocess, + audio_ext=audio_ext, + text_ext=text_ext, + max_len=max_len, + audio_cfg=model_cfg["audio_cfg"], + class_index_dict=copy.deepcopy(args.class_index_dict), + data_filling=args.data_filling, + data_truncating=args.data_truncating, + text_augment_selection=args.text_augment_selection, + ) + ), + ) + + pipeline.append( + wds.batched( + args.batch_size, + partial=not (is_train or args.parallel_eval), + collation_fn=collate_fn, + ) + ) + + dataset = wds.DataPipeline(*pipeline) + if is_train or args.parallel_eval: + # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples. + # (yusong): See comments below. + # roll over and repeat a few samples to get same number of full batches on each node + global_batch_size = args.batch_size * args.world_size + num_batches = math.ceil(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = math.ceil( + num_batches / num_workers + ) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch( + num_worker_batches + ) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + kwargs = {} + if args.horovod: # multi-node training on summit + kwargs["multiprocessing_context"] = "forkserver" + + dataloader = wds.WebLoader( + dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader, None) + + +def wds_batch_list2dict( + batch, + keys=[ + "__url__", + "__key__", + "waveform", + "text", + "raw_text", + "audio_name", + "text_name", + "audio_orig_sr", + ], +): + """ + Return a dictionary of the batch, with keys as the names of the fields. + """ + assert len(keys) == len( + batch + ), "batch must have same number of keys as keys argument" + return {keys[i]: batch[i] for i in range(len(batch))} + + +def get_csv_dataset(args, preprocess_fn, is_train): + input_filename = args.train_data if is_train else args.val_data + assert input_filename + dataset = CsvDataset( + input_filename, + preprocess_fn, + img_key=args.csv_img_key, + caption_key=args.csv_caption_key, + sep=args.csv_separator, + ) + num_samples = len(dataset) + sampler = DistributedSampler(dataset) if args.distributed and is_train else None + shuffle = is_train and sampler is None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_toy_dataset(args, model_cfg, is_train): + index_path = args.train_data if is_train else args.val_data + ipc_path = args.train_ipc if is_train else args.val_ipc + assert index_path and ipc_path + eval_mode = not is_train + dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode) + + num_samples = len(dataset) + sampler = ( + DistributedSampler(dataset, shuffle=False) + if args.distributed and is_train + else None + ) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_dataset_fn(data_path, dataset_type): + if dataset_type == "webdataset": + return get_wds_dataset + elif dataset_type == "csv": + return get_csv_dataset + elif dataset_type == "auto": + ext = data_path.split(".")[-1] + if ext in ["csv", "tsv"]: + return get_csv_dataset + elif ext in ["tar"]: + return get_wds_dataset + else: + raise ValueError( + f"Tried to figure out dataset type, but failed for extention {ext}." + ) + elif dataset_type == "toy": + return get_toy_dataset + else: + raise ValueError(f"Unsupported dataset type: {dataset_type}") + + +def get_data(args, model_cfg): + data = {} + + args.class_index_dict = load_class_label(args.class_label_path) + + if args.datasetinfos is None: + args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] + if args.dataset_type == "webdataset": + args.train_data = get_tar_path_from_dataset_name( + args.datasetnames, + args.datasetinfos, + islocal=not args.remotedata, + proportion=args.dataset_proportion, + dataset_path=args.datasetpath, + full_dataset=args.full_train_dataset, + ) + + if args.full_train_dataset is None: + args.full_train_dataset = [] + if args.exclude_eval_dataset is None: + args.exclude_eval_dataset = [] + excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset + + val_dataset_names = ( + [n for n in args.datasetnames if n not in excluded_eval_datasets] + if excluded_eval_datasets + else args.datasetnames + ) + args.val_dataset_names = val_dataset_names + args.val_data = get_tar_path_from_dataset_name( + val_dataset_names, + ["valid", "test", "eval"], + islocal=not args.remotedata, + proportion=1, + dataset_path=args.datasetpath, + full_dataset=None, + ) + + if args.train_data: + data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( + args, model_cfg, is_train=True + ) + + if args.val_data: + data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( + args, model_cfg, is_train=False + ) + + return data diff --git a/models/CLAP/training/distributed.py b/models/CLAP/training/distributed.py new file mode 100755 index 0000000000000000000000000000000000000000..2fa61f76c5cc3ab9f6a9643042afa8e1f2e1cb7f --- /dev/null +++ b/models/CLAP/training/distributed.py @@ -0,0 +1,150 @@ +import os + +import torch +import socket + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def is_global_master(args): + return args.rank == 0 + + +def is_local_master(args): + return args.local_rank == 0 + + +def is_master(args, local=False): + return is_local_master(args) if local else is_global_master(args) + + +def is_using_horovod(): + # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set + # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... + ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] + pmi_vars = ["PMI_RANK", "PMI_SIZE"] + if all([var in os.environ for var in ompi_vars]) or all( + [var in os.environ for var in pmi_vars] + ): + return True + else: + return False + + +def is_using_distributed(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) > 1 + if "SLURM_NTASKS" in os.environ: + return int(os.environ["SLURM_NTASKS"]) > 1 + return False + + +def world_info_from_env(): + local_rank = 0 + for v in ( + "SLURM_LOCALID", + "MPI_LOCALRANKID", + "OMPI_COMM_WORLD_LOCAL_RANK", + "LOCAL_RANK", + ): + if v in os.environ: + local_rank = int(os.environ[v]) + break + global_rank = 0 + for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"): + if v in os.environ: + global_rank = int(os.environ[v]) + break + world_size = 1 + for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"): + if v in os.environ: + world_size = int(os.environ[v]) + break + + return local_rank, global_rank, world_size + + +def init_distributed_device(args): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + args.distributed = False + args.world_size = 1 + args.rank = 0 # global rank + args.local_rank = 0 + if args.horovod: + assert hvd is not None, "Horovod is not installed" + hvd.init() + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.local_rank = local_rank + args.rank = world_rank + args.world_size = world_size + # args.local_rank = int(hvd.local_rank()) + # args.rank = hvd.rank() + # args.world_size = hvd.size() + args.distributed = True + os.environ["LOCAL_RANK"] = str(args.local_rank) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + print( + f"Distributed training: local_rank={args.local_rank}, " + f"rank={args.rank}, world_size={args.world_size}, " + f"hostname={socket.gethostname()}, pid={os.getpid()}" + ) + elif is_using_distributed(): + if "SLURM_PROCID" in os.environ: + # DDP via SLURM + args.local_rank, args.rank, args.world_size = world_info_from_env() + # SLURM var -> torch.distributed vars in case needed + os.environ["LOCAL_RANK"] = str(args.local_rank) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.local_rank = local_rank + args.rank = world_rank + args.world_size = world_size + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + else: + # DDP via torchrun, torch.distributed.launch + args.local_rank, _, _ = world_info_from_env() + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url + ) + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + args.distributed = True + print( + f"Distributed training: local_rank={args.local_rank}, " + f"rank={args.rank}, world_size={args.world_size}, " + f"hostname={socket.gethostname()}, pid={os.getpid()}" + ) + + if torch.cuda.is_available(): + if args.distributed and not args.no_set_device_rank: + device = "cuda:%d" % args.local_rank + else: + device = "cuda:0" + torch.cuda.set_device(device) + else: + device = "cpu" + args.device = device + device = torch.device(device) + return device diff --git a/models/CLAP/training/imagenet_zeroshot_data.py b/models/CLAP/training/imagenet_zeroshot_data.py new file mode 100755 index 0000000000000000000000000000000000000000..d32e55328d6799ccb8d61625f43abb80a33d6c17 --- /dev/null +++ b/models/CLAP/training/imagenet_zeroshot_data.py @@ -0,0 +1,1088 @@ +# NOTE: This script is currently not supported for CLAP. + +imagenet_classnames = [ + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead shark", + "electric ray", + "stingray", + "rooster", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "American robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "American dipper", + "kite (bird of prey)", + "bald eagle", + "vulture", + "great grey owl", + "fire salamander", + "smooth newt", + "newt", + "spotted salamander", + "axolotl", + "American bullfrog", + "tree frog", + "tailed frog", + "loggerhead sea turtle", + "leatherback sea turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "green iguana", + "Carolina anole", + "desert grassland whiptail lizard", + "agama", + "frilled-necked lizard", + "alligator lizard", + "Gila monster", + "European green lizard", + "chameleon", + "Komodo dragon", + "Nile crocodile", + "American alligator", + "triceratops", + "worm snake", + "ring-necked snake", + "eastern hog-nosed snake", + "smooth green snake", + "kingsnake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "African rock python", + "Indian cobra", + "green mamba", + "sea snake", + "Saharan horned viper", + "eastern diamondback rattlesnake", + "sidewinder rattlesnake", + "trilobite", + "harvestman", + "scorpion", + "yellow garden spider", + "barn spider", + "European garden spider", + "southern black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie grouse", + "peafowl", + "quail", + "partridge", + "african grey parrot", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "duck", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "red king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "great egret", + "bittern bird", + "crane bird", + "limpkin", + "common gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "dunlin", + "common redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese Chin", + "Maltese", + "Pekingese", + "Shih Tzu", + "King Charles Spaniel", + "Papillon", + "toy terrier", + "Rhodesian Ridgeback", + "Afghan Hound", + "Basset Hound", + "Beagle", + "Bloodhound", + "Bluetick Coonhound", + "Black and Tan Coonhound", + "Treeing Walker Coonhound", + "English foxhound", + "Redbone Coonhound", + "borzoi", + "Irish Wolfhound", + "Italian Greyhound", + "Whippet", + "Ibizan Hound", + "Norwegian Elkhound", + "Otterhound", + "Saluki", + "Scottish Deerhound", + "Weimaraner", + "Staffordshire Bull Terrier", + "American Staffordshire Terrier", + "Bedlington Terrier", + "Border Terrier", + "Kerry Blue Terrier", + "Irish Terrier", + "Norfolk Terrier", + "Norwich Terrier", + "Yorkshire Terrier", + "Wire Fox Terrier", + "Lakeland Terrier", + "Sealyham Terrier", + "Airedale Terrier", + "Cairn Terrier", + "Australian Terrier", + "Dandie Dinmont Terrier", + "Boston Terrier", + "Miniature Schnauzer", + "Giant Schnauzer", + "Standard Schnauzer", + "Scottish Terrier", + "Tibetan Terrier", + "Australian Silky Terrier", + "Soft-coated Wheaten Terrier", + "West Highland White Terrier", + "Lhasa Apso", + "Flat-Coated Retriever", + "Curly-coated Retriever", + "Golden Retriever", + "Labrador Retriever", + "Chesapeake Bay Retriever", + "German Shorthaired Pointer", + "Vizsla", + "English Setter", + "Irish Setter", + "Gordon Setter", + "Brittany dog", + "Clumber Spaniel", + "English Springer Spaniel", + "Welsh Springer Spaniel", + "Cocker Spaniel", + "Sussex Spaniel", + "Irish Water Spaniel", + "Kuvasz", + "Schipperke", + "Groenendael dog", + "Malinois", + "Briard", + "Australian Kelpie", + "Komondor", + "Old English Sheepdog", + "Shetland Sheepdog", + "collie", + "Border Collie", + "Bouvier des Flandres dog", + "Rottweiler", + "German Shepherd Dog", + "Dobermann", + "Miniature Pinscher", + "Greater Swiss Mountain Dog", + "Bernese Mountain Dog", + "Appenzeller Sennenhund", + "Entlebucher Sennenhund", + "Boxer", + "Bullmastiff", + "Tibetan Mastiff", + "French Bulldog", + "Great Dane", + "St. Bernard", + "husky", + "Alaskan Malamute", + "Siberian Husky", + "Dalmatian", + "Affenpinscher", + "Basenji", + "pug", + "Leonberger", + "Newfoundland dog", + "Great Pyrenees dog", + "Samoyed", + "Pomeranian", + "Chow Chow", + "Keeshond", + "brussels griffon", + "Pembroke Welsh Corgi", + "Cardigan Welsh Corgi", + "Toy Poodle", + "Miniature Poodle", + "Standard Poodle", + "Mexican hairless dog (xoloitzcuintli)", + "grey wolf", + "Alaskan tundra wolf", + "red wolf or maned wolf", + "coyote", + "dingo", + "dhole", + "African wild dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian Mau", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "polar bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "longhorn beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket insect", + "stick insect", + "cockroach", + "praying mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "red admiral butterfly", + "ringlet butterfly", + "monarch butterfly", + "small white butterfly", + "sulphur butterfly", + "gossamer-winged butterfly", + "starfish", + "sea urchin", + "sea cucumber", + "cottontail rabbit", + "hare", + "Angora rabbit", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "common sorrel horse", + "zebra", + "pig", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram (adult male sheep)", + "bighorn sheep", + "Alpine ibex", + "hartebeest", + "impala (antelope)", + "gazelle", + "arabian camel", + "llama", + "weasel", + "mink", + "European polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas monkey", + "baboon", + "macaque", + "langur", + "black-and-white colobus", + "proboscis monkey", + "marmoset", + "white-headed capuchin", + "howler monkey", + "titi monkey", + "Geoffroy's spider monkey", + "common squirrel monkey", + "ring-tailed lemur", + "indri", + "Asian elephant", + "African bush elephant", + "red panda", + "giant panda", + "snoek fish", + "eel", + "silver salmon", + "rock beauty fish", + "clownfish", + "sturgeon", + "gar fish", + "lionfish", + "pufferfish", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibious vehicle", + "analog clock", + "apiary", + "apron", + "trash can", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint pen", + "Band-Aid", + "banjo", + "baluster / handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "swimming cap", + "bath towel", + "bathtub", + "station wagon", + "lighthouse", + "beaker", + "military hat (bearskin or shako)", + "beer bottle", + "beer glass", + "bell tower", + "baby bib", + "tandem bicycle", + "bikini", + "ring binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsleigh", + "bolo tie", + "poke bonnet", + "bookcase", + "bookstore", + "bottle cap", + "hunting bow", + "bow tie", + "brass memorial plaque", + "bra", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "high-speed train", + "butcher shop", + "taxicab", + "cauldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "tool kit", + "cardboard box / carton", + "car wheel", + "automated teller machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "mobile phone", + "chain", + "chain-link fence", + "chain mail", + "chainsaw", + "storage chest", + "chiffonier", + "bell or wind chime", + "china cabinet", + "Christmas stocking", + "church", + "movie theater", + "cleaver", + "cliff dwelling", + "cloak", + "clogs", + "cocktail shaker", + "coffee mug", + "coffeemaker", + "spiral or coil", + "combination lock", + "computer keyboard", + "candy store", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "construction crane", + "crash helmet", + "crate", + "infant bed", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "rotary dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishcloth", + "dishwasher", + "disc brake", + "dock", + "dog sled", + "dome", + "doormat", + "drilling rig", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso machine", + "face powder", + "feather boa", + "filing cabinet", + "fireboat", + "fire truck", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster bed", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gas mask or respirator", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golf cart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "radiator grille", + "grocery store", + "guillotine", + "hair clip", + "hair spray", + "half-track", + "hammer", + "hamper", + "hair dryer", + "hand-held computer", + "handkerchief", + "hard disk drive", + "harmonica", + "harp", + "combine harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoop skirt", + "gymnastic horizontal bar", + "horse-drawn vehicle", + "hourglass", + "iPod", + "clothes iron", + "carved pumpkin", + "jeans", + "jeep", + "T-shirt", + "jigsaw puzzle", + "rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop computer", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "ocean liner", + "lipstick", + "slip-on shoe", + "lotion", + "music speaker", + "loupe magnifying glass", + "sawmill", + "magnetic compass", + "messenger bag", + "mailbox", + "tights", + "one-piece bathing suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine cabinet", + "megalith", + "microphone", + "microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "ford model t", + "modem", + "monastery", + "monitor", + "moped", + "mortar and pestle", + "graduation cap", + "mosque", + "mosquito net", + "vespa", + "mountain bike", + "tent", + "computer mouse", + "mousetrap", + "moving van", + "muzzle", + "metal nail", + "neck brace", + "necklace", + "baby pacifier", + "notebook computer", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "pipe organ", + "oscilloscope", + "overskirt", + "bullock cart", + "oxygen mask", + "product packet / packaging", + "paddle", + "paddle wheel", + "padlock", + "paintbrush", + "pajamas", + "palace", + "pan flute", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "railroad car", + "patio", + "payphone", + "pedestal", + "pencil case", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "plectrum", + "Pickelhaube", + "picket fence", + "pickup truck", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate ship", + "drink pitcher", + "block plane", + "planetarium", + "plastic bag", + "plate rack", + "farm plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "soda bottle", + "plant pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "missile", + "projector", + "hockey puck", + "punching bag", + "purse", + "quill", + "quilt", + "race car", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "fishing casting reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "eraser", + "rugby ball", + "ruler measuring stick", + "sneaker", + "safe", + "safety pin", + "salt shaker", + "sandal", + "sarong", + "saxophone", + "scabbard", + "weighing scale", + "school bus", + "schooner", + "scoreboard", + "CRT monitor", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe store", + "shoji screen / room divider", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "balaclava ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot machine", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar thermal collector", + "sombrero", + "soup bowl", + "keyboard space bar", + "space heater", + "space shuttle", + "spatula", + "motorboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "through arch bridge", + "steel drum", + "stethoscope", + "scarf", + "stone wall", + "stopwatch", + "stove", + "strainer", + "tram", + "stretcher", + "couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglasses", + "sunglasses", + "sunscreen", + "suspension bridge", + "mop", + "sweatshirt", + "swim trunks / shorts", + "swing", + "electrical switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy bear", + "television", + "tennis ball", + "thatched roof", + "front curtain", + "thimble", + "threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toy store", + "tractor", + "semi-trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "hot tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright piano", + "vacuum cleaner", + "vase", + "vaulted or arched ceiling", + "velvet fabric", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "military aircraft", + "sink", + "washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "hair wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "airplane wing", + "wok", + "wooden spoon", + "wool", + "split-rail fence", + "shipwreck", + "sailboat", + "yurt", + "website", + "comic book", + "crossword", + "traffic or street sign", + "traffic light", + "dust jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "popsicle", + "baguette", + "bagel", + "pretzel", + "cheeseburger", + "hot dog", + "mashed potatoes", + "cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith apple", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "cherimoya (custard apple)", + "pomegranate", + "hay", + "carbonara", + "chocolate syrup", + "dough", + "meatloaf", + "pizza", + "pot pie", + "burrito", + "red wine", + "espresso", + "tea cup", + "eggnog", + "mountain", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeshore", + "promontory", + "sandbar", + "beach", + "valley", + "volcano", + "baseball player", + "bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "rose hip", + "horse chestnut seed", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn mushroom", + "earth star fungus", + "hen of the woods mushroom", + "bolete", + "corn cob", + "toilet paper", +] + + +openai_imagenet_template = [ + lambda c: f"a bad photo of a {c}.", + lambda c: f"a photo of many {c}.", + lambda c: f"a sculpture of a {c}.", + lambda c: f"a photo of the hard to see {c}.", + lambda c: f"a low resolution photo of the {c}.", + lambda c: f"a rendering of a {c}.", + lambda c: f"graffiti of a {c}.", + lambda c: f"a bad photo of the {c}.", + lambda c: f"a cropped photo of the {c}.", + lambda c: f"a tattoo of a {c}.", + lambda c: f"the embroidered {c}.", + lambda c: f"a photo of a hard to see {c}.", + lambda c: f"a bright photo of a {c}.", + lambda c: f"a photo of a clean {c}.", + lambda c: f"a photo of a dirty {c}.", + lambda c: f"a dark photo of the {c}.", + lambda c: f"a drawing of a {c}.", + lambda c: f"a photo of my {c}.", + lambda c: f"the plastic {c}.", + lambda c: f"a photo of the cool {c}.", + lambda c: f"a close-up photo of a {c}.", + lambda c: f"a black and white photo of the {c}.", + lambda c: f"a painting of the {c}.", + lambda c: f"a painting of a {c}.", + lambda c: f"a pixelated photo of the {c}.", + lambda c: f"a sculpture of the {c}.", + lambda c: f"a bright photo of the {c}.", + lambda c: f"a cropped photo of a {c}.", + lambda c: f"a plastic {c}.", + lambda c: f"a photo of the dirty {c}.", + lambda c: f"a jpeg corrupted photo of a {c}.", + lambda c: f"a blurry photo of the {c}.", + lambda c: f"a photo of the {c}.", + lambda c: f"a good photo of the {c}.", + lambda c: f"a rendering of the {c}.", + lambda c: f"a {c} in a video game.", + lambda c: f"a photo of one {c}.", + lambda c: f"a doodle of a {c}.", + lambda c: f"a close-up photo of the {c}.", + lambda c: f"a photo of a {c}.", + lambda c: f"the origami {c}.", + lambda c: f"the {c} in a video game.", + lambda c: f"a sketch of a {c}.", + lambda c: f"a doodle of the {c}.", + lambda c: f"a origami {c}.", + lambda c: f"a low resolution photo of a {c}.", + lambda c: f"the toy {c}.", + lambda c: f"a rendition of the {c}.", + lambda c: f"a photo of the clean {c}.", + lambda c: f"a photo of a large {c}.", + lambda c: f"a rendition of a {c}.", + lambda c: f"a photo of a nice {c}.", + lambda c: f"a photo of a weird {c}.", + lambda c: f"a blurry photo of a {c}.", + lambda c: f"a cartoon {c}.", + lambda c: f"art of a {c}.", + lambda c: f"a sketch of the {c}.", + lambda c: f"a embroidered {c}.", + lambda c: f"a pixelated photo of a {c}.", + lambda c: f"itap of the {c}.", + lambda c: f"a jpeg corrupted photo of the {c}.", + lambda c: f"a good photo of a {c}.", + lambda c: f"a plushie {c}.", + lambda c: f"a photo of the nice {c}.", + lambda c: f"a photo of the small {c}.", + lambda c: f"a photo of the weird {c}.", + lambda c: f"the cartoon {c}.", + lambda c: f"art of the {c}.", + lambda c: f"a drawing of the {c}.", + lambda c: f"a photo of the large {c}.", + lambda c: f"a black and white photo of a {c}.", + lambda c: f"the plushie {c}.", + lambda c: f"a dark photo of a {c}.", + lambda c: f"itap of a {c}.", + lambda c: f"graffiti of the {c}.", + lambda c: f"a toy {c}.", + lambda c: f"itap of my {c}.", + lambda c: f"a photo of a cool {c}.", + lambda c: f"a photo of a small {c}.", + lambda c: f"a tattoo of the {c}.", +] diff --git a/models/CLAP/training/infer_demo.py b/models/CLAP/training/infer_demo.py new file mode 100755 index 0000000000000000000000000000000000000000..6a1bcc1fd8cf89ba30773d3479b2a78e8dc06d9f --- /dev/null +++ b/models/CLAP/training/infer_demo.py @@ -0,0 +1,109 @@ +import sys + +sys.path.append( + "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/src" +) + +import os +import torch +import librosa +from open_clip import create_model +from training.data import get_audio_features +from training.data import int16_to_float32, float32_to_int16 +from transformers import RobertaTokenizer + +tokenize = RobertaTokenizer.from_pretrained("roberta-base") + + +def tokenizer(text): + result = tokenize( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + + +PRETRAINED_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/checkpoints/epoch_top_0_audioset_no_fusion.pt" +WAVE_48k_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/audio/machine.wav" + + +def infer_text(): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + precision = "fp32" + amodel = "HTSAT-tiny" # or 'PANN-14' + tmodel = "roberta" # the best text encoder in our training + enable_fusion = False # False if you do not want to use the fusion model + fusion_type = "aff_2d" + pretrained = PRETRAINED_PATH + + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision=precision, + device=device, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + # load the text, can be a list (i.e. batch size) + text_data = ["I love the contrastive learning", "I love the pretrain model"] + # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90 + text_data = tokenizer(text_data) + + text_embed = model.get_text_embedding(text_data) + print(text_embed.size()) + + +def infer_audio(): + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + precision = "fp32" + amodel = "HTSAT-tiny" # or 'PANN-14' + tmodel = "roberta" # the best text encoder in our training + enable_fusion = False # False if you do not want to use the fusion model + fusion_type = "aff_2d" + pretrained = PRETRAINED_PATH + + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision=precision, + device=device, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + + # load the waveform of the shape (T,), should resample to 48000 + audio_waveform, sr = librosa.load(WAVE_48k_PATH, sr=48000) + # quantize + audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) + audio_waveform = torch.from_numpy(audio_waveform).float() + audio_dict = {} + + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + import ipdb + + ipdb.set_trace() + audio_dict = get_audio_features( + audio_dict, + audio_waveform, + 480000, + data_truncating="fusion", + data_filling="repeatpad", + audio_cfg=model_cfg["audio_cfg"], + ) + # can send a list to the model, to process many audio tracks in one time (i.e. batch size) + audio_embed = model.get_audio_embedding([audio_dict]) + print(audio_embed.size()) + import ipdb + + ipdb.set_trace() + + +if __name__ == "__main__": + infer_text() + infer_audio() diff --git a/models/CLAP/training/logger.py b/models/CLAP/training/logger.py new file mode 100755 index 0000000000000000000000000000000000000000..ac4634970fae6aacde2b7b808355dbd50c90ce73 --- /dev/null +++ b/models/CLAP/training/logger.py @@ -0,0 +1,30 @@ +import logging + + +def setup_logging(log_file, level, include_host=False): + if include_host: + import socket + + hostname = socket.gethostname() + formatter = logging.Formatter( + f"%(asctime)s | {hostname} | %(levelname)s | %(message)s", + datefmt="%Y-%m-%d,%H:%M:%S", + ) + else: + formatter = logging.Formatter( + "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S" + ) + + logging.root.setLevel(level) + loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] + for logger in loggers: + logger.setLevel(level) + + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logging.root.addHandler(stream_handler) + + if log_file: + file_handler = logging.FileHandler(filename=log_file) + file_handler.setFormatter(formatter) + logging.root.addHandler(file_handler) diff --git a/models/CLAP/training/lp_main.py b/models/CLAP/training/lp_main.py new file mode 100755 index 0000000000000000000000000000000000000000..c2d4e8c85aaa3c8e4221963ef56a815cc14f354f --- /dev/null +++ b/models/CLAP/training/lp_main.py @@ -0,0 +1,670 @@ +from cmath import cos +from inspect import getargs +import logging +import os +import random +from datetime import datetime +import bisect +import copy +from sched import scheduler +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch import optim +from torch.cuda.amp import GradScaler +import faulthandler +import pathlib +import argparse +import time + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from open_clip import create_model_and_transforms, trace_model, create_model +from training.data import get_data +from training.params import parse_args +from training.distributed import is_master, init_distributed_device, world_info_from_env +from training.logger import setup_logging +from training.scheduler import cosine_lr +from training.lp_train import train_one_epoch, evaluate +from open_clip.utils import get_tar_path_from_dataset_name, dataset_split, get_optimizer +from open_clip.utils import load_p, load_class_label +from open_clip.linear_probe import LinearProbe + + +def maintain_ckpts(args, startidx, all_idx_len): + for i in reversed(range(startidx, all_idx_len)): + if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): + os.rename( + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), + ) + if os.path.exists( + os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") + ): + os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) + return + + +def update_top_k_performance( + new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True +): + """ + Record the top-k performance of the current epoch. + current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} + """ + if isinstance(new_metrics_inputs, (list, tuple)): + new_metrics_inputs = np.mean(new_metrics_inputs) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, dict): + new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, (float, int)): + update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} + sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) + sorted_values = sorted( + current_top_k_ckpt_metrics.values(), reverse=bignumbetter + ) + sorted_values_ = copy.deepcopy(sorted_values) + sorted_values.append(new_metrics_inputs) + sorted_values = sorted(sorted_values, reverse=bignumbetter) + sorted_values = sorted_values[:-1] + + if sorted_values == sorted_values_: + return current_top_k_ckpt_metrics, new_metrics_inputs + else: + for i in range(len(sorted_keys)): + if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: + current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] + update_flag[sorted_keys[i]] = True + for i in range(len(update_flag)): + if update_flag[i]: + maintain_ckpts(args, i, len(sorted_keys)) + torch.save( + ckpt, + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + ) + break + return current_top_k_ckpt_metrics, new_metrics_inputs + + +# def updateifNone(a, b): +# a = b if None else a +# return a + + +def is_pretrained_params(n): + return ( + n.startswith("clap_model.transformer") + or n in ["clap_model.positional_embedding", "clap_model.text_projection"] + or n.startswith("clap_model.token_embedding") + or n.startswith("clap_model.ln_final") + or n.startswith("clap_model.logit_scale_t") + ) + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def config_lp_optimizer(model, data, args): + # set wd-related params to 0 if use adam optimizer + if args.optimizer == "adam": + args.wd = 0 + args.wd_pretrained = 0 + args.wd_new = 0 + + in_clap = lambda n, p: n.startswith("clap_model") + + named_parameters = list(model.named_parameters()) + + optimizer = {} + scheduler = {} + + # freeze text encoder + text_freeze_parameters = [ + p + for n, p in named_parameters + if n.startswith("clap_model.transformer") + or n in ["clap_model.positional_embedding", "clap_model.text_projection"] + or n.startswith("clap_model.token_embedding") + or n.startswith("clap_model.ln_final") + ] + + if args.freeze_text: + logging.info("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + + if not args.lp_freeze: + exclude = ( + lambda n, p: p.ndim < 2 + or "bn" in n + or "ln" in n + or "bias" in n + or "logit_scale" in n + ) + include = lambda n, p: not exclude(n, p) + + # (yusong): we do not split the learning rate anymore + # p for n, p in named_parameters if in_clap(n,p) and exclude(n, p) and p.requires_grad + gain_or_bias_params = [ + p for n, p in named_parameters if exclude(n, p) and p.requires_grad + ] + # rest_params = [p for n, p in named_parameters if in_clap(n,p) and include(n, p) and p.requires_grad] + rest_params = [ + p for n, p in named_parameters if include(n, p) and p.requires_grad + ] + + if args.train_data is None: + optimizer = None + scheduler = None + else: + total_steps = data["train"].dataloader.num_batches * args.epochs + + if args.split_opt: + for x in ["lr", "beta1", "beta2", "eps", "wd"]: + for y in ["_new", "_pretrained"]: + if getattr(args, x + y) is None: + setattr(args, x + y, getattr(args, x)) + + gain_or_bias_pretrained_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + rest_pretrained_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + gain_or_bias_new_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) + and (not is_pretrained_params(n)) + ] + rest_new_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) + and (not is_pretrained_params(n)) + ] + + pretrained_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, + { + "params": rest_pretrained_params, + "weight_decay": args.wd_pretrained, + }, + ], + lr=args.lr_pretrained, + betas=(args.beta1_pretrained, args.beta2_pretrained), + eps=args.eps_pretrained, + momentum=args.momentum_pretrained, + optimizer_name=args.optimizer, + ) + pretrained_params_scheduler = cosine_lr( + pretrained_params_optimizer, + args.lr_pretrained, + args.warmup, + total_steps, + ) + + new_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_new_params, "weight_decay": 0.0}, + {"params": rest_new_params, "weight_decay": args.wd_new}, + ], + lr=args.lr_new, + betas=(args.beta1_new, args.beta2_new), + eps=args.eps_new, + momentum=args.momentum_new, + optimizer_name=args.optimizer, + ) + new_params_scheduler = cosine_lr( + new_params_optimizer, args.lr_new, args.warmup, total_steps + ) + + optimizer["text"] = pretrained_params_optimizer + optimizer["audio"] = new_params_optimizer + scheduler["text"] = pretrained_params_scheduler + scheduler["audio"] = new_params_scheduler + + if args.horovod: + pretrained_params_optimizer = hvd.DistributedOptimizer( + pretrained_params_optimizer, + named_parameters=model.named_parameters(), + ) + new_params_optimizer = hvd.DistributedOptimizer( + new_params_optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state( + pretrained_params_optimizer, root_rank=0 + ) + hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) + else: + + optimizer["clap"] = get_optimizer( + [ + {"params": gain_or_bias_params, "weight_decay": 0.0}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=args.momentum, + optimizer_name=args.optimizer, + ) + scheduler["clap"] = cosine_lr( + optimizer["clap"], args.lr, args.warmup, total_steps + ) + + if args.horovod: + optimizer["clap"] = hvd.DistributedOptimizer( + optimizer["clap"], named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer["clap"], root_rank=0) + + # linear probe optimizer + else: + lp_params = [ + p for n, p in named_parameters if (not in_clap(n, p)) and p.requires_grad + ] + lp_optim = get_optimizer( + lp_params, + lr=args.lp_lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=0.9, + optimizer_name=args.optimizer, + ) + optimizer["lp"] = lp_optim + + return optimizer, scheduler, text_freeze_parameters + + +def main(): + args = parse_args() + + time.sleep(args.sleep) + + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + args.amodel = args.amodel.replace("/", "-") + # download sizes.json file + + # (yusong): the below two lines are for debug + # print("setting up faulthandler") + # faulthandler.register(10) + + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + args.class_index_dict = load_class_label(args.class_label_path) + + # get the name of the experiments + if args.name is None: + args.name = "-".join( + [ + datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), + f"linear_probe" f"model_{args.amodel}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ] + ) + + # discover initial world args early so we can log properly + args.distributed = False + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.remotedata and is_master(args): + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + args.log_path = None + if is_master(args, local=args.log_local): + log_base_path = os.path.join(args.logs, args.name) + os.makedirs(log_base_path, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path, log_filename) + + # avoid log dir in same name: + postfix = 0 + while os.path.exists(args.log_path): + postfix += 1 + log_base_path_new = log_base_path + "-" + str(postfix) + os.makedirs(log_base_path_new, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path_new, log_filename) + # print( + # "Error. Experiment already exists. Use --name {} to specify a new experiment." + # ) + # return -1 + + # Set logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # fully initialize distributed device environment + device = init_distributed_device(args) + + args.wandb = "wandb" in args.report_to or "all" in args.report_to + args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to + if is_master(args): + args.tensorboard_path = ( + os.path.join(args.logs, args.name, "tensorboard") + if args.tensorboard + else "" + ) + args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = "" + args.checkpoint_path = "" + + if args.copy_codebase: + copy_codebase(args) + + assert args.precision in ["amp", "fp16", "fp32"] + if args.precision == "fp16": + logging.warning( + "It is recommended to use AMP mixed-precision instead of FP16. " + "FP16 support needs further verification and tuning, especially for train." + ) + + if args.horovod: + logging.info( + f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + elif args.distributed: + logging.info( + f"Running in distributed mode with multiple processes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + else: + logging.info(f"Running with a single process. Device {args.device}.") + + logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}") + + # Create CLAP model + clap_model, clap_model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False, + pretrained_audio=args.pretrained_audio, + pretrained_text=args.pretrained_text, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type, + ) + + args.lp_out_ch = len(list(args.class_index_dict.keys())) + # Linear Probe + logging.info(f"linear probe using mlp: {args.lp_mlp}") + logging.info(f"linear probe using freeze: {args.lp_freeze}") + logging.info(f"linear probe act layer: {args.lp_act}") + logging.info(f"linear probe out ch: {args.lp_out_ch}") + logging.info(f"linear probe learning rate (if applicable): {args.lp_lr}") + logging.info(f"linear probe loss func: {args.lp_loss}") + logging.info(f"linear probe lp_metrics: {args.lp_metrics}") + + model = LinearProbe( + clap_model, + mlp=args.lp_mlp, + freeze=args.lp_freeze, + in_ch=512, + out_ch=args.lp_out_ch, + act=args.lp_act, + ) # in_ch is fixed (i.e., 512) + model = model.to(device) + + if args.horovod: + with torch.no_grad(): + for param in model.parameters(): + param.set_(param.contiguous()) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if is_master(args): + logging.info("Linear Probe CLAP Model:") + logging.info(f"{str(clap_model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args["static_graph"] = True + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True, **ddp_args + ) + + data = get_data(args, clap_model_cfg) + assert len(data), "At least one train or eval dataset must be specified." + if args.trace: + assert "train" not in data, "Cannot train with traced model" + + optimizer, scheduler, text_freeze_parameters = config_lp_optimizer( + model, data, args + ) + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module.") :]: v for k, v in sd.items()} + model.load_state_dict(sd) + if args.split_opt: + if optimizer is not None: + for k, o_ in optimizer.items(): + o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and "scaler" in checkpoint: + scaler.load_state_dict(checkpoint["scaler"]) + logging.info( + f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info( + f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" + ) + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + cudnn.deterministic = False + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, "Please install wandb." + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + project="clap", + notes=args.wandb_notes, + name=args.wandb_notes, + tags=[], + config=vars(args), + ) + if args.debug: + wandb.watch(model, log="all") + wandb.save(params_file) + logging.debug("Finished loading wandb.") + + if "train" not in data: + evaluate(model, data, start_epoch, args, writer) + return + elif start_epoch == 0 and "val" in data and not args.no_eval: + evaluate(model, data, 0, args, writer) + if args.save_top_performance: + current_top_k_ckpt_metrics = { + i: 0 for i in range(args.save_top_performance) + } # initialize the top-k metric for ckpts to 0 + + for epoch in range(start_epoch, args.epochs): + # freeze the text param after (include) args.freeze_text_after, this is -1 by default + if epoch == args.freeze_text_after: + print("Text pretrained parameters are freezed since this epoch.") + for k in text_freeze_parameters: + k.requires_grad = False + if is_master(args): + logging.info(f"Start epoch {epoch}") + + train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + completed_epoch = epoch + 1 + + if ( + any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) + and not args.no_eval + ): + metrics = evaluate(model, data, completed_epoch, args, writer) + if args.save_top_performance: + top_k_dataset = args.top_k_checkpoint_select_dataset + top_k_metric = args.top_k_checkpoint_select_metric + filtered_metrics = [ + v + for k, v in metrics.items() + if top_k_metric in k and top_k_dataset in k + ] # check all R@10 metrics (all dataset) and use it to update the ckpt + # Saving checkpoints. + if args.save_logs: + opt_dict = { + k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() + } + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + } + checkpoint_dict.update(opt_dict) + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.save_most_recent: + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_latest.pt"), + ) + if args.save_top_performance and not args.no_eval: + update_top_k_performance( + filtered_metrics, + current_top_k_ckpt_metrics, + args, + checkpoint_dict, + bignumbetter=True, + ) + + if args.wandb and is_master(args): + wandb.finish() + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree( + current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") + ) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main() diff --git a/models/CLAP/training/lp_train.py b/models/CLAP/training/lp_train.py new file mode 100755 index 0000000000000000000000000000000000000000..24a19bacd0a4b789415cfccbce1f8bc99bc493ed --- /dev/null +++ b/models/CLAP/training/lp_train.py @@ -0,0 +1,301 @@ +import json +import logging +import math +import os +import time +from contextlib import suppress + +import numpy as np +import torch +import torch.nn.functional as F + +try: + import wandb +except ImportError: + wandb = None + +from open_clip import LPLoss, LPMetrics, lp_gather_features +from open_clip.utils import do_mixup, get_mix_lambda +from .distributed import is_master +from .zero_shot import zero_shot_eval + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def unwrap_model(model): + if hasattr(model, "module"): + return model.module + else: + return model + + +def train_one_epoch( + model, + data, + epoch, + optimizer, + scaler, + scheduler, + args, + tb_writer=None, + extra_suffix="", +): + device = torch.device(args.device) + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + model.train() + loss = LPLoss(args.lp_loss) + + dataloader, sampler = data["train"].dataloader, data["train"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + num_batches_per_epoch = dataloader.num_batches + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + # for toy dataset + if args.dataset_type == "toy": + dataloader.dataset.generate_queue() + + loss_m = AverageMeter() + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for i, batch in enumerate(dataloader): + step = num_batches_per_epoch * epoch + i + + if isinstance(scheduler, dict): + for s in scheduler.values(): + s(step) + else: + scheduler(step) + + audio = batch # contains mel_spec, wavform, and longer list + class_label = batch["class_label"] + # audio = audio.to(device=device, non_blocking=True) + class_label = class_label.to(device=device, non_blocking=True) + + if args.mixup: + # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146 + mix_lambda = torch.from_numpy( + get_mix_lambda(0.5, len(audio["waveform"])) + ).to(device) + class_label = do_mixup(class_label, mix_lambda) + else: + mix_lambda = None + + data_time_m.update(time.time() - end) + if isinstance(optimizer, dict): + for o_ in optimizer.values(): + o_.zero_grad() + else: + optimizer.zero_grad() + + with autocast(): + pred = model(audio, mix_lambda=mix_lambda, device=device) + total_loss = loss(pred, class_label) + + if isinstance(optimizer, dict): + if scaler is not None: + scaler.scale(total_loss).backward() + for o_ in optimizer.values(): + if args.horovod: + o_.synchronize() + scaler.unscale_(o_) + with o_.skip_synchronize(): + scaler.step(o_) + else: + scaler.step(o_) + scaler.update() + else: + total_loss.backward() + for o_ in optimizer.values(): + o_.step() + else: + if scaler is not None: + scaler.scale(total_loss).backward() + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + optimizer.step() + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100)) + unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i + 1 + + if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): + if isinstance(audio, dict): + batch_size = len(audio["waveform"]) + else: + batch_size = len(audio) + num_samples = batch_count * batch_size * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + loss_m.update(total_loss.item(), batch_size) + if isinstance(optimizer, dict): + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "lr": optimizer.param_groups[0]["lr"], + } + for name, val in log_data.items(): + name = f"train{extra_suffix}/{name}" + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, "Please install wandb." + wandb.log({name: val, "step": step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + + +def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): + metrics = {} + if not args.parallel_eval: + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + # CHANGE + # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + # metrics.update(zero_shot_metrics) + if is_master(args): + print("Evaluating...") + metric_names = args.lp_metrics.split(",") + eval_tool = LPMetrics(metric_names=metric_names) + + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + if "val" in data and ( + args.val_frequency + and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) + ): + if args.parallel_eval: + dataloader, sampler = data["val"].dataloader, data["val"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + samples_per_val = dataloader.num_samples + else: + dataloader = data["val"].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + eval_info = {"pred": [], "target": []} + with torch.no_grad(): + for i, batch in enumerate(dataloader): + audio = batch # contains mel_spec, wavform, and longer list + class_label = batch["class_label"] + + # audio = audio.to(device=device, non_blocking=True) + class_label = class_label.to(device=device, non_blocking=True) + + with autocast(): + pred = model(audio, device=device) + if args.parallel_eval: + pred, class_label = lp_gather_features( + pred, class_label, args.world_size, args.horovod + ) + eval_info["pred"].append(pred) + eval_info["target"].append(class_label) + + num_samples += class_label.shape[0] + + if (i % 100) == 0: # and i != 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" + ) + + if is_master(args): + eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu() + eval_info["target"] = torch.cat(eval_info["target"], 0).cpu() + metric_dict = eval_tool.evaluate_mertics( + eval_info["pred"], eval_info["target"] + ) + metrics.update(metric_dict) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + + if is_master(args): + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\n".join( + ["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics] + ) + ) + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, "Please install wandb." + for name, val in metrics.items(): + wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch}) + + return metrics + else: + return metrics diff --git a/models/CLAP/training/main.py b/models/CLAP/training/main.py new file mode 100755 index 0000000000000000000000000000000000000000..3b563a5d001be7adfbe779dee7ad8ac49aadc50d --- /dev/null +++ b/models/CLAP/training/main.py @@ -0,0 +1,596 @@ +from inspect import getargs +import logging +import os +import random +from datetime import datetime +import bisect +import copy +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch import optim +from torch.cuda.amp import GradScaler +import faulthandler +import pathlib + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from open_clip import create_model_and_transforms, trace_model, create_model +from training.data import get_data +from training.distributed import is_master, init_distributed_device, world_info_from_env +from training.logger import setup_logging +from training.params import parse_args +from training.scheduler import cosine_lr +from training.train import train_one_epoch, evaluate +from open_clip.utils import dataset_split, get_optimizer + + +def maintain_ckpts(args, startidx, all_idx_len): + for i in reversed(range(startidx, all_idx_len)): + if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): + os.rename( + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), + ) + if os.path.exists( + os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") + ): + os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) + return + + +def update_top_k_performance( + new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True +): + """ + Record the top-k performance of the current epoch. + current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} + """ + if isinstance(new_metrics_inputs, (list, tuple)): + new_metrics_inputs = np.mean(new_metrics_inputs) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, dict): + new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, (float, int)): + update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} + sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) + sorted_values = sorted( + current_top_k_ckpt_metrics.values(), reverse=bignumbetter + ) + sorted_values_ = copy.deepcopy(sorted_values) + sorted_values.append(new_metrics_inputs) + sorted_values = sorted(sorted_values, reverse=bignumbetter) + sorted_values = sorted_values[:-1] + + if sorted_values == sorted_values_: + return current_top_k_ckpt_metrics, new_metrics_inputs + else: + for i in range(len(sorted_keys)): + if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: + current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] + update_flag[sorted_keys[i]] = True + for i in range(len(update_flag)): + if update_flag[i]: + maintain_ckpts(args, i, len(sorted_keys)) + torch.save( + ckpt, + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + ) + break + return current_top_k_ckpt_metrics, new_metrics_inputs + + +# def updateifNone(a, b): +# a = b if None else a +# return a + + +def is_pretrained_params(n): + return ( + n.startswith("transformer") + or n in ["positional_embedding", "text_projection"] + or n.startswith("token_embedding") + or n.startswith("ln_final") + or n.startswith("logit_scale_t") + ) + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def main(): + args = parse_args() + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + args.amodel = args.amodel.replace("/", "-") + # download sizes.json file + + # (yusong): the below two lines are for debug + # print("setting up faulthandler") + # faulthandler.register(10) + + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + if args.tmodel == "bert" or args.tmodel == "roberta" or args.tmodel == "bart": + assert ( + args.pretrained == "" or args.pretrained is None + ), "bert/roberta/bart text encoder does not support pretrained models." + + # get the name of the experiments + if args.name is None: + args.name = "-".join( + [ + datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), + f"model_{args.amodel}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ] + ) + + # discover initial world args early so we can log properly + args.distributed = False + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.remotedata and is_master(args): + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + args.log_path = None + if is_master(args, local=args.log_local): + log_base_path = os.path.join(args.logs, args.name) + os.makedirs(log_base_path, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path, log_filename) + if os.path.exists(args.log_path): + print( + "Error. Experiment already exists. Use --name {} to specify a new experiment." + ) + return -1 + + # Set logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # fully initialize distributed device environment + device = init_distributed_device(args) + + args.wandb = "wandb" in args.report_to or "all" in args.report_to + args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to + if is_master(args): + args.tensorboard_path = ( + os.path.join(args.logs, args.name, "tensorboard") + if args.tensorboard + else "" + ) + args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = "" + args.checkpoint_path = "" + + if args.copy_codebase: + copy_codebase(args) + + assert args.precision in ["amp", "fp16", "fp32"] + if args.precision == "fp16": + logging.warning( + "It is recommended to use AMP mixed-precision instead of FP16. " + "FP16 support needs further verification and tuning, especially for train." + ) + + if args.horovod: + logging.info( + f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + elif args.distributed: + logging.info( + f"Running in distributed mode with multiple processes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + else: + logging.info(f"Running with a single process. Device {args.device}.") + + logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}") + + model, model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=True, + pretrained_audio=args.pretrained_audio, + pretrained_text=args.pretrained_text, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type, + ) + + if args.horovod: + with torch.no_grad(): + for param in model.parameters(): + param.set_(param.contiguous()) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if is_master(args): + logging.info("Model:") + logging.info(f"{str(model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args["static_graph"] = True + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True, **ddp_args + ) + + data = get_data(args, model_cfg) + assert len(data), "At least one train or eval dataset must be specified." + if args.trace: + assert "train" not in data, "Cannot train with traced model" + + exclude = ( + lambda n, p: p.ndim < 2 + or "bn" in n + or "ln" in n + or "bias" in n + or "logit_scale" in n + ) + include = lambda n, p: not exclude(n, p) + + named_parameters = list(model.named_parameters()) + + # freeze text encoder + text_freeze_parameters = [p for n, p in named_parameters if "text_branch" in n] + + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + + gain_or_bias_params = [ + p for n, p in named_parameters if exclude(n, p) and p.requires_grad + ] + rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] + + # set wd-related params to 0 if use adam optimizer + if args.optimizer == "adam": + args.wd = 0 + args.wd_pretrained = 0 + args.wd_new = 0 + + if args.train_data is None: + optimizer = None + scheduler = None + else: + total_steps = data["train"].dataloader.num_batches * args.epochs + + if args.split_opt: + for x in ["lr", "beta1", "beta2", "eps", "wd"]: + for y in ["_new", "_pretrained"]: + if getattr(args, x + y) is None: + setattr(args, x + y, getattr(args, x)) + + gain_or_bias_pretrained_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + rest_pretrained_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + gain_or_bias_new_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n)) + ] + rest_new_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n)) + ] + pretrained_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, + { + "params": rest_pretrained_params, + "weight_decay": args.wd_pretrained, + }, + ], + lr=args.lr_pretrained, + betas=(args.beta1_pretrained, args.beta2_pretrained), + eps=args.eps_pretrained, + momentum=args.momentum_pretrained, + optimizer_name=args.optimizer, + ) + pretrained_params_scheduler = cosine_lr( + pretrained_params_optimizer, + args.lr_pretrained, + args.warmup, + total_steps, + ) + new_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_new_params, "weight_decay": 0.0}, + {"params": rest_new_params, "weight_decay": args.wd_new}, + ], + lr=args.lr_new, + betas=(args.beta1_new, args.beta2_new), + eps=args.eps_new, + momentum=args.momentum_new, + optimizer_name=args.optimizer, + ) + + new_params_scheduler = cosine_lr( + new_params_optimizer, args.lr_new, args.warmup, total_steps + ) + + optimizer = { + "pretrained": pretrained_params_optimizer, + "new": new_params_optimizer, + } + scheduler = { + "pretrained": pretrained_params_scheduler, + "new": new_params_scheduler, + } + + if args.horovod: + pretrained_params_optimizer = hvd.DistributedOptimizer( + pretrained_params_optimizer, + named_parameters=model.named_parameters(), + ) + new_params_optimizer = hvd.DistributedOptimizer( + new_params_optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0) + hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) + else: + optimizer = get_optimizer( + [ + {"params": gain_or_bias_params, "weight_decay": 0.0}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=args.momentum, + optimizer_name=args.optimizer, + ) + + scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) + + if args.horovod: + optimizer = hvd.DistributedOptimizer( + optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module.") :]: v for k, v in sd.items()} + model.load_state_dict(sd) + if args.split_opt: + if optimizer is not None: + for k, o_ in optimizer.items(): + o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and "scaler" in checkpoint: + scaler.load_state_dict(checkpoint["scaler"]) + logging.info( + f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info( + f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" + ) + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + cudnn.deterministic = False + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, "Please install wandb." + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + project="clap", + notes=args.wandb_notes, + name=args.wandb_notes, + tags=[], + config=vars(args), + ) + if args.debug: + wandb.watch(model, log="all") + wandb.save(params_file) + logging.debug("Finished loading wandb.") + + if "train" not in data: + evaluate(model, data, start_epoch, args, writer) + return + elif start_epoch == 0 and "val" in data and not args.no_eval: + evaluate(model, data, 0, args, writer) + # print(f'rank {args.rank}, Start First Evaluation')# (yusong): for debug + if args.save_top_performance: + current_top_k_ckpt_metrics = { + i: 0 for i in range(args.save_top_performance) + } # initialize the top-k metric for ckpts to 0 + + # print(f'rank {args.rank}, Start Training') # (yusong): for debug + for epoch in range(start_epoch, args.epochs): + # freeze the text param after (include) args.freeze_text_after, this is -1 by default + if epoch == args.freeze_text_after: + print("Text pretrained parameters are freezed since this epoch.") + for k in text_freeze_parameters: + k.requires_grad = False + if is_master(args): + logging.info(f"Start epoch {epoch}") + + train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + completed_epoch = epoch + 1 + + if ( + any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) + and not args.no_eval + ): + metrics = evaluate(model, data, completed_epoch, args, writer) + if args.save_top_performance: + top_k_dataset = args.top_k_checkpoint_select_dataset + top_k_metric = args.top_k_checkpoint_select_metric + filtered_metrics = [ + v + for k, v in metrics.items() + if top_k_metric in k and top_k_dataset in k + ] # check all R@10 metrics (all dataset) and use it to update the ckpt + # Saving checkpoints. + if args.save_logs: + if args.split_opt: + opt_dict = { + k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() + } + else: + opt_dict = {"optimizer": optimizer.state_dict()} + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + } + checkpoint_dict.update(opt_dict) + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.save_most_recent: + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_latest.pt"), + ) + if args.save_top_performance and not args.no_eval: + update_top_k_performance( + filtered_metrics, + current_top_k_ckpt_metrics, + args, + checkpoint_dict, + bignumbetter=True, + ) + + if args.wandb and is_master(args): + wandb.finish() + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree( + current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") + ) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main() diff --git a/models/CLAP/training/params.py b/models/CLAP/training/params.py new file mode 100755 index 0000000000000000000000000000000000000000..0cc1a0e2d982e900988cf5a4b24b2e59b093537b --- /dev/null +++ b/models/CLAP/training/params.py @@ -0,0 +1,563 @@ +import argparse + + +def get_default_params(model_name): + # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) + model_name = model_name.lower() + if "vit" in model_name: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} + else: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--train-data", + type=str, + default=None, + help="Path to h5 filewith training data", + ) + parser.add_argument( + "--val-data", + type=str, + default=None, + help="Path to h5 file with validation data", + ) + parser.add_argument( + "--freeze-text", + default=False, + action="store_true", + help="if you need to freeze the text encoder, make this True", + ) + parser.add_argument( + "--freeze-text-after", + type=int, + default=-1, + help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it", + ) + parser.add_argument( + "--train-ipc", + type=str, + default=None, + help="Path to npy file of the number of instance per class in training data", + ) + parser.add_argument( + "--val-ipc", + type=str, + default=None, + help="Path to npy file of the number of instance per class in validation data", + ) + parser.add_argument( + "--train-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Required for webdataset if not available in info file.", + ) + parser.add_argument( + "--val-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Useful for webdataset if not available in info file.", + ) + parser.add_argument( + "--dataset-type", + choices=["webdataset", "csv", "auto", "toy"], + default="auto", + help="Which type of dataset to process.", + ) + parser.add_argument( + "--csv-separator", + type=str, + default="\t", + help="For csv-like datasets, which separator to use.", + ) + parser.add_argument( + "--csv-img-key", + type=str, + default="filepath", + help="For csv-like datasets, the name of the key for the image paths.", + ) + parser.add_argument( + "--csv-caption-key", + type=str, + default="title", + help="For csv-like datasets, the name of the key for the captions.", + ) + parser.add_argument( + "--imagenet-val", + type=str, + default=None, + help="Path to imagenet val set for conducting zero shot evaluation.", + ) + parser.add_argument( + "--imagenet-v2", + type=str, + default=None, + help="Path to imagenet v2 for conducting zero shot evaluation.", + ) + parser.add_argument( + "--datasetnames", + nargs="+", + default=None, + help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects", + ) + parser.add_argument( + "--full-train-dataset", + nargs="+", + default=None, + help="Which dataset will be trained with all the subsets. (train+test)", + ) + parser.add_argument( + "--exclude-eval-dataset", + nargs="+", + default=None, + help="Which dataset will be excluded with evaluation", + ) + parser.add_argument( + "--datasetinfos", + nargs="+", + default=None, + help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval", + ) + parser.add_argument( + "--dataset-proportion", + type=float, + default=1.0, + help="How much proportion of dataset we want to train.", + ) + parser.add_argument( + "--remotedata", + default=False, + action="store_true", + help="if the dataset is remote, set this flag", + ) + parser.add_argument( + "--class-label-path", + type=str, + default=None, + help="The path of the class label pickle or csv.", + ) + parser.add_argument( + "--datasetpath", + type=str, + default="/mnt/audio_clip/webdataset_tar", + help="The path to the dataset", + ) + parser.add_argument( + "--logs", + type=str, + default="./logs/", + help="Where to store tensorboard logs. Use None to avoid storing logs.", + ) + parser.add_argument( + "--log-local", + action="store_true", + default=False, + help="log files on local master, otherwise global master only.", + ) + parser.add_argument( + "--name", + type=str, + default=None, + help="Optional identifier for the experiment when storing logs. Otherwise use current time.", + ) + parser.add_argument( + "--workers", type=int, default=1, help="Number of workers per GPU." + ) + parser.add_argument( + "--batch-size", type=int, default=64, help="Batch size per GPU." + ) + parser.add_argument( + "--epochs", type=int, default=32, help="Number of epochs to train for." + ) + parser.add_argument("--lr", type=float, default=None, help="Learning rate.") + parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") + parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") + parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") + parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.") + parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") + + parser.add_argument( + "--split-opt", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--lr-pretrained", type=float, default=None, help="Learning rate for text." + ) + parser.add_argument( + "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text." + ) + parser.add_argument( + "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text." + ) + parser.add_argument( + "--eps-pretrained", type=float, default=None, help="Adam epsilon for text." + ) + parser.add_argument( + "--wd-pretrained", type=float, default=0.2, help="Weight decay for text." + ) + parser.add_argument( + "--momentum-pretrained", type=float, default=0.9, help="Momentum for text." + ) + parser.add_argument( + "--lr-new", type=float, default=None, help="Learning rate for audio." + ) + parser.add_argument( + "--beta1-new", type=float, default=None, help="Adam beta 1 for audio." + ) + parser.add_argument( + "--beta2-new", type=float, default=None, help="Adam beta 2 for audio." + ) + parser.add_argument( + "--eps-new", type=float, default=None, help="Adam epsilon for audio." + ) + parser.add_argument( + "--wd-new", type=float, default=0.2, help="Weight decay for audio." + ) + parser.add_argument( + "--momentum-new", type=float, default=0.9, help="Momentum for audio." + ) + parser.add_argument( + "--warmup", type=int, default=10000, help="Number of steps to warmup for." + ) + parser.add_argument( + "--use-bn-sync", + default=False, + action="store_true", + help="Whether to use batch norm sync.", + ) + parser.add_argument( + "--skip-scheduler", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--save-frequency", type=int, default=1, help="How often to save checkpoints." + ) + parser.add_argument( + "--save-top-performance", + type=int, + default=0, + help="Save the top x performance weights if the value >0", + ) + parser.add_argument( + "--save-most-recent", + action="store_true", + default=False, + help="Always save the most recent model trained to epoch_latest.pt.", + ) + parser.add_argument( + "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." + ) + parser.add_argument( + "--val-frequency", + type=int, + default=1, + help="How often to run evaluation with val data.", + ) + parser.add_argument( + "--resume", + default=None, + type=str, + help="path to latest checkpoint (default: none)", + ) + parser.add_argument( + "--precision", + choices=["amp", "fp16", "fp32"], + default="amp", + help="Floating point precision.", + ) + parser.add_argument( + "--amodel", + type=str, + default="RN50", + help="Name of the audio backbone to use.", + ) + parser.add_argument( + "--tmodel", + type=str, + default="transformer", + help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]", + ) + parser.add_argument( + "--pretrained-audio", + default="", + type=str, + help="Use a pretrained audio model weights for the audio encoder of CLAP", + ) + parser.add_argument( + "--pretrained-text", + default="", + type=str, + help="Use a pretrained text model weights for the text encoder of CLAP", + ) + parser.add_argument( + "--pretrained", + default="", + type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--pretrained-image", + default=False, + action="store_true", + help="Load imagenet pretrained weights for image tower backbone if available.", + ) + parser.add_argument( + "--lock-image", + default=False, + action="store_true", + help="Lock full image tower by disabling gradients.", + ) + parser.add_argument( + "--lock-image-unlocked-groups", + type=int, + default=0, + help="Leave last n image tower layer groups unlocked.", + ) + parser.add_argument( + "--lock-image-freeze-bn-stats", + default=False, + action="store_true", + help="Freeze BatchNorm running stats in image tower for any locked layers.", + ) + parser.add_argument( + "--local-loss", + default=False, + action="store_true", + help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)", + ) + parser.add_argument( + "--gather-with-grad", + default=False, + action="store_true", + help="enable full distributed gradient for feature gather", + ) + parser.add_argument( + "--force-quick-gelu", + default=False, + action="store_true", + help="Force use of QuickGELU activation for non-OpenAI transformer models.", + ) + parser.add_argument( + "--torchscript", + default=False, + action="store_true", + help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", + ) + parser.add_argument( + "--trace", + default=False, + action="store_true", + help="torch.jit.trace the model for inference / eval only", + ) + # arguments for distributed training + parser.add_argument( + "--dist-url", + default="env://", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" + ) + parser.add_argument( + "--report-to", + default="", + type=str, + help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']", + ) + parser.add_argument( + "--wandb-notes", default="", type=str, help="Notes if logging with wandb" + ) + parser.add_argument( + "--C", type=float, default=3.16, help="inverse regularizer for logistic reg." + ) + parser.add_argument( + "--debug", + default=False, + action="store_true", + help="If true, more information is logged.", + ) + parser.add_argument( + "--copy-codebase", + default=False, + action="store_true", + help="If true, we copy the entire base on the log diretory, and execute from there.", + ) + parser.add_argument( + "--horovod", + default=False, + action="store_true", + help="Use horovod for distributed training.", + ) + parser.add_argument( + "--ddp-static-graph", + default=False, + action="store_true", + help="Enable static graph optimization for DDP in PyTorch >= 1.11.", + ) + parser.add_argument( + "--no-set-device-rank", + default=False, + action="store_true", + help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", + ) + parser.add_argument("--seed", type=int, default=4242, help="Default random seed.") + + parser.add_argument( + "--top-k-checkpoint-select-dataset", + type=str, + default="all", + help="The dataset of selecting top-k checkpoint.", + ) + + # @R10, @R@5, @R1, mAP@10 + parser.add_argument( + "--top-k-checkpoint-select-metric", + type=str, + default="_R@10", + help="The metric for selecting top-k checkpoint.", + ) + parser.add_argument( + "--openai-model-cache-dir", + type=str, + default="~/.cache/clip", + help="Directory to download OpenAI models.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="adamw", + help="can be AdamW or SGD", + ) + parser.add_argument( + "--parallel-eval", + default=False, + action="store_true", + help="Eval in parallel (multi-GPU, multi-node).", + ) + + parser.add_argument( + "--no-eval", + default=False, + action="store_true", + help="Training without evaluation.", + ) + + parser.add_argument( + "--lp-mlp", + default=False, + action="store_true", + help="Linear Probe using MLP layer or not.", + ) + + parser.add_argument( + "--lp-freeze", + default=False, + action="store_true", + help="Linear Probe using Freeze CLAP or not", + ) + + parser.add_argument( + "--lp-act", + default="None", + type=str, + help="Options are ['relu','elu','prelu','softmax','sigmoid']", + ) + + parser.add_argument( + "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe." + ) + + parser.add_argument( + "--lp-metrics", + type=str, + default="map,mauc,acc", + help="Metrics of Linear Probe.", + ) + + parser.add_argument( + "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe" + ) + parser.add_argument( + "--kappa", + type=float, + default=0, + help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss", + ) + + parser.add_argument( + "--data-filling", + type=str, + default="pad", + help="type of data filling when the audio length is shorter than the max length." + "Can be one of the following: repeat, repeatpad, pad", + ) + parser.add_argument( + "--data-truncating", + type=str, + default="rand_trunc", + help="type of data truncation when the audio length is longer than the max length." + "Can be one of the following: rand_trunc, fusion", + ) + + parser.add_argument( + "--clap-mlploss", + default=False, + action="store_true", + help="Using MLP loss for CLAP model or not", + ) + + parser.add_argument( + "--wandb-id", + type=str, + default=None, + help="the id of wandb experiment to restore.", + ) + + parser.add_argument( + "--sleep", type=float, default=0, help="sleep n seconds before start training" + ) + + # variable length processing + parser.add_argument( + "--enable-fusion", + default=False, + action="store_true", + help="Enable feature funsion for variable-length data", + ) + + parser.add_argument( + "--fusion-type", + type=str, + default="None", + help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']", + ) + + parser.add_argument( + "--mixup", + default=False, + action="store_true", + help="Enable mixup in finetuning training.", + ) + parser.add_argument( + "--text-augment-selection", + type=str, + default=None, + help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']", + ) + + args = parser.parse_args() + + # If some params are not passed, we use the default values based on model name. + default_params = get_default_params(args.amodel) + for name, val in default_params.items(): + if getattr(args, name) is None: + setattr(args, name, val) + + return args diff --git a/models/CLAP/training/scheduler.py b/models/CLAP/training/scheduler.py new file mode 100755 index 0000000000000000000000000000000000000000..7151ffbab25a113673b7627027b443b27f22cb0f --- /dev/null +++ b/models/CLAP/training/scheduler.py @@ -0,0 +1,24 @@ +import numpy as np + + +def assign_learning_rate(optimizer, new_lr): + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + + +def _warmup_lr(base_lr, warmup_length, step): + return base_lr * (step + 1) / warmup_length + + +def cosine_lr(optimizer, base_lr, warmup_length, steps): + def _lr_adjuster(step): + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + e = step - warmup_length + es = steps - warmup_length + lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr + assign_learning_rate(optimizer, lr) + return lr + + return _lr_adjuster diff --git a/models/CLAP/training/train.py b/models/CLAP/training/train.py new file mode 100755 index 0000000000000000000000000000000000000000..f5759c4679d2ee9c0748444adf66b8453cf09728 --- /dev/null +++ b/models/CLAP/training/train.py @@ -0,0 +1,838 @@ +import json +import logging +import math +import os +import time +from contextlib import suppress + +import numpy as np +import torch +import torch.nn.functional as F + +try: + import wandb +except ImportError: + wandb = None + +from open_clip import ClipLoss, gather_features +from .distributed import is_master +from .zero_shot import zero_shot_eval + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def unwrap_model(model): + if hasattr(model, "module"): + return model.module + else: + return model + + +def train_one_epoch( + model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None +): + device = torch.device(args.device) + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + model.train() + loss = ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + weight_loss_kappa=args.kappa, + ) + + dataloader, sampler = data["train"].dataloader, data["train"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + num_batches_per_epoch = dataloader.num_batches + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + # for toy dataset + if args.dataset_type == "toy": + dataloader.dataset.generate_queue() + + loss_m = AverageMeter() + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for i, batch in enumerate(dataloader): + # logging.info(f"batch {i} of {num_batches_per_epoch}") + step = num_batches_per_epoch * epoch + i + if isinstance(scheduler, dict): + for s in scheduler.values(): + s(step) + else: + scheduler(step) + audios = batch # contains mel_spec, wavform, and longer list + texts = batch["text"] + # audios = audios.to(device=device, non_blocking=True) + # texts = texts.to(device=device, non_blocking=True) + + data_time_m.update(time.time() - end) + if isinstance(optimizer, dict): + for o_ in optimizer.values(): + o_.zero_grad() + else: + optimizer.zero_grad() + + with autocast(): + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + logit_scale_a, + logit_scale_t, + ) = model(audios, texts, device) + + if args.clap_mlploss: + total_loss = loss( + audio_features=audio_features, + text_features=text_features, + logit_scale_a=logit_scale_a, + logit_scale_t=logit_scale_t, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + ) + else: + total_loss = loss( + audio_features=audio_features, + text_features=text_features, + logit_scale_a=logit_scale_a, + ) + if isinstance(optimizer, dict): + if scaler is not None: + scaler.scale(total_loss).backward() + for o_ in optimizer.values(): + if args.horovod: + o_.synchronize() + scaler.unscale_(o_) + with o_.skip_synchronize(): + scaler.step(o_) + else: + scaler.step(o_) + scaler.update() + else: + total_loss.backward() + for o_ in optimizer.values(): + o_.step() + else: + if scaler is not None: + scaler.scale(total_loss).backward() + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + optimizer.step() + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).logit_scale_a.clamp_(0, math.log(100)) + if args.clap_mlploss: + unwrap_model(model).logit_scale_t.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i + 1 + if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): + if isinstance(audios, dict): + batch_size = len(audios["waveform"]) + else: + batch_size = len(audios) + num_samples = batch_count * batch_size * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + loss_m.update(total_loss.item(), batch_size) + logit_scale_scalar_a = logit_scale_a.item() + logit_scale_scalar_t = logit_scale_t.item() + if isinstance(optimizer, dict): + if args.clap_mlploss: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + f"Logit Scale Text: {logit_scale_scalar_t:.3f}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "scale_text": logit_scale_scalar_t, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + + else: + if args.clap_mlploss: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + f"Logit Scale Text: {logit_scale_scalar_t:.3f}" + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "scale_text": logit_scale_scalar_t, + "lr": optimizer.param_groups[0]["lr"], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "lr": optimizer.param_groups[0]["lr"], + } + for name, val in log_data.items(): + name = "train/" + name + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, "Please install wandb." + wandb.log({name: val, "step": step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + + +def evaluate(model, data, epoch, args, tb_writer=None): + metrics = {} + if not args.parallel_eval: + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + # CHANGE + # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + # metrics.update(zero_shot_metrics) + if is_master(args): + print("Evaluating...") + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + if args.val_dataset_names == ["Clotho", "audiocaps"]: + # if only clotho and audiocaps are used, then we will use a different evaluation function. + # This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio. + if args.parallel_eval: + # (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps. + raise NotImplementedError( + "Parallel evaluation not supported for eval only Clotho and audiocaps." + ) + val_metrics_per_dataset = evaluate_clotho_audiocaps( + model, data, epoch, args, autocast, device, tb_writer + ) + for m in val_metrics_per_dataset.values(): + metrics.update(m) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + metrics = select_top_metric_clotho_audiocaps( + metrics, val_metrics_per_dataset, args + ) + elif "val" in data and ( + args.val_frequency + and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) + ): + dataloader = data["val"].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + # FIXME this does not scale past small eval datasets + # all_audio_features @ all_text_features will blow up memory and compute very quickly + eval_info = {} + if args.clap_mlploss: + eval_info["all"] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + "all_audio_features_mlp": [], + "all_text_features_mlp": [], + } # cumulative_loss = 0.0 + else: + eval_info["all"] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } # cumu + # all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], [] + with torch.no_grad(): + for i, batch in enumerate(dataloader): + audios = batch # contains mel_spec, wavform, and longer list + texts = batch["text"] + # audios = audios.to(device=device, non_blocking=True) + + all_names = list( + set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) + ) + for name in all_names: + if name not in eval_info.keys(): + if args.clap_mlploss: + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + "all_audio_features_mlp": [], + "all_text_features_mlp": [], + } + else: + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + with autocast(): + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + logit_scale_a, + logit_scale_t, + ) = model(audios, texts, device) + + if args.parallel_eval: + # multi-GPU eval + if args.clap_mlploss: + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + ) = gather_features( + audio_features=audio_features, + text_features=text_features, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + local_loss=False, + gather_with_grad=False, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + ) + else: + (audio_features, text_features,) = gather_features( + audio_features=audio_features, + text_features=text_features, + local_loss=False, + gather_with_grad=False, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + ) + + if is_master(args): + num_samples += audio_features.shape[0] + for n in [*all_names, "all"]: + if n == "all": + eval_info[n]["all_audio_features"].append( + audio_features.cpu() + ) + eval_info[n]["all_text_features"].append( + text_features.cpu() + ) + if args.clap_mlploss: + eval_info[n]["all_audio_features_mlp"].append( + audio_features_mlp.cpu() + ) + eval_info[n]["all_text_features_mlp"].append( + text_features_mlp.cpu() + ) + else: + idx = np.where( + np.array( + [ + "-".join(b.split("/")[-3:-1]) + for b in batch["__url__"] + ] + ) + == n + )[0] + eval_info[n]["all_audio_features"].append( + audio_features.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + eval_info[n]["all_text_features"].append( + text_features.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + if args.clap_mlploss: + eval_info[n]["all_audio_features_mlp"].append( + audio_features_mlp.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + eval_info[n]["all_text_features_mlp"].append( + text_features_mlp.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + # print(f'eval step {i}') # (yusong): for debug + + # cumulative_loss += total_loss * batch_size + # num_samples += batch_size + if is_master(args) and (i % 100) == 0: # and i != 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" + ) + if is_master(args): + val_metrics_per_dataset = {} + for n in eval_info.keys(): + if args.clap_mlploss: + metrics_single_dataset = get_metrics( + audio_features=torch.cat( + eval_info[n]["all_audio_features"] + ), + text_features=torch.cat(eval_info[n]["all_text_features"]), + logit_scale_a=logit_scale_a.cpu(), + audio_features_mlp=torch.cat( + eval_info[n]["all_audio_features_mlp"] + ), + text_features_mlp=torch.cat( + eval_info[n]["all_text_features_mlp"] + ), + logit_scale_t=logit_scale_t.cpu(), + mlp_loss=args.clap_mlploss, + ) + else: + metrics_single_dataset = get_metrics( + audio_features=torch.cat( + eval_info[n]["all_audio_features"] + ), + text_features=torch.cat(eval_info[n]["all_text_features"]), + logit_scale_a=logit_scale_a.cpu(), + mlp_loss=args.clap_mlploss, + ) + val_metrics_per_dataset[n] = { + n + "/" + k: v for k, v in metrics_single_dataset.items() + } + metrics.update(val_metrics_per_dataset[n]) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + if is_master(args): + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\n".join( + [ + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()]) + for m in val_metrics_per_dataset.values() + ] + ) + ) + + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, "Please install wandb." + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, "epoch": epoch}) + + return metrics + else: + return metrics + + +def get_metrics( + audio_features, + text_features, + logit_scale_a, + audio_features_mlp=None, + text_features_mlp=None, + logit_scale_t=None, + mlp_loss=False, +): + metrics = {} + if mlp_loss: + # Set up audio to text & text to audio similary matrice + a_logits_per_audio = ( + (logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu() + ) + a_logits_per_text = a_logits_per_audio.t().detach().cpu() + t_logits_per_audio = ( + (logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu() + ) + t_logits_per_text = t_logits_per_audio.t().detach().cpu() + + labels = torch.arange(audio_features.shape[0]).long() + # Change the loss from two terms into four terms with 2x2 combined CE loss + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels) + + F.cross_entropy(a_logits_per_text, labels) + + F.cross_entropy(t_logits_per_audio, labels) + + F.cross_entropy(t_logits_per_text, labels) + ) / 4 + + metrics[f"cumulative_loss"] = total_loss.item() + metrics[f"num_samples"] = audio_features.shape[0] + + logits = { + "audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2, + "text_to_audio": (a_logits_per_text + t_logits_per_text) / 2, + } + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + else: + # print("text_features", text_features) + # print("text_features.shape", text_features.shape) + logits_per_audio = ( + (logit_scale_a * audio_features @ text_features.t()).detach().cpu() + ) + logits_per_text = logits_per_audio.t().detach().cpu() + + labels = torch.arange(audio_features.shape[0]).long() + # Change the loss from two terms into four terms with 2x2 combined CE loss + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + metrics[f"cumulative_loss"] = total_loss.item() + metrics[f"num_samples"] = audio_features.shape[0] + + logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text} + + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + for name, logit in logits.items(): + ranking = torch.argsort(logit, descending=True) + preds = torch.where(ranking == ground_truth)[ + 1 + ] # (yusong) this line is slow because it uses single thread + preds = preds.detach().cpu().numpy() + metrics[f"{name}_mean_rank"] = preds.mean() + 1 + metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = np.mean(preds < k) + # map@10 + metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) + + return metrics + + +def evaluate_clotho_audiocaps( + model, data, epoch, args, autocast, device, tb_writer=None +): + """ + Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py. + 1. for text-to-audio retrieval, do 5 times and average the results + 2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text + 3. for map@10 in audio-to-text retrieval: + 3.1: sort the rank of 5 text + 3.2: exclude the rank >=10 (0-index) + 3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks). + (3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth. + (3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc. + """ + # TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now. + dataloader = data["val"].dataloader + with torch.no_grad(): + eval_info = {} + for i, batch in enumerate(dataloader): + audios = batch # contains mel_spec, wavform, and longer list + + # each item in the list has 5 texts + if args.tmodel == "transformer": + from open_clip import tokenize + + texts = [tokenize(t) for t in batch["full_text"]] + texts = torch.cat(texts) + else: + from .data import tokenizer + + texts = [ + tokenizer(t) for t in batch["full_text"] + ] # 5 texts for each audio + texts = { + k: torch.cat([t[k] for t in texts]) for k in texts[0].keys() + } # 5 x batch + + # audios = audios.to(device=device, non_blocking=True) + + all_names = list( + set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) + ) + for name in all_names: + if name not in eval_info.keys(): + # we will not use mlp outputs even if args.clap_mlploss=True + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + with autocast(): + audio_features = model(audios, None, device) + text_features = model(None, texts, device) + audio_features = F.normalize(audio_features, dim=-1) + text_features = F.normalize(text_features, dim=-1) + + all_names = list( + set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) + ) + for n in all_names: + idx = np.where( + np.array( + ["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]] + ) + == n + )[0] + eval_info[n]["all_audio_features"].append( + audio_features.cpu().index_select(0, torch.tensor(idx).long()) + ) + # (yusong) please double-check. This is for selecting 5 text features at once. + # because idx is a list of indices in size of num_samples, + # and text_features is a tensor of size (5*num_samples, dim) + # so we need to select 5 consecutive indices at once for a single index in idx. + eval_info[n]["all_text_features"].append( + text_features.cpu() + .reshape([-1, 5, text_features.shape[1]]) + .index_select(0, torch.tensor(idx).long()) + .reshape([-1, text_features.shape[1]]) + ) + + val_metrics_all = {} + + for n in eval_info.keys(): + logit_scale_a, logit_scale_t = model(None, None, device) + logit_scale_a = logit_scale_a.cpu() + + audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0) + text_features = torch.cat(eval_info[n]["all_text_features"], dim=0) + + logits_per_audio = ( + (logit_scale_a * audio_features @ text_features.t()).detach().cpu() + ) + logits_per_text = logits_per_audio.t().detach().cpu() + + # logits_per_audio shape: [num_samples, num_samples*5] + # logits_per_text shape: [num_samples*5, num_samples] + + logging.info( + f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, " + f"logits_per_text shape: {logits_per_text.shape}" + ) + + metrics = {} + num_samples = audio_features.shape[0] + metrics[f"num_samples"] = num_samples + + # (yusong) the following code is very important, please double-check: + # logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d] + # logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] + # Those two are retrieving one of the 5 text for each audio. + labels = torch.arange(audio_features.shape[0]).long() + audio_to_text_loss = [ + F.cross_entropy( + logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d], + labels, + ) + for d in range(5) + ] + text_to_audio_loss = [ + F.cross_entropy( + logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :], + labels, + ) + for d in range(5) + ] + total_loss = (np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)) / 2 + + metrics[f"cumulative_loss"] = total_loss.item() + + # text to audio: do 5 times + pred_text = [] + for d in range(5): + logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] + ground_truth = torch.arange(len(logit)).view(-1, 1) + ranking = torch.argsort( + logit, descending=True + ) # [num_samples, num_samples] + preds = torch.where(ranking == ground_truth)[1] + pred_text.append(preds.detach().cpu().numpy()) + pred_text_concat = np.concatenate(pred_text, axis=0) # [5*num_samples] + metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1 + metrics[f"text_to_audio_median_rank"] = ( + np.floor(np.median(pred_text_concat)) + 1 + ) + for k in [1, 5, 10]: + metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k) + # map@10 + metrics[f"text_to_audio_mAP@10"] = np.mean( + np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0) + ) + + # audio to text: take the best result + # for audio to text map 10, sort and assign descending ground truth. + # see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103 + # map@10 + map_all = [] + pred_audio_all = [] + for d in range(num_samples): + # logits_per_audio: [num_samples, num_samples*5] + logit_single = logits_per_audio[d, :] # [5*num_samples] + # Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4] + ranking = torch.argsort( + logit_single, descending=True + ) # [5*num_samples] + # ranking: the index of first match, second match, ... + ground_truth = torch.arange(d * 5, d * 5 + 5)[None] + all_pred = torch.where( + torch.stack([ranking] * 5) == ground_truth.view(-1, 1) + )[1] + min_pred = torch.min(all_pred) + pred_audio_all.append(min_pred.detach().cpu().numpy()) + all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy() + # /5 because we have 5 text, so it means for the text rank >=10 we count as 0. + map_single = ( + np.sum( + (np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1)) + ) + / 5 + ) + map_all.append(map_single) + metrics[f"audio_to_text_mAP@10"] = np.mean(map_all) + for k in [1, 5, 10]: + metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k) + + val_metrics_all[n] = {n + "/" + k: v for k, v in metrics.items()} + return val_metrics_all + + +def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset): + """ + Calculate performance for Clotho+AudioCaps for model selection. + """ + selection_performance_all = [] + for n in val_metrics_per_dataset.keys(): + selection_performance = ( + val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"] + + val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"] + ) / 2 + selection_performance_all.append(selection_performance) + return np.mean(selection_performance_all) + + +def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args): + # val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value + # metrics: dict, key: metric name, value: metric value + # Hack: use args to save the top performance + if not hasattr(args, "top_selection_performance"): + selection_performance = calculate_selection_performance_clotho_audiocaps( + val_metrics_per_dataset + ) + # TODO: write the if and else together + metric_update = {} + for n in val_metrics_per_dataset.keys(): + for k in val_metrics_per_dataset[n].keys(): + metric_update[ + k.split("/")[0] + "-top" + "/" + k.split("/")[1] + ] = val_metrics_per_dataset[n][k] + metric_update["top_selection_performance"] = selection_performance + metric_update["top-selection-epoch"] = metrics["epoch"] + metrics.update(metric_update) + args.top_metric = metric_update + args.top_selection_performance = selection_performance + else: + selection_performance_new = calculate_selection_performance_clotho_audiocaps( + val_metrics_per_dataset + ) + selection_performance_old = args.top_selection_performance + if selection_performance_new > selection_performance_old: + metric_update = {} + for n in val_metrics_per_dataset.keys(): + for k in val_metrics_per_dataset[n].keys(): + metric_update[ + k.split("/")[0] + "-top" + "/" + k.split("/")[1] + ] = val_metrics_per_dataset[n][k] + metric_update["top_selection_performance"] = selection_performance_new + metric_update["top-selection-epoch"] = metrics["epoch"] + metrics.update(metric_update) + args.top_metric = metric_update + args.top_selection_performance = selection_performance_new + else: + metrics.update(args.top_metric) + return metrics diff --git a/models/CLAP/training/zero_shot.py b/models/CLAP/training/zero_shot.py new file mode 100755 index 0000000000000000000000000000000000000000..28b8fccc1af17fc69002857a7f529ac041c374f2 --- /dev/null +++ b/models/CLAP/training/zero_shot.py @@ -0,0 +1,95 @@ +# NOTE: This script is currently not supported for CLAP. +import logging +from contextlib import suppress + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from open_clip import tokenize +from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template + + +def zero_shot_classifier(model, classnames, templates, args): + with torch.no_grad(): + zeroshot_weights = [] + for classname in tqdm(classnames): + texts = [template(classname) for template in templates] # format with class + texts = tokenize(texts).to(args.device) # tokenize + if args.distributed and not args.horovod: + class_embeddings = model.module.encode_text(texts) + else: + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) + return zeroshot_weights + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [ + float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) + for k in topk + ] + + +def run(model, classifier, dataloader, args): + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + with torch.no_grad(): + top1, top5, n = 0.0, 0.0, 0.0 + for images, target in tqdm(dataloader, unit_scale=args.batch_size): + images = images.to(args.device) + target = target.to(args.device) + + with autocast(): + # predict + if args.distributed and not args.horovod: + image_features = model.module.encode_image(images) + else: + image_features = model.encode_image(images) + image_features = F.normalize(image_features, dim=-1) + logits = 100.0 * image_features @ classifier + + # measure accuracy + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = top1 / n + top5 = top5 / n + return top1, top5 + + +def zero_shot_eval(model, data, epoch, args): + if "imagenet-val" not in data and "imagenet-v2" not in data: + return {} + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + + logging.info("Starting zero-shot imagenet.") + + logging.info("Building zero-shot classifier") + classifier = zero_shot_classifier( + model, imagenet_classnames, openai_imagenet_template, args + ) + + logging.info("Using classifier") + results = {} + if "imagenet-val" in data: + top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args) + results["imagenet-zeroshot-val-top1"] = top1 + results["imagenet-zeroshot-val-top5"] = top5 + if "imagenet-v2" in data: + top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args) + results["imagenetv2-zeroshot-val-top1"] = top1 + results["imagenetv2-zeroshot-val-top5"] = top5 + + logging.info("Finished zero-shot imagenet.") + + return results diff --git a/models/audiosep.py b/models/audiosep.py new file mode 100644 index 0000000000000000000000000000000000000000..57f262b246ddf47cd684e252eba34e06512d30ba --- /dev/null +++ b/models/audiosep.py @@ -0,0 +1,150 @@ +from typing import Any, Callable, Dict +import random +import lightning.pytorch as pl +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim.lr_scheduler import LambdaLR + + +class AudioSep(pl.LightningModule): + def __init__( + self, + ss_model: nn.Module, + waveform_mixer, + query_encoder, + loss_function, + optimizer_type: str, + learning_rate: float, + lr_lambda_func, + use_text_ratio=1.0, + ): + r"""Pytorch Lightning wrapper of PyTorch model, including forward, + optimization of model, etc. + + Args: + ss_model: nn.Module + anchor_segment_detector: nn.Module + loss_function: function or object + learning_rate: float + lr_lambda: function + """ + + super().__init__() + self.ss_model = ss_model + self.waveform_mixer = waveform_mixer + self.query_encoder = query_encoder + self.query_encoder_type = self.query_encoder.encoder_type + self.use_text_ratio = use_text_ratio + self.loss_function = loss_function + self.optimizer_type = optimizer_type + self.learning_rate = learning_rate + self.lr_lambda_func = lr_lambda_func + + + def forward(self, x): + pass + + def training_step(self, batch_data_dict, batch_idx): + r"""Forward a mini-batch data to model, calculate loss function, and + train for one step. A mini-batch data is evenly distributed to multiple + devices (if there are) for parallel training. + + Args: + batch_data_dict: e.g. + 'audio_text': { + 'text': ['a sound of dog', ...] + 'waveform': (batch_size, 1, samples) + } + batch_idx: int + + Returns: + loss: float, loss function of this mini-batch + """ + # [important] fix random seeds across devices + random.seed(batch_idx) + + batch_audio_text_dict = batch_data_dict['audio_text'] + + batch_text = batch_audio_text_dict['text'] + batch_audio = batch_audio_text_dict['waveform'] + device = batch_audio.device + + mixtures, segments = self.waveform_mixer( + waveforms=batch_audio + ) + + # calculate text embed for audio-text data + if self.query_encoder_type == 'CLAP': + conditions = self.query_encoder.get_query_embed( + modality='hybird', + text=batch_text, + audio=segments.squeeze(1), + use_text_ratio=self.use_text_ratio, + ) + + input_dict = { + 'mixture': mixtures[:, None, :].squeeze(1), + 'condition': conditions, + } + + target_dict = { + 'segment': segments.squeeze(1), + } + + self.ss_model.train() + sep_segment = self.ss_model(input_dict)['waveform'] + sep_segment = sep_segment.squeeze() + # (batch_size, 1, segment_samples) + + output_dict = { + 'segment': sep_segment, + } + + # Calculate loss. + loss = self.loss_function(output_dict, target_dict) + + self.log_dict({"train_loss": loss}) + + return loss + + def test_step(self, batch, batch_idx): + pass + + def configure_optimizers(self): + r"""Configure optimizer. + """ + + if self.optimizer_type == "AdamW": + optimizer = optim.AdamW( + params=self.ss_model.parameters(), + lr=self.learning_rate, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0.0, + amsgrad=True, + ) + else: + raise NotImplementedError + + scheduler = LambdaLR(optimizer, self.lr_lambda_func) + + output_dict = { + "optimizer": optimizer, + "lr_scheduler": { + 'scheduler': scheduler, + 'interval': 'step', + 'frequency': 1, + } + } + + return output_dict + + +def get_model_class(model_type): + if model_type == 'ResUNet30': + from models.resunet import ResUNet30 + return ResUNet30 + + else: + raise NotImplementedError diff --git a/models/base.py b/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..6b70dd804dcf9b9cf3a9aacd84c707852bab2d7c --- /dev/null +++ b/models/base.py @@ -0,0 +1,152 @@ +import torch.nn as nn +import torch +import numpy as np +import torch.nn.functional as F +import math +from torchlibrosa.stft import magphase + + +def init_layer(layer): + """Initialize a Linear or Convolutional layer. """ + nn.init.xavier_uniform_(layer.weight) + + if hasattr(layer, "bias"): + if layer.bias is not None: + layer.bias.data.fill_(0.0) + + +def init_bn(bn): + """Initialize a Batchnorm layer. """ + bn.bias.data.fill_(0.0) + bn.weight.data.fill_(1.0) + + +def init_embedding(layer): + """Initialize a Linear or Convolutional layer. """ + nn.init.uniform_(layer.weight, -1., 1.) + + if hasattr(layer, 'bias'): + if layer.bias is not None: + layer.bias.data.fill_(0.) + + +def init_gru(rnn): + """Initialize a GRU layer. """ + + def _concat_init(tensor, init_funcs): + (length, fan_out) = tensor.shape + fan_in = length // len(init_funcs) + + for (i, init_func) in enumerate(init_funcs): + init_func(tensor[i * fan_in : (i + 1) * fan_in, :]) + + def _inner_uniform(tensor): + fan_in = nn.init._calculate_correct_fan(tensor, "fan_in") + nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) + + for i in range(rnn.num_layers): + _concat_init( + getattr(rnn, "weight_ih_l{}".format(i)), + [_inner_uniform, _inner_uniform, _inner_uniform], + ) + torch.nn.init.constant_(getattr(rnn, "bias_ih_l{}".format(i)), 0) + + _concat_init( + getattr(rnn, "weight_hh_l{}".format(i)), + [_inner_uniform, _inner_uniform, nn.init.orthogonal_], + ) + torch.nn.init.constant_(getattr(rnn, "bias_hh_l{}".format(i)), 0) + + +def act(x, activation): + if activation == "relu": + return F.relu_(x) + + elif activation == "leaky_relu": + return F.leaky_relu_(x, negative_slope=0.01) + + elif activation == "swish": + return x * torch.sigmoid(x) + + else: + raise Exception("Incorrect activation!") + + +class Base: + def __init__(self): + pass + + def spectrogram(self, input, eps=0.): + (real, imag) = self.stft(input) + return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 + + def spectrogram_phase(self, input, eps=0.): + (real, imag) = self.stft(input) + mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 + cos = real / mag + sin = imag / mag + return mag, cos, sin + + + def wav_to_spectrogram_phase(self, input, eps=1e-10): + """Waveform to spectrogram. + + Args: + input: (batch_size, segment_samples, channels_num) + + Outputs: + output: (batch_size, channels_num, time_steps, freq_bins) + """ + sp_list = [] + cos_list = [] + sin_list = [] + channels_num = input.shape[1] + for channel in range(channels_num): + mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps) + sp_list.append(mag) + cos_list.append(cos) + sin_list.append(sin) + + sps = torch.cat(sp_list, dim=1) + coss = torch.cat(cos_list, dim=1) + sins = torch.cat(sin_list, dim=1) + return sps, coss, sins + + def wav_to_spectrogram(self, input, eps=0.): + """Waveform to spectrogram. + + Args: + input: (batch_size, segment_samples, channels_num) + + Outputs: + output: (batch_size, channels_num, time_steps, freq_bins) + """ + sp_list = [] + channels_num = input.shape[1] + for channel in range(channels_num): + sp_list.append(self.spectrogram(input[:, channel, :], eps=eps)) + + output = torch.cat(sp_list, dim=1) + return output + + + def spectrogram_to_wav(self, input, spectrogram, length=None): + """Spectrogram to waveform. + + Args: + input: (batch_size, segment_samples, channels_num) + spectrogram: (batch_size, channels_num, time_steps, freq_bins) + + Outputs: + output: (batch_size, segment_samples, channels_num) + """ + channels_num = input.shape[1] + wav_list = [] + for channel in range(channels_num): + (real, imag) = self.stft(input[:, channel, :]) + (_, cos, sin) = magphase(real, imag) + wav_list.append(self.istft(spectrogram[:, channel : channel + 1, :, :] * cos, + spectrogram[:, channel : channel + 1, :, :] * sin, length)) + + output = torch.stack(wav_list, dim=1) + return output diff --git a/models/clap_encoder.py b/models/clap_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1e00c7ed38997fcd971e4755a306a65676a07429 --- /dev/null +++ b/models/clap_encoder.py @@ -0,0 +1,117 @@ +import random +import torch +import torch.nn as nn +import torchaudio +from models.CLAP.open_clip import create_model +from models.CLAP.training.data import get_audio_features +from transformers import RobertaTokenizer +from utils import ignore_warnings; ignore_warnings() + + +class CLAP_Encoder(nn.Module): + def __init__( + self, + pretrained_path='checkpoint/music_speech_audioset_epoch_15_esc_89.98.pt', + sampling_rate=32000, + amodel = "HTSAT-base", + ): + super().__init__() + self.device = "cpu" + self.precision = "fp32" + self.amodel = amodel # or 'PANN-14' + self.tmodel = "roberta" # the best text encoder in our training + self.enable_fusion = False # False if you do not want to use the fusion model + self.fusion_type = "aff_2d" + self.pretrained = pretrained_path + self.sampling_rate = sampling_rate + self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") + + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device=self.device, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + + for p in self.model.parameters(): + p.requires_grad = False + + self.model.eval() + self.encoder_type = 'CLAP' + + def batch_to_list(self, batch): + ret = [] + for i in range(batch.size(0)): + ret.append(batch[i]) + return ret + + def _get_audio_embed(self, batch): + # batch: [B, samples] + with torch.no_grad(): + audio_dict_list = [] + assert ( + self.sampling_rate == 32000 + ), "We only support 32000 sampling rate" + + # batch: [bs, 1, t-samples] + batch = torchaudio.functional.resample( + batch, orig_freq=self.sampling_rate, new_freq=48000 + ) + for waveform in self.batch_to_list(batch): + audio_dict = {} + audio_dict = get_audio_features( + audio_dict, + waveform, + 480000, + data_truncating="fusion", + data_filling="repeatpad", + audio_cfg=self.model_cfg["audio_cfg"], + ) + audio_dict_list.append(audio_dict) + # [bs, 512] + embed = self.model.get_audio_embedding(audio_dict_list) + + return embed.detach() + + def _get_text_embed(self, batch): + double_batch = False + if len(batch) == 1: + batch = batch * 2 + double_batch = True + with torch.no_grad(): + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + text_data = self.tokenizer(batch) + embed = self.model.get_text_embedding(text_data) + if double_batch: + embed = embed[0].unsqueeze(0) + + return embed.detach() + + + def get_query_embed(self, modality, audio=None, text=None, use_text_ratio=0.5, device=None): + if modality == 'audio': + embed = self._get_audio_embed(audio) + elif modality == 'text': + embed = self._get_text_embed(text) + elif modality == 'hybird': + if random.random() > use_text_ratio: + embed = self._get_audio_embed(audio) + else: + embed = self._get_text_embed(text) + else: + raise NotImplementedError("Please check flag 'training_modality'.") + + return embed.float() + + def tokenizer(self, text): + result = self.tokenize( + text, + padding="max_length", + truncation=True, + max_length=512, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} diff --git a/models/resunet.py b/models/resunet.py new file mode 100644 index 0000000000000000000000000000000000000000..06b7f70a06446cca22871a22c5875f6e3365564b --- /dev/null +++ b/models/resunet.py @@ -0,0 +1,655 @@ +import numpy as np +from typing import Dict, List, NoReturn, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchlibrosa.stft import STFT, ISTFT, magphase +from models.base import Base, init_layer, init_bn, act + + +class FiLM(nn.Module): + def __init__(self, film_meta, condition_size): + super(FiLM, self).__init__() + + self.condition_size = condition_size + + self.modules, _ = self.create_film_modules( + film_meta=film_meta, + ancestor_names=[], + ) + + def create_film_modules(self, film_meta, ancestor_names): + + modules = {} + + # Pre-order traversal of modules + for module_name, value in film_meta.items(): + + if isinstance(value, int): + + ancestor_names.append(module_name) + unique_module_name = '->'.join(ancestor_names) + + modules[module_name] = self.add_film_layer_to_module( + num_features=value, + unique_module_name=unique_module_name, + ) + + elif isinstance(value, dict): + + ancestor_names.append(module_name) + + modules[module_name], _ = self.create_film_modules( + film_meta=value, + ancestor_names=ancestor_names, + ) + + ancestor_names.pop() + + return modules, ancestor_names + + def add_film_layer_to_module(self, num_features, unique_module_name): + + layer = nn.Linear(self.condition_size, num_features) + init_layer(layer) + self.add_module(name=unique_module_name, module=layer) + + return layer + + def forward(self, conditions): + + film_dict = self.calculate_film_data( + conditions=conditions, + modules=self.modules, + ) + + return film_dict + + def calculate_film_data(self, conditions, modules): + + film_data = {} + + # Pre-order traversal of modules + for module_name, module in modules.items(): + + if isinstance(module, nn.Module): + film_data[module_name] = module(conditions)[:, :, None, None] + + elif isinstance(module, dict): + film_data[module_name] = self.calculate_film_data(conditions, module) + + return film_data + + +class ConvBlockRes(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple, + momentum: float, + has_film, + ): + r"""Residual block.""" + super(ConvBlockRes, self).__init__() + + padding = [kernel_size[0] // 2, kernel_size[1] // 2] + + self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) + self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum) + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=(1, 1), + dilation=(1, 1), + padding=padding, + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=(1, 1), + dilation=(1, 1), + padding=padding, + bias=False, + ) + + if in_channels != out_channels: + self.shortcut = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + ) + self.is_shortcut = True + else: + self.is_shortcut = False + + self.has_film = has_film + + self.init_weights() + + def init_weights(self) -> NoReturn: + r"""Initialize weights.""" + init_bn(self.bn1) + init_bn(self.bn2) + init_layer(self.conv1) + init_layer(self.conv2) + + if self.is_shortcut: + init_layer(self.shortcut) + + def forward(self, input_tensor: torch.Tensor, film_dict: Dict) -> torch.Tensor: + r"""Forward data into the module. + + Args: + input_tensor: (batch_size, input_feature_maps, time_steps, freq_bins) + + Returns: + output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins) + """ + b1 = film_dict['beta1'] + b2 = film_dict['beta2'] + + x = self.conv1(F.leaky_relu_(self.bn1(input_tensor) + b1, negative_slope=0.01)) + x = self.conv2(F.leaky_relu_(self.bn2(x) + b2, negative_slope=0.01)) + + if self.is_shortcut: + return self.shortcut(input_tensor) + x + else: + return input_tensor + x + + +class EncoderBlockRes1B(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple, + downsample: Tuple, + momentum: float, + has_film, + ): + r"""Encoder block, contains 8 convolutional layers.""" + super(EncoderBlockRes1B, self).__init__() + + self.conv_block1 = ConvBlockRes( + in_channels, out_channels, kernel_size, momentum, has_film, + ) + self.downsample = downsample + + def forward(self, input_tensor: torch.Tensor, film_dict: Dict) -> torch.Tensor: + r"""Forward data into the module. + + Args: + input_tensor: (batch_size, input_feature_maps, time_steps, freq_bins) + + Returns: + encoder_pool: (batch_size, output_feature_maps, downsampled_time_steps, downsampled_freq_bins) + encoder: (batch_size, output_feature_maps, time_steps, freq_bins) + """ + encoder = self.conv_block1(input_tensor, film_dict['conv_block1']) + encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample) + return encoder_pool, encoder + + +class DecoderBlockRes1B(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple, + upsample: Tuple, + momentum: float, + has_film, + ): + r"""Decoder block, contains 1 transposed convolutional and 8 convolutional layers.""" + super(DecoderBlockRes1B, self).__init__() + self.kernel_size = kernel_size + self.stride = upsample + + self.conv1 = torch.nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self.stride, + stride=self.stride, + padding=(0, 0), + bias=False, + dilation=(1, 1), + ) + + self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) + self.conv_block2 = ConvBlockRes( + out_channels * 2, out_channels, kernel_size, momentum, has_film, + ) + self.bn2 = nn.BatchNorm2d(in_channels, momentum=momentum) + self.has_film = has_film + + self.init_weights() + + def init_weights(self): + r"""Initialize weights.""" + init_bn(self.bn1) + init_layer(self.conv1) + + def forward( + self, input_tensor: torch.Tensor, concat_tensor: torch.Tensor, film_dict: Dict, + ) -> torch.Tensor: + r"""Forward data into the module. + + Args: + input_tensor: (batch_size, input_feature_maps, downsampled_time_steps, downsampled_freq_bins) + concat_tensor: (batch_size, input_feature_maps, time_steps, freq_bins) + + Returns: + output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins) + """ + # b1 = film_dict['beta1'] + + b1 = film_dict['beta1'] + x = self.conv1(F.leaky_relu_(self.bn1(input_tensor) + b1)) + # (batch_size, input_feature_maps, time_steps, freq_bins) + + x = torch.cat((x, concat_tensor), dim=1) + # (batch_size, input_feature_maps * 2, time_steps, freq_bins) + + x = self.conv_block2(x, film_dict['conv_block2']) + # output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins) + + return x + + +class ResUNet30_Base(nn.Module, Base): + def __init__(self, input_channels, output_channels): + super(ResUNet30_Base, self).__init__() + + window_size = 2048 + hop_size = 320 + center = True + pad_mode = "reflect" + window = "hann" + momentum = 0.01 + + self.output_channels = output_channels + self.target_sources_num = 1 + self.K = 3 + + self.time_downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks} + + self.stft = STFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.istft = ISTFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum) + + self.pre_conv = nn.Conv2d( + in_channels=input_channels, + out_channels=32, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + + self.encoder_block1 = EncoderBlockRes1B( + in_channels=32, + out_channels=32, + kernel_size=(3, 3), + downsample=(2, 2), + momentum=momentum, + has_film=True, + ) + self.encoder_block2 = EncoderBlockRes1B( + in_channels=32, + out_channels=64, + kernel_size=(3, 3), + downsample=(2, 2), + momentum=momentum, + has_film=True, + ) + self.encoder_block3 = EncoderBlockRes1B( + in_channels=64, + out_channels=128, + kernel_size=(3, 3), + downsample=(2, 2), + momentum=momentum, + has_film=True, + ) + self.encoder_block4 = EncoderBlockRes1B( + in_channels=128, + out_channels=256, + kernel_size=(3, 3), + downsample=(2, 2), + momentum=momentum, + has_film=True, + ) + self.encoder_block5 = EncoderBlockRes1B( + in_channels=256, + out_channels=384, + kernel_size=(3, 3), + downsample=(2, 2), + momentum=momentum, + has_film=True, + ) + self.encoder_block6 = EncoderBlockRes1B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 2), + momentum=momentum, + has_film=True, + ) + self.conv_block7a = EncoderBlockRes1B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + momentum=momentum, + has_film=True, + ) + self.decoder_block1 = DecoderBlockRes1B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + upsample=(1, 2), + momentum=momentum, + has_film=True, + ) + self.decoder_block2 = DecoderBlockRes1B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + upsample=(2, 2), + momentum=momentum, + has_film=True, + ) + self.decoder_block3 = DecoderBlockRes1B( + in_channels=384, + out_channels=256, + kernel_size=(3, 3), + upsample=(2, 2), + momentum=momentum, + has_film=True, + ) + self.decoder_block4 = DecoderBlockRes1B( + in_channels=256, + out_channels=128, + kernel_size=(3, 3), + upsample=(2, 2), + momentum=momentum, + has_film=True, + ) + self.decoder_block5 = DecoderBlockRes1B( + in_channels=128, + out_channels=64, + kernel_size=(3, 3), + upsample=(2, 2), + momentum=momentum, + has_film=True, + ) + self.decoder_block6 = DecoderBlockRes1B( + in_channels=64, + out_channels=32, + kernel_size=(3, 3), + upsample=(2, 2), + momentum=momentum, + has_film=True, + ) + + self.after_conv = nn.Conv2d( + in_channels=32, + out_channels=output_channels * self.K, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + + self.init_weights() + + def init_weights(self): + init_bn(self.bn0) + init_layer(self.pre_conv) + init_layer(self.after_conv) + + def feature_maps_to_wav( + self, + input_tensor: torch.Tensor, + sp: torch.Tensor, + sin_in: torch.Tensor, + cos_in: torch.Tensor, + audio_length: int, + ) -> torch.Tensor: + r"""Convert feature maps to waveform. + + Args: + input_tensor: (batch_size, target_sources_num * output_channels * self.K, time_steps, freq_bins) + sp: (batch_size, input_channels, time_steps, freq_bins) + sin_in: (batch_size, input_channels, time_steps, freq_bins) + cos_in: (batch_size, input_channels, time_steps, freq_bins) + + (There is input_channels == output_channels for the source separation task.) + + Outputs: + waveform: (batch_size, target_sources_num * output_channels, segment_samples) + """ + batch_size, _, time_steps, freq_bins = input_tensor.shape + + x = input_tensor.reshape( + batch_size, + self.target_sources_num, + self.output_channels, + self.K, + time_steps, + freq_bins, + ) + # x: (batch_size, target_sources_num, output_channels, self.K, time_steps, freq_bins) + + mask_mag = torch.sigmoid(x[:, :, :, 0, :, :]) + _mask_real = torch.tanh(x[:, :, :, 1, :, :]) + _mask_imag = torch.tanh(x[:, :, :, 2, :, :]) + # linear_mag = torch.tanh(x[:, :, :, 3, :, :]) + _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag) + # mask_cos, mask_sin: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) + + # Y = |Y|cos∠Y + j|Y|sin∠Y + # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M) + # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M) + out_cos = ( + cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin + ) + out_sin = ( + sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin + ) + # out_cos: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) + # out_sin: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) + + # Calculate |Y|. + out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag) + # out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag) + # out_mag: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) + + # Calculate Y_{real} and Y_{imag} for ISTFT. + out_real = out_mag * out_cos + out_imag = out_mag * out_sin + # out_real, out_imag: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) + + # Reformat shape to (N, 1, time_steps, freq_bins) for ISTFT where + # N = batch_size * target_sources_num * output_channels + shape = ( + batch_size * self.target_sources_num * self.output_channels, + 1, + time_steps, + freq_bins, + ) + out_real = out_real.reshape(shape) + out_imag = out_imag.reshape(shape) + + # ISTFT. + x = self.istft(out_real, out_imag, audio_length) + # (batch_size * target_sources_num * output_channels, segments_num) + + # Reshape. + waveform = x.reshape( + batch_size, self.target_sources_num * self.output_channels, audio_length + ) + # (batch_size, target_sources_num * output_channels, segments_num) + + return waveform + + + def forward(self, mixtures, film_dict): + """ + Args: + input: (batch_size, segment_samples, channels_num) + + Outputs: + output_dict: { + 'wav': (batch_size, segment_samples, channels_num), + 'sp': (batch_size, channels_num, time_steps, freq_bins)} + """ + + mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures) + x = mag + + # Batch normalization + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + """(batch_size, chanenls, time_steps, freq_bins)""" + + # Pad spectrogram to be evenly divided by downsample ratio. + origin_len = x.shape[2] + pad_len = ( + int(np.ceil(x.shape[2] / self.time_downsample_ratio)) * self.time_downsample_ratio + - origin_len + ) + x = F.pad(x, pad=(0, 0, 0, pad_len)) + """(batch_size, channels, padded_time_steps, freq_bins)""" + + # Let frequency bins be evenly divided by 2, e.g., 513 -> 512 + x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F) + + # UNet + x = self.pre_conv(x) + x1_pool, x1 = self.encoder_block1(x, film_dict['encoder_block1']) # x1_pool: (bs, 32, T / 2, F / 2) + x2_pool, x2 = self.encoder_block2(x1_pool, film_dict['encoder_block2']) # x2_pool: (bs, 64, T / 4, F / 4) + x3_pool, x3 = self.encoder_block3(x2_pool, film_dict['encoder_block3']) # x3_pool: (bs, 128, T / 8, F / 8) + x4_pool, x4 = self.encoder_block4(x3_pool, film_dict['encoder_block4']) # x4_pool: (bs, 256, T / 16, F / 16) + x5_pool, x5 = self.encoder_block5(x4_pool, film_dict['encoder_block5']) # x5_pool: (bs, 384, T / 32, F / 32) + x6_pool, x6 = self.encoder_block6(x5_pool, film_dict['encoder_block6']) # x6_pool: (bs, 384, T / 32, F / 64) + x_center, _ = self.conv_block7a(x6_pool, film_dict['conv_block7a']) # (bs, 384, T / 32, F / 64) + x7 = self.decoder_block1(x_center, x6, film_dict['decoder_block1']) # (bs, 384, T / 32, F / 32) + x8 = self.decoder_block2(x7, x5, film_dict['decoder_block2']) # (bs, 384, T / 16, F / 16) + x9 = self.decoder_block3(x8, x4, film_dict['decoder_block3']) # (bs, 256, T / 8, F / 8) + x10 = self.decoder_block4(x9, x3, film_dict['decoder_block4']) # (bs, 128, T / 4, F / 4) + x11 = self.decoder_block5(x10, x2, film_dict['decoder_block5']) # (bs, 64, T / 2, F / 2) + x12 = self.decoder_block6(x11, x1, film_dict['decoder_block6']) # (bs, 32, T, F) + + x = self.after_conv(x12) + + # Recover shape + x = F.pad(x, pad=(0, 1)) + x = x[:, :, 0:origin_len, :] + + audio_length = mixtures.shape[2] + + # Recover each subband spectrograms to subband waveforms. Then synthesis + # the subband waveforms to a waveform. + separated_audio = self.feature_maps_to_wav( + input_tensor=x, + # input_tensor: (batch_size, target_sources_num * output_channels * self.K, T, F') + sp=mag, + # sp: (batch_size, input_channels, T, F') + sin_in=sin_in, + # sin_in: (batch_size, input_channels, T, F') + cos_in=cos_in, + # cos_in: (batch_size, input_channels, T, F') + audio_length=audio_length, + ) + # (batch_size, target_sources_num * output_channels, subbands_num, segment_samples) + + output_dict = {'waveform': separated_audio} + + return output_dict + + +def get_film_meta(module): + + film_meta = {} + + if hasattr(module, 'has_film'):\ + + if module.has_film: + film_meta['beta1'] = module.bn1.num_features + film_meta['beta2'] = module.bn2.num_features + else: + film_meta['beta1'] = 0 + film_meta['beta2'] = 0 + + for child_name, child_module in module.named_children(): + + child_meta = get_film_meta(child_module) + + if len(child_meta) > 0: + film_meta[child_name] = child_meta + + return film_meta + + +class ResUNet30(nn.Module): + def __init__(self, input_channels, output_channels, condition_size): + super(ResUNet30, self).__init__() + + self.base = ResUNet30_Base( + input_channels=input_channels, + output_channels=output_channels, + ) + + self.film_meta = get_film_meta( + module=self.base, + ) + + self.film = FiLM( + film_meta=self.film_meta, + condition_size=condition_size + ) + + + def forward(self, input_dict): + mixtures = input_dict['mixture'] + conditions = input_dict['condition'] + + film_dict = self.film( + conditions=conditions, + ) + + output_dict = self.base( + mixtures=mixtures, + film_dict=film_dict, + ) + + return output_dict + + diff --git a/optimizers/lr_schedulers.py b/optimizers/lr_schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..07bdaed801b3c547144530b25f215a680aad6819 --- /dev/null +++ b/optimizers/lr_schedulers.py @@ -0,0 +1,101 @@ +from functools import partial +from typing import Callable + + +def linear_warm_up( + step: int, + warm_up_steps: int, + reduce_lr_steps: int +) -> float: + r"""Get linear warm up scheduler for LambdaLR. + + Args: + step (int): global step + warm_up_steps (int): steps for warm up + reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step + + .. code-block: python + >>> lr_lambda = partial(linear_warm_up, warm_up_steps=1000, reduce_lr_steps=10000) + >>> from torch.optim.lr_scheduler import LambdaLR + >>> LambdaLR(optimizer, lr_lambda) + + Returns: + lr_scale (float): learning rate scaler + """ + + if step <= warm_up_steps: + lr_scale = step / warm_up_steps + else: + lr_scale = 0.9 ** (step // reduce_lr_steps) + + return lr_scale + + +def constant_warm_up( + step: int, + warm_up_steps: int, + reduce_lr_steps: int +) -> float: + r"""Get constant warm up scheduler for LambdaLR. + + Args: + step (int): global step + warm_up_steps (int): steps for warm up + reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step + + .. code-block: python + >>> lr_lambda = partial(constant_warm_up, warm_up_steps=1000, reduce_lr_steps=10000) + >>> from torch.optim.lr_scheduler import LambdaLR + >>> LambdaLR(optimizer, lr_lambda) + + Returns: + lr_scale (float): learning rate scaler + """ + + if 0 <= step < warm_up_steps: + lr_scale = 0.001 + + elif warm_up_steps <= step < 2 * warm_up_steps: + lr_scale = 0.01 + + elif 2 * warm_up_steps <= step < 3 * warm_up_steps: + lr_scale = 0.1 + + else: + lr_scale = 1 + + return lr_scale + + +def get_lr_lambda( + lr_lambda_type: str, + **kwargs +) -> Callable: + r"""Get learning scheduler. + + Args: + lr_lambda_type (str), e.g., "constant_warm_up" | "linear_warm_up" + + Returns: + lr_lambda_func (Callable) + """ + if lr_lambda_type == "constant_warm_up": + + lr_lambda_func = partial( + constant_warm_up, + warm_up_steps=kwargs["warm_up_steps"], + reduce_lr_steps=kwargs["reduce_lr_steps"], + ) + + elif lr_lambda_type == "linear_warm_up": + + lr_lambda_func = partial( + linear_warm_up, + warm_up_steps=kwargs["warm_up_steps"], + reduce_lr_steps=kwargs["reduce_lr_steps"], + ) + + else: + raise NotImplementedError + + return lr_lambda_func diff --git a/pipeline.py b/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..ca10a2ba413c13a3fb54214d68e11bdf78dffbd2 --- /dev/null +++ b/pipeline.py @@ -0,0 +1,69 @@ +import yaml +from typing import Dict, List +import torch +import torch.nn as nn +import numpy as np +import librosa +from scipy.io.wavfile import write +from utils import ignore_warnings; ignore_warnings() +from utils import parse_yaml, load_ss_model +from models.clap_encoder import CLAP_Encoder + + +def build_audiosep(config_yaml, checkpoint_path, device): + configs = parse_yaml(config_yaml) + + query_encoder = CLAP_Encoder().eval() + model = load_ss_model( + configs=configs, + checkpoint_path=checkpoint_path, + query_encoder=query_encoder + ).eval().to(device) + + print(f'Load AudioSep model from [{checkpoint_path}]') + return model + + +def inference(model, audio_file, text, output_file, device='cuda'): + print(f'Separate audio from [{audio_file}] with textual query [{text}]') + mixture, fs = librosa.load(audio_file, sr=32000, mono=True) + with torch.no_grad(): + text = [text] + + conditions = model.query_encoder.get_query_embed( + modality='text', + text=text, + device=device + ) + + input_dict = { + "mixture": torch.Tensor(mixture)[None, None, :].to(device), + "condition": conditions, + } + + sep_segment = model.ss_model(input_dict)["waveform"] + + sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() + + write(output_file, 32000, np.round(sep_segment * 32767).astype(np.int16)) + print(f'Write separated audio to [{output_file}]') + + +if __name__ == '__main__': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = build_audiosep( + config_yaml='config/audiosep_base.yaml', + checkpoint_path='checkpoint/step=3920000.ckpt', + device=device) + + audio_file = '/mnt/bn/data-xubo/project/AudioShop/YT_audios/Y3VHpLxtd498.wav' + text = 'pigeons are cooing in the background' + output_file='separated_audio.wav' + + inference(model, audio_file, text, output_file, device) + + + + + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..231b8465a0841d36fe5ef896cb8cba1b0cfada6e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +gdown +lightning==2.0.0 +transformers==4.28.1 +ftfy==6.1.1 +braceexpand==0.1.7 +webdataset==0.2.48 +soundfile==0.12.1 +wget==3.2 +h5py==3.8.0 +gradio==3.47.1 +torchlibrosa==0.1.0 \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..acde85b20c7e1abd4b5f8fc732470a80c8428d82 --- /dev/null +++ b/train.py @@ -0,0 +1,307 @@ +import argparse +import logging +import os +import pathlib +from typing import List, NoReturn +import lightning.pytorch as pl +from lightning.pytorch.strategies import DDPStrategy +from torch.utils.tensorboard import SummaryWriter +from data.datamodules import * +from utils import create_logging, parse_yaml +from models.resunet import * +from losses import get_loss_function +from models.audiosep import AudioSep, get_model_class +from data.waveform_mixers import SegmentMixer +from models.clap_encoder import CLAP_Encoder +from callbacks.base import CheckpointEveryNSteps +from optimizers.lr_schedulers import get_lr_lambda + + +def get_dirs( + workspace: str, + filename: str, + config_yaml: str, + devices_num: int +) -> List[str]: + r"""Get directories and paths. + + Args: + workspace (str): directory of workspace + filename (str): filename of current .py file. + config_yaml (str): config yaml path + devices_num (int): 0 for cpu and 8 for training with 8 GPUs + + Returns: + checkpoints_dir (str): directory to save checkpoints + logs_dir (str), directory to save logs + tf_logs_dir (str), directory to save TensorBoard logs + statistics_path (str), directory to save statistics + """ + + os.makedirs(workspace, exist_ok=True) + + yaml_name = pathlib.Path(config_yaml).stem + + # Directory to save checkpoints + checkpoints_dir = os.path.join( + workspace, + "checkpoints", + filename, + "{},devices={}".format(yaml_name, devices_num), + ) + os.makedirs(checkpoints_dir, exist_ok=True) + + # Directory to save logs + logs_dir = os.path.join( + workspace, + "logs", + filename, + "{},devices={}".format(yaml_name, devices_num), + ) + os.makedirs(logs_dir, exist_ok=True) + + # Directory to save TensorBoard logs + create_logging(logs_dir, filemode="w") + logging.info(args) + + tf_logs_dir = os.path.join( + workspace, + "tf_logs", + filename, + "{},devices={}".format(yaml_name, devices_num), + ) + + # Directory to save statistics + statistics_path = os.path.join( + workspace, + "statistics", + filename, + "{},devices={}".format(yaml_name, devices_num), + "statistics.pkl", + ) + os.makedirs(os.path.dirname(statistics_path), exist_ok=True) + + return checkpoints_dir, logs_dir, tf_logs_dir, statistics_path + + +def get_data_module( + config_yaml: str, + num_workers: int, + batch_size: int, +) -> DataModule: + r"""Create data_module. Mini-batch data can be obtained by: + + code-block:: python + + data_module.setup() + + for batch_data_dict in data_module.train_dataloader(): + print(batch_data_dict.keys()) + break + + Args: + workspace: str + config_yaml: str + num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores + for preparing data in parallel + distributed: bool + + Returns: + data_module: DataModule + """ + + # read configurations + configs = parse_yaml(config_yaml) + sampling_rate = configs['data']['sampling_rate'] + segment_seconds = configs['data']['segment_seconds'] + + # audio-text datasets + datafiles = configs['data']['datafiles'] + + # dataset + dataset = AudioTextDataset( + datafiles=datafiles, + sampling_rate=sampling_rate, + max_clip_len=segment_seconds, + ) + + + # data module + data_module = DataModule( + train_dataset=dataset, + num_workers=num_workers, + batch_size=batch_size + ) + + return data_module + + +def train(args) -> NoReturn: + r"""Train, evaluate, and save checkpoints. + + Args: + workspace: str, directory of workspace + gpus: int, number of GPUs to train + config_yaml: str + """ + + # arguments & parameters + workspace = args.workspace + config_yaml = args.config_yaml + filename = args.filename + + devices_num = torch.cuda.device_count() + # Read config file. + configs = parse_yaml(config_yaml) + + # Configuration of data + max_mix_num = configs['data']['max_mix_num'] + sampling_rate = configs['data']['sampling_rate'] + lower_db = configs['data']['loudness_norm']['lower_db'] + higher_db = configs['data']['loudness_norm']['higher_db'] + + # Configuration of the separation model + query_net = configs['model']['query_net'] + model_type = configs['model']['model_type'] + input_channels = configs['model']['input_channels'] + output_channels = configs['model']['output_channels'] + condition_size = configs['model']['condition_size'] + use_text_ratio = configs['model']['use_text_ratio'] + + # Configuration of the trainer + num_nodes = configs['train']['num_nodes'] + batch_size = configs['train']['batch_size_per_device'] + sync_batchnorm = configs['train']['sync_batchnorm'] + num_workers = configs['train']['num_workers'] + loss_type = configs['train']['loss_type'] + optimizer_type = configs["train"]["optimizer"]["optimizer_type"] + learning_rate = float(configs['train']["optimizer"]['learning_rate']) + lr_lambda_type = configs['train']["optimizer"]['lr_lambda_type'] + warm_up_steps = configs['train']["optimizer"]['warm_up_steps'] + reduce_lr_steps = configs['train']["optimizer"]['reduce_lr_steps'] + save_step_frequency = configs['train']['save_step_frequency'] + resume_checkpoint_path = args.resume_checkpoint_path + if resume_checkpoint_path == "": + resume_checkpoint_path = None + else: + logging.info(f'Finetuning AudioSep with checkpoint [{resume_checkpoint_path}]') + + # Get directories and paths + checkpoints_dir, logs_dir, tf_logs_dir, statistics_path = get_dirs( + workspace, filename, config_yaml, devices_num, + ) + + logging.info(configs) + + # data module + data_module = get_data_module( + config_yaml=config_yaml, + batch_size=batch_size, + num_workers=num_workers, + ) + + # model + Model = get_model_class(model_type=model_type) + + ss_model = Model( + input_channels=input_channels, + output_channels=output_channels, + condition_size=condition_size, + ) + + # loss function + loss_function = get_loss_function(loss_type) + + segment_mixer = SegmentMixer( + max_mix_num=max_mix_num, + lower_db=lower_db, + higher_db=higher_db + ) + + + if query_net == 'CLAP': + query_encoder = CLAP_Encoder() + else: + raise NotImplementedError + + lr_lambda_func = get_lr_lambda( + lr_lambda_type=lr_lambda_type, + warm_up_steps=warm_up_steps, + reduce_lr_steps=reduce_lr_steps, + ) + + # pytorch-lightning model + pl_model = AudioSep( + ss_model=ss_model, + waveform_mixer=segment_mixer, + query_encoder=query_encoder, + loss_function=loss_function, + optimizer_type=optimizer_type, + learning_rate=learning_rate, + lr_lambda_func=lr_lambda_func, + use_text_ratio=use_text_ratio + ) + + checkpoint_every_n_steps = CheckpointEveryNSteps( + checkpoints_dir=checkpoints_dir, + save_step_frequency=save_step_frequency, + ) + + summary_writer = SummaryWriter(log_dir=tf_logs_dir) + + callbacks = [checkpoint_every_n_steps] + + trainer = pl.Trainer( + accelerator='auto', + devices='auto', + strategy='ddp_find_unused_parameters_true', + num_nodes=num_nodes, + precision="32-true", + logger=None, + callbacks=callbacks, + fast_dev_run=False, + max_epochs=-1, + log_every_n_steps=50, + use_distributed_sampler=True, + sync_batchnorm=sync_batchnorm, + num_sanity_val_steps=2, + enable_checkpointing=False, + enable_progress_bar=True, + enable_model_summary=True, + ) + + # Fit, evaluate, and save checkpoints. + trainer.fit( + model=pl_model, + train_dataloaders=None, + val_dataloaders=None, + datamodule=data_module, + ckpt_path=resume_checkpoint_path, + ) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "--workspace", type=str, required=True, help="Directory of workspace." + ) + parser.add_argument( + "--config_yaml", + type=str, + required=True, + help="Path of config file for training.", + ) + + parser.add_argument( + "--resume_checkpoint_path", + type=str, + required=True, + default='', + help="Path of pretrained checkpoint for finetuning.", + ) + + args = parser.parse_args() + args.filename = pathlib.Path(__file__).stem + + train(args) \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..abfb28500aa2c7f7cf395a869245d4c2061f9ca5 --- /dev/null +++ b/utils.py @@ -0,0 +1,384 @@ +import os +import datetime +import json +import logging +import librosa +import pickle +from typing import Dict +import numpy as np +import torch +import torch.nn as nn +import yaml +from models.audiosep import AudioSep, get_model_class + + +def ignore_warnings(): + import warnings + # Ignore UserWarning from torch.meshgrid + warnings.filterwarnings('ignore', category=UserWarning, module='torch.functional') + + # Refined regex pattern to capture variations in the warning message + pattern = r"Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: \['lm_head\..*'\].*" + warnings.filterwarnings('ignore', message=pattern) + + + +def create_logging(log_dir, filemode): + os.makedirs(log_dir, exist_ok=True) + i1 = 0 + + while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): + i1 += 1 + + log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", + datefmt="%a, %d %b %Y %H:%M:%S", + filename=log_path, + filemode=filemode, + ) + + # Print to console + console = logging.StreamHandler() + console.setLevel(logging.INFO) + formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") + console.setFormatter(formatter) + logging.getLogger("").addHandler(console) + + return logging + + +def float32_to_int16(x: float) -> int: + x = np.clip(x, a_min=-1, a_max=1) + return (x * 32767.0).astype(np.int16) + + +def int16_to_float32(x: int) -> float: + return (x / 32767.0).astype(np.float32) + + +def parse_yaml(config_yaml: str) -> Dict: + r"""Parse yaml file. + + Args: + config_yaml (str): config yaml path + + Returns: + yaml_dict (Dict): parsed yaml file + """ + + with open(config_yaml, "r") as fr: + return yaml.load(fr, Loader=yaml.FullLoader) + + +def get_audioset632_id_to_lb(ontology_path: str) -> Dict: + r"""Get AudioSet 632 classes ID to label mapping.""" + + audioset632_id_to_lb = {} + + with open(ontology_path) as f: + data_list = json.load(f) + + for e in data_list: + audioset632_id_to_lb[e["id"]] = e["name"] + + return audioset632_id_to_lb + + +def load_pretrained_panns( + model_type: str, + checkpoint_path: str, + freeze: bool +) -> nn.Module: + r"""Load pretrained pretrained audio neural networks (PANNs). + + Args: + model_type: str, e.g., "Cnn14" + checkpoint_path, str, e.g., "Cnn14_mAP=0.431.pth" + freeze: bool + + Returns: + model: nn.Module + """ + + if model_type == "Cnn14": + Model = Cnn14 + + elif model_type == "Cnn14_DecisionLevelMax": + Model = Cnn14_DecisionLevelMax + + else: + raise NotImplementedError + + model = Model(sample_rate=32000, window_size=1024, hop_size=320, + mel_bins=64, fmin=50, fmax=14000, classes_num=527) + + if checkpoint_path: + checkpoint = torch.load(checkpoint_path, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + + if freeze: + for param in model.parameters(): + param.requires_grad = False + + return model + + +def energy(x): + return torch.mean(x ** 2) + + +def magnitude_to_db(x): + eps = 1e-10 + return 20. * np.log10(max(x, eps)) + + +def db_to_magnitude(x): + return 10. ** (x / 20) + + +def ids_to_hots(ids, classes_num, device): + hots = torch.zeros(classes_num).to(device) + for id in ids: + hots[id] = 1 + return hots + + +def calculate_sdr( + ref: np.ndarray, + est: np.ndarray, + eps=1e-10 +) -> float: + r"""Calculate SDR between reference and estimation. + + Args: + ref (np.ndarray), reference signal + est (np.ndarray), estimated signal + """ + reference = ref + noise = est - reference + + + numerator = np.clip(a=np.mean(reference ** 2), a_min=eps, a_max=None) + + denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None) + + sdr = 10. * np.log10(numerator / denominator) + + return sdr + + +def calculate_sisdr(ref, est): + r"""Calculate SDR between reference and estimation. + + Args: + ref (np.ndarray), reference signal + est (np.ndarray), estimated signal + """ + + eps = np.finfo(ref.dtype).eps + + reference = ref.copy() + estimate = est.copy() + + reference = reference.reshape(reference.size, 1) + estimate = estimate.reshape(estimate.size, 1) + + Rss = np.dot(reference.T, reference) + # get the scaling factor for clean sources + a = (eps + np.dot(reference.T, estimate)) / (Rss + eps) + + e_true = a * reference + e_res = estimate - e_true + + Sss = (e_true**2).sum() + Snn = (e_res**2).sum() + + sisdr = 10 * np.log10((eps+ Sss)/(eps + Snn)) + + return sisdr + + +class StatisticsContainer(object): + def __init__(self, statistics_path): + self.statistics_path = statistics_path + + self.backup_statistics_path = "{}_{}.pkl".format( + os.path.splitext(self.statistics_path)[0], + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), + ) + + self.statistics_dict = {"balanced_train": [], "test": []} + + def append(self, steps, statistics, split, flush=True): + statistics["steps"] = steps + self.statistics_dict[split].append(statistics) + + if flush: + self.flush() + + def flush(self): + pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) + pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) + logging.info(" Dump statistics to {}".format(self.statistics_path)) + logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) + + +def get_mean_sdr_from_dict(sdris_dict): + mean_sdr = np.nanmean(list(sdris_dict.values())) + return mean_sdr + + +def remove_silence(audio: np.ndarray, sample_rate: int) -> np.ndarray: + r"""Remove silent frames.""" + window_size = int(sample_rate * 0.1) + threshold = 0.02 + + frames = librosa.util.frame(x=audio, frame_length=window_size, hop_length=window_size).T + # shape: (frames_num, window_size) + + new_frames = get_active_frames(frames, threshold) + # shape: (new_frames_num, window_size) + + new_audio = new_frames.flatten() + # shape: (new_audio_samples,) + + return new_audio + + +def get_active_frames(frames: np.ndarray, threshold: float) -> np.ndarray: + r"""Get active frames.""" + + energy = np.max(np.abs(frames), axis=-1) + # shape: (frames_num,) + + active_indexes = np.where(energy > threshold)[0] + # shape: (new_frames_num,) + + new_frames = frames[active_indexes] + # shape: (new_frames_num,) + + return new_frames + + +def repeat_to_length(audio: np.ndarray, segment_samples: int) -> np.ndarray: + r"""Repeat audio to length.""" + + repeats_num = (segment_samples // audio.shape[-1]) + 1 + audio = np.tile(audio, repeats_num)[0 : segment_samples] + + return audio + +def calculate_segmentwise_sdr(ref, est, hop_samples, return_sdr_list=False): + min_len = min(ref.shape[-1], est.shape[-1]) + pointer = 0 + sdrs = [] + while pointer + hop_samples < min_len: + sdr = calculate_sdr( + ref=ref[:, pointer : pointer + hop_samples], + est=est[:, pointer : pointer + hop_samples], + ) + sdrs.append(sdr) + pointer += hop_samples + + sdr = np.nanmedian(sdrs) + + if return_sdr_list: + return sdr, sdrs + else: + return sdr + + +def loudness(data, input_loudness, target_loudness): + """ Loudness normalize a signal. + + Normalize an input signal to a user loudness in dB LKFS. + + Params + ------- + data : torch.Tensor + Input multichannel audio data. + input_loudness : float + Loudness of the input in dB LUFS. + target_loudness : float + Target loudness of the output in dB LUFS. + + Returns + ------- + output : torch.Tensor + Loudness normalized output data. + """ + + # calculate the gain needed to scale to the desired loudness level + delta_loudness = target_loudness - input_loudness + gain = torch.pow(10.0, delta_loudness / 20.0) + + output = gain * data + + # check for potentially clipped samples + # if torch.max(torch.abs(output)) >= 1.0: + # warnings.warn("Possible clipped samples in output.") + + return output + + +def load_ss_model( + configs: Dict, + checkpoint_path: str, + query_encoder: nn.Module +) -> nn.Module: + r"""Load trained universal source separation model. + + Args: + configs (Dict) + checkpoint_path (str): path of the checkpoint to load + device (str): e.g., "cpu" | "cuda" + + Returns: + pl_model: pl.LightningModule + """ + + ss_model_type = configs["model"]["model_type"] + input_channels = configs["model"]["input_channels"] + output_channels = configs["model"]["output_channels"] + condition_size = configs["model"]["condition_size"] + + # Initialize separation model + SsModel = get_model_class(model_type=ss_model_type) + + ss_model = SsModel( + input_channels=input_channels, + output_channels=output_channels, + condition_size=condition_size, + ) + + # Load PyTorch Lightning model + pl_model = AudioSep.load_from_checkpoint( + checkpoint_path=checkpoint_path, + strict=False, + ss_model=ss_model, + waveform_mixer=None, + query_encoder=query_encoder, + loss_function=None, + optimizer_type=None, + learning_rate=None, + lr_lambda_func=None, + map_location=torch.device('cpu'), + ) + + return pl_model + + +def parse_yaml(config_yaml: str) -> Dict: + r"""Parse yaml file. + + Args: + config_yaml (str): config yaml path + + Returns: + yaml_dict (Dict): parsed yaml file + """ + + with open(config_yaml, "r") as fr: + return yaml.load(fr, Loader=yaml.FullLoader) \ No newline at end of file