diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c7d9f3332a950355d5a77d85000f05e6f45435ea --- /dev/null +++ b/.gitattributes @@ -0,0 +1,34 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..cc6825672b8a19e9c5b089612d8c6c6083126df7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.pyc +__pycache__ +test.py +flagged \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a197a70ddf44e8dfe5dbefd2736181ffdf329c41 --- /dev/null +++ b/LICENSE @@ -0,0 +1,13 @@ +“Commons Clause” License Condition v1.0 + +The Software is provided to you by the Licensor under the License, as defined below, subject to the following condition. + +Without limiting other conditions in the License, the grant of rights under the License will not include, and the License does not grant to you, the right to Sell the Software. + +For purposes of the foregoing, “Sell” means practicing any or all of the rights granted to you under the License to provide to third parties, for a fee or other consideration (including without limitation fees for hosting or consulting/ support services related to the Software), a product or service whose value derives, entirely or substantially, from the functionality of the Software. Any license notice or attribution required by the License must also include this Commons Clause License Condition notice. + +Software: AudioLDM (including all related model and software) + +License: Apache 2.0 + +Licensor: Haohe Liu \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c870fd68dc0208bf0edb387bc27d465f1c2fbbb0 --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +--- +title: Audioldm Text To Audio Generation +emoji: 🔊 +colorFrom: indigo +colorTo: red +sdk: gradio +sdk_version: 3.16.2 +app_file: app.py +pinned: false +license: bigscience-openrail-m +duplicated_from: dawood/audioldm-text-to-audio-generation +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..7d21ec322d38cd4b957ac958d12215debad34586 --- /dev/null +++ b/app.py @@ -0,0 +1,235 @@ +import gradio as gr +import numpy as np +from audioldm import text_to_audio, build_model +from share_btn import community_icon_html, loading_icon_html, share_js + +model_id="haoheliu/AudioLDM-S-Full" + +audioldm = build_model() +# audioldm=None + +# def predict(input, history=[]): +# # tokenize the new input sentence +# new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt') + +# # append the new user input tokens to the chat history +# bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) + +# # generate a response +# history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist() + +# # convert the tokens to text, and then split the responses into lines +# response = tokenizer.decode(history[0]).split("<|endoftext|>") +# response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list +# return response, history + +def text2audio(text, duration, guidance_scale, random_seed, n_candidates): + # print(text, length, guidance_scale) + waveform = text_to_audio(audioldm, text, random_seed, duration=duration, guidance_scale=guidance_scale, n_candidate_gen_per_text=int(n_candidates)) # [bs, 1, samples] + waveform = [gr.make_waveform((16000, wave[0])) for wave in waveform] + # waveform = [(16000, np.random.randn(16000)), (16000, np.random.randn(16000))] + if(len(waveform) == 1): + waveform = waveform[0] + return waveform,gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) + +# iface = gr.Interface(fn=text2audio, inputs=[ +# gr.Textbox(value="A man is speaking in a huge room", max_lines=1), +# gr.Slider(2.5, 10, value=5, step=2.5), +# gr.Slider(0, 5, value=2.5, step=0.5), +# gr.Number(value=42) +# ], outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")], +# allow_flagging="never" +# ) +# iface.launch(share=True) + +css = """ + .gradio-container { + font-family: 'IBM Plex Sans', sans-serif; + } + .gr-button { + color: white; + border-color: black; + background: black; + } + input[type='range'] { + accent-color: black; + } + .dark input[type='range'] { + accent-color: #dfdfdf; + } + .container { + max-width: 730px; + margin: auto; + padding-top: 1.5rem; + } + #gallery { + min-height: 22rem; + margin-bottom: 15px; + margin-left: auto; + margin-right: auto; + border-bottom-right-radius: .5rem !important; + border-bottom-left-radius: .5rem !important; + } + #gallery>div>.h-full { + min-height: 20rem; + } + .details:hover { + text-decoration: underline; + } + .gr-button { + white-space: nowrap; + } + .gr-button:focus { + border-color: rgb(147 197 253 / var(--tw-border-opacity)); + outline: none; + box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); + --tw-border-opacity: 1; + --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); + --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); + --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); + --tw-ring-opacity: .5; + } + #advanced-btn { + font-size: .7rem !important; + line-height: 19px; + margin-top: 12px; + margin-bottom: 12px; + padding: 2px 8px; + border-radius: 14px !important; + } + #advanced-options { + display: none; + margin-bottom: 20px; + } + .footer { + margin-bottom: 45px; + margin-top: 35px; + text-align: center; + border-bottom: 1px solid #e5e5e5; + } + .footer>p { + font-size: .8rem; + display: inline-block; + padding: 0 10px; + transform: translateY(10px); + background: white; + } + .dark .footer { + border-color: #303030; + } + .dark .footer>p { + background: #0b0f19; + } + .acknowledgments h4{ + margin: 1.25em 0 .25em 0; + font-weight: bold; + font-size: 115%; + } + .animate-spin { + animation: spin 1s linear infinite; + } + @keyframes spin { + from { + transform: rotate(0deg); + } + to { + transform: rotate(360deg); + } + } + #share-btn-container { + display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; + margin-top: 10px; + margin-left: auto; + } + #share-btn { + all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; + } + #share-btn * { + all: unset; + } + #share-btn-container div:nth-child(-n+2){ + width: auto !important; + min-height: 0px !important; + } + #share-btn-container .wrap { + display: none !important; + } + + .gr-form{ + flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0; + } + #prompt-container{ + gap: 0; + } + #prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem} + #component-16{border-top-width: 1px!important;margin-top: 1em} + .image_duplication{position: absolute; width: 100px; left: 50px} +""" +iface = gr.Blocks(css=css) + +with iface: + gr.HTML( + """ +
+
+

+ AudioLDM: Text-to-Audio Generation with Latent Diffusion Models +

+
+

+ [Paper] [Project page] +

+
+ """ + ) + with gr.Group(): + with gr.Box(): + ############# Input + textbox = gr.Textbox(value="A hammer is hitting a wooden surface", max_lines=1, label="Input your text here. Please ensure it is descriptive and of moderate length.") + + with gr.Accordion("Advanced Options", open=False): + seed = gr.Number(value=42, label="Change this value (any integer number) will lead to a different generation result.") + duration = gr.Slider(2.5, 10, value=5, step=2.5, label="Duration (seconds)") + guidance_scale = gr.Slider(0, 5, value=2.5, step=0.5, label="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)") + n_candidates = gr.Slider(1, 5, value=3, step=1, label="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation") + ############# Output + # outputs=gr.Audio(label="Output", type="numpy") + outputs=gr.Video(label="Output") + with gr.Group(elem_id="container-advanced-btns"): + # advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") + with gr.Group(elem_id="share-btn-container"): + community_icon = gr.HTML(community_icon_html, visible=False) + loading_icon = gr.HTML(loading_icon_html, visible=False) + share_button = gr.Button("Share to community", elem_id="share-btn", visible=False) + # outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")] + + btn = gr.Button("Submit").style(full_width=True) + btn.click(text2audio, inputs=[textbox, duration, guidance_scale, seed, n_candidates], outputs=[outputs, community_icon, loading_icon, share_button]) # , share_button, community_icon, loading_icon + share_button.click(None, [], [], _js=share_js) + gr.HTML(''' +
+ + ''') + + with gr.Accordion("Additional information", open=False): + gr.HTML( + """ +
+

We build the model with data from AudioSet, Freesound and BBC Sound Effect library. We share this demo based on the UK copyright exception of data for academic research.

+

This demo is strictly for research demo purpose only. For commercial use please contact us.

+
+ """ + ) + +iface.queue(concurrency_count = 2) +iface.launch(debug=True) +# iface.launch(debug=True, share=True) \ No newline at end of file diff --git a/audioldm/__init__.py b/audioldm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f93cab80ded8e7239bb96eb6e364c3fd4fb46d9 --- /dev/null +++ b/audioldm/__init__.py @@ -0,0 +1,3 @@ +from .ldm import LatentDiffusion +from .utils import seed_everything +from .pipeline import * \ No newline at end of file diff --git a/audioldm/audio/__init__.py b/audioldm/audio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm/audio/audio_processing.py b/audioldm/audio/audio_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..77a4057aa82f226f68474f4c2a19eba84510d663 --- /dev/null +++ b/audioldm/audio/audio_processing.py @@ -0,0 +1,100 @@ +import torch +import numpy as np +import librosa.util as librosa_util +from scipy.signal import get_window + + +def window_sumsquare( + window, + n_frames, + hop_length, + win_length, + n_fft, + dtype=np.float32, + norm=None, +): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] + return x + + +def griffin_lim(magnitudes, stft_fn, n_iters=30): + """ + PARAMS + ------ + magnitudes: spectrogram magnitudes + stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods + """ + + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) + angles = angles.astype(np.float32) + angles = torch.autograd.Variable(torch.from_numpy(angles)) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + + for i in range(n_iters): + _, angles = stft_fn.transform(signal) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + return signal + + +def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return normalize_fun(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C diff --git a/audioldm/audio/stft.py b/audioldm/audio/stft.py new file mode 100644 index 0000000000000000000000000000000000000000..2aa1ac89277734a6676c20a81bf88e21e8ca7aa9 --- /dev/null +++ b/audioldm/audio/stft.py @@ -0,0 +1,180 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy.signal import get_window +from librosa.util import pad_center, tiny +from librosa.filters import mel as librosa_mel_fn + +from audioldm.audio.audio_processing import ( + dynamic_range_compression, + dynamic_range_decompression, + window_sumsquare, +) + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + + def __init__(self, filter_length, hop_length, win_length, window="hann"): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) + + if window is not None: + assert filter_length >= win_length + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode="reflect", + ) + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + torch.autograd.Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ).cpu() + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + torch.autograd.Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0] + ) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False + ) + window_sum = window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] + inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class TacotronSTFT(torch.nn.Module): + def __init__( + self, + filter_length, + hop_length, + win_length, + n_mel_channels, + sampling_rate, + mel_fmin, + mel_fmax, + ): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + mel_basis = librosa_mel_fn( + sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + + def spectral_normalize(self, magnitudes, normalize_fun): + output = dynamic_range_compression(magnitudes, normalize_fun) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y, normalize_fun=torch.log): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert torch.min(y.data) >= -1, torch.min(y.data) + assert torch.max(y.data) <= 1, torch.max(y.data) + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output, normalize_fun) + energy = torch.norm(magnitudes, dim=1) + + log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) + + return mel_output, log_magnitudes, energy diff --git a/audioldm/audio/tools.py b/audioldm/audio/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..7aca95cc1f5c120568a210907e9506589899a1c6 --- /dev/null +++ b/audioldm/audio/tools.py @@ -0,0 +1,33 @@ +import torch +import numpy as np + + +def get_mel_from_wav(audio, _stft): + audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) + audio = torch.autograd.Variable(audio, requires_grad=False) + melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) + melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) + log_magnitudes_stft = ( + torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32) + ) + energy = torch.squeeze(energy, 0).numpy().astype(np.float32) + return melspec, log_magnitudes_stft, energy + + +# def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): +# mel = torch.stack([mel]) +# mel_decompress = _stft.spectral_de_normalize(mel) +# mel_decompress = mel_decompress.transpose(1, 2).data.cpu() +# spec_from_mel_scaling = 1000 +# spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) +# spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) +# spec_from_mel = spec_from_mel * spec_from_mel_scaling + +# audio = griffin_lim( +# torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters +# ) + +# audio = audio.squeeze() +# audio = audio.cpu().numpy() +# audio_path = out_filename +# write(audio_path, _stft.sampling_rate, audio) diff --git a/audioldm/clap/__init__.py b/audioldm/clap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm/clap/encoders.py b/audioldm/clap/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..8f42d9832c31fa51cf361c472c639d68869769f1 --- /dev/null +++ b/audioldm/clap/encoders.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn +from audioldm.clap.open_clip import create_model +from audioldm.clap.training.data import get_audio_features +import torchaudio +from transformers import RobertaTokenizer +import torch.nn.functional as F + + +class CLAPAudioEmbeddingClassifierFreev2(nn.Module): + def __init__( + self, + pretrained_path="", + key="class", + sampling_rate=16000, + embed_mode="audio", + unconditional_prob=0.1, + random_mute=False, + max_random_mute_portion=0.5, + training_mode=True, + ): + super().__init__() + + self.key = key + self.device = "cpu" + self.precision = "fp32" + self.amodel = "HTSAT-tiny" # or 'PANN-14' + self.tmodel = "roberta" # the best text encoder in our training + self.enable_fusion = False # False if you do not want to use the fusion model + self.fusion_type = "aff_2d" + self.pretrained = pretrained_path + self.embed_mode = embed_mode + self.embed_mode_orig = embed_mode + self.sampling_rate = sampling_rate + self.unconditional_prob = unconditional_prob + self.random_mute = random_mute + self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") + self.max_random_mute_portion = max_random_mute_portion + self.training_mode = training_mode + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device=self.device, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + for p in self.model.parameters(): + p.requires_grad = False + + self.model.eval() + + def get_unconditional_condition(self, batchsize): + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) + + def batch_to_list(self, batch): + ret = [] + for i in range(batch.size(0)): + ret.append(batch[i]) + return ret + + def make_decision(self, probability): + if float(torch.rand(1)) < probability: + return True + else: + return False + + def random_uniform(self, start, end): + val = torch.rand(1).item() + return start + (end - start) * val + + def _random_mute(self, waveform): + # waveform: [bs, t-steps] + t_steps = waveform.size(-1) + for i in range(waveform.size(0)): + mute_size = int( + self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) + ) + mute_start = int(self.random_uniform(0, t_steps - mute_size)) + waveform[i, mute_start : mute_start + mute_size] = 0 + return waveform + + def cos_similarity(self, waveform, text): + # waveform: [bs, t_steps] + with torch.no_grad(): + self.embed_mode = "audio" + audio_emb = self(waveform.cuda()) + self.embed_mode = "text" + text_emb = self(text) + similarity = F.cosine_similarity(audio_emb, text_emb, dim=2) + return similarity.squeeze() + + def forward(self, batch, key=None): + # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 + # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 + if self.model.training == True and not self.training_mode: + print( + "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." + ) + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device="cuda", + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + for p in self.model.parameters(): + p.requires_grad = False + self.model.eval() + + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + if self.embed_mode == "audio": + with torch.no_grad(): + audio_dict_list = [] + assert ( + self.sampling_rate == 16000 + ), "We only support 16000 sampling rate" + if self.random_mute: + batch = self._random_mute(batch) + # batch: [bs, 1, t-samples] + batch = torchaudio.functional.resample( + batch, orig_freq=self.sampling_rate, new_freq=48000 + ) + for waveform in self.batch_to_list(batch): + audio_dict = {} + audio_dict = get_audio_features( + audio_dict, + waveform, + 480000, + data_truncating="fusion", + data_filling="repeatpad", + audio_cfg=self.model_cfg["audio_cfg"], + ) + audio_dict_list.append(audio_dict) + # [bs, 512] + embed = self.model.get_audio_embedding(audio_dict_list) + elif self.embed_mode == "text": + with torch.no_grad(): + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + text_data = self.tokenizer(batch) + embed = self.model.get_text_embedding(text_data) + + embed = embed.unsqueeze(1) + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + + for i in range(embed.size(0)): + if self.make_decision(self.unconditional_prob): + embed[i] = self.unconditional_token + + # [bs, 1, 512] + return embed.detach() + + def tokenizer(self, text): + result = self.tokenize( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} diff --git a/audioldm/clap/open_clip/__init__.py b/audioldm/clap/open_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f728f2f273be5d5fdbec6c6cc41d737176a8c0 --- /dev/null +++ b/audioldm/clap/open_clip/__init__.py @@ -0,0 +1,25 @@ +from .factory import ( + list_models, + create_model, + create_model_and_transforms, + add_model_config, +) +from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics +from .model import ( + CLAP, + CLAPTextCfg, + CLAPVisionCfg, + CLAPAudioCfp, + convert_weights_to_fp16, + trace_model, +) +from .openai import load_openai_model, list_openai_models +from .pretrained import ( + list_pretrained, + list_pretrained_tag_models, + list_pretrained_model_tags, + get_pretrained_url, + download_pretrained, +) +from .tokenizer import SimpleTokenizer, tokenize +from .transform import image_transform diff --git a/audioldm/clap/open_clip/bert.py b/audioldm/clap/open_clip/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..a83d96d2a77ed05198efc05837522bc88d2499cc --- /dev/null +++ b/audioldm/clap/open_clip/bert.py @@ -0,0 +1,40 @@ +from transformers import BertTokenizer, BertModel + +tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") +model = BertModel.from_pretrained("bert-base-uncased") +text = "Replace me by any text you'd like." + + +def bert_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output = model(**encoded_input) + return output + + +from transformers import RobertaTokenizer, RobertaModel + +tokenizer = RobertaTokenizer.from_pretrained("roberta-base") +model = RobertaModel.from_pretrained("roberta-base") +text = "Replace me by any text you'd like." + + +def Roberta_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output = model(**encoded_input) + return output + + +from transformers import BartTokenizer, BartModel + +tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") +model = BartModel.from_pretrained("facebook/bart-base") +text = "Replace me by any text you'd like." + + +def bart_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output = model(**encoded_input) + return output diff --git a/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz b/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/audioldm/clap/open_clip/factory.py b/audioldm/clap/open_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..844f9ca0e12a0ff43ba3e042a3e43530ebe91b8c --- /dev/null +++ b/audioldm/clap/open_clip/factory.py @@ -0,0 +1,277 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path + +import torch + +from .model import CLAP, convert_weights_to_fp16 +from .openai import load_openai_model +from .pretrained import get_pretrained_url, download_pretrained +from .transform import image_transform + +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = (".json",) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f"*{ext}")) + + for cf in config_files: + if os.path.basename(cf)[0] == ".": + continue # Ignore hidden files + + with open(cf, "r") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = { + k: v + for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) + } + + +_rescan_model_configs() # initial populate of model config registry + + +def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + if skip_params: + if next(iter(state_dict.items()))[0].startswith("module"): + state_dict = {k[7:]: v for k, v in state_dict.items()} + # for k in state_dict: + # if k.startswith('transformer'): + # v = state_dict.pop(k) + # state_dict['text_branch.' + k[12:]] = v + return state_dict + + +def create_model( + amodel_name: str, + tmodel_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"), + skip_params=True, + pretrained_audio: str = "", + pretrained_text: str = "", + enable_fusion: bool = False, + fusion_type: str = "None" + # pretrained_image: bool = False, +): + amodel_name = amodel_name.replace( + "/", "-" + ) # for callers using old naming with / in ViT names + pretrained_orig = pretrained + pretrained = pretrained.lower() + if pretrained == "openai": + if amodel_name in _MODEL_CONFIGS: + logging.info(f"Loading {amodel_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) + else: + logging.error( + f"Model config for {amodel_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {amodel_name} not found.") + + logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") + # Hard Code in model name + model_cfg["text_cfg"]["model_type"] = tmodel_name + model = load_openai_model( + "ViT-B-16", + model_cfg, + device=device, + jit=jit, + cache_dir=openai_model_cache_dir, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 + if precision == "amp" or precision == "fp32": + model = model.float() + else: + if amodel_name in _MODEL_CONFIGS: + logging.info(f"Loading {amodel_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) + else: + logging.error( + f"Model config for {amodel_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {amodel_name} not found.") + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + # if pretrained_image: + # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): + # # pretrained weight loading for timm models set via vision_cfg + # model_cfg['vision_cfg']['timm_model_pretrained'] = True + # else: + # assert False, 'pretrained image towers currently only supported for timm models' + model_cfg["text_cfg"]["model_type"] = tmodel_name + model_cfg["enable_fusion"] = enable_fusion + model_cfg["fusion_type"] = fusion_type + model = CLAP(**model_cfg) + + if pretrained: + checkpoint_path = "" + url = get_pretrained_url(amodel_name, pretrained) + if url: + checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) + elif os.path.exists(pretrained_orig): + checkpoint_path = pretrained_orig + if checkpoint_path: + logging.info( + f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})." + ) + ckpt = load_state_dict(checkpoint_path, skip_params=True) + model.load_state_dict(ckpt) + param_names = [n for n, p in model.named_parameters()] + # for n in param_names: + # print(n, "\t", "Loaded" if n in ckpt else "Unloaded") + else: + logging.warning( + f"Pretrained weights ({pretrained}) not found for model {amodel_name}." + ) + raise RuntimeError( + f"Pretrained weights ({pretrained}) not found for model {amodel_name}." + ) + + if pretrained_audio: + if amodel_name.startswith("PANN"): + if "Cnn14_mAP" in pretrained_audio: # official checkpoint + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["model"] + keys = list(audio_ckpt.keys()) + for key in keys: + if ( + "spectrogram_extractor" not in key + and "logmel_extractor" not in key + ): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key] = v + elif os.path.basename(pretrained_audio).startswith( + "PANN" + ): # checkpoint trained via HTSAT codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model"): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "finetuned" + ): # checkpoint trained via linear probe codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + else: + raise ValueError("Unknown audio checkpoint") + elif amodel_name.startswith("HTSAT"): + if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model") and ( + "spectrogram_extractor" not in key + and "logmel_extractor" not in key + ): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "HTSAT" + ): # checkpoint trained via HTSAT codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model"): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "finetuned" + ): # checkpoint trained via linear probe codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + else: + raise ValueError("Unknown audio checkpoint") + else: + raise f"this audio encoder pretrained checkpoint is not support" + + model.load_state_dict(audio_ckpt, strict=False) + logging.info( + f"Loading pretrained {amodel_name} weights ({pretrained_audio})." + ) + param_names = [n for n, p in model.named_parameters()] + for n in param_names: + print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") + + model.to(device=device) + if precision == "fp16": + assert device.type != "cpu" + convert_weights_to_fp16(model) + + if jit: + model = torch.jit.script(model) + + return model, model_cfg + + +def create_model_and_transforms( + model_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + # pretrained_image: bool = False, +): + model = create_model( + model_name, + pretrained, + precision, + device, + jit, + force_quick_gelu=force_quick_gelu, + # pretrained_image=pretrained_image + ) + preprocess_train = image_transform(model.visual.image_size, is_train=True) + preprocess_val = image_transform(model.visual.image_size, is_train=False) + return model, preprocess_train, preprocess_val + + +def list_models(): + """enumerate available model architectures based on config files""" + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """add model config path or file and update registry""" + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() diff --git a/audioldm/clap/open_clip/feature_fusion.py b/audioldm/clap/open_clip/feature_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe4e170e05894c12ebdc36ba1dc1de65e441b89 --- /dev/null +++ b/audioldm/clap/open_clip/feature_fusion.py @@ -0,0 +1,192 @@ +""" +Feature Fusion for Varible-Length Data Processing +AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py +According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 +""" + +import torch +import torch.nn as nn + + +class DAF(nn.Module): + """ + 直接相加 DirectAddFuse + """ + + def __init__(self): + super(DAF, self).__init__() + + def forward(self, x, residual): + return x + residual + + +class iAFF(nn.Module): + """ + 多特征融合 iAFF + """ + + def __init__(self, channels=64, r=4, type="2D"): + super(iAFF, self).__init__() + inter_channels = int(channels // r) + + if type == "1D": + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + elif type == "2D": + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + else: + raise f"the type is not supported" + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = torch.cat([xa, xa], dim=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xi = x * wei + residual * (1 - wei) + + xl2 = self.local_att2(xi) + xg2 = self.global_att(xi) + xlg2 = xl2 + xg2 + wei2 = self.sigmoid(xlg2) + xo = x * wei2 + residual * (1 - wei2) + if flag: + xo = xo[0].unsqueeze(0) + return xo + + +class AFF(nn.Module): + """ + 多特征融合 AFF + """ + + def __init__(self, channels=64, r=4, type="2D"): + super(AFF, self).__init__() + inter_channels = int(channels // r) + + if type == "1D": + self.local_att = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + elif type == "2D": + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + else: + raise f"the type is not supported." + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = torch.cat([xa, xa], dim=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xo = 2 * x * wei + 2 * residual * (1 - wei) + if flag: + xo = xo[0].unsqueeze(0) + return xo diff --git a/audioldm/clap/open_clip/htsat.py b/audioldm/clap/open_clip/htsat.py new file mode 100644 index 0000000000000000000000000000000000000000..3b856c6a43df162116a941f1b5c76e93713b276a --- /dev/null +++ b/audioldm/clap/open_clip/htsat.py @@ -0,0 +1,1308 @@ +# Ke Chen +# knutchen@ucsd.edu +# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION +# Some layers designed on the model +# below codes are based and referred from https://github.com/microsoft/Swin-Transformer +# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf + +import torch +import torch.nn as nn +import torch.nn.functional as F +from itertools import repeat +import collections.abc +import math +import warnings + +from torch.nn.init import _calculate_fan_in_and_fan_out +import torch.utils.checkpoint as checkpoint + +import random + +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from itertools import repeat +from .utils import do_mixup, interpolate + +from .feature_fusion import iAFF, AFF, DAF + +# from PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + patch_stride=16, + enable_fusion=False, + fusion_type="None", + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patch_stride = to_2tuple(patch_stride) + self.img_size = img_size + self.patch_size = patch_size + self.patch_stride = patch_stride + self.grid_size = ( + img_size[0] // patch_stride[0], + img_size[1] // patch_stride[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + padding = ( + (patch_size[0] - patch_stride[0]) // 2, + (patch_size[1] - patch_stride[1]) // 2, + ) + + if (self.enable_fusion) and (self.fusion_type == "channel_map"): + self.proj = nn.Conv2d( + in_chans * 4, + embed_dim, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + else: + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + self.mel_conv2d = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=(patch_size[0], patch_size[1] * 3), + stride=(patch_stride[0], patch_stride[1] * 3), + padding=padding, + ) + if self.fusion_type == "daf_2d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_2d": + self.fusion_model = AFF(channels=embed_dim, type="2D") + elif self.fusion_type == "iaff_2d": + self.fusion_model = iAFF(channels=embed_dim, type="2D") + + def forward(self, x, longer_idx=None): + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + global_x = x[:, 0:1, :, :] + + # global processing + B, C, H, W = global_x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + global_x = self.proj(global_x) + TW = global_x.size(-1) + if len(longer_idx) > 0: + # local processing + local_x = x[longer_idx, 1:, :, :].contiguous() + B, C, H, W = local_x.shape + local_x = local_x.view(B * C, 1, H, W) + local_x = self.mel_conv2d(local_x) + local_x = local_x.view( + B, C, local_x.size(1), local_x.size(2), local_x.size(3) + ) + local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3) + TB, TC, TH, _ = local_x.size() + if local_x.size(-1) < TW: + local_x = torch.cat( + [ + local_x, + torch.zeros( + (TB, TC, TH, TW - local_x.size(-1)), + device=global_x.device, + ), + ], + dim=-1, + ) + else: + local_x = local_x[:, :, :, :TW] + + global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x) + x = global_x + else: + B, C, H, W = x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}" + + +# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model +class SwinTransformerBlock(nn.Module): + r"""Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + norm_before_mlp="ln", + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.norm_before_mlp = norm_before_mlp + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + if self.norm_before_mlp == "ln": + self.norm2 = nn.LayerNorm(dim) + elif self.norm_before_mlp == "bn": + self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose( + 1, 2 + ) + else: + raise NotImplementedError + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + # pdb.set_trace() + H, W = self.input_resolution + # print("H: ", H) + # print("W: ", W) + # pdb.set_trace() + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows, attn = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x, attn + + def extra_repr(self): + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) + + +class PatchMerging(nn.Module): + r"""Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self): + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + norm_before_mlp="ln", + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + norm_before_mlp=norm_before_mlp, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer + ) + else: + self.downsample = None + + def forward(self, x): + attns = [] + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x, attn = blk(x) + if not self.training: + attns.append(attn.unsqueeze(0)) + if self.downsample is not None: + x = self.downsample(x) + if not self.training: + attn = torch.cat(attns, dim=0) + attn = torch.mean(attn, dim=0) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +# The Core of HTSAT +class HTSAT_Swin_Transformer(nn.Module): + r"""HTSAT based on the Swin Transformer + Args: + spec_size (int | tuple(int)): Input Spectrogram size. Default 256 + patch_size (int | tuple(int)): Patch size. Default: 4 + path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 + in_chans (int): Number of input image channels. Default: 1 (mono) + num_classes (int): Number of classes for classification head. Default: 527 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 8 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + config (module): The configuration Module from config.py + """ + + def __init__( + self, + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + in_chans=1, + num_classes=527, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + norm_before_mlp="ln", + config=None, + enable_fusion=False, + fusion_type="None", + **kwargs, + ): + super(HTSAT_Swin_Transformer, self).__init__() + + self.config = config + self.spec_size = spec_size + self.patch_stride = patch_stride + self.patch_size = patch_size + self.window_size = window_size + self.embed_dim = embed_dim + self.depths = depths + self.ape = ape + self.in_chans = in_chans + self.num_classes = num_classes + self.num_heads = num_heads + self.num_layers = len(self.depths) + self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) + + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + + self.qkv_bias = qkv_bias + self.qk_scale = None + + self.patch_norm = patch_norm + self.norm_layer = norm_layer if self.patch_norm else None + self.norm_before_mlp = norm_before_mlp + self.mlp_ratio = mlp_ratio + + self.use_checkpoint = use_checkpoint + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # process mel-spec ; used only once + self.freq_ratio = self.spec_size // self.config.mel_bins + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + self.interpolate_ratio = 32 # Downsampled ratio + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=config.window_size, + hop_length=config.hop_size, + win_length=config.window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=config.sample_rate, + n_fft=config.window_size, + n_mels=config.mel_bins, + fmin=config.fmin, + fmax=config.fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) # 2 2 + self.bn0 = nn.BatchNorm2d(self.config.mel_bins) + + # split spctrogram into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=self.spec_size, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + norm_layer=self.norm_layer, + patch_stride=patch_stride, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.grid_size + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dim) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=self.drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(self.embed_dim * 2**i_layer), + input_resolution=( + patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer), + ), + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_scale=self.qk_scale, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[ + sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1]) + ], + norm_layer=self.norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + norm_before_mlp=self.norm_before_mlp, + ) + self.layers.append(layer) + + self.norm = self.norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.maxpool = nn.AdaptiveMaxPool1d(1) + + SF = ( + self.spec_size + // (2 ** (len(self.depths) - 1)) + // self.patch_stride[0] + // self.freq_ratio + ) + self.tscam_conv = nn.Conv2d( + in_channels=self.num_features, + out_channels=self.num_classes, + kernel_size=(SF, 3), + padding=(0, 1), + ) + self.head = nn.Linear(num_classes, num_classes) + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] + ): + self.mel_conv1d = nn.Sequential( + nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1d(64), + ) + if self.fusion_type == "daf_1d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_1d": + self.fusion_model = AFF(channels=64, type="1D") + elif self.fusion_type == "iaff_1d": + self.fusion_model = iAFF(channels=64, type="1D") + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"absolute_pos_embed"} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"relative_position_bias_table"} + + def forward_features(self, x, longer_idx=None): + # A deprecated optimization for using a hierarchical output from different blocks + + frames_num = x.shape[2] + x = self.patch_embed(x, longer_idx=longer_idx) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + for i, layer in enumerate(self.layers): + x, attn = layer(x) + # for x + x = self.norm(x) + B, N, C = x.shape + SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] + ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] + x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST) + B, C, F, T = x.shape + # group 2D CNN + c_freq_bin = F // self.freq_ratio + x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) + x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1) + # get latent_output + fine_grained_latent_output = torch.mean(x, dim=2) + fine_grained_latent_output = interpolate( + fine_grained_latent_output.permute(0, 2, 1).contiguous(), + 8 * self.patch_stride[1], + ) + + latent_output = self.avgpool(torch.flatten(x, 2)) + latent_output = torch.flatten(latent_output, 1) + + # display the attention map, if needed + + x = self.tscam_conv(x) + x = torch.flatten(x, 2) # B, C, T + + fpx = interpolate( + torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1] + ) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + output_dict = { + "framewise_output": fpx, # already sigmoided + "clipwise_output": torch.sigmoid(x), + "fine_grained_embedding": fine_grained_latent_output, + "embedding": latent_output, + } + + return output_dict + + def crop_wav(self, x, crop_size, spe_pos=None): + time_steps = x.shape[2] + tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) + for i in range(len(x)): + if spe_pos is None: + crop_pos = random.randint(0, time_steps - crop_size - 1) + else: + crop_pos = spe_pos + tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :] + return tx + + # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model + def reshape_wav2img(self, x): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert ( + T <= target_T and F <= target_F + ), "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate( + x, (target_T, x.shape[3]), mode="bicubic", align_corners=True + ) + if F < target_F: + x = nn.functional.interpolate( + x, (x.shape[2], target_F), mode="bicubic", align_corners=True + ) + x = x.permute(0, 1, 3, 2).contiguous() + x = x.reshape( + x.shape[0], + x.shape[1], + x.shape[2], + self.freq_ratio, + x.shape[3] // self.freq_ratio, + ) + # print(x.shape) + x = x.permute(0, 1, 3, 2, 4).contiguous() + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) + return x + + # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model + def repeat_wat2img(self, x, cur_pos): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert ( + T <= target_T and F <= target_F + ), "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate( + x, (target_T, x.shape[3]), mode="bicubic", align_corners=True + ) + if F < target_F: + x = nn.functional.interpolate( + x, (x.shape[2], target_F), mode="bicubic", align_corners=True + ) + x = x.permute(0, 1, 3, 2).contiguous() # B C F T + x = x[:, :, :, cur_pos : cur_pos + self.spec_size] + x = x.repeat(repeats=(1, 1, 4, 1)) + return x + + def forward( + self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None + ): # out_feat_keys: List[str] = None): + + if self.enable_fusion and x["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True + + if not self.enable_fusion: + x = x["waveform"].to(device=device, non_blocking=True) + x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.training: + x = self.spec_augmenter(x) + + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x) + else: + longer_list = x["longer"].to(device=device, non_blocking=True) + x = x["mel_fusion"].to(device=device, non_blocking=True) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + longer_list_idx = torch.where(longer_list)[0] + if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: + new_x = x[:, 0:1, :, :].clone().contiguous() + if len(longer_list_idx) > 0: + # local processing + fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() + FB, FC, FT, FF = fusion_x_local.size() + fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) + fusion_x_local = torch.permute( + fusion_x_local, (0, 2, 1) + ).contiguous() + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.view( + FB, FC, FF, fusion_x_local.size(-1) + ) + fusion_x_local = ( + torch.permute(fusion_x_local, (0, 2, 1, 3)) + .contiguous() + .flatten(2) + ) + if fusion_x_local.size(-1) < FT: + fusion_x_local = torch.cat( + [ + fusion_x_local, + torch.zeros( + (FB, FF, FT - fusion_x_local.size(-1)), + device=device, + ), + ], + dim=-1, + ) + else: + fusion_x_local = fusion_x_local[:, :, :FT] + # 1D fusion + new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() + new_x[longer_list_idx] = self.fusion_model( + new_x[longer_list_idx], fusion_x_local + ) + x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] + else: + x = new_x + + elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x, longer_idx=longer_list_idx) + + # if infer_mode: + # # in infer mode. we need to handle different length audio input + # frame_num = x.shape[2] + # target_T = int(self.spec_size * self.freq_ratio) + # repeat_ratio = math.floor(target_T / frame_num) + # x = x.repeat(repeats=(1,1,repeat_ratio,1)) + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # else: + # if x.shape[2] > self.freq_ratio * self.spec_size: + # if self.training: + # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # else: + # # Change: Hard code here + # overlap_size = (x.shape[2] - 1) // 4 + # output_dicts = [] + # crop_size = (x.shape[2] - 1) // 2 + # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): + # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) + # tx = self.reshape_wav2img(tx) + # output_dicts.append(self.forward_features(tx)) + # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) + # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) + # for d in output_dicts: + # clipwise_output += d["clipwise_output"] + # framewise_output += d["framewise_output"] + # clipwise_output = clipwise_output / len(output_dicts) + # framewise_output = framewise_output / len(output_dicts) + # output_dict = { + # 'framewise_output': framewise_output, + # 'clipwise_output': clipwise_output + # } + # else: # this part is typically used, and most easy one + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # x = self.head(x) + + # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T + + return output_dict + + +def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"): + try: + + assert audio_cfg.model_name in [ + "tiny", + "base", + "large", + ], "model name for HTS-AT is wrong!" + if audio_cfg.model_name == "tiny": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + elif audio_cfg.model_name == "base": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=128, + depths=[2, 2, 12, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + elif audio_cfg.model_name == "large": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=256, + depths=[2, 2, 12, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + + return model + except: + raise RuntimeError( + f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." + ) diff --git a/audioldm/clap/open_clip/linear_probe.py b/audioldm/clap/open_clip/linear_probe.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7e23b6b67a53e16d050d675a99d01d7d04d581 --- /dev/null +++ b/audioldm/clap/open_clip/linear_probe.py @@ -0,0 +1,66 @@ +import numpy as np +import torch.nn.functional as F +from torch import nn +from .model import MLPLayers + + +class LinearProbe(nn.Module): + def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): + """ + Args: + model: nn.Module + mlp: bool, if True, then use the MLP layer as the linear probe module + freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe + in_ch: int, the output channel from CLAP model + out_ch: int, the output channel from linear probe (class_num) + act: torch.nn.functional, the activation function before the loss function + """ + super().__init__() + in_ch = 512 + self.clap_model = model + self.clap_model.text_branch = None # to save memory + self.freeze = freeze + if mlp: + self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) + else: + self.lp_layer = nn.Linear(in_ch, out_ch) + + if self.freeze: + for param in self.clap_model.parameters(): + param.requires_grad = False + + if act == "None": + self.act = None + elif act == "relu": + self.act = nn.ReLU() + elif act == "elu": + self.act = nn.ELU() + elif act == "prelu": + self.act = nn.PReLU(num_parameters=in_ch) + elif act == "softmax": + self.act = nn.Softmax(dim=-1) + elif act == "sigmoid": + self.act = nn.Sigmoid() + + def forward(self, x, mix_lambda=None, device=None): + """ + Args: + x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list + mix_lambda: torch.tensor [batch], the mixup lambda + Returns: + class_prob: torch.tensor [batch, class_num] + + """ + # batchnorm cancel grandient + if self.freeze: + self.clap_model.eval() + + x = self.clap_model.audio_projection( + self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ + "embedding" + ] + ) + out = self.lp_layer(x) + if self.act is not None: + out = self.act(out) + return out diff --git a/audioldm/clap/open_clip/loss.py b/audioldm/clap/open_clip/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..cc66298a14997da4aa2efc71e37c0a6bcda53fd1 --- /dev/null +++ b/audioldm/clap/open_clip/loss.py @@ -0,0 +1,398 @@ +from multiprocessing.sharedctypes import Value +import torch +import torch.distributed.nn +from torch import distributed as dist, nn as nn +from torch.nn import functional as F +import numpy as np +from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def gather_features( + audio_features, + text_features, + audio_features_mlp=None, + text_features_mlp=None, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False, + mlp_loss=False, +): + if use_horovod: + assert hvd is not None, "Please install horovod" + if gather_with_grad: + all_audio_features = hvd.allgather(audio_features) + all_text_features = hvd.allgather(text_features) + if mlp_loss: + all_audio_features_mlp = hvd.allgather(audio_features_mlp) + all_text_features_mlp = hvd.allgather(text_features_mlp) + else: + with torch.no_grad(): + all_audio_features = hvd.allgather(audio_features) + all_text_features = hvd.allgather(text_features) + if mlp_loss: + all_audio_features_mlp = hvd.allgather(audio_features_mlp) + all_text_features_mlp = hvd.allgather(text_features_mlp) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features = list( + all_audio_features.chunk(world_size, dim=0) + ) + gathered_text_features = list( + all_text_features.chunk(world_size, dim=0) + ) + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + if mlp_loss: + gathered_audio_features_mlp = list( + all_audio_features_mlp.chunk(world_size, dim=0) + ) + gathered_text_features_mlp = list( + all_text_features_mlp.chunk(world_size, dim=0) + ) + gathered_audio_features_mlp[rank] = audio_features_mlp + gathered_text_features_mlp[rank] = text_features_mlp + all_audio_features_mlp = torch.cat( + gathered_audio_features_mlp, dim=0 + ) + all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_audio_features = torch.cat( + torch.distributed.nn.all_gather(audio_features), dim=0 + ) + all_text_features = torch.cat( + torch.distributed.nn.all_gather(text_features), dim=0 + ) + if mlp_loss: + all_audio_features_mlp = torch.cat( + torch.distributed.nn.all_gather(audio_features_mlp), dim=0 + ) + all_text_features_mlp = torch.cat( + torch.distributed.nn.all_gather(text_features_mlp), dim=0 + ) + else: + gathered_audio_features = [ + torch.zeros_like(audio_features) for _ in range(world_size) + ] + gathered_text_features = [ + torch.zeros_like(text_features) for _ in range(world_size) + ] + dist.all_gather(gathered_audio_features, audio_features) + dist.all_gather(gathered_text_features, text_features) + if mlp_loss: + gathered_audio_features_mlp = [ + torch.zeros_like(audio_features_mlp) for _ in range(world_size) + ] + gathered_text_features_mlp = [ + torch.zeros_like(text_features_mlp) for _ in range(world_size) + ] + dist.all_gather(gathered_audio_features_mlp, audio_features_mlp) + dist.all_gather(gathered_text_features_mlp, text_features_mlp) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + if mlp_loss: + gathered_audio_features_mlp[rank] = audio_features_mlp + gathered_text_features_mlp[rank] = text_features_mlp + + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + if mlp_loss: + all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) + all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) + if mlp_loss: + return ( + all_audio_features, + all_text_features, + all_audio_features_mlp, + all_text_features_mlp, + ) + else: + return all_audio_features, all_text_features + + +class ClipLoss(nn.Module): + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + mlp_loss=False, + weight_loss_kappa=0, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + self.mlp_loss = mlp_loss + self.weighted_loss = bool(weight_loss_kappa != 0) + self.weight_loss_kappa = weight_loss_kappa + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward( + self, + audio_features, + text_features, + logit_scale_a, + logit_scale_t=None, + audio_features_mlp=None, + text_features_mlp=None, + ): + device = audio_features.device + if self.mlp_loss: + if self.world_size > 1: + ( + all_audio_features, + all_text_features, + all_audio_features_mlp, + all_text_features_mlp, + ) = gather_features( + audio_features=audio_features, + text_features=text_features, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + use_horovod=self.use_horovod, + mlp_loss=self.mlp_loss, + ) + if self.local_loss: + a_logits_per_audio = ( + logit_scale_a * audio_features @ all_text_features_mlp.T + ) + a_logits_per_text = ( + logit_scale_a * text_features_mlp @ all_audio_features.T + ) + t_logits_per_audio = ( + logit_scale_t * audio_features_mlp @ all_text_features.T + ) + t_logits_per_text = ( + logit_scale_t * text_features @ all_audio_features_mlp.T + ) + else: + a_logits_per_audio = ( + logit_scale_a * all_audio_features @ all_text_features_mlp.T + ) + a_logits_per_text = a_logits_per_audio.T + t_logits_per_audio = ( + logit_scale_t * all_audio_features_mlp @ all_text_features.T + ) + t_logits_per_text = t_logits_per_audio.T + else: + a_logits_per_audio = ( + logit_scale_a * audio_features @ text_features_mlp.T + ) + a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T + t_logits_per_audio = ( + logit_scale_t * audio_features_mlp @ text_features.T + ) + t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T + + # calculated ground-truth and cache if enabled + num_logits = a_logits_per_audio.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + if not self.weighted_loss: + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels) + + F.cross_entropy(a_logits_per_text, labels) + + F.cross_entropy(t_logits_per_audio, labels) + + F.cross_entropy(t_logits_per_text, labels) + ) / 4 + else: + audio_weight = (audio_features @ audio_features.T).detach() + audio_weight = ( + torch.exp( + torch.sum(audio_weight, axis=1) + / (self.weight_loss_kappa * len(audio_weight)) + ) + ).detach() + text_weight = (text_features @ text_features.T).detach() + text_weight = ( + torch.exp( + torch.sum(text_weight, axis=1) + / (self.weight_loss_kappa * len(text_features)) + ) + ).detach() + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) + + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) + + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) + + F.cross_entropy(t_logits_per_text, labels, weight=text_weight) + ) / 4 + else: + if self.world_size > 1: + all_audio_features, all_text_features = gather_features( + audio_features=audio_features, + text_features=text_features, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + use_horovod=self.use_horovod, + mlp_loss=self.mlp_loss, + ) + + if self.local_loss: + logits_per_audio = ( + logit_scale_a * audio_features @ all_text_features.T + ) + logits_per_text = ( + logit_scale_a * text_features @ all_audio_features.T + ) + else: + logits_per_audio = ( + logit_scale_a * all_audio_features @ all_text_features.T + ) + logits_per_text = logits_per_audio.T + else: + logits_per_audio = logit_scale_a * audio_features @ text_features.T + logits_per_text = logit_scale_a * text_features @ audio_features.T + + # calculated ground-truth and cache if enabled + num_logits = logits_per_audio.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + if not self.weighted_loss: + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + else: + audio_weight = (all_audio_features @ all_audio_features.T).detach() + audio_weight = ( + torch.exp( + torch.sum(audio_weight, axis=1) + / (self.weight_loss_kappa * len(all_audio_features)) + ) + ).detach() + text_weight = (all_text_features @ all_text_features.T).detach() + text_weight = ( + torch.exp( + torch.sum(text_weight, axis=1) + / (self.weight_loss_kappa * len(all_text_features)) + ) + ).detach() + total_loss = ( + F.cross_entropy(logits_per_audio, labels, weight=text_weight) + + F.cross_entropy(logits_per_text, labels, weight=audio_weight) + ) / 2 + return total_loss + + +def lp_gather_features(pred, target, world_size=1, use_horovod=False): + if use_horovod: + assert hvd is not None, "Please install horovod" + with torch.no_grad(): + all_preds = hvd.allgather(pred) + all_targets = hvd.allgath(target) + else: + gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] + gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] + + dist.all_gather(gathered_preds, pred) + dist.all_gather(gathered_targets, target) + all_preds = torch.cat(gathered_preds, dim=0) + all_targets = torch.cat(gathered_targets, dim=0) + + return all_preds, all_targets + + +def get_map(pred, target): + pred = torch.sigmoid(pred).numpy() + target = target.numpy() + return np.mean(average_precision_score(target, pred, average=None)) + + +def get_acc(pred, target): + pred = torch.argmax(pred, 1).numpy() + target = torch.argmax(target, 1).numpy() + return accuracy_score(target, pred) + + +def get_mauc(pred, target): + pred = torch.sigmoid(pred).numpy() + target = target.numpy() + return np.mean(roc_auc_score(target, pred, average=None)) + + +class LPMetrics(object): + def __init__(self, metric_names=["map", "acc", "mauc"]): + self.metrics = [] + for name in metric_names: + self.metrics.append(self.get_metric(name)) + self.metric_names = metric_names + + def get_metric(self, name): + if name == "map": + return get_map + elif name == "acc": + return get_acc + elif name == "mauc": + return get_mauc + else: + raise ValueError(f"the metric should be at least one of [map, acc, mauc]") + + def evaluate_mertics(self, pred, target): + metric_dict = {} + for i in range(len(self.metric_names)): + metric_dict[self.metric_names[i]] = self.metrics[i](pred, target) + return metric_dict + + +def calc_celoss(pred, target): + target = torch.argmax(target, 1).long() + return nn.CrossEntropyLoss()(pred, target) + + +class LPLoss(nn.Module): + def __init__(self, loss_name): + super().__init__() + if loss_name == "bce": + self.loss_func = nn.BCEWithLogitsLoss() + elif loss_name == "ce": + self.loss_func = calc_celoss + elif loss_name == "mse": + self.loss_func = nn.MSELoss() + else: + raise ValueError(f"the loss func should be at least one of [bce, ce, mse]") + + def forward(self, pred, target): + loss = self.loss_func(pred, target) + return loss diff --git a/audioldm/clap/open_clip/model.py b/audioldm/clap/open_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f41e6d6d0b0bbecacb90744928a516b75d218214 --- /dev/null +++ b/audioldm/clap/open_clip/model.py @@ -0,0 +1,936 @@ +""" CLAP Model + +Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +Adapted to the Audio Task. +""" + +from collections import OrderedDict +from dataclasses import dataclass +from email.mime import audio +from typing import Tuple, Union, Callable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from .timm_model import TimmModel +import logging +from .utils import freeze_batch_norm_2d + +from .pann_model import create_pann_model +from .htsat import create_htsat_model +from transformers import BertModel, RobertaModel, BartModel +from transformers.tokenization_utils_base import BatchEncoding + + +class MLPLayers(nn.Module): + def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1): + super(MLPLayers, self).__init__() + self.nonlin = nonlin + self.dropout = dropout + + sequence = [] + for u0, u1 in zip(units[:-1], units[1:]): + sequence.append(nn.Linear(u0, u1)) + sequence.append(self.nonlin) + sequence.append(nn.Dropout(self.dropout)) + sequence = sequence[:-2] + + self.sequential = nn.Sequential(*sequence) + + def forward(self, X): + X = self.sequential(X) + return X + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( + 2, 0, 1 + ) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + def stem(self, x): + for conv, bn in [ + (self.conv1, self.bn1), + (self.conv2, self.bn2), + (self.conv3, self.bn3), + ]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock(width, heads, act_layer=act_layer) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisualTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + act_layer: Callable = nn.GELU, + ): + super().__init__() + self.image_size = image_size + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn((image_size // patch_size) ** 2 + 1, width) + ) + self.ln_pre = LayerNorm(width) + + self.text_branch = Transformer(width, layers, heads, act_layer=act_layer) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_branch(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +@dataclass +class CLAPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + timm_model_name: str = ( + None # a valid model name overrides layers, width, patch_size + ) + timm_model_pretrained: bool = ( + False # use (imagenet) pretrained weights for named model + ) + timm_pool: str = ( + "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + ) + timm_proj: str = ( + "linear" # linear projection for timm model output ('linear', 'mlp', '') + ) + + +# Audio Config Class +@dataclass +class CLAPAudioCfp: + model_type: str = "PANN" + model_name: str = "Cnn14" + sample_rate: int = 48000 + # Param + audio_length: int = 1024 + window_size: int = 1024 + hop_size: int = 1024 + fmin: int = 50 + fmax: int = 14000 + class_num: int = 527 + mel_bins: int = 64 + clip_samples: int = 480000 + + +@dataclass +class CLAPTextCfg: + context_length: int + vocab_size: int + width: int + heads: int + layers: int + model_type: str + + +class CLAP(nn.Module): + def __init__( + self, + embed_dim: int, + audio_cfg: CLAPAudioCfp, + text_cfg: CLAPTextCfg, + quick_gelu: bool = False, + enable_fusion: bool = False, + fusion_type: str = "None", + joint_embed_shape: int = 512, + mlp_act: str = "relu", + ): + super().__init__() + if isinstance(audio_cfg, dict): + audio_cfg = CLAPAudioCfp(**audio_cfg) + if isinstance(text_cfg, dict): + text_cfg = CLAPTextCfg(**text_cfg) + + self.audio_cfg = audio_cfg + self.text_cfg = text_cfg + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + self.joint_embed_shape = joint_embed_shape + self.mlp_act = mlp_act + + self.context_length = text_cfg.context_length + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if mlp_act == "relu": + mlp_act_layer = nn.ReLU() + elif mlp_act == "gelu": + mlp_act_layer = nn.GELU() + else: + raise NotImplementedError + + # audio branch + # audio branch parameters + if audio_cfg.model_type == "PANN": + self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type) + elif audio_cfg.model_type == "HTSAT": + self.audio_branch = create_htsat_model( + audio_cfg, enable_fusion, fusion_type + ) + else: + logging.error(f"Model config for {audio_cfg.model_type} not found") + raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.") + + # text branch + # text branch parameters + if text_cfg.model_type == "transformer": + self.text_branch = Transformer( + width=text_cfg.width, + layers=text_cfg.layers, + heads=text_cfg.heads, + act_layer=act_layer, + ) + self.vocab_size = text_cfg.vocab_size + self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, text_cfg.width) + ) + self.ln_final = LayerNorm(text_cfg.width) + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(text_cfg.width, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "bert": + self.text_branch = BertModel.from_pretrained("bert-base-uncased") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "roberta": + self.text_branch = RobertaModel.from_pretrained("roberta-base") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "bart": + self.text_branch = BartModel.from_pretrained("facebook/bart-base") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + else: + logging.error(f"Model config for {text_cfg.model_type} not found") + raise RuntimeError(f"Model config for {text_cfg.model_type} not found.") + self.text_branch_type = text_cfg.model_type + # text branch parameters + + # audio branch parameters + self.audio_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + + # below here is text branch parameters + + # ============================================================================================================ + self.audio_projection = nn.Sequential( + nn.Linear(embed_dim, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + + self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) + + self.init_text_branch_parameters() + + def init_text_branch_parameters(self): + if self.text_branch_type == "transformer": + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + proj_std = (self.text_branch.width**-0.5) * ( + (2 * self.text_branch.layers) ** -0.5 + ) + attn_std = self.text_branch.width**-0.5 + fc_std = (2 * self.text_branch.width) ** -0.5 + for block in self.text_branch.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + if self.text_branch_type == "bert" or self.text_branch_type == "roberta": + width = self.text_branch.embeddings.word_embeddings.weight.shape[-1] + elif self.text_branch_type == "bart": + width = self.text_branch.shared.weight.shape[-1] + else: + width = self.text_branch.width + nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07)) + nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07)) + + # deprecated + # if hasattr(self.visual, 'init_parameters'): + # self.visual.init_parameters() + + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_audio(self, audio, device): + return self.audio_branch( + audio, mixup_lambda=None, device=device + ) # mix lambda needs to add + + # def list_of_dict_of_tensor2dict_of_tensor(self, x, device): + # tmp = {} + # for k in x[0].keys(): + # tmp[k] = [] + # for i in range(len(x)): + # tmp[k].append(x[i][k][:77]) + # for k in x[0].keys(): + # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True) + # return tmp + + def encode_text(self, text, device): + if self.text_branch_type == "transformer": + text = text.to(device=device, non_blocking=True) + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_branch(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)]) + elif self.text_branch_type == "bert": + # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device) + # text = BatchEncoding(text) + x = self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + token_type_ids=text["token_type_ids"].to( + device=device, non_blocking=True + ), + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "roberta": + x = self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "bart": + x = torch.mean( + self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + )["encoder_last_hidden_state"], + axis=1, + ) + x = self.text_projection(x) + else: + logging.error(f"Model type {self.text_branch_type} not found") + raise RuntimeError(f"Model type {self.text_branch_type} not found.") + return x + + def forward(self, audio, text, device=None): + """Forward audio and text into the CLAP + + Parameters + ---------- + audio: torch.Tensor (batch_size, audio_length) + the time-domain audio input / the batch of mel_spec and longer list. + text: torch.Tensor () // need to add + the text token input + """ + if device is None: + if audio is not None: + device = audio.device + elif text is not None: + device = text.device + if audio is None and text is None: + # a hack to get the logit scale + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + elif audio is None: + return self.encode_text(text, device=device) + elif text is None: + return self.audio_projection( + self.encode_audio(audio, device=device)["embedding"] + ) + audio_features = self.audio_projection( + self.encode_audio(audio, device=device)["embedding"] + ) + audio_features = F.normalize(audio_features, dim=-1) + + text_features = self.encode_text(text, device=device) + # print("text_features", text_features) + # print("text_features.shape", text_features.shape) + # print("text_features.type", type(text_features)) + text_features = F.normalize(text_features, dim=-1) + + audio_features_mlp = self.audio_transform(audio_features) + text_features_mlp = self.text_transform(text_features) + # Four outputs: audio features (basic & MLP), text features (basic & MLP) + return ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + self.logit_scale_a.exp(), + self.logit_scale_t.exp(), + ) + + def get_logit_scale(self): + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + + def get_text_embedding(self, data): + """Get the text embedding from the model + + Parameters + ---------- + data: torch.Tensor + a tensor of text embedding + + Returns + ---------- + text_embed: torch.Tensor + a tensor of text_embeds (N, D) + + """ + device = next(self.parameters()).device + for k in data: + data[k] = data[k].to(device) + if(len(data[k].size()) < 2): + data[k] = data[k].unsqueeze(0) + text_embeds = self.encode_text(data, device=device) + text_embeds = F.normalize(text_embeds, dim=-1) + + return text_embeds + + def get_audio_embedding(self, data): + """Get the audio embedding from the model + + Parameters + ---------- + data: a list of dict + the audio input dict list from 'get_audio_feature' method + + Returns + ---------- + audio_embed: torch.Tensor + a tensor of audio_embeds (N, D) + + """ + device = next(self.parameters()).device + input_dict = {} + keys = data[0].keys() + for k in keys: + input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to( + device + ) + + audio_embeds = self.audio_projection( + self.encode_audio(input_dict, device=device)["embedding"] + ) + audio_embeds = F.normalize(audio_embeds, dim=-1) + + return audio_embeds + + def audio_infer(self, audio, hopsize=None, device=None): + """Forward one audio and produce the audio embedding + + Parameters + ---------- + audio: (audio_length) + the time-domain audio input, notice that it must be only one input + hopsize: int + the overlap hopsize as the sliding window + + Returns + ---------- + output_dict: { + key: [n, (embedding_shape)] if "HTS-AT" + or + key: [(embedding_shape)] if "PANN" + } + the list of key values of the audio branch + + """ + + assert not self.training, "the inference mode must be run at eval stage" + output_dict = {} + # PANN + if self.audio_cfg.model_type == "PANN": + audio_input = audio.unsqueeze(dim=0) + output_dict[key] = self.encode_audio(audio_input, device=device)[ + key + ].squeeze(dim=0) + elif self.audio_cfg.model_type == "HTSAT": + # repeat + audio_len = len(audio) + k = self.audio_cfg.clip_samples // audio_len + if k > 1: + audio = audio.repeat(k) + audio_len = len(audio) + + if hopsize is None: + hopsize = min(hopsize, audio_len) + + if audio_len > self.audio_cfg.clip_samples: + audio_input = [ + audio[pos : pos + self.audio_cfg.clip_samples].clone() + for pos in range( + 0, audio_len - self.audio_cfg.clip_samples, hopsize + ) + ] + audio_input.append(audio[-self.audio_cfg.clip_samples :].clone()) + audio_input = torch.stack(audio_input) + output_dict[key] = self.encode_audio(audio_input, device=device)[key] + else: + audio_input = audio.unsqueeze(dim=0) + output_dict[key] = self.encode_audio(audio_input, device=device)[ + key + ].squeeze(dim=0) + + return output_dict + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [ + *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", + "bias_k", + "bias_v", + ]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +# Ignore the state dict of the vision part +def build_model_from_openai_state_dict( + state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None" +): + + embed_dim = model_cfg["embed_dim"] + audio_cfg = model_cfg["audio_cfg"] + text_cfg = model_cfg["text_cfg"] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"transformer.resblocks") + ) + ) + + audio_cfg = CLAPAudioCfp(**audio_cfg) + text_cfg = CLAPTextCfg(**text_cfg) + + model = CLAP( + embed_dim, + audio_cfg=audio_cfg, + text_cfg=text_cfg, + quick_gelu=True, # OpenAI models were trained with QuickGELU + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + state_dict["logit_scale_a"] = state_dict["logit_scale"] + state_dict["logit_scale_t"] = state_dict["logit_scale"] + pop_keys = list(state_dict.keys())[::] + # pop the visual branch saved weights + for key in pop_keys: + if key.startswith("visual."): + state_dict.pop(key, None) + + for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + # not use fp16 + # convert_weights_to_fp16(model) + model.load_state_dict(state_dict, strict=False) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device("cpu")): + model.eval() + audio_length = model.audio_cfg.audio_length + example_audio = torch.ones((batch_size, audio_length), device=device) + example_text = torch.zeros( + (batch_size, model.context_length), dtype=torch.int, device=device + ) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_audio, example_text), + encode_text=(example_text,), + encode_image=(example_audio,), + ), + ) + model.audio_cfg.audio_length = audio_length # Question: what does this do? + return model diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-base.json b/audioldm/clap/open_clip/model_configs/HTSAT-base.json new file mode 100644 index 0000000000000000000000000000000000000000..6cef625a89daf4431f1c9f72e10bc9640eef2ba8 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/HTSAT-base.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 1024, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "base" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-large.json b/audioldm/clap/open_clip/model_configs/HTSAT-large.json new file mode 100644 index 0000000000000000000000000000000000000000..699cdb1b16855582606551e4196b24aba2ffd871 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/HTSAT-large.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "large" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json b/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json new file mode 100644 index 0000000000000000000000000000000000000000..73e42990fe8361a0df502e7f93d29f19f58c9ecb --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1536, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "tiny" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json b/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..a6e7821163d9afa81c27345a1e472475b92af169 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "tiny" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-10.json b/audioldm/clap/open_clip/model_configs/PANN-10.json new file mode 100644 index 0000000000000000000000000000000000000000..954ddf62921aed7dde9c37ffffec98a2e96a4ee7 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-10.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 1024, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn10" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json new file mode 100644 index 0000000000000000000000000000000000000000..b7989bc0cd95d0d39049b7524eba508b3e386439 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 18000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json new file mode 100644 index 0000000000000000000000000000000000000000..56bdb56bedc304ffa52d8bf5988cea2c1d82d14e --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 960000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 360, + "fmin": 50, + "fmax": 8000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json b/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json new file mode 100644 index 0000000000000000000000000000000000000000..5756e3bebc97cc985f512cb081930fee4e49bec1 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 4 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json b/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json new file mode 100644 index 0000000000000000000000000000000000000000..5a9e7e208b661619d5e26625e849da1adda8a475 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1536, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14.json b/audioldm/clap/open_clip/model_configs/PANN-14.json new file mode 100644 index 0000000000000000000000000000000000000000..39a5134cde1d8c50f4758377c952ef22f07bab41 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-6.json b/audioldm/clap/open_clip/model_configs/PANN-6.json new file mode 100644 index 0000000000000000000000000000000000000000..21ebc344326de260c386ba77e0ad63cf9b04febf --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-6.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 512, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn6" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json b/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN101.json b/audioldm/clap/open_clip/model_configs/RN101.json new file mode 100644 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json b/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/audioldm/clap/open_clip/model_configs/RN50.json b/audioldm/clap/open_clip/model_configs/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN50x16.json b/audioldm/clap/open_clip/model_configs/RN50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN50x4.json b/audioldm/clap/open_clip/model_configs/RN50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-16.json b/audioldm/clap/open_clip/model_configs/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json b/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-32.json b/audioldm/clap/open_clip/model_configs/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-L-14.json b/audioldm/clap/open_clip/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/openai.py b/audioldm/clap/open_clip/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4eb8b55fe960e1792b3da804b60b3d8f70fe26 --- /dev/null +++ b/audioldm/clap/open_clip/openai.py @@ -0,0 +1,156 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import Union, List + +import torch + +from .model import build_model_from_openai_state_dict +from .pretrained import ( + get_pretrained_url, + list_pretrained_tag_models, + download_pretrained, +) + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_tag_models("openai") + + +def load_openai_model( + name: str, + model_cfg, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit=True, + cache_dir=os.path.expanduser("~/.cache/clip"), + enable_fusion: bool = False, + fusion_type: str = "None", +): + """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + + Returns + ------- + model : torch.nn.Module + The CLAP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if get_pretrained_url(name, "openai"): + model_path = download_pretrained( + get_pretrained_url(name, "openai"), root=cache_dir + ) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError( + f"Model {name} not found; available models = {list_openai_models()}" + ) + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn( + f"File {model_path} is not a JIT archive. Loading as a state dict instead" + ) + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + try: + model = build_model_from_openai_state_dict( + state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type + ).to(device) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict( + sd, model_cfg, enable_fusion, fusion_type + ).to(device) + + if str(device) == "cpu": + model.float() + return model + + # patch the device names + device_holder = torch.jit.trace( + lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] + ) + device_node = [ + n + for n in device_holder.graph.findAllNodes("prim::Constant") + if "Device" in repr(n) + ][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith( + "cuda" + ): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_audio) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace( + lambda: torch.ones([]).float(), example_inputs=[] + ) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [ + 1, + 2, + ]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_audio) + patch_float(model.encode_text) + model.float() + + model.audio_branch.audio_length = model.audio_cfg.audio_length + return model diff --git a/audioldm/clap/open_clip/pann_model.py b/audioldm/clap/open_clip/pann_model.py new file mode 100644 index 0000000000000000000000000000000000000000..874a03fc6eabcfdf3a63c59ca1e05d4f991453c5 --- /dev/null +++ b/audioldm/clap/open_clip/pann_model.py @@ -0,0 +1,703 @@ +# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition +# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn +# Some layers are re-designed for CLAP +import os + +os.environ["NUMBA_CACHE_DIR"] = "/tmp/" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from .utils import do_mixup, interpolate, pad_framewise_output +from .feature_fusion import iAFF, AFF, DAF + + +def init_layer(layer): + """Initialize a Linear or Convolutional layer.""" + nn.init.xavier_uniform_(layer.weight) + + if hasattr(layer, "bias"): + if layer.bias is not None: + layer.bias.data.fill_(0.0) + +def init_bn(bn): + """Initialize a Batchnorm layer.""" + bn.bias.data.fill_(0.0) + bn.weight.data.fill_(1.0) + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + + def forward(self, input, pool_size=(2, 2), pool_type="avg"): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + x = F.relu_(self.bn2(self.conv2(x))) + if pool_type == "max": + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg": + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg+max": + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception("Incorrect argument!") + + return x + + +class ConvBlock5x5(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock5x5, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(5, 5), + stride=(1, 1), + padding=(2, 2), + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_bn(self.bn1) + + def forward(self, input, pool_size=(2, 2), pool_type="avg"): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + if pool_type == "max": + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg": + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg+max": + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception("Incorrect argument!") + + return x + + +class AttBlock(nn.Module): + def __init__(self, n_in, n_out, activation="linear", temperature=1.0): + super(AttBlock, self).__init__() + + self.activation = activation + self.temperature = temperature + self.att = nn.Conv1d( + in_channels=n_in, + out_channels=n_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + self.cla = nn.Conv1d( + in_channels=n_in, + out_channels=n_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + + self.bn_att = nn.BatchNorm1d(n_out) + self.init_weights() + + def init_weights(self): + init_layer(self.att) + init_layer(self.cla) + init_bn(self.bn_att) + + def forward(self, x): + # x: (n_samples, n_in, n_time) + norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) + cla = self.nonlinear_transform(self.cla(x)) + x = torch.sum(norm_att * cla, dim=2) + return x, norm_att, cla + + def nonlinear_transform(self, x): + if self.activation == "linear": + return x + elif self.activation == "sigmoid": + return torch.sigmoid(x) + + +class Cnn14(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + + super(Cnn14, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + if (self.enable_fusion) and (self.fusion_type == "channel_map"): + self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) + else: + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] + ): + self.mel_conv1d = nn.Sequential( + nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1d(64), # No Relu + ) + if self.fusion_type == "daf_1d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_1d": + self.fusion_model = AFF(channels=64, type="1D") + elif self.fusion_type == "iaff_1d": + self.fusion_model = iAFF(channels=64, type="1D") + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + self.mel_conv2d = nn.Sequential( + nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + ) + + if self.fusion_type == "daf_2d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_2d": + self.fusion_model = AFF(channels=64, type="2D") + elif self.fusion_type == "iaff_2d": + self.fusion_model = iAFF(channels=64, type="2D") + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + if self.enable_fusion and input["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True + + if not self.enable_fusion: + x = self.spectrogram_extractor( + input["waveform"].to(device=device, non_blocking=True) + ) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + else: + longer_list = input["longer"].to(device=device, non_blocking=True) + x = input["mel_fusion"].to(device=device, non_blocking=True) + longer_list_idx = torch.where(longer_list)[0] + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: + new_x = x[:, 0:1, :, :].clone().contiguous() + # local processing + if len(longer_list_idx) > 0: + fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() + FB, FC, FT, FF = fusion_x_local.size() + fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) + fusion_x_local = torch.permute( + fusion_x_local, (0, 2, 1) + ).contiguous() + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.view( + FB, FC, FF, fusion_x_local.size(-1) + ) + fusion_x_local = ( + torch.permute(fusion_x_local, (0, 2, 1, 3)) + .contiguous() + .flatten(2) + ) + if fusion_x_local.size(-1) < FT: + fusion_x_local = torch.cat( + [ + fusion_x_local, + torch.zeros( + (FB, FF, FT - fusion_x_local.size(-1)), + device=device, + ), + ], + dim=-1, + ) + else: + fusion_x_local = fusion_x_local[:, :, :FT] + # 1D fusion + new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() + new_x[longer_list_idx] = self.fusion_model( + new_x[longer_list_idx], fusion_x_local + ) + x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] + else: + x = new_x + elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + global_x = x[:, 0:1, :, :] + + # global processing + B, C, H, W = global_x.shape + global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg") + if len(longer_list_idx) > 0: + local_x = x[longer_list_idx, 1:, :, :].contiguous() + TH = global_x.size(-2) + # local processing + B, C, H, W = local_x.shape + local_x = local_x.view(B * C, 1, H, W) + local_x = self.mel_conv2d(local_x) + local_x = local_x.view( + B, C, local_x.size(1), local_x.size(2), local_x.size(3) + ) + local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3) + TB, TC, _, TW = local_x.size() + if local_x.size(-2) < TH: + local_x = torch.cat( + [ + local_x, + torch.zeros( + (TB, TC, TH - local_x.size(-2), TW), + device=global_x.device, + ), + ], + dim=-2, + ) + else: + local_x = local_x[:, :, :TH, :] + + global_x[longer_list_idx] = self.fusion_model( + global_x[longer_list_idx], local_x + ) + x = global_x + else: + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 32) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + return output_dict + + +class Cnn6(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + + super(Cnn6, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) + + self.fc1 = nn.Linear(512, 512, bias=True) + self.fc_audioset = nn.Linear(512, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 16) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + + return output_dict + + +class Cnn10(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + + super(Cnn10, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + + self.fc1 = nn.Linear(1024, 1024, bias=True) + self.fc_audioset = nn.Linear(1024, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 32) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + + return output_dict + + +def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"): + try: + ModelProto = eval(audio_cfg.model_name) + model = ModelProto( + sample_rate=audio_cfg.sample_rate, + window_size=audio_cfg.window_size, + hop_size=audio_cfg.hop_size, + mel_bins=audio_cfg.mel_bins, + fmin=audio_cfg.fmin, + fmax=audio_cfg.fmax, + classes_num=audio_cfg.class_num, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + return model + except: + raise RuntimeError( + f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." + ) diff --git a/audioldm/clap/open_clip/pretrained.py b/audioldm/clap/open_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..e211d8b5b59320a599e62605f1dee6199f317253 --- /dev/null +++ b/audioldm/clap/open_clip/pretrained.py @@ -0,0 +1,167 @@ +import hashlib +import os +import urllib +import warnings + +from tqdm import tqdm + +_RN50 = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", +) + +_RN50_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", +) + +_RN101 = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", +) + +_RN101_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", +) + +_RN50x4 = dict( + openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", +) + +_RN50x16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", +) + +_RN50x64 = dict( + openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", +) + +_VITB32 = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB32_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", +) + +_VITL14 = dict( + openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", +) + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-L-14": _VITL14, +} + + +def list_pretrained(as_str: bool = False): + """returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [ + ":".join([k, t]) if as_str else (k, t) + for k in _PRETRAINED.keys() + for t in _PRETRAINED[k].keys() + ] + + +def list_pretrained_tag_models(tag: str): + """return all models having the specified pretrain tag""" + models = [] + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_model_tags(model: str): + """return all pretrain tags for the specified model architecture""" + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def get_pretrained_url(model: str, tag: str): + if model not in _PRETRAINED: + return "" + model_pretrained = _PRETRAINED[model] + if tag not in model_pretrained: + return "" + return model_pretrained[tag] + + +def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + if "openaipublic" in url: + expected_sha256 = url.split("/")[-2] + else: + expected_sha256 = "" + + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if ( + hashlib.sha256(open(download_target, "rb").read()).hexdigest() + == expected_sha256 + ): + return download_target + else: + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if ( + expected_sha256 + and hashlib.sha256(open(download_target, "rb").read()).hexdigest() + != expected_sha256 + ): + raise RuntimeError( + f"Model has been downloaded but the SHA256 checksum does not not match" + ) + + return download_target diff --git a/audioldm/clap/open_clip/timm_model.py b/audioldm/clap/open_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d1ab4666b5bab5038d44b90c9ddca5087de460 --- /dev/null +++ b/audioldm/clap/open_clip/timm_model.py @@ -0,0 +1,112 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +from collections import OrderedDict + +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import ( + AttentionPool2d as AbsAttentionPool2d, + ) +except ImportError as e: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool="avg", + proj="linear", + drop=0.0, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + self.trunk = timm.create_model(model_name, pretrained=pretrained) + feat_size = self.trunk.default_cfg.get("pool_size", None) + feature_ndim = 1 if not feat_size else 2 + if pool in ("abs_attn", "rot_attn"): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool="") + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == "abs_attn": + head_layers["pool"] = AbsAttentionPool2d( + prev_chs, feat_size=feat_size, out_features=embed_dim + ) + prev_chs = embed_dim + elif pool == "rot_attn": + head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, "projection layer needed if non-attention pooling is used." + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == "linear": + head_layers["drop"] = nn.Dropout(drop) + head_layers["proj"] = nn.Linear(prev_chs, embed_dim) + elif proj == "mlp": + head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" + ) + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/audioldm/clap/open_clip/tokenizer.py b/audioldm/clap/open_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4d28450ec5dd12a79daf38cf3088e9e73c2cd5 --- /dev/null +++ b/audioldm/clap/open_clip/tokenizer.py @@ -0,0 +1,197 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" + ) + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + if not special_tokens: + special_tokens = ["", ""] + else: + special_tokens = ["", ""] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + +_tokenizer = SimpleTokenizer() + + +def tokenize( + texts: Union[str, List[str]], context_length: int = 77 +) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + result[i, : len(tokens)] = torch.tensor(tokens) + + return result diff --git a/audioldm/clap/open_clip/transform.py b/audioldm/clap/open_clip/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..77aaa722c4a5544ac50de6df35d3e922f63b111d --- /dev/null +++ b/audioldm/clap/open_clip/transform.py @@ -0,0 +1,45 @@ +from torchvision.transforms import ( + Normalize, + Compose, + RandomResizedCrop, + InterpolationMode, + ToTensor, + Resize, + CenterCrop, +) + + +def _convert_to_rgb(image): + return image.convert("RGB") + + +def image_transform( + image_size: int, + is_train: bool, + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), +): + normalize = Normalize(mean=mean, std=std) + if is_train: + return Compose( + [ + RandomResizedCrop( + image_size, + scale=(0.9, 1.0), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) + else: + return Compose( + [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) diff --git a/audioldm/clap/open_clip/utils.py b/audioldm/clap/open_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..de59fd2746a13742197ecdeac671d61ece3f79ba --- /dev/null +++ b/audioldm/clap/open_clip/utils.py @@ -0,0 +1,361 @@ +import numpy as np +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d +import logging +# import h5py +from tqdm import tqdm +import random +import json +import os +import pathlib + +# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later. +dataset_split = { + "audiocaps": ["train", "valid", "test"], + "audioset": ["balanced_train", "unbalanced_train", "eval"], + "BBCSoundEffects": ["train", "test"], + "Clotho": ["train", "test", "valid"], + "free_to_use_sounds": ["train", "test"], + "paramount_motion": ["train", "test"], + "sonniss_game_effects": ["train", "test"], + "wesoundeffects": ["train", "test"], + "MACS": ["train", "test"], + "freesound": ["train", "test"], + "FSD50K": ["train", "test", "valid"], + "fsd50k_class_label": ["train", "test", "valid"], + "esc50": ["train", "test"], + "audiostock": ["train", "test"], + "freesound_no_overlap_noesc50": ["train", "test"], + "epidemic_sound_effects": ["train", "test"], + "VGGSound": ["train", "test"], + "urbansound8k_class_label": ["train", "test"], + "audioset_t5": ["balanced_train", "unbalanced_train", "eval"], + "epidemic_sound_effects_t5": ["train", "test"], + "WavText5K": ["train", "test"], + "esc50_no_overlap": ["train", "test"], + "usd8k_no_overlap": ["train", "test"], + "fsd50k_200_class_label": ["train", "test", "valid"], +} + + +def freeze_batch_norm_2d(module, module_match={}, name=""): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance( + module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) + ): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = ".".join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +def exist(dataset_name, dataset_type): + """ + Check if dataset exists + """ + if dataset_type in dataset_split[dataset_name]: + return True + else: + return False + + +def get_tar_path_from_dataset_name( + dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None +): + """ + Get tar path from dataset name and type + """ + output = [] + for n in dataset_names: + if full_dataset is not None and n in full_dataset: + current_dataset_types = dataset_split[n] + else: + current_dataset_types = dataset_types + for s in current_dataset_types: + tmp = [] + if islocal: + sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json" + if not os.path.exists(sizefilepath_): + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + else: + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + if not os.path.exists(sizefilepath_): + continue + sizes = json.load(open(sizefilepath_, "r")) + for k in sizes.keys(): + if islocal: + tmp.append(f"{dataset_path}/{n}/{s}/{k}") + else: + tmp.append( + f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -" + ) + if proportion != 1: + tmp = random.sample(tmp, int(proportion * len(tmp))) + output.append(tmp) + return sum(output, []) + + +def get_tar_path_from_txts(txt_path, islocal, proportion=1): + """ + Get tar path from txt path + """ + if isinstance(txt_path, (list, tuple)): + return sum( + [ + get_tar_path_from_txts( + txt_path[i], islocal=islocal, proportion=proportion + ) + for i in range(len(txt_path)) + ], + [], + ) + if isinstance(txt_path, str): + with open(txt_path) as f: + lines = f.readlines() + if islocal: + lines = [ + lines[i] + .split("\n")[0] + .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/") + for i in range(len(lines)) + ] + else: + lines = [ + lines[i].split("\n")[0].replace(".tar", ".tar -") + for i in range(len(lines)) + ] + if proportion != 1: + print("Sampling tars with proportion of {}".format(proportion)) + lines = random.sample(lines, int(proportion * len(lines))) + return lines + + +def get_mix_lambda(mixup_alpha, batch_size): + mixup_lambdas = [ + np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size) + ] + return np.array(mixup_lambdas).astype(np.float32) + + +def do_mixup(x, mixup_lambda): + """ + Args: + x: (batch_size , ...) + mixup_lambda: (batch_size,) + Returns: + out: (batch_size, ...) + """ + out = ( + x.transpose(0, -1) * mixup_lambda + + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda) + ).transpose(0, -1) + return out + + +def interpolate(x, ratio): + """Interpolate data in time domain. This is used to compensate the + resolution reduction in downsampling of a CNN. + + Args: + x: (batch_size, time_steps, classes_num) + ratio: int, ratio to interpolate + Returns: + upsampled: (batch_size, time_steps * ratio, classes_num) + """ + (batch_size, time_steps, classes_num) = x.shape + upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) + upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) + return upsampled + + +def pad_framewise_output(framewise_output, frames_num): + """Pad framewise_output to the same length as input frames. The pad value + is the same as the value of the last frame. + Args: + framewise_output: (batch_size, frames_num, classes_num) + frames_num: int, number of frames to pad + Outputs: + output: (batch_size, frames_num, classes_num) + """ + pad = framewise_output[:, -1:, :].repeat( + 1, frames_num - framewise_output.shape[1], 1 + ) + """tensor for padding""" + + output = torch.cat((framewise_output, pad), dim=1) + """(batch_size, frames_num, classes_num)""" + + +# def process_ipc(index_path, classes_num, filename): +# # load data +# logging.info("Load Data...............") +# ipc = [[] for _ in range(classes_num)] +# with h5py.File(index_path, "r") as f: +# for i in tqdm(range(len(f["target"]))): +# t_class = np.where(f["target"][i])[0] +# for t in t_class: +# ipc[t].append(i) +# print(ipc) +# np.save(filename, ipc) +# logging.info("Load Data Succeed...............") + + +def save_to_dict(s, o_={}): + sp = s.split(": ") + o_.update({sp[0]: float(sp[1])}) + return o_ + + +def get_data_from_log(txt_path): + """ + Output dictionary from out.txt log file + """ + with open(txt_path) as f: + lines = f.readlines() + val_data = {} + train_data = {} + train_losses = [] + train_losses_epoch = [] + for i in range(len(lines)): + if "| INFO |" in lines[i]: + if "Eval Epoch" in lines[i]: + if "val_loss" in lines[i]: + # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", "")) + line = lines[i].split("Eval Epoch: ")[-1] + num_epoch = int(line.split(" ")[0].split(" ")[0]) + d = { + line.split(" ")[0] + .split(" ")[1] + .replace(":", ""): float(line.split(" ")[0].split(" ")[-1]) + } + for i in range(1, len(line.split(" "))): + d = save_to_dict(line.split(" ")[i], d) + val_data[num_epoch] = d + elif "Train Epoch" in lines[i]: + num_epoch = int(lines[i].split("Train Epoch: ")[1][0]) + loss = float(lines[i].split("Loss: ")[-1].split(" (")[0]) + train_losses.append(loss) + train_losses_epoch.append(num_epoch) + for i in range(len(train_losses)): + train_data[i] = { + "num_epoch": train_losses_epoch[i], + "train_loss": train_losses[i], + } + return train_data, val_data + + +def save_p(obj, filename): + import pickle + + try: + from deepdiff import DeepDiff + except: + os.system("pip install deepdiff") + from deepdiff import DeepDiff + with open(filename, "wb") as file: + pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol + with open(filename, "rb") as file: + z = pickle.load(file) + assert ( + DeepDiff(obj, z, ignore_string_case=True) == {} + ), "there is something wrong with the saving process" + return + + +def load_p(filename): + import pickle + + with open(filename, "rb") as file: + z = pickle.load(file) + return z + + +def save_json(data, name="data.json"): + import json + + with open(name, "w") as fp: + json.dump(data, fp) + return + + +def load_json(name): + import json + + with open(name, "r") as fp: + data = json.load(fp) + return data + + +from multiprocessing import Process, Manager +from multiprocessing import Process, Value, Array +from ctypes import c_wchar + + +def load_class_label(path): + # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing + # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array + out = None + if path is not None: + if pathlib.Path(path).suffix in [".pkl", ".pickle"]: + out = load_p(path) + elif pathlib.Path(path).suffix in [".json", ".txt"]: + out = load_json(path) + elif pathlib.Path(path).suffix in [".npy", ".npz"]: + out = np.load(path) + elif pathlib.Path(path).suffix in [".csv"]: + import pandas as pd + + out = pd.read_csv(path) + return out + # if out is None: + # return None + # else: + # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False) + # val = Array('i', out.values(), lock=False) + # return (key, val) + + +from torch import optim + + +def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): + if optimizer_name.lower() == "adamw": + optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps) + elif optimizer_name.lower() == "sgd": + optimizer = optim.SGD(params, lr=lr, momentum=momentum) + elif optimizer_name.lower() == "adam": + optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps) + else: + raise ValueError("optimizer name is not correct") + return optimizer diff --git a/audioldm/clap/open_clip/version.py b/audioldm/clap/open_clip/version.py new file mode 100644 index 0000000000000000000000000000000000000000..3ced3581bb601ae91b1e1da4b8f4f520855a065e --- /dev/null +++ b/audioldm/clap/open_clip/version.py @@ -0,0 +1 @@ +__version__ = "0.2.1" diff --git a/audioldm/clap/training/__init__.py b/audioldm/clap/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm/clap/training/audioset_textmap.npy b/audioldm/clap/training/audioset_textmap.npy new file mode 100644 index 0000000000000000000000000000000000000000..3da4c92d3819aaec11e5f576464a9973a6df811b --- /dev/null +++ b/audioldm/clap/training/audioset_textmap.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b +size 84448 diff --git a/audioldm/clap/training/data.py b/audioldm/clap/training/data.py new file mode 100644 index 0000000000000000000000000000000000000000..1d80d598be97d4e04f1b7f3e53a877cfe82ce667 --- /dev/null +++ b/audioldm/clap/training/data.py @@ -0,0 +1,977 @@ +import ast +import json +import logging +import math +import os +import random +# import h5py +from dataclasses import dataclass +from audioldm.clap.training.params import parse_args +# import braceexpand +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.datasets as datasets +import torchvision.transforms +# import webdataset as wds +from PIL import Image +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler +from torch.utils.data.distributed import DistributedSampler +from functools import partial +import soundfile as sf +import io +from pathlib import Path +# import wget + +from audioldm.clap.open_clip.utils import ( + get_tar_path_from_dataset_name, + dataset_split, +) +from audioldm.clap.open_clip.utils import load_p, load_class_label +import copy + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +try: + import torchaudio +except ImportError: + torchaudio = None + +from audioldm.clap.open_clip import tokenize + + +def tokenizer(text): + return tokenize(text).squeeze(0) + + +from transformers import RobertaTokenizer + +tokenize = RobertaTokenizer.from_pretrained("roberta-base") + + +def tokenizer(text): + result = tokenize( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + + +# initizlied the audioset map +_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy") +_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1.0, a_max=1.0) + return (x * 32767.0).astype(np.int16) + + +# For Toy Dataset +# class ToyDataset(Dataset): +# def __init__(self, index_path, ipc, config, eval_mode=False): +# """Toy Dataset for testing the audioset input with text labels +# Parameters +# ---------- +# index_path: str +# the link to the h5 file of each audio +# idc: str +# the link to the npy file, the number of samples in each class +# config: dict +# the audio cfg file +# eval_model (bool): to indicate if the dataset is a testing dataset +# """ +# self.audio_cfg = config["audio_cfg"] +# self.text_cfg = config["text_cfg"] +# self.fp = h5py.File(index_path, "r") +# self.ipc = np.load(ipc, allow_pickle=True) +# self.total_size = len(self.fp["audio_name"]) +# self.classes_num = self.audio_cfg["class_num"] +# self.eval_mode = eval_mode + +# if not eval_mode: +# self.generate_queue() +# else: +# self.queue = [] +# for i in range(self.total_size): +# target = self.fp["target"][i] +# if np.sum(target) > 0: +# self.queue.append(i) +# self.total_size = len(self.queue) +# logging.info("total dataset size: %d" % (self.total_size)) +# logging.info("class num: %d" % (self.classes_num)) + +# def time_shifting(self, x): +# frame_num = len(x) +# shift_len = random.randint(0, frame_num - 1) +# new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0) +# return new_sample + +# def generate_queue(self): +# self.queue = [] +# while len(self.queue) < self.total_size: +# class_set = [*range(self.classes_num)] +# random.shuffle(class_set) +# self.queue += [ +# self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set +# ] +# self.queue = self.queue[: self.total_size] + +# logging.info("queue regenerated:%s" % (self.queue[-5:])) + +# def crop_wav(self, x): +# crop_size = self.audio_cfg["crop_size"] +# crop_pos = random.randint(0, len(x) - crop_size - 1) +# return x[crop_pos : crop_pos + crop_size] + +# def prompt_text(self, target): +# events = _AUDIOSET_MAP[np.where(target > 0)] +# event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1] +# text = tokenize(event_text)[0] +# return text + +# def __getitem__(self, index): +# """Load waveform, text, and target of an audio clip + +# Parameters +# ---------- +# index: int +# the index number +# Return +# ------ +# output: dict { +# "hdf5_path": str, +# "index_in_hdf5": int, +# "audio_name": str, +# "waveform": list (audio_length,), +# "target": list (class_num, ), +# "text": torch.tensor (context_length,) +# } +# the output dictionary +# """ +# s_index = self.queue[index] + +# audio_name = self.fp["audio_name"][s_index].decode() +# # Hardcode here CHANGE +# hdf5_path = ( +# self.fp["hdf5_path"][s_index] +# .decode() +# .replace( +# "../workspace", +# "/home/la/kechen/Research/ke_zsasp/workspace", +# ) +# ) +# r_idx = self.fp["index_in_hdf5"][s_index] +# target = self.fp["target"][s_index].astype(np.float32) +# text = self.prompt_text(target) +# with h5py.File(hdf5_path, "r") as f: +# waveform = int16_to_float32(f["waveform"][r_idx])[ +# : self.audio_cfg["clip_samples"] +# ] +# assert ( +# len(waveform) == self.audio_cfg["clip_samples"] +# ), "The sample length is not match" +# # Time shift +# # if (self.config.enable_time_shift) and (not self.eval_mode): +# # waveform = self.time_shifting(waveform) +# # # Label Enhance +# # if (self.config.crop_size is not None) and (not self.eval_mode): +# # waveform = self.crop_wav(waveform) +# # # the label enhance rate is fixed 0.5 +# # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5: +# # kidx = np.where(target)[0] +# # for k in kidx: +# # for add_key in self.class_map[k][1]: +# # target[add_key] = 1.0 +# # if len(self.class_map[k][2]) > 0: +# # add_key = random.choice(self.class_map[k][2]) +# # target[add_key] = 1.0 + +# # missing the text input +# mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :] +# mel_spec = ( +# torch.cat( +# [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0 +# ) +# .cpu() +# .numpy() +# ) +# longer = random.choice([True, False]) +# if longer == False: +# mel_spec[1:, :, :] = 0.0 +# data_dict = { +# "hdf5_path": hdf5_path, +# "index_in_hdf5": r_idx, +# "audio_name": audio_name, +# "waveform": waveform, +# "class_label": target, +# "text": text, +# "longer": longer, +# "mel_fusion": mel_spec, +# } +# return data_dict + +# def __len__(self): +# return self.total_size + + +class CsvDataset(Dataset): + def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"): + logging.debug(f"Loading csv data from {input_filename}.") + df = pd.read_csv(input_filename, sep=sep) + + self.images = df[img_key].tolist() + self.captions = df[caption_key].tolist() + self.transforms = transforms + logging.debug("Done loading data.") + + def __len__(self): + return len(self.captions) + + def __getitem__(self, idx): + images = self.transforms(Image.open(str(self.images[idx]))) + texts = tokenize([str(self.captions[idx])])[0] + return images, texts + + +@dataclass +class DataInfo: + dataloader: DataLoader + sampler: DistributedSampler + + +def preprocess_txt(text): + return tokenize([str(text)])[0] + + +def get_dataset_size(shards, sizefilepath_=None, is_local=True): + if isinstance(shards, list): + size_list = [] + for s in shards: + size_list.append( + get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0] + ) + else: + if not is_local: + for n in dataset_split.keys(): + if n in shards.split("/"): + break + for s in dataset_split[n]: + if s in shards.split("/"): + break + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + shards_list = list(braceexpand.braceexpand(shards)) + dir_path = os.path.dirname(shards) + if sizefilepath_ is not None: + sizes = json.load(open(sizefilepath_, "r")) + total_size = sum( + [ + int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))]) + for shard in shards_list + ] + ) + else: + sizes_filename = os.path.join(dir_path, "sizes.json") + len_filename = os.path.join(dir_path, "__len__") + if os.path.exists(sizes_filename): + sizes = json.load(open(sizes_filename, "r")) + total_size = sum( + [int(sizes[os.path.basename(shard)]) for shard in shards_list] + ) + elif os.path.exists(len_filename): + # FIXME this used to be eval(open(...)) but that seemed rather unsafe + total_size = ast.literal_eval(open(len_filename, "r").read()) + else: + raise Exception( + "Cannot find sizes file for dataset. Please specify the path to the file." + ) + # total_size = None # num samples undefined + # some common dataset sizes (at time of authors last download) + # cc3m-train: 2905954 + # cc12m: 10968539 + # LAION-400m: 407332084 + num_shards = len(shards_list) + if isinstance(shards, list): + return sum(size_list), len(shards) + else: + return total_size, num_shards + + +def get_imagenet(args, preprocess_fns, split): + assert split in ["train", "val", "v2"] + is_train = split == "train" + preprocess_train, preprocess_val = preprocess_fns + + if split == "v2": + from imagenetv2_pytorch import ImageNetV2Dataset + + dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) + else: + if is_train: + data_path = args.imagenet_train + preprocess_fn = preprocess_train + else: + data_path = args.imagenet_val + preprocess_fn = preprocess_val + assert data_path + + dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) + + if is_train: + idxs = np.zeros(len(dataset.targets)) + target_array = np.array(dataset.targets) + k = 50 + for c in range(1000): + m = target_array == c + n = len(idxs[m]) + arr = np.zeros(n) + arr[:k] = 1 + np.random.shuffle(arr) + idxs[m] = arr + + idxs = idxs.astype("int") + sampler = SubsetRandomSampler(np.where(idxs)[0]) + else: + sampler = None + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=sampler, + ) + + return DataInfo(dataloader, sampler) + + +def count_samples(dataloader): + os.environ["WDS_EPOCH"] = "0" + n_elements, n_batches = 0, 0 + for images, texts in dataloader: + n_batches += 1 + n_elements += len(images) + assert len(images) == len(texts) + return n_elements, n_batches + + +def filter_no_caption(sample): + return "txt" in sample + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") + return True + + +_SHARD_SHUFFLE_SIZE = 2000 +_SHARD_SHUFFLE_INITIAL = 500 +_SAMPLE_SHUFFLE_SIZE = 5000 +_SAMPLE_SHUFFLE_INITIAL = 1000 + + +def sample_prop(sizefile, inputs, proportion, is_local=True): + """ + Sample a proportion of the data. + """ + file_path_dict = { + os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0] + for i in range(len(inputs)) + } + sampled_filepath_dict = {} + sampled_size_dict = {} + if not is_local: + if os.path.exists("sizes.json"): + os.remove("sizes.json") + wget.download(sizefile, "sizes.json") + sizefile = "sizes.json" + with open(sizefile, "r", encoding="UTF-8") as f: + load_dict = json.load(f) + L = int(len(file_path_dict) * proportion) + subkeys = random.sample(file_path_dict.keys(), L) + for k in subkeys: + sampled_size_dict[k] = load_dict[k] + sampled_filepath_dict[k] = file_path_dict[k] + return ( + sum(sampled_size_dict.values()), + L, + [os.path.join(v, k) for k, v in sampled_filepath_dict.items()], + sampled_size_dict, + ) + + +def get_mel(audio_data, audio_cfg): + # mel shape: (n_mels, T) + mel = torchaudio.transforms.MelSpectrogram( + sample_rate=audio_cfg["sample_rate"], + n_fft=audio_cfg["window_size"], + win_length=audio_cfg["window_size"], + hop_length=audio_cfg["hop_size"], + center=True, + pad_mode="reflect", + power=2.0, + norm=None, + onesided=True, + n_mels=64, + f_min=audio_cfg["fmin"], + f_max=audio_cfg["fmax"], + ).to(audio_data.device) + mel = mel(audio_data) + # Align to librosa: + # librosa_melspec = librosa.feature.melspectrogram( + # waveform, + # sr=audio_cfg['sample_rate'], + # n_fft=audio_cfg['window_size'], + # hop_length=audio_cfg['hop_size'], + # win_length=audio_cfg['window_size'], + # center=True, + # pad_mode="reflect", + # power=2.0, + # n_mels=64, + # norm=None, + # htk=True, + # f_min=audio_cfg['fmin'], + # f_max=audio_cfg['fmax'] + # ) + # we use log mel spectrogram as input + mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) + return mel.T # (T, n_mels) + + +def get_audio_features( + sample, audio_data, max_len, data_truncating, data_filling, audio_cfg +): + """ + Calculate and add audio features to sample. + Sample: a dict containing all the data of current sample. + audio_data: a tensor of shape (T) containing audio data. + max_len: the maximum length of audio data. + data_truncating: the method of truncating data. + data_filling: the method of filling data. + audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. + """ + with torch.no_grad(): + if len(audio_data) > max_len: + if data_truncating == "rand_trunc": + longer = torch.tensor([True]) + elif data_truncating == "fusion": + # fusion + mel = get_mel(audio_data, audio_cfg) + # split to three parts + chunk_frames = ( + max_len // audio_cfg["hop_size"] + 1 + ) # the +1 related to how the spectrogram is computed + total_frames = mel.shape[0] + if chunk_frames == total_frames: + # there is a corner case where the audio length is + # larger than max_len but smaller than max_len+hop_size. + # In this case, we just use the whole audio. + mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([False]) + else: + ranges = np.array_split( + list(range(0, total_frames - chunk_frames + 1)), 3 + ) + # print('total_frames-chunk_frames:', total_frames-chunk_frames, + # 'len(audio_data):', len(audio_data), + # 'chunk_frames:', chunk_frames, + # 'total_frames:', total_frames) + if len(ranges[1]) == 0: + # if the audio is too short, we just use the first chunk + ranges[1] = [0] + if len(ranges[2]) == 0: + # if the audio is too short, we just use the first chunk + ranges[2] = [0] + # randomly choose index for each part + idx_front = np.random.choice(ranges[0]) + idx_middle = np.random.choice(ranges[1]) + idx_back = np.random.choice(ranges[2]) + # select mel + mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :] + mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :] + mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :] + + # shrink the mel + mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])( + mel[None] + )[0] + # logging.info(f"mel_shrink.shape: {mel_shrink.shape}") + + # stack + mel_fusion = torch.stack( + [mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink], + dim=0, + ) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([True]) + else: + raise NotImplementedError( + f"data_truncating {data_truncating} not implemented" + ) + # random crop to max_len (for compatibility) + overflow = len(audio_data) - max_len + idx = np.random.randint(0, overflow + 1) + audio_data = audio_data[idx : idx + max_len] + + else: # padding if too short + if len(audio_data) < max_len: # do nothing if equal + if data_filling == "repeatpad": + n_repeat = int(max_len / len(audio_data)) + audio_data = audio_data.repeat(n_repeat) + # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0) + # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0] + audio_data = F.pad( + audio_data, + (0, max_len - len(audio_data)), + mode="constant", + value=0, + ) + elif data_filling == "pad": + audio_data = F.pad( + audio_data, + (0, max_len - len(audio_data)), + mode="constant", + value=0, + ) + elif data_filling == "repeat": + n_repeat = int(max_len / len(audio_data)) + audio_data = audio_data.repeat(n_repeat + 1)[:max_len] + else: + raise NotImplementedError( + f"data_filling {data_filling} not implemented" + ) + if data_truncating == "fusion": + mel = get_mel(audio_data, audio_cfg) + mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([False]) + + sample["longer"] = longer + sample["waveform"] = audio_data + + return sample + + +def preprocess( + sample, + audio_ext, + text_ext, + max_len, + audio_cfg, + class_index_dict=None, + data_filling="pad", + data_truncating="rand_trunc", + text_augment_selection=None, +): + """ + Preprocess a single sample for wdsdataloader. + """ + audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) + audio_data = int16_to_float32(float32_to_int16(audio_data)) + audio_data = torch.tensor(audio_data).float() + + # TODO: (yusong) to be include in the future + # # if torchaudio not installed, use soundfile to load audio + # if torchaudio is None: + # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) + # audio_data = torch.tensor(audio_data).float() + # else: + # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py + # with tempfile.TemporaryDirectory() as dirname: + # os.makedirs(dirname, exist_ok=True) + # fname = os.path.join(dirname, f"file.flac") + # with open(fname, "wb") as stream: + # stream.write(sample[audio_ext]) + # audio_data, orig_sr = torchaudio.load(fname) + # audio_data = audio_data[0, :].float() + + sample = get_audio_features( + sample, audio_data, max_len, data_truncating, data_filling, audio_cfg + ) + del sample[audio_ext] + + try: + json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) + except: + print("sample[__url__]:", sample["__url__"]) + + # For selecting augmented text from dataset + if text_augment_selection is None or text_augment_selection == "none": + texts = json_dict_raw["text"] + elif text_augment_selection == "all": + if "text_augment_all" in json_dict_raw.keys(): + texts = json_dict_raw["text_augment_all"] + else: + texts = json_dict_raw["text"] + elif text_augment_selection == "augment_only": + if "text_augment_all" in json_dict_raw.keys(): + if json_dict_raw["text_augment_t5"] is None: + texts = json_dict_raw["text"] + else: + texts = json_dict_raw["text_augment_t5"] + else: + texts = json_dict_raw["text"] + else: + raise NotImplementedError( + f"text_augment_selection {text_augment_selection} not implemented" + ) + sample["full_text"] = texts + + if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: + texts = random.choice(texts) + sample["raw_text"] = texts + sample["text"] = tokenizer(texts) # text shape: [num_token] + if class_index_dict is not None: + # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing + # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array + # key, val = class_index_dict + # key = key[:].split('\n') + # _dict = {k: v for k, v in zip(key, val)} + sample["class_label"] = np.zeros(len(class_index_dict.keys())) + for x in json_dict_raw["tag"]: + sample["class_label"][class_index_dict[x]] = 1 + sample["class_label"] = torch.tensor(sample["class_label"]).float() + del sample[text_ext] + sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext + sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext + sample["audio_orig_sr"] = orig_sr + return sample + + +def collate_fn(batch): + """ + Collate function for wdsdataloader. + batch: a list of dict, each dict is a sample + """ + # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend. + batch_dict = {} + for k in batch[0].keys(): + if isinstance(batch[0][k], dict): # dealwith bert tokenizer output + batch_dict[k] = {} + for kk in batch[0][k].keys(): + tmp = [] + for i in range(len(batch)): + tmp.append(batch[i][k][kk]) + batch_dict[k][kk] = torch.vstack(tmp) + elif isinstance(batch[0][k], torch.Tensor): + batch_dict[k] = torch.stack([sample[k] for sample in batch]) + elif isinstance(batch[0][k], np.ndarray): + batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch])) + else: + batch_dict[k] = [sample[k] for sample in batch] + return batch_dict + + +def get_wds_dataset( + args, + model_cfg, + is_train, + audio_ext="flac", + text_ext="json", + max_len=480000, + proportion=1.0, + sizefilepath_=None, + is_local=None, +): + """ + Get a dataset for wdsdataloader. + """ + if is_local is None and (not args.remotedata is None): + is_local = not args.remotedata + + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + + if not sizefilepath_ is None: + sizefilepath = sizefilepath_ + else: + sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json") + + if proportion != 1.0: + num_samples, num_shards, input_shards, _ = sample_prop( + sizefilepath, input_shards, proportion, is_local=is_local + ) + else: + num_samples, num_shards = get_dataset_size( + input_shards, sizefilepath_=sizefilepath_, is_local=is_local + ) + + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + "Currently, number of dataset samples must be specified for training dataset. " + "Please specify via `--train-num-samples` if no dataset length info present." + ) + else: + num_samples = ( + args.val_num_samples or 0 + ) # eval will just exhaust the iterator if not specified + + pipeline = [wds.SimpleShardList(input_shards)] + # at this point we have an iterator over all the shards + # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node + if is_train or args.parallel_eval: + pipeline.extend( + [ + wds.detshuffle( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + ), + wds.split_by_node, + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker at each node + wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + rng=random.Random(args.seed), + ), + # wds.repeatedly, # FIXME determine if this is beneficial + ] + ) + else: + pipeline.extend( + [ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ] + ) + pipeline.append( + wds.map( + partial( + preprocess, + audio_ext=audio_ext, + text_ext=text_ext, + max_len=max_len, + audio_cfg=model_cfg["audio_cfg"], + class_index_dict=copy.deepcopy(args.class_index_dict), + data_filling=args.data_filling, + data_truncating=args.data_truncating, + text_augment_selection=args.text_augment_selection, + ) + ), + ) + + pipeline.append( + wds.batched( + args.batch_size, + partial=not (is_train or args.parallel_eval), + collation_fn=collate_fn, + ) + ) + + dataset = wds.DataPipeline(*pipeline) + if is_train or args.parallel_eval: + # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples. + # (yusong): See comments below. + # roll over and repeat a few samples to get same number of full batches on each node + global_batch_size = args.batch_size * args.world_size + num_batches = math.ceil(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = math.ceil( + num_batches / num_workers + ) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch( + num_worker_batches + ) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + kwargs = {} + if args.horovod: # multi-node training on summit + kwargs["multiprocessing_context"] = "forkserver" + + dataloader = wds.WebLoader( + dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader, None) + + +def wds_batch_list2dict( + batch, + keys=[ + "__url__", + "__key__", + "waveform", + "text", + "raw_text", + "audio_name", + "text_name", + "audio_orig_sr", + ], +): + """ + Return a dictionary of the batch, with keys as the names of the fields. + """ + assert len(keys) == len( + batch + ), "batch must have same number of keys as keys argument" + return {keys[i]: batch[i] for i in range(len(batch))} + + +def get_csv_dataset(args, preprocess_fn, is_train): + input_filename = args.train_data if is_train else args.val_data + assert input_filename + dataset = CsvDataset( + input_filename, + preprocess_fn, + img_key=args.csv_img_key, + caption_key=args.csv_caption_key, + sep=args.csv_separator, + ) + num_samples = len(dataset) + sampler = DistributedSampler(dataset) if args.distributed and is_train else None + shuffle = is_train and sampler is None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_toy_dataset(args, model_cfg, is_train): + index_path = args.train_data if is_train else args.val_data + ipc_path = args.train_ipc if is_train else args.val_ipc + assert index_path and ipc_path + eval_mode = not is_train + dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode) + + num_samples = len(dataset) + sampler = ( + DistributedSampler(dataset, shuffle=False) + if args.distributed and is_train + else None + ) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_dataset_fn(data_path, dataset_type): + if dataset_type == "webdataset": + return get_wds_dataset + elif dataset_type == "csv": + return get_csv_dataset + elif dataset_type == "auto": + ext = data_path.split(".")[-1] + if ext in ["csv", "tsv"]: + return get_csv_dataset + elif ext in ["tar"]: + return get_wds_dataset + else: + raise ValueError( + f"Tried to figure out dataset type, but failed for extention {ext}." + ) + elif dataset_type == "toy": + return get_toy_dataset + else: + raise ValueError(f"Unsupported dataset type: {dataset_type}") + + +def get_data(args, model_cfg): + data = {} + + args.class_index_dict = load_class_label(args.class_label_path) + + if args.datasetinfos is None: + args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] + if args.dataset_type == "webdataset": + args.train_data = get_tar_path_from_dataset_name( + args.datasetnames, + args.datasetinfos, + islocal=not args.remotedata, + proportion=args.dataset_proportion, + dataset_path=args.datasetpath, + full_dataset=args.full_train_dataset, + ) + + if args.full_train_dataset is None: + args.full_train_dataset = [] + if args.exclude_eval_dataset is None: + args.exclude_eval_dataset = [] + excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset + + val_dataset_names = ( + [n for n in args.datasetnames if n not in excluded_eval_datasets] + if excluded_eval_datasets + else args.datasetnames + ) + args.val_dataset_names = val_dataset_names + args.val_data = get_tar_path_from_dataset_name( + val_dataset_names, + ["valid", "test", "eval"], + islocal=not args.remotedata, + proportion=1, + dataset_path=args.datasetpath, + full_dataset=None, + ) + + if args.train_data: + data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( + args, model_cfg, is_train=True + ) + + if args.val_data: + data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( + args, model_cfg, is_train=False + ) + + return data diff --git a/audioldm/clap/training/distributed.py b/audioldm/clap/training/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa61f76c5cc3ab9f6a9643042afa8e1f2e1cb7f --- /dev/null +++ b/audioldm/clap/training/distributed.py @@ -0,0 +1,150 @@ +import os + +import torch +import socket + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def is_global_master(args): + return args.rank == 0 + + +def is_local_master(args): + return args.local_rank == 0 + + +def is_master(args, local=False): + return is_local_master(args) if local else is_global_master(args) + + +def is_using_horovod(): + # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set + # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... + ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] + pmi_vars = ["PMI_RANK", "PMI_SIZE"] + if all([var in os.environ for var in ompi_vars]) or all( + [var in os.environ for var in pmi_vars] + ): + return True + else: + return False + + +def is_using_distributed(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) > 1 + if "SLURM_NTASKS" in os.environ: + return int(os.environ["SLURM_NTASKS"]) > 1 + return False + + +def world_info_from_env(): + local_rank = 0 + for v in ( + "SLURM_LOCALID", + "MPI_LOCALRANKID", + "OMPI_COMM_WORLD_LOCAL_RANK", + "LOCAL_RANK", + ): + if v in os.environ: + local_rank = int(os.environ[v]) + break + global_rank = 0 + for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"): + if v in os.environ: + global_rank = int(os.environ[v]) + break + world_size = 1 + for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"): + if v in os.environ: + world_size = int(os.environ[v]) + break + + return local_rank, global_rank, world_size + + +def init_distributed_device(args): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + args.distributed = False + args.world_size = 1 + args.rank = 0 # global rank + args.local_rank = 0 + if args.horovod: + assert hvd is not None, "Horovod is not installed" + hvd.init() + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.local_rank = local_rank + args.rank = world_rank + args.world_size = world_size + # args.local_rank = int(hvd.local_rank()) + # args.rank = hvd.rank() + # args.world_size = hvd.size() + args.distributed = True + os.environ["LOCAL_RANK"] = str(args.local_rank) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + print( + f"Distributed training: local_rank={args.local_rank}, " + f"rank={args.rank}, world_size={args.world_size}, " + f"hostname={socket.gethostname()}, pid={os.getpid()}" + ) + elif is_using_distributed(): + if "SLURM_PROCID" in os.environ: + # DDP via SLURM + args.local_rank, args.rank, args.world_size = world_info_from_env() + # SLURM var -> torch.distributed vars in case needed + os.environ["LOCAL_RANK"] = str(args.local_rank) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.local_rank = local_rank + args.rank = world_rank + args.world_size = world_size + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + else: + # DDP via torchrun, torch.distributed.launch + args.local_rank, _, _ = world_info_from_env() + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url + ) + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + args.distributed = True + print( + f"Distributed training: local_rank={args.local_rank}, " + f"rank={args.rank}, world_size={args.world_size}, " + f"hostname={socket.gethostname()}, pid={os.getpid()}" + ) + + if torch.cuda.is_available(): + if args.distributed and not args.no_set_device_rank: + device = "cuda:%d" % args.local_rank + else: + device = "cuda:0" + torch.cuda.set_device(device) + else: + device = "cpu" + args.device = device + device = torch.device(device) + return device diff --git a/audioldm/clap/training/imagenet_zeroshot_data.py b/audioldm/clap/training/imagenet_zeroshot_data.py new file mode 100644 index 0000000000000000000000000000000000000000..d32e55328d6799ccb8d61625f43abb80a33d6c17 --- /dev/null +++ b/audioldm/clap/training/imagenet_zeroshot_data.py @@ -0,0 +1,1088 @@ +# NOTE: This script is currently not supported for CLAP. + +imagenet_classnames = [ + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead shark", + "electric ray", + "stingray", + "rooster", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "American robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "American dipper", + "kite (bird of prey)", + "bald eagle", + "vulture", + "great grey owl", + "fire salamander", + "smooth newt", + "newt", + "spotted salamander", + "axolotl", + "American bullfrog", + "tree frog", + "tailed frog", + "loggerhead sea turtle", + "leatherback sea turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "green iguana", + "Carolina anole", + "desert grassland whiptail lizard", + "agama", + "frilled-necked lizard", + "alligator lizard", + "Gila monster", + "European green lizard", + "chameleon", + "Komodo dragon", + "Nile crocodile", + "American alligator", + "triceratops", + "worm snake", + "ring-necked snake", + "eastern hog-nosed snake", + "smooth green snake", + "kingsnake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "African rock python", + "Indian cobra", + "green mamba", + "sea snake", + "Saharan horned viper", + "eastern diamondback rattlesnake", + "sidewinder rattlesnake", + "trilobite", + "harvestman", + "scorpion", + "yellow garden spider", + "barn spider", + "European garden spider", + "southern black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie grouse", + "peafowl", + "quail", + "partridge", + "african grey parrot", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "duck", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "red king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "great egret", + "bittern bird", + "crane bird", + "limpkin", + "common gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "dunlin", + "common redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese Chin", + "Maltese", + "Pekingese", + "Shih Tzu", + "King Charles Spaniel", + "Papillon", + "toy terrier", + "Rhodesian Ridgeback", + "Afghan Hound", + "Basset Hound", + "Beagle", + "Bloodhound", + "Bluetick Coonhound", + "Black and Tan Coonhound", + "Treeing Walker Coonhound", + "English foxhound", + "Redbone Coonhound", + "borzoi", + "Irish Wolfhound", + "Italian Greyhound", + "Whippet", + "Ibizan Hound", + "Norwegian Elkhound", + "Otterhound", + "Saluki", + "Scottish Deerhound", + "Weimaraner", + "Staffordshire Bull Terrier", + "American Staffordshire Terrier", + "Bedlington Terrier", + "Border Terrier", + "Kerry Blue Terrier", + "Irish Terrier", + "Norfolk Terrier", + "Norwich Terrier", + "Yorkshire Terrier", + "Wire Fox Terrier", + "Lakeland Terrier", + "Sealyham Terrier", + "Airedale Terrier", + "Cairn Terrier", + "Australian Terrier", + "Dandie Dinmont Terrier", + "Boston Terrier", + "Miniature Schnauzer", + "Giant Schnauzer", + "Standard Schnauzer", + "Scottish Terrier", + "Tibetan Terrier", + "Australian Silky Terrier", + "Soft-coated Wheaten Terrier", + "West Highland White Terrier", + "Lhasa Apso", + "Flat-Coated Retriever", + "Curly-coated Retriever", + "Golden Retriever", + "Labrador Retriever", + "Chesapeake Bay Retriever", + "German Shorthaired Pointer", + "Vizsla", + "English Setter", + "Irish Setter", + "Gordon Setter", + "Brittany dog", + "Clumber Spaniel", + "English Springer Spaniel", + "Welsh Springer Spaniel", + "Cocker Spaniel", + "Sussex Spaniel", + "Irish Water Spaniel", + "Kuvasz", + "Schipperke", + "Groenendael dog", + "Malinois", + "Briard", + "Australian Kelpie", + "Komondor", + "Old English Sheepdog", + "Shetland Sheepdog", + "collie", + "Border Collie", + "Bouvier des Flandres dog", + "Rottweiler", + "German Shepherd Dog", + "Dobermann", + "Miniature Pinscher", + "Greater Swiss Mountain Dog", + "Bernese Mountain Dog", + "Appenzeller Sennenhund", + "Entlebucher Sennenhund", + "Boxer", + "Bullmastiff", + "Tibetan Mastiff", + "French Bulldog", + "Great Dane", + "St. Bernard", + "husky", + "Alaskan Malamute", + "Siberian Husky", + "Dalmatian", + "Affenpinscher", + "Basenji", + "pug", + "Leonberger", + "Newfoundland dog", + "Great Pyrenees dog", + "Samoyed", + "Pomeranian", + "Chow Chow", + "Keeshond", + "brussels griffon", + "Pembroke Welsh Corgi", + "Cardigan Welsh Corgi", + "Toy Poodle", + "Miniature Poodle", + "Standard Poodle", + "Mexican hairless dog (xoloitzcuintli)", + "grey wolf", + "Alaskan tundra wolf", + "red wolf or maned wolf", + "coyote", + "dingo", + "dhole", + "African wild dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian Mau", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "polar bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "longhorn beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket insect", + "stick insect", + "cockroach", + "praying mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "red admiral butterfly", + "ringlet butterfly", + "monarch butterfly", + "small white butterfly", + "sulphur butterfly", + "gossamer-winged butterfly", + "starfish", + "sea urchin", + "sea cucumber", + "cottontail rabbit", + "hare", + "Angora rabbit", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "common sorrel horse", + "zebra", + "pig", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram (adult male sheep)", + "bighorn sheep", + "Alpine ibex", + "hartebeest", + "impala (antelope)", + "gazelle", + "arabian camel", + "llama", + "weasel", + "mink", + "European polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas monkey", + "baboon", + "macaque", + "langur", + "black-and-white colobus", + "proboscis monkey", + "marmoset", + "white-headed capuchin", + "howler monkey", + "titi monkey", + "Geoffroy's spider monkey", + "common squirrel monkey", + "ring-tailed lemur", + "indri", + "Asian elephant", + "African bush elephant", + "red panda", + "giant panda", + "snoek fish", + "eel", + "silver salmon", + "rock beauty fish", + "clownfish", + "sturgeon", + "gar fish", + "lionfish", + "pufferfish", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibious vehicle", + "analog clock", + "apiary", + "apron", + "trash can", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint pen", + "Band-Aid", + "banjo", + "baluster / handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "swimming cap", + "bath towel", + "bathtub", + "station wagon", + "lighthouse", + "beaker", + "military hat (bearskin or shako)", + "beer bottle", + "beer glass", + "bell tower", + "baby bib", + "tandem bicycle", + "bikini", + "ring binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsleigh", + "bolo tie", + "poke bonnet", + "bookcase", + "bookstore", + "bottle cap", + "hunting bow", + "bow tie", + "brass memorial plaque", + "bra", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "high-speed train", + "butcher shop", + "taxicab", + "cauldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "tool kit", + "cardboard box / carton", + "car wheel", + "automated teller machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "mobile phone", + "chain", + "chain-link fence", + "chain mail", + "chainsaw", + "storage chest", + "chiffonier", + "bell or wind chime", + "china cabinet", + "Christmas stocking", + "church", + "movie theater", + "cleaver", + "cliff dwelling", + "cloak", + "clogs", + "cocktail shaker", + "coffee mug", + "coffeemaker", + "spiral or coil", + "combination lock", + "computer keyboard", + "candy store", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "construction crane", + "crash helmet", + "crate", + "infant bed", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "rotary dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishcloth", + "dishwasher", + "disc brake", + "dock", + "dog sled", + "dome", + "doormat", + "drilling rig", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso machine", + "face powder", + "feather boa", + "filing cabinet", + "fireboat", + "fire truck", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster bed", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gas mask or respirator", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golf cart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "radiator grille", + "grocery store", + "guillotine", + "hair clip", + "hair spray", + "half-track", + "hammer", + "hamper", + "hair dryer", + "hand-held computer", + "handkerchief", + "hard disk drive", + "harmonica", + "harp", + "combine harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoop skirt", + "gymnastic horizontal bar", + "horse-drawn vehicle", + "hourglass", + "iPod", + "clothes iron", + "carved pumpkin", + "jeans", + "jeep", + "T-shirt", + "jigsaw puzzle", + "rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop computer", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "ocean liner", + "lipstick", + "slip-on shoe", + "lotion", + "music speaker", + "loupe magnifying glass", + "sawmill", + "magnetic compass", + "messenger bag", + "mailbox", + "tights", + "one-piece bathing suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine cabinet", + "megalith", + "microphone", + "microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "ford model t", + "modem", + "monastery", + "monitor", + "moped", + "mortar and pestle", + "graduation cap", + "mosque", + "mosquito net", + "vespa", + "mountain bike", + "tent", + "computer mouse", + "mousetrap", + "moving van", + "muzzle", + "metal nail", + "neck brace", + "necklace", + "baby pacifier", + "notebook computer", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "pipe organ", + "oscilloscope", + "overskirt", + "bullock cart", + "oxygen mask", + "product packet / packaging", + "paddle", + "paddle wheel", + "padlock", + "paintbrush", + "pajamas", + "palace", + "pan flute", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "railroad car", + "patio", + "payphone", + "pedestal", + "pencil case", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "plectrum", + "Pickelhaube", + "picket fence", + "pickup truck", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate ship", + "drink pitcher", + "block plane", + "planetarium", + "plastic bag", + "plate rack", + "farm plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "soda bottle", + "plant pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "missile", + "projector", + "hockey puck", + "punching bag", + "purse", + "quill", + "quilt", + "race car", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "fishing casting reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "eraser", + "rugby ball", + "ruler measuring stick", + "sneaker", + "safe", + "safety pin", + "salt shaker", + "sandal", + "sarong", + "saxophone", + "scabbard", + "weighing scale", + "school bus", + "schooner", + "scoreboard", + "CRT monitor", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe store", + "shoji screen / room divider", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "balaclava ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot machine", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar thermal collector", + "sombrero", + "soup bowl", + "keyboard space bar", + "space heater", + "space shuttle", + "spatula", + "motorboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "through arch bridge", + "steel drum", + "stethoscope", + "scarf", + "stone wall", + "stopwatch", + "stove", + "strainer", + "tram", + "stretcher", + "couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglasses", + "sunglasses", + "sunscreen", + "suspension bridge", + "mop", + "sweatshirt", + "swim trunks / shorts", + "swing", + "electrical switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy bear", + "television", + "tennis ball", + "thatched roof", + "front curtain", + "thimble", + "threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toy store", + "tractor", + "semi-trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "hot tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright piano", + "vacuum cleaner", + "vase", + "vaulted or arched ceiling", + "velvet fabric", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "military aircraft", + "sink", + "washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "hair wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "airplane wing", + "wok", + "wooden spoon", + "wool", + "split-rail fence", + "shipwreck", + "sailboat", + "yurt", + "website", + "comic book", + "crossword", + "traffic or street sign", + "traffic light", + "dust jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "popsicle", + "baguette", + "bagel", + "pretzel", + "cheeseburger", + "hot dog", + "mashed potatoes", + "cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith apple", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "cherimoya (custard apple)", + "pomegranate", + "hay", + "carbonara", + "chocolate syrup", + "dough", + "meatloaf", + "pizza", + "pot pie", + "burrito", + "red wine", + "espresso", + "tea cup", + "eggnog", + "mountain", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeshore", + "promontory", + "sandbar", + "beach", + "valley", + "volcano", + "baseball player", + "bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "rose hip", + "horse chestnut seed", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn mushroom", + "earth star fungus", + "hen of the woods mushroom", + "bolete", + "corn cob", + "toilet paper", +] + + +openai_imagenet_template = [ + lambda c: f"a bad photo of a {c}.", + lambda c: f"a photo of many {c}.", + lambda c: f"a sculpture of a {c}.", + lambda c: f"a photo of the hard to see {c}.", + lambda c: f"a low resolution photo of the {c}.", + lambda c: f"a rendering of a {c}.", + lambda c: f"graffiti of a {c}.", + lambda c: f"a bad photo of the {c}.", + lambda c: f"a cropped photo of the {c}.", + lambda c: f"a tattoo of a {c}.", + lambda c: f"the embroidered {c}.", + lambda c: f"a photo of a hard to see {c}.", + lambda c: f"a bright photo of a {c}.", + lambda c: f"a photo of a clean {c}.", + lambda c: f"a photo of a dirty {c}.", + lambda c: f"a dark photo of the {c}.", + lambda c: f"a drawing of a {c}.", + lambda c: f"a photo of my {c}.", + lambda c: f"the plastic {c}.", + lambda c: f"a photo of the cool {c}.", + lambda c: f"a close-up photo of a {c}.", + lambda c: f"a black and white photo of the {c}.", + lambda c: f"a painting of the {c}.", + lambda c: f"a painting of a {c}.", + lambda c: f"a pixelated photo of the {c}.", + lambda c: f"a sculpture of the {c}.", + lambda c: f"a bright photo of the {c}.", + lambda c: f"a cropped photo of a {c}.", + lambda c: f"a plastic {c}.", + lambda c: f"a photo of the dirty {c}.", + lambda c: f"a jpeg corrupted photo of a {c}.", + lambda c: f"a blurry photo of the {c}.", + lambda c: f"a photo of the {c}.", + lambda c: f"a good photo of the {c}.", + lambda c: f"a rendering of the {c}.", + lambda c: f"a {c} in a video game.", + lambda c: f"a photo of one {c}.", + lambda c: f"a doodle of a {c}.", + lambda c: f"a close-up photo of the {c}.", + lambda c: f"a photo of a {c}.", + lambda c: f"the origami {c}.", + lambda c: f"the {c} in a video game.", + lambda c: f"a sketch of a {c}.", + lambda c: f"a doodle of the {c}.", + lambda c: f"a origami {c}.", + lambda c: f"a low resolution photo of a {c}.", + lambda c: f"the toy {c}.", + lambda c: f"a rendition of the {c}.", + lambda c: f"a photo of the clean {c}.", + lambda c: f"a photo of a large {c}.", + lambda c: f"a rendition of a {c}.", + lambda c: f"a photo of a nice {c}.", + lambda c: f"a photo of a weird {c}.", + lambda c: f"a blurry photo of a {c}.", + lambda c: f"a cartoon {c}.", + lambda c: f"art of a {c}.", + lambda c: f"a sketch of the {c}.", + lambda c: f"a embroidered {c}.", + lambda c: f"a pixelated photo of a {c}.", + lambda c: f"itap of the {c}.", + lambda c: f"a jpeg corrupted photo of the {c}.", + lambda c: f"a good photo of a {c}.", + lambda c: f"a plushie {c}.", + lambda c: f"a photo of the nice {c}.", + lambda c: f"a photo of the small {c}.", + lambda c: f"a photo of the weird {c}.", + lambda c: f"the cartoon {c}.", + lambda c: f"art of the {c}.", + lambda c: f"a drawing of the {c}.", + lambda c: f"a photo of the large {c}.", + lambda c: f"a black and white photo of a {c}.", + lambda c: f"the plushie {c}.", + lambda c: f"a dark photo of a {c}.", + lambda c: f"itap of a {c}.", + lambda c: f"graffiti of the {c}.", + lambda c: f"a toy {c}.", + lambda c: f"itap of my {c}.", + lambda c: f"a photo of a cool {c}.", + lambda c: f"a photo of a small {c}.", + lambda c: f"a tattoo of the {c}.", +] diff --git a/audioldm/clap/training/infer_demo.py b/audioldm/clap/training/infer_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..7d1f4784898dbfeb69affefb6f624711adc8cb42 --- /dev/null +++ b/audioldm/clap/training/infer_demo.py @@ -0,0 +1,105 @@ +import sys + +import os +import torch +import librosa +from open_clip import create_model +from training.data import get_audio_features +from training.data import int16_to_float32, float32_to_int16 +from transformers import RobertaTokenizer + +tokenize = RobertaTokenizer.from_pretrained("roberta-base") + + +def tokenizer(text): + result = tokenize( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + + +PRETRAINED_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/checkpoints/epoch_top_0_audioset_no_fusion.pt" +WAVE_48k_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/audio/machine.wav" + + +def infer_text(): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + precision = "fp32" + amodel = "HTSAT-tiny" # or 'PANN-14' + tmodel = "roberta" # the best text encoder in our training + enable_fusion = False # False if you do not want to use the fusion model + fusion_type = "aff_2d" + pretrained = PRETRAINED_PATH + + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision=precision, + device=device, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + # load the text, can be a list (i.e. batch size) + text_data = ["I love the contrastive learning", "I love the pretrain model"] + # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90 + text_data = tokenizer(text_data) + + text_embed = model.get_text_embedding(text_data) + print(text_embed.size()) + + +def infer_audio(): + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + precision = "fp32" + amodel = "HTSAT-tiny" # or 'PANN-14' + tmodel = "roberta" # the best text encoder in our training + enable_fusion = False # False if you do not want to use the fusion model + fusion_type = "aff_2d" + pretrained = PRETRAINED_PATH + + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision=precision, + device=device, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + + # load the waveform of the shape (T,), should resample to 48000 + audio_waveform, sr = librosa.load(WAVE_48k_PATH, sr=48000) + # quantize + audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) + audio_waveform = torch.from_numpy(audio_waveform).float() + audio_dict = {} + + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + import ipdb + + ipdb.set_trace() + audio_dict = get_audio_features( + audio_dict, + audio_waveform, + 480000, + data_truncating="fusion", + data_filling="repeatpad", + audio_cfg=model_cfg["audio_cfg"], + ) + # can send a list to the model, to process many audio tracks in one time (i.e. batch size) + audio_embed = model.get_audio_embedding([audio_dict]) + print(audio_embed.size()) + import ipdb + + ipdb.set_trace() + + +if __name__ == "__main__": + infer_text() + infer_audio() diff --git a/audioldm/clap/training/logger.py b/audioldm/clap/training/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ac4634970fae6aacde2b7b808355dbd50c90ce73 --- /dev/null +++ b/audioldm/clap/training/logger.py @@ -0,0 +1,30 @@ +import logging + + +def setup_logging(log_file, level, include_host=False): + if include_host: + import socket + + hostname = socket.gethostname() + formatter = logging.Formatter( + f"%(asctime)s | {hostname} | %(levelname)s | %(message)s", + datefmt="%Y-%m-%d,%H:%M:%S", + ) + else: + formatter = logging.Formatter( + "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S" + ) + + logging.root.setLevel(level) + loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] + for logger in loggers: + logger.setLevel(level) + + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logging.root.addHandler(stream_handler) + + if log_file: + file_handler = logging.FileHandler(filename=log_file) + file_handler.setFormatter(formatter) + logging.root.addHandler(file_handler) diff --git a/audioldm/clap/training/lp_main.py b/audioldm/clap/training/lp_main.py new file mode 100644 index 0000000000000000000000000000000000000000..c2d4e8c85aaa3c8e4221963ef56a815cc14f354f --- /dev/null +++ b/audioldm/clap/training/lp_main.py @@ -0,0 +1,670 @@ +from cmath import cos +from inspect import getargs +import logging +import os +import random +from datetime import datetime +import bisect +import copy +from sched import scheduler +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch import optim +from torch.cuda.amp import GradScaler +import faulthandler +import pathlib +import argparse +import time + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from open_clip import create_model_and_transforms, trace_model, create_model +from training.data import get_data +from training.params import parse_args +from training.distributed import is_master, init_distributed_device, world_info_from_env +from training.logger import setup_logging +from training.scheduler import cosine_lr +from training.lp_train import train_one_epoch, evaluate +from open_clip.utils import get_tar_path_from_dataset_name, dataset_split, get_optimizer +from open_clip.utils import load_p, load_class_label +from open_clip.linear_probe import LinearProbe + + +def maintain_ckpts(args, startidx, all_idx_len): + for i in reversed(range(startidx, all_idx_len)): + if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): + os.rename( + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), + ) + if os.path.exists( + os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") + ): + os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) + return + + +def update_top_k_performance( + new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True +): + """ + Record the top-k performance of the current epoch. + current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} + """ + if isinstance(new_metrics_inputs, (list, tuple)): + new_metrics_inputs = np.mean(new_metrics_inputs) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, dict): + new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, (float, int)): + update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} + sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) + sorted_values = sorted( + current_top_k_ckpt_metrics.values(), reverse=bignumbetter + ) + sorted_values_ = copy.deepcopy(sorted_values) + sorted_values.append(new_metrics_inputs) + sorted_values = sorted(sorted_values, reverse=bignumbetter) + sorted_values = sorted_values[:-1] + + if sorted_values == sorted_values_: + return current_top_k_ckpt_metrics, new_metrics_inputs + else: + for i in range(len(sorted_keys)): + if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: + current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] + update_flag[sorted_keys[i]] = True + for i in range(len(update_flag)): + if update_flag[i]: + maintain_ckpts(args, i, len(sorted_keys)) + torch.save( + ckpt, + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + ) + break + return current_top_k_ckpt_metrics, new_metrics_inputs + + +# def updateifNone(a, b): +# a = b if None else a +# return a + + +def is_pretrained_params(n): + return ( + n.startswith("clap_model.transformer") + or n in ["clap_model.positional_embedding", "clap_model.text_projection"] + or n.startswith("clap_model.token_embedding") + or n.startswith("clap_model.ln_final") + or n.startswith("clap_model.logit_scale_t") + ) + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def config_lp_optimizer(model, data, args): + # set wd-related params to 0 if use adam optimizer + if args.optimizer == "adam": + args.wd = 0 + args.wd_pretrained = 0 + args.wd_new = 0 + + in_clap = lambda n, p: n.startswith("clap_model") + + named_parameters = list(model.named_parameters()) + + optimizer = {} + scheduler = {} + + # freeze text encoder + text_freeze_parameters = [ + p + for n, p in named_parameters + if n.startswith("clap_model.transformer") + or n in ["clap_model.positional_embedding", "clap_model.text_projection"] + or n.startswith("clap_model.token_embedding") + or n.startswith("clap_model.ln_final") + ] + + if args.freeze_text: + logging.info("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + + if not args.lp_freeze: + exclude = ( + lambda n, p: p.ndim < 2 + or "bn" in n + or "ln" in n + or "bias" in n + or "logit_scale" in n + ) + include = lambda n, p: not exclude(n, p) + + # (yusong): we do not split the learning rate anymore + # p for n, p in named_parameters if in_clap(n,p) and exclude(n, p) and p.requires_grad + gain_or_bias_params = [ + p for n, p in named_parameters if exclude(n, p) and p.requires_grad + ] + # rest_params = [p for n, p in named_parameters if in_clap(n,p) and include(n, p) and p.requires_grad] + rest_params = [ + p for n, p in named_parameters if include(n, p) and p.requires_grad + ] + + if args.train_data is None: + optimizer = None + scheduler = None + else: + total_steps = data["train"].dataloader.num_batches * args.epochs + + if args.split_opt: + for x in ["lr", "beta1", "beta2", "eps", "wd"]: + for y in ["_new", "_pretrained"]: + if getattr(args, x + y) is None: + setattr(args, x + y, getattr(args, x)) + + gain_or_bias_pretrained_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + rest_pretrained_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + gain_or_bias_new_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) + and (not is_pretrained_params(n)) + ] + rest_new_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) + and (not is_pretrained_params(n)) + ] + + pretrained_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, + { + "params": rest_pretrained_params, + "weight_decay": args.wd_pretrained, + }, + ], + lr=args.lr_pretrained, + betas=(args.beta1_pretrained, args.beta2_pretrained), + eps=args.eps_pretrained, + momentum=args.momentum_pretrained, + optimizer_name=args.optimizer, + ) + pretrained_params_scheduler = cosine_lr( + pretrained_params_optimizer, + args.lr_pretrained, + args.warmup, + total_steps, + ) + + new_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_new_params, "weight_decay": 0.0}, + {"params": rest_new_params, "weight_decay": args.wd_new}, + ], + lr=args.lr_new, + betas=(args.beta1_new, args.beta2_new), + eps=args.eps_new, + momentum=args.momentum_new, + optimizer_name=args.optimizer, + ) + new_params_scheduler = cosine_lr( + new_params_optimizer, args.lr_new, args.warmup, total_steps + ) + + optimizer["text"] = pretrained_params_optimizer + optimizer["audio"] = new_params_optimizer + scheduler["text"] = pretrained_params_scheduler + scheduler["audio"] = new_params_scheduler + + if args.horovod: + pretrained_params_optimizer = hvd.DistributedOptimizer( + pretrained_params_optimizer, + named_parameters=model.named_parameters(), + ) + new_params_optimizer = hvd.DistributedOptimizer( + new_params_optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state( + pretrained_params_optimizer, root_rank=0 + ) + hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) + else: + + optimizer["clap"] = get_optimizer( + [ + {"params": gain_or_bias_params, "weight_decay": 0.0}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=args.momentum, + optimizer_name=args.optimizer, + ) + scheduler["clap"] = cosine_lr( + optimizer["clap"], args.lr, args.warmup, total_steps + ) + + if args.horovod: + optimizer["clap"] = hvd.DistributedOptimizer( + optimizer["clap"], named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer["clap"], root_rank=0) + + # linear probe optimizer + else: + lp_params = [ + p for n, p in named_parameters if (not in_clap(n, p)) and p.requires_grad + ] + lp_optim = get_optimizer( + lp_params, + lr=args.lp_lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=0.9, + optimizer_name=args.optimizer, + ) + optimizer["lp"] = lp_optim + + return optimizer, scheduler, text_freeze_parameters + + +def main(): + args = parse_args() + + time.sleep(args.sleep) + + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + args.amodel = args.amodel.replace("/", "-") + # download sizes.json file + + # (yusong): the below two lines are for debug + # print("setting up faulthandler") + # faulthandler.register(10) + + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + args.class_index_dict = load_class_label(args.class_label_path) + + # get the name of the experiments + if args.name is None: + args.name = "-".join( + [ + datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), + f"linear_probe" f"model_{args.amodel}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ] + ) + + # discover initial world args early so we can log properly + args.distributed = False + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.remotedata and is_master(args): + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + args.log_path = None + if is_master(args, local=args.log_local): + log_base_path = os.path.join(args.logs, args.name) + os.makedirs(log_base_path, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path, log_filename) + + # avoid log dir in same name: + postfix = 0 + while os.path.exists(args.log_path): + postfix += 1 + log_base_path_new = log_base_path + "-" + str(postfix) + os.makedirs(log_base_path_new, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path_new, log_filename) + # print( + # "Error. Experiment already exists. Use --name {} to specify a new experiment." + # ) + # return -1 + + # Set logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # fully initialize distributed device environment + device = init_distributed_device(args) + + args.wandb = "wandb" in args.report_to or "all" in args.report_to + args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to + if is_master(args): + args.tensorboard_path = ( + os.path.join(args.logs, args.name, "tensorboard") + if args.tensorboard + else "" + ) + args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = "" + args.checkpoint_path = "" + + if args.copy_codebase: + copy_codebase(args) + + assert args.precision in ["amp", "fp16", "fp32"] + if args.precision == "fp16": + logging.warning( + "It is recommended to use AMP mixed-precision instead of FP16. " + "FP16 support needs further verification and tuning, especially for train." + ) + + if args.horovod: + logging.info( + f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + elif args.distributed: + logging.info( + f"Running in distributed mode with multiple processes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + else: + logging.info(f"Running with a single process. Device {args.device}.") + + logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}") + + # Create CLAP model + clap_model, clap_model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False, + pretrained_audio=args.pretrained_audio, + pretrained_text=args.pretrained_text, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type, + ) + + args.lp_out_ch = len(list(args.class_index_dict.keys())) + # Linear Probe + logging.info(f"linear probe using mlp: {args.lp_mlp}") + logging.info(f"linear probe using freeze: {args.lp_freeze}") + logging.info(f"linear probe act layer: {args.lp_act}") + logging.info(f"linear probe out ch: {args.lp_out_ch}") + logging.info(f"linear probe learning rate (if applicable): {args.lp_lr}") + logging.info(f"linear probe loss func: {args.lp_loss}") + logging.info(f"linear probe lp_metrics: {args.lp_metrics}") + + model = LinearProbe( + clap_model, + mlp=args.lp_mlp, + freeze=args.lp_freeze, + in_ch=512, + out_ch=args.lp_out_ch, + act=args.lp_act, + ) # in_ch is fixed (i.e., 512) + model = model.to(device) + + if args.horovod: + with torch.no_grad(): + for param in model.parameters(): + param.set_(param.contiguous()) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if is_master(args): + logging.info("Linear Probe CLAP Model:") + logging.info(f"{str(clap_model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args["static_graph"] = True + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True, **ddp_args + ) + + data = get_data(args, clap_model_cfg) + assert len(data), "At least one train or eval dataset must be specified." + if args.trace: + assert "train" not in data, "Cannot train with traced model" + + optimizer, scheduler, text_freeze_parameters = config_lp_optimizer( + model, data, args + ) + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module.") :]: v for k, v in sd.items()} + model.load_state_dict(sd) + if args.split_opt: + if optimizer is not None: + for k, o_ in optimizer.items(): + o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and "scaler" in checkpoint: + scaler.load_state_dict(checkpoint["scaler"]) + logging.info( + f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info( + f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" + ) + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + cudnn.deterministic = False + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, "Please install wandb." + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + project="clap", + notes=args.wandb_notes, + name=args.wandb_notes, + tags=[], + config=vars(args), + ) + if args.debug: + wandb.watch(model, log="all") + wandb.save(params_file) + logging.debug("Finished loading wandb.") + + if "train" not in data: + evaluate(model, data, start_epoch, args, writer) + return + elif start_epoch == 0 and "val" in data and not args.no_eval: + evaluate(model, data, 0, args, writer) + if args.save_top_performance: + current_top_k_ckpt_metrics = { + i: 0 for i in range(args.save_top_performance) + } # initialize the top-k metric for ckpts to 0 + + for epoch in range(start_epoch, args.epochs): + # freeze the text param after (include) args.freeze_text_after, this is -1 by default + if epoch == args.freeze_text_after: + print("Text pretrained parameters are freezed since this epoch.") + for k in text_freeze_parameters: + k.requires_grad = False + if is_master(args): + logging.info(f"Start epoch {epoch}") + + train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + completed_epoch = epoch + 1 + + if ( + any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) + and not args.no_eval + ): + metrics = evaluate(model, data, completed_epoch, args, writer) + if args.save_top_performance: + top_k_dataset = args.top_k_checkpoint_select_dataset + top_k_metric = args.top_k_checkpoint_select_metric + filtered_metrics = [ + v + for k, v in metrics.items() + if top_k_metric in k and top_k_dataset in k + ] # check all R@10 metrics (all dataset) and use it to update the ckpt + # Saving checkpoints. + if args.save_logs: + opt_dict = { + k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() + } + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + } + checkpoint_dict.update(opt_dict) + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.save_most_recent: + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_latest.pt"), + ) + if args.save_top_performance and not args.no_eval: + update_top_k_performance( + filtered_metrics, + current_top_k_ckpt_metrics, + args, + checkpoint_dict, + bignumbetter=True, + ) + + if args.wandb and is_master(args): + wandb.finish() + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree( + current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") + ) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main() diff --git a/audioldm/clap/training/lp_train.py b/audioldm/clap/training/lp_train.py new file mode 100644 index 0000000000000000000000000000000000000000..24a19bacd0a4b789415cfccbce1f8bc99bc493ed --- /dev/null +++ b/audioldm/clap/training/lp_train.py @@ -0,0 +1,301 @@ +import json +import logging +import math +import os +import time +from contextlib import suppress + +import numpy as np +import torch +import torch.nn.functional as F + +try: + import wandb +except ImportError: + wandb = None + +from open_clip import LPLoss, LPMetrics, lp_gather_features +from open_clip.utils import do_mixup, get_mix_lambda +from .distributed import is_master +from .zero_shot import zero_shot_eval + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def unwrap_model(model): + if hasattr(model, "module"): + return model.module + else: + return model + + +def train_one_epoch( + model, + data, + epoch, + optimizer, + scaler, + scheduler, + args, + tb_writer=None, + extra_suffix="", +): + device = torch.device(args.device) + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + model.train() + loss = LPLoss(args.lp_loss) + + dataloader, sampler = data["train"].dataloader, data["train"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + num_batches_per_epoch = dataloader.num_batches + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + # for toy dataset + if args.dataset_type == "toy": + dataloader.dataset.generate_queue() + + loss_m = AverageMeter() + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for i, batch in enumerate(dataloader): + step = num_batches_per_epoch * epoch + i + + if isinstance(scheduler, dict): + for s in scheduler.values(): + s(step) + else: + scheduler(step) + + audio = batch # contains mel_spec, wavform, and longer list + class_label = batch["class_label"] + # audio = audio.to(device=device, non_blocking=True) + class_label = class_label.to(device=device, non_blocking=True) + + if args.mixup: + # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146 + mix_lambda = torch.from_numpy( + get_mix_lambda(0.5, len(audio["waveform"])) + ).to(device) + class_label = do_mixup(class_label, mix_lambda) + else: + mix_lambda = None + + data_time_m.update(time.time() - end) + if isinstance(optimizer, dict): + for o_ in optimizer.values(): + o_.zero_grad() + else: + optimizer.zero_grad() + + with autocast(): + pred = model(audio, mix_lambda=mix_lambda, device=device) + total_loss = loss(pred, class_label) + + if isinstance(optimizer, dict): + if scaler is not None: + scaler.scale(total_loss).backward() + for o_ in optimizer.values(): + if args.horovod: + o_.synchronize() + scaler.unscale_(o_) + with o_.skip_synchronize(): + scaler.step(o_) + else: + scaler.step(o_) + scaler.update() + else: + total_loss.backward() + for o_ in optimizer.values(): + o_.step() + else: + if scaler is not None: + scaler.scale(total_loss).backward() + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + optimizer.step() + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100)) + unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i + 1 + + if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): + if isinstance(audio, dict): + batch_size = len(audio["waveform"]) + else: + batch_size = len(audio) + num_samples = batch_count * batch_size * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + loss_m.update(total_loss.item(), batch_size) + if isinstance(optimizer, dict): + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "lr": optimizer.param_groups[0]["lr"], + } + for name, val in log_data.items(): + name = f"train{extra_suffix}/{name}" + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, "Please install wandb." + wandb.log({name: val, "step": step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + + +def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): + metrics = {} + if not args.parallel_eval: + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + # CHANGE + # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + # metrics.update(zero_shot_metrics) + if is_master(args): + print("Evaluating...") + metric_names = args.lp_metrics.split(",") + eval_tool = LPMetrics(metric_names=metric_names) + + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + if "val" in data and ( + args.val_frequency + and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) + ): + if args.parallel_eval: + dataloader, sampler = data["val"].dataloader, data["val"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + samples_per_val = dataloader.num_samples + else: + dataloader = data["val"].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + eval_info = {"pred": [], "target": []} + with torch.no_grad(): + for i, batch in enumerate(dataloader): + audio = batch # contains mel_spec, wavform, and longer list + class_label = batch["class_label"] + + # audio = audio.to(device=device, non_blocking=True) + class_label = class_label.to(device=device, non_blocking=True) + + with autocast(): + pred = model(audio, device=device) + if args.parallel_eval: + pred, class_label = lp_gather_features( + pred, class_label, args.world_size, args.horovod + ) + eval_info["pred"].append(pred) + eval_info["target"].append(class_label) + + num_samples += class_label.shape[0] + + if (i % 100) == 0: # and i != 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" + ) + + if is_master(args): + eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu() + eval_info["target"] = torch.cat(eval_info["target"], 0).cpu() + metric_dict = eval_tool.evaluate_mertics( + eval_info["pred"], eval_info["target"] + ) + metrics.update(metric_dict) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + + if is_master(args): + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\n".join( + ["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics] + ) + ) + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, "Please install wandb." + for name, val in metrics.items(): + wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch}) + + return metrics + else: + return metrics diff --git a/audioldm/clap/training/main.py b/audioldm/clap/training/main.py new file mode 100644 index 0000000000000000000000000000000000000000..3b563a5d001be7adfbe779dee7ad8ac49aadc50d --- /dev/null +++ b/audioldm/clap/training/main.py @@ -0,0 +1,596 @@ +from inspect import getargs +import logging +import os +import random +from datetime import datetime +import bisect +import copy +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch import optim +from torch.cuda.amp import GradScaler +import faulthandler +import pathlib + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from open_clip import create_model_and_transforms, trace_model, create_model +from training.data import get_data +from training.distributed import is_master, init_distributed_device, world_info_from_env +from training.logger import setup_logging +from training.params import parse_args +from training.scheduler import cosine_lr +from training.train import train_one_epoch, evaluate +from open_clip.utils import dataset_split, get_optimizer + + +def maintain_ckpts(args, startidx, all_idx_len): + for i in reversed(range(startidx, all_idx_len)): + if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): + os.rename( + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), + ) + if os.path.exists( + os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") + ): + os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) + return + + +def update_top_k_performance( + new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True +): + """ + Record the top-k performance of the current epoch. + current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} + """ + if isinstance(new_metrics_inputs, (list, tuple)): + new_metrics_inputs = np.mean(new_metrics_inputs) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, dict): + new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, (float, int)): + update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} + sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) + sorted_values = sorted( + current_top_k_ckpt_metrics.values(), reverse=bignumbetter + ) + sorted_values_ = copy.deepcopy(sorted_values) + sorted_values.append(new_metrics_inputs) + sorted_values = sorted(sorted_values, reverse=bignumbetter) + sorted_values = sorted_values[:-1] + + if sorted_values == sorted_values_: + return current_top_k_ckpt_metrics, new_metrics_inputs + else: + for i in range(len(sorted_keys)): + if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: + current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] + update_flag[sorted_keys[i]] = True + for i in range(len(update_flag)): + if update_flag[i]: + maintain_ckpts(args, i, len(sorted_keys)) + torch.save( + ckpt, + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + ) + break + return current_top_k_ckpt_metrics, new_metrics_inputs + + +# def updateifNone(a, b): +# a = b if None else a +# return a + + +def is_pretrained_params(n): + return ( + n.startswith("transformer") + or n in ["positional_embedding", "text_projection"] + or n.startswith("token_embedding") + or n.startswith("ln_final") + or n.startswith("logit_scale_t") + ) + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def main(): + args = parse_args() + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + args.amodel = args.amodel.replace("/", "-") + # download sizes.json file + + # (yusong): the below two lines are for debug + # print("setting up faulthandler") + # faulthandler.register(10) + + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + if args.tmodel == "bert" or args.tmodel == "roberta" or args.tmodel == "bart": + assert ( + args.pretrained == "" or args.pretrained is None + ), "bert/roberta/bart text encoder does not support pretrained models." + + # get the name of the experiments + if args.name is None: + args.name = "-".join( + [ + datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), + f"model_{args.amodel}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ] + ) + + # discover initial world args early so we can log properly + args.distributed = False + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.remotedata and is_master(args): + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + args.log_path = None + if is_master(args, local=args.log_local): + log_base_path = os.path.join(args.logs, args.name) + os.makedirs(log_base_path, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path, log_filename) + if os.path.exists(args.log_path): + print( + "Error. Experiment already exists. Use --name {} to specify a new experiment." + ) + return -1 + + # Set logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # fully initialize distributed device environment + device = init_distributed_device(args) + + args.wandb = "wandb" in args.report_to or "all" in args.report_to + args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to + if is_master(args): + args.tensorboard_path = ( + os.path.join(args.logs, args.name, "tensorboard") + if args.tensorboard + else "" + ) + args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = "" + args.checkpoint_path = "" + + if args.copy_codebase: + copy_codebase(args) + + assert args.precision in ["amp", "fp16", "fp32"] + if args.precision == "fp16": + logging.warning( + "It is recommended to use AMP mixed-precision instead of FP16. " + "FP16 support needs further verification and tuning, especially for train." + ) + + if args.horovod: + logging.info( + f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + elif args.distributed: + logging.info( + f"Running in distributed mode with multiple processes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + else: + logging.info(f"Running with a single process. Device {args.device}.") + + logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}") + + model, model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=True, + pretrained_audio=args.pretrained_audio, + pretrained_text=args.pretrained_text, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type, + ) + + if args.horovod: + with torch.no_grad(): + for param in model.parameters(): + param.set_(param.contiguous()) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if is_master(args): + logging.info("Model:") + logging.info(f"{str(model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args["static_graph"] = True + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True, **ddp_args + ) + + data = get_data(args, model_cfg) + assert len(data), "At least one train or eval dataset must be specified." + if args.trace: + assert "train" not in data, "Cannot train with traced model" + + exclude = ( + lambda n, p: p.ndim < 2 + or "bn" in n + or "ln" in n + or "bias" in n + or "logit_scale" in n + ) + include = lambda n, p: not exclude(n, p) + + named_parameters = list(model.named_parameters()) + + # freeze text encoder + text_freeze_parameters = [p for n, p in named_parameters if "text_branch" in n] + + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + + gain_or_bias_params = [ + p for n, p in named_parameters if exclude(n, p) and p.requires_grad + ] + rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] + + # set wd-related params to 0 if use adam optimizer + if args.optimizer == "adam": + args.wd = 0 + args.wd_pretrained = 0 + args.wd_new = 0 + + if args.train_data is None: + optimizer = None + scheduler = None + else: + total_steps = data["train"].dataloader.num_batches * args.epochs + + if args.split_opt: + for x in ["lr", "beta1", "beta2", "eps", "wd"]: + for y in ["_new", "_pretrained"]: + if getattr(args, x + y) is None: + setattr(args, x + y, getattr(args, x)) + + gain_or_bias_pretrained_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + rest_pretrained_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + gain_or_bias_new_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n)) + ] + rest_new_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n)) + ] + pretrained_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, + { + "params": rest_pretrained_params, + "weight_decay": args.wd_pretrained, + }, + ], + lr=args.lr_pretrained, + betas=(args.beta1_pretrained, args.beta2_pretrained), + eps=args.eps_pretrained, + momentum=args.momentum_pretrained, + optimizer_name=args.optimizer, + ) + pretrained_params_scheduler = cosine_lr( + pretrained_params_optimizer, + args.lr_pretrained, + args.warmup, + total_steps, + ) + new_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_new_params, "weight_decay": 0.0}, + {"params": rest_new_params, "weight_decay": args.wd_new}, + ], + lr=args.lr_new, + betas=(args.beta1_new, args.beta2_new), + eps=args.eps_new, + momentum=args.momentum_new, + optimizer_name=args.optimizer, + ) + + new_params_scheduler = cosine_lr( + new_params_optimizer, args.lr_new, args.warmup, total_steps + ) + + optimizer = { + "pretrained": pretrained_params_optimizer, + "new": new_params_optimizer, + } + scheduler = { + "pretrained": pretrained_params_scheduler, + "new": new_params_scheduler, + } + + if args.horovod: + pretrained_params_optimizer = hvd.DistributedOptimizer( + pretrained_params_optimizer, + named_parameters=model.named_parameters(), + ) + new_params_optimizer = hvd.DistributedOptimizer( + new_params_optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0) + hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) + else: + optimizer = get_optimizer( + [ + {"params": gain_or_bias_params, "weight_decay": 0.0}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=args.momentum, + optimizer_name=args.optimizer, + ) + + scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) + + if args.horovod: + optimizer = hvd.DistributedOptimizer( + optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module.") :]: v for k, v in sd.items()} + model.load_state_dict(sd) + if args.split_opt: + if optimizer is not None: + for k, o_ in optimizer.items(): + o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and "scaler" in checkpoint: + scaler.load_state_dict(checkpoint["scaler"]) + logging.info( + f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info( + f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" + ) + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + cudnn.deterministic = False + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, "Please install wandb." + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + project="clap", + notes=args.wandb_notes, + name=args.wandb_notes, + tags=[], + config=vars(args), + ) + if args.debug: + wandb.watch(model, log="all") + wandb.save(params_file) + logging.debug("Finished loading wandb.") + + if "train" not in data: + evaluate(model, data, start_epoch, args, writer) + return + elif start_epoch == 0 and "val" in data and not args.no_eval: + evaluate(model, data, 0, args, writer) + # print(f'rank {args.rank}, Start First Evaluation')# (yusong): for debug + if args.save_top_performance: + current_top_k_ckpt_metrics = { + i: 0 for i in range(args.save_top_performance) + } # initialize the top-k metric for ckpts to 0 + + # print(f'rank {args.rank}, Start Training') # (yusong): for debug + for epoch in range(start_epoch, args.epochs): + # freeze the text param after (include) args.freeze_text_after, this is -1 by default + if epoch == args.freeze_text_after: + print("Text pretrained parameters are freezed since this epoch.") + for k in text_freeze_parameters: + k.requires_grad = False + if is_master(args): + logging.info(f"Start epoch {epoch}") + + train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + completed_epoch = epoch + 1 + + if ( + any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) + and not args.no_eval + ): + metrics = evaluate(model, data, completed_epoch, args, writer) + if args.save_top_performance: + top_k_dataset = args.top_k_checkpoint_select_dataset + top_k_metric = args.top_k_checkpoint_select_metric + filtered_metrics = [ + v + for k, v in metrics.items() + if top_k_metric in k and top_k_dataset in k + ] # check all R@10 metrics (all dataset) and use it to update the ckpt + # Saving checkpoints. + if args.save_logs: + if args.split_opt: + opt_dict = { + k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() + } + else: + opt_dict = {"optimizer": optimizer.state_dict()} + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + } + checkpoint_dict.update(opt_dict) + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.save_most_recent: + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_latest.pt"), + ) + if args.save_top_performance and not args.no_eval: + update_top_k_performance( + filtered_metrics, + current_top_k_ckpt_metrics, + args, + checkpoint_dict, + bignumbetter=True, + ) + + if args.wandb and is_master(args): + wandb.finish() + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree( + current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") + ) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main() diff --git a/audioldm/clap/training/params.py b/audioldm/clap/training/params.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc1a0e2d982e900988cf5a4b24b2e59b093537b --- /dev/null +++ b/audioldm/clap/training/params.py @@ -0,0 +1,563 @@ +import argparse + + +def get_default_params(model_name): + # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) + model_name = model_name.lower() + if "vit" in model_name: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} + else: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--train-data", + type=str, + default=None, + help="Path to h5 filewith training data", + ) + parser.add_argument( + "--val-data", + type=str, + default=None, + help="Path to h5 file with validation data", + ) + parser.add_argument( + "--freeze-text", + default=False, + action="store_true", + help="if you need to freeze the text encoder, make this True", + ) + parser.add_argument( + "--freeze-text-after", + type=int, + default=-1, + help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it", + ) + parser.add_argument( + "--train-ipc", + type=str, + default=None, + help="Path to npy file of the number of instance per class in training data", + ) + parser.add_argument( + "--val-ipc", + type=str, + default=None, + help="Path to npy file of the number of instance per class in validation data", + ) + parser.add_argument( + "--train-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Required for webdataset if not available in info file.", + ) + parser.add_argument( + "--val-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Useful for webdataset if not available in info file.", + ) + parser.add_argument( + "--dataset-type", + choices=["webdataset", "csv", "auto", "toy"], + default="auto", + help="Which type of dataset to process.", + ) + parser.add_argument( + "--csv-separator", + type=str, + default="\t", + help="For csv-like datasets, which separator to use.", + ) + parser.add_argument( + "--csv-img-key", + type=str, + default="filepath", + help="For csv-like datasets, the name of the key for the image paths.", + ) + parser.add_argument( + "--csv-caption-key", + type=str, + default="title", + help="For csv-like datasets, the name of the key for the captions.", + ) + parser.add_argument( + "--imagenet-val", + type=str, + default=None, + help="Path to imagenet val set for conducting zero shot evaluation.", + ) + parser.add_argument( + "--imagenet-v2", + type=str, + default=None, + help="Path to imagenet v2 for conducting zero shot evaluation.", + ) + parser.add_argument( + "--datasetnames", + nargs="+", + default=None, + help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects", + ) + parser.add_argument( + "--full-train-dataset", + nargs="+", + default=None, + help="Which dataset will be trained with all the subsets. (train+test)", + ) + parser.add_argument( + "--exclude-eval-dataset", + nargs="+", + default=None, + help="Which dataset will be excluded with evaluation", + ) + parser.add_argument( + "--datasetinfos", + nargs="+", + default=None, + help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval", + ) + parser.add_argument( + "--dataset-proportion", + type=float, + default=1.0, + help="How much proportion of dataset we want to train.", + ) + parser.add_argument( + "--remotedata", + default=False, + action="store_true", + help="if the dataset is remote, set this flag", + ) + parser.add_argument( + "--class-label-path", + type=str, + default=None, + help="The path of the class label pickle or csv.", + ) + parser.add_argument( + "--datasetpath", + type=str, + default="/mnt/audio_clip/webdataset_tar", + help="The path to the dataset", + ) + parser.add_argument( + "--logs", + type=str, + default="./logs/", + help="Where to store tensorboard logs. Use None to avoid storing logs.", + ) + parser.add_argument( + "--log-local", + action="store_true", + default=False, + help="log files on local master, otherwise global master only.", + ) + parser.add_argument( + "--name", + type=str, + default=None, + help="Optional identifier for the experiment when storing logs. Otherwise use current time.", + ) + parser.add_argument( + "--workers", type=int, default=1, help="Number of workers per GPU." + ) + parser.add_argument( + "--batch-size", type=int, default=64, help="Batch size per GPU." + ) + parser.add_argument( + "--epochs", type=int, default=32, help="Number of epochs to train for." + ) + parser.add_argument("--lr", type=float, default=None, help="Learning rate.") + parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") + parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") + parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") + parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.") + parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") + + parser.add_argument( + "--split-opt", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--lr-pretrained", type=float, default=None, help="Learning rate for text." + ) + parser.add_argument( + "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text." + ) + parser.add_argument( + "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text." + ) + parser.add_argument( + "--eps-pretrained", type=float, default=None, help="Adam epsilon for text." + ) + parser.add_argument( + "--wd-pretrained", type=float, default=0.2, help="Weight decay for text." + ) + parser.add_argument( + "--momentum-pretrained", type=float, default=0.9, help="Momentum for text." + ) + parser.add_argument( + "--lr-new", type=float, default=None, help="Learning rate for audio." + ) + parser.add_argument( + "--beta1-new", type=float, default=None, help="Adam beta 1 for audio." + ) + parser.add_argument( + "--beta2-new", type=float, default=None, help="Adam beta 2 for audio." + ) + parser.add_argument( + "--eps-new", type=float, default=None, help="Adam epsilon for audio." + ) + parser.add_argument( + "--wd-new", type=float, default=0.2, help="Weight decay for audio." + ) + parser.add_argument( + "--momentum-new", type=float, default=0.9, help="Momentum for audio." + ) + parser.add_argument( + "--warmup", type=int, default=10000, help="Number of steps to warmup for." + ) + parser.add_argument( + "--use-bn-sync", + default=False, + action="store_true", + help="Whether to use batch norm sync.", + ) + parser.add_argument( + "--skip-scheduler", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--save-frequency", type=int, default=1, help="How often to save checkpoints." + ) + parser.add_argument( + "--save-top-performance", + type=int, + default=0, + help="Save the top x performance weights if the value >0", + ) + parser.add_argument( + "--save-most-recent", + action="store_true", + default=False, + help="Always save the most recent model trained to epoch_latest.pt.", + ) + parser.add_argument( + "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." + ) + parser.add_argument( + "--val-frequency", + type=int, + default=1, + help="How often to run evaluation with val data.", + ) + parser.add_argument( + "--resume", + default=None, + type=str, + help="path to latest checkpoint (default: none)", + ) + parser.add_argument( + "--precision", + choices=["amp", "fp16", "fp32"], + default="amp", + help="Floating point precision.", + ) + parser.add_argument( + "--amodel", + type=str, + default="RN50", + help="Name of the audio backbone to use.", + ) + parser.add_argument( + "--tmodel", + type=str, + default="transformer", + help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]", + ) + parser.add_argument( + "--pretrained-audio", + default="", + type=str, + help="Use a pretrained audio model weights for the audio encoder of CLAP", + ) + parser.add_argument( + "--pretrained-text", + default="", + type=str, + help="Use a pretrained text model weights for the text encoder of CLAP", + ) + parser.add_argument( + "--pretrained", + default="", + type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--pretrained-image", + default=False, + action="store_true", + help="Load imagenet pretrained weights for image tower backbone if available.", + ) + parser.add_argument( + "--lock-image", + default=False, + action="store_true", + help="Lock full image tower by disabling gradients.", + ) + parser.add_argument( + "--lock-image-unlocked-groups", + type=int, + default=0, + help="Leave last n image tower layer groups unlocked.", + ) + parser.add_argument( + "--lock-image-freeze-bn-stats", + default=False, + action="store_true", + help="Freeze BatchNorm running stats in image tower for any locked layers.", + ) + parser.add_argument( + "--local-loss", + default=False, + action="store_true", + help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)", + ) + parser.add_argument( + "--gather-with-grad", + default=False, + action="store_true", + help="enable full distributed gradient for feature gather", + ) + parser.add_argument( + "--force-quick-gelu", + default=False, + action="store_true", + help="Force use of QuickGELU activation for non-OpenAI transformer models.", + ) + parser.add_argument( + "--torchscript", + default=False, + action="store_true", + help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", + ) + parser.add_argument( + "--trace", + default=False, + action="store_true", + help="torch.jit.trace the model for inference / eval only", + ) + # arguments for distributed training + parser.add_argument( + "--dist-url", + default="env://", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" + ) + parser.add_argument( + "--report-to", + default="", + type=str, + help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']", + ) + parser.add_argument( + "--wandb-notes", default="", type=str, help="Notes if logging with wandb" + ) + parser.add_argument( + "--C", type=float, default=3.16, help="inverse regularizer for logistic reg." + ) + parser.add_argument( + "--debug", + default=False, + action="store_true", + help="If true, more information is logged.", + ) + parser.add_argument( + "--copy-codebase", + default=False, + action="store_true", + help="If true, we copy the entire base on the log diretory, and execute from there.", + ) + parser.add_argument( + "--horovod", + default=False, + action="store_true", + help="Use horovod for distributed training.", + ) + parser.add_argument( + "--ddp-static-graph", + default=False, + action="store_true", + help="Enable static graph optimization for DDP in PyTorch >= 1.11.", + ) + parser.add_argument( + "--no-set-device-rank", + default=False, + action="store_true", + help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", + ) + parser.add_argument("--seed", type=int, default=4242, help="Default random seed.") + + parser.add_argument( + "--top-k-checkpoint-select-dataset", + type=str, + default="all", + help="The dataset of selecting top-k checkpoint.", + ) + + # @R10, @R@5, @R1, mAP@10 + parser.add_argument( + "--top-k-checkpoint-select-metric", + type=str, + default="_R@10", + help="The metric for selecting top-k checkpoint.", + ) + parser.add_argument( + "--openai-model-cache-dir", + type=str, + default="~/.cache/clip", + help="Directory to download OpenAI models.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="adamw", + help="can be AdamW or SGD", + ) + parser.add_argument( + "--parallel-eval", + default=False, + action="store_true", + help="Eval in parallel (multi-GPU, multi-node).", + ) + + parser.add_argument( + "--no-eval", + default=False, + action="store_true", + help="Training without evaluation.", + ) + + parser.add_argument( + "--lp-mlp", + default=False, + action="store_true", + help="Linear Probe using MLP layer or not.", + ) + + parser.add_argument( + "--lp-freeze", + default=False, + action="store_true", + help="Linear Probe using Freeze CLAP or not", + ) + + parser.add_argument( + "--lp-act", + default="None", + type=str, + help="Options are ['relu','elu','prelu','softmax','sigmoid']", + ) + + parser.add_argument( + "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe." + ) + + parser.add_argument( + "--lp-metrics", + type=str, + default="map,mauc,acc", + help="Metrics of Linear Probe.", + ) + + parser.add_argument( + "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe" + ) + parser.add_argument( + "--kappa", + type=float, + default=0, + help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss", + ) + + parser.add_argument( + "--data-filling", + type=str, + default="pad", + help="type of data filling when the audio length is shorter than the max length." + "Can be one of the following: repeat, repeatpad, pad", + ) + parser.add_argument( + "--data-truncating", + type=str, + default="rand_trunc", + help="type of data truncation when the audio length is longer than the max length." + "Can be one of the following: rand_trunc, fusion", + ) + + parser.add_argument( + "--clap-mlploss", + default=False, + action="store_true", + help="Using MLP loss for CLAP model or not", + ) + + parser.add_argument( + "--wandb-id", + type=str, + default=None, + help="the id of wandb experiment to restore.", + ) + + parser.add_argument( + "--sleep", type=float, default=0, help="sleep n seconds before start training" + ) + + # variable length processing + parser.add_argument( + "--enable-fusion", + default=False, + action="store_true", + help="Enable feature funsion for variable-length data", + ) + + parser.add_argument( + "--fusion-type", + type=str, + default="None", + help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']", + ) + + parser.add_argument( + "--mixup", + default=False, + action="store_true", + help="Enable mixup in finetuning training.", + ) + parser.add_argument( + "--text-augment-selection", + type=str, + default=None, + help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']", + ) + + args = parser.parse_args() + + # If some params are not passed, we use the default values based on model name. + default_params = get_default_params(args.amodel) + for name, val in default_params.items(): + if getattr(args, name) is None: + setattr(args, name, val) + + return args diff --git a/audioldm/clap/training/scheduler.py b/audioldm/clap/training/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..7151ffbab25a113673b7627027b443b27f22cb0f --- /dev/null +++ b/audioldm/clap/training/scheduler.py @@ -0,0 +1,24 @@ +import numpy as np + + +def assign_learning_rate(optimizer, new_lr): + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + + +def _warmup_lr(base_lr, warmup_length, step): + return base_lr * (step + 1) / warmup_length + + +def cosine_lr(optimizer, base_lr, warmup_length, steps): + def _lr_adjuster(step): + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + e = step - warmup_length + es = steps - warmup_length + lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr + assign_learning_rate(optimizer, lr) + return lr + + return _lr_adjuster diff --git a/audioldm/clap/training/train.py b/audioldm/clap/training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f5759c4679d2ee9c0748444adf66b8453cf09728 --- /dev/null +++ b/audioldm/clap/training/train.py @@ -0,0 +1,838 @@ +import json +import logging +import math +import os +import time +from contextlib import suppress + +import numpy as np +import torch +import torch.nn.functional as F + +try: + import wandb +except ImportError: + wandb = None + +from open_clip import ClipLoss, gather_features +from .distributed import is_master +from .zero_shot import zero_shot_eval + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def unwrap_model(model): + if hasattr(model, "module"): + return model.module + else: + return model + + +def train_one_epoch( + model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None +): + device = torch.device(args.device) + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + model.train() + loss = ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + weight_loss_kappa=args.kappa, + ) + + dataloader, sampler = data["train"].dataloader, data["train"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + num_batches_per_epoch = dataloader.num_batches + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + # for toy dataset + if args.dataset_type == "toy": + dataloader.dataset.generate_queue() + + loss_m = AverageMeter() + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for i, batch in enumerate(dataloader): + # logging.info(f"batch {i} of {num_batches_per_epoch}") + step = num_batches_per_epoch * epoch + i + if isinstance(scheduler, dict): + for s in scheduler.values(): + s(step) + else: + scheduler(step) + audios = batch # contains mel_spec, wavform, and longer list + texts = batch["text"] + # audios = audios.to(device=device, non_blocking=True) + # texts = texts.to(device=device, non_blocking=True) + + data_time_m.update(time.time() - end) + if isinstance(optimizer, dict): + for o_ in optimizer.values(): + o_.zero_grad() + else: + optimizer.zero_grad() + + with autocast(): + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + logit_scale_a, + logit_scale_t, + ) = model(audios, texts, device) + + if args.clap_mlploss: + total_loss = loss( + audio_features=audio_features, + text_features=text_features, + logit_scale_a=logit_scale_a, + logit_scale_t=logit_scale_t, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + ) + else: + total_loss = loss( + audio_features=audio_features, + text_features=text_features, + logit_scale_a=logit_scale_a, + ) + if isinstance(optimizer, dict): + if scaler is not None: + scaler.scale(total_loss).backward() + for o_ in optimizer.values(): + if args.horovod: + o_.synchronize() + scaler.unscale_(o_) + with o_.skip_synchronize(): + scaler.step(o_) + else: + scaler.step(o_) + scaler.update() + else: + total_loss.backward() + for o_ in optimizer.values(): + o_.step() + else: + if scaler is not None: + scaler.scale(total_loss).backward() + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + optimizer.step() + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).logit_scale_a.clamp_(0, math.log(100)) + if args.clap_mlploss: + unwrap_model(model).logit_scale_t.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i + 1 + if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): + if isinstance(audios, dict): + batch_size = len(audios["waveform"]) + else: + batch_size = len(audios) + num_samples = batch_count * batch_size * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + loss_m.update(total_loss.item(), batch_size) + logit_scale_scalar_a = logit_scale_a.item() + logit_scale_scalar_t = logit_scale_t.item() + if isinstance(optimizer, dict): + if args.clap_mlploss: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + f"Logit Scale Text: {logit_scale_scalar_t:.3f}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "scale_text": logit_scale_scalar_t, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + + else: + if args.clap_mlploss: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + f"Logit Scale Text: {logit_scale_scalar_t:.3f}" + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "scale_text": logit_scale_scalar_t, + "lr": optimizer.param_groups[0]["lr"], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "lr": optimizer.param_groups[0]["lr"], + } + for name, val in log_data.items(): + name = "train/" + name + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, "Please install wandb." + wandb.log({name: val, "step": step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + + +def evaluate(model, data, epoch, args, tb_writer=None): + metrics = {} + if not args.parallel_eval: + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + # CHANGE + # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + # metrics.update(zero_shot_metrics) + if is_master(args): + print("Evaluating...") + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + if args.val_dataset_names == ["Clotho", "audiocaps"]: + # if only clotho and audiocaps are used, then we will use a different evaluation function. + # This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio. + if args.parallel_eval: + # (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps. + raise NotImplementedError( + "Parallel evaluation not supported for eval only Clotho and audiocaps." + ) + val_metrics_per_dataset = evaluate_clotho_audiocaps( + model, data, epoch, args, autocast, device, tb_writer + ) + for m in val_metrics_per_dataset.values(): + metrics.update(m) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + metrics = select_top_metric_clotho_audiocaps( + metrics, val_metrics_per_dataset, args + ) + elif "val" in data and ( + args.val_frequency + and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) + ): + dataloader = data["val"].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + # FIXME this does not scale past small eval datasets + # all_audio_features @ all_text_features will blow up memory and compute very quickly + eval_info = {} + if args.clap_mlploss: + eval_info["all"] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + "all_audio_features_mlp": [], + "all_text_features_mlp": [], + } # cumulative_loss = 0.0 + else: + eval_info["all"] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } # cumu + # all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], [] + with torch.no_grad(): + for i, batch in enumerate(dataloader): + audios = batch # contains mel_spec, wavform, and longer list + texts = batch["text"] + # audios = audios.to(device=device, non_blocking=True) + + all_names = list( + set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) + ) + for name in all_names: + if name not in eval_info.keys(): + if args.clap_mlploss: + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + "all_audio_features_mlp": [], + "all_text_features_mlp": [], + } + else: + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + with autocast(): + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + logit_scale_a, + logit_scale_t, + ) = model(audios, texts, device) + + if args.parallel_eval: + # multi-GPU eval + if args.clap_mlploss: + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + ) = gather_features( + audio_features=audio_features, + text_features=text_features, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + local_loss=False, + gather_with_grad=False, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + ) + else: + (audio_features, text_features,) = gather_features( + audio_features=audio_features, + text_features=text_features, + local_loss=False, + gather_with_grad=False, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + ) + + if is_master(args): + num_samples += audio_features.shape[0] + for n in [*all_names, "all"]: + if n == "all": + eval_info[n]["all_audio_features"].append( + audio_features.cpu() + ) + eval_info[n]["all_text_features"].append( + text_features.cpu() + ) + if args.clap_mlploss: + eval_info[n]["all_audio_features_mlp"].append( + audio_features_mlp.cpu() + ) + eval_info[n]["all_text_features_mlp"].append( + text_features_mlp.cpu() + ) + else: + idx = np.where( + np.array( + [ + "-".join(b.split("/")[-3:-1]) + for b in batch["__url__"] + ] + ) + == n + )[0] + eval_info[n]["all_audio_features"].append( + audio_features.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + eval_info[n]["all_text_features"].append( + text_features.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + if args.clap_mlploss: + eval_info[n]["all_audio_features_mlp"].append( + audio_features_mlp.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + eval_info[n]["all_text_features_mlp"].append( + text_features_mlp.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + # print(f'eval step {i}') # (yusong): for debug + + # cumulative_loss += total_loss * batch_size + # num_samples += batch_size + if is_master(args) and (i % 100) == 0: # and i != 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" + ) + if is_master(args): + val_metrics_per_dataset = {} + for n in eval_info.keys(): + if args.clap_mlploss: + metrics_single_dataset = get_metrics( + audio_features=torch.cat( + eval_info[n]["all_audio_features"] + ), + text_features=torch.cat(eval_info[n]["all_text_features"]), + logit_scale_a=logit_scale_a.cpu(), + audio_features_mlp=torch.cat( + eval_info[n]["all_audio_features_mlp"] + ), + text_features_mlp=torch.cat( + eval_info[n]["all_text_features_mlp"] + ), + logit_scale_t=logit_scale_t.cpu(), + mlp_loss=args.clap_mlploss, + ) + else: + metrics_single_dataset = get_metrics( + audio_features=torch.cat( + eval_info[n]["all_audio_features"] + ), + text_features=torch.cat(eval_info[n]["all_text_features"]), + logit_scale_a=logit_scale_a.cpu(), + mlp_loss=args.clap_mlploss, + ) + val_metrics_per_dataset[n] = { + n + "/" + k: v for k, v in metrics_single_dataset.items() + } + metrics.update(val_metrics_per_dataset[n]) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + if is_master(args): + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\n".join( + [ + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()]) + for m in val_metrics_per_dataset.values() + ] + ) + ) + + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, "Please install wandb." + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, "epoch": epoch}) + + return metrics + else: + return metrics + + +def get_metrics( + audio_features, + text_features, + logit_scale_a, + audio_features_mlp=None, + text_features_mlp=None, + logit_scale_t=None, + mlp_loss=False, +): + metrics = {} + if mlp_loss: + # Set up audio to text & text to audio similary matrice + a_logits_per_audio = ( + (logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu() + ) + a_logits_per_text = a_logits_per_audio.t().detach().cpu() + t_logits_per_audio = ( + (logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu() + ) + t_logits_per_text = t_logits_per_audio.t().detach().cpu() + + labels = torch.arange(audio_features.shape[0]).long() + # Change the loss from two terms into four terms with 2x2 combined CE loss + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels) + + F.cross_entropy(a_logits_per_text, labels) + + F.cross_entropy(t_logits_per_audio, labels) + + F.cross_entropy(t_logits_per_text, labels) + ) / 4 + + metrics[f"cumulative_loss"] = total_loss.item() + metrics[f"num_samples"] = audio_features.shape[0] + + logits = { + "audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2, + "text_to_audio": (a_logits_per_text + t_logits_per_text) / 2, + } + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + else: + # print("text_features", text_features) + # print("text_features.shape", text_features.shape) + logits_per_audio = ( + (logit_scale_a * audio_features @ text_features.t()).detach().cpu() + ) + logits_per_text = logits_per_audio.t().detach().cpu() + + labels = torch.arange(audio_features.shape[0]).long() + # Change the loss from two terms into four terms with 2x2 combined CE loss + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + metrics[f"cumulative_loss"] = total_loss.item() + metrics[f"num_samples"] = audio_features.shape[0] + + logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text} + + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + for name, logit in logits.items(): + ranking = torch.argsort(logit, descending=True) + preds = torch.where(ranking == ground_truth)[ + 1 + ] # (yusong) this line is slow because it uses single thread + preds = preds.detach().cpu().numpy() + metrics[f"{name}_mean_rank"] = preds.mean() + 1 + metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = np.mean(preds < k) + # map@10 + metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) + + return metrics + + +def evaluate_clotho_audiocaps( + model, data, epoch, args, autocast, device, tb_writer=None +): + """ + Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py. + 1. for text-to-audio retrieval, do 5 times and average the results + 2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text + 3. for map@10 in audio-to-text retrieval: + 3.1: sort the rank of 5 text + 3.2: exclude the rank >=10 (0-index) + 3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks). + (3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth. + (3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc. + """ + # TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now. + dataloader = data["val"].dataloader + with torch.no_grad(): + eval_info = {} + for i, batch in enumerate(dataloader): + audios = batch # contains mel_spec, wavform, and longer list + + # each item in the list has 5 texts + if args.tmodel == "transformer": + from open_clip import tokenize + + texts = [tokenize(t) for t in batch["full_text"]] + texts = torch.cat(texts) + else: + from .data import tokenizer + + texts = [ + tokenizer(t) for t in batch["full_text"] + ] # 5 texts for each audio + texts = { + k: torch.cat([t[k] for t in texts]) for k in texts[0].keys() + } # 5 x batch + + # audios = audios.to(device=device, non_blocking=True) + + all_names = list( + set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) + ) + for name in all_names: + if name not in eval_info.keys(): + # we will not use mlp outputs even if args.clap_mlploss=True + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + with autocast(): + audio_features = model(audios, None, device) + text_features = model(None, texts, device) + audio_features = F.normalize(audio_features, dim=-1) + text_features = F.normalize(text_features, dim=-1) + + all_names = list( + set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) + ) + for n in all_names: + idx = np.where( + np.array( + ["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]] + ) + == n + )[0] + eval_info[n]["all_audio_features"].append( + audio_features.cpu().index_select(0, torch.tensor(idx).long()) + ) + # (yusong) please double-check. This is for selecting 5 text features at once. + # because idx is a list of indices in size of num_samples, + # and text_features is a tensor of size (5*num_samples, dim) + # so we need to select 5 consecutive indices at once for a single index in idx. + eval_info[n]["all_text_features"].append( + text_features.cpu() + .reshape([-1, 5, text_features.shape[1]]) + .index_select(0, torch.tensor(idx).long()) + .reshape([-1, text_features.shape[1]]) + ) + + val_metrics_all = {} + + for n in eval_info.keys(): + logit_scale_a, logit_scale_t = model(None, None, device) + logit_scale_a = logit_scale_a.cpu() + + audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0) + text_features = torch.cat(eval_info[n]["all_text_features"], dim=0) + + logits_per_audio = ( + (logit_scale_a * audio_features @ text_features.t()).detach().cpu() + ) + logits_per_text = logits_per_audio.t().detach().cpu() + + # logits_per_audio shape: [num_samples, num_samples*5] + # logits_per_text shape: [num_samples*5, num_samples] + + logging.info( + f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, " + f"logits_per_text shape: {logits_per_text.shape}" + ) + + metrics = {} + num_samples = audio_features.shape[0] + metrics[f"num_samples"] = num_samples + + # (yusong) the following code is very important, please double-check: + # logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d] + # logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] + # Those two are retrieving one of the 5 text for each audio. + labels = torch.arange(audio_features.shape[0]).long() + audio_to_text_loss = [ + F.cross_entropy( + logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d], + labels, + ) + for d in range(5) + ] + text_to_audio_loss = [ + F.cross_entropy( + logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :], + labels, + ) + for d in range(5) + ] + total_loss = (np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)) / 2 + + metrics[f"cumulative_loss"] = total_loss.item() + + # text to audio: do 5 times + pred_text = [] + for d in range(5): + logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] + ground_truth = torch.arange(len(logit)).view(-1, 1) + ranking = torch.argsort( + logit, descending=True + ) # [num_samples, num_samples] + preds = torch.where(ranking == ground_truth)[1] + pred_text.append(preds.detach().cpu().numpy()) + pred_text_concat = np.concatenate(pred_text, axis=0) # [5*num_samples] + metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1 + metrics[f"text_to_audio_median_rank"] = ( + np.floor(np.median(pred_text_concat)) + 1 + ) + for k in [1, 5, 10]: + metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k) + # map@10 + metrics[f"text_to_audio_mAP@10"] = np.mean( + np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0) + ) + + # audio to text: take the best result + # for audio to text map 10, sort and assign descending ground truth. + # see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103 + # map@10 + map_all = [] + pred_audio_all = [] + for d in range(num_samples): + # logits_per_audio: [num_samples, num_samples*5] + logit_single = logits_per_audio[d, :] # [5*num_samples] + # Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4] + ranking = torch.argsort( + logit_single, descending=True + ) # [5*num_samples] + # ranking: the index of first match, second match, ... + ground_truth = torch.arange(d * 5, d * 5 + 5)[None] + all_pred = torch.where( + torch.stack([ranking] * 5) == ground_truth.view(-1, 1) + )[1] + min_pred = torch.min(all_pred) + pred_audio_all.append(min_pred.detach().cpu().numpy()) + all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy() + # /5 because we have 5 text, so it means for the text rank >=10 we count as 0. + map_single = ( + np.sum( + (np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1)) + ) + / 5 + ) + map_all.append(map_single) + metrics[f"audio_to_text_mAP@10"] = np.mean(map_all) + for k in [1, 5, 10]: + metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k) + + val_metrics_all[n] = {n + "/" + k: v for k, v in metrics.items()} + return val_metrics_all + + +def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset): + """ + Calculate performance for Clotho+AudioCaps for model selection. + """ + selection_performance_all = [] + for n in val_metrics_per_dataset.keys(): + selection_performance = ( + val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"] + + val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"] + ) / 2 + selection_performance_all.append(selection_performance) + return np.mean(selection_performance_all) + + +def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args): + # val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value + # metrics: dict, key: metric name, value: metric value + # Hack: use args to save the top performance + if not hasattr(args, "top_selection_performance"): + selection_performance = calculate_selection_performance_clotho_audiocaps( + val_metrics_per_dataset + ) + # TODO: write the if and else together + metric_update = {} + for n in val_metrics_per_dataset.keys(): + for k in val_metrics_per_dataset[n].keys(): + metric_update[ + k.split("/")[0] + "-top" + "/" + k.split("/")[1] + ] = val_metrics_per_dataset[n][k] + metric_update["top_selection_performance"] = selection_performance + metric_update["top-selection-epoch"] = metrics["epoch"] + metrics.update(metric_update) + args.top_metric = metric_update + args.top_selection_performance = selection_performance + else: + selection_performance_new = calculate_selection_performance_clotho_audiocaps( + val_metrics_per_dataset + ) + selection_performance_old = args.top_selection_performance + if selection_performance_new > selection_performance_old: + metric_update = {} + for n in val_metrics_per_dataset.keys(): + for k in val_metrics_per_dataset[n].keys(): + metric_update[ + k.split("/")[0] + "-top" + "/" + k.split("/")[1] + ] = val_metrics_per_dataset[n][k] + metric_update["top_selection_performance"] = selection_performance_new + metric_update["top-selection-epoch"] = metrics["epoch"] + metrics.update(metric_update) + args.top_metric = metric_update + args.top_selection_performance = selection_performance_new + else: + metrics.update(args.top_metric) + return metrics diff --git a/audioldm/clap/training/zero_shot.py b/audioldm/clap/training/zero_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..28b8fccc1af17fc69002857a7f529ac041c374f2 --- /dev/null +++ b/audioldm/clap/training/zero_shot.py @@ -0,0 +1,95 @@ +# NOTE: This script is currently not supported for CLAP. +import logging +from contextlib import suppress + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from open_clip import tokenize +from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template + + +def zero_shot_classifier(model, classnames, templates, args): + with torch.no_grad(): + zeroshot_weights = [] + for classname in tqdm(classnames): + texts = [template(classname) for template in templates] # format with class + texts = tokenize(texts).to(args.device) # tokenize + if args.distributed and not args.horovod: + class_embeddings = model.module.encode_text(texts) + else: + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) + return zeroshot_weights + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [ + float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) + for k in topk + ] + + +def run(model, classifier, dataloader, args): + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + with torch.no_grad(): + top1, top5, n = 0.0, 0.0, 0.0 + for images, target in tqdm(dataloader, unit_scale=args.batch_size): + images = images.to(args.device) + target = target.to(args.device) + + with autocast(): + # predict + if args.distributed and not args.horovod: + image_features = model.module.encode_image(images) + else: + image_features = model.encode_image(images) + image_features = F.normalize(image_features, dim=-1) + logits = 100.0 * image_features @ classifier + + # measure accuracy + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = top1 / n + top5 = top5 / n + return top1, top5 + + +def zero_shot_eval(model, data, epoch, args): + if "imagenet-val" not in data and "imagenet-v2" not in data: + return {} + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + + logging.info("Starting zero-shot imagenet.") + + logging.info("Building zero-shot classifier") + classifier = zero_shot_classifier( + model, imagenet_classnames, openai_imagenet_template, args + ) + + logging.info("Using classifier") + results = {} + if "imagenet-val" in data: + top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args) + results["imagenet-zeroshot-val-top1"] = top1 + results["imagenet-zeroshot-val-top5"] = top5 + if "imagenet-v2" in data: + top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args) + results["imagenetv2-zeroshot-val-top1"] = top1 + results["imagenetv2-zeroshot-val-top5"] = top5 + + logging.info("Finished zero-shot imagenet.") + + return results diff --git a/audioldm/hifigan/__init__.py b/audioldm/hifigan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ae476fe58c48e998c56234a55b871beba4042d --- /dev/null +++ b/audioldm/hifigan/__init__.py @@ -0,0 +1,7 @@ +from .models import Generator + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self diff --git a/audioldm/hifigan/models.py b/audioldm/hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..c4382cc39de0463f9b7c0f33f037dbc233e7cb36 --- /dev/null +++ b/audioldm/hifigan/models.py @@ -0,0 +1,174 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + # print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/audioldm/hifigan/utilities.py b/audioldm/hifigan/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..47fd39ea0af181772d640feec2413cf631a75702 --- /dev/null +++ b/audioldm/hifigan/utilities.py @@ -0,0 +1,85 @@ +import os +import json + +import torch +import numpy as np + +import audioldm.hifigan as hifigan + +HIFIGAN_16K_64 = { + "resblock": "1", + "num_gpus": 6, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "upsample_initial_channel": 1024, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "segment_size": 8192, + "num_mels": 64, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 160, + "win_size": 1024, + "sampling_rate": 16000, + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, +} + + +def get_available_checkpoint_keys(model, ckpt): + print("==> Attemp to reload from %s" % ckpt) + state_dict = torch.load(ckpt)["state_dict"] + current_state_dict = model.state_dict() + new_state_dict = {} + for k in state_dict.keys(): + if ( + k in current_state_dict.keys() + and current_state_dict[k].size() == state_dict[k].size() + ): + new_state_dict[k] = state_dict[k] + else: + print("==> WARNING: Skipping %s" % k) + print( + "%s out of %s keys are matched" + % (len(new_state_dict.keys()), len(state_dict.keys())) + ) + return new_state_dict + + +def get_param_num(model): + num_param = sum(param.numel() for param in model.parameters()) + return num_param + + +def get_vocoder(config, device): + config = hifigan.AttrDict(HIFIGAN_16K_64) + vocoder = hifigan.Generator(config) + vocoder.eval() + vocoder.remove_weight_norm() + vocoder.to(device) + return vocoder + + +def vocoder_infer(mels, vocoder, lengths=None): + with torch.no_grad(): + wavs = vocoder(mels).squeeze(1) + + wavs = (wavs.cpu().numpy() * 32768).astype("int16") + + if lengths is not None: + wavs = wavs[:, :lengths] + + return wavs diff --git a/audioldm/latent_diffusion/__init__.py b/audioldm/latent_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm/latent_diffusion/attention.py b/audioldm/latent_diffusion/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..583dd169e7ec9502ee29faeb12689a46494838c0 --- /dev/null +++ b/audioldm/latent_diffusion/attention.py @@ -0,0 +1,468 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn +from einops import rearrange + +from audioldm.latent_diffusion.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + """ + ### Cross Attention Layer + This falls-back to self-attention when conditional embeddings are not specified. + """ + + # use_flash_attention: bool = True + use_flash_attention: bool = False + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + is_inplace: bool = True, + ): + # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True): + """ + :param d_model: is the input embedding size + :param n_heads: is the number of attention heads + :param d_head: is the size of a attention head + :param d_cond: is the size of the conditional embeddings + :param is_inplace: specifies whether to perform the attention softmax computation inplace to + save memory + """ + super().__init__() + + self.is_inplace = is_inplace + self.n_heads = heads + self.d_head = dim_head + + # Attention scaling factor + self.scale = dim_head**-0.5 + + # The normal self-attention layer + if context_dim is None: + context_dim = query_dim + + # Query, key and value mappings + d_attn = dim_head * heads + self.to_q = nn.Linear(query_dim, d_attn, bias=False) + self.to_k = nn.Linear(context_dim, d_attn, bias=False) + self.to_v = nn.Linear(context_dim, d_attn, bias=False) + + # Final linear layer + self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout)) + + # Setup [flash attention](https://github.com/HazyResearch/flash-attention). + # Flash attention is only used if it's installed + # and `CrossAttention.use_flash_attention` is set to `True`. + try: + # You can install flash attention by cloning their Github repo, + # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention) + # and then running `python setup.py install` + from flash_attn.flash_attention import FlashAttention + + self.flash = FlashAttention() + # Set the scale for scaled dot-product attention. + self.flash.softmax_scale = self.scale + # Set to `None` if it's not installed + except ImportError: + self.flash = None + + def forward(self, x, context=None, mask=None): + """ + :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` + :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` + """ + + # If `cond` is `None` we perform self attention + has_cond = context is not None + if not has_cond: + context = x + + # Get query, key and value vectors + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + + # Use flash attention if it's available and the head size is less than or equal to `128` + if ( + CrossAttention.use_flash_attention + and self.flash is not None + and not has_cond + and self.d_head <= 128 + ): + return self.flash_attention(q, k, v) + # Otherwise, fallback to normal attention + else: + return self.normal_attention(q, k, v) + + def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """ + #### Flash Attention + :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + """ + + # Get batch size and number of elements along sequence axis (`width * height`) + batch_size, seq_len, _ = q.shape + + # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of + # shape `[batch_size, seq_len, 3, n_heads * d_head]` + qkv = torch.stack((q, k, v), dim=2) + # Split the heads + qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head) + + # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to + # fit this size. + if self.d_head <= 32: + pad = 32 - self.d_head + elif self.d_head <= 64: + pad = 64 - self.d_head + elif self.d_head <= 128: + pad = 128 - self.d_head + else: + raise ValueError(f"Head size ${self.d_head} too large for Flash Attention") + + # Pad the heads + if pad: + qkv = torch.cat( + (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1 + ) + + # Compute attention + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ + # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]` + # TODO here I add the dtype changing + out, _ = self.flash(qkv.type(torch.float16)) + # Truncate the extra head size + out = out[:, :, :, : self.d_head].float() + # Reshape to `[batch_size, seq_len, n_heads * d_head]` + out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head) + + # Map to `[batch_size, height * width, d_model]` with a linear layer + return self.to_out(out) + + def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """ + #### Normal Attention + + :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + """ + + # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` + q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32] + k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32] + v = v.view(*v.shape[:2], self.n_heads, -1) + + # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ + attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale + + # Compute softmax + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ + if self.is_inplace: + half = attn.shape[0] // 2 + attn[half:] = attn[half:].softmax(dim=-1) + attn[:half] = attn[:half].softmax(dim=-1) + else: + attn = attn.softmax(dim=-1) + + # Compute attention output + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ + # attn: [bs, 20, 64, 1] + # v: [bs, 1, 20, 32] + out = torch.einsum("bhij,bjhd->bihd", attn, v) + # Reshape to `[batch_size, height * width, n_heads * d_head]` + out = out.reshape(*out.shape[:2], -1) + # Map to `[batch_size, height * width, d_model]` with a linear layer + return self.to_out(out) + + +# class CrossAttention(nn.Module): +# def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): +# super().__init__() +# inner_dim = dim_head * heads +# context_dim = default(context_dim, query_dim) + +# self.scale = dim_head ** -0.5 +# self.heads = heads + +# self.to_q = nn.Linear(query_dim, inner_dim, bias=False) +# self.to_k = nn.Linear(context_dim, inner_dim, bias=False) +# self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + +# self.to_out = nn.Sequential( +# nn.Linear(inner_dim, query_dim), +# nn.Dropout(dropout) +# ) + +# def forward(self, x, context=None, mask=None): +# h = self.heads + +# q = self.to_q(x) +# context = default(context, x) +# k = self.to_k(context) +# v = self.to_v(context) + +# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + +# sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + +# if exists(mask): +# mask = rearrange(mask, 'b ... -> b (...)') +# max_neg_value = -torch.finfo(sim.dtype).max +# mask = repeat(mask, 'b j -> (b h) () j', h=h) +# sim.masked_fill_(~mask, max_neg_value) + +# # attention, what we cannot get enough of +# attn = sim.softmax(dim=-1) + +# out = einsum('b i j, b j d -> b i d', attn, v) +# out = rearrange(out, '(b h) n d -> b n (h d)', h=h) +# return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + if context is None: + return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) + else: + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + no_context=False, + ): + super().__init__() + + if no_context: + context_dim = None + + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/audioldm/latent_diffusion/ddim.py b/audioldm/latent_diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..57ee8d302c77cb09bd73ef803ef9e715098feafc --- /dev/null +++ b/audioldm/latent_diffusion/ddim.py @@ -0,0 +1,377 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from audioldm.latent_diffusion.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, + extract_into_tensor, +) +import gradio as gr + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = gr.Progress().tqdm(time_range, desc="DDIM Sampler", total=total_steps) + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO deterministic forward pass? + img = ( + img_orig * mask + (1.0 - mask) * img + ) # In the first sampling step, img is pure gaussian noise + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + img, pred_x0 = outs + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + ): + + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = gr.Progress().tqdm(time_range, desc="Decoding image", total=total_steps) + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) + x_dec = x_latent + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return x_dec + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + # When unconditional_guidance_scale == 1: only e_t + # When unconditional_guidance_scale == 0: only unconditional + # When unconditional_guidance_scale > 1: add more unconditional guidance + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise # TODO + return x_prev, pred_x0 diff --git a/audioldm/latent_diffusion/ddpm.py b/audioldm/latent_diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..ffca031c27d413698adee5a58547b7d0ea4069c3 --- /dev/null +++ b/audioldm/latent_diffusion/ddpm.py @@ -0,0 +1,441 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" +import sys +import os + +import torch +import torch.nn as nn +import numpy as np +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm + +from audioldm.utils import exists, default, count_params, instantiate_from_config +from audioldm.latent_diffusion.ema import LitEma +from audioldm.latent_diffusion.util import ( + make_beta_schedule, + extract_into_tensor, + noise_like, +) +import soundfile as sf +import os + + +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DiffusionWrapper(nn.Module): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [ + None, + "concat", + "crossattn", + "hybrid", + "adm", + "film", + ] + + def forward( + self, x, t, c_concat: list = None, c_crossattn: list = None, c_film: list = None + ): + x = x.contiguous() + t = t.contiguous() + + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == "concat": + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == "crossattn": + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == "hybrid": + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif ( + self.conditioning_key == "film" + ): # The condition is assumed to be a global token, which wil pass through a linear layer and added with the time embedding for the FILM + cc = c_film[0].squeeze(1) # only has one token + out = self.diffusion_model(x, t, y=cc) + elif self.conditioning_key == "adm": + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class DDPM(nn.Module): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + latent_t_size=256, + latent_f_size=16, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0.0, + ): + super().__init__() + assert parameterization in [ + "eps", + "x0", + ], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + self.state = None + # print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + + self.latent_t_size = latent_t_size + self.latent_f_size = latent_f_size + + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + else: + self.logvar = nn.Parameter(self.logvar, requires_grad=False) + + self.logger_save_dir = None + self.logger_project = None + self.logger_version = None + self.label_indices_total = None + # To avoid the system cannot find metric value for checkpoint + self.metrics_buffer = { + "val/kullback_leibler_divergence_sigmoid": 15.0, + "val/kullback_leibler_divergence_softmax": 10.0, + "val/psnr": 0.0, + "val/ssim": 0.0, + "val/inception_score_mean": 1.0, + "val/inception_score_std": 0.0, + "val/kernel_inception_distance_mean": 0.0, + "val/kernel_inception_distance_std": 0.0, + "val/frechet_inception_distance": 133.0, + "val/frechet_audio_distance": 32.0, + } + self.initial_learning_rate = None + + def get_log_dir(self): + if ( + self.logger_save_dir is None + and self.logger_project is None + and self.logger_version is None + ): + return os.path.join( + self.logger.save_dir, self.logger._project, self.logger.version + ) + else: + return os.path.join( + self.logger_save_dir, self.logger_project, self.logger_version + ) + + def set_log_dir(self, save_dir, project, version): + self.logger_save_dir = save_dir + self.logger_project = project + self.logger_version = version + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + # print(f"{context}: Switched to EMA weights") + pass + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + # print(f"{context}: Restored training weights") + pass + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised + ) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = ( + (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() + ) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm( + reversed(range(0, self.num_timesteps)), + desc="Sampling t", + total=self.num_timesteps, + ): + img = self.p_sample( + img, + torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised, + ) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + shape = (batch_size, channels, self.latent_t_size, self.latent_f_size) + channels = self.channels + return self.p_sample_loop(shape, return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def forward(self, x, *args, **kwargs): + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch + fbank, log_magnitudes_stft, label_indices, fname, waveform, text = batch + ret = {} + + ret["fbank"] = ( + fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() + ) + ret["stft"] = log_magnitudes_stft.to( + memory_format=torch.contiguous_format + ).float() + # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float() + ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() + ret["text"] = list(text) + ret["fname"] = fname + + return ret[k] diff --git a/audioldm/latent_diffusion/ema.py b/audioldm/latent_diffusion/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..192b012186bab3d8a5380bc9b891da8eef0fd9fa --- /dev/null +++ b/audioldm/latent_diffusion/ema.py @@ -0,0 +1,81 @@ +import torch +from torch import nn + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/audioldm/latent_diffusion/openaimodel.py b/audioldm/latent_diffusion/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..831d7aafb36bba16888e4389153979a6c13639f5 --- /dev/null +++ b/audioldm/latent_diffusion/openaimodel.py @@ -0,0 +1,1069 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from audioldm.latent_diffusion.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from audioldm.latent_diffusion.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1).contiguous() # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1).contiguous() + qkv = self.qkv(self.norm(x)).contiguous() + h = self.attention(qkv).contiguous() + h = self.proj_out(h).contiguous() + return (x + h).reshape(b, c, *spatial).contiguous() + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = ( + qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1) + ) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum( + "bts,bcs->bct", + weight, + v.reshape(bs * self.n_heads, ch, length).contiguous(), + ) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + extra_film_condition_dim=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + extra_film_use_concat=False, # If true, concatenate extrafilm condition with time embedding, else addition + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.extra_film_condition_dim = extra_film_condition_dim + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + self.extra_film_use_concat = extra_film_use_concat + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + assert not ( + self.num_classes is not None and self.extra_film_condition_dim is not None + ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim." + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.use_extra_film_by_concat = ( + self.extra_film_condition_dim is not None and self.extra_film_use_concat + ) + self.use_extra_film_by_addition = ( + self.extra_film_condition_dim is not None and not self.extra_film_use_concat + ) + + if self.extra_film_condition_dim is not None: + self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim) + # print("+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " % self.extra_film_condition_dim) + # if(self.use_extra_film_by_concat): + # print("\t By concatenation with time embedding") + # elif(self.use_extra_film_by_concat): + # print("\t By addition with time embedding") + + if use_spatial_transformer and ( + self.use_extra_film_by_concat or self.use_extra_film_by_addition + ): + # print("+ Spatial transformer will only be used as self-attention. Because you have choose to use film as your global condition.") + spatial_transformer_no_context = True + else: + spatial_transformer_no_context = False + + if use_spatial_transformer and not spatial_transformer_no_context: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None and not spatial_transformer_no_context: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + from omegaconf.listconfig import ListConfig + + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + no_context=spatial_transformer_no_context, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + no_context=spatial_transformer_no_context, + ), + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + no_context=spatial_transformer_no_context, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + self.shape_reported = False + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional + :return: an [N x C x ...] Tensor of outputs. + """ + if not self.shape_reported: + # print("The shape of UNet input is", x.size()) + self.shape_reported = True + + assert (y is not None) == ( + self.num_classes is not None or self.extra_film_condition_dim is not None + ), "must specify y if and only if the model is class-conditional or film embedding conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + if self.use_extra_film_by_addition: + emb = emb + self.film_emb(y) + elif self.use_extra_film_by_concat: + emb = th.cat([emb, self.film_emb(y)], dim=-1) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) diff --git a/audioldm/latent_diffusion/util.py b/audioldm/latent_diffusion/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8b289f6aa7f22a070870d8a706f944dc8547e936 --- /dev/null +++ b/audioldm/latent_diffusion/util.py @@ -0,0 +1,295 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from audioldm.utils import instantiate_from_config + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t).contiguous() + return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/audioldm/ldm.py b/audioldm/ldm.py new file mode 100644 index 0000000000000000000000000000000000000000..b0392e28404c315e5d8ca5ede571da386f5d4b42 --- /dev/null +++ b/audioldm/ldm.py @@ -0,0 +1,715 @@ +import os + +import torch +import numpy as np +from tqdm import tqdm +from audioldm.utils import default, instantiate_from_config, save_wave +from audioldm.latent_diffusion.ddpm import DDPM +from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution +from audioldm.latent_diffusion.util import noise_like +from audioldm.latent_diffusion.ddim import DDIMSampler +import os + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + +class LatentDiffusion(DDPM): + """main class""" + + def __init__( + self, + device="cuda", + first_stage_config=None, + cond_stage_config=None, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + base_learning_rate=None, + *args, + **kwargs, + ): + self.device = device + self.learning_rate = base_learning_rate + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs["timesteps"] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = "concat" if concat_mode else "crossattn" + if cond_stage_config == "__is_unconditional__": + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.cond_stage_key_orig = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer("scale_factor", torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule( + given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s + ) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != "__is_first_stage__" + assert config != "__is_unconditional__" + model = instantiate_from_config(config) + self.cond_stage_model = model + self.cond_stage_model = self.cond_stage_model.to(self.device) + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, "encode") and callable( + self.cond_stage_model.encode + ): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + if len(c) == 1: + c = self.cond_stage_model([c[0], c[0]]) + c = c[0:1] + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + @torch.no_grad() + def get_input( + self, + batch, + k, + return_first_stage_encode=True, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + ): + x = super().get_input(batch, k) + + if bs is not None: + x = x[:bs] + + x = x.to(self.device) + + if return_first_stage_encode: + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + else: + z = None + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ["caption", "coordinates_bbox"]: + xc = batch[cond_key] + elif cond_key == "class_label": + xc = batch + else: + # [bs, 1, 527] + xc = super().get_input(batch, cond_key) + if type(xc) == torch.Tensor: + xc = xc.to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + + if bs is not None: + c = c[:bs] + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {"pos_x": pos_x, "pos_y": pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, "b h w c -> b c h w").contiguous() + + z = 1.0 / self.scale_factor * z + return self.first_stage_model.decode(z) + + def mel_spectrogram_to_waveform(self, mel): + # Mel: [bs, 1, t-steps, fbins] + if len(mel.size()) == 4: + mel = mel.squeeze(1) + mel = mel.permute(0, 2, 1) + waveform = self.first_stage_model.vocoder(mel) + waveform = waveform.cpu().detach().numpy() + return waveform + + @torch.no_grad() + def encode_first_stage(self, x): + return self.first_stage_model.encode(x) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + if self.model.conditioning_key == "concat": + key = "c_concat" + elif self.model.conditioning_key == "crossattn": + key = "c_crossattn" + else: + key = "c_film" + + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score( + self, model_out, x, t, c, **corrector_kwargs + ) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = ( + (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() + ) + + if return_codebook_ids: + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance + ).exp() * noise, logits.argmax(dim=1) + if return_x0: + return ( + model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, + x0, + ) + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm( + reversed(range(0, timesteps)), + desc="Progressive Generation", + total=timesteps, + ) + if verbose + else reversed(range(0, timesteps)) + ) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + ) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): + if shape is None: + shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + **kwargs, + ) + + @torch.no_grad() + def sample_log( + self, + cond, + batch_size, + ddim, + ddim_steps, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_plms=False, + mask=None, + **kwargs, + ): + + if mask is not None: + shape = (self.channels, mask.size()[-2], mask.size()[-1]) + else: + shape = (self.channels, self.latent_t_size, self.latent_f_size) + + intermediate = None + if ddim and not use_plms: + # print("Use ddim sampler") + + ddim_sampler = DDIMSampler(self) + samples, intermediates = ddim_sampler.sample( + ddim_steps, + batch_size, + shape, + cond, + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask=mask, + **kwargs, + ) + + else: + # print("Use DDPM sampler") + samples, intermediates = self.sample( + cond=cond, + batch_size=batch_size, + return_intermediates=True, + unconditional_guidance_scale=unconditional_guidance_scale, + mask=mask, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + + return samples, intermediate + + + @torch.no_grad() + def generate_sample( + self, + batchs, + ddim_steps=200, + ddim_eta=1.0, + x_T=None, + n_candidate_gen_per_text=1, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + name="waveform", + use_plms=False, + save=False, + **kwargs, + ): + # Generate n_candidate_gen_per_text times and select the best + # Batch: audio, text, fnames + assert x_T is None + try: + batchs = iter(batchs) + except TypeError: + raise ValueError("The first input argument should be an iterable object") + + if use_plms: + assert ddim_steps is not None + use_ddim = ddim_steps is not None + # waveform_save_path = os.path.join(self.get_log_dir(), name) + # os.makedirs(waveform_save_path, exist_ok=True) + # print("Waveform save path: ", waveform_save_path) + + with self.ema_scope("Generate"): + for batch in batchs: + z, c = self.get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=False, + force_c_encode=True, + return_original_cond=False, + bs=None, + ) + text = super().get_input(batch, "text") + + # Generate multiple samples + batch_size = z.shape[0] * n_candidate_gen_per_text + c = torch.cat([c] * n_candidate_gen_per_text, dim=0) + text = text * n_candidate_gen_per_text + + if unconditional_guidance_scale != 1.0: + unconditional_conditioning = ( + self.cond_stage_model.get_unconditional_condition(batch_size) + ) + + samples, _ = self.sample_log( + cond=c, + batch_size=batch_size, + x_T=x_T, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + use_plms=use_plms, + ) + + mel = self.decode_first_stage(samples) + + waveform = self.mel_spectrogram_to_waveform(mel) + + if(waveform.shape[0] > 1): + similarity = self.cond_stage_model.cos_similarity( + torch.FloatTensor(waveform).squeeze(1), text + ) + + best_index = [] + for i in range(z.shape[0]): + candidates = similarity[i :: z.shape[0]] + max_index = torch.argmax(candidates).item() + best_index.append(i + max_index * z.shape[0]) + + waveform = waveform[best_index] + # print("Similarity between generated audio and text", similarity) + # print("Choose the following indexes:", best_index) + + return waveform diff --git a/audioldm/pipeline.py b/audioldm/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d8ecb9ff68eeb871a0bcca3b11d0a7956091acab --- /dev/null +++ b/audioldm/pipeline.py @@ -0,0 +1,78 @@ + + +import os + +import argparse +import yaml +import torch + +from audioldm import LatentDiffusion, seed_everything +from audioldm.utils import default_audioldm_config + + +import time + +def make_batch_for_text_to_audio(text, batchsize=1): + text = [text] * batchsize + if batchsize < 1: + print("Warning: Batchsize must be at least 1. Batchsize is set to .") + fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format + stft = torch.zeros((batchsize, 1024, 512)) # Not used + waveform = torch.zeros((batchsize, 160000)) # Not used + fname = [""] * batchsize # Not used + batch = ( + fbank, + stft, + None, + fname, + waveform, + text, + ) + return batch + +def build_model(config=None): + if(torch.cuda.is_available()): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + + if(config is not None): + assert type(config) is str + config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) + else: + config = default_audioldm_config() + + # Use text as condition instead of using waveform during training + config["model"]["params"]["device"] = device + config["model"]["params"]["cond_stage_key"] = "text" + + # No normalization here + latent_diffusion = LatentDiffusion(**config["model"]["params"]) + + resume_from_checkpoint = "./ckpt/ldm_trimmed.ckpt" + + checkpoint = torch.load(resume_from_checkpoint, map_location=device) + latent_diffusion.load_state_dict(checkpoint["state_dict"]) + + latent_diffusion.eval() + latent_diffusion = latent_diffusion.to(device) + + latent_diffusion.cond_stage_model.embed_mode = "text" + return latent_diffusion + +def duration_to_latent_t_size(duration): + return int(duration * 25.6) + +def text_to_audio(latent_diffusion, text, seed=42, duration=10, batchsize=1, guidance_scale=2.5, n_candidate_gen_per_text=3, config=None): + seed_everything(int(seed)) + batch = make_batch_for_text_to_audio(text, batchsize=batchsize) + + latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) + with torch.no_grad(): + waveform = latent_diffusion.generate_sample( + [batch], + unconditional_guidance_scale=guidance_scale, + n_candidate_gen_per_text=n_candidate_gen_per_text, + duration=duration + ) + return waveform diff --git a/audioldm/utils.py b/audioldm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2a22213b627ebee77ab3d0bda3a59d1c3ade4040 --- /dev/null +++ b/audioldm/utils.py @@ -0,0 +1,73 @@ +import importlib + +from inspect import isfunction + +import os +import soundfile as sf + +def seed_everything(seed): + import random, os + import numpy as np + import torch + + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + +def save_wave(waveform, savepath, name="outwav"): + if type(name) is not list: + name = [name] * waveform.shape[0] + + for i in range(waveform.shape[0]): + path = os.path.join( + savepath, + "%s_%s.wav" + % ( + os.path.basename(name[i]) + if (not ".wav" in name[i]) + else os.path.basename(name[i]).split(".")[0], + i, + ), + ) + sf.write(path, waveform[i, 0], samplerate=16000) + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + +def default_audioldm_config(): + return {'wave_file_save_path': './output', 'id': {'version': 'v1', 'name': 'default', 'root': '/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml'}, 'model': {'device': 'cuda', 'reload_from_ckpt': '/mnt/fast/nobackup/scratch4weeks/hl01486/exps/audio_generation/stablediffusion/LDM/audioverse/2023_01_14_full_F4_B_spatial_v2_v1/checkpoints/last.ckpt', 'target': 'audioldm.pipline.LatentDiffusion', 'params': {'base_learning_rate': 5e-06, 'linear_start': 0.0015, 'linear_end': 0.0195, 'num_timesteps_cond': 1, 'log_every_t': 200, 'timesteps': 1000, 'first_stage_key': 'fbank', 'cond_stage_key': 'waveform', 'latent_t_size': 256, 'latent_f_size': 16, 'channels': 8, 'cond_stage_trainable': True, 'conditioning_key': 'film', 'monitor': 'val/loss_simple_ema', 'scale_by_std': True, 'unet_config': {'target': 'audioldm.latent_diffusion.openaimodel.UNetModel', 'params': {'image_size': 64, 'extra_film_condition_dim': 512, 'extra_film_use_concat': True, 'in_channels': 8, 'out_channels': 8, 'model_channels': 128, 'attention_resolutions': [8, 4, 2], 'num_res_blocks': 2, 'channel_mult': [1, 2, 3, 5], 'num_head_channels': 32, 'use_spatial_transformer': True}}, 'first_stage_config': {'base_learning_rate': 4.5e-05, 'target': 'audioldm.variational_autoencoder.autoencoder.AutoencoderKL', 'params': {'monitor': 'val/rec_loss', 'image_key': 'fbank', 'subband': 1, 'embed_dim': 8, 'time_shuffle': 1, 'ddconfig': {'double_z': True, 'z_channels': 8, 'resolution': 256, 'downsample_time': False, 'in_channels': 1, 'out_ch': 1, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}}}, 'cond_stage_config': {'target': 'audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2', 'params': {'key': 'waveform', 'sampling_rate': 16000, 'embed_mode': 'audio', 'unconditional_prob': 0.1}}}}} \ No newline at end of file diff --git a/audioldm/variational_autoencoder/__init__.py b/audioldm/variational_autoencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm/variational_autoencoder/autoencoder.py b/audioldm/variational_autoencoder/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..cfbdceba0e1171b052f797885530bacd0f3c73d5 --- /dev/null +++ b/audioldm/variational_autoencoder/autoencoder.py @@ -0,0 +1,102 @@ +import torch +from audioldm.latent_diffusion.ema import * +from audioldm.variational_autoencoder.modules import Encoder, Decoder +from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution + +from audioldm.hifigan.utilities import get_vocoder, vocoder_infer + +class AutoencoderKL(nn.Module): + def __init__( + self, + ddconfig=None, + lossconfig=None, + image_key="fbank", + embed_dim=None, + time_shuffle=1, + subband=1, + ckpt_path=None, + reload_from_ckpt=None, + ignore_keys=[], + colorize_nlabels=None, + monitor=None, + base_learning_rate=1e-5, + ): + super().__init__() + + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + + self.subband = int(subband) + + if self.subband > 1: + print("Use subband decomposition %s" % self.subband) + + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + + self.vocoder = get_vocoder(None, "cpu") + self.embed_dim = embed_dim + + if monitor is not None: + self.monitor = monitor + + self.time_shuffle = time_shuffle + self.reload_from_ckpt = reload_from_ckpt + self.reloaded = False + self.mean, self.std = None, None + + def encode(self, x): + # x = self.time_shuffle_operation(x) + x = self.freq_split_subband(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + dec = self.freq_merge_subband(dec) + return dec + + def decode_to_waveform(self, dec): + dec = dec.squeeze(1).permute(0, 2, 1) + wav_reconstruction = vocoder_infer(dec, self.vocoder) + return wav_reconstruction + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + + if self.flag_first_run: + print("Latent size: ", z.size()) + self.flag_first_run = False + + dec = self.decode(z) + + return dec, posterior + + def freq_split_subband(self, fbank): + if self.subband == 1 or self.image_key != "stft": + return fbank + + bs, ch, tstep, fbins = fbank.size() + + assert fbank.size(-1) % self.subband == 0 + assert ch == 1 + + return ( + fbank.squeeze(1) + .reshape(bs, tstep, self.subband, fbins // self.subband) + .permute(0, 2, 1, 3) + ) + + def freq_merge_subband(self, subband_fbank): + if self.subband == 1 or self.image_key != "stft": + return subband_fbank + assert subband_fbank.size(1) == self.subband # Channel dimension + bs, sub_ch, tstep, fbins = subband_fbank.size() + return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1) diff --git a/audioldm/variational_autoencoder/distributions.py b/audioldm/variational_autoencoder/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..58eb535e7769f402169ddff77ee45c96ba3650d9 --- /dev/null +++ b/audioldm/variational_autoencoder/distributions.py @@ -0,0 +1,102 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/audioldm/variational_autoencoder/modules.py b/audioldm/variational_autoencoder/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2c3dca2d168fb5fbaff5acc4b5a06280a496a7 --- /dev/null +++ b/audioldm/variational_autoencoder/modules.py @@ -0,0 +1,1064 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from audioldm.utils import instantiate_from_config +from audioldm.latent_diffusion.attention import LinearAttention + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class UpsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=1, padding=2 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class DownsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + downsample_time_stride4_levels=[], + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level in self.downsample_time_stride4_levels: + down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) + else: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + downsample_time_stride4_levels=[], + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + # print("Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level - 1 in self.downsample_time_stride4_levels: + up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) + else: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x).contiguous() + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x + + +class FirstStagePostProcessor(nn.Module): + def __init__( + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, + ): + super().__init__() + if pretrained_config is None: + assert ( + pretrained_model is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert ( + pretrained_config is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d( + in_channels, n_channels, kernel_size=3, stride=1, padding=1 + ) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append( + ResnetBlock( + in_channels=ch_in, out_channels=m * n_channels, dropout=dropout + ) + ) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, "b c h w -> b (h w) c") + return z diff --git a/ckpt/ldm_trimmed.ckpt b/ckpt/ldm_trimmed.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..5925f8aca1f0bf7be45276a956f056ac6c2aa84e --- /dev/null +++ b/ckpt/ldm_trimmed.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f8d1923410622be823279b61967d27a2df3fd03ddd764afb298e7c20ef8877d +size 2558947469 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..93174b5dd3fad9851e3963c0aca028c811675d50 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +git+https://github.com/huggingface/diffusers.git +--extra-index-url https://download.pytorch.org/whl/cu113 +torch +scipy +torchaudio>=0.13.0 +torchvision>=0.14.0 +tqdm +pyyaml +einops +numpy<=1.23.5 +soundfile +librosa +pandas +# transformers +torchlibrosa +transformers +ftfy \ No newline at end of file diff --git a/share_btn.py b/share_btn.py new file mode 100644 index 0000000000000000000000000000000000000000..b8c2ed17439625f85fd0e910766c727b29131e3d --- /dev/null +++ b/share_btn.py @@ -0,0 +1,60 @@ +community_icon_html = """""" + +loading_icon_html = """""" + +share_js = """async () => { + async function uploadFile(file){ + const UPLOAD_URL = 'https://huggingface.co/uploads'; + const response = await fetch(UPLOAD_URL, { + method: 'POST', + headers: { + 'Content-Type': file.type, + 'X-Requested-With': 'XMLHttpRequest', + }, + body: file, /// <- File inherits from Blob + }); + const url = await response.text(); + return url; + } + const gradioEl = document.querySelector('body > gradio-app'); + const imgEls = gradioEl.querySelectorAll('#gallery img'); + const promptTxt = gradioEl.querySelector('#prompt-text-input input').value; + const shareBtnEl = gradioEl.querySelector('#share-btn'); + const shareIconEl = gradioEl.querySelector('#share-btn-share-icon'); + const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon'); + if(!imgEls.length){ + return; + }; + shareBtnEl.style.pointerEvents = 'none'; + shareIconEl.style.display = 'none'; + loadingIconEl.style.removeProperty('display'); + const files = await Promise.all( + [...imgEls].map(async (imgEl) => { + const res = await fetch(imgEl.src); + const blob = await res.blob(); + const imgId = Date.now() % 200; + const fileName = `diffuse-the-rest-${{imgId}}.jpg`; + return new File([blob], fileName, { type: 'image/jpeg' }); + }) + ); + const urls = await Promise.all(files.map((f) => uploadFile(f))); + const htmlImgs = urls.map(url => ``); + const descriptionMd = `
+${htmlImgs.join(`\n`)} +
`; + const params = new URLSearchParams({ + title: promptTxt, + description: descriptionMd, + }); + const paramsStr = params.toString(); + window.open(`https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/discussions/new?${paramsStr}`, '_blank'); + shareBtnEl.style.removeProperty('pointer-events'); + shareIconEl.style.removeProperty('display'); + loadingIconEl.style.display = 'none'; +}""" \ No newline at end of file