diff --git a/README.md b/README.md index 49b3f9d9b337749338e72cf5d6782b9562943a8b..a267d537ae6a23859d5640388bd6ccbf04a480f3 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ sdk_version: 3.27.0 app_file: app.py pinned: false license: bigscience-openrail-m +duplicated_from: haoheliu/audioldm-text-to-audio-generation --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py index 629eab165af167483b3def8286581e7270b8e01c..113d2d66f0ec1a6f0cdd393ea6763f7f71a483e7 100644 --- a/app.py +++ b/app.py @@ -1,197 +1,139 @@ import gradio as gr -import numpy as np -from audioldm import text_to_audio, build_model +import torch +from diffusers import AudioLDMPipeline from share_btn import community_icon_html, loading_icon_html, share_js -model_id="haoheliu/AudioLDM-S-Full" +from transformers import AutoProcessor, ClapModel -audioldm = None -current_model_name = None -# def predict(input, history=[]): -# # tokenize the new input sentence -# new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt') +# make Space compatible with CPU duplicates +if torch.cuda.is_available(): + device = "cuda" + torch_dtype = torch.float16 +else: + device = "cpu" + torch_dtype = torch.float32 -# # append the new user input tokens to the chat history -# bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) +# load the diffusers pipeline +repo_id = "cvssp/audioldm-m-full" +pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device) +pipe.unet = torch.compile(pipe.unet) -# # generate a response -# history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist() +# CLAP model (only required for automatic scoring) +clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device) +processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full") -# # convert the tokens to text, and then split the responses into lines -# response = tokenizer.decode(history[0]).split("<|endoftext|>") -# response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list -# return response, history - -def text2audio(text, duration, guidance_scale, random_seed, n_candidates, model_name="audioldm-m-text-ft"): - global audioldm, current_model_name - - if audioldm is None or model_name != current_model_name: - audioldm=build_model(model_name=model_name) - current_model_name = model_name - - # print(text, length, guidance_scale) - waveform = text_to_audio( - latent_diffusion=audioldm, - text=text, - seed=random_seed, - duration=duration, +generator = torch.Generator(device) + + +def text2audio(text, negative_prompt, duration, guidance_scale, random_seed, n_candidates): + if text is None: + raise gr.Error("Please provide a text input.") + + waveforms = pipe( + text, + audio_length_in_s=duration, guidance_scale=guidance_scale, - n_candidate_gen_per_text=int(n_candidates), - ) # [bs, 1, samples] - waveform = [ - gr.make_waveform((16000, wave[0]), bg_image="bg.png") for wave in waveform - ] - # waveform = [(16000, np.random.randn(16000)), (16000, np.random.randn(16000))] - if(len(waveform) == 1): - waveform = waveform[0] - return waveform + negative_prompt=negative_prompt, + num_waveforms_per_prompt=n_candidates if n_candidates else 1, + generator=generator.manual_seed(int(random_seed)), + )["audios"] + + if waveforms.shape[0] > 1: + waveform = score_waveforms(text, waveforms) + else: + waveform = waveforms[0] + + return gr.make_waveform((16000, waveform), bg_image="bg.png") + -# iface = gr.Interface(fn=text2audio, inputs=[ -# gr.Textbox(value="A man is speaking in a huge room", max_lines=1), -# gr.Slider(2.5, 10, value=5, step=2.5), -# gr.Slider(0, 5, value=2.5, step=0.5), -# gr.Number(value=42) -# ], outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")], -# allow_flagging="never" -# ) -# iface.launch(share=True) +def score_waveforms(text, waveforms): + inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True) + inputs = {key: inputs[key].to(device) for key in inputs} + with torch.no_grad(): + logits_per_text = clap_model(**inputs).logits_per_text # this is the audio-text similarity score + probs = logits_per_text.softmax(dim=-1) # we can take the softmax to get the label probabilities + most_probable = torch.argmax(probs) # and now select the most likely audio waveform + waveform = waveforms[most_probable] + return waveform css = """ a { - color: inherit; - text-decoration: underline; - } - .gradio-container { + color: inherit; text-decoration: underline; + } .gradio-container { font-family: 'IBM Plex Sans', sans-serif; - } - .gr-button { - color: white; - border-color: #000000; - background: #000000; - } - input[type='range'] { + } .gr-button { + color: white; border-color: #000000; background: #000000; + } input[type='range'] { accent-color: #000000; - } - .dark input[type='range'] { + } .dark input[type='range'] { accent-color: #dfdfdf; - } - .container { - max-width: 730px; - margin: auto; - padding-top: 1.5rem; - } - #gallery { - min-height: 22rem; - margin-bottom: 15px; - margin-left: auto; - margin-right: auto; - border-bottom-right-radius: .5rem !important; - border-bottom-left-radius: .5rem !important; - } - #gallery>div>.h-full { + } .container { + max-width: 730px; margin: auto; padding-top: 1.5rem; + } #gallery { + min-height: 22rem; margin-bottom: 15px; margin-left: auto; margin-right: auto; border-bottom-right-radius: + .5rem !important; border-bottom-left-radius: .5rem !important; + } #gallery>div>.h-full { min-height: 20rem; - } - .details:hover { + } .details:hover { text-decoration: underline; - } - .gr-button { + } .gr-button { white-space: nowrap; - } - .gr-button:focus { - border-color: rgb(147 197 253 / var(--tw-border-opacity)); - outline: none; - box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); - --tw-border-opacity: 1; - --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); - --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); - --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); - --tw-ring-opacity: .5; - } - #advanced-btn { - font-size: .7rem !important; - line-height: 19px; - margin-top: 12px; - margin-bottom: 12px; - padding: 2px 8px; + } .gr-button:focus { + border-color: rgb(147 197 253 / var(--tw-border-opacity)); outline: none; box-shadow: + var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); --tw-border-opacity: 1; + --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) + var(--tw-ring-offset-color); --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px + var(--tw-ring-offset-width)) var(--tw-ring-color); --tw-ring-color: rgb(191 219 254 / + var(--tw-ring-opacity)); --tw-ring-opacity: .5; + } #advanced-btn { + font-size: .7rem !important; line-height: 19px; margin-top: 12px; margin-bottom: 12px; padding: 2px 8px; border-radius: 14px !important; - } - #advanced-options { + } #advanced-options { margin-bottom: 20px; - } - .footer { - margin-bottom: 45px; - margin-top: 35px; - text-align: center; - border-bottom: 1px solid #e5e5e5; - } - .footer>p { - font-size: .8rem; - display: inline-block; - padding: 0 10px; - transform: translateY(10px); - background: white; - } - .dark .footer { + } .footer { + margin-bottom: 45px; margin-top: 35px; text-align: center; border-bottom: 1px solid #e5e5e5; + } .footer>p { + font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white; + } .dark .footer { border-color: #303030; - } - .dark .footer>p { + } .dark .footer>p { background: #0b0f19; - } - .acknowledgments h4{ - margin: 1.25em 0 .25em 0; - font-weight: bold; - font-size: 115%; - } - #container-advanced-btns{ - display: flex; - flex-wrap: wrap; - justify-content: space-between; - align-items: center; - } - .animate-spin { + } .acknowledgments h4{ + margin: 1.25em 0 .25em 0; font-weight: bold; font-size: 115%; + } #container-advanced-btns{ + display: flex; flex-wrap: wrap; justify-content: space-between; align-items: center; + } .animate-spin { animation: spin 1s linear infinite; - } - @keyframes spin { + } @keyframes spin { from { transform: rotate(0deg); - } - to { + } to { transform: rotate(360deg); } - } - #share-btn-container { - display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; - margin-top: 10px; - margin-left: auto; - } - #share-btn { - all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; - } - #share-btn * { + } #share-btn-container { + display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: + #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; + margin-top: 10px; margin-left: auto; + } #share-btn { + all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; + margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem + !important;right:0; + } #share-btn * { all: unset; - } - #share-btn-container div:nth-child(-n+2){ - width: auto !important; - min-height: 0px !important; - } - #share-btn-container .wrap { + } #share-btn-container div:nth-child(-n+2){ + width: auto !important; min-height: 0px !important; + } #share-btn-container .wrap { display: none !important; - } - .gr-form{ + } .gr-form{ flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0; - } - #prompt-container{ + } #prompt-container{ gap: 0; - } - #generated_id{ + } #generated_id{ min-height: 700px - } - #setting_id{ - margin-bottom: 12px; - text-align: center; - font-weight: 900; + } #setting_id{ + margin-bottom: 12px; text-align: center; font-weight: 900; } """ iface = gr.Blocks(css=css) @@ -202,56 +144,72 @@ with iface:

AudioLDM: Text-to-Audio Generation with Latent Diffusion Models

-
-

- [Paper] [Project page] +

+ [Paper] [Project + page] [🧨 + Diffusers]

""" ) - gr.HTML(""" -

- AudioLDM: Text-to-Audio Generation with Latent Diffusion Models -

-

For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. -
- - Duplicate Space -

- """) + gr.HTML( + """ +

This is the demo for AudioLDM, powered by 🧨 Diffusers. Demo uses the checkpoint audioldm-m-full . For faster inference without waiting in + queue, you may duplicate the space and upgrade to a GPU in the settings.
Duplicate Space

+ """ + ) + with gr.Group(): with gr.Box(): - ############# Input - textbox = gr.Textbox(value="A hammer is hitting a wooden surface", max_lines=1, label="Input your text here. Your text is important for the audio quality. Please ensure it is descriptive by using more adjectives.", elem_id="prompt-in") + textbox = gr.Textbox( + value="A hammer is hitting a wooden surface", + max_lines=1, + label="Input text", + info="Your text is important for the audio quality. Please ensure it is descriptive by using more adjectives.", + elem_id="prompt-in", + ) + negative_textbox = gr.Textbox( + value="low quality, average quality", + max_lines=1, + label="Negative prompt", + info="Enter a negative prompt not to guide the audio generation. Selecting appropriate negative prompts can improve the audio quality significantly.", + elem_id="prompt-in", + ) with gr.Accordion("Click to modify detailed configurations", open=False): - seed = gr.Number(value=45, label="Change this value (any integer number) will lead to a different generation result.") - duration = gr.Slider(2.5, 10, value=5, step=2.5, label="Duration (seconds)") - guidance_scale = gr.Slider(0, 4, value=2.5, step=0.5, label="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)") - n_candidates = gr.Slider(1, 3, value=3, step=1, label="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation") - # model_name = gr.Dropdown( - # ["audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full","audioldm-s-full-v2", "audioldm-s-full", "audioldm-l-full"], value="audioldm-m-full", label="Choose the model to use. audioldm-m-text-ft and audioldm-s-text-ft are recommanded. -s- means small, -m- means medium and -l- means large", - # ) - ############# Output - # outputs=gr.Audio(label="Output", type="numpy") - outputs=gr.Video(label="Output", elem_id="output-video") - - # with gr.Group(elem_id="container-advanced-btns"): - # # advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") - # with gr.Group(elem_id="share-btn-container"): - # community_icon = gr.HTML(community_icon_html, visible=False) - # loading_icon = gr.HTML(loading_icon_html, visible=False) - # share_button = gr.Button("Share to community", elem_id="share-btn", visible=False) - # outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")] + seed = gr.Number( + value=45, + label="Seed", + info="Change this value (any integer number) will lead to a different generation result.", + ) + duration = gr.Slider(2.5, 10, value=5, step=2.5, label="Duration (seconds)") + guidance_scale = gr.Slider( + 0, + 4, + value=2.5, + step=0.5, + label="Guidance scale", + info="Large => better quality and relevancy to text; Small => better diversity", + ) + n_candidates = gr.Slider( + 1, + 3, + value=3, + step=1, + label="Number waveforms to generate", + info="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", + ) + + outputs = gr.Video(label="Output", elem_id="output-video") btn = gr.Button("Submit").style(full_width=True) with gr.Group(elem_id="share-btn-container", visible=False): @@ -259,51 +217,61 @@ with iface: loading_icon = gr.HTML(loading_icon_html) share_button = gr.Button("Share to community", elem_id="share-btn") - # btn.click(text2audio, inputs=[ - # textbox, duration, guidance_scale, seed, n_candidates, model_name], outputs=[outputs]) - btn.click(text2audio, inputs=[ - textbox, duration, guidance_scale, seed, n_candidates], outputs=[outputs]) - + btn.click( + text2audio, + inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates], + outputs=[outputs], + ) + share_button.click(None, [], [], _js=share_js) - gr.HTML(''' + gr.HTML( + """

- ''') - gr.Examples([ - ["A hammer is hitting a wooden surface", 5, 2.5, 45, 3, "audioldm-m-full"], - ["Peaceful and calming ambient music with singing bowl and other instruments.", 5, 2.5, 45, 3, "audioldm-m-full"], - ["A man is speaking in a small room.", 5, 2.5, 45, 3, "audioldm-m-full"], - ["A female is speaking followed by footstep sound", 5, 2.5, 45, 3, "audioldm-m-full"], - ["Wooden table tapping sound followed by water pouring sound.", 5, 2.5, 45, 3, "audioldm-m-full"], - ], + """ + ) + gr.Examples( + [ + ["A hammer is hitting a wooden surface", "low quality, average quality", 5, 2.5, 45, 3], + ["Peaceful and calming ambient music with singing bowl and other instruments.", "low quality, average quality", 5, 2.5, 45, 3], + ["A man is speaking in a small room.", "low quality, average quality", 5, 2.5, 45, 3], + ["A female is speaking followed by footstep sound", "low quality, average quality", 5, 2.5, 45, 3], + ["Wooden table tapping sound followed by water pouring sound.", "low quality, average quality", 5, 2.5, 45, 3], + ], fn=text2audio, - # inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name], - inputs=[textbox, duration, guidance_scale, seed, n_candidates], + inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates], outputs=[outputs], cache_examples=True, ) - gr.HTML(''' -
-

Essential Tricks for Enhancing the Quality of Your Generated Audio

-

1. Try to use more adjectives to describe your sound. For example: "A man is speaking clearly and slowly in a large room" is better than "A man is speaking". This can make sure AudioLDM understands what you want.

-

2. Try to use different random seeds, which can affect the generation quality significantly sometimes.

-

3. It's better to use general terms like 'man' or 'woman' instead of specific names for individuals or abstract objects that humans may not be familiar with, such as 'mummy'.

-
- ''') + gr.HTML( + """ +

Essential Tricks for Enhancing the Quality of Your Generated + Audio

1. Try to use more adjectives to describe your sound. For example: "A man is speaking + clearly and slowly in a large room" is better than "A man is speaking". This can make sure AudioLDM + understands what you want.

2. Try to use different random seeds, which can affect the generation + quality significantly sometimes.

3. It's better to use general terms like 'man' or 'woman' + instead of specific names for individuals or abstract objects that humans may not be familiar with, + such as 'mummy'.

4. Using a negative prompt to not guide the diffusion process can improve the + audio quality significantly. Try using negative prompts like 'low quality'.

+ """ + ) with gr.Accordion("Additional information", open=False): gr.HTML( """
-

We build the model with data from AudioSet, Freesound and BBC Sound Effect library. We share this demo based on the UK copyright exception of data for academic research.

+

We build the model with data from AudioSet, + Freesound and BBC Sound Effect library. We share this demo + based on the UK + copyright exception of data for academic research.

""" ) #

This demo is strictly for research demo purpose only. For commercial use please contact us.

iface.queue(max_size=10).launch(debug=True) -# iface.launch(debug=True, share=True) diff --git a/audioldm/__init__.py b/audioldm/__init__.py deleted file mode 100644 index 2f93cab80ded8e7239bb96eb6e364c3fd4fb46d9..0000000000000000000000000000000000000000 --- a/audioldm/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .ldm import LatentDiffusion -from .utils import seed_everything -from .pipeline import * \ No newline at end of file diff --git a/audioldm/audio/__init__.py b/audioldm/audio/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm/audio/audio_processing.py b/audioldm/audio/audio_processing.py deleted file mode 100644 index 77a4057aa82f226f68474f4c2a19eba84510d663..0000000000000000000000000000000000000000 --- a/audioldm/audio/audio_processing.py +++ /dev/null @@ -1,100 +0,0 @@ -import torch -import numpy as np -import librosa.util as librosa_util -from scipy.signal import get_window - - -def window_sumsquare( - window, - n_frames, - hop_length, - win_length, - n_fft, - dtype=np.float32, - norm=None, -): - """ - # from librosa 0.6 - Compute the sum-square envelope of a window function at a given hop length. - - This is used to estimate modulation effects induced by windowing - observations in short-time fourier transforms. - - Parameters - ---------- - window : string, tuple, number, callable, or list-like - Window specification, as in `get_window` - - n_frames : int > 0 - The number of analysis frames - - hop_length : int > 0 - The number of samples to advance between frames - - win_length : [optional] - The length of the window function. By default, this matches `n_fft`. - - n_fft : int > 0 - The length of each analysis frame. - - dtype : np.dtype - The data type of the output - - Returns - ------- - wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` - The sum-squared envelope of the window function - """ - if win_length is None: - win_length = n_fft - - n = n_fft + hop_length * (n_frames - 1) - x = np.zeros(n, dtype=dtype) - - # Compute the squared window at the desired length - win_sq = get_window(window, win_length, fftbins=True) - win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 - win_sq = librosa_util.pad_center(win_sq, n_fft) - - # Fill the envelope - for i in range(n_frames): - sample = i * hop_length - x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] - return x - - -def griffin_lim(magnitudes, stft_fn, n_iters=30): - """ - PARAMS - ------ - magnitudes: spectrogram magnitudes - stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods - """ - - angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) - angles = angles.astype(np.float32) - angles = torch.autograd.Variable(torch.from_numpy(angles)) - signal = stft_fn.inverse(magnitudes, angles).squeeze(1) - - for i in range(n_iters): - _, angles = stft_fn.transform(signal) - signal = stft_fn.inverse(magnitudes, angles).squeeze(1) - return signal - - -def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): - """ - PARAMS - ------ - C: compression factor - """ - return normalize_fun(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression(x, C=1): - """ - PARAMS - ------ - C: compression factor used to compress - """ - return torch.exp(x) / C diff --git a/audioldm/audio/stft.py b/audioldm/audio/stft.py deleted file mode 100644 index 2aa1ac89277734a6676c20a81bf88e21e8ca7aa9..0000000000000000000000000000000000000000 --- a/audioldm/audio/stft.py +++ /dev/null @@ -1,180 +0,0 @@ -import torch -import torch.nn.functional as F -import numpy as np -from scipy.signal import get_window -from librosa.util import pad_center, tiny -from librosa.filters import mel as librosa_mel_fn - -from audioldm.audio.audio_processing import ( - dynamic_range_compression, - dynamic_range_decompression, - window_sumsquare, -) - - -class STFT(torch.nn.Module): - """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" - - def __init__(self, filter_length, hop_length, win_length, window="hann"): - super(STFT, self).__init__() - self.filter_length = filter_length - self.hop_length = hop_length - self.win_length = win_length - self.window = window - self.forward_transform = None - scale = self.filter_length / self.hop_length - fourier_basis = np.fft.fft(np.eye(self.filter_length)) - - cutoff = int((self.filter_length / 2 + 1)) - fourier_basis = np.vstack( - [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] - ) - - forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) - inverse_basis = torch.FloatTensor( - np.linalg.pinv(scale * fourier_basis).T[:, None, :] - ) - - if window is not None: - assert filter_length >= win_length - # get window and zero center pad it to filter_length - fft_window = get_window(window, win_length, fftbins=True) - fft_window = pad_center(fft_window, filter_length) - fft_window = torch.from_numpy(fft_window).float() - - # window the bases - forward_basis *= fft_window - inverse_basis *= fft_window - - self.register_buffer("forward_basis", forward_basis.float()) - self.register_buffer("inverse_basis", inverse_basis.float()) - - def transform(self, input_data): - num_batches = input_data.size(0) - num_samples = input_data.size(1) - - self.num_samples = num_samples - - # similar to librosa, reflect-pad the input - input_data = input_data.view(num_batches, 1, num_samples) - input_data = F.pad( - input_data.unsqueeze(1), - (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), - mode="reflect", - ) - input_data = input_data.squeeze(1) - - forward_transform = F.conv1d( - input_data, - torch.autograd.Variable(self.forward_basis, requires_grad=False), - stride=self.hop_length, - padding=0, - ).cpu() - - cutoff = int((self.filter_length / 2) + 1) - real_part = forward_transform[:, :cutoff, :] - imag_part = forward_transform[:, cutoff:, :] - - magnitude = torch.sqrt(real_part**2 + imag_part**2) - phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) - - return magnitude, phase - - def inverse(self, magnitude, phase): - recombine_magnitude_phase = torch.cat( - [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 - ) - - inverse_transform = F.conv_transpose1d( - recombine_magnitude_phase, - torch.autograd.Variable(self.inverse_basis, requires_grad=False), - stride=self.hop_length, - padding=0, - ) - - if self.window is not None: - window_sum = window_sumsquare( - self.window, - magnitude.size(-1), - hop_length=self.hop_length, - win_length=self.win_length, - n_fft=self.filter_length, - dtype=np.float32, - ) - # remove modulation effects - approx_nonzero_indices = torch.from_numpy( - np.where(window_sum > tiny(window_sum))[0] - ) - window_sum = torch.autograd.Variable( - torch.from_numpy(window_sum), requires_grad=False - ) - window_sum = window_sum - inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ - approx_nonzero_indices - ] - - # scale by hop ratio - inverse_transform *= float(self.filter_length) / self.hop_length - - inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] - inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] - - return inverse_transform - - def forward(self, input_data): - self.magnitude, self.phase = self.transform(input_data) - reconstruction = self.inverse(self.magnitude, self.phase) - return reconstruction - - -class TacotronSTFT(torch.nn.Module): - def __init__( - self, - filter_length, - hop_length, - win_length, - n_mel_channels, - sampling_rate, - mel_fmin, - mel_fmax, - ): - super(TacotronSTFT, self).__init__() - self.n_mel_channels = n_mel_channels - self.sampling_rate = sampling_rate - self.stft_fn = STFT(filter_length, hop_length, win_length) - mel_basis = librosa_mel_fn( - sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax - ) - mel_basis = torch.from_numpy(mel_basis).float() - self.register_buffer("mel_basis", mel_basis) - - def spectral_normalize(self, magnitudes, normalize_fun): - output = dynamic_range_compression(magnitudes, normalize_fun) - return output - - def spectral_de_normalize(self, magnitudes): - output = dynamic_range_decompression(magnitudes) - return output - - def mel_spectrogram(self, y, normalize_fun=torch.log): - """Computes mel-spectrograms from a batch of waves - PARAMS - ------ - y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] - - RETURNS - ------- - mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) - """ - assert torch.min(y.data) >= -1, torch.min(y.data) - assert torch.max(y.data) <= 1, torch.max(y.data) - - magnitudes, phases = self.stft_fn.transform(y) - magnitudes = magnitudes.data - mel_output = torch.matmul(self.mel_basis, magnitudes) - mel_output = self.spectral_normalize(mel_output, normalize_fun) - energy = torch.norm(magnitudes, dim=1) - - log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) - - return mel_output, log_magnitudes, energy diff --git a/audioldm/audio/tools.py b/audioldm/audio/tools.py deleted file mode 100644 index 7aca95cc1f5c120568a210907e9506589899a1c6..0000000000000000000000000000000000000000 --- a/audioldm/audio/tools.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -import numpy as np - - -def get_mel_from_wav(audio, _stft): - audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) - audio = torch.autograd.Variable(audio, requires_grad=False) - melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) - melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) - log_magnitudes_stft = ( - torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32) - ) - energy = torch.squeeze(energy, 0).numpy().astype(np.float32) - return melspec, log_magnitudes_stft, energy - - -# def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): -# mel = torch.stack([mel]) -# mel_decompress = _stft.spectral_de_normalize(mel) -# mel_decompress = mel_decompress.transpose(1, 2).data.cpu() -# spec_from_mel_scaling = 1000 -# spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) -# spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) -# spec_from_mel = spec_from_mel * spec_from_mel_scaling - -# audio = griffin_lim( -# torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters -# ) - -# audio = audio.squeeze() -# audio = audio.cpu().numpy() -# audio_path = out_filename -# write(audio_path, _stft.sampling_rate, audio) diff --git a/audioldm/clap/__init__.py b/audioldm/clap/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm/clap/encoders.py b/audioldm/clap/encoders.py deleted file mode 100644 index 5effd8efd3b933888c586199b5eaa89e632cab03..0000000000000000000000000000000000000000 --- a/audioldm/clap/encoders.py +++ /dev/null @@ -1,170 +0,0 @@ -import torch -import torch.nn as nn -from audioldm.clap.open_clip import create_model -from audioldm.clap.training.data import get_audio_features -import torchaudio -from transformers import RobertaTokenizer -import torch.nn.functional as F - - -class CLAPAudioEmbeddingClassifierFreev2(nn.Module): - def __init__( - self, - pretrained_path="", - key="class", - sampling_rate=16000, - embed_mode="audio", - amodel = "HTSAT-tiny", - unconditional_prob=0.1, - random_mute=False, - max_random_mute_portion=0.5, - training_mode=True, - ): - super().__init__() - - self.key = key - self.device = "cpu" - self.precision = "fp32" - self.amodel = amodel - self.tmodel = "roberta" # the best text encoder in our training - self.enable_fusion = False # False if you do not want to use the fusion model - self.fusion_type = "aff_2d" - self.pretrained = pretrained_path - self.embed_mode = embed_mode - self.embed_mode_orig = embed_mode - self.sampling_rate = sampling_rate - self.unconditional_prob = unconditional_prob - self.random_mute = random_mute - self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") - self.max_random_mute_portion = max_random_mute_portion - self.training_mode = training_mode - self.model, self.model_cfg = create_model( - self.amodel, - self.tmodel, - self.pretrained, - precision=self.precision, - device=self.device, - enable_fusion=self.enable_fusion, - fusion_type=self.fusion_type, - ) - for p in self.model.parameters(): - p.requires_grad = False - - self.model.eval() - - def get_unconditional_condition(self, batchsize): - self.unconditional_token = self.model.get_text_embedding( - self.tokenizer(["", ""]) - )[0:1] - return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) - - def batch_to_list(self, batch): - ret = [] - for i in range(batch.size(0)): - ret.append(batch[i]) - return ret - - def make_decision(self, probability): - if float(torch.rand(1)) < probability: - return True - else: - return False - - def random_uniform(self, start, end): - val = torch.rand(1).item() - return start + (end - start) * val - - def _random_mute(self, waveform): - # waveform: [bs, t-steps] - t_steps = waveform.size(-1) - for i in range(waveform.size(0)): - mute_size = int( - self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) - ) - mute_start = int(self.random_uniform(0, t_steps - mute_size)) - waveform[i, mute_start : mute_start + mute_size] = 0 - return waveform - - def cos_similarity(self, waveform, text): - # waveform: [bs, t_steps] - with torch.no_grad(): - self.embed_mode = "audio" - audio_emb = self(waveform.cuda()) - self.embed_mode = "text" - text_emb = self(text) - similarity = F.cosine_similarity(audio_emb, text_emb, dim=2) - return similarity.squeeze() - - def forward(self, batch, key=None): - # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 - # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 - if self.model.training == True and not self.training_mode: - print( - "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." - ) - self.model, self.model_cfg = create_model( - self.amodel, - self.tmodel, - self.pretrained, - precision=self.precision, - device="cuda", - enable_fusion=self.enable_fusion, - fusion_type=self.fusion_type, - ) - for p in self.model.parameters(): - p.requires_grad = False - self.model.eval() - - # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode - if self.embed_mode == "audio": - with torch.no_grad(): - audio_dict_list = [] - assert ( - self.sampling_rate == 16000 - ), "We only support 16000 sampling rate" - if self.random_mute: - batch = self._random_mute(batch) - # batch: [bs, 1, t-samples] - batch = torchaudio.functional.resample( - batch, orig_freq=self.sampling_rate, new_freq=48000 - ) - for waveform in self.batch_to_list(batch): - audio_dict = {} - audio_dict = get_audio_features( - audio_dict, - waveform, - 480000, - data_truncating="fusion", - data_filling="repeatpad", - audio_cfg=self.model_cfg["audio_cfg"], - ) - audio_dict_list.append(audio_dict) - # [bs, 512] - embed = self.model.get_audio_embedding(audio_dict_list) - elif self.embed_mode == "text": - with torch.no_grad(): - # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode - text_data = self.tokenizer(batch) - embed = self.model.get_text_embedding(text_data) - - embed = embed.unsqueeze(1) - self.unconditional_token = self.model.get_text_embedding( - self.tokenizer(["", ""]) - )[0:1] - - for i in range(embed.size(0)): - if self.make_decision(self.unconditional_prob): - embed[i] = self.unconditional_token - - # [bs, 1, 512] - return embed.detach() - - def tokenizer(self, text): - result = self.tokenize( - text, - padding="max_length", - truncation=True, - max_length=512, - return_tensors="pt", - ) - return {k: v.squeeze(0) for k, v in result.items()} diff --git a/audioldm/clap/open_clip/__init__.py b/audioldm/clap/open_clip/__init__.py deleted file mode 100644 index e9f728f2f273be5d5fdbec6c6cc41d737176a8c0..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from .factory import ( - list_models, - create_model, - create_model_and_transforms, - add_model_config, -) -from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics -from .model import ( - CLAP, - CLAPTextCfg, - CLAPVisionCfg, - CLAPAudioCfp, - convert_weights_to_fp16, - trace_model, -) -from .openai import load_openai_model, list_openai_models -from .pretrained import ( - list_pretrained, - list_pretrained_tag_models, - list_pretrained_model_tags, - get_pretrained_url, - download_pretrained, -) -from .tokenizer import SimpleTokenizer, tokenize -from .transform import image_transform diff --git a/audioldm/clap/open_clip/bert.py b/audioldm/clap/open_clip/bert.py deleted file mode 100644 index a83d96d2a77ed05198efc05837522bc88d2499cc..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/bert.py +++ /dev/null @@ -1,40 +0,0 @@ -from transformers import BertTokenizer, BertModel - -tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") -model = BertModel.from_pretrained("bert-base-uncased") -text = "Replace me by any text you'd like." - - -def bert_embeddings(text): - # text = "Replace me by any text you'd like." - encoded_input = tokenizer(text, return_tensors="pt") - output = model(**encoded_input) - return output - - -from transformers import RobertaTokenizer, RobertaModel - -tokenizer = RobertaTokenizer.from_pretrained("roberta-base") -model = RobertaModel.from_pretrained("roberta-base") -text = "Replace me by any text you'd like." - - -def Roberta_embeddings(text): - # text = "Replace me by any text you'd like." - encoded_input = tokenizer(text, return_tensors="pt") - output = model(**encoded_input) - return output - - -from transformers import BartTokenizer, BartModel - -tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") -model = BartModel.from_pretrained("facebook/bart-base") -text = "Replace me by any text you'd like." - - -def bart_embeddings(text): - # text = "Replace me by any text you'd like." - encoded_input = tokenizer(text, return_tensors="pt") - output = model(**encoded_input) - return output diff --git a/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz b/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz deleted file mode 100644 index 36a15856e00a06a9fbed8cdd34d2393fea4a3113..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a -size 1356917 diff --git a/audioldm/clap/open_clip/factory.py b/audioldm/clap/open_clip/factory.py deleted file mode 100644 index 844f9ca0e12a0ff43ba3e042a3e43530ebe91b8c..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/factory.py +++ /dev/null @@ -1,277 +0,0 @@ -import json -import logging -import os -import pathlib -import re -from copy import deepcopy -from pathlib import Path - -import torch - -from .model import CLAP, convert_weights_to_fp16 -from .openai import load_openai_model -from .pretrained import get_pretrained_url, download_pretrained -from .transform import image_transform - -_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] -_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs - - -def _natural_key(string_): - return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] - - -def _rescan_model_configs(): - global _MODEL_CONFIGS - - config_ext = (".json",) - config_files = [] - for config_path in _MODEL_CONFIG_PATHS: - if config_path.is_file() and config_path.suffix in config_ext: - config_files.append(config_path) - elif config_path.is_dir(): - for ext in config_ext: - config_files.extend(config_path.glob(f"*{ext}")) - - for cf in config_files: - if os.path.basename(cf)[0] == ".": - continue # Ignore hidden files - - with open(cf, "r") as f: - model_cfg = json.load(f) - if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): - _MODEL_CONFIGS[cf.stem] = model_cfg - - _MODEL_CONFIGS = { - k: v - for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) - } - - -_rescan_model_configs() # initial populate of model config registry - - -def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): - checkpoint = torch.load(checkpoint_path, map_location=map_location) - if isinstance(checkpoint, dict) and "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - else: - state_dict = checkpoint - if skip_params: - if next(iter(state_dict.items()))[0].startswith("module"): - state_dict = {k[7:]: v for k, v in state_dict.items()} - # for k in state_dict: - # if k.startswith('transformer'): - # v = state_dict.pop(k) - # state_dict['text_branch.' + k[12:]] = v - return state_dict - - -def create_model( - amodel_name: str, - tmodel_name: str, - pretrained: str = "", - precision: str = "fp32", - device: torch.device = torch.device("cpu"), - jit: bool = False, - force_quick_gelu: bool = False, - openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"), - skip_params=True, - pretrained_audio: str = "", - pretrained_text: str = "", - enable_fusion: bool = False, - fusion_type: str = "None" - # pretrained_image: bool = False, -): - amodel_name = amodel_name.replace( - "/", "-" - ) # for callers using old naming with / in ViT names - pretrained_orig = pretrained - pretrained = pretrained.lower() - if pretrained == "openai": - if amodel_name in _MODEL_CONFIGS: - logging.info(f"Loading {amodel_name} model config.") - model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) - else: - logging.error( - f"Model config for {amodel_name} not found; available models {list_models()}." - ) - raise RuntimeError(f"Model config for {amodel_name} not found.") - - logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") - # Hard Code in model name - model_cfg["text_cfg"]["model_type"] = tmodel_name - model = load_openai_model( - "ViT-B-16", - model_cfg, - device=device, - jit=jit, - cache_dir=openai_model_cache_dir, - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 - if precision == "amp" or precision == "fp32": - model = model.float() - else: - if amodel_name in _MODEL_CONFIGS: - logging.info(f"Loading {amodel_name} model config.") - model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) - else: - logging.error( - f"Model config for {amodel_name} not found; available models {list_models()}." - ) - raise RuntimeError(f"Model config for {amodel_name} not found.") - - if force_quick_gelu: - # override for use of QuickGELU on non-OpenAI transformer models - model_cfg["quick_gelu"] = True - - # if pretrained_image: - # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): - # # pretrained weight loading for timm models set via vision_cfg - # model_cfg['vision_cfg']['timm_model_pretrained'] = True - # else: - # assert False, 'pretrained image towers currently only supported for timm models' - model_cfg["text_cfg"]["model_type"] = tmodel_name - model_cfg["enable_fusion"] = enable_fusion - model_cfg["fusion_type"] = fusion_type - model = CLAP(**model_cfg) - - if pretrained: - checkpoint_path = "" - url = get_pretrained_url(amodel_name, pretrained) - if url: - checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) - elif os.path.exists(pretrained_orig): - checkpoint_path = pretrained_orig - if checkpoint_path: - logging.info( - f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})." - ) - ckpt = load_state_dict(checkpoint_path, skip_params=True) - model.load_state_dict(ckpt) - param_names = [n for n, p in model.named_parameters()] - # for n in param_names: - # print(n, "\t", "Loaded" if n in ckpt else "Unloaded") - else: - logging.warning( - f"Pretrained weights ({pretrained}) not found for model {amodel_name}." - ) - raise RuntimeError( - f"Pretrained weights ({pretrained}) not found for model {amodel_name}." - ) - - if pretrained_audio: - if amodel_name.startswith("PANN"): - if "Cnn14_mAP" in pretrained_audio: # official checkpoint - audio_ckpt = torch.load(pretrained_audio, map_location="cpu") - audio_ckpt = audio_ckpt["model"] - keys = list(audio_ckpt.keys()) - for key in keys: - if ( - "spectrogram_extractor" not in key - and "logmel_extractor" not in key - ): - v = audio_ckpt.pop(key) - audio_ckpt["audio_branch." + key] = v - elif os.path.basename(pretrained_audio).startswith( - "PANN" - ): # checkpoint trained via HTSAT codebase - audio_ckpt = torch.load(pretrained_audio, map_location="cpu") - audio_ckpt = audio_ckpt["state_dict"] - keys = list(audio_ckpt.keys()) - for key in keys: - if key.startswith("sed_model"): - v = audio_ckpt.pop(key) - audio_ckpt["audio_branch." + key[10:]] = v - elif os.path.basename(pretrained_audio).startswith( - "finetuned" - ): # checkpoint trained via linear probe codebase - audio_ckpt = torch.load(pretrained_audio, map_location="cpu") - else: - raise ValueError("Unknown audio checkpoint") - elif amodel_name.startswith("HTSAT"): - if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint - audio_ckpt = torch.load(pretrained_audio, map_location="cpu") - audio_ckpt = audio_ckpt["state_dict"] - keys = list(audio_ckpt.keys()) - for key in keys: - if key.startswith("sed_model") and ( - "spectrogram_extractor" not in key - and "logmel_extractor" not in key - ): - v = audio_ckpt.pop(key) - audio_ckpt["audio_branch." + key[10:]] = v - elif os.path.basename(pretrained_audio).startswith( - "HTSAT" - ): # checkpoint trained via HTSAT codebase - audio_ckpt = torch.load(pretrained_audio, map_location="cpu") - audio_ckpt = audio_ckpt["state_dict"] - keys = list(audio_ckpt.keys()) - for key in keys: - if key.startswith("sed_model"): - v = audio_ckpt.pop(key) - audio_ckpt["audio_branch." + key[10:]] = v - elif os.path.basename(pretrained_audio).startswith( - "finetuned" - ): # checkpoint trained via linear probe codebase - audio_ckpt = torch.load(pretrained_audio, map_location="cpu") - else: - raise ValueError("Unknown audio checkpoint") - else: - raise f"this audio encoder pretrained checkpoint is not support" - - model.load_state_dict(audio_ckpt, strict=False) - logging.info( - f"Loading pretrained {amodel_name} weights ({pretrained_audio})." - ) - param_names = [n for n, p in model.named_parameters()] - for n in param_names: - print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") - - model.to(device=device) - if precision == "fp16": - assert device.type != "cpu" - convert_weights_to_fp16(model) - - if jit: - model = torch.jit.script(model) - - return model, model_cfg - - -def create_model_and_transforms( - model_name: str, - pretrained: str = "", - precision: str = "fp32", - device: torch.device = torch.device("cpu"), - jit: bool = False, - force_quick_gelu: bool = False, - # pretrained_image: bool = False, -): - model = create_model( - model_name, - pretrained, - precision, - device, - jit, - force_quick_gelu=force_quick_gelu, - # pretrained_image=pretrained_image - ) - preprocess_train = image_transform(model.visual.image_size, is_train=True) - preprocess_val = image_transform(model.visual.image_size, is_train=False) - return model, preprocess_train, preprocess_val - - -def list_models(): - """enumerate available model architectures based on config files""" - return list(_MODEL_CONFIGS.keys()) - - -def add_model_config(path): - """add model config path or file and update registry""" - if not isinstance(path, Path): - path = Path(path) - _MODEL_CONFIG_PATHS.append(path) - _rescan_model_configs() diff --git a/audioldm/clap/open_clip/feature_fusion.py b/audioldm/clap/open_clip/feature_fusion.py deleted file mode 100644 index dbe4e170e05894c12ebdc36ba1dc1de65e441b89..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/feature_fusion.py +++ /dev/null @@ -1,192 +0,0 @@ -""" -Feature Fusion for Varible-Length Data Processing -AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py -According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 -""" - -import torch -import torch.nn as nn - - -class DAF(nn.Module): - """ - 直接相加 DirectAddFuse - """ - - def __init__(self): - super(DAF, self).__init__() - - def forward(self, x, residual): - return x + residual - - -class iAFF(nn.Module): - """ - 多特征融合 iAFF - """ - - def __init__(self, channels=64, r=4, type="2D"): - super(iAFF, self).__init__() - inter_channels = int(channels // r) - - if type == "1D": - # 本地注意力 - self.local_att = nn.Sequential( - nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(channels), - ) - - # 全局注意力 - self.global_att = nn.Sequential( - nn.AdaptiveAvgPool1d(1), - nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(channels), - ) - - # 第二次本地注意力 - self.local_att2 = nn.Sequential( - nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(channels), - ) - # 第二次全局注意力 - self.global_att2 = nn.Sequential( - nn.AdaptiveAvgPool1d(1), - nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(channels), - ) - elif type == "2D": - # 本地注意力 - self.local_att = nn.Sequential( - nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - - # 全局注意力 - self.global_att = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - - # 第二次本地注意力 - self.local_att2 = nn.Sequential( - nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - # 第二次全局注意力 - self.global_att2 = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - else: - raise f"the type is not supported" - - self.sigmoid = nn.Sigmoid() - - def forward(self, x, residual): - flag = False - xa = x + residual - if xa.size(0) == 1: - xa = torch.cat([xa, xa], dim=0) - flag = True - xl = self.local_att(xa) - xg = self.global_att(xa) - xlg = xl + xg - wei = self.sigmoid(xlg) - xi = x * wei + residual * (1 - wei) - - xl2 = self.local_att2(xi) - xg2 = self.global_att(xi) - xlg2 = xl2 + xg2 - wei2 = self.sigmoid(xlg2) - xo = x * wei2 + residual * (1 - wei2) - if flag: - xo = xo[0].unsqueeze(0) - return xo - - -class AFF(nn.Module): - """ - 多特征融合 AFF - """ - - def __init__(self, channels=64, r=4, type="2D"): - super(AFF, self).__init__() - inter_channels = int(channels // r) - - if type == "1D": - self.local_att = nn.Sequential( - nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(channels), - ) - self.global_att = nn.Sequential( - nn.AdaptiveAvgPool1d(1), - nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm1d(channels), - ) - elif type == "2D": - self.local_att = nn.Sequential( - nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - self.global_att = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True), - nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(channels), - ) - else: - raise f"the type is not supported." - - self.sigmoid = nn.Sigmoid() - - def forward(self, x, residual): - flag = False - xa = x + residual - if xa.size(0) == 1: - xa = torch.cat([xa, xa], dim=0) - flag = True - xl = self.local_att(xa) - xg = self.global_att(xa) - xlg = xl + xg - wei = self.sigmoid(xlg) - xo = 2 * x * wei + 2 * residual * (1 - wei) - if flag: - xo = xo[0].unsqueeze(0) - return xo diff --git a/audioldm/clap/open_clip/htsat.py b/audioldm/clap/open_clip/htsat.py deleted file mode 100644 index 3b856c6a43df162116a941f1b5c76e93713b276a..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/htsat.py +++ /dev/null @@ -1,1308 +0,0 @@ -# Ke Chen -# knutchen@ucsd.edu -# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION -# Some layers designed on the model -# below codes are based and referred from https://github.com/microsoft/Swin-Transformer -# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf - -import torch -import torch.nn as nn -import torch.nn.functional as F -from itertools import repeat -import collections.abc -import math -import warnings - -from torch.nn.init import _calculate_fan_in_and_fan_out -import torch.utils.checkpoint as checkpoint - -import random - -from torchlibrosa.stft import Spectrogram, LogmelFilterBank -from torchlibrosa.augmentation import SpecAugmentation - -from itertools import repeat -from .utils import do_mixup, interpolate - -from .feature_fusion import iAFF, AFF, DAF - -# from PyTorch internals -def _ntuple(n): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -to_1tuple = _ntuple(1) -to_2tuple = _ntuple(2) -to_3tuple = _ntuple(3) -to_4tuple = _ntuple(4) -to_ntuple = _ntuple - - -def drop_path(x, drop_prob: float = 0.0, training: bool = False): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, - the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for - changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use - 'survival rate' as the argument. - """ - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * ( - x.ndim - 1 - ) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding""" - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - norm_layer=None, - flatten=True, - patch_stride=16, - enable_fusion=False, - fusion_type="None", - ): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patch_stride = to_2tuple(patch_stride) - self.img_size = img_size - self.patch_size = patch_size - self.patch_stride = patch_stride - self.grid_size = ( - img_size[0] // patch_stride[0], - img_size[1] // patch_stride[1], - ) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.enable_fusion = enable_fusion - self.fusion_type = fusion_type - - padding = ( - (patch_size[0] - patch_stride[0]) // 2, - (patch_size[1] - patch_stride[1]) // 2, - ) - - if (self.enable_fusion) and (self.fusion_type == "channel_map"): - self.proj = nn.Conv2d( - in_chans * 4, - embed_dim, - kernel_size=patch_size, - stride=patch_stride, - padding=padding, - ) - else: - self.proj = nn.Conv2d( - in_chans, - embed_dim, - kernel_size=patch_size, - stride=patch_stride, - padding=padding, - ) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - if (self.enable_fusion) and ( - self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] - ): - self.mel_conv2d = nn.Conv2d( - in_chans, - embed_dim, - kernel_size=(patch_size[0], patch_size[1] * 3), - stride=(patch_stride[0], patch_stride[1] * 3), - padding=padding, - ) - if self.fusion_type == "daf_2d": - self.fusion_model = DAF() - elif self.fusion_type == "aff_2d": - self.fusion_model = AFF(channels=embed_dim, type="2D") - elif self.fusion_type == "iaff_2d": - self.fusion_model = iAFF(channels=embed_dim, type="2D") - - def forward(self, x, longer_idx=None): - if (self.enable_fusion) and ( - self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] - ): - global_x = x[:, 0:1, :, :] - - # global processing - B, C, H, W = global_x.shape - assert ( - H == self.img_size[0] and W == self.img_size[1] - ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - global_x = self.proj(global_x) - TW = global_x.size(-1) - if len(longer_idx) > 0: - # local processing - local_x = x[longer_idx, 1:, :, :].contiguous() - B, C, H, W = local_x.shape - local_x = local_x.view(B * C, 1, H, W) - local_x = self.mel_conv2d(local_x) - local_x = local_x.view( - B, C, local_x.size(1), local_x.size(2), local_x.size(3) - ) - local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3) - TB, TC, TH, _ = local_x.size() - if local_x.size(-1) < TW: - local_x = torch.cat( - [ - local_x, - torch.zeros( - (TB, TC, TH, TW - local_x.size(-1)), - device=global_x.device, - ), - ], - dim=-1, - ) - else: - local_x = local_x[:, :, :, :TW] - - global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x) - x = global_x - else: - B, C, H, W = x.shape - assert ( - H == self.img_size[0] and W == self.img_size[1] - ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x) - - if self.flatten: - x = x.flatten(2).transpose(1, 2) # BCHW -> BNC - x = self.norm(x) - return x - - -class Mlp(nn.Module): - """MLP as used in Vision Transformer, MLP-Mixer and related networks""" - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0.0, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def _no_grad_trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): - # type: (Tensor, float, float, float, float) -> Tensor - r"""Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \leq \text{mean} \leq b`. - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - Examples: - >>> w = torch.empty(3, 5) - >>> nn.init.trunc_normal_(w) - """ - return _no_grad_trunc_normal_(tensor, mean, std, a, b) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - if mode == "fan_in": - denom = fan_in - elif mode == "fan_out": - denom = fan_out - elif mode == "fan_avg": - denom = (fan_in + fan_out) / 2 - - variance = scale / denom - - if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) - elif distribution == "normal": - tensor.normal_(std=math.sqrt(variance)) - elif distribution == "uniform": - bound = math.sqrt(3 * variance) - tensor.uniform_(-bound, bound) - else: - raise ValueError(f"invalid distribution {distribution}") - - -def lecun_normal_(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = ( - x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - ) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view( - B, H // window_size, W // window_size, window_size, window_size, -1 - ) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r"""Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__( - self, - dim, - window_size, - num_heads, - qkv_bias=True, - qk_scale=None, - attn_drop=0.0, - proj_drop=0.0, - ): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) - ) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = ( - coords_flatten[:, :, None] - coords_flatten[:, None, :] - ) # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute( - 1, 2, 0 - ).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=0.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = ( - self.qkv(x) - .reshape(B_, N, 3, self.num_heads, C // self.num_heads) - .permute(2, 0, 3, 1, 4) - ) - q, k, v = ( - qkv[0], - qkv[1], - qkv[2], - ) # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = q @ k.transpose(-2, -1) - - relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1) - ].view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], - -1, - ) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute( - 2, 0, 1 - ).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( - 1 - ).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x, attn - - def extra_repr(self): - return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}" - - -# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model -class SwinTransformerBlock(nn.Module): - r"""Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__( - self, - dim, - input_resolution, - num_heads, - window_size=7, - shift_size=0, - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - norm_before_mlp="ln", - ): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - self.norm_before_mlp = norm_before_mlp - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert ( - 0 <= self.shift_size < self.window_size - ), "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, - window_size=to_2tuple(self.window_size), - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - ) - - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - if self.norm_before_mlp == "ln": - self.norm2 = nn.LayerNorm(dim) - elif self.norm_before_mlp == "bn": - self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose( - 1, 2 - ) - else: - raise NotImplementedError - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop, - ) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None), - ) - w_slices = ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None), - ) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition( - img_mask, self.window_size - ) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill( - attn_mask != 0, float(-100.0) - ).masked_fill(attn_mask == 0, float(0.0)) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, x): - # pdb.set_trace() - H, W = self.input_resolution - # print("H: ", H) - # print("W: ", W) - # pdb.set_trace() - B, L, C = x.shape - # assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll( - x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) - ) - else: - shifted_x = x - - # partition windows - x_windows = window_partition( - shifted_x, self.window_size - ) # nW*B, window_size, window_size, C - x_windows = x_windows.view( - -1, self.window_size * self.window_size, C - ) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA - attn_windows, attn = self.attn( - x_windows, mask=self.attn_mask - ) # nW*B, window_size*window_size, C - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll( - shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) - ) - else: - x = shifted_x - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x, attn - - def extra_repr(self): - return ( - f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - ) - - -class PatchMerging(nn.Module): - r"""Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self): - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - -class BasicLayer(nn.Module): - """A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__( - self, - dim, - input_resolution, - depth, - num_heads, - window_size, - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - norm_layer=nn.LayerNorm, - downsample=None, - use_checkpoint=False, - norm_before_mlp="ln", - ): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList( - [ - SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] - if isinstance(drop_path, list) - else drop_path, - norm_layer=norm_layer, - norm_before_mlp=norm_before_mlp, - ) - for i in range(depth) - ] - ) - - # patch merging layer - if downsample is not None: - self.downsample = downsample( - input_resolution, dim=dim, norm_layer=norm_layer - ) - else: - self.downsample = None - - def forward(self, x): - attns = [] - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x, attn = blk(x) - if not self.training: - attns.append(attn.unsqueeze(0)) - if self.downsample is not None: - x = self.downsample(x) - if not self.training: - attn = torch.cat(attns, dim=0) - attn = torch.mean(attn, dim=0) - return x, attn - - def extra_repr(self): - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - -# The Core of HTSAT -class HTSAT_Swin_Transformer(nn.Module): - r"""HTSAT based on the Swin Transformer - Args: - spec_size (int | tuple(int)): Input Spectrogram size. Default 256 - patch_size (int | tuple(int)): Patch size. Default: 4 - path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 - in_chans (int): Number of input image channels. Default: 1 (mono) - num_classes (int): Number of classes for classification head. Default: 527 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 8 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - config (module): The configuration Module from config.py - """ - - def __init__( - self, - spec_size=256, - patch_size=4, - patch_stride=(4, 4), - in_chans=1, - num_classes=527, - embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[4, 8, 16, 32], - window_size=8, - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, - drop_path_rate=0.1, - norm_layer=nn.LayerNorm, - ape=False, - patch_norm=True, - use_checkpoint=False, - norm_before_mlp="ln", - config=None, - enable_fusion=False, - fusion_type="None", - **kwargs, - ): - super(HTSAT_Swin_Transformer, self).__init__() - - self.config = config - self.spec_size = spec_size - self.patch_stride = patch_stride - self.patch_size = patch_size - self.window_size = window_size - self.embed_dim = embed_dim - self.depths = depths - self.ape = ape - self.in_chans = in_chans - self.num_classes = num_classes - self.num_heads = num_heads - self.num_layers = len(self.depths) - self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) - - self.drop_rate = drop_rate - self.attn_drop_rate = attn_drop_rate - self.drop_path_rate = drop_path_rate - - self.qkv_bias = qkv_bias - self.qk_scale = None - - self.patch_norm = patch_norm - self.norm_layer = norm_layer if self.patch_norm else None - self.norm_before_mlp = norm_before_mlp - self.mlp_ratio = mlp_ratio - - self.use_checkpoint = use_checkpoint - - self.enable_fusion = enable_fusion - self.fusion_type = fusion_type - - # process mel-spec ; used only once - self.freq_ratio = self.spec_size // self.config.mel_bins - window = "hann" - center = True - pad_mode = "reflect" - ref = 1.0 - amin = 1e-10 - top_db = None - self.interpolate_ratio = 32 # Downsampled ratio - # Spectrogram extractor - self.spectrogram_extractor = Spectrogram( - n_fft=config.window_size, - hop_length=config.hop_size, - win_length=config.window_size, - window=window, - center=center, - pad_mode=pad_mode, - freeze_parameters=True, - ) - # Logmel feature extractor - self.logmel_extractor = LogmelFilterBank( - sr=config.sample_rate, - n_fft=config.window_size, - n_mels=config.mel_bins, - fmin=config.fmin, - fmax=config.fmax, - ref=ref, - amin=amin, - top_db=top_db, - freeze_parameters=True, - ) - # Spec augmenter - self.spec_augmenter = SpecAugmentation( - time_drop_width=64, - time_stripes_num=2, - freq_drop_width=8, - freq_stripes_num=2, - ) # 2 2 - self.bn0 = nn.BatchNorm2d(self.config.mel_bins) - - # split spctrogram into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=self.spec_size, - patch_size=self.patch_size, - in_chans=self.in_chans, - embed_dim=self.embed_dim, - norm_layer=self.norm_layer, - patch_stride=patch_stride, - enable_fusion=self.enable_fusion, - fusion_type=self.fusion_type, - ) - - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.grid_size - self.patches_resolution = patches_resolution - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter( - torch.zeros(1, num_patches, self.embed_dim) - ) - trunc_normal_(self.absolute_pos_embed, std=0.02) - - self.pos_drop = nn.Dropout(p=self.drop_rate) - - # stochastic depth - dpr = [ - x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths)) - ] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer( - dim=int(self.embed_dim * 2**i_layer), - input_resolution=( - patches_resolution[0] // (2**i_layer), - patches_resolution[1] // (2**i_layer), - ), - depth=self.depths[i_layer], - num_heads=self.num_heads[i_layer], - window_size=self.window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=self.qkv_bias, - qk_scale=self.qk_scale, - drop=self.drop_rate, - attn_drop=self.attn_drop_rate, - drop_path=dpr[ - sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1]) - ], - norm_layer=self.norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - use_checkpoint=use_checkpoint, - norm_before_mlp=self.norm_before_mlp, - ) - self.layers.append(layer) - - self.norm = self.norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - self.maxpool = nn.AdaptiveMaxPool1d(1) - - SF = ( - self.spec_size - // (2 ** (len(self.depths) - 1)) - // self.patch_stride[0] - // self.freq_ratio - ) - self.tscam_conv = nn.Conv2d( - in_channels=self.num_features, - out_channels=self.num_classes, - kernel_size=(SF, 3), - padding=(0, 1), - ) - self.head = nn.Linear(num_classes, num_classes) - - if (self.enable_fusion) and ( - self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] - ): - self.mel_conv1d = nn.Sequential( - nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), - nn.BatchNorm1d(64), - ) - if self.fusion_type == "daf_1d": - self.fusion_model = DAF() - elif self.fusion_type == "aff_1d": - self.fusion_model = AFF(channels=64, type="1D") - elif self.fusion_type == "iaff_1d": - self.fusion_model = iAFF(channels=64, type="1D") - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {"absolute_pos_embed"} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {"relative_position_bias_table"} - - def forward_features(self, x, longer_idx=None): - # A deprecated optimization for using a hierarchical output from different blocks - - frames_num = x.shape[2] - x = self.patch_embed(x, longer_idx=longer_idx) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - for i, layer in enumerate(self.layers): - x, attn = layer(x) - # for x - x = self.norm(x) - B, N, C = x.shape - SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] - ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] - x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST) - B, C, F, T = x.shape - # group 2D CNN - c_freq_bin = F // self.freq_ratio - x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) - x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1) - # get latent_output - fine_grained_latent_output = torch.mean(x, dim=2) - fine_grained_latent_output = interpolate( - fine_grained_latent_output.permute(0, 2, 1).contiguous(), - 8 * self.patch_stride[1], - ) - - latent_output = self.avgpool(torch.flatten(x, 2)) - latent_output = torch.flatten(latent_output, 1) - - # display the attention map, if needed - - x = self.tscam_conv(x) - x = torch.flatten(x, 2) # B, C, T - - fpx = interpolate( - torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1] - ) - - x = self.avgpool(x) - x = torch.flatten(x, 1) - - output_dict = { - "framewise_output": fpx, # already sigmoided - "clipwise_output": torch.sigmoid(x), - "fine_grained_embedding": fine_grained_latent_output, - "embedding": latent_output, - } - - return output_dict - - def crop_wav(self, x, crop_size, spe_pos=None): - time_steps = x.shape[2] - tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) - for i in range(len(x)): - if spe_pos is None: - crop_pos = random.randint(0, time_steps - crop_size - 1) - else: - crop_pos = spe_pos - tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :] - return tx - - # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model - def reshape_wav2img(self, x): - B, C, T, F = x.shape - target_T = int(self.spec_size * self.freq_ratio) - target_F = self.spec_size // self.freq_ratio - assert ( - T <= target_T and F <= target_F - ), "the wav size should less than or equal to the swin input size" - # to avoid bicubic zero error - if T < target_T: - x = nn.functional.interpolate( - x, (target_T, x.shape[3]), mode="bicubic", align_corners=True - ) - if F < target_F: - x = nn.functional.interpolate( - x, (x.shape[2], target_F), mode="bicubic", align_corners=True - ) - x = x.permute(0, 1, 3, 2).contiguous() - x = x.reshape( - x.shape[0], - x.shape[1], - x.shape[2], - self.freq_ratio, - x.shape[3] // self.freq_ratio, - ) - # print(x.shape) - x = x.permute(0, 1, 3, 2, 4).contiguous() - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) - return x - - # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model - def repeat_wat2img(self, x, cur_pos): - B, C, T, F = x.shape - target_T = int(self.spec_size * self.freq_ratio) - target_F = self.spec_size // self.freq_ratio - assert ( - T <= target_T and F <= target_F - ), "the wav size should less than or equal to the swin input size" - # to avoid bicubic zero error - if T < target_T: - x = nn.functional.interpolate( - x, (target_T, x.shape[3]), mode="bicubic", align_corners=True - ) - if F < target_F: - x = nn.functional.interpolate( - x, (x.shape[2], target_F), mode="bicubic", align_corners=True - ) - x = x.permute(0, 1, 3, 2).contiguous() # B C F T - x = x[:, :, :, cur_pos : cur_pos + self.spec_size] - x = x.repeat(repeats=(1, 1, 4, 1)) - return x - - def forward( - self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None - ): # out_feat_keys: List[str] = None): - - if self.enable_fusion and x["longer"].sum() == 0: - # if no audio is longer than 10s, then randomly select one audio to be longer - x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True - - if not self.enable_fusion: - x = x["waveform"].to(device=device, non_blocking=True) - x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) - x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) - x = x.transpose(1, 3) - x = self.bn0(x) - x = x.transpose(1, 3) - if self.training: - x = self.spec_augmenter(x) - - if self.training and mixup_lambda is not None: - x = do_mixup(x, mixup_lambda) - - x = self.reshape_wav2img(x) - output_dict = self.forward_features(x) - else: - longer_list = x["longer"].to(device=device, non_blocking=True) - x = x["mel_fusion"].to(device=device, non_blocking=True) - x = x.transpose(1, 3) - x = self.bn0(x) - x = x.transpose(1, 3) - longer_list_idx = torch.where(longer_list)[0] - if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: - new_x = x[:, 0:1, :, :].clone().contiguous() - if len(longer_list_idx) > 0: - # local processing - fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() - FB, FC, FT, FF = fusion_x_local.size() - fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) - fusion_x_local = torch.permute( - fusion_x_local, (0, 2, 1) - ).contiguous() - fusion_x_local = self.mel_conv1d(fusion_x_local) - fusion_x_local = fusion_x_local.view( - FB, FC, FF, fusion_x_local.size(-1) - ) - fusion_x_local = ( - torch.permute(fusion_x_local, (0, 2, 1, 3)) - .contiguous() - .flatten(2) - ) - if fusion_x_local.size(-1) < FT: - fusion_x_local = torch.cat( - [ - fusion_x_local, - torch.zeros( - (FB, FF, FT - fusion_x_local.size(-1)), - device=device, - ), - ], - dim=-1, - ) - else: - fusion_x_local = fusion_x_local[:, :, :FT] - # 1D fusion - new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() - new_x[longer_list_idx] = self.fusion_model( - new_x[longer_list_idx], fusion_x_local - ) - x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] - else: - x = new_x - - elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: - x = x # no change - - if self.training: - x = self.spec_augmenter(x) - if self.training and mixup_lambda is not None: - x = do_mixup(x, mixup_lambda) - - x = self.reshape_wav2img(x) - output_dict = self.forward_features(x, longer_idx=longer_list_idx) - - # if infer_mode: - # # in infer mode. we need to handle different length audio input - # frame_num = x.shape[2] - # target_T = int(self.spec_size * self.freq_ratio) - # repeat_ratio = math.floor(target_T / frame_num) - # x = x.repeat(repeats=(1,1,repeat_ratio,1)) - # x = self.reshape_wav2img(x) - # output_dict = self.forward_features(x) - # else: - # if x.shape[2] > self.freq_ratio * self.spec_size: - # if self.training: - # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) - # x = self.reshape_wav2img(x) - # output_dict = self.forward_features(x) - # else: - # # Change: Hard code here - # overlap_size = (x.shape[2] - 1) // 4 - # output_dicts = [] - # crop_size = (x.shape[2] - 1) // 2 - # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): - # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) - # tx = self.reshape_wav2img(tx) - # output_dicts.append(self.forward_features(tx)) - # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) - # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) - # for d in output_dicts: - # clipwise_output += d["clipwise_output"] - # framewise_output += d["framewise_output"] - # clipwise_output = clipwise_output / len(output_dicts) - # framewise_output = framewise_output / len(output_dicts) - # output_dict = { - # 'framewise_output': framewise_output, - # 'clipwise_output': clipwise_output - # } - # else: # this part is typically used, and most easy one - # x = self.reshape_wav2img(x) - # output_dict = self.forward_features(x) - # x = self.head(x) - - # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T - - return output_dict - - -def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"): - try: - - assert audio_cfg.model_name in [ - "tiny", - "base", - "large", - ], "model name for HTS-AT is wrong!" - if audio_cfg.model_name == "tiny": - model = HTSAT_Swin_Transformer( - spec_size=256, - patch_size=4, - patch_stride=(4, 4), - num_classes=audio_cfg.class_num, - embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[4, 8, 16, 32], - window_size=8, - config=audio_cfg, - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - elif audio_cfg.model_name == "base": - model = HTSAT_Swin_Transformer( - spec_size=256, - patch_size=4, - patch_stride=(4, 4), - num_classes=audio_cfg.class_num, - embed_dim=128, - depths=[2, 2, 12, 2], - num_heads=[4, 8, 16, 32], - window_size=8, - config=audio_cfg, - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - elif audio_cfg.model_name == "large": - model = HTSAT_Swin_Transformer( - spec_size=256, - patch_size=4, - patch_stride=(4, 4), - num_classes=audio_cfg.class_num, - embed_dim=256, - depths=[2, 2, 12, 2], - num_heads=[4, 8, 16, 32], - window_size=8, - config=audio_cfg, - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - - return model - except: - raise RuntimeError( - f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." - ) diff --git a/audioldm/clap/open_clip/linear_probe.py b/audioldm/clap/open_clip/linear_probe.py deleted file mode 100644 index 9d7e23b6b67a53e16d050d675a99d01d7d04d581..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/linear_probe.py +++ /dev/null @@ -1,66 +0,0 @@ -import numpy as np -import torch.nn.functional as F -from torch import nn -from .model import MLPLayers - - -class LinearProbe(nn.Module): - def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): - """ - Args: - model: nn.Module - mlp: bool, if True, then use the MLP layer as the linear probe module - freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe - in_ch: int, the output channel from CLAP model - out_ch: int, the output channel from linear probe (class_num) - act: torch.nn.functional, the activation function before the loss function - """ - super().__init__() - in_ch = 512 - self.clap_model = model - self.clap_model.text_branch = None # to save memory - self.freeze = freeze - if mlp: - self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) - else: - self.lp_layer = nn.Linear(in_ch, out_ch) - - if self.freeze: - for param in self.clap_model.parameters(): - param.requires_grad = False - - if act == "None": - self.act = None - elif act == "relu": - self.act = nn.ReLU() - elif act == "elu": - self.act = nn.ELU() - elif act == "prelu": - self.act = nn.PReLU(num_parameters=in_ch) - elif act == "softmax": - self.act = nn.Softmax(dim=-1) - elif act == "sigmoid": - self.act = nn.Sigmoid() - - def forward(self, x, mix_lambda=None, device=None): - """ - Args: - x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list - mix_lambda: torch.tensor [batch], the mixup lambda - Returns: - class_prob: torch.tensor [batch, class_num] - - """ - # batchnorm cancel grandient - if self.freeze: - self.clap_model.eval() - - x = self.clap_model.audio_projection( - self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ - "embedding" - ] - ) - out = self.lp_layer(x) - if self.act is not None: - out = self.act(out) - return out diff --git a/audioldm/clap/open_clip/loss.py b/audioldm/clap/open_clip/loss.py deleted file mode 100644 index cc66298a14997da4aa2efc71e37c0a6bcda53fd1..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/loss.py +++ /dev/null @@ -1,398 +0,0 @@ -from multiprocessing.sharedctypes import Value -import torch -import torch.distributed.nn -from torch import distributed as dist, nn as nn -from torch.nn import functional as F -import numpy as np -from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score - -try: - import horovod.torch as hvd -except ImportError: - hvd = None - - -def gather_features( - audio_features, - text_features, - audio_features_mlp=None, - text_features_mlp=None, - local_loss=False, - gather_with_grad=False, - rank=0, - world_size=1, - use_horovod=False, - mlp_loss=False, -): - if use_horovod: - assert hvd is not None, "Please install horovod" - if gather_with_grad: - all_audio_features = hvd.allgather(audio_features) - all_text_features = hvd.allgather(text_features) - if mlp_loss: - all_audio_features_mlp = hvd.allgather(audio_features_mlp) - all_text_features_mlp = hvd.allgather(text_features_mlp) - else: - with torch.no_grad(): - all_audio_features = hvd.allgather(audio_features) - all_text_features = hvd.allgather(text_features) - if mlp_loss: - all_audio_features_mlp = hvd.allgather(audio_features_mlp) - all_text_features_mlp = hvd.allgather(text_features_mlp) - if not local_loss: - # ensure grads for local rank when all_* features don't have a gradient - gathered_audio_features = list( - all_audio_features.chunk(world_size, dim=0) - ) - gathered_text_features = list( - all_text_features.chunk(world_size, dim=0) - ) - gathered_audio_features[rank] = audio_features - gathered_text_features[rank] = text_features - all_audio_features = torch.cat(gathered_audio_features, dim=0) - all_text_features = torch.cat(gathered_text_features, dim=0) - if mlp_loss: - gathered_audio_features_mlp = list( - all_audio_features_mlp.chunk(world_size, dim=0) - ) - gathered_text_features_mlp = list( - all_text_features_mlp.chunk(world_size, dim=0) - ) - gathered_audio_features_mlp[rank] = audio_features_mlp - gathered_text_features_mlp[rank] = text_features_mlp - all_audio_features_mlp = torch.cat( - gathered_audio_features_mlp, dim=0 - ) - all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) - else: - # We gather tensors from all gpus - if gather_with_grad: - all_audio_features = torch.cat( - torch.distributed.nn.all_gather(audio_features), dim=0 - ) - all_text_features = torch.cat( - torch.distributed.nn.all_gather(text_features), dim=0 - ) - if mlp_loss: - all_audio_features_mlp = torch.cat( - torch.distributed.nn.all_gather(audio_features_mlp), dim=0 - ) - all_text_features_mlp = torch.cat( - torch.distributed.nn.all_gather(text_features_mlp), dim=0 - ) - else: - gathered_audio_features = [ - torch.zeros_like(audio_features) for _ in range(world_size) - ] - gathered_text_features = [ - torch.zeros_like(text_features) for _ in range(world_size) - ] - dist.all_gather(gathered_audio_features, audio_features) - dist.all_gather(gathered_text_features, text_features) - if mlp_loss: - gathered_audio_features_mlp = [ - torch.zeros_like(audio_features_mlp) for _ in range(world_size) - ] - gathered_text_features_mlp = [ - torch.zeros_like(text_features_mlp) for _ in range(world_size) - ] - dist.all_gather(gathered_audio_features_mlp, audio_features_mlp) - dist.all_gather(gathered_text_features_mlp, text_features_mlp) - if not local_loss: - # ensure grads for local rank when all_* features don't have a gradient - gathered_audio_features[rank] = audio_features - gathered_text_features[rank] = text_features - if mlp_loss: - gathered_audio_features_mlp[rank] = audio_features_mlp - gathered_text_features_mlp[rank] = text_features_mlp - - all_audio_features = torch.cat(gathered_audio_features, dim=0) - all_text_features = torch.cat(gathered_text_features, dim=0) - if mlp_loss: - all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) - all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) - if mlp_loss: - return ( - all_audio_features, - all_text_features, - all_audio_features_mlp, - all_text_features_mlp, - ) - else: - return all_audio_features, all_text_features - - -class ClipLoss(nn.Module): - def __init__( - self, - local_loss=False, - gather_with_grad=False, - cache_labels=False, - rank=0, - world_size=1, - use_horovod=False, - mlp_loss=False, - weight_loss_kappa=0, - ): - super().__init__() - self.local_loss = local_loss - self.gather_with_grad = gather_with_grad - self.cache_labels = cache_labels - self.rank = rank - self.world_size = world_size - self.use_horovod = use_horovod - self.mlp_loss = mlp_loss - self.weighted_loss = bool(weight_loss_kappa != 0) - self.weight_loss_kappa = weight_loss_kappa - # cache state - self.prev_num_logits = 0 - self.labels = {} - - def forward( - self, - audio_features, - text_features, - logit_scale_a, - logit_scale_t=None, - audio_features_mlp=None, - text_features_mlp=None, - ): - device = audio_features.device - if self.mlp_loss: - if self.world_size > 1: - ( - all_audio_features, - all_text_features, - all_audio_features_mlp, - all_text_features_mlp, - ) = gather_features( - audio_features=audio_features, - text_features=text_features, - audio_features_mlp=audio_features_mlp, - text_features_mlp=text_features_mlp, - local_loss=self.local_loss, - gather_with_grad=self.gather_with_grad, - rank=self.rank, - world_size=self.world_size, - use_horovod=self.use_horovod, - mlp_loss=self.mlp_loss, - ) - if self.local_loss: - a_logits_per_audio = ( - logit_scale_a * audio_features @ all_text_features_mlp.T - ) - a_logits_per_text = ( - logit_scale_a * text_features_mlp @ all_audio_features.T - ) - t_logits_per_audio = ( - logit_scale_t * audio_features_mlp @ all_text_features.T - ) - t_logits_per_text = ( - logit_scale_t * text_features @ all_audio_features_mlp.T - ) - else: - a_logits_per_audio = ( - logit_scale_a * all_audio_features @ all_text_features_mlp.T - ) - a_logits_per_text = a_logits_per_audio.T - t_logits_per_audio = ( - logit_scale_t * all_audio_features_mlp @ all_text_features.T - ) - t_logits_per_text = t_logits_per_audio.T - else: - a_logits_per_audio = ( - logit_scale_a * audio_features @ text_features_mlp.T - ) - a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T - t_logits_per_audio = ( - logit_scale_t * audio_features_mlp @ text_features.T - ) - t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T - - # calculated ground-truth and cache if enabled - num_logits = a_logits_per_audio.shape[0] - if self.prev_num_logits != num_logits or device not in self.labels: - labels = torch.arange(num_logits, device=device, dtype=torch.long) - if self.world_size > 1 and self.local_loss: - labels = labels + num_logits * self.rank - if self.cache_labels: - self.labels[device] = labels - self.prev_num_logits = num_logits - else: - labels = self.labels[device] - - if not self.weighted_loss: - total_loss = ( - F.cross_entropy(a_logits_per_audio, labels) - + F.cross_entropy(a_logits_per_text, labels) - + F.cross_entropy(t_logits_per_audio, labels) - + F.cross_entropy(t_logits_per_text, labels) - ) / 4 - else: - audio_weight = (audio_features @ audio_features.T).detach() - audio_weight = ( - torch.exp( - torch.sum(audio_weight, axis=1) - / (self.weight_loss_kappa * len(audio_weight)) - ) - ).detach() - text_weight = (text_features @ text_features.T).detach() - text_weight = ( - torch.exp( - torch.sum(text_weight, axis=1) - / (self.weight_loss_kappa * len(text_features)) - ) - ).detach() - total_loss = ( - F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) - + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) - + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) - + F.cross_entropy(t_logits_per_text, labels, weight=text_weight) - ) / 4 - else: - if self.world_size > 1: - all_audio_features, all_text_features = gather_features( - audio_features=audio_features, - text_features=text_features, - local_loss=self.local_loss, - gather_with_grad=self.gather_with_grad, - rank=self.rank, - world_size=self.world_size, - use_horovod=self.use_horovod, - mlp_loss=self.mlp_loss, - ) - - if self.local_loss: - logits_per_audio = ( - logit_scale_a * audio_features @ all_text_features.T - ) - logits_per_text = ( - logit_scale_a * text_features @ all_audio_features.T - ) - else: - logits_per_audio = ( - logit_scale_a * all_audio_features @ all_text_features.T - ) - logits_per_text = logits_per_audio.T - else: - logits_per_audio = logit_scale_a * audio_features @ text_features.T - logits_per_text = logit_scale_a * text_features @ audio_features.T - - # calculated ground-truth and cache if enabled - num_logits = logits_per_audio.shape[0] - if self.prev_num_logits != num_logits or device not in self.labels: - labels = torch.arange(num_logits, device=device, dtype=torch.long) - if self.world_size > 1 and self.local_loss: - labels = labels + num_logits * self.rank - if self.cache_labels: - self.labels[device] = labels - self.prev_num_logits = num_logits - else: - labels = self.labels[device] - if not self.weighted_loss: - total_loss = ( - F.cross_entropy(logits_per_audio, labels) - + F.cross_entropy(logits_per_text, labels) - ) / 2 - else: - audio_weight = (all_audio_features @ all_audio_features.T).detach() - audio_weight = ( - torch.exp( - torch.sum(audio_weight, axis=1) - / (self.weight_loss_kappa * len(all_audio_features)) - ) - ).detach() - text_weight = (all_text_features @ all_text_features.T).detach() - text_weight = ( - torch.exp( - torch.sum(text_weight, axis=1) - / (self.weight_loss_kappa * len(all_text_features)) - ) - ).detach() - total_loss = ( - F.cross_entropy(logits_per_audio, labels, weight=text_weight) - + F.cross_entropy(logits_per_text, labels, weight=audio_weight) - ) / 2 - return total_loss - - -def lp_gather_features(pred, target, world_size=1, use_horovod=False): - if use_horovod: - assert hvd is not None, "Please install horovod" - with torch.no_grad(): - all_preds = hvd.allgather(pred) - all_targets = hvd.allgath(target) - else: - gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] - gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] - - dist.all_gather(gathered_preds, pred) - dist.all_gather(gathered_targets, target) - all_preds = torch.cat(gathered_preds, dim=0) - all_targets = torch.cat(gathered_targets, dim=0) - - return all_preds, all_targets - - -def get_map(pred, target): - pred = torch.sigmoid(pred).numpy() - target = target.numpy() - return np.mean(average_precision_score(target, pred, average=None)) - - -def get_acc(pred, target): - pred = torch.argmax(pred, 1).numpy() - target = torch.argmax(target, 1).numpy() - return accuracy_score(target, pred) - - -def get_mauc(pred, target): - pred = torch.sigmoid(pred).numpy() - target = target.numpy() - return np.mean(roc_auc_score(target, pred, average=None)) - - -class LPMetrics(object): - def __init__(self, metric_names=["map", "acc", "mauc"]): - self.metrics = [] - for name in metric_names: - self.metrics.append(self.get_metric(name)) - self.metric_names = metric_names - - def get_metric(self, name): - if name == "map": - return get_map - elif name == "acc": - return get_acc - elif name == "mauc": - return get_mauc - else: - raise ValueError(f"the metric should be at least one of [map, acc, mauc]") - - def evaluate_mertics(self, pred, target): - metric_dict = {} - for i in range(len(self.metric_names)): - metric_dict[self.metric_names[i]] = self.metrics[i](pred, target) - return metric_dict - - -def calc_celoss(pred, target): - target = torch.argmax(target, 1).long() - return nn.CrossEntropyLoss()(pred, target) - - -class LPLoss(nn.Module): - def __init__(self, loss_name): - super().__init__() - if loss_name == "bce": - self.loss_func = nn.BCEWithLogitsLoss() - elif loss_name == "ce": - self.loss_func = calc_celoss - elif loss_name == "mse": - self.loss_func = nn.MSELoss() - else: - raise ValueError(f"the loss func should be at least one of [bce, ce, mse]") - - def forward(self, pred, target): - loss = self.loss_func(pred, target) - return loss diff --git a/audioldm/clap/open_clip/model.py b/audioldm/clap/open_clip/model.py deleted file mode 100644 index f41e6d6d0b0bbecacb90744928a516b75d218214..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model.py +++ /dev/null @@ -1,936 +0,0 @@ -""" CLAP Model - -Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. -Adapted to the Audio Task. -""" - -from collections import OrderedDict -from dataclasses import dataclass -from email.mime import audio -from typing import Tuple, Union, Callable, Optional - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - -from .timm_model import TimmModel -import logging -from .utils import freeze_batch_norm_2d - -from .pann_model import create_pann_model -from .htsat import create_htsat_model -from transformers import BertModel, RobertaModel, BartModel -from transformers.tokenization_utils_base import BatchEncoding - - -class MLPLayers(nn.Module): - def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1): - super(MLPLayers, self).__init__() - self.nonlin = nonlin - self.dropout = dropout - - sequence = [] - for u0, u1 in zip(units[:-1], units[1:]): - sequence.append(nn.Linear(u0, u1)) - sequence.append(self.nonlin) - sequence.append(nn.Dropout(self.dropout)) - sequence = sequence[:-2] - - self.sequential = nn.Sequential(*sequence) - - def forward(self, X): - X = self.sequential(X) - return X - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1): - super().__init__() - - # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 - self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - - self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() - - self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - - self.relu = nn.ReLU(inplace=True) - self.downsample = None - self.stride = stride - - if stride > 1 or inplanes != planes * Bottleneck.expansion: - # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential( - OrderedDict( - [ - ("-1", nn.AvgPool2d(stride)), - ( - "0", - nn.Conv2d( - inplanes, - planes * self.expansion, - 1, - stride=1, - bias=False, - ), - ), - ("1", nn.BatchNorm2d(planes * self.expansion)), - ] - ) - ) - - def forward(self, x: torch.Tensor): - identity = x - - out = self.relu(self.bn1(self.conv1(x))) - out = self.relu(self.bn2(self.conv2(out))) - out = self.avgpool(out) - out = self.bn3(self.conv3(out)) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - return out - - -class AttentionPool2d(nn.Module): - def __init__( - self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None - ): - super().__init__() - self.positional_embedding = nn.Parameter( - torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 - ) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( - 2, 0, 1 - ) # NCHW -> (HW)NC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC - x, _ = F.multi_head_attention_forward( - query=x, - key=x, - value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat( - [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] - ), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False, - ) - - return x[0] - - -class ModifiedResNet(nn.Module): - """ - A ResNet class that is similar to torchvision's but contains the following changes: - - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - - The final pooling layer is a QKV attention instead of an average pool - """ - - def __init__(self, layers, output_dim, heads, image_size=224, width=64): - super().__init__() - self.output_dim = output_dim - self.image_size = image_size - - # the 3-layer stem - self.conv1 = nn.Conv2d( - 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False - ) - self.bn1 = nn.BatchNorm2d(width // 2) - self.conv2 = nn.Conv2d( - width // 2, width // 2, kernel_size=3, padding=1, bias=False - ) - self.bn2 = nn.BatchNorm2d(width // 2) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) - self.bn3 = nn.BatchNorm2d(width) - self.avgpool = nn.AvgPool2d(2) - self.relu = nn.ReLU(inplace=True) - - # residual layers - self._inplanes = width # this is a *mutable* variable used during construction - self.layer1 = self._make_layer(width, layers[0]) - self.layer2 = self._make_layer(width * 2, layers[1], stride=2) - self.layer3 = self._make_layer(width * 4, layers[2], stride=2) - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) - - embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) - - self.init_parameters() - - def _make_layer(self, planes, blocks, stride=1): - layers = [Bottleneck(self._inplanes, planes, stride)] - - self._inplanes = planes * Bottleneck.expansion - for _ in range(1, blocks): - layers.append(Bottleneck(self._inplanes, planes)) - - return nn.Sequential(*layers) - - def init_parameters(self): - if self.attnpool is not None: - std = self.attnpool.c_proj.in_features**-0.5 - nn.init.normal_(self.attnpool.q_proj.weight, std=std) - nn.init.normal_(self.attnpool.k_proj.weight, std=std) - nn.init.normal_(self.attnpool.v_proj.weight, std=std) - nn.init.normal_(self.attnpool.c_proj.weight, std=std) - - for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: - for name, param in resnet_block.named_parameters(): - if name.endswith("bn3.weight"): - nn.init.zeros_(param) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - assert ( - unlocked_groups == 0 - ), "partial locking not currently supported for this model" - for param in self.parameters(): - param.requires_grad = False - if freeze_bn_stats: - freeze_batch_norm_2d(self) - - def stem(self, x): - for conv, bn in [ - (self.conv1, self.bn1), - (self.conv2, self.bn2), - (self.conv3, self.bn3), - ]: - x = self.relu(bn(conv(x))) - x = self.avgpool(x) - return x - - def forward(self, x): - x = self.stem(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.attnpool(x) - - return x - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - return x.to(orig_type) - - -class QuickGELU(nn.Module): - # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class ResidualAttentionBlock(nn.Module): - def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU): - super().__init__() - - self.attn = nn.MultiheadAttention(d_model, n_head) - self.ln_1 = LayerNorm(d_model) - self.mlp = nn.Sequential( - OrderedDict( - [ - ("c_fc", nn.Linear(d_model, d_model * 4)), - ("gelu", act_layer()), - ("c_proj", nn.Linear(d_model * 4, d_model)), - ] - ) - ) - self.ln_2 = LayerNorm(d_model) - - def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] - - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) - x = x + self.mlp(self.ln_2(x)) - return x - - -class Transformer(nn.Module): - def __init__( - self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU - ): - super().__init__() - self.width = width - self.layers = layers - self.resblocks = nn.ModuleList( - [ - ResidualAttentionBlock(width, heads, act_layer=act_layer) - for _ in range(layers) - ] - ) - - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - for r in self.resblocks: - x = r(x, attn_mask=attn_mask) - return x - - -class VisualTransformer(nn.Module): - def __init__( - self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - output_dim: int, - act_layer: Callable = nn.GELU, - ): - super().__init__() - self.image_size = image_size - self.output_dim = output_dim - self.conv1 = nn.Conv2d( - in_channels=3, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - bias=False, - ) - - scale = width**-0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter( - scale * torch.randn((image_size // patch_size) ** 2 + 1, width) - ) - self.ln_pre = LayerNorm(width) - - self.text_branch = Transformer(width, layers, heads, act_layer=act_layer) - - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - assert ( - unlocked_groups == 0 - ), "partial locking not currently supported for this model" - for param in self.parameters(): - param.requires_grad = False - - def forward(self, x: torch.Tensor): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat( - [ - self.class_embedding.to(x.dtype) - + torch.zeros( - x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device - ), - x, - ], - dim=1, - ) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.text_branch(x) - x = x.permute(1, 0, 2) # LND -> NLD - - x = self.ln_post(x[:, 0, :]) - - if self.proj is not None: - x = x @ self.proj - - return x - - -@dataclass -class CLAPVisionCfg: - layers: Union[Tuple[int, int, int, int], int] = 12 - width: int = 768 - patch_size: int = 16 - image_size: Union[Tuple[int, int], int] = 224 - timm_model_name: str = ( - None # a valid model name overrides layers, width, patch_size - ) - timm_model_pretrained: bool = ( - False # use (imagenet) pretrained weights for named model - ) - timm_pool: str = ( - "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') - ) - timm_proj: str = ( - "linear" # linear projection for timm model output ('linear', 'mlp', '') - ) - - -# Audio Config Class -@dataclass -class CLAPAudioCfp: - model_type: str = "PANN" - model_name: str = "Cnn14" - sample_rate: int = 48000 - # Param - audio_length: int = 1024 - window_size: int = 1024 - hop_size: int = 1024 - fmin: int = 50 - fmax: int = 14000 - class_num: int = 527 - mel_bins: int = 64 - clip_samples: int = 480000 - - -@dataclass -class CLAPTextCfg: - context_length: int - vocab_size: int - width: int - heads: int - layers: int - model_type: str - - -class CLAP(nn.Module): - def __init__( - self, - embed_dim: int, - audio_cfg: CLAPAudioCfp, - text_cfg: CLAPTextCfg, - quick_gelu: bool = False, - enable_fusion: bool = False, - fusion_type: str = "None", - joint_embed_shape: int = 512, - mlp_act: str = "relu", - ): - super().__init__() - if isinstance(audio_cfg, dict): - audio_cfg = CLAPAudioCfp(**audio_cfg) - if isinstance(text_cfg, dict): - text_cfg = CLAPTextCfg(**text_cfg) - - self.audio_cfg = audio_cfg - self.text_cfg = text_cfg - self.enable_fusion = enable_fusion - self.fusion_type = fusion_type - self.joint_embed_shape = joint_embed_shape - self.mlp_act = mlp_act - - self.context_length = text_cfg.context_length - - # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more - # memory efficient in recent PyTorch releases (>= 1.10). - # NOTE: timm models always use native GELU regardless of quick_gelu flag. - act_layer = QuickGELU if quick_gelu else nn.GELU - - if mlp_act == "relu": - mlp_act_layer = nn.ReLU() - elif mlp_act == "gelu": - mlp_act_layer = nn.GELU() - else: - raise NotImplementedError - - # audio branch - # audio branch parameters - if audio_cfg.model_type == "PANN": - self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type) - elif audio_cfg.model_type == "HTSAT": - self.audio_branch = create_htsat_model( - audio_cfg, enable_fusion, fusion_type - ) - else: - logging.error(f"Model config for {audio_cfg.model_type} not found") - raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.") - - # text branch - # text branch parameters - if text_cfg.model_type == "transformer": - self.text_branch = Transformer( - width=text_cfg.width, - layers=text_cfg.layers, - heads=text_cfg.heads, - act_layer=act_layer, - ) - self.vocab_size = text_cfg.vocab_size - self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) - self.positional_embedding = nn.Parameter( - torch.empty(self.context_length, text_cfg.width) - ) - self.ln_final = LayerNorm(text_cfg.width) - self.text_transform = MLPLayers( - units=[ - self.joint_embed_shape, - self.joint_embed_shape, - self.joint_embed_shape, - ], - dropout=0.1, - ) - self.text_projection = nn.Sequential( - nn.Linear(text_cfg.width, self.joint_embed_shape), - mlp_act_layer, - nn.Linear(self.joint_embed_shape, self.joint_embed_shape), - ) - elif text_cfg.model_type == "bert": - self.text_branch = BertModel.from_pretrained("bert-base-uncased") - self.text_transform = MLPLayers( - units=[ - self.joint_embed_shape, - self.joint_embed_shape, - self.joint_embed_shape, - ], - dropout=0.1, - ) - self.text_projection = nn.Sequential( - nn.Linear(768, self.joint_embed_shape), - mlp_act_layer, - nn.Linear(self.joint_embed_shape, self.joint_embed_shape), - ) - elif text_cfg.model_type == "roberta": - self.text_branch = RobertaModel.from_pretrained("roberta-base") - self.text_transform = MLPLayers( - units=[ - self.joint_embed_shape, - self.joint_embed_shape, - self.joint_embed_shape, - ], - dropout=0.1, - ) - self.text_projection = nn.Sequential( - nn.Linear(768, self.joint_embed_shape), - mlp_act_layer, - nn.Linear(self.joint_embed_shape, self.joint_embed_shape), - ) - elif text_cfg.model_type == "bart": - self.text_branch = BartModel.from_pretrained("facebook/bart-base") - self.text_transform = MLPLayers( - units=[ - self.joint_embed_shape, - self.joint_embed_shape, - self.joint_embed_shape, - ], - dropout=0.1, - ) - self.text_projection = nn.Sequential( - nn.Linear(768, self.joint_embed_shape), - mlp_act_layer, - nn.Linear(self.joint_embed_shape, self.joint_embed_shape), - ) - else: - logging.error(f"Model config for {text_cfg.model_type} not found") - raise RuntimeError(f"Model config for {text_cfg.model_type} not found.") - self.text_branch_type = text_cfg.model_type - # text branch parameters - - # audio branch parameters - self.audio_transform = MLPLayers( - units=[ - self.joint_embed_shape, - self.joint_embed_shape, - self.joint_embed_shape, - ], - dropout=0.1, - ) - - # below here is text branch parameters - - # ============================================================================================================ - self.audio_projection = nn.Sequential( - nn.Linear(embed_dim, self.joint_embed_shape), - mlp_act_layer, - nn.Linear(self.joint_embed_shape, self.joint_embed_shape), - ) - - self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) - - self.init_text_branch_parameters() - - def init_text_branch_parameters(self): - if self.text_branch_type == "transformer": - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - proj_std = (self.text_branch.width**-0.5) * ( - (2 * self.text_branch.layers) ** -0.5 - ) - attn_std = self.text_branch.width**-0.5 - fc_std = (2 * self.text_branch.width) ** -0.5 - for block in self.text_branch.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - if self.text_branch_type == "bert" or self.text_branch_type == "roberta": - width = self.text_branch.embeddings.word_embeddings.weight.shape[-1] - elif self.text_branch_type == "bart": - width = self.text_branch.shared.weight.shape[-1] - else: - width = self.text_branch.width - nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07)) - nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07)) - - # deprecated - # if hasattr(self.visual, 'init_parameters'): - # self.visual.init_parameters() - - # if self.text_projection is not None: - # nn.init.normal_(self.text_projection, std=width**-0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - - def encode_audio(self, audio, device): - return self.audio_branch( - audio, mixup_lambda=None, device=device - ) # mix lambda needs to add - - # def list_of_dict_of_tensor2dict_of_tensor(self, x, device): - # tmp = {} - # for k in x[0].keys(): - # tmp[k] = [] - # for i in range(len(x)): - # tmp[k].append(x[i][k][:77]) - # for k in x[0].keys(): - # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True) - # return tmp - - def encode_text(self, text, device): - if self.text_branch_type == "transformer": - text = text.to(device=device, non_blocking=True) - x = self.token_embedding(text) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding - x = x.permute(1, 0, 2) # NLD -> LND - x = self.text_branch(x, attn_mask=self.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) - - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)]) - elif self.text_branch_type == "bert": - # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device) - # text = BatchEncoding(text) - x = self.text_branch( - input_ids=text["input_ids"].to(device=device, non_blocking=True), - attention_mask=text["attention_mask"].to( - device=device, non_blocking=True - ), - token_type_ids=text["token_type_ids"].to( - device=device, non_blocking=True - ), - )["pooler_output"] - x = self.text_projection(x) - elif self.text_branch_type == "roberta": - x = self.text_branch( - input_ids=text["input_ids"].to(device=device, non_blocking=True), - attention_mask=text["attention_mask"].to( - device=device, non_blocking=True - ), - )["pooler_output"] - x = self.text_projection(x) - elif self.text_branch_type == "bart": - x = torch.mean( - self.text_branch( - input_ids=text["input_ids"].to(device=device, non_blocking=True), - attention_mask=text["attention_mask"].to( - device=device, non_blocking=True - ), - )["encoder_last_hidden_state"], - axis=1, - ) - x = self.text_projection(x) - else: - logging.error(f"Model type {self.text_branch_type} not found") - raise RuntimeError(f"Model type {self.text_branch_type} not found.") - return x - - def forward(self, audio, text, device=None): - """Forward audio and text into the CLAP - - Parameters - ---------- - audio: torch.Tensor (batch_size, audio_length) - the time-domain audio input / the batch of mel_spec and longer list. - text: torch.Tensor () // need to add - the text token input - """ - if device is None: - if audio is not None: - device = audio.device - elif text is not None: - device = text.device - if audio is None and text is None: - # a hack to get the logit scale - return self.logit_scale_a.exp(), self.logit_scale_t.exp() - elif audio is None: - return self.encode_text(text, device=device) - elif text is None: - return self.audio_projection( - self.encode_audio(audio, device=device)["embedding"] - ) - audio_features = self.audio_projection( - self.encode_audio(audio, device=device)["embedding"] - ) - audio_features = F.normalize(audio_features, dim=-1) - - text_features = self.encode_text(text, device=device) - # print("text_features", text_features) - # print("text_features.shape", text_features.shape) - # print("text_features.type", type(text_features)) - text_features = F.normalize(text_features, dim=-1) - - audio_features_mlp = self.audio_transform(audio_features) - text_features_mlp = self.text_transform(text_features) - # Four outputs: audio features (basic & MLP), text features (basic & MLP) - return ( - audio_features, - text_features, - audio_features_mlp, - text_features_mlp, - self.logit_scale_a.exp(), - self.logit_scale_t.exp(), - ) - - def get_logit_scale(self): - return self.logit_scale_a.exp(), self.logit_scale_t.exp() - - def get_text_embedding(self, data): - """Get the text embedding from the model - - Parameters - ---------- - data: torch.Tensor - a tensor of text embedding - - Returns - ---------- - text_embed: torch.Tensor - a tensor of text_embeds (N, D) - - """ - device = next(self.parameters()).device - for k in data: - data[k] = data[k].to(device) - if(len(data[k].size()) < 2): - data[k] = data[k].unsqueeze(0) - text_embeds = self.encode_text(data, device=device) - text_embeds = F.normalize(text_embeds, dim=-1) - - return text_embeds - - def get_audio_embedding(self, data): - """Get the audio embedding from the model - - Parameters - ---------- - data: a list of dict - the audio input dict list from 'get_audio_feature' method - - Returns - ---------- - audio_embed: torch.Tensor - a tensor of audio_embeds (N, D) - - """ - device = next(self.parameters()).device - input_dict = {} - keys = data[0].keys() - for k in keys: - input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to( - device - ) - - audio_embeds = self.audio_projection( - self.encode_audio(input_dict, device=device)["embedding"] - ) - audio_embeds = F.normalize(audio_embeds, dim=-1) - - return audio_embeds - - def audio_infer(self, audio, hopsize=None, device=None): - """Forward one audio and produce the audio embedding - - Parameters - ---------- - audio: (audio_length) - the time-domain audio input, notice that it must be only one input - hopsize: int - the overlap hopsize as the sliding window - - Returns - ---------- - output_dict: { - key: [n, (embedding_shape)] if "HTS-AT" - or - key: [(embedding_shape)] if "PANN" - } - the list of key values of the audio branch - - """ - - assert not self.training, "the inference mode must be run at eval stage" - output_dict = {} - # PANN - if self.audio_cfg.model_type == "PANN": - audio_input = audio.unsqueeze(dim=0) - output_dict[key] = self.encode_audio(audio_input, device=device)[ - key - ].squeeze(dim=0) - elif self.audio_cfg.model_type == "HTSAT": - # repeat - audio_len = len(audio) - k = self.audio_cfg.clip_samples // audio_len - if k > 1: - audio = audio.repeat(k) - audio_len = len(audio) - - if hopsize is None: - hopsize = min(hopsize, audio_len) - - if audio_len > self.audio_cfg.clip_samples: - audio_input = [ - audio[pos : pos + self.audio_cfg.clip_samples].clone() - for pos in range( - 0, audio_len - self.audio_cfg.clip_samples, hopsize - ) - ] - audio_input.append(audio[-self.audio_cfg.clip_samples :].clone()) - audio_input = torch.stack(audio_input) - output_dict[key] = self.encode_audio(audio_input, device=device)[key] - else: - audio_input = audio.unsqueeze(dim=0) - output_dict[key] = self.encode_audio(audio_input, device=device)[ - key - ].squeeze(dim=0) - - return output_dict - - -def convert_weights_to_fp16(model: nn.Module): - """Convert applicable model parameters to fp16""" - - def _convert_weights_to_fp16(l): - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() - - if isinstance(l, nn.MultiheadAttention): - for attr in [ - *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], - "in_proj_bias", - "bias_k", - "bias_v", - ]: - tensor = getattr(l, attr) - if tensor is not None: - tensor.data = tensor.data.half() - - for name in ["text_projection", "proj"]: - if hasattr(l, name): - attr = getattr(l, name) - if attr is not None: - attr.data = attr.data.half() - - model.apply(_convert_weights_to_fp16) - - -# Ignore the state dict of the vision part -def build_model_from_openai_state_dict( - state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None" -): - - embed_dim = model_cfg["embed_dim"] - audio_cfg = model_cfg["audio_cfg"] - text_cfg = model_cfg["text_cfg"] - context_length = state_dict["positional_embedding"].shape[0] - vocab_size = state_dict["token_embedding.weight"].shape[0] - transformer_width = state_dict["ln_final.weight"].shape[0] - transformer_heads = transformer_width // 64 - transformer_layers = len( - set( - k.split(".")[2] - for k in state_dict - if k.startswith(f"transformer.resblocks") - ) - ) - - audio_cfg = CLAPAudioCfp(**audio_cfg) - text_cfg = CLAPTextCfg(**text_cfg) - - model = CLAP( - embed_dim, - audio_cfg=audio_cfg, - text_cfg=text_cfg, - quick_gelu=True, # OpenAI models were trained with QuickGELU - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - state_dict["logit_scale_a"] = state_dict["logit_scale"] - state_dict["logit_scale_t"] = state_dict["logit_scale"] - pop_keys = list(state_dict.keys())[::] - # pop the visual branch saved weights - for key in pop_keys: - if key.startswith("visual."): - state_dict.pop(key, None) - - for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]: - state_dict.pop(key, None) - - # not use fp16 - # convert_weights_to_fp16(model) - model.load_state_dict(state_dict, strict=False) - return model.eval() - - -def trace_model(model, batch_size=256, device=torch.device("cpu")): - model.eval() - audio_length = model.audio_cfg.audio_length - example_audio = torch.ones((batch_size, audio_length), device=device) - example_text = torch.zeros( - (batch_size, model.context_length), dtype=torch.int, device=device - ) - model = torch.jit.trace_module( - model, - inputs=dict( - forward=(example_audio, example_text), - encode_text=(example_text,), - encode_image=(example_audio,), - ), - ) - model.audio_cfg.audio_length = audio_length # Question: what does this do? - return model diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-base.json b/audioldm/clap/open_clip/model_configs/HTSAT-base.json deleted file mode 100644 index 6cef625a89daf4431f1c9f72e10bc9640eef2ba8..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/HTSAT-base.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 1024, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "HTSAT", - "model_name": "base" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-large.json b/audioldm/clap/open_clip/model_configs/HTSAT-large.json deleted file mode 100644 index 699cdb1b16855582606551e4196b24aba2ffd871..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/HTSAT-large.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 2048, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "HTSAT", - "model_name": "large" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json b/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json deleted file mode 100644 index 73e42990fe8361a0df502e7f93d29f19f58c9ecb..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 768, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1536, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "HTSAT", - "model_name": "tiny" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json b/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json deleted file mode 100644 index a6e7821163d9afa81c27345a1e472475b92af169..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 768, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "HTSAT", - "model_name": "tiny" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-10.json b/audioldm/clap/open_clip/model_configs/PANN-10.json deleted file mode 100644 index 954ddf62921aed7dde9c37ffffec98a2e96a4ee7..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/PANN-10.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 1024, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn10" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json deleted file mode 100644 index b7989bc0cd95d0d39049b7524eba508b3e386439..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 2048, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 18000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn14" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json deleted file mode 100644 index 56bdb56bedc304ffa52d8bf5988cea2c1d82d14e..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 2048, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 960000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 360, - "fmin": 50, - "fmax": 8000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn14" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json b/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json deleted file mode 100644 index 5756e3bebc97cc985f512cb081930fee4e49bec1..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 2048, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn14" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 4 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json b/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json deleted file mode 100644 index 5a9e7e208b661619d5e26625e849da1adda8a475..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 2048, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1536, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn14" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14.json b/audioldm/clap/open_clip/model_configs/PANN-14.json deleted file mode 100644 index 39a5134cde1d8c50f4758377c952ef22f07bab41..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/PANN-14.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 2048, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn14" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-6.json b/audioldm/clap/open_clip/model_configs/PANN-6.json deleted file mode 100644 index 21ebc344326de260c386ba77e0ad63cf9b04febf..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/PANN-6.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "embed_dim": 512, - "audio_cfg": { - "audio_length": 1024, - "clip_samples": 480000, - "mel_bins": 64, - "sample_rate": 48000, - "window_size": 1024, - "hop_size": 480, - "fmin": 50, - "fmax": 14000, - "class_num": 527, - "model_type": "PANN", - "model_name": "Cnn6" - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json b/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json deleted file mode 100644 index d0db2c161d13138788c4609d373b023b8454d624..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "embed_dim": 512, - "quick_gelu": true, - "vision_cfg": { - "image_size": 224, - "layers": [ - 3, - 4, - 23, - 3 - ], - "width": 64, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN101.json b/audioldm/clap/open_clip/model_configs/RN101.json deleted file mode 100644 index b88b4d3acbaa701c614ab0ea65fc88fcfe289c32..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/RN101.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": [ - 3, - 4, - 23, - 3 - ], - "width": 64, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json b/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json deleted file mode 100644 index 8c2f91260cdeb043434dc1e893cce81d4ce7f0d1..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "embed_dim": 1024, - "quick_gelu": true, - "vision_cfg": { - "image_size": 224, - "layers": [ - 3, - 4, - 6, - 3 - ], - "width": 64, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} diff --git a/audioldm/clap/open_clip/model_configs/RN50.json b/audioldm/clap/open_clip/model_configs/RN50.json deleted file mode 100644 index 33aa884d54fee0076c33676831e49d5e1ffcb8f2..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/RN50.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 224, - "layers": [ - 3, - 4, - 6, - 3 - ], - "width": 64, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN50x16.json b/audioldm/clap/open_clip/model_configs/RN50x16.json deleted file mode 100644 index 3161e1a2c9a839161e652a4d729c2cdc971161db..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/RN50x16.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 384, - "layers": [ - 6, - 8, - 18, - 8 - ], - "width": 96, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN50x4.json b/audioldm/clap/open_clip/model_configs/RN50x4.json deleted file mode 100644 index e155237f8ce1026aaaeecc80751eabe6f329f0bb..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/RN50x4.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "embed_dim": 640, - "vision_cfg": { - "image_size": 288, - "layers": [ - 4, - 6, - 10, - 6 - ], - "width": 80, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 640, - "heads": 10, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-16.json b/audioldm/clap/open_clip/model_configs/ViT-B-16.json deleted file mode 100644 index 395eea77ec3907c0611531aba63459b193e67b9c..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/ViT-B-16.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json b/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json deleted file mode 100644 index ce6bd923593293ed50dfcfb28b73ca7403bcf3c5..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 512, - "quick_gelu": true, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-32.json b/audioldm/clap/open_clip/model_configs/ViT-B-32.json deleted file mode 100644 index 07c8e28eb06fa1813ba932fe4eec668262d1c47f..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/ViT-B-32.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-L-14.json b/audioldm/clap/open_clip/model_configs/ViT-L-14.json deleted file mode 100644 index d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/model_configs/ViT-L-14.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 224, - "layers": 24, - "width": 1024, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/audioldm/clap/open_clip/openai.py b/audioldm/clap/open_clip/openai.py deleted file mode 100644 index 3f4eb8b55fe960e1792b3da804b60b3d8f70fe26..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/openai.py +++ /dev/null @@ -1,156 +0,0 @@ -""" OpenAI pretrained model functions - -Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. -""" - -import os -import warnings -from typing import Union, List - -import torch - -from .model import build_model_from_openai_state_dict -from .pretrained import ( - get_pretrained_url, - list_pretrained_tag_models, - download_pretrained, -) - -__all__ = ["list_openai_models", "load_openai_model"] - - -def list_openai_models() -> List[str]: - """Returns the names of available CLIP models""" - return list_pretrained_tag_models("openai") - - -def load_openai_model( - name: str, - model_cfg, - device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", - jit=True, - cache_dir=os.path.expanduser("~/.cache/clip"), - enable_fusion: bool = False, - fusion_type: str = "None", -): - """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model - - Parameters - ---------- - name : str - A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict - device : Union[str, torch.device] - The device to put the loaded model - jit : bool - Whether to load the optimized JIT model (default) or more hackable non-JIT model. - - Returns - ------- - model : torch.nn.Module - The CLAP model - preprocess : Callable[[PIL.Image], torch.Tensor] - A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input - """ - if get_pretrained_url(name, "openai"): - model_path = download_pretrained( - get_pretrained_url(name, "openai"), root=cache_dir - ) - elif os.path.isfile(name): - model_path = name - else: - raise RuntimeError( - f"Model {name} not found; available models = {list_openai_models()}" - ) - - try: - # loading JIT archive - model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() - state_dict = None - except RuntimeError: - # loading saved state dict - if jit: - warnings.warn( - f"File {model_path} is not a JIT archive. Loading as a state dict instead" - ) - jit = False - state_dict = torch.load(model_path, map_location="cpu") - - if not jit: - try: - model = build_model_from_openai_state_dict( - state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type - ).to(device) - except KeyError: - sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} - model = build_model_from_openai_state_dict( - sd, model_cfg, enable_fusion, fusion_type - ).to(device) - - if str(device) == "cpu": - model.float() - return model - - # patch the device names - device_holder = torch.jit.trace( - lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] - ) - device_node = [ - n - for n in device_holder.graph.findAllNodes("prim::Constant") - if "Device" in repr(n) - ][-1] - - def patch_device(module): - try: - graphs = [module.graph] if hasattr(module, "graph") else [] - except RuntimeError: - graphs = [] - - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("prim::Constant"): - if "value" in node.attributeNames() and str(node["value"]).startswith( - "cuda" - ): - node.copyAttributes(device_node) - - model.apply(patch_device) - patch_device(model.encode_audio) - patch_device(model.encode_text) - - # patch dtype to float32 on CPU - if str(device) == "cpu": - float_holder = torch.jit.trace( - lambda: torch.ones([]).float(), example_inputs=[] - ) - float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] - float_node = float_input.node() - - def patch_float(module): - try: - graphs = [module.graph] if hasattr(module, "graph") else [] - except RuntimeError: - graphs = [] - - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("aten::to"): - inputs = list(node.inputs()) - for i in [ - 1, - 2, - ]: # dtype can be the second or third argument to aten::to() - if inputs[i].node()["value"] == 5: - inputs[i].node().copyAttributes(float_node) - - model.apply(patch_float) - patch_float(model.encode_audio) - patch_float(model.encode_text) - model.float() - - model.audio_branch.audio_length = model.audio_cfg.audio_length - return model diff --git a/audioldm/clap/open_clip/pann_model.py b/audioldm/clap/open_clip/pann_model.py deleted file mode 100644 index 874a03fc6eabcfdf3a63c59ca1e05d4f991453c5..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/pann_model.py +++ /dev/null @@ -1,703 +0,0 @@ -# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition -# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn -# Some layers are re-designed for CLAP -import os - -os.environ["NUMBA_CACHE_DIR"] = "/tmp/" - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchlibrosa.stft import Spectrogram, LogmelFilterBank -from torchlibrosa.augmentation import SpecAugmentation - -from .utils import do_mixup, interpolate, pad_framewise_output -from .feature_fusion import iAFF, AFF, DAF - - -def init_layer(layer): - """Initialize a Linear or Convolutional layer.""" - nn.init.xavier_uniform_(layer.weight) - - if hasattr(layer, "bias"): - if layer.bias is not None: - layer.bias.data.fill_(0.0) - -def init_bn(bn): - """Initialize a Batchnorm layer.""" - bn.bias.data.fill_(0.0) - bn.weight.data.fill_(1.0) - - -class ConvBlock(nn.Module): - def __init__(self, in_channels, out_channels): - - super(ConvBlock, self).__init__() - - self.conv1 = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(3, 3), - stride=(1, 1), - padding=(1, 1), - bias=False, - ) - - self.conv2 = nn.Conv2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=(3, 3), - stride=(1, 1), - padding=(1, 1), - bias=False, - ) - - self.bn1 = nn.BatchNorm2d(out_channels) - self.bn2 = nn.BatchNorm2d(out_channels) - - self.init_weight() - - def init_weight(self): - init_layer(self.conv1) - init_layer(self.conv2) - init_bn(self.bn1) - init_bn(self.bn2) - - def forward(self, input, pool_size=(2, 2), pool_type="avg"): - - x = input - x = F.relu_(self.bn1(self.conv1(x))) - x = F.relu_(self.bn2(self.conv2(x))) - if pool_type == "max": - x = F.max_pool2d(x, kernel_size=pool_size) - elif pool_type == "avg": - x = F.avg_pool2d(x, kernel_size=pool_size) - elif pool_type == "avg+max": - x1 = F.avg_pool2d(x, kernel_size=pool_size) - x2 = F.max_pool2d(x, kernel_size=pool_size) - x = x1 + x2 - else: - raise Exception("Incorrect argument!") - - return x - - -class ConvBlock5x5(nn.Module): - def __init__(self, in_channels, out_channels): - - super(ConvBlock5x5, self).__init__() - - self.conv1 = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(5, 5), - stride=(1, 1), - padding=(2, 2), - bias=False, - ) - - self.bn1 = nn.BatchNorm2d(out_channels) - - self.init_weight() - - def init_weight(self): - init_layer(self.conv1) - init_bn(self.bn1) - - def forward(self, input, pool_size=(2, 2), pool_type="avg"): - - x = input - x = F.relu_(self.bn1(self.conv1(x))) - if pool_type == "max": - x = F.max_pool2d(x, kernel_size=pool_size) - elif pool_type == "avg": - x = F.avg_pool2d(x, kernel_size=pool_size) - elif pool_type == "avg+max": - x1 = F.avg_pool2d(x, kernel_size=pool_size) - x2 = F.max_pool2d(x, kernel_size=pool_size) - x = x1 + x2 - else: - raise Exception("Incorrect argument!") - - return x - - -class AttBlock(nn.Module): - def __init__(self, n_in, n_out, activation="linear", temperature=1.0): - super(AttBlock, self).__init__() - - self.activation = activation - self.temperature = temperature - self.att = nn.Conv1d( - in_channels=n_in, - out_channels=n_out, - kernel_size=1, - stride=1, - padding=0, - bias=True, - ) - self.cla = nn.Conv1d( - in_channels=n_in, - out_channels=n_out, - kernel_size=1, - stride=1, - padding=0, - bias=True, - ) - - self.bn_att = nn.BatchNorm1d(n_out) - self.init_weights() - - def init_weights(self): - init_layer(self.att) - init_layer(self.cla) - init_bn(self.bn_att) - - def forward(self, x): - # x: (n_samples, n_in, n_time) - norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) - cla = self.nonlinear_transform(self.cla(x)) - x = torch.sum(norm_att * cla, dim=2) - return x, norm_att, cla - - def nonlinear_transform(self, x): - if self.activation == "linear": - return x - elif self.activation == "sigmoid": - return torch.sigmoid(x) - - -class Cnn14(nn.Module): - def __init__( - self, - sample_rate, - window_size, - hop_size, - mel_bins, - fmin, - fmax, - classes_num, - enable_fusion=False, - fusion_type="None", - ): - - super(Cnn14, self).__init__() - - window = "hann" - center = True - pad_mode = "reflect" - ref = 1.0 - amin = 1e-10 - top_db = None - - self.enable_fusion = enable_fusion - self.fusion_type = fusion_type - - # Spectrogram extractor - self.spectrogram_extractor = Spectrogram( - n_fft=window_size, - hop_length=hop_size, - win_length=window_size, - window=window, - center=center, - pad_mode=pad_mode, - freeze_parameters=True, - ) - - # Logmel feature extractor - self.logmel_extractor = LogmelFilterBank( - sr=sample_rate, - n_fft=window_size, - n_mels=mel_bins, - fmin=fmin, - fmax=fmax, - ref=ref, - amin=amin, - top_db=top_db, - freeze_parameters=True, - ) - - # Spec augmenter - self.spec_augmenter = SpecAugmentation( - time_drop_width=64, - time_stripes_num=2, - freq_drop_width=8, - freq_stripes_num=2, - ) - - self.bn0 = nn.BatchNorm2d(64) - - if (self.enable_fusion) and (self.fusion_type == "channel_map"): - self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) - else: - self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) - self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) - self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) - self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) - self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) - self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) - - self.fc1 = nn.Linear(2048, 2048, bias=True) - self.fc_audioset = nn.Linear(2048, classes_num, bias=True) - - if (self.enable_fusion) and ( - self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] - ): - self.mel_conv1d = nn.Sequential( - nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), - nn.BatchNorm1d(64), # No Relu - ) - if self.fusion_type == "daf_1d": - self.fusion_model = DAF() - elif self.fusion_type == "aff_1d": - self.fusion_model = AFF(channels=64, type="1D") - elif self.fusion_type == "iaff_1d": - self.fusion_model = iAFF(channels=64, type="1D") - - if (self.enable_fusion) and ( - self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] - ): - self.mel_conv2d = nn.Sequential( - nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)), - nn.BatchNorm2d(64), - nn.ReLU(inplace=True), - ) - - if self.fusion_type == "daf_2d": - self.fusion_model = DAF() - elif self.fusion_type == "aff_2d": - self.fusion_model = AFF(channels=64, type="2D") - elif self.fusion_type == "iaff_2d": - self.fusion_model = iAFF(channels=64, type="2D") - self.init_weight() - - def init_weight(self): - init_bn(self.bn0) - init_layer(self.fc1) - init_layer(self.fc_audioset) - - def forward(self, input, mixup_lambda=None, device=None): - """ - Input: (batch_size, data_length)""" - - if self.enable_fusion and input["longer"].sum() == 0: - # if no audio is longer than 10s, then randomly select one audio to be longer - input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True - - if not self.enable_fusion: - x = self.spectrogram_extractor( - input["waveform"].to(device=device, non_blocking=True) - ) # (batch_size, 1, time_steps, freq_bins) - x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) - - x = x.transpose(1, 3) - x = self.bn0(x) - x = x.transpose(1, 3) - else: - longer_list = input["longer"].to(device=device, non_blocking=True) - x = input["mel_fusion"].to(device=device, non_blocking=True) - longer_list_idx = torch.where(longer_list)[0] - x = x.transpose(1, 3) - x = self.bn0(x) - x = x.transpose(1, 3) - if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: - new_x = x[:, 0:1, :, :].clone().contiguous() - # local processing - if len(longer_list_idx) > 0: - fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() - FB, FC, FT, FF = fusion_x_local.size() - fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) - fusion_x_local = torch.permute( - fusion_x_local, (0, 2, 1) - ).contiguous() - fusion_x_local = self.mel_conv1d(fusion_x_local) - fusion_x_local = fusion_x_local.view( - FB, FC, FF, fusion_x_local.size(-1) - ) - fusion_x_local = ( - torch.permute(fusion_x_local, (0, 2, 1, 3)) - .contiguous() - .flatten(2) - ) - if fusion_x_local.size(-1) < FT: - fusion_x_local = torch.cat( - [ - fusion_x_local, - torch.zeros( - (FB, FF, FT - fusion_x_local.size(-1)), - device=device, - ), - ], - dim=-1, - ) - else: - fusion_x_local = fusion_x_local[:, :, :FT] - # 1D fusion - new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() - new_x[longer_list_idx] = self.fusion_model( - new_x[longer_list_idx], fusion_x_local - ) - x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] - else: - x = new_x - elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: - x = x # no change - - if self.training: - x = self.spec_augmenter(x) - # Mixup on spectrogram - if self.training and mixup_lambda is not None: - x = do_mixup(x, mixup_lambda) - if (self.enable_fusion) and ( - self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] - ): - global_x = x[:, 0:1, :, :] - - # global processing - B, C, H, W = global_x.shape - global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg") - if len(longer_list_idx) > 0: - local_x = x[longer_list_idx, 1:, :, :].contiguous() - TH = global_x.size(-2) - # local processing - B, C, H, W = local_x.shape - local_x = local_x.view(B * C, 1, H, W) - local_x = self.mel_conv2d(local_x) - local_x = local_x.view( - B, C, local_x.size(1), local_x.size(2), local_x.size(3) - ) - local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3) - TB, TC, _, TW = local_x.size() - if local_x.size(-2) < TH: - local_x = torch.cat( - [ - local_x, - torch.zeros( - (TB, TC, TH - local_x.size(-2), TW), - device=global_x.device, - ), - ], - dim=-2, - ) - else: - local_x = local_x[:, :, :TH, :] - - global_x[longer_list_idx] = self.fusion_model( - global_x[longer_list_idx], local_x - ) - x = global_x - else: - x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") - - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = torch.mean(x, dim=3) - - latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) - latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) - latent_x = latent_x1 + latent_x2 - latent_x = latent_x.transpose(1, 2) - latent_x = F.relu_(self.fc1(latent_x)) - latent_output = interpolate(latent_x, 32) - - (x1, _) = torch.max(x, dim=2) - x2 = torch.mean(x, dim=2) - x = x1 + x2 - x = F.dropout(x, p=0.5, training=self.training) - x = F.relu_(self.fc1(x)) - embedding = F.dropout(x, p=0.5, training=self.training) - clipwise_output = torch.sigmoid(self.fc_audioset(x)) - - output_dict = { - "clipwise_output": clipwise_output, - "embedding": embedding, - "fine_grained_embedding": latent_output, - } - return output_dict - - -class Cnn6(nn.Module): - def __init__( - self, - sample_rate, - window_size, - hop_size, - mel_bins, - fmin, - fmax, - classes_num, - enable_fusion=False, - fusion_type="None", - ): - - super(Cnn6, self).__init__() - - window = "hann" - center = True - pad_mode = "reflect" - ref = 1.0 - amin = 1e-10 - top_db = None - - self.enable_fusion = enable_fusion - self.fusion_type = fusion_type - - # Spectrogram extractor - self.spectrogram_extractor = Spectrogram( - n_fft=window_size, - hop_length=hop_size, - win_length=window_size, - window=window, - center=center, - pad_mode=pad_mode, - freeze_parameters=True, - ) - - # Logmel feature extractor - self.logmel_extractor = LogmelFilterBank( - sr=sample_rate, - n_fft=window_size, - n_mels=mel_bins, - fmin=fmin, - fmax=fmax, - ref=ref, - amin=amin, - top_db=top_db, - freeze_parameters=True, - ) - - # Spec augmenter - self.spec_augmenter = SpecAugmentation( - time_drop_width=64, - time_stripes_num=2, - freq_drop_width=8, - freq_stripes_num=2, - ) - - self.bn0 = nn.BatchNorm2d(64) - - self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) - self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) - self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) - self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) - - self.fc1 = nn.Linear(512, 512, bias=True) - self.fc_audioset = nn.Linear(512, classes_num, bias=True) - - self.init_weight() - - def init_weight(self): - init_bn(self.bn0) - init_layer(self.fc1) - init_layer(self.fc_audioset) - - def forward(self, input, mixup_lambda=None, device=None): - """ - Input: (batch_size, data_length)""" - - x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) - x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) - - x = x.transpose(1, 3) - x = self.bn0(x) - x = x.transpose(1, 3) - - if self.training: - x = self.spec_augmenter(x) - - # Mixup on spectrogram - if self.training and mixup_lambda is not None: - x = do_mixup(x, mixup_lambda) - - x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = torch.mean(x, dim=3) - - latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) - latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) - latent_x = latent_x1 + latent_x2 - latent_x = latent_x.transpose(1, 2) - latent_x = F.relu_(self.fc1(latent_x)) - latent_output = interpolate(latent_x, 16) - - (x1, _) = torch.max(x, dim=2) - x2 = torch.mean(x, dim=2) - x = x1 + x2 - x = F.dropout(x, p=0.5, training=self.training) - x = F.relu_(self.fc1(x)) - embedding = F.dropout(x, p=0.5, training=self.training) - clipwise_output = torch.sigmoid(self.fc_audioset(x)) - - output_dict = { - "clipwise_output": clipwise_output, - "embedding": embedding, - "fine_grained_embedding": latent_output, - } - - return output_dict - - -class Cnn10(nn.Module): - def __init__( - self, - sample_rate, - window_size, - hop_size, - mel_bins, - fmin, - fmax, - classes_num, - enable_fusion=False, - fusion_type="None", - ): - - super(Cnn10, self).__init__() - - window = "hann" - center = True - pad_mode = "reflect" - ref = 1.0 - amin = 1e-10 - top_db = None - - self.enable_fusion = enable_fusion - self.fusion_type = fusion_type - - # Spectrogram extractor - self.spectrogram_extractor = Spectrogram( - n_fft=window_size, - hop_length=hop_size, - win_length=window_size, - window=window, - center=center, - pad_mode=pad_mode, - freeze_parameters=True, - ) - - # Logmel feature extractor - self.logmel_extractor = LogmelFilterBank( - sr=sample_rate, - n_fft=window_size, - n_mels=mel_bins, - fmin=fmin, - fmax=fmax, - ref=ref, - amin=amin, - top_db=top_db, - freeze_parameters=True, - ) - - # Spec augmenter - self.spec_augmenter = SpecAugmentation( - time_drop_width=64, - time_stripes_num=2, - freq_drop_width=8, - freq_stripes_num=2, - ) - - self.bn0 = nn.BatchNorm2d(64) - - self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) - self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) - self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) - self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) - self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) - - self.fc1 = nn.Linear(1024, 1024, bias=True) - self.fc_audioset = nn.Linear(1024, classes_num, bias=True) - - self.init_weight() - - def init_weight(self): - init_bn(self.bn0) - init_layer(self.fc1) - init_layer(self.fc_audioset) - - def forward(self, input, mixup_lambda=None, device=None): - """ - Input: (batch_size, data_length)""" - - x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) - x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) - - x = x.transpose(1, 3) - x = self.bn0(x) - x = x.transpose(1, 3) - - if self.training: - x = self.spec_augmenter(x) - - # Mixup on spectrogram - if self.training and mixup_lambda is not None: - x = do_mixup(x, mixup_lambda) - - x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") - x = F.dropout(x, p=0.2, training=self.training) - x = torch.mean(x, dim=3) - - latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) - latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) - latent_x = latent_x1 + latent_x2 - latent_x = latent_x.transpose(1, 2) - latent_x = F.relu_(self.fc1(latent_x)) - latent_output = interpolate(latent_x, 32) - - (x1, _) = torch.max(x, dim=2) - x2 = torch.mean(x, dim=2) - x = x1 + x2 - x = F.dropout(x, p=0.5, training=self.training) - x = F.relu_(self.fc1(x)) - embedding = F.dropout(x, p=0.5, training=self.training) - clipwise_output = torch.sigmoid(self.fc_audioset(x)) - - output_dict = { - "clipwise_output": clipwise_output, - "embedding": embedding, - "fine_grained_embedding": latent_output, - } - - return output_dict - - -def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"): - try: - ModelProto = eval(audio_cfg.model_name) - model = ModelProto( - sample_rate=audio_cfg.sample_rate, - window_size=audio_cfg.window_size, - hop_size=audio_cfg.hop_size, - mel_bins=audio_cfg.mel_bins, - fmin=audio_cfg.fmin, - fmax=audio_cfg.fmax, - classes_num=audio_cfg.class_num, - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - return model - except: - raise RuntimeError( - f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." - ) diff --git a/audioldm/clap/open_clip/pretrained.py b/audioldm/clap/open_clip/pretrained.py deleted file mode 100644 index e211d8b5b59320a599e62605f1dee6199f317253..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/pretrained.py +++ /dev/null @@ -1,167 +0,0 @@ -import hashlib -import os -import urllib -import warnings - -from tqdm import tqdm - -_RN50 = dict( - openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", - yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", - cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", -) - -_RN50_quickgelu = dict( - openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", - yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", - cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", -) - -_RN101 = dict( - openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", - yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", -) - -_RN101_quickgelu = dict( - openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", - yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", -) - -_RN50x4 = dict( - openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", -) - -_RN50x16 = dict( - openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", -) - -_RN50x64 = dict( - openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", -) - -_VITB32 = dict( - openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", - laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", - laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", - laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", -) - -_VITB32_quickgelu = dict( - openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", - laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", - laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", - laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", -) - -_VITB16 = dict( - openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", -) - -_VITL14 = dict( - openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", -) - -_PRETRAINED = { - "RN50": _RN50, - "RN50-quickgelu": _RN50_quickgelu, - "RN101": _RN101, - "RN101-quickgelu": _RN101_quickgelu, - "RN50x4": _RN50x4, - "RN50x16": _RN50x16, - "ViT-B-32": _VITB32, - "ViT-B-32-quickgelu": _VITB32_quickgelu, - "ViT-B-16": _VITB16, - "ViT-L-14": _VITL14, -} - - -def list_pretrained(as_str: bool = False): - """returns list of pretrained models - Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True - """ - return [ - ":".join([k, t]) if as_str else (k, t) - for k in _PRETRAINED.keys() - for t in _PRETRAINED[k].keys() - ] - - -def list_pretrained_tag_models(tag: str): - """return all models having the specified pretrain tag""" - models = [] - for k in _PRETRAINED.keys(): - if tag in _PRETRAINED[k]: - models.append(k) - return models - - -def list_pretrained_model_tags(model: str): - """return all pretrain tags for the specified model architecture""" - tags = [] - if model in _PRETRAINED: - tags.extend(_PRETRAINED[model].keys()) - return tags - - -def get_pretrained_url(model: str, tag: str): - if model not in _PRETRAINED: - return "" - model_pretrained = _PRETRAINED[model] - if tag not in model_pretrained: - return "" - return model_pretrained[tag] - - -def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): - os.makedirs(root, exist_ok=True) - filename = os.path.basename(url) - - if "openaipublic" in url: - expected_sha256 = url.split("/")[-2] - else: - expected_sha256 = "" - - download_target = os.path.join(root, filename) - - if os.path.exists(download_target) and not os.path.isfile(download_target): - raise RuntimeError(f"{download_target} exists and is not a regular file") - - if os.path.isfile(download_target): - if expected_sha256: - if ( - hashlib.sha256(open(download_target, "rb").read()).hexdigest() - == expected_sha256 - ): - return download_target - else: - warnings.warn( - f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" - ) - else: - return download_target - - with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm( - total=int(source.info().get("Content-Length")), - ncols=80, - unit="iB", - unit_scale=True, - ) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - if ( - expected_sha256 - and hashlib.sha256(open(download_target, "rb").read()).hexdigest() - != expected_sha256 - ): - raise RuntimeError( - f"Model has been downloaded but the SHA256 checksum does not not match" - ) - - return download_target diff --git a/audioldm/clap/open_clip/timm_model.py b/audioldm/clap/open_clip/timm_model.py deleted file mode 100644 index c9d1ab4666b5bab5038d44b90c9ddca5087de460..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/timm_model.py +++ /dev/null @@ -1,112 +0,0 @@ -""" timm model adapter - -Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. -""" -from collections import OrderedDict - -import torch.nn as nn - -try: - import timm - from timm.models.layers import Mlp, to_2tuple - from timm.models.layers.attention_pool2d import RotAttentionPool2d - from timm.models.layers.attention_pool2d import ( - AttentionPool2d as AbsAttentionPool2d, - ) -except ImportError as e: - timm = None - -from .utils import freeze_batch_norm_2d - - -class TimmModel(nn.Module): - """timm model adapter - # FIXME this adapter is a work in progress, may change in ways that break weight compat - """ - - def __init__( - self, - model_name, - embed_dim, - image_size=224, - pool="avg", - proj="linear", - drop=0.0, - pretrained=False, - ): - super().__init__() - if timm is None: - raise RuntimeError("Please `pip install timm` to use timm models.") - - self.image_size = to_2tuple(image_size) - self.trunk = timm.create_model(model_name, pretrained=pretrained) - feat_size = self.trunk.default_cfg.get("pool_size", None) - feature_ndim = 1 if not feat_size else 2 - if pool in ("abs_attn", "rot_attn"): - assert feature_ndim == 2 - # if attn pooling used, remove both classifier and default pool - self.trunk.reset_classifier(0, global_pool="") - else: - # reset global pool if pool config set, otherwise leave as network default - reset_kwargs = dict(global_pool=pool) if pool else {} - self.trunk.reset_classifier(0, **reset_kwargs) - prev_chs = self.trunk.num_features - - head_layers = OrderedDict() - if pool == "abs_attn": - head_layers["pool"] = AbsAttentionPool2d( - prev_chs, feat_size=feat_size, out_features=embed_dim - ) - prev_chs = embed_dim - elif pool == "rot_attn": - head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) - prev_chs = embed_dim - else: - assert proj, "projection layer needed if non-attention pooling is used." - - # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used - if proj == "linear": - head_layers["drop"] = nn.Dropout(drop) - head_layers["proj"] = nn.Linear(prev_chs, embed_dim) - elif proj == "mlp": - head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) - - self.head = nn.Sequential(head_layers) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - """lock modules - Args: - unlocked_groups (int): leave last n layer groups unlocked (default: 0) - """ - if not unlocked_groups: - # lock full model - for param in self.trunk.parameters(): - param.requires_grad = False - if freeze_bn_stats: - freeze_batch_norm_2d(self.trunk) - else: - # NOTE: partial freeze requires latest timm (master) branch and is subject to change - try: - # FIXME import here until API stable and in an official release - from timm.models.helpers import group_parameters, group_modules - except ImportError: - raise RuntimeError( - "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" - ) - matcher = self.trunk.group_matcher() - gparams = group_parameters(self.trunk, matcher) - max_layer_id = max(gparams.keys()) - max_layer_id = max_layer_id - unlocked_groups - for group_idx in range(max_layer_id + 1): - group = gparams[group_idx] - for param in group: - self.trunk.get_parameter(param).requires_grad = False - if freeze_bn_stats: - gmodules = group_modules(self.trunk, matcher, reverse=True) - gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} - freeze_batch_norm_2d(self.trunk, gmodules) - - def forward(self, x): - x = self.trunk(x) - x = self.head(x) - return x diff --git a/audioldm/clap/open_clip/tokenizer.py b/audioldm/clap/open_clip/tokenizer.py deleted file mode 100644 index ee4d28450ec5dd12a79daf38cf3088e9e73c2cd5..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/tokenizer.py +++ /dev/null @@ -1,197 +0,0 @@ -""" CLIP tokenizer - -Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. -""" -import gzip -import html -import os -from functools import lru_cache -from typing import Union, List - -import ftfy -import regex as re -import torch - - -@lru_cache() -def default_bpe(): - return os.path.join( - os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" - ) - - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) - ) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8 + n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -class SimpleTokenizer(object): - def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") - merges = merges[1 : 49152 - 256 - 2 + 1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v + "" for v in vocab] - for merge in merges: - vocab.append("".join(merge)) - if not special_tokens: - special_tokens = ["", ""] - else: - special_tokens = ["", ""] + special_tokens - vocab.extend(special_tokens) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.decoder = {v: k for k, v in self.encoder.items()} - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {t: t for t in special_tokens} - special = "|".join(special_tokens) - self.pat = re.compile( - special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", - re.IGNORECASE, - ) - - self.vocab_size = len(self.encoder) - self.all_special_ids = [self.encoder[t] for t in special_tokens] - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + (token[-1] + "",) - pairs = get_pairs(word) - - if not pairs: - return token + "" - - while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word) - 1 and word[i + 1] == second: - new_word.append(first + second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = " ".join(word) - self.cache[token] = word - return word - - def encode(self, text): - bpe_tokens = [] - text = whitespace_clean(basic_clean(text)).lower() - for token in re.findall(self.pat, text): - token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) - bpe_tokens.extend( - self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") - ) - return bpe_tokens - - def decode(self, tokens): - text = "".join([self.decoder[token] for token in tokens]) - text = ( - bytearray([self.byte_decoder[c] for c in text]) - .decode("utf-8", errors="replace") - .replace("", " ") - ) - return text - - -_tokenizer = SimpleTokenizer() - - -def tokenize( - texts: Union[str, List[str]], context_length: int = 77 -) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] - - sot_token = _tokenizer.encoder[""] - eot_token = _tokenizer.encoder[""] - all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - tokens = tokens[:context_length] # Truncate - result[i, : len(tokens)] = torch.tensor(tokens) - - return result diff --git a/audioldm/clap/open_clip/transform.py b/audioldm/clap/open_clip/transform.py deleted file mode 100644 index 77aaa722c4a5544ac50de6df35d3e922f63b111d..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/transform.py +++ /dev/null @@ -1,45 +0,0 @@ -from torchvision.transforms import ( - Normalize, - Compose, - RandomResizedCrop, - InterpolationMode, - ToTensor, - Resize, - CenterCrop, -) - - -def _convert_to_rgb(image): - return image.convert("RGB") - - -def image_transform( - image_size: int, - is_train: bool, - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), -): - normalize = Normalize(mean=mean, std=std) - if is_train: - return Compose( - [ - RandomResizedCrop( - image_size, - scale=(0.9, 1.0), - interpolation=InterpolationMode.BICUBIC, - ), - _convert_to_rgb, - ToTensor(), - normalize, - ] - ) - else: - return Compose( - [ - Resize(image_size, interpolation=InterpolationMode.BICUBIC), - CenterCrop(image_size), - _convert_to_rgb, - ToTensor(), - normalize, - ] - ) diff --git a/audioldm/clap/open_clip/utils.py b/audioldm/clap/open_clip/utils.py deleted file mode 100644 index de59fd2746a13742197ecdeac671d61ece3f79ba..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/utils.py +++ /dev/null @@ -1,361 +0,0 @@ -import numpy as np -import torch -from torch import nn as nn -from torchvision.ops.misc import FrozenBatchNorm2d -import logging -# import h5py -from tqdm import tqdm -import random -import json -import os -import pathlib - -# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later. -dataset_split = { - "audiocaps": ["train", "valid", "test"], - "audioset": ["balanced_train", "unbalanced_train", "eval"], - "BBCSoundEffects": ["train", "test"], - "Clotho": ["train", "test", "valid"], - "free_to_use_sounds": ["train", "test"], - "paramount_motion": ["train", "test"], - "sonniss_game_effects": ["train", "test"], - "wesoundeffects": ["train", "test"], - "MACS": ["train", "test"], - "freesound": ["train", "test"], - "FSD50K": ["train", "test", "valid"], - "fsd50k_class_label": ["train", "test", "valid"], - "esc50": ["train", "test"], - "audiostock": ["train", "test"], - "freesound_no_overlap_noesc50": ["train", "test"], - "epidemic_sound_effects": ["train", "test"], - "VGGSound": ["train", "test"], - "urbansound8k_class_label": ["train", "test"], - "audioset_t5": ["balanced_train", "unbalanced_train", "eval"], - "epidemic_sound_effects_t5": ["train", "test"], - "WavText5K": ["train", "test"], - "esc50_no_overlap": ["train", "test"], - "usd8k_no_overlap": ["train", "test"], - "fsd50k_200_class_label": ["train", "test", "valid"], -} - - -def freeze_batch_norm_2d(module, module_match={}, name=""): - """ - Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is - itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and - returned. Otherwise, the module is walked recursively and submodules are converted in place. - - Args: - module (torch.nn.Module): Any PyTorch module. - module_match (dict): Dictionary of full module names to freeze (all if empty) - name (str): Full module name (prefix) - - Returns: - torch.nn.Module: Resulting module - - Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 - """ - res = module - is_match = True - if module_match: - is_match = name in module_match - if is_match and isinstance( - module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) - ): - res = FrozenBatchNorm2d(module.num_features) - res.num_features = module.num_features - res.affine = module.affine - if module.affine: - res.weight.data = module.weight.data.clone().detach() - res.bias.data = module.bias.data.clone().detach() - res.running_mean.data = module.running_mean.data - res.running_var.data = module.running_var.data - res.eps = module.eps - else: - for child_name, child in module.named_children(): - full_child_name = ".".join([name, child_name]) if name else child_name - new_child = freeze_batch_norm_2d(child, module_match, full_child_name) - if new_child is not child: - res.add_module(child_name, new_child) - return res - - -def exist(dataset_name, dataset_type): - """ - Check if dataset exists - """ - if dataset_type in dataset_split[dataset_name]: - return True - else: - return False - - -def get_tar_path_from_dataset_name( - dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None -): - """ - Get tar path from dataset name and type - """ - output = [] - for n in dataset_names: - if full_dataset is not None and n in full_dataset: - current_dataset_types = dataset_split[n] - else: - current_dataset_types = dataset_types - for s in current_dataset_types: - tmp = [] - if islocal: - sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json" - if not os.path.exists(sizefilepath_): - sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" - else: - sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" - if not os.path.exists(sizefilepath_): - continue - sizes = json.load(open(sizefilepath_, "r")) - for k in sizes.keys(): - if islocal: - tmp.append(f"{dataset_path}/{n}/{s}/{k}") - else: - tmp.append( - f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -" - ) - if proportion != 1: - tmp = random.sample(tmp, int(proportion * len(tmp))) - output.append(tmp) - return sum(output, []) - - -def get_tar_path_from_txts(txt_path, islocal, proportion=1): - """ - Get tar path from txt path - """ - if isinstance(txt_path, (list, tuple)): - return sum( - [ - get_tar_path_from_txts( - txt_path[i], islocal=islocal, proportion=proportion - ) - for i in range(len(txt_path)) - ], - [], - ) - if isinstance(txt_path, str): - with open(txt_path) as f: - lines = f.readlines() - if islocal: - lines = [ - lines[i] - .split("\n")[0] - .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/") - for i in range(len(lines)) - ] - else: - lines = [ - lines[i].split("\n")[0].replace(".tar", ".tar -") - for i in range(len(lines)) - ] - if proportion != 1: - print("Sampling tars with proportion of {}".format(proportion)) - lines = random.sample(lines, int(proportion * len(lines))) - return lines - - -def get_mix_lambda(mixup_alpha, batch_size): - mixup_lambdas = [ - np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size) - ] - return np.array(mixup_lambdas).astype(np.float32) - - -def do_mixup(x, mixup_lambda): - """ - Args: - x: (batch_size , ...) - mixup_lambda: (batch_size,) - Returns: - out: (batch_size, ...) - """ - out = ( - x.transpose(0, -1) * mixup_lambda - + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda) - ).transpose(0, -1) - return out - - -def interpolate(x, ratio): - """Interpolate data in time domain. This is used to compensate the - resolution reduction in downsampling of a CNN. - - Args: - x: (batch_size, time_steps, classes_num) - ratio: int, ratio to interpolate - Returns: - upsampled: (batch_size, time_steps * ratio, classes_num) - """ - (batch_size, time_steps, classes_num) = x.shape - upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) - upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) - return upsampled - - -def pad_framewise_output(framewise_output, frames_num): - """Pad framewise_output to the same length as input frames. The pad value - is the same as the value of the last frame. - Args: - framewise_output: (batch_size, frames_num, classes_num) - frames_num: int, number of frames to pad - Outputs: - output: (batch_size, frames_num, classes_num) - """ - pad = framewise_output[:, -1:, :].repeat( - 1, frames_num - framewise_output.shape[1], 1 - ) - """tensor for padding""" - - output = torch.cat((framewise_output, pad), dim=1) - """(batch_size, frames_num, classes_num)""" - - -# def process_ipc(index_path, classes_num, filename): -# # load data -# logging.info("Load Data...............") -# ipc = [[] for _ in range(classes_num)] -# with h5py.File(index_path, "r") as f: -# for i in tqdm(range(len(f["target"]))): -# t_class = np.where(f["target"][i])[0] -# for t in t_class: -# ipc[t].append(i) -# print(ipc) -# np.save(filename, ipc) -# logging.info("Load Data Succeed...............") - - -def save_to_dict(s, o_={}): - sp = s.split(": ") - o_.update({sp[0]: float(sp[1])}) - return o_ - - -def get_data_from_log(txt_path): - """ - Output dictionary from out.txt log file - """ - with open(txt_path) as f: - lines = f.readlines() - val_data = {} - train_data = {} - train_losses = [] - train_losses_epoch = [] - for i in range(len(lines)): - if "| INFO |" in lines[i]: - if "Eval Epoch" in lines[i]: - if "val_loss" in lines[i]: - # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", "")) - line = lines[i].split("Eval Epoch: ")[-1] - num_epoch = int(line.split(" ")[0].split(" ")[0]) - d = { - line.split(" ")[0] - .split(" ")[1] - .replace(":", ""): float(line.split(" ")[0].split(" ")[-1]) - } - for i in range(1, len(line.split(" "))): - d = save_to_dict(line.split(" ")[i], d) - val_data[num_epoch] = d - elif "Train Epoch" in lines[i]: - num_epoch = int(lines[i].split("Train Epoch: ")[1][0]) - loss = float(lines[i].split("Loss: ")[-1].split(" (")[0]) - train_losses.append(loss) - train_losses_epoch.append(num_epoch) - for i in range(len(train_losses)): - train_data[i] = { - "num_epoch": train_losses_epoch[i], - "train_loss": train_losses[i], - } - return train_data, val_data - - -def save_p(obj, filename): - import pickle - - try: - from deepdiff import DeepDiff - except: - os.system("pip install deepdiff") - from deepdiff import DeepDiff - with open(filename, "wb") as file: - pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol - with open(filename, "rb") as file: - z = pickle.load(file) - assert ( - DeepDiff(obj, z, ignore_string_case=True) == {} - ), "there is something wrong with the saving process" - return - - -def load_p(filename): - import pickle - - with open(filename, "rb") as file: - z = pickle.load(file) - return z - - -def save_json(data, name="data.json"): - import json - - with open(name, "w") as fp: - json.dump(data, fp) - return - - -def load_json(name): - import json - - with open(name, "r") as fp: - data = json.load(fp) - return data - - -from multiprocessing import Process, Manager -from multiprocessing import Process, Value, Array -from ctypes import c_wchar - - -def load_class_label(path): - # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing - # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array - out = None - if path is not None: - if pathlib.Path(path).suffix in [".pkl", ".pickle"]: - out = load_p(path) - elif pathlib.Path(path).suffix in [".json", ".txt"]: - out = load_json(path) - elif pathlib.Path(path).suffix in [".npy", ".npz"]: - out = np.load(path) - elif pathlib.Path(path).suffix in [".csv"]: - import pandas as pd - - out = pd.read_csv(path) - return out - # if out is None: - # return None - # else: - # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False) - # val = Array('i', out.values(), lock=False) - # return (key, val) - - -from torch import optim - - -def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): - if optimizer_name.lower() == "adamw": - optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps) - elif optimizer_name.lower() == "sgd": - optimizer = optim.SGD(params, lr=lr, momentum=momentum) - elif optimizer_name.lower() == "adam": - optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps) - else: - raise ValueError("optimizer name is not correct") - return optimizer diff --git a/audioldm/clap/open_clip/version.py b/audioldm/clap/open_clip/version.py deleted file mode 100644 index 3ced3581bb601ae91b1e1da4b8f4f520855a065e..0000000000000000000000000000000000000000 --- a/audioldm/clap/open_clip/version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.2.1" diff --git a/audioldm/clap/training/__init__.py b/audioldm/clap/training/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm/clap/training/audioset_textmap.npy b/audioldm/clap/training/audioset_textmap.npy deleted file mode 100644 index 3da4c92d3819aaec11e5f576464a9973a6df811b..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/audioset_textmap.npy +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b -size 84448 diff --git a/audioldm/clap/training/data.py b/audioldm/clap/training/data.py deleted file mode 100644 index 1d80d598be97d4e04f1b7f3e53a877cfe82ce667..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/data.py +++ /dev/null @@ -1,977 +0,0 @@ -import ast -import json -import logging -import math -import os -import random -# import h5py -from dataclasses import dataclass -from audioldm.clap.training.params import parse_args -# import braceexpand -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.datasets as datasets -import torchvision.transforms -# import webdataset as wds -from PIL import Image -from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler -from torch.utils.data.distributed import DistributedSampler -from functools import partial -import soundfile as sf -import io -from pathlib import Path -# import wget - -from audioldm.clap.open_clip.utils import ( - get_tar_path_from_dataset_name, - dataset_split, -) -from audioldm.clap.open_clip.utils import load_p, load_class_label -import copy - -try: - import horovod.torch as hvd -except ImportError: - hvd = None - -try: - import torchaudio -except ImportError: - torchaudio = None - -from audioldm.clap.open_clip import tokenize - - -def tokenizer(text): - return tokenize(text).squeeze(0) - - -from transformers import RobertaTokenizer - -tokenize = RobertaTokenizer.from_pretrained("roberta-base") - - -def tokenizer(text): - result = tokenize( - text, - padding="max_length", - truncation=True, - max_length=77, - return_tensors="pt", - ) - return {k: v.squeeze(0) for k, v in result.items()} - - -# initizlied the audioset map -_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy") -_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True) - - -def int16_to_float32(x): - return (x / 32767.0).astype(np.float32) - - -def float32_to_int16(x): - x = np.clip(x, a_min=-1.0, a_max=1.0) - return (x * 32767.0).astype(np.int16) - - -# For Toy Dataset -# class ToyDataset(Dataset): -# def __init__(self, index_path, ipc, config, eval_mode=False): -# """Toy Dataset for testing the audioset input with text labels -# Parameters -# ---------- -# index_path: str -# the link to the h5 file of each audio -# idc: str -# the link to the npy file, the number of samples in each class -# config: dict -# the audio cfg file -# eval_model (bool): to indicate if the dataset is a testing dataset -# """ -# self.audio_cfg = config["audio_cfg"] -# self.text_cfg = config["text_cfg"] -# self.fp = h5py.File(index_path, "r") -# self.ipc = np.load(ipc, allow_pickle=True) -# self.total_size = len(self.fp["audio_name"]) -# self.classes_num = self.audio_cfg["class_num"] -# self.eval_mode = eval_mode - -# if not eval_mode: -# self.generate_queue() -# else: -# self.queue = [] -# for i in range(self.total_size): -# target = self.fp["target"][i] -# if np.sum(target) > 0: -# self.queue.append(i) -# self.total_size = len(self.queue) -# logging.info("total dataset size: %d" % (self.total_size)) -# logging.info("class num: %d" % (self.classes_num)) - -# def time_shifting(self, x): -# frame_num = len(x) -# shift_len = random.randint(0, frame_num - 1) -# new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0) -# return new_sample - -# def generate_queue(self): -# self.queue = [] -# while len(self.queue) < self.total_size: -# class_set = [*range(self.classes_num)] -# random.shuffle(class_set) -# self.queue += [ -# self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set -# ] -# self.queue = self.queue[: self.total_size] - -# logging.info("queue regenerated:%s" % (self.queue[-5:])) - -# def crop_wav(self, x): -# crop_size = self.audio_cfg["crop_size"] -# crop_pos = random.randint(0, len(x) - crop_size - 1) -# return x[crop_pos : crop_pos + crop_size] - -# def prompt_text(self, target): -# events = _AUDIOSET_MAP[np.where(target > 0)] -# event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1] -# text = tokenize(event_text)[0] -# return text - -# def __getitem__(self, index): -# """Load waveform, text, and target of an audio clip - -# Parameters -# ---------- -# index: int -# the index number -# Return -# ------ -# output: dict { -# "hdf5_path": str, -# "index_in_hdf5": int, -# "audio_name": str, -# "waveform": list (audio_length,), -# "target": list (class_num, ), -# "text": torch.tensor (context_length,) -# } -# the output dictionary -# """ -# s_index = self.queue[index] - -# audio_name = self.fp["audio_name"][s_index].decode() -# # Hardcode here CHANGE -# hdf5_path = ( -# self.fp["hdf5_path"][s_index] -# .decode() -# .replace( -# "../workspace", -# "/home/la/kechen/Research/ke_zsasp/workspace", -# ) -# ) -# r_idx = self.fp["index_in_hdf5"][s_index] -# target = self.fp["target"][s_index].astype(np.float32) -# text = self.prompt_text(target) -# with h5py.File(hdf5_path, "r") as f: -# waveform = int16_to_float32(f["waveform"][r_idx])[ -# : self.audio_cfg["clip_samples"] -# ] -# assert ( -# len(waveform) == self.audio_cfg["clip_samples"] -# ), "The sample length is not match" -# # Time shift -# # if (self.config.enable_time_shift) and (not self.eval_mode): -# # waveform = self.time_shifting(waveform) -# # # Label Enhance -# # if (self.config.crop_size is not None) and (not self.eval_mode): -# # waveform = self.crop_wav(waveform) -# # # the label enhance rate is fixed 0.5 -# # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5: -# # kidx = np.where(target)[0] -# # for k in kidx: -# # for add_key in self.class_map[k][1]: -# # target[add_key] = 1.0 -# # if len(self.class_map[k][2]) > 0: -# # add_key = random.choice(self.class_map[k][2]) -# # target[add_key] = 1.0 - -# # missing the text input -# mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :] -# mel_spec = ( -# torch.cat( -# [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0 -# ) -# .cpu() -# .numpy() -# ) -# longer = random.choice([True, False]) -# if longer == False: -# mel_spec[1:, :, :] = 0.0 -# data_dict = { -# "hdf5_path": hdf5_path, -# "index_in_hdf5": r_idx, -# "audio_name": audio_name, -# "waveform": waveform, -# "class_label": target, -# "text": text, -# "longer": longer, -# "mel_fusion": mel_spec, -# } -# return data_dict - -# def __len__(self): -# return self.total_size - - -class CsvDataset(Dataset): - def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"): - logging.debug(f"Loading csv data from {input_filename}.") - df = pd.read_csv(input_filename, sep=sep) - - self.images = df[img_key].tolist() - self.captions = df[caption_key].tolist() - self.transforms = transforms - logging.debug("Done loading data.") - - def __len__(self): - return len(self.captions) - - def __getitem__(self, idx): - images = self.transforms(Image.open(str(self.images[idx]))) - texts = tokenize([str(self.captions[idx])])[0] - return images, texts - - -@dataclass -class DataInfo: - dataloader: DataLoader - sampler: DistributedSampler - - -def preprocess_txt(text): - return tokenize([str(text)])[0] - - -def get_dataset_size(shards, sizefilepath_=None, is_local=True): - if isinstance(shards, list): - size_list = [] - for s in shards: - size_list.append( - get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0] - ) - else: - if not is_local: - for n in dataset_split.keys(): - if n in shards.split("/"): - break - for s in dataset_split[n]: - if s in shards.split("/"): - break - sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" - shards_list = list(braceexpand.braceexpand(shards)) - dir_path = os.path.dirname(shards) - if sizefilepath_ is not None: - sizes = json.load(open(sizefilepath_, "r")) - total_size = sum( - [ - int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))]) - for shard in shards_list - ] - ) - else: - sizes_filename = os.path.join(dir_path, "sizes.json") - len_filename = os.path.join(dir_path, "__len__") - if os.path.exists(sizes_filename): - sizes = json.load(open(sizes_filename, "r")) - total_size = sum( - [int(sizes[os.path.basename(shard)]) for shard in shards_list] - ) - elif os.path.exists(len_filename): - # FIXME this used to be eval(open(...)) but that seemed rather unsafe - total_size = ast.literal_eval(open(len_filename, "r").read()) - else: - raise Exception( - "Cannot find sizes file for dataset. Please specify the path to the file." - ) - # total_size = None # num samples undefined - # some common dataset sizes (at time of authors last download) - # cc3m-train: 2905954 - # cc12m: 10968539 - # LAION-400m: 407332084 - num_shards = len(shards_list) - if isinstance(shards, list): - return sum(size_list), len(shards) - else: - return total_size, num_shards - - -def get_imagenet(args, preprocess_fns, split): - assert split in ["train", "val", "v2"] - is_train = split == "train" - preprocess_train, preprocess_val = preprocess_fns - - if split == "v2": - from imagenetv2_pytorch import ImageNetV2Dataset - - dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) - else: - if is_train: - data_path = args.imagenet_train - preprocess_fn = preprocess_train - else: - data_path = args.imagenet_val - preprocess_fn = preprocess_val - assert data_path - - dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) - - if is_train: - idxs = np.zeros(len(dataset.targets)) - target_array = np.array(dataset.targets) - k = 50 - for c in range(1000): - m = target_array == c - n = len(idxs[m]) - arr = np.zeros(n) - arr[:k] = 1 - np.random.shuffle(arr) - idxs[m] = arr - - idxs = idxs.astype("int") - sampler = SubsetRandomSampler(np.where(idxs)[0]) - else: - sampler = None - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=args.batch_size, - num_workers=args.workers, - sampler=sampler, - ) - - return DataInfo(dataloader, sampler) - - -def count_samples(dataloader): - os.environ["WDS_EPOCH"] = "0" - n_elements, n_batches = 0, 0 - for images, texts in dataloader: - n_batches += 1 - n_elements += len(images) - assert len(images) == len(texts) - return n_elements, n_batches - - -def filter_no_caption(sample): - return "txt" in sample - - -def log_and_continue(exn): - """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" - logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") - return True - - -_SHARD_SHUFFLE_SIZE = 2000 -_SHARD_SHUFFLE_INITIAL = 500 -_SAMPLE_SHUFFLE_SIZE = 5000 -_SAMPLE_SHUFFLE_INITIAL = 1000 - - -def sample_prop(sizefile, inputs, proportion, is_local=True): - """ - Sample a proportion of the data. - """ - file_path_dict = { - os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0] - for i in range(len(inputs)) - } - sampled_filepath_dict = {} - sampled_size_dict = {} - if not is_local: - if os.path.exists("sizes.json"): - os.remove("sizes.json") - wget.download(sizefile, "sizes.json") - sizefile = "sizes.json" - with open(sizefile, "r", encoding="UTF-8") as f: - load_dict = json.load(f) - L = int(len(file_path_dict) * proportion) - subkeys = random.sample(file_path_dict.keys(), L) - for k in subkeys: - sampled_size_dict[k] = load_dict[k] - sampled_filepath_dict[k] = file_path_dict[k] - return ( - sum(sampled_size_dict.values()), - L, - [os.path.join(v, k) for k, v in sampled_filepath_dict.items()], - sampled_size_dict, - ) - - -def get_mel(audio_data, audio_cfg): - # mel shape: (n_mels, T) - mel = torchaudio.transforms.MelSpectrogram( - sample_rate=audio_cfg["sample_rate"], - n_fft=audio_cfg["window_size"], - win_length=audio_cfg["window_size"], - hop_length=audio_cfg["hop_size"], - center=True, - pad_mode="reflect", - power=2.0, - norm=None, - onesided=True, - n_mels=64, - f_min=audio_cfg["fmin"], - f_max=audio_cfg["fmax"], - ).to(audio_data.device) - mel = mel(audio_data) - # Align to librosa: - # librosa_melspec = librosa.feature.melspectrogram( - # waveform, - # sr=audio_cfg['sample_rate'], - # n_fft=audio_cfg['window_size'], - # hop_length=audio_cfg['hop_size'], - # win_length=audio_cfg['window_size'], - # center=True, - # pad_mode="reflect", - # power=2.0, - # n_mels=64, - # norm=None, - # htk=True, - # f_min=audio_cfg['fmin'], - # f_max=audio_cfg['fmax'] - # ) - # we use log mel spectrogram as input - mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) - return mel.T # (T, n_mels) - - -def get_audio_features( - sample, audio_data, max_len, data_truncating, data_filling, audio_cfg -): - """ - Calculate and add audio features to sample. - Sample: a dict containing all the data of current sample. - audio_data: a tensor of shape (T) containing audio data. - max_len: the maximum length of audio data. - data_truncating: the method of truncating data. - data_filling: the method of filling data. - audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. - """ - with torch.no_grad(): - if len(audio_data) > max_len: - if data_truncating == "rand_trunc": - longer = torch.tensor([True]) - elif data_truncating == "fusion": - # fusion - mel = get_mel(audio_data, audio_cfg) - # split to three parts - chunk_frames = ( - max_len // audio_cfg["hop_size"] + 1 - ) # the +1 related to how the spectrogram is computed - total_frames = mel.shape[0] - if chunk_frames == total_frames: - # there is a corner case where the audio length is - # larger than max_len but smaller than max_len+hop_size. - # In this case, we just use the whole audio. - mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) - sample["mel_fusion"] = mel_fusion - longer = torch.tensor([False]) - else: - ranges = np.array_split( - list(range(0, total_frames - chunk_frames + 1)), 3 - ) - # print('total_frames-chunk_frames:', total_frames-chunk_frames, - # 'len(audio_data):', len(audio_data), - # 'chunk_frames:', chunk_frames, - # 'total_frames:', total_frames) - if len(ranges[1]) == 0: - # if the audio is too short, we just use the first chunk - ranges[1] = [0] - if len(ranges[2]) == 0: - # if the audio is too short, we just use the first chunk - ranges[2] = [0] - # randomly choose index for each part - idx_front = np.random.choice(ranges[0]) - idx_middle = np.random.choice(ranges[1]) - idx_back = np.random.choice(ranges[2]) - # select mel - mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :] - mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :] - mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :] - - # shrink the mel - mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])( - mel[None] - )[0] - # logging.info(f"mel_shrink.shape: {mel_shrink.shape}") - - # stack - mel_fusion = torch.stack( - [mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink], - dim=0, - ) - sample["mel_fusion"] = mel_fusion - longer = torch.tensor([True]) - else: - raise NotImplementedError( - f"data_truncating {data_truncating} not implemented" - ) - # random crop to max_len (for compatibility) - overflow = len(audio_data) - max_len - idx = np.random.randint(0, overflow + 1) - audio_data = audio_data[idx : idx + max_len] - - else: # padding if too short - if len(audio_data) < max_len: # do nothing if equal - if data_filling == "repeatpad": - n_repeat = int(max_len / len(audio_data)) - audio_data = audio_data.repeat(n_repeat) - # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0) - # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0] - audio_data = F.pad( - audio_data, - (0, max_len - len(audio_data)), - mode="constant", - value=0, - ) - elif data_filling == "pad": - audio_data = F.pad( - audio_data, - (0, max_len - len(audio_data)), - mode="constant", - value=0, - ) - elif data_filling == "repeat": - n_repeat = int(max_len / len(audio_data)) - audio_data = audio_data.repeat(n_repeat + 1)[:max_len] - else: - raise NotImplementedError( - f"data_filling {data_filling} not implemented" - ) - if data_truncating == "fusion": - mel = get_mel(audio_data, audio_cfg) - mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) - sample["mel_fusion"] = mel_fusion - longer = torch.tensor([False]) - - sample["longer"] = longer - sample["waveform"] = audio_data - - return sample - - -def preprocess( - sample, - audio_ext, - text_ext, - max_len, - audio_cfg, - class_index_dict=None, - data_filling="pad", - data_truncating="rand_trunc", - text_augment_selection=None, -): - """ - Preprocess a single sample for wdsdataloader. - """ - audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) - audio_data = int16_to_float32(float32_to_int16(audio_data)) - audio_data = torch.tensor(audio_data).float() - - # TODO: (yusong) to be include in the future - # # if torchaudio not installed, use soundfile to load audio - # if torchaudio is None: - # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) - # audio_data = torch.tensor(audio_data).float() - # else: - # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py - # with tempfile.TemporaryDirectory() as dirname: - # os.makedirs(dirname, exist_ok=True) - # fname = os.path.join(dirname, f"file.flac") - # with open(fname, "wb") as stream: - # stream.write(sample[audio_ext]) - # audio_data, orig_sr = torchaudio.load(fname) - # audio_data = audio_data[0, :].float() - - sample = get_audio_features( - sample, audio_data, max_len, data_truncating, data_filling, audio_cfg - ) - del sample[audio_ext] - - try: - json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) - except: - print("sample[__url__]:", sample["__url__"]) - - # For selecting augmented text from dataset - if text_augment_selection is None or text_augment_selection == "none": - texts = json_dict_raw["text"] - elif text_augment_selection == "all": - if "text_augment_all" in json_dict_raw.keys(): - texts = json_dict_raw["text_augment_all"] - else: - texts = json_dict_raw["text"] - elif text_augment_selection == "augment_only": - if "text_augment_all" in json_dict_raw.keys(): - if json_dict_raw["text_augment_t5"] is None: - texts = json_dict_raw["text"] - else: - texts = json_dict_raw["text_augment_t5"] - else: - texts = json_dict_raw["text"] - else: - raise NotImplementedError( - f"text_augment_selection {text_augment_selection} not implemented" - ) - sample["full_text"] = texts - - if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: - texts = random.choice(texts) - sample["raw_text"] = texts - sample["text"] = tokenizer(texts) # text shape: [num_token] - if class_index_dict is not None: - # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing - # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array - # key, val = class_index_dict - # key = key[:].split('\n') - # _dict = {k: v for k, v in zip(key, val)} - sample["class_label"] = np.zeros(len(class_index_dict.keys())) - for x in json_dict_raw["tag"]: - sample["class_label"][class_index_dict[x]] = 1 - sample["class_label"] = torch.tensor(sample["class_label"]).float() - del sample[text_ext] - sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext - sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext - sample["audio_orig_sr"] = orig_sr - return sample - - -def collate_fn(batch): - """ - Collate function for wdsdataloader. - batch: a list of dict, each dict is a sample - """ - # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend. - batch_dict = {} - for k in batch[0].keys(): - if isinstance(batch[0][k], dict): # dealwith bert tokenizer output - batch_dict[k] = {} - for kk in batch[0][k].keys(): - tmp = [] - for i in range(len(batch)): - tmp.append(batch[i][k][kk]) - batch_dict[k][kk] = torch.vstack(tmp) - elif isinstance(batch[0][k], torch.Tensor): - batch_dict[k] = torch.stack([sample[k] for sample in batch]) - elif isinstance(batch[0][k], np.ndarray): - batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch])) - else: - batch_dict[k] = [sample[k] for sample in batch] - return batch_dict - - -def get_wds_dataset( - args, - model_cfg, - is_train, - audio_ext="flac", - text_ext="json", - max_len=480000, - proportion=1.0, - sizefilepath_=None, - is_local=None, -): - """ - Get a dataset for wdsdataloader. - """ - if is_local is None and (not args.remotedata is None): - is_local = not args.remotedata - - input_shards = args.train_data if is_train else args.val_data - assert input_shards is not None - - if not sizefilepath_ is None: - sizefilepath = sizefilepath_ - else: - sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json") - - if proportion != 1.0: - num_samples, num_shards, input_shards, _ = sample_prop( - sizefilepath, input_shards, proportion, is_local=is_local - ) - else: - num_samples, num_shards = get_dataset_size( - input_shards, sizefilepath_=sizefilepath_, is_local=is_local - ) - - if not num_samples: - if is_train: - num_samples = args.train_num_samples - if not num_samples: - raise RuntimeError( - "Currently, number of dataset samples must be specified for training dataset. " - "Please specify via `--train-num-samples` if no dataset length info present." - ) - else: - num_samples = ( - args.val_num_samples or 0 - ) # eval will just exhaust the iterator if not specified - - pipeline = [wds.SimpleShardList(input_shards)] - # at this point we have an iterator over all the shards - # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node - if is_train or args.parallel_eval: - pipeline.extend( - [ - wds.detshuffle( - bufsize=_SHARD_SHUFFLE_SIZE, - initial=_SHARD_SHUFFLE_INITIAL, - seed=args.seed, - ), - wds.split_by_node, - wds.split_by_worker, - # at this point, we have an iterator over the shards assigned to each worker at each node - wds.tarfile_to_samples(handler=log_and_continue), - wds.shuffle( - bufsize=_SAMPLE_SHUFFLE_SIZE, - initial=_SAMPLE_SHUFFLE_INITIAL, - rng=random.Random(args.seed), - ), - # wds.repeatedly, # FIXME determine if this is beneficial - ] - ) - else: - pipeline.extend( - [ - wds.split_by_worker, - # at this point, we have an iterator over the shards assigned to each worker - wds.tarfile_to_samples(handler=log_and_continue), - ] - ) - pipeline.append( - wds.map( - partial( - preprocess, - audio_ext=audio_ext, - text_ext=text_ext, - max_len=max_len, - audio_cfg=model_cfg["audio_cfg"], - class_index_dict=copy.deepcopy(args.class_index_dict), - data_filling=args.data_filling, - data_truncating=args.data_truncating, - text_augment_selection=args.text_augment_selection, - ) - ), - ) - - pipeline.append( - wds.batched( - args.batch_size, - partial=not (is_train or args.parallel_eval), - collation_fn=collate_fn, - ) - ) - - dataset = wds.DataPipeline(*pipeline) - if is_train or args.parallel_eval: - # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples. - # (yusong): See comments below. - # roll over and repeat a few samples to get same number of full batches on each node - global_batch_size = args.batch_size * args.world_size - num_batches = math.ceil(num_samples / global_batch_size) - num_workers = max(1, args.workers) - num_worker_batches = math.ceil( - num_batches / num_workers - ) # per dataloader worker - num_batches = num_worker_batches * num_workers - num_samples = num_batches * global_batch_size - dataset = dataset.with_epoch( - num_worker_batches - ) # each worker is iterating over this - else: - # last batches are partial, eval is done on single (master) node - num_batches = math.ceil(num_samples / args.batch_size) - - kwargs = {} - if args.horovod: # multi-node training on summit - kwargs["multiprocessing_context"] = "forkserver" - - dataloader = wds.WebLoader( - dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs - ) - - # FIXME not clear which approach is better, with_epoch before vs after dataloader? - # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 - # if is_train: - # # roll over and repeat a few samples to get same number of full batches on each node - # global_batch_size = args.batch_size * args.world_size - # num_batches = math.ceil(num_samples / global_batch_size) - # num_workers = max(1, args.workers) - # num_batches = math.ceil(num_batches / num_workers) * num_workers - # num_samples = num_batches * global_batch_size - # dataloader = dataloader.with_epoch(num_batches) - # else: - # # last batches are partial, eval is done on single (master) node - # num_batches = math.ceil(num_samples / args.batch_size) - - # add meta-data to dataloader instance for convenience - dataloader.num_batches = num_batches - dataloader.num_samples = num_samples - - return DataInfo(dataloader, None) - - -def wds_batch_list2dict( - batch, - keys=[ - "__url__", - "__key__", - "waveform", - "text", - "raw_text", - "audio_name", - "text_name", - "audio_orig_sr", - ], -): - """ - Return a dictionary of the batch, with keys as the names of the fields. - """ - assert len(keys) == len( - batch - ), "batch must have same number of keys as keys argument" - return {keys[i]: batch[i] for i in range(len(batch))} - - -def get_csv_dataset(args, preprocess_fn, is_train): - input_filename = args.train_data if is_train else args.val_data - assert input_filename - dataset = CsvDataset( - input_filename, - preprocess_fn, - img_key=args.csv_img_key, - caption_key=args.csv_caption_key, - sep=args.csv_separator, - ) - num_samples = len(dataset) - sampler = DistributedSampler(dataset) if args.distributed and is_train else None - shuffle = is_train and sampler is None - - dataloader = DataLoader( - dataset, - batch_size=args.batch_size, - shuffle=shuffle, - num_workers=args.workers, - pin_memory=True, - sampler=sampler, - drop_last=is_train, - ) - dataloader.num_samples = num_samples - dataloader.num_batches = len(dataloader) - - return DataInfo(dataloader, sampler) - - -def get_toy_dataset(args, model_cfg, is_train): - index_path = args.train_data if is_train else args.val_data - ipc_path = args.train_ipc if is_train else args.val_ipc - assert index_path and ipc_path - eval_mode = not is_train - dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode) - - num_samples = len(dataset) - sampler = ( - DistributedSampler(dataset, shuffle=False) - if args.distributed and is_train - else None - ) - - dataloader = DataLoader( - dataset, - batch_size=args.batch_size, - shuffle=False, - num_workers=args.workers, - sampler=sampler, - drop_last=is_train, - ) - dataloader.num_samples = num_samples - dataloader.num_batches = len(dataloader) - - return DataInfo(dataloader, sampler) - - -def get_dataset_fn(data_path, dataset_type): - if dataset_type == "webdataset": - return get_wds_dataset - elif dataset_type == "csv": - return get_csv_dataset - elif dataset_type == "auto": - ext = data_path.split(".")[-1] - if ext in ["csv", "tsv"]: - return get_csv_dataset - elif ext in ["tar"]: - return get_wds_dataset - else: - raise ValueError( - f"Tried to figure out dataset type, but failed for extention {ext}." - ) - elif dataset_type == "toy": - return get_toy_dataset - else: - raise ValueError(f"Unsupported dataset type: {dataset_type}") - - -def get_data(args, model_cfg): - data = {} - - args.class_index_dict = load_class_label(args.class_label_path) - - if args.datasetinfos is None: - args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] - if args.dataset_type == "webdataset": - args.train_data = get_tar_path_from_dataset_name( - args.datasetnames, - args.datasetinfos, - islocal=not args.remotedata, - proportion=args.dataset_proportion, - dataset_path=args.datasetpath, - full_dataset=args.full_train_dataset, - ) - - if args.full_train_dataset is None: - args.full_train_dataset = [] - if args.exclude_eval_dataset is None: - args.exclude_eval_dataset = [] - excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset - - val_dataset_names = ( - [n for n in args.datasetnames if n not in excluded_eval_datasets] - if excluded_eval_datasets - else args.datasetnames - ) - args.val_dataset_names = val_dataset_names - args.val_data = get_tar_path_from_dataset_name( - val_dataset_names, - ["valid", "test", "eval"], - islocal=not args.remotedata, - proportion=1, - dataset_path=args.datasetpath, - full_dataset=None, - ) - - if args.train_data: - data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( - args, model_cfg, is_train=True - ) - - if args.val_data: - data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( - args, model_cfg, is_train=False - ) - - return data diff --git a/audioldm/clap/training/distributed.py b/audioldm/clap/training/distributed.py deleted file mode 100644 index 2fa61f76c5cc3ab9f6a9643042afa8e1f2e1cb7f..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/distributed.py +++ /dev/null @@ -1,150 +0,0 @@ -import os - -import torch -import socket - -try: - import horovod.torch as hvd -except ImportError: - hvd = None - - -def is_global_master(args): - return args.rank == 0 - - -def is_local_master(args): - return args.local_rank == 0 - - -def is_master(args, local=False): - return is_local_master(args) if local else is_global_master(args) - - -def is_using_horovod(): - # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set - # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... - ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] - pmi_vars = ["PMI_RANK", "PMI_SIZE"] - if all([var in os.environ for var in ompi_vars]) or all( - [var in os.environ for var in pmi_vars] - ): - return True - else: - return False - - -def is_using_distributed(): - if "WORLD_SIZE" in os.environ: - return int(os.environ["WORLD_SIZE"]) > 1 - if "SLURM_NTASKS" in os.environ: - return int(os.environ["SLURM_NTASKS"]) > 1 - return False - - -def world_info_from_env(): - local_rank = 0 - for v in ( - "SLURM_LOCALID", - "MPI_LOCALRANKID", - "OMPI_COMM_WORLD_LOCAL_RANK", - "LOCAL_RANK", - ): - if v in os.environ: - local_rank = int(os.environ[v]) - break - global_rank = 0 - for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"): - if v in os.environ: - global_rank = int(os.environ[v]) - break - world_size = 1 - for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"): - if v in os.environ: - world_size = int(os.environ[v]) - break - - return local_rank, global_rank, world_size - - -def init_distributed_device(args): - # Distributed training = training on more than one GPU. - # Works in both single and multi-node scenarios. - args.distributed = False - args.world_size = 1 - args.rank = 0 # global rank - args.local_rank = 0 - if args.horovod: - assert hvd is not None, "Horovod is not installed" - hvd.init() - world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) - world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) - local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) - args.local_rank = local_rank - args.rank = world_rank - args.world_size = world_size - # args.local_rank = int(hvd.local_rank()) - # args.rank = hvd.rank() - # args.world_size = hvd.size() - args.distributed = True - os.environ["LOCAL_RANK"] = str(args.local_rank) - os.environ["RANK"] = str(args.rank) - os.environ["WORLD_SIZE"] = str(args.world_size) - print( - f"Distributed training: local_rank={args.local_rank}, " - f"rank={args.rank}, world_size={args.world_size}, " - f"hostname={socket.gethostname()}, pid={os.getpid()}" - ) - elif is_using_distributed(): - if "SLURM_PROCID" in os.environ: - # DDP via SLURM - args.local_rank, args.rank, args.world_size = world_info_from_env() - # SLURM var -> torch.distributed vars in case needed - os.environ["LOCAL_RANK"] = str(args.local_rank) - os.environ["RANK"] = str(args.rank) - os.environ["WORLD_SIZE"] = str(args.world_size) - torch.distributed.init_process_group( - backend=args.dist_backend, - init_method=args.dist_url, - world_size=args.world_size, - rank=args.rank, - ) - elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster - world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) - world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) - local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) - args.local_rank = local_rank - args.rank = world_rank - args.world_size = world_size - torch.distributed.init_process_group( - backend=args.dist_backend, - init_method=args.dist_url, - world_size=args.world_size, - rank=args.rank, - ) - else: - # DDP via torchrun, torch.distributed.launch - args.local_rank, _, _ = world_info_from_env() - torch.distributed.init_process_group( - backend=args.dist_backend, init_method=args.dist_url - ) - args.world_size = torch.distributed.get_world_size() - args.rank = torch.distributed.get_rank() - args.distributed = True - print( - f"Distributed training: local_rank={args.local_rank}, " - f"rank={args.rank}, world_size={args.world_size}, " - f"hostname={socket.gethostname()}, pid={os.getpid()}" - ) - - if torch.cuda.is_available(): - if args.distributed and not args.no_set_device_rank: - device = "cuda:%d" % args.local_rank - else: - device = "cuda:0" - torch.cuda.set_device(device) - else: - device = "cpu" - args.device = device - device = torch.device(device) - return device diff --git a/audioldm/clap/training/imagenet_zeroshot_data.py b/audioldm/clap/training/imagenet_zeroshot_data.py deleted file mode 100644 index d32e55328d6799ccb8d61625f43abb80a33d6c17..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/imagenet_zeroshot_data.py +++ /dev/null @@ -1,1088 +0,0 @@ -# NOTE: This script is currently not supported for CLAP. - -imagenet_classnames = [ - "tench", - "goldfish", - "great white shark", - "tiger shark", - "hammerhead shark", - "electric ray", - "stingray", - "rooster", - "hen", - "ostrich", - "brambling", - "goldfinch", - "house finch", - "junco", - "indigo bunting", - "American robin", - "bulbul", - "jay", - "magpie", - "chickadee", - "American dipper", - "kite (bird of prey)", - "bald eagle", - "vulture", - "great grey owl", - "fire salamander", - "smooth newt", - "newt", - "spotted salamander", - "axolotl", - "American bullfrog", - "tree frog", - "tailed frog", - "loggerhead sea turtle", - "leatherback sea turtle", - "mud turtle", - "terrapin", - "box turtle", - "banded gecko", - "green iguana", - "Carolina anole", - "desert grassland whiptail lizard", - "agama", - "frilled-necked lizard", - "alligator lizard", - "Gila monster", - "European green lizard", - "chameleon", - "Komodo dragon", - "Nile crocodile", - "American alligator", - "triceratops", - "worm snake", - "ring-necked snake", - "eastern hog-nosed snake", - "smooth green snake", - "kingsnake", - "garter snake", - "water snake", - "vine snake", - "night snake", - "boa constrictor", - "African rock python", - "Indian cobra", - "green mamba", - "sea snake", - "Saharan horned viper", - "eastern diamondback rattlesnake", - "sidewinder rattlesnake", - "trilobite", - "harvestman", - "scorpion", - "yellow garden spider", - "barn spider", - "European garden spider", - "southern black widow", - "tarantula", - "wolf spider", - "tick", - "centipede", - "black grouse", - "ptarmigan", - "ruffed grouse", - "prairie grouse", - "peafowl", - "quail", - "partridge", - "african grey parrot", - "macaw", - "sulphur-crested cockatoo", - "lorikeet", - "coucal", - "bee eater", - "hornbill", - "hummingbird", - "jacamar", - "toucan", - "duck", - "red-breasted merganser", - "goose", - "black swan", - "tusker", - "echidna", - "platypus", - "wallaby", - "koala", - "wombat", - "jellyfish", - "sea anemone", - "brain coral", - "flatworm", - "nematode", - "conch", - "snail", - "slug", - "sea slug", - "chiton", - "chambered nautilus", - "Dungeness crab", - "rock crab", - "fiddler crab", - "red king crab", - "American lobster", - "spiny lobster", - "crayfish", - "hermit crab", - "isopod", - "white stork", - "black stork", - "spoonbill", - "flamingo", - "little blue heron", - "great egret", - "bittern bird", - "crane bird", - "limpkin", - "common gallinule", - "American coot", - "bustard", - "ruddy turnstone", - "dunlin", - "common redshank", - "dowitcher", - "oystercatcher", - "pelican", - "king penguin", - "albatross", - "grey whale", - "killer whale", - "dugong", - "sea lion", - "Chihuahua", - "Japanese Chin", - "Maltese", - "Pekingese", - "Shih Tzu", - "King Charles Spaniel", - "Papillon", - "toy terrier", - "Rhodesian Ridgeback", - "Afghan Hound", - "Basset Hound", - "Beagle", - "Bloodhound", - "Bluetick Coonhound", - "Black and Tan Coonhound", - "Treeing Walker Coonhound", - "English foxhound", - "Redbone Coonhound", - "borzoi", - "Irish Wolfhound", - "Italian Greyhound", - "Whippet", - "Ibizan Hound", - "Norwegian Elkhound", - "Otterhound", - "Saluki", - "Scottish Deerhound", - "Weimaraner", - "Staffordshire Bull Terrier", - "American Staffordshire Terrier", - "Bedlington Terrier", - "Border Terrier", - "Kerry Blue Terrier", - "Irish Terrier", - "Norfolk Terrier", - "Norwich Terrier", - "Yorkshire Terrier", - "Wire Fox Terrier", - "Lakeland Terrier", - "Sealyham Terrier", - "Airedale Terrier", - "Cairn Terrier", - "Australian Terrier", - "Dandie Dinmont Terrier", - "Boston Terrier", - "Miniature Schnauzer", - "Giant Schnauzer", - "Standard Schnauzer", - "Scottish Terrier", - "Tibetan Terrier", - "Australian Silky Terrier", - "Soft-coated Wheaten Terrier", - "West Highland White Terrier", - "Lhasa Apso", - "Flat-Coated Retriever", - "Curly-coated Retriever", - "Golden Retriever", - "Labrador Retriever", - "Chesapeake Bay Retriever", - "German Shorthaired Pointer", - "Vizsla", - "English Setter", - "Irish Setter", - "Gordon Setter", - "Brittany dog", - "Clumber Spaniel", - "English Springer Spaniel", - "Welsh Springer Spaniel", - "Cocker Spaniel", - "Sussex Spaniel", - "Irish Water Spaniel", - "Kuvasz", - "Schipperke", - "Groenendael dog", - "Malinois", - "Briard", - "Australian Kelpie", - "Komondor", - "Old English Sheepdog", - "Shetland Sheepdog", - "collie", - "Border Collie", - "Bouvier des Flandres dog", - "Rottweiler", - "German Shepherd Dog", - "Dobermann", - "Miniature Pinscher", - "Greater Swiss Mountain Dog", - "Bernese Mountain Dog", - "Appenzeller Sennenhund", - "Entlebucher Sennenhund", - "Boxer", - "Bullmastiff", - "Tibetan Mastiff", - "French Bulldog", - "Great Dane", - "St. Bernard", - "husky", - "Alaskan Malamute", - "Siberian Husky", - "Dalmatian", - "Affenpinscher", - "Basenji", - "pug", - "Leonberger", - "Newfoundland dog", - "Great Pyrenees dog", - "Samoyed", - "Pomeranian", - "Chow Chow", - "Keeshond", - "brussels griffon", - "Pembroke Welsh Corgi", - "Cardigan Welsh Corgi", - "Toy Poodle", - "Miniature Poodle", - "Standard Poodle", - "Mexican hairless dog (xoloitzcuintli)", - "grey wolf", - "Alaskan tundra wolf", - "red wolf or maned wolf", - "coyote", - "dingo", - "dhole", - "African wild dog", - "hyena", - "red fox", - "kit fox", - "Arctic fox", - "grey fox", - "tabby cat", - "tiger cat", - "Persian cat", - "Siamese cat", - "Egyptian Mau", - "cougar", - "lynx", - "leopard", - "snow leopard", - "jaguar", - "lion", - "tiger", - "cheetah", - "brown bear", - "American black bear", - "polar bear", - "sloth bear", - "mongoose", - "meerkat", - "tiger beetle", - "ladybug", - "ground beetle", - "longhorn beetle", - "leaf beetle", - "dung beetle", - "rhinoceros beetle", - "weevil", - "fly", - "bee", - "ant", - "grasshopper", - "cricket insect", - "stick insect", - "cockroach", - "praying mantis", - "cicada", - "leafhopper", - "lacewing", - "dragonfly", - "damselfly", - "red admiral butterfly", - "ringlet butterfly", - "monarch butterfly", - "small white butterfly", - "sulphur butterfly", - "gossamer-winged butterfly", - "starfish", - "sea urchin", - "sea cucumber", - "cottontail rabbit", - "hare", - "Angora rabbit", - "hamster", - "porcupine", - "fox squirrel", - "marmot", - "beaver", - "guinea pig", - "common sorrel horse", - "zebra", - "pig", - "wild boar", - "warthog", - "hippopotamus", - "ox", - "water buffalo", - "bison", - "ram (adult male sheep)", - "bighorn sheep", - "Alpine ibex", - "hartebeest", - "impala (antelope)", - "gazelle", - "arabian camel", - "llama", - "weasel", - "mink", - "European polecat", - "black-footed ferret", - "otter", - "skunk", - "badger", - "armadillo", - "three-toed sloth", - "orangutan", - "gorilla", - "chimpanzee", - "gibbon", - "siamang", - "guenon", - "patas monkey", - "baboon", - "macaque", - "langur", - "black-and-white colobus", - "proboscis monkey", - "marmoset", - "white-headed capuchin", - "howler monkey", - "titi monkey", - "Geoffroy's spider monkey", - "common squirrel monkey", - "ring-tailed lemur", - "indri", - "Asian elephant", - "African bush elephant", - "red panda", - "giant panda", - "snoek fish", - "eel", - "silver salmon", - "rock beauty fish", - "clownfish", - "sturgeon", - "gar fish", - "lionfish", - "pufferfish", - "abacus", - "abaya", - "academic gown", - "accordion", - "acoustic guitar", - "aircraft carrier", - "airliner", - "airship", - "altar", - "ambulance", - "amphibious vehicle", - "analog clock", - "apiary", - "apron", - "trash can", - "assault rifle", - "backpack", - "bakery", - "balance beam", - "balloon", - "ballpoint pen", - "Band-Aid", - "banjo", - "baluster / handrail", - "barbell", - "barber chair", - "barbershop", - "barn", - "barometer", - "barrel", - "wheelbarrow", - "baseball", - "basketball", - "bassinet", - "bassoon", - "swimming cap", - "bath towel", - "bathtub", - "station wagon", - "lighthouse", - "beaker", - "military hat (bearskin or shako)", - "beer bottle", - "beer glass", - "bell tower", - "baby bib", - "tandem bicycle", - "bikini", - "ring binder", - "binoculars", - "birdhouse", - "boathouse", - "bobsleigh", - "bolo tie", - "poke bonnet", - "bookcase", - "bookstore", - "bottle cap", - "hunting bow", - "bow tie", - "brass memorial plaque", - "bra", - "breakwater", - "breastplate", - "broom", - "bucket", - "buckle", - "bulletproof vest", - "high-speed train", - "butcher shop", - "taxicab", - "cauldron", - "candle", - "cannon", - "canoe", - "can opener", - "cardigan", - "car mirror", - "carousel", - "tool kit", - "cardboard box / carton", - "car wheel", - "automated teller machine", - "cassette", - "cassette player", - "castle", - "catamaran", - "CD player", - "cello", - "mobile phone", - "chain", - "chain-link fence", - "chain mail", - "chainsaw", - "storage chest", - "chiffonier", - "bell or wind chime", - "china cabinet", - "Christmas stocking", - "church", - "movie theater", - "cleaver", - "cliff dwelling", - "cloak", - "clogs", - "cocktail shaker", - "coffee mug", - "coffeemaker", - "spiral or coil", - "combination lock", - "computer keyboard", - "candy store", - "container ship", - "convertible", - "corkscrew", - "cornet", - "cowboy boot", - "cowboy hat", - "cradle", - "construction crane", - "crash helmet", - "crate", - "infant bed", - "Crock Pot", - "croquet ball", - "crutch", - "cuirass", - "dam", - "desk", - "desktop computer", - "rotary dial telephone", - "diaper", - "digital clock", - "digital watch", - "dining table", - "dishcloth", - "dishwasher", - "disc brake", - "dock", - "dog sled", - "dome", - "doormat", - "drilling rig", - "drum", - "drumstick", - "dumbbell", - "Dutch oven", - "electric fan", - "electric guitar", - "electric locomotive", - "entertainment center", - "envelope", - "espresso machine", - "face powder", - "feather boa", - "filing cabinet", - "fireboat", - "fire truck", - "fire screen", - "flagpole", - "flute", - "folding chair", - "football helmet", - "forklift", - "fountain", - "fountain pen", - "four-poster bed", - "freight car", - "French horn", - "frying pan", - "fur coat", - "garbage truck", - "gas mask or respirator", - "gas pump", - "goblet", - "go-kart", - "golf ball", - "golf cart", - "gondola", - "gong", - "gown", - "grand piano", - "greenhouse", - "radiator grille", - "grocery store", - "guillotine", - "hair clip", - "hair spray", - "half-track", - "hammer", - "hamper", - "hair dryer", - "hand-held computer", - "handkerchief", - "hard disk drive", - "harmonica", - "harp", - "combine harvester", - "hatchet", - "holster", - "home theater", - "honeycomb", - "hook", - "hoop skirt", - "gymnastic horizontal bar", - "horse-drawn vehicle", - "hourglass", - "iPod", - "clothes iron", - "carved pumpkin", - "jeans", - "jeep", - "T-shirt", - "jigsaw puzzle", - "rickshaw", - "joystick", - "kimono", - "knee pad", - "knot", - "lab coat", - "ladle", - "lampshade", - "laptop computer", - "lawn mower", - "lens cap", - "letter opener", - "library", - "lifeboat", - "lighter", - "limousine", - "ocean liner", - "lipstick", - "slip-on shoe", - "lotion", - "music speaker", - "loupe magnifying glass", - "sawmill", - "magnetic compass", - "messenger bag", - "mailbox", - "tights", - "one-piece bathing suit", - "manhole cover", - "maraca", - "marimba", - "mask", - "matchstick", - "maypole", - "maze", - "measuring cup", - "medicine cabinet", - "megalith", - "microphone", - "microwave oven", - "military uniform", - "milk can", - "minibus", - "miniskirt", - "minivan", - "missile", - "mitten", - "mixing bowl", - "mobile home", - "ford model t", - "modem", - "monastery", - "monitor", - "moped", - "mortar and pestle", - "graduation cap", - "mosque", - "mosquito net", - "vespa", - "mountain bike", - "tent", - "computer mouse", - "mousetrap", - "moving van", - "muzzle", - "metal nail", - "neck brace", - "necklace", - "baby pacifier", - "notebook computer", - "obelisk", - "oboe", - "ocarina", - "odometer", - "oil filter", - "pipe organ", - "oscilloscope", - "overskirt", - "bullock cart", - "oxygen mask", - "product packet / packaging", - "paddle", - "paddle wheel", - "padlock", - "paintbrush", - "pajamas", - "palace", - "pan flute", - "paper towel", - "parachute", - "parallel bars", - "park bench", - "parking meter", - "railroad car", - "patio", - "payphone", - "pedestal", - "pencil case", - "pencil sharpener", - "perfume", - "Petri dish", - "photocopier", - "plectrum", - "Pickelhaube", - "picket fence", - "pickup truck", - "pier", - "piggy bank", - "pill bottle", - "pillow", - "ping-pong ball", - "pinwheel", - "pirate ship", - "drink pitcher", - "block plane", - "planetarium", - "plastic bag", - "plate rack", - "farm plow", - "plunger", - "Polaroid camera", - "pole", - "police van", - "poncho", - "pool table", - "soda bottle", - "plant pot", - "potter's wheel", - "power drill", - "prayer rug", - "printer", - "prison", - "missile", - "projector", - "hockey puck", - "punching bag", - "purse", - "quill", - "quilt", - "race car", - "racket", - "radiator", - "radio", - "radio telescope", - "rain barrel", - "recreational vehicle", - "fishing casting reel", - "reflex camera", - "refrigerator", - "remote control", - "restaurant", - "revolver", - "rifle", - "rocking chair", - "rotisserie", - "eraser", - "rugby ball", - "ruler measuring stick", - "sneaker", - "safe", - "safety pin", - "salt shaker", - "sandal", - "sarong", - "saxophone", - "scabbard", - "weighing scale", - "school bus", - "schooner", - "scoreboard", - "CRT monitor", - "screw", - "screwdriver", - "seat belt", - "sewing machine", - "shield", - "shoe store", - "shoji screen / room divider", - "shopping basket", - "shopping cart", - "shovel", - "shower cap", - "shower curtain", - "ski", - "balaclava ski mask", - "sleeping bag", - "slide rule", - "sliding door", - "slot machine", - "snorkel", - "snowmobile", - "snowplow", - "soap dispenser", - "soccer ball", - "sock", - "solar thermal collector", - "sombrero", - "soup bowl", - "keyboard space bar", - "space heater", - "space shuttle", - "spatula", - "motorboat", - "spider web", - "spindle", - "sports car", - "spotlight", - "stage", - "steam locomotive", - "through arch bridge", - "steel drum", - "stethoscope", - "scarf", - "stone wall", - "stopwatch", - "stove", - "strainer", - "tram", - "stretcher", - "couch", - "stupa", - "submarine", - "suit", - "sundial", - "sunglasses", - "sunglasses", - "sunscreen", - "suspension bridge", - "mop", - "sweatshirt", - "swim trunks / shorts", - "swing", - "electrical switch", - "syringe", - "table lamp", - "tank", - "tape player", - "teapot", - "teddy bear", - "television", - "tennis ball", - "thatched roof", - "front curtain", - "thimble", - "threshing machine", - "throne", - "tile roof", - "toaster", - "tobacco shop", - "toilet seat", - "torch", - "totem pole", - "tow truck", - "toy store", - "tractor", - "semi-trailer truck", - "tray", - "trench coat", - "tricycle", - "trimaran", - "tripod", - "triumphal arch", - "trolleybus", - "trombone", - "hot tub", - "turnstile", - "typewriter keyboard", - "umbrella", - "unicycle", - "upright piano", - "vacuum cleaner", - "vase", - "vaulted or arched ceiling", - "velvet fabric", - "vending machine", - "vestment", - "viaduct", - "violin", - "volleyball", - "waffle iron", - "wall clock", - "wallet", - "wardrobe", - "military aircraft", - "sink", - "washing machine", - "water bottle", - "water jug", - "water tower", - "whiskey jug", - "whistle", - "hair wig", - "window screen", - "window shade", - "Windsor tie", - "wine bottle", - "airplane wing", - "wok", - "wooden spoon", - "wool", - "split-rail fence", - "shipwreck", - "sailboat", - "yurt", - "website", - "comic book", - "crossword", - "traffic or street sign", - "traffic light", - "dust jacket", - "menu", - "plate", - "guacamole", - "consomme", - "hot pot", - "trifle", - "ice cream", - "popsicle", - "baguette", - "bagel", - "pretzel", - "cheeseburger", - "hot dog", - "mashed potatoes", - "cabbage", - "broccoli", - "cauliflower", - "zucchini", - "spaghetti squash", - "acorn squash", - "butternut squash", - "cucumber", - "artichoke", - "bell pepper", - "cardoon", - "mushroom", - "Granny Smith apple", - "strawberry", - "orange", - "lemon", - "fig", - "pineapple", - "banana", - "jackfruit", - "cherimoya (custard apple)", - "pomegranate", - "hay", - "carbonara", - "chocolate syrup", - "dough", - "meatloaf", - "pizza", - "pot pie", - "burrito", - "red wine", - "espresso", - "tea cup", - "eggnog", - "mountain", - "bubble", - "cliff", - "coral reef", - "geyser", - "lakeshore", - "promontory", - "sandbar", - "beach", - "valley", - "volcano", - "baseball player", - "bridegroom", - "scuba diver", - "rapeseed", - "daisy", - "yellow lady's slipper", - "corn", - "acorn", - "rose hip", - "horse chestnut seed", - "coral fungus", - "agaric", - "gyromitra", - "stinkhorn mushroom", - "earth star fungus", - "hen of the woods mushroom", - "bolete", - "corn cob", - "toilet paper", -] - - -openai_imagenet_template = [ - lambda c: f"a bad photo of a {c}.", - lambda c: f"a photo of many {c}.", - lambda c: f"a sculpture of a {c}.", - lambda c: f"a photo of the hard to see {c}.", - lambda c: f"a low resolution photo of the {c}.", - lambda c: f"a rendering of a {c}.", - lambda c: f"graffiti of a {c}.", - lambda c: f"a bad photo of the {c}.", - lambda c: f"a cropped photo of the {c}.", - lambda c: f"a tattoo of a {c}.", - lambda c: f"the embroidered {c}.", - lambda c: f"a photo of a hard to see {c}.", - lambda c: f"a bright photo of a {c}.", - lambda c: f"a photo of a clean {c}.", - lambda c: f"a photo of a dirty {c}.", - lambda c: f"a dark photo of the {c}.", - lambda c: f"a drawing of a {c}.", - lambda c: f"a photo of my {c}.", - lambda c: f"the plastic {c}.", - lambda c: f"a photo of the cool {c}.", - lambda c: f"a close-up photo of a {c}.", - lambda c: f"a black and white photo of the {c}.", - lambda c: f"a painting of the {c}.", - lambda c: f"a painting of a {c}.", - lambda c: f"a pixelated photo of the {c}.", - lambda c: f"a sculpture of the {c}.", - lambda c: f"a bright photo of the {c}.", - lambda c: f"a cropped photo of a {c}.", - lambda c: f"a plastic {c}.", - lambda c: f"a photo of the dirty {c}.", - lambda c: f"a jpeg corrupted photo of a {c}.", - lambda c: f"a blurry photo of the {c}.", - lambda c: f"a photo of the {c}.", - lambda c: f"a good photo of the {c}.", - lambda c: f"a rendering of the {c}.", - lambda c: f"a {c} in a video game.", - lambda c: f"a photo of one {c}.", - lambda c: f"a doodle of a {c}.", - lambda c: f"a close-up photo of the {c}.", - lambda c: f"a photo of a {c}.", - lambda c: f"the origami {c}.", - lambda c: f"the {c} in a video game.", - lambda c: f"a sketch of a {c}.", - lambda c: f"a doodle of the {c}.", - lambda c: f"a origami {c}.", - lambda c: f"a low resolution photo of a {c}.", - lambda c: f"the toy {c}.", - lambda c: f"a rendition of the {c}.", - lambda c: f"a photo of the clean {c}.", - lambda c: f"a photo of a large {c}.", - lambda c: f"a rendition of a {c}.", - lambda c: f"a photo of a nice {c}.", - lambda c: f"a photo of a weird {c}.", - lambda c: f"a blurry photo of a {c}.", - lambda c: f"a cartoon {c}.", - lambda c: f"art of a {c}.", - lambda c: f"a sketch of the {c}.", - lambda c: f"a embroidered {c}.", - lambda c: f"a pixelated photo of a {c}.", - lambda c: f"itap of the {c}.", - lambda c: f"a jpeg corrupted photo of the {c}.", - lambda c: f"a good photo of a {c}.", - lambda c: f"a plushie {c}.", - lambda c: f"a photo of the nice {c}.", - lambda c: f"a photo of the small {c}.", - lambda c: f"a photo of the weird {c}.", - lambda c: f"the cartoon {c}.", - lambda c: f"art of the {c}.", - lambda c: f"a drawing of the {c}.", - lambda c: f"a photo of the large {c}.", - lambda c: f"a black and white photo of a {c}.", - lambda c: f"the plushie {c}.", - lambda c: f"a dark photo of a {c}.", - lambda c: f"itap of a {c}.", - lambda c: f"graffiti of the {c}.", - lambda c: f"a toy {c}.", - lambda c: f"itap of my {c}.", - lambda c: f"a photo of a cool {c}.", - lambda c: f"a photo of a small {c}.", - lambda c: f"a tattoo of the {c}.", -] diff --git a/audioldm/clap/training/infer_demo.py b/audioldm/clap/training/infer_demo.py deleted file mode 100644 index 7d1f4784898dbfeb69affefb6f624711adc8cb42..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/infer_demo.py +++ /dev/null @@ -1,105 +0,0 @@ -import sys - -import os -import torch -import librosa -from open_clip import create_model -from training.data import get_audio_features -from training.data import int16_to_float32, float32_to_int16 -from transformers import RobertaTokenizer - -tokenize = RobertaTokenizer.from_pretrained("roberta-base") - - -def tokenizer(text): - result = tokenize( - text, - padding="max_length", - truncation=True, - max_length=77, - return_tensors="pt", - ) - return {k: v.squeeze(0) for k, v in result.items()} - - -PRETRAINED_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/checkpoints/epoch_top_0_audioset_no_fusion.pt" -WAVE_48k_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/audio/machine.wav" - - -def infer_text(): - device = "cuda:0" if torch.cuda.is_available() else "cpu" - precision = "fp32" - amodel = "HTSAT-tiny" # or 'PANN-14' - tmodel = "roberta" # the best text encoder in our training - enable_fusion = False # False if you do not want to use the fusion model - fusion_type = "aff_2d" - pretrained = PRETRAINED_PATH - - model, model_cfg = create_model( - amodel, - tmodel, - pretrained, - precision=precision, - device=device, - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - # load the text, can be a list (i.e. batch size) - text_data = ["I love the contrastive learning", "I love the pretrain model"] - # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90 - text_data = tokenizer(text_data) - - text_embed = model.get_text_embedding(text_data) - print(text_embed.size()) - - -def infer_audio(): - - device = "cuda:0" if torch.cuda.is_available() else "cpu" - precision = "fp32" - amodel = "HTSAT-tiny" # or 'PANN-14' - tmodel = "roberta" # the best text encoder in our training - enable_fusion = False # False if you do not want to use the fusion model - fusion_type = "aff_2d" - pretrained = PRETRAINED_PATH - - model, model_cfg = create_model( - amodel, - tmodel, - pretrained, - precision=precision, - device=device, - enable_fusion=enable_fusion, - fusion_type=fusion_type, - ) - - # load the waveform of the shape (T,), should resample to 48000 - audio_waveform, sr = librosa.load(WAVE_48k_PATH, sr=48000) - # quantize - audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) - audio_waveform = torch.from_numpy(audio_waveform).float() - audio_dict = {} - - # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode - import ipdb - - ipdb.set_trace() - audio_dict = get_audio_features( - audio_dict, - audio_waveform, - 480000, - data_truncating="fusion", - data_filling="repeatpad", - audio_cfg=model_cfg["audio_cfg"], - ) - # can send a list to the model, to process many audio tracks in one time (i.e. batch size) - audio_embed = model.get_audio_embedding([audio_dict]) - print(audio_embed.size()) - import ipdb - - ipdb.set_trace() - - -if __name__ == "__main__": - infer_text() - infer_audio() diff --git a/audioldm/clap/training/logger.py b/audioldm/clap/training/logger.py deleted file mode 100644 index ac4634970fae6aacde2b7b808355dbd50c90ce73..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/logger.py +++ /dev/null @@ -1,30 +0,0 @@ -import logging - - -def setup_logging(log_file, level, include_host=False): - if include_host: - import socket - - hostname = socket.gethostname() - formatter = logging.Formatter( - f"%(asctime)s | {hostname} | %(levelname)s | %(message)s", - datefmt="%Y-%m-%d,%H:%M:%S", - ) - else: - formatter = logging.Formatter( - "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S" - ) - - logging.root.setLevel(level) - loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] - for logger in loggers: - logger.setLevel(level) - - stream_handler = logging.StreamHandler() - stream_handler.setFormatter(formatter) - logging.root.addHandler(stream_handler) - - if log_file: - file_handler = logging.FileHandler(filename=log_file) - file_handler.setFormatter(formatter) - logging.root.addHandler(file_handler) diff --git a/audioldm/clap/training/lp_main.py b/audioldm/clap/training/lp_main.py deleted file mode 100644 index c2d4e8c85aaa3c8e4221963ef56a815cc14f354f..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/lp_main.py +++ /dev/null @@ -1,670 +0,0 @@ -from cmath import cos -from inspect import getargs -import logging -import os -import random -from datetime import datetime -import bisect -import copy -from sched import scheduler -import numpy as np -import torch -import torch.backends.cudnn as cudnn -from torch import optim -from torch.cuda.amp import GradScaler -import faulthandler -import pathlib -import argparse -import time - -try: - import wandb -except ImportError: - wandb = None - -try: - import torch.utils.tensorboard as tensorboard -except ImportError: - tensorboard = None - -try: - import horovod.torch as hvd -except ImportError: - hvd = None - -from open_clip import create_model_and_transforms, trace_model, create_model -from training.data import get_data -from training.params import parse_args -from training.distributed import is_master, init_distributed_device, world_info_from_env -from training.logger import setup_logging -from training.scheduler import cosine_lr -from training.lp_train import train_one_epoch, evaluate -from open_clip.utils import get_tar_path_from_dataset_name, dataset_split, get_optimizer -from open_clip.utils import load_p, load_class_label -from open_clip.linear_probe import LinearProbe - - -def maintain_ckpts(args, startidx, all_idx_len): - for i in reversed(range(startidx, all_idx_len)): - if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): - os.rename( - os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), - os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), - ) - if os.path.exists( - os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") - ): - os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) - return - - -def update_top_k_performance( - new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True -): - """ - Record the top-k performance of the current epoch. - current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} - """ - if isinstance(new_metrics_inputs, (list, tuple)): - new_metrics_inputs = np.mean(new_metrics_inputs) - return update_top_k_performance( - new_metrics_inputs, - current_top_k_ckpt_metrics, - args=args, - ckpt=ckpt, - bignumbetter=bignumbetter, - ) - elif isinstance(new_metrics_inputs, dict): - new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) - return update_top_k_performance( - new_metrics_inputs, - current_top_k_ckpt_metrics, - args=args, - ckpt=ckpt, - bignumbetter=bignumbetter, - ) - elif isinstance(new_metrics_inputs, (float, int)): - update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} - sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) - sorted_values = sorted( - current_top_k_ckpt_metrics.values(), reverse=bignumbetter - ) - sorted_values_ = copy.deepcopy(sorted_values) - sorted_values.append(new_metrics_inputs) - sorted_values = sorted(sorted_values, reverse=bignumbetter) - sorted_values = sorted_values[:-1] - - if sorted_values == sorted_values_: - return current_top_k_ckpt_metrics, new_metrics_inputs - else: - for i in range(len(sorted_keys)): - if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: - current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] - update_flag[sorted_keys[i]] = True - for i in range(len(update_flag)): - if update_flag[i]: - maintain_ckpts(args, i, len(sorted_keys)) - torch.save( - ckpt, - os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), - ) - break - return current_top_k_ckpt_metrics, new_metrics_inputs - - -# def updateifNone(a, b): -# a = b if None else a -# return a - - -def is_pretrained_params(n): - return ( - n.startswith("clap_model.transformer") - or n in ["clap_model.positional_embedding", "clap_model.text_projection"] - or n.startswith("clap_model.token_embedding") - or n.startswith("clap_model.ln_final") - or n.startswith("clap_model.logit_scale_t") - ) - - -def random_seed(seed=42, rank=0): - torch.manual_seed(seed + rank) - np.random.seed(seed + rank) - random.seed(seed + rank) - - -def config_lp_optimizer(model, data, args): - # set wd-related params to 0 if use adam optimizer - if args.optimizer == "adam": - args.wd = 0 - args.wd_pretrained = 0 - args.wd_new = 0 - - in_clap = lambda n, p: n.startswith("clap_model") - - named_parameters = list(model.named_parameters()) - - optimizer = {} - scheduler = {} - - # freeze text encoder - text_freeze_parameters = [ - p - for n, p in named_parameters - if n.startswith("clap_model.transformer") - or n in ["clap_model.positional_embedding", "clap_model.text_projection"] - or n.startswith("clap_model.token_embedding") - or n.startswith("clap_model.ln_final") - ] - - if args.freeze_text: - logging.info("Freeze Text!!!!") - for k in text_freeze_parameters: - k.requires_grad = False - - if not args.lp_freeze: - exclude = ( - lambda n, p: p.ndim < 2 - or "bn" in n - or "ln" in n - or "bias" in n - or "logit_scale" in n - ) - include = lambda n, p: not exclude(n, p) - - # (yusong): we do not split the learning rate anymore - # p for n, p in named_parameters if in_clap(n,p) and exclude(n, p) and p.requires_grad - gain_or_bias_params = [ - p for n, p in named_parameters if exclude(n, p) and p.requires_grad - ] - # rest_params = [p for n, p in named_parameters if in_clap(n,p) and include(n, p) and p.requires_grad] - rest_params = [ - p for n, p in named_parameters if include(n, p) and p.requires_grad - ] - - if args.train_data is None: - optimizer = None - scheduler = None - else: - total_steps = data["train"].dataloader.num_batches * args.epochs - - if args.split_opt: - for x in ["lr", "beta1", "beta2", "eps", "wd"]: - for y in ["_new", "_pretrained"]: - if getattr(args, x + y) is None: - setattr(args, x + y, getattr(args, x)) - - gain_or_bias_pretrained_params = [ - p - for n, p in named_parameters - if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) - ] - rest_pretrained_params = [ - p - for n, p in named_parameters - if (include(n, p) and p.requires_grad) and is_pretrained_params(n) - ] - gain_or_bias_new_params = [ - p - for n, p in named_parameters - if (exclude(n, p) and p.requires_grad) - and (not is_pretrained_params(n)) - ] - rest_new_params = [ - p - for n, p in named_parameters - if (include(n, p) and p.requires_grad) - and (not is_pretrained_params(n)) - ] - - pretrained_params_optimizer = get_optimizer( - [ - {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, - { - "params": rest_pretrained_params, - "weight_decay": args.wd_pretrained, - }, - ], - lr=args.lr_pretrained, - betas=(args.beta1_pretrained, args.beta2_pretrained), - eps=args.eps_pretrained, - momentum=args.momentum_pretrained, - optimizer_name=args.optimizer, - ) - pretrained_params_scheduler = cosine_lr( - pretrained_params_optimizer, - args.lr_pretrained, - args.warmup, - total_steps, - ) - - new_params_optimizer = get_optimizer( - [ - {"params": gain_or_bias_new_params, "weight_decay": 0.0}, - {"params": rest_new_params, "weight_decay": args.wd_new}, - ], - lr=args.lr_new, - betas=(args.beta1_new, args.beta2_new), - eps=args.eps_new, - momentum=args.momentum_new, - optimizer_name=args.optimizer, - ) - new_params_scheduler = cosine_lr( - new_params_optimizer, args.lr_new, args.warmup, total_steps - ) - - optimizer["text"] = pretrained_params_optimizer - optimizer["audio"] = new_params_optimizer - scheduler["text"] = pretrained_params_scheduler - scheduler["audio"] = new_params_scheduler - - if args.horovod: - pretrained_params_optimizer = hvd.DistributedOptimizer( - pretrained_params_optimizer, - named_parameters=model.named_parameters(), - ) - new_params_optimizer = hvd.DistributedOptimizer( - new_params_optimizer, named_parameters=model.named_parameters() - ) - hvd.broadcast_parameters(model.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state( - pretrained_params_optimizer, root_rank=0 - ) - hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) - else: - - optimizer["clap"] = get_optimizer( - [ - {"params": gain_or_bias_params, "weight_decay": 0.0}, - {"params": rest_params, "weight_decay": args.wd}, - ], - lr=args.lr, - betas=(args.beta1, args.beta2), - eps=args.eps, - momentum=args.momentum, - optimizer_name=args.optimizer, - ) - scheduler["clap"] = cosine_lr( - optimizer["clap"], args.lr, args.warmup, total_steps - ) - - if args.horovod: - optimizer["clap"] = hvd.DistributedOptimizer( - optimizer["clap"], named_parameters=model.named_parameters() - ) - hvd.broadcast_parameters(model.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state(optimizer["clap"], root_rank=0) - - # linear probe optimizer - else: - lp_params = [ - p for n, p in named_parameters if (not in_clap(n, p)) and p.requires_grad - ] - lp_optim = get_optimizer( - lp_params, - lr=args.lp_lr, - betas=(args.beta1, args.beta2), - eps=args.eps, - momentum=0.9, - optimizer_name=args.optimizer, - ) - optimizer["lp"] = lp_optim - - return optimizer, scheduler, text_freeze_parameters - - -def main(): - args = parse_args() - - time.sleep(args.sleep) - - # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? - args.amodel = args.amodel.replace("/", "-") - # download sizes.json file - - # (yusong): the below two lines are for debug - # print("setting up faulthandler") - # faulthandler.register(10) - - random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - np.random.seed(args.seed) - args.class_index_dict = load_class_label(args.class_label_path) - - # get the name of the experiments - if args.name is None: - args.name = "-".join( - [ - datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), - f"linear_probe" f"model_{args.amodel}", - f"lr_{args.lr}", - f"b_{args.batch_size}", - f"j_{args.workers}", - f"p_{args.precision}", - ] - ) - - # discover initial world args early so we can log properly - args.distributed = False - args.local_rank, args.rank, args.world_size = world_info_from_env() - - if args.remotedata and is_master(args): - for dataset_name in args.datasetnames: - for split in dataset_split[dataset_name]: - if not os.path.exists(f"./json_files/{dataset_name}/{split}"): - os.makedirs(f"./json_files/{dataset_name}/{split}") - os.system( - f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" - ) - - args.log_path = None - if is_master(args, local=args.log_local): - log_base_path = os.path.join(args.logs, args.name) - os.makedirs(log_base_path, exist_ok=True) - log_filename = f"out-{args.rank}" if args.log_local else "out.log" - args.log_path = os.path.join(log_base_path, log_filename) - - # avoid log dir in same name: - postfix = 0 - while os.path.exists(args.log_path): - postfix += 1 - log_base_path_new = log_base_path + "-" + str(postfix) - os.makedirs(log_base_path_new, exist_ok=True) - log_filename = f"out-{args.rank}" if args.log_local else "out.log" - args.log_path = os.path.join(log_base_path_new, log_filename) - # print( - # "Error. Experiment already exists. Use --name {} to specify a new experiment." - # ) - # return -1 - - # Set logger - args.log_level = logging.DEBUG if args.debug else logging.INFO - setup_logging(args.log_path, args.log_level) - - # fully initialize distributed device environment - device = init_distributed_device(args) - - args.wandb = "wandb" in args.report_to or "all" in args.report_to - args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to - if is_master(args): - args.tensorboard_path = ( - os.path.join(args.logs, args.name, "tensorboard") - if args.tensorboard - else "" - ) - args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") - for dirname in [args.tensorboard_path, args.checkpoint_path]: - if dirname: - os.makedirs(dirname, exist_ok=True) - else: - args.tensorboard_path = "" - args.checkpoint_path = "" - - if args.copy_codebase: - copy_codebase(args) - - assert args.precision in ["amp", "fp16", "fp32"] - if args.precision == "fp16": - logging.warning( - "It is recommended to use AMP mixed-precision instead of FP16. " - "FP16 support needs further verification and tuning, especially for train." - ) - - if args.horovod: - logging.info( - f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." - f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." - ) - elif args.distributed: - logging.info( - f"Running in distributed mode with multiple processes. Device: {args.device}." - f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." - ) - else: - logging.info(f"Running with a single process. Device {args.device}.") - - logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}") - - # Create CLAP model - clap_model, clap_model_cfg = create_model( - args.amodel, - args.tmodel, - args.pretrained, - precision=args.precision, - device=device, - jit=args.torchscript, - force_quick_gelu=args.force_quick_gelu, - openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), - skip_params=False, - pretrained_audio=args.pretrained_audio, - pretrained_text=args.pretrained_text, - enable_fusion=args.enable_fusion, - fusion_type=args.fusion_type, - ) - - args.lp_out_ch = len(list(args.class_index_dict.keys())) - # Linear Probe - logging.info(f"linear probe using mlp: {args.lp_mlp}") - logging.info(f"linear probe using freeze: {args.lp_freeze}") - logging.info(f"linear probe act layer: {args.lp_act}") - logging.info(f"linear probe out ch: {args.lp_out_ch}") - logging.info(f"linear probe learning rate (if applicable): {args.lp_lr}") - logging.info(f"linear probe loss func: {args.lp_loss}") - logging.info(f"linear probe lp_metrics: {args.lp_metrics}") - - model = LinearProbe( - clap_model, - mlp=args.lp_mlp, - freeze=args.lp_freeze, - in_ch=512, - out_ch=args.lp_out_ch, - act=args.lp_act, - ) # in_ch is fixed (i.e., 512) - model = model.to(device) - - if args.horovod: - with torch.no_grad(): - for param in model.parameters(): - param.set_(param.contiguous()) - - if args.trace: - model = trace_model(model, batch_size=args.batch_size, device=device) - - if is_master(args): - logging.info("Linear Probe CLAP Model:") - logging.info(f"{str(clap_model)}") - logging.info("Params:") - params_file = os.path.join(args.logs, args.name, "params.txt") - with open(params_file, "w") as f: - for name in sorted(vars(args)): - val = getattr(args, name) - logging.info(f" {name}: {val}") - f.write(f"{name}: {val}\n") - - if args.distributed and not args.horovod: - if args.use_bn_sync: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - ddp_args = {} - if args.ddp_static_graph: - # this doesn't exist in older PyTorch, arg only added if enabled - ddp_args["static_graph"] = True - model = torch.nn.parallel.DistributedDataParallel( - model, device_ids=[device], find_unused_parameters=True, **ddp_args - ) - - data = get_data(args, clap_model_cfg) - assert len(data), "At least one train or eval dataset must be specified." - if args.trace: - assert "train" not in data, "Cannot train with traced model" - - optimizer, scheduler, text_freeze_parameters = config_lp_optimizer( - model, data, args - ) - - scaler = GradScaler() if args.precision == "amp" else None - - # optionally resume from a checkpoint - start_epoch = 0 - if args.resume is not None: - if os.path.isfile(args.resume): - checkpoint = torch.load(args.resume, map_location=device) - if "epoch" in checkpoint: - # resuming a train checkpoint w/ epoch and optimizer state - start_epoch = checkpoint["epoch"] - sd = checkpoint["state_dict"] - if not args.distributed and next(iter(sd.items()))[0].startswith( - "module" - ): - sd = {k[len("module.") :]: v for k, v in sd.items()} - model.load_state_dict(sd) - if args.split_opt: - if optimizer is not None: - for k, o_ in optimizer.items(): - o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) - if optimizer is not None: - optimizer.load_state_dict(checkpoint["optimizer"]) - if scaler is not None and "scaler" in checkpoint: - scaler.load_state_dict(checkpoint["scaler"]) - logging.info( - f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" - ) - else: - # loading a bare (model only) checkpoint for fine-tune or evaluation - model.load_state_dict(checkpoint) - logging.info( - f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" - ) - if args.freeze_text: - print("Freeze Text!!!!") - for k in text_freeze_parameters: - k.requires_grad = False - else: - logging.info("=> no checkpoint found at '{}'".format(args.resume)) - - cudnn.benchmark = True - cudnn.deterministic = False - - # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 - args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) - writer = None - if args.save_logs and args.tensorboard: - assert tensorboard is not None, "Please install tensorboard." - writer = tensorboard.SummaryWriter(args.tensorboard_path) - - if args.wandb and is_master(args): - assert wandb is not None, "Please install wandb." - logging.debug("Starting wandb.") - args.train_sz = data["train"].dataloader.num_samples - if args.val_data is not None: - args.val_sz = data["val"].dataloader.num_samples - # you will have to configure this for your project! - wandb.init( - project="clap", - notes=args.wandb_notes, - name=args.wandb_notes, - tags=[], - config=vars(args), - ) - if args.debug: - wandb.watch(model, log="all") - wandb.save(params_file) - logging.debug("Finished loading wandb.") - - if "train" not in data: - evaluate(model, data, start_epoch, args, writer) - return - elif start_epoch == 0 and "val" in data and not args.no_eval: - evaluate(model, data, 0, args, writer) - if args.save_top_performance: - current_top_k_ckpt_metrics = { - i: 0 for i in range(args.save_top_performance) - } # initialize the top-k metric for ckpts to 0 - - for epoch in range(start_epoch, args.epochs): - # freeze the text param after (include) args.freeze_text_after, this is -1 by default - if epoch == args.freeze_text_after: - print("Text pretrained parameters are freezed since this epoch.") - for k in text_freeze_parameters: - k.requires_grad = False - if is_master(args): - logging.info(f"Start epoch {epoch}") - - train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) - completed_epoch = epoch + 1 - - if ( - any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) - and not args.no_eval - ): - metrics = evaluate(model, data, completed_epoch, args, writer) - if args.save_top_performance: - top_k_dataset = args.top_k_checkpoint_select_dataset - top_k_metric = args.top_k_checkpoint_select_metric - filtered_metrics = [ - v - for k, v in metrics.items() - if top_k_metric in k and top_k_dataset in k - ] # check all R@10 metrics (all dataset) and use it to update the ckpt - # Saving checkpoints. - if args.save_logs: - opt_dict = { - k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() - } - checkpoint_dict = { - "epoch": completed_epoch, - "name": args.name, - "state_dict": model.state_dict(), - } - checkpoint_dict.update(opt_dict) - if scaler is not None: - checkpoint_dict["scaler"] = scaler.state_dict() - - if completed_epoch == args.epochs or ( - args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 - ): - torch.save( - checkpoint_dict, - os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), - ) - if args.save_most_recent: - torch.save( - checkpoint_dict, - os.path.join(args.checkpoint_path, f"epoch_latest.pt"), - ) - if args.save_top_performance and not args.no_eval: - update_top_k_performance( - filtered_metrics, - current_top_k_ckpt_metrics, - args, - checkpoint_dict, - bignumbetter=True, - ) - - if args.wandb and is_master(args): - wandb.finish() - - -def copy_codebase(args): - from shutil import copytree, ignore_patterns - - new_code_path = os.path.join(args.logs, args.name, "code") - if os.path.exists(new_code_path): - print( - f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." - ) - return -1 - print(f"Copying codebase to {new_code_path}") - current_code_path = os.path.realpath(__file__) - for _ in range(3): - current_code_path = os.path.dirname(current_code_path) - copytree( - current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") - ) - print("Done copying code.") - return 1 - - -if __name__ == "__main__": - main() diff --git a/audioldm/clap/training/lp_train.py b/audioldm/clap/training/lp_train.py deleted file mode 100644 index 24a19bacd0a4b789415cfccbce1f8bc99bc493ed..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/lp_train.py +++ /dev/null @@ -1,301 +0,0 @@ -import json -import logging -import math -import os -import time -from contextlib import suppress - -import numpy as np -import torch -import torch.nn.functional as F - -try: - import wandb -except ImportError: - wandb = None - -from open_clip import LPLoss, LPMetrics, lp_gather_features -from open_clip.utils import do_mixup, get_mix_lambda -from .distributed import is_master -from .zero_shot import zero_shot_eval - - -class AverageMeter(object): - """Computes and stores the average and current value""" - - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def unwrap_model(model): - if hasattr(model, "module"): - return model.module - else: - return model - - -def train_one_epoch( - model, - data, - epoch, - optimizer, - scaler, - scheduler, - args, - tb_writer=None, - extra_suffix="", -): - device = torch.device(args.device) - autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress - model.train() - loss = LPLoss(args.lp_loss) - - dataloader, sampler = data["train"].dataloader, data["train"].sampler - if args.distributed and sampler is not None: - sampler.set_epoch(epoch) - num_batches_per_epoch = dataloader.num_batches - sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) - - # for toy dataset - if args.dataset_type == "toy": - dataloader.dataset.generate_queue() - - loss_m = AverageMeter() - batch_time_m = AverageMeter() - data_time_m = AverageMeter() - end = time.time() - - for i, batch in enumerate(dataloader): - step = num_batches_per_epoch * epoch + i - - if isinstance(scheduler, dict): - for s in scheduler.values(): - s(step) - else: - scheduler(step) - - audio = batch # contains mel_spec, wavform, and longer list - class_label = batch["class_label"] - # audio = audio.to(device=device, non_blocking=True) - class_label = class_label.to(device=device, non_blocking=True) - - if args.mixup: - # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146 - mix_lambda = torch.from_numpy( - get_mix_lambda(0.5, len(audio["waveform"])) - ).to(device) - class_label = do_mixup(class_label, mix_lambda) - else: - mix_lambda = None - - data_time_m.update(time.time() - end) - if isinstance(optimizer, dict): - for o_ in optimizer.values(): - o_.zero_grad() - else: - optimizer.zero_grad() - - with autocast(): - pred = model(audio, mix_lambda=mix_lambda, device=device) - total_loss = loss(pred, class_label) - - if isinstance(optimizer, dict): - if scaler is not None: - scaler.scale(total_loss).backward() - for o_ in optimizer.values(): - if args.horovod: - o_.synchronize() - scaler.unscale_(o_) - with o_.skip_synchronize(): - scaler.step(o_) - else: - scaler.step(o_) - scaler.update() - else: - total_loss.backward() - for o_ in optimizer.values(): - o_.step() - else: - if scaler is not None: - scaler.scale(total_loss).backward() - if args.horovod: - optimizer.synchronize() - scaler.unscale_(optimizer) - with optimizer.skip_synchronize(): - scaler.step(optimizer) - else: - scaler.step(optimizer) - scaler.update() - else: - total_loss.backward() - optimizer.step() - - # Note: we clamp to 4.6052 = ln(100), as in the original paper. - with torch.no_grad(): - unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100)) - unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100)) - - batch_time_m.update(time.time() - end) - end = time.time() - batch_count = i + 1 - - if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): - if isinstance(audio, dict): - batch_size = len(audio["waveform"]) - else: - batch_size = len(audio) - num_samples = batch_count * batch_size * args.world_size - samples_per_epoch = dataloader.num_samples - percent_complete = 100.0 * batch_count / num_batches_per_epoch - - # NOTE loss is coarsely sampled, just master node and per log update - loss_m.update(total_loss.item(), batch_size) - if isinstance(optimizer, dict): - logging.info( - f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " - f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " - f"Data (t): {data_time_m.avg:.3f} " - f"Batch (t): {batch_time_m.avg:.3f} " - f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}" - ) - log_data = { - "loss": loss_m.val, - "data_time": data_time_m.val, - "batch_time": batch_time_m.val, - "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], - } - else: - logging.info( - f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " - f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " - f"Data (t): {data_time_m.avg:.3f} " - f"Batch (t): {batch_time_m.avg:.3f} " - f"LR: {optimizer.param_groups[0]['lr']:5f} " - ) - - # Save train loss / etc. Using non avg meter values as loggers have their own smoothing - log_data = { - "loss": loss_m.val, - "data_time": data_time_m.val, - "batch_time": batch_time_m.val, - "lr": optimizer.param_groups[0]["lr"], - } - for name, val in log_data.items(): - name = f"train{extra_suffix}/{name}" - if tb_writer is not None: - tb_writer.add_scalar(name, val, step) - if args.wandb: - assert wandb is not None, "Please install wandb." - wandb.log({name: val, "step": step}) - - # resetting batch / data time meters per log window - batch_time_m.reset() - data_time_m.reset() - # end for - - -def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): - metrics = {} - if not args.parallel_eval: - if not is_master(args): - return metrics - device = torch.device(args.device) - model.eval() - - # CHANGE - # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) - # metrics.update(zero_shot_metrics) - if is_master(args): - print("Evaluating...") - metric_names = args.lp_metrics.split(",") - eval_tool = LPMetrics(metric_names=metric_names) - - autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress - if "val" in data and ( - args.val_frequency - and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) - ): - if args.parallel_eval: - dataloader, sampler = data["val"].dataloader, data["val"].sampler - if args.distributed and sampler is not None: - sampler.set_epoch(epoch) - samples_per_val = dataloader.num_samples - else: - dataloader = data["val"].dataloader - num_samples = 0 - samples_per_val = dataloader.num_samples - - eval_info = {"pred": [], "target": []} - with torch.no_grad(): - for i, batch in enumerate(dataloader): - audio = batch # contains mel_spec, wavform, and longer list - class_label = batch["class_label"] - - # audio = audio.to(device=device, non_blocking=True) - class_label = class_label.to(device=device, non_blocking=True) - - with autocast(): - pred = model(audio, device=device) - if args.parallel_eval: - pred, class_label = lp_gather_features( - pred, class_label, args.world_size, args.horovod - ) - eval_info["pred"].append(pred) - eval_info["target"].append(class_label) - - num_samples += class_label.shape[0] - - if (i % 100) == 0: # and i != 0: - logging.info( - f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" - ) - - if is_master(args): - eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu() - eval_info["target"] = torch.cat(eval_info["target"], 0).cpu() - metric_dict = eval_tool.evaluate_mertics( - eval_info["pred"], eval_info["target"] - ) - metrics.update(metric_dict) - if "epoch" not in metrics.keys(): - metrics.update({"epoch": epoch}) - - if is_master(args): - if not metrics: - return metrics - - logging.info( - f"Eval Epoch: {epoch} " - + "\n".join( - ["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics] - ) - ) - if args.save_logs: - for name, val in metrics.items(): - if tb_writer is not None: - tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch) - - with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: - f.write(json.dumps(metrics)) - f.write("\n") - - if args.wandb: - assert wandb is not None, "Please install wandb." - for name, val in metrics.items(): - wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch}) - - return metrics - else: - return metrics diff --git a/audioldm/clap/training/main.py b/audioldm/clap/training/main.py deleted file mode 100644 index 3b563a5d001be7adfbe779dee7ad8ac49aadc50d..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/main.py +++ /dev/null @@ -1,596 +0,0 @@ -from inspect import getargs -import logging -import os -import random -from datetime import datetime -import bisect -import copy -import numpy as np -import torch -import torch.backends.cudnn as cudnn -from torch import optim -from torch.cuda.amp import GradScaler -import faulthandler -import pathlib - -try: - import wandb -except ImportError: - wandb = None - -try: - import torch.utils.tensorboard as tensorboard -except ImportError: - tensorboard = None - -try: - import horovod.torch as hvd -except ImportError: - hvd = None - -from open_clip import create_model_and_transforms, trace_model, create_model -from training.data import get_data -from training.distributed import is_master, init_distributed_device, world_info_from_env -from training.logger import setup_logging -from training.params import parse_args -from training.scheduler import cosine_lr -from training.train import train_one_epoch, evaluate -from open_clip.utils import dataset_split, get_optimizer - - -def maintain_ckpts(args, startidx, all_idx_len): - for i in reversed(range(startidx, all_idx_len)): - if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): - os.rename( - os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), - os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), - ) - if os.path.exists( - os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") - ): - os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) - return - - -def update_top_k_performance( - new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True -): - """ - Record the top-k performance of the current epoch. - current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} - """ - if isinstance(new_metrics_inputs, (list, tuple)): - new_metrics_inputs = np.mean(new_metrics_inputs) - return update_top_k_performance( - new_metrics_inputs, - current_top_k_ckpt_metrics, - args=args, - ckpt=ckpt, - bignumbetter=bignumbetter, - ) - elif isinstance(new_metrics_inputs, dict): - new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) - return update_top_k_performance( - new_metrics_inputs, - current_top_k_ckpt_metrics, - args=args, - ckpt=ckpt, - bignumbetter=bignumbetter, - ) - elif isinstance(new_metrics_inputs, (float, int)): - update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} - sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) - sorted_values = sorted( - current_top_k_ckpt_metrics.values(), reverse=bignumbetter - ) - sorted_values_ = copy.deepcopy(sorted_values) - sorted_values.append(new_metrics_inputs) - sorted_values = sorted(sorted_values, reverse=bignumbetter) - sorted_values = sorted_values[:-1] - - if sorted_values == sorted_values_: - return current_top_k_ckpt_metrics, new_metrics_inputs - else: - for i in range(len(sorted_keys)): - if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: - current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] - update_flag[sorted_keys[i]] = True - for i in range(len(update_flag)): - if update_flag[i]: - maintain_ckpts(args, i, len(sorted_keys)) - torch.save( - ckpt, - os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), - ) - break - return current_top_k_ckpt_metrics, new_metrics_inputs - - -# def updateifNone(a, b): -# a = b if None else a -# return a - - -def is_pretrained_params(n): - return ( - n.startswith("transformer") - or n in ["positional_embedding", "text_projection"] - or n.startswith("token_embedding") - or n.startswith("ln_final") - or n.startswith("logit_scale_t") - ) - - -def random_seed(seed=42, rank=0): - torch.manual_seed(seed + rank) - np.random.seed(seed + rank) - random.seed(seed + rank) - - -def main(): - args = parse_args() - # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? - args.amodel = args.amodel.replace("/", "-") - # download sizes.json file - - # (yusong): the below two lines are for debug - # print("setting up faulthandler") - # faulthandler.register(10) - - random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - np.random.seed(args.seed) - if args.tmodel == "bert" or args.tmodel == "roberta" or args.tmodel == "bart": - assert ( - args.pretrained == "" or args.pretrained is None - ), "bert/roberta/bart text encoder does not support pretrained models." - - # get the name of the experiments - if args.name is None: - args.name = "-".join( - [ - datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), - f"model_{args.amodel}", - f"lr_{args.lr}", - f"b_{args.batch_size}", - f"j_{args.workers}", - f"p_{args.precision}", - ] - ) - - # discover initial world args early so we can log properly - args.distributed = False - args.local_rank, args.rank, args.world_size = world_info_from_env() - - if args.remotedata and is_master(args): - for dataset_name in args.datasetnames: - for split in dataset_split[dataset_name]: - if not os.path.exists(f"./json_files/{dataset_name}/{split}"): - os.makedirs(f"./json_files/{dataset_name}/{split}") - os.system( - f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" - ) - - args.log_path = None - if is_master(args, local=args.log_local): - log_base_path = os.path.join(args.logs, args.name) - os.makedirs(log_base_path, exist_ok=True) - log_filename = f"out-{args.rank}" if args.log_local else "out.log" - args.log_path = os.path.join(log_base_path, log_filename) - if os.path.exists(args.log_path): - print( - "Error. Experiment already exists. Use --name {} to specify a new experiment." - ) - return -1 - - # Set logger - args.log_level = logging.DEBUG if args.debug else logging.INFO - setup_logging(args.log_path, args.log_level) - - # fully initialize distributed device environment - device = init_distributed_device(args) - - args.wandb = "wandb" in args.report_to or "all" in args.report_to - args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to - if is_master(args): - args.tensorboard_path = ( - os.path.join(args.logs, args.name, "tensorboard") - if args.tensorboard - else "" - ) - args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") - for dirname in [args.tensorboard_path, args.checkpoint_path]: - if dirname: - os.makedirs(dirname, exist_ok=True) - else: - args.tensorboard_path = "" - args.checkpoint_path = "" - - if args.copy_codebase: - copy_codebase(args) - - assert args.precision in ["amp", "fp16", "fp32"] - if args.precision == "fp16": - logging.warning( - "It is recommended to use AMP mixed-precision instead of FP16. " - "FP16 support needs further verification and tuning, especially for train." - ) - - if args.horovod: - logging.info( - f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." - f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." - ) - elif args.distributed: - logging.info( - f"Running in distributed mode with multiple processes. Device: {args.device}." - f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." - ) - else: - logging.info(f"Running with a single process. Device {args.device}.") - - logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}") - - model, model_cfg = create_model( - args.amodel, - args.tmodel, - args.pretrained, - precision=args.precision, - device=device, - jit=args.torchscript, - force_quick_gelu=args.force_quick_gelu, - openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), - skip_params=True, - pretrained_audio=args.pretrained_audio, - pretrained_text=args.pretrained_text, - enable_fusion=args.enable_fusion, - fusion_type=args.fusion_type, - ) - - if args.horovod: - with torch.no_grad(): - for param in model.parameters(): - param.set_(param.contiguous()) - - if args.trace: - model = trace_model(model, batch_size=args.batch_size, device=device) - - if is_master(args): - logging.info("Model:") - logging.info(f"{str(model)}") - logging.info("Params:") - params_file = os.path.join(args.logs, args.name, "params.txt") - with open(params_file, "w") as f: - for name in sorted(vars(args)): - val = getattr(args, name) - logging.info(f" {name}: {val}") - f.write(f"{name}: {val}\n") - - if args.distributed and not args.horovod: - if args.use_bn_sync: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - ddp_args = {} - if args.ddp_static_graph: - # this doesn't exist in older PyTorch, arg only added if enabled - ddp_args["static_graph"] = True - model = torch.nn.parallel.DistributedDataParallel( - model, device_ids=[device], find_unused_parameters=True, **ddp_args - ) - - data = get_data(args, model_cfg) - assert len(data), "At least one train or eval dataset must be specified." - if args.trace: - assert "train" not in data, "Cannot train with traced model" - - exclude = ( - lambda n, p: p.ndim < 2 - or "bn" in n - or "ln" in n - or "bias" in n - or "logit_scale" in n - ) - include = lambda n, p: not exclude(n, p) - - named_parameters = list(model.named_parameters()) - - # freeze text encoder - text_freeze_parameters = [p for n, p in named_parameters if "text_branch" in n] - - if args.freeze_text: - print("Freeze Text!!!!") - for k in text_freeze_parameters: - k.requires_grad = False - - gain_or_bias_params = [ - p for n, p in named_parameters if exclude(n, p) and p.requires_grad - ] - rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] - - # set wd-related params to 0 if use adam optimizer - if args.optimizer == "adam": - args.wd = 0 - args.wd_pretrained = 0 - args.wd_new = 0 - - if args.train_data is None: - optimizer = None - scheduler = None - else: - total_steps = data["train"].dataloader.num_batches * args.epochs - - if args.split_opt: - for x in ["lr", "beta1", "beta2", "eps", "wd"]: - for y in ["_new", "_pretrained"]: - if getattr(args, x + y) is None: - setattr(args, x + y, getattr(args, x)) - - gain_or_bias_pretrained_params = [ - p - for n, p in named_parameters - if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) - ] - rest_pretrained_params = [ - p - for n, p in named_parameters - if (include(n, p) and p.requires_grad) and is_pretrained_params(n) - ] - gain_or_bias_new_params = [ - p - for n, p in named_parameters - if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n)) - ] - rest_new_params = [ - p - for n, p in named_parameters - if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n)) - ] - pretrained_params_optimizer = get_optimizer( - [ - {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, - { - "params": rest_pretrained_params, - "weight_decay": args.wd_pretrained, - }, - ], - lr=args.lr_pretrained, - betas=(args.beta1_pretrained, args.beta2_pretrained), - eps=args.eps_pretrained, - momentum=args.momentum_pretrained, - optimizer_name=args.optimizer, - ) - pretrained_params_scheduler = cosine_lr( - pretrained_params_optimizer, - args.lr_pretrained, - args.warmup, - total_steps, - ) - new_params_optimizer = get_optimizer( - [ - {"params": gain_or_bias_new_params, "weight_decay": 0.0}, - {"params": rest_new_params, "weight_decay": args.wd_new}, - ], - lr=args.lr_new, - betas=(args.beta1_new, args.beta2_new), - eps=args.eps_new, - momentum=args.momentum_new, - optimizer_name=args.optimizer, - ) - - new_params_scheduler = cosine_lr( - new_params_optimizer, args.lr_new, args.warmup, total_steps - ) - - optimizer = { - "pretrained": pretrained_params_optimizer, - "new": new_params_optimizer, - } - scheduler = { - "pretrained": pretrained_params_scheduler, - "new": new_params_scheduler, - } - - if args.horovod: - pretrained_params_optimizer = hvd.DistributedOptimizer( - pretrained_params_optimizer, - named_parameters=model.named_parameters(), - ) - new_params_optimizer = hvd.DistributedOptimizer( - new_params_optimizer, named_parameters=model.named_parameters() - ) - hvd.broadcast_parameters(model.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0) - hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) - else: - optimizer = get_optimizer( - [ - {"params": gain_or_bias_params, "weight_decay": 0.0}, - {"params": rest_params, "weight_decay": args.wd}, - ], - lr=args.lr, - betas=(args.beta1, args.beta2), - eps=args.eps, - momentum=args.momentum, - optimizer_name=args.optimizer, - ) - - scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) - - if args.horovod: - optimizer = hvd.DistributedOptimizer( - optimizer, named_parameters=model.named_parameters() - ) - hvd.broadcast_parameters(model.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state(optimizer, root_rank=0) - - scaler = GradScaler() if args.precision == "amp" else None - - # optionally resume from a checkpoint - start_epoch = 0 - if args.resume is not None: - if os.path.isfile(args.resume): - checkpoint = torch.load(args.resume, map_location=device) - if "epoch" in checkpoint: - # resuming a train checkpoint w/ epoch and optimizer state - start_epoch = checkpoint["epoch"] - sd = checkpoint["state_dict"] - if not args.distributed and next(iter(sd.items()))[0].startswith( - "module" - ): - sd = {k[len("module.") :]: v for k, v in sd.items()} - model.load_state_dict(sd) - if args.split_opt: - if optimizer is not None: - for k, o_ in optimizer.items(): - o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) - if optimizer is not None: - optimizer.load_state_dict(checkpoint["optimizer"]) - if scaler is not None and "scaler" in checkpoint: - scaler.load_state_dict(checkpoint["scaler"]) - logging.info( - f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" - ) - else: - # loading a bare (model only) checkpoint for fine-tune or evaluation - model.load_state_dict(checkpoint) - logging.info( - f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" - ) - if args.freeze_text: - print("Freeze Text!!!!") - for k in text_freeze_parameters: - k.requires_grad = False - else: - logging.info("=> no checkpoint found at '{}'".format(args.resume)) - - cudnn.benchmark = True - cudnn.deterministic = False - - # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 - args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) - writer = None - if args.save_logs and args.tensorboard: - assert tensorboard is not None, "Please install tensorboard." - writer = tensorboard.SummaryWriter(args.tensorboard_path) - - if args.wandb and is_master(args): - assert wandb is not None, "Please install wandb." - logging.debug("Starting wandb.") - args.train_sz = data["train"].dataloader.num_samples - if args.val_data is not None: - args.val_sz = data["val"].dataloader.num_samples - # you will have to configure this for your project! - wandb.init( - project="clap", - notes=args.wandb_notes, - name=args.wandb_notes, - tags=[], - config=vars(args), - ) - if args.debug: - wandb.watch(model, log="all") - wandb.save(params_file) - logging.debug("Finished loading wandb.") - - if "train" not in data: - evaluate(model, data, start_epoch, args, writer) - return - elif start_epoch == 0 and "val" in data and not args.no_eval: - evaluate(model, data, 0, args, writer) - # print(f'rank {args.rank}, Start First Evaluation')# (yusong): for debug - if args.save_top_performance: - current_top_k_ckpt_metrics = { - i: 0 for i in range(args.save_top_performance) - } # initialize the top-k metric for ckpts to 0 - - # print(f'rank {args.rank}, Start Training') # (yusong): for debug - for epoch in range(start_epoch, args.epochs): - # freeze the text param after (include) args.freeze_text_after, this is -1 by default - if epoch == args.freeze_text_after: - print("Text pretrained parameters are freezed since this epoch.") - for k in text_freeze_parameters: - k.requires_grad = False - if is_master(args): - logging.info(f"Start epoch {epoch}") - - train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) - completed_epoch = epoch + 1 - - if ( - any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) - and not args.no_eval - ): - metrics = evaluate(model, data, completed_epoch, args, writer) - if args.save_top_performance: - top_k_dataset = args.top_k_checkpoint_select_dataset - top_k_metric = args.top_k_checkpoint_select_metric - filtered_metrics = [ - v - for k, v in metrics.items() - if top_k_metric in k and top_k_dataset in k - ] # check all R@10 metrics (all dataset) and use it to update the ckpt - # Saving checkpoints. - if args.save_logs: - if args.split_opt: - opt_dict = { - k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() - } - else: - opt_dict = {"optimizer": optimizer.state_dict()} - checkpoint_dict = { - "epoch": completed_epoch, - "name": args.name, - "state_dict": model.state_dict(), - } - checkpoint_dict.update(opt_dict) - if scaler is not None: - checkpoint_dict["scaler"] = scaler.state_dict() - - if completed_epoch == args.epochs or ( - args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 - ): - torch.save( - checkpoint_dict, - os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), - ) - if args.save_most_recent: - torch.save( - checkpoint_dict, - os.path.join(args.checkpoint_path, f"epoch_latest.pt"), - ) - if args.save_top_performance and not args.no_eval: - update_top_k_performance( - filtered_metrics, - current_top_k_ckpt_metrics, - args, - checkpoint_dict, - bignumbetter=True, - ) - - if args.wandb and is_master(args): - wandb.finish() - - -def copy_codebase(args): - from shutil import copytree, ignore_patterns - - new_code_path = os.path.join(args.logs, args.name, "code") - if os.path.exists(new_code_path): - print( - f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." - ) - return -1 - print(f"Copying codebase to {new_code_path}") - current_code_path = os.path.realpath(__file__) - for _ in range(3): - current_code_path = os.path.dirname(current_code_path) - copytree( - current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") - ) - print("Done copying code.") - return 1 - - -if __name__ == "__main__": - main() diff --git a/audioldm/clap/training/params.py b/audioldm/clap/training/params.py deleted file mode 100644 index 0cc1a0e2d982e900988cf5a4b24b2e59b093537b..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/params.py +++ /dev/null @@ -1,563 +0,0 @@ -import argparse - - -def get_default_params(model_name): - # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) - model_name = model_name.lower() - if "vit" in model_name: - return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} - else: - return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--train-data", - type=str, - default=None, - help="Path to h5 filewith training data", - ) - parser.add_argument( - "--val-data", - type=str, - default=None, - help="Path to h5 file with validation data", - ) - parser.add_argument( - "--freeze-text", - default=False, - action="store_true", - help="if you need to freeze the text encoder, make this True", - ) - parser.add_argument( - "--freeze-text-after", - type=int, - default=-1, - help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it", - ) - parser.add_argument( - "--train-ipc", - type=str, - default=None, - help="Path to npy file of the number of instance per class in training data", - ) - parser.add_argument( - "--val-ipc", - type=str, - default=None, - help="Path to npy file of the number of instance per class in validation data", - ) - parser.add_argument( - "--train-num-samples", - type=int, - default=None, - help="Number of samples in dataset. Required for webdataset if not available in info file.", - ) - parser.add_argument( - "--val-num-samples", - type=int, - default=None, - help="Number of samples in dataset. Useful for webdataset if not available in info file.", - ) - parser.add_argument( - "--dataset-type", - choices=["webdataset", "csv", "auto", "toy"], - default="auto", - help="Which type of dataset to process.", - ) - parser.add_argument( - "--csv-separator", - type=str, - default="\t", - help="For csv-like datasets, which separator to use.", - ) - parser.add_argument( - "--csv-img-key", - type=str, - default="filepath", - help="For csv-like datasets, the name of the key for the image paths.", - ) - parser.add_argument( - "--csv-caption-key", - type=str, - default="title", - help="For csv-like datasets, the name of the key for the captions.", - ) - parser.add_argument( - "--imagenet-val", - type=str, - default=None, - help="Path to imagenet val set for conducting zero shot evaluation.", - ) - parser.add_argument( - "--imagenet-v2", - type=str, - default=None, - help="Path to imagenet v2 for conducting zero shot evaluation.", - ) - parser.add_argument( - "--datasetnames", - nargs="+", - default=None, - help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects", - ) - parser.add_argument( - "--full-train-dataset", - nargs="+", - default=None, - help="Which dataset will be trained with all the subsets. (train+test)", - ) - parser.add_argument( - "--exclude-eval-dataset", - nargs="+", - default=None, - help="Which dataset will be excluded with evaluation", - ) - parser.add_argument( - "--datasetinfos", - nargs="+", - default=None, - help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval", - ) - parser.add_argument( - "--dataset-proportion", - type=float, - default=1.0, - help="How much proportion of dataset we want to train.", - ) - parser.add_argument( - "--remotedata", - default=False, - action="store_true", - help="if the dataset is remote, set this flag", - ) - parser.add_argument( - "--class-label-path", - type=str, - default=None, - help="The path of the class label pickle or csv.", - ) - parser.add_argument( - "--datasetpath", - type=str, - default="/mnt/audio_clip/webdataset_tar", - help="The path to the dataset", - ) - parser.add_argument( - "--logs", - type=str, - default="./logs/", - help="Where to store tensorboard logs. Use None to avoid storing logs.", - ) - parser.add_argument( - "--log-local", - action="store_true", - default=False, - help="log files on local master, otherwise global master only.", - ) - parser.add_argument( - "--name", - type=str, - default=None, - help="Optional identifier for the experiment when storing logs. Otherwise use current time.", - ) - parser.add_argument( - "--workers", type=int, default=1, help="Number of workers per GPU." - ) - parser.add_argument( - "--batch-size", type=int, default=64, help="Batch size per GPU." - ) - parser.add_argument( - "--epochs", type=int, default=32, help="Number of epochs to train for." - ) - parser.add_argument("--lr", type=float, default=None, help="Learning rate.") - parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") - parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") - parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") - parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.") - parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") - - parser.add_argument( - "--split-opt", - action="store_true", - default=False, - help="Use this flag to skip the learning rate decay.", - ) - parser.add_argument( - "--lr-pretrained", type=float, default=None, help="Learning rate for text." - ) - parser.add_argument( - "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text." - ) - parser.add_argument( - "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text." - ) - parser.add_argument( - "--eps-pretrained", type=float, default=None, help="Adam epsilon for text." - ) - parser.add_argument( - "--wd-pretrained", type=float, default=0.2, help="Weight decay for text." - ) - parser.add_argument( - "--momentum-pretrained", type=float, default=0.9, help="Momentum for text." - ) - parser.add_argument( - "--lr-new", type=float, default=None, help="Learning rate for audio." - ) - parser.add_argument( - "--beta1-new", type=float, default=None, help="Adam beta 1 for audio." - ) - parser.add_argument( - "--beta2-new", type=float, default=None, help="Adam beta 2 for audio." - ) - parser.add_argument( - "--eps-new", type=float, default=None, help="Adam epsilon for audio." - ) - parser.add_argument( - "--wd-new", type=float, default=0.2, help="Weight decay for audio." - ) - parser.add_argument( - "--momentum-new", type=float, default=0.9, help="Momentum for audio." - ) - parser.add_argument( - "--warmup", type=int, default=10000, help="Number of steps to warmup for." - ) - parser.add_argument( - "--use-bn-sync", - default=False, - action="store_true", - help="Whether to use batch norm sync.", - ) - parser.add_argument( - "--skip-scheduler", - action="store_true", - default=False, - help="Use this flag to skip the learning rate decay.", - ) - parser.add_argument( - "--save-frequency", type=int, default=1, help="How often to save checkpoints." - ) - parser.add_argument( - "--save-top-performance", - type=int, - default=0, - help="Save the top x performance weights if the value >0", - ) - parser.add_argument( - "--save-most-recent", - action="store_true", - default=False, - help="Always save the most recent model trained to epoch_latest.pt.", - ) - parser.add_argument( - "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." - ) - parser.add_argument( - "--val-frequency", - type=int, - default=1, - help="How often to run evaluation with val data.", - ) - parser.add_argument( - "--resume", - default=None, - type=str, - help="path to latest checkpoint (default: none)", - ) - parser.add_argument( - "--precision", - choices=["amp", "fp16", "fp32"], - default="amp", - help="Floating point precision.", - ) - parser.add_argument( - "--amodel", - type=str, - default="RN50", - help="Name of the audio backbone to use.", - ) - parser.add_argument( - "--tmodel", - type=str, - default="transformer", - help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]", - ) - parser.add_argument( - "--pretrained-audio", - default="", - type=str, - help="Use a pretrained audio model weights for the audio encoder of CLAP", - ) - parser.add_argument( - "--pretrained-text", - default="", - type=str, - help="Use a pretrained text model weights for the text encoder of CLAP", - ) - parser.add_argument( - "--pretrained", - default="", - type=str, - help="Use a pretrained CLIP model weights with the specified tag or file path.", - ) - parser.add_argument( - "--pretrained-image", - default=False, - action="store_true", - help="Load imagenet pretrained weights for image tower backbone if available.", - ) - parser.add_argument( - "--lock-image", - default=False, - action="store_true", - help="Lock full image tower by disabling gradients.", - ) - parser.add_argument( - "--lock-image-unlocked-groups", - type=int, - default=0, - help="Leave last n image tower layer groups unlocked.", - ) - parser.add_argument( - "--lock-image-freeze-bn-stats", - default=False, - action="store_true", - help="Freeze BatchNorm running stats in image tower for any locked layers.", - ) - parser.add_argument( - "--local-loss", - default=False, - action="store_true", - help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)", - ) - parser.add_argument( - "--gather-with-grad", - default=False, - action="store_true", - help="enable full distributed gradient for feature gather", - ) - parser.add_argument( - "--force-quick-gelu", - default=False, - action="store_true", - help="Force use of QuickGELU activation for non-OpenAI transformer models.", - ) - parser.add_argument( - "--torchscript", - default=False, - action="store_true", - help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", - ) - parser.add_argument( - "--trace", - default=False, - action="store_true", - help="torch.jit.trace the model for inference / eval only", - ) - # arguments for distributed training - parser.add_argument( - "--dist-url", - default="env://", - type=str, - help="url used to set up distributed training", - ) - parser.add_argument( - "--dist-backend", default="nccl", type=str, help="distributed backend" - ) - parser.add_argument( - "--report-to", - default="", - type=str, - help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']", - ) - parser.add_argument( - "--wandb-notes", default="", type=str, help="Notes if logging with wandb" - ) - parser.add_argument( - "--C", type=float, default=3.16, help="inverse regularizer for logistic reg." - ) - parser.add_argument( - "--debug", - default=False, - action="store_true", - help="If true, more information is logged.", - ) - parser.add_argument( - "--copy-codebase", - default=False, - action="store_true", - help="If true, we copy the entire base on the log diretory, and execute from there.", - ) - parser.add_argument( - "--horovod", - default=False, - action="store_true", - help="Use horovod for distributed training.", - ) - parser.add_argument( - "--ddp-static-graph", - default=False, - action="store_true", - help="Enable static graph optimization for DDP in PyTorch >= 1.11.", - ) - parser.add_argument( - "--no-set-device-rank", - default=False, - action="store_true", - help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", - ) - parser.add_argument("--seed", type=int, default=4242, help="Default random seed.") - - parser.add_argument( - "--top-k-checkpoint-select-dataset", - type=str, - default="all", - help="The dataset of selecting top-k checkpoint.", - ) - - # @R10, @R@5, @R1, mAP@10 - parser.add_argument( - "--top-k-checkpoint-select-metric", - type=str, - default="_R@10", - help="The metric for selecting top-k checkpoint.", - ) - parser.add_argument( - "--openai-model-cache-dir", - type=str, - default="~/.cache/clip", - help="Directory to download OpenAI models.", - ) - parser.add_argument( - "--optimizer", - type=str, - default="adamw", - help="can be AdamW or SGD", - ) - parser.add_argument( - "--parallel-eval", - default=False, - action="store_true", - help="Eval in parallel (multi-GPU, multi-node).", - ) - - parser.add_argument( - "--no-eval", - default=False, - action="store_true", - help="Training without evaluation.", - ) - - parser.add_argument( - "--lp-mlp", - default=False, - action="store_true", - help="Linear Probe using MLP layer or not.", - ) - - parser.add_argument( - "--lp-freeze", - default=False, - action="store_true", - help="Linear Probe using Freeze CLAP or not", - ) - - parser.add_argument( - "--lp-act", - default="None", - type=str, - help="Options are ['relu','elu','prelu','softmax','sigmoid']", - ) - - parser.add_argument( - "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe." - ) - - parser.add_argument( - "--lp-metrics", - type=str, - default="map,mauc,acc", - help="Metrics of Linear Probe.", - ) - - parser.add_argument( - "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe" - ) - parser.add_argument( - "--kappa", - type=float, - default=0, - help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss", - ) - - parser.add_argument( - "--data-filling", - type=str, - default="pad", - help="type of data filling when the audio length is shorter than the max length." - "Can be one of the following: repeat, repeatpad, pad", - ) - parser.add_argument( - "--data-truncating", - type=str, - default="rand_trunc", - help="type of data truncation when the audio length is longer than the max length." - "Can be one of the following: rand_trunc, fusion", - ) - - parser.add_argument( - "--clap-mlploss", - default=False, - action="store_true", - help="Using MLP loss for CLAP model or not", - ) - - parser.add_argument( - "--wandb-id", - type=str, - default=None, - help="the id of wandb experiment to restore.", - ) - - parser.add_argument( - "--sleep", type=float, default=0, help="sleep n seconds before start training" - ) - - # variable length processing - parser.add_argument( - "--enable-fusion", - default=False, - action="store_true", - help="Enable feature funsion for variable-length data", - ) - - parser.add_argument( - "--fusion-type", - type=str, - default="None", - help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']", - ) - - parser.add_argument( - "--mixup", - default=False, - action="store_true", - help="Enable mixup in finetuning training.", - ) - parser.add_argument( - "--text-augment-selection", - type=str, - default=None, - help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']", - ) - - args = parser.parse_args() - - # If some params are not passed, we use the default values based on model name. - default_params = get_default_params(args.amodel) - for name, val in default_params.items(): - if getattr(args, name) is None: - setattr(args, name, val) - - return args diff --git a/audioldm/clap/training/scheduler.py b/audioldm/clap/training/scheduler.py deleted file mode 100644 index 7151ffbab25a113673b7627027b443b27f22cb0f..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/scheduler.py +++ /dev/null @@ -1,24 +0,0 @@ -import numpy as np - - -def assign_learning_rate(optimizer, new_lr): - for param_group in optimizer.param_groups: - param_group["lr"] = new_lr - - -def _warmup_lr(base_lr, warmup_length, step): - return base_lr * (step + 1) / warmup_length - - -def cosine_lr(optimizer, base_lr, warmup_length, steps): - def _lr_adjuster(step): - if step < warmup_length: - lr = _warmup_lr(base_lr, warmup_length, step) - else: - e = step - warmup_length - es = steps - warmup_length - lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr - assign_learning_rate(optimizer, lr) - return lr - - return _lr_adjuster diff --git a/audioldm/clap/training/train.py b/audioldm/clap/training/train.py deleted file mode 100644 index f5759c4679d2ee9c0748444adf66b8453cf09728..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/train.py +++ /dev/null @@ -1,838 +0,0 @@ -import json -import logging -import math -import os -import time -from contextlib import suppress - -import numpy as np -import torch -import torch.nn.functional as F - -try: - import wandb -except ImportError: - wandb = None - -from open_clip import ClipLoss, gather_features -from .distributed import is_master -from .zero_shot import zero_shot_eval - - -class AverageMeter(object): - """Computes and stores the average and current value""" - - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def unwrap_model(model): - if hasattr(model, "module"): - return model.module - else: - return model - - -def train_one_epoch( - model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None -): - device = torch.device(args.device) - autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress - model.train() - loss = ClipLoss( - local_loss=args.local_loss, - gather_with_grad=args.gather_with_grad, - cache_labels=True, - rank=args.rank, - world_size=args.world_size, - use_horovod=args.horovod, - mlp_loss=args.clap_mlploss, - weight_loss_kappa=args.kappa, - ) - - dataloader, sampler = data["train"].dataloader, data["train"].sampler - if args.distributed and sampler is not None: - sampler.set_epoch(epoch) - num_batches_per_epoch = dataloader.num_batches - sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) - - # for toy dataset - if args.dataset_type == "toy": - dataloader.dataset.generate_queue() - - loss_m = AverageMeter() - batch_time_m = AverageMeter() - data_time_m = AverageMeter() - end = time.time() - - for i, batch in enumerate(dataloader): - # logging.info(f"batch {i} of {num_batches_per_epoch}") - step = num_batches_per_epoch * epoch + i - if isinstance(scheduler, dict): - for s in scheduler.values(): - s(step) - else: - scheduler(step) - audios = batch # contains mel_spec, wavform, and longer list - texts = batch["text"] - # audios = audios.to(device=device, non_blocking=True) - # texts = texts.to(device=device, non_blocking=True) - - data_time_m.update(time.time() - end) - if isinstance(optimizer, dict): - for o_ in optimizer.values(): - o_.zero_grad() - else: - optimizer.zero_grad() - - with autocast(): - ( - audio_features, - text_features, - audio_features_mlp, - text_features_mlp, - logit_scale_a, - logit_scale_t, - ) = model(audios, texts, device) - - if args.clap_mlploss: - total_loss = loss( - audio_features=audio_features, - text_features=text_features, - logit_scale_a=logit_scale_a, - logit_scale_t=logit_scale_t, - audio_features_mlp=audio_features_mlp, - text_features_mlp=text_features_mlp, - ) - else: - total_loss = loss( - audio_features=audio_features, - text_features=text_features, - logit_scale_a=logit_scale_a, - ) - if isinstance(optimizer, dict): - if scaler is not None: - scaler.scale(total_loss).backward() - for o_ in optimizer.values(): - if args.horovod: - o_.synchronize() - scaler.unscale_(o_) - with o_.skip_synchronize(): - scaler.step(o_) - else: - scaler.step(o_) - scaler.update() - else: - total_loss.backward() - for o_ in optimizer.values(): - o_.step() - else: - if scaler is not None: - scaler.scale(total_loss).backward() - if args.horovod: - optimizer.synchronize() - scaler.unscale_(optimizer) - with optimizer.skip_synchronize(): - scaler.step(optimizer) - else: - scaler.step(optimizer) - scaler.update() - else: - total_loss.backward() - optimizer.step() - - # Note: we clamp to 4.6052 = ln(100), as in the original paper. - with torch.no_grad(): - unwrap_model(model).logit_scale_a.clamp_(0, math.log(100)) - if args.clap_mlploss: - unwrap_model(model).logit_scale_t.clamp_(0, math.log(100)) - - batch_time_m.update(time.time() - end) - end = time.time() - batch_count = i + 1 - if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): - if isinstance(audios, dict): - batch_size = len(audios["waveform"]) - else: - batch_size = len(audios) - num_samples = batch_count * batch_size * args.world_size - samples_per_epoch = dataloader.num_samples - percent_complete = 100.0 * batch_count / num_batches_per_epoch - - # NOTE loss is coarsely sampled, just master node and per log update - loss_m.update(total_loss.item(), batch_size) - logit_scale_scalar_a = logit_scale_a.item() - logit_scale_scalar_t = logit_scale_t.item() - if isinstance(optimizer, dict): - if args.clap_mlploss: - logging.info( - f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " - f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " - f"Data (t): {data_time_m.avg:.3f} " - f"Batch (t): {batch_time_m.avg:.3f} " - f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " - f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" - f"Logit Scale Text: {logit_scale_scalar_t:.3f}" - ) - log_data = { - "loss": loss_m.val, - "data_time": data_time_m.val, - "batch_time": batch_time_m.val, - "scale_audio": logit_scale_scalar_a, - "scale_text": logit_scale_scalar_t, - "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], - } - else: - logging.info( - f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " - f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " - f"Data (t): {data_time_m.avg:.3f} " - f"Batch (t): {batch_time_m.avg:.3f} " - f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " - f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" - ) - log_data = { - "loss": loss_m.val, - "data_time": data_time_m.val, - "batch_time": batch_time_m.val, - "scale_audio": logit_scale_scalar_a, - "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], - } - - else: - if args.clap_mlploss: - logging.info( - f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " - f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " - f"Data (t): {data_time_m.avg:.3f} " - f"Batch (t): {batch_time_m.avg:.3f} " - f"LR: {optimizer.param_groups[0]['lr']:5f} " - f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" - f"Logit Scale Text: {logit_scale_scalar_t:.3f}" - ) - - # Save train loss / etc. Using non avg meter values as loggers have their own smoothing - log_data = { - "loss": loss_m.val, - "data_time": data_time_m.val, - "batch_time": batch_time_m.val, - "scale_audio": logit_scale_scalar_a, - "scale_text": logit_scale_scalar_t, - "lr": optimizer.param_groups[0]["lr"], - } - else: - logging.info( - f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " - f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " - f"Data (t): {data_time_m.avg:.3f} " - f"Batch (t): {batch_time_m.avg:.3f} " - f"LR: {optimizer.param_groups[0]['lr']:5f} " - f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" - ) - - # Save train loss / etc. Using non avg meter values as loggers have their own smoothing - log_data = { - "loss": loss_m.val, - "data_time": data_time_m.val, - "batch_time": batch_time_m.val, - "scale_audio": logit_scale_scalar_a, - "lr": optimizer.param_groups[0]["lr"], - } - for name, val in log_data.items(): - name = "train/" + name - if tb_writer is not None: - tb_writer.add_scalar(name, val, step) - if args.wandb: - assert wandb is not None, "Please install wandb." - wandb.log({name: val, "step": step}) - - # resetting batch / data time meters per log window - batch_time_m.reset() - data_time_m.reset() - # end for - - -def evaluate(model, data, epoch, args, tb_writer=None): - metrics = {} - if not args.parallel_eval: - if not is_master(args): - return metrics - device = torch.device(args.device) - model.eval() - - # CHANGE - # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) - # metrics.update(zero_shot_metrics) - if is_master(args): - print("Evaluating...") - autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress - if args.val_dataset_names == ["Clotho", "audiocaps"]: - # if only clotho and audiocaps are used, then we will use a different evaluation function. - # This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio. - if args.parallel_eval: - # (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps. - raise NotImplementedError( - "Parallel evaluation not supported for eval only Clotho and audiocaps." - ) - val_metrics_per_dataset = evaluate_clotho_audiocaps( - model, data, epoch, args, autocast, device, tb_writer - ) - for m in val_metrics_per_dataset.values(): - metrics.update(m) - if "epoch" not in metrics.keys(): - metrics.update({"epoch": epoch}) - metrics = select_top_metric_clotho_audiocaps( - metrics, val_metrics_per_dataset, args - ) - elif "val" in data and ( - args.val_frequency - and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) - ): - dataloader = data["val"].dataloader - num_samples = 0 - samples_per_val = dataloader.num_samples - - # FIXME this does not scale past small eval datasets - # all_audio_features @ all_text_features will blow up memory and compute very quickly - eval_info = {} - if args.clap_mlploss: - eval_info["all"] = { - "cumulative_loss": 0.0, - "num_samples": 0, - "all_audio_features": [], - "all_text_features": [], - "all_audio_features_mlp": [], - "all_text_features_mlp": [], - } # cumulative_loss = 0.0 - else: - eval_info["all"] = { - "cumulative_loss": 0.0, - "num_samples": 0, - "all_audio_features": [], - "all_text_features": [], - } # cumu - # all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], [] - with torch.no_grad(): - for i, batch in enumerate(dataloader): - audios = batch # contains mel_spec, wavform, and longer list - texts = batch["text"] - # audios = audios.to(device=device, non_blocking=True) - - all_names = list( - set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) - ) - for name in all_names: - if name not in eval_info.keys(): - if args.clap_mlploss: - eval_info[name] = { - "cumulative_loss": 0.0, - "num_samples": 0, - "all_audio_features": [], - "all_text_features": [], - "all_audio_features_mlp": [], - "all_text_features_mlp": [], - } - else: - eval_info[name] = { - "cumulative_loss": 0.0, - "num_samples": 0, - "all_audio_features": [], - "all_text_features": [], - } - with autocast(): - ( - audio_features, - text_features, - audio_features_mlp, - text_features_mlp, - logit_scale_a, - logit_scale_t, - ) = model(audios, texts, device) - - if args.parallel_eval: - # multi-GPU eval - if args.clap_mlploss: - ( - audio_features, - text_features, - audio_features_mlp, - text_features_mlp, - ) = gather_features( - audio_features=audio_features, - text_features=text_features, - audio_features_mlp=audio_features_mlp, - text_features_mlp=text_features_mlp, - local_loss=False, - gather_with_grad=False, - rank=args.rank, - world_size=args.world_size, - use_horovod=args.horovod, - mlp_loss=args.clap_mlploss, - ) - else: - (audio_features, text_features,) = gather_features( - audio_features=audio_features, - text_features=text_features, - local_loss=False, - gather_with_grad=False, - rank=args.rank, - world_size=args.world_size, - use_horovod=args.horovod, - mlp_loss=args.clap_mlploss, - ) - - if is_master(args): - num_samples += audio_features.shape[0] - for n in [*all_names, "all"]: - if n == "all": - eval_info[n]["all_audio_features"].append( - audio_features.cpu() - ) - eval_info[n]["all_text_features"].append( - text_features.cpu() - ) - if args.clap_mlploss: - eval_info[n]["all_audio_features_mlp"].append( - audio_features_mlp.cpu() - ) - eval_info[n]["all_text_features_mlp"].append( - text_features_mlp.cpu() - ) - else: - idx = np.where( - np.array( - [ - "-".join(b.split("/")[-3:-1]) - for b in batch["__url__"] - ] - ) - == n - )[0] - eval_info[n]["all_audio_features"].append( - audio_features.cpu().index_select( - 0, torch.tensor(idx).long() - ) - ) - eval_info[n]["all_text_features"].append( - text_features.cpu().index_select( - 0, torch.tensor(idx).long() - ) - ) - if args.clap_mlploss: - eval_info[n]["all_audio_features_mlp"].append( - audio_features_mlp.cpu().index_select( - 0, torch.tensor(idx).long() - ) - ) - eval_info[n]["all_text_features_mlp"].append( - text_features_mlp.cpu().index_select( - 0, torch.tensor(idx).long() - ) - ) - # print(f'eval step {i}') # (yusong): for debug - - # cumulative_loss += total_loss * batch_size - # num_samples += batch_size - if is_master(args) and (i % 100) == 0: # and i != 0: - logging.info( - f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" - ) - if is_master(args): - val_metrics_per_dataset = {} - for n in eval_info.keys(): - if args.clap_mlploss: - metrics_single_dataset = get_metrics( - audio_features=torch.cat( - eval_info[n]["all_audio_features"] - ), - text_features=torch.cat(eval_info[n]["all_text_features"]), - logit_scale_a=logit_scale_a.cpu(), - audio_features_mlp=torch.cat( - eval_info[n]["all_audio_features_mlp"] - ), - text_features_mlp=torch.cat( - eval_info[n]["all_text_features_mlp"] - ), - logit_scale_t=logit_scale_t.cpu(), - mlp_loss=args.clap_mlploss, - ) - else: - metrics_single_dataset = get_metrics( - audio_features=torch.cat( - eval_info[n]["all_audio_features"] - ), - text_features=torch.cat(eval_info[n]["all_text_features"]), - logit_scale_a=logit_scale_a.cpu(), - mlp_loss=args.clap_mlploss, - ) - val_metrics_per_dataset[n] = { - n + "/" + k: v for k, v in metrics_single_dataset.items() - } - metrics.update(val_metrics_per_dataset[n]) - if "epoch" not in metrics.keys(): - metrics.update({"epoch": epoch}) - if is_master(args): - if not metrics: - return metrics - - logging.info( - f"Eval Epoch: {epoch} " - + "\n".join( - [ - "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()]) - for m in val_metrics_per_dataset.values() - ] - ) - ) - - if args.save_logs: - for name, val in metrics.items(): - if tb_writer is not None: - tb_writer.add_scalar(f"val/{name}", val, epoch) - - with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: - f.write(json.dumps(metrics)) - f.write("\n") - - if args.wandb: - assert wandb is not None, "Please install wandb." - for name, val in metrics.items(): - wandb.log({f"val/{name}": val, "epoch": epoch}) - - return metrics - else: - return metrics - - -def get_metrics( - audio_features, - text_features, - logit_scale_a, - audio_features_mlp=None, - text_features_mlp=None, - logit_scale_t=None, - mlp_loss=False, -): - metrics = {} - if mlp_loss: - # Set up audio to text & text to audio similary matrice - a_logits_per_audio = ( - (logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu() - ) - a_logits_per_text = a_logits_per_audio.t().detach().cpu() - t_logits_per_audio = ( - (logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu() - ) - t_logits_per_text = t_logits_per_audio.t().detach().cpu() - - labels = torch.arange(audio_features.shape[0]).long() - # Change the loss from two terms into four terms with 2x2 combined CE loss - total_loss = ( - F.cross_entropy(a_logits_per_audio, labels) - + F.cross_entropy(a_logits_per_text, labels) - + F.cross_entropy(t_logits_per_audio, labels) - + F.cross_entropy(t_logits_per_text, labels) - ) / 4 - - metrics[f"cumulative_loss"] = total_loss.item() - metrics[f"num_samples"] = audio_features.shape[0] - - logits = { - "audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2, - "text_to_audio": (a_logits_per_text + t_logits_per_text) / 2, - } - ground_truth = torch.arange(len(text_features)).view(-1, 1) - - else: - # print("text_features", text_features) - # print("text_features.shape", text_features.shape) - logits_per_audio = ( - (logit_scale_a * audio_features @ text_features.t()).detach().cpu() - ) - logits_per_text = logits_per_audio.t().detach().cpu() - - labels = torch.arange(audio_features.shape[0]).long() - # Change the loss from two terms into four terms with 2x2 combined CE loss - total_loss = ( - F.cross_entropy(logits_per_audio, labels) - + F.cross_entropy(logits_per_text, labels) - ) / 2 - - metrics[f"cumulative_loss"] = total_loss.item() - metrics[f"num_samples"] = audio_features.shape[0] - - logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text} - - ground_truth = torch.arange(len(text_features)).view(-1, 1) - - for name, logit in logits.items(): - ranking = torch.argsort(logit, descending=True) - preds = torch.where(ranking == ground_truth)[ - 1 - ] # (yusong) this line is slow because it uses single thread - preds = preds.detach().cpu().numpy() - metrics[f"{name}_mean_rank"] = preds.mean() + 1 - metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 - for k in [1, 5, 10]: - metrics[f"{name}_R@{k}"] = np.mean(preds < k) - # map@10 - metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) - - return metrics - - -def evaluate_clotho_audiocaps( - model, data, epoch, args, autocast, device, tb_writer=None -): - """ - Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py. - 1. for text-to-audio retrieval, do 5 times and average the results - 2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text - 3. for map@10 in audio-to-text retrieval: - 3.1: sort the rank of 5 text - 3.2: exclude the rank >=10 (0-index) - 3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks). - (3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth. - (3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc. - """ - # TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now. - dataloader = data["val"].dataloader - with torch.no_grad(): - eval_info = {} - for i, batch in enumerate(dataloader): - audios = batch # contains mel_spec, wavform, and longer list - - # each item in the list has 5 texts - if args.tmodel == "transformer": - from open_clip import tokenize - - texts = [tokenize(t) for t in batch["full_text"]] - texts = torch.cat(texts) - else: - from .data import tokenizer - - texts = [ - tokenizer(t) for t in batch["full_text"] - ] # 5 texts for each audio - texts = { - k: torch.cat([t[k] for t in texts]) for k in texts[0].keys() - } # 5 x batch - - # audios = audios.to(device=device, non_blocking=True) - - all_names = list( - set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) - ) - for name in all_names: - if name not in eval_info.keys(): - # we will not use mlp outputs even if args.clap_mlploss=True - eval_info[name] = { - "cumulative_loss": 0.0, - "num_samples": 0, - "all_audio_features": [], - "all_text_features": [], - } - with autocast(): - audio_features = model(audios, None, device) - text_features = model(None, texts, device) - audio_features = F.normalize(audio_features, dim=-1) - text_features = F.normalize(text_features, dim=-1) - - all_names = list( - set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) - ) - for n in all_names: - idx = np.where( - np.array( - ["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]] - ) - == n - )[0] - eval_info[n]["all_audio_features"].append( - audio_features.cpu().index_select(0, torch.tensor(idx).long()) - ) - # (yusong) please double-check. This is for selecting 5 text features at once. - # because idx is a list of indices in size of num_samples, - # and text_features is a tensor of size (5*num_samples, dim) - # so we need to select 5 consecutive indices at once for a single index in idx. - eval_info[n]["all_text_features"].append( - text_features.cpu() - .reshape([-1, 5, text_features.shape[1]]) - .index_select(0, torch.tensor(idx).long()) - .reshape([-1, text_features.shape[1]]) - ) - - val_metrics_all = {} - - for n in eval_info.keys(): - logit_scale_a, logit_scale_t = model(None, None, device) - logit_scale_a = logit_scale_a.cpu() - - audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0) - text_features = torch.cat(eval_info[n]["all_text_features"], dim=0) - - logits_per_audio = ( - (logit_scale_a * audio_features @ text_features.t()).detach().cpu() - ) - logits_per_text = logits_per_audio.t().detach().cpu() - - # logits_per_audio shape: [num_samples, num_samples*5] - # logits_per_text shape: [num_samples*5, num_samples] - - logging.info( - f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, " - f"logits_per_text shape: {logits_per_text.shape}" - ) - - metrics = {} - num_samples = audio_features.shape[0] - metrics[f"num_samples"] = num_samples - - # (yusong) the following code is very important, please double-check: - # logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d] - # logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] - # Those two are retrieving one of the 5 text for each audio. - labels = torch.arange(audio_features.shape[0]).long() - audio_to_text_loss = [ - F.cross_entropy( - logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d], - labels, - ) - for d in range(5) - ] - text_to_audio_loss = [ - F.cross_entropy( - logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :], - labels, - ) - for d in range(5) - ] - total_loss = (np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)) / 2 - - metrics[f"cumulative_loss"] = total_loss.item() - - # text to audio: do 5 times - pred_text = [] - for d in range(5): - logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] - ground_truth = torch.arange(len(logit)).view(-1, 1) - ranking = torch.argsort( - logit, descending=True - ) # [num_samples, num_samples] - preds = torch.where(ranking == ground_truth)[1] - pred_text.append(preds.detach().cpu().numpy()) - pred_text_concat = np.concatenate(pred_text, axis=0) # [5*num_samples] - metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1 - metrics[f"text_to_audio_median_rank"] = ( - np.floor(np.median(pred_text_concat)) + 1 - ) - for k in [1, 5, 10]: - metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k) - # map@10 - metrics[f"text_to_audio_mAP@10"] = np.mean( - np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0) - ) - - # audio to text: take the best result - # for audio to text map 10, sort and assign descending ground truth. - # see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103 - # map@10 - map_all = [] - pred_audio_all = [] - for d in range(num_samples): - # logits_per_audio: [num_samples, num_samples*5] - logit_single = logits_per_audio[d, :] # [5*num_samples] - # Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4] - ranking = torch.argsort( - logit_single, descending=True - ) # [5*num_samples] - # ranking: the index of first match, second match, ... - ground_truth = torch.arange(d * 5, d * 5 + 5)[None] - all_pred = torch.where( - torch.stack([ranking] * 5) == ground_truth.view(-1, 1) - )[1] - min_pred = torch.min(all_pred) - pred_audio_all.append(min_pred.detach().cpu().numpy()) - all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy() - # /5 because we have 5 text, so it means for the text rank >=10 we count as 0. - map_single = ( - np.sum( - (np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1)) - ) - / 5 - ) - map_all.append(map_single) - metrics[f"audio_to_text_mAP@10"] = np.mean(map_all) - for k in [1, 5, 10]: - metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k) - - val_metrics_all[n] = {n + "/" + k: v for k, v in metrics.items()} - return val_metrics_all - - -def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset): - """ - Calculate performance for Clotho+AudioCaps for model selection. - """ - selection_performance_all = [] - for n in val_metrics_per_dataset.keys(): - selection_performance = ( - val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"] - + val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"] - ) / 2 - selection_performance_all.append(selection_performance) - return np.mean(selection_performance_all) - - -def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args): - # val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value - # metrics: dict, key: metric name, value: metric value - # Hack: use args to save the top performance - if not hasattr(args, "top_selection_performance"): - selection_performance = calculate_selection_performance_clotho_audiocaps( - val_metrics_per_dataset - ) - # TODO: write the if and else together - metric_update = {} - for n in val_metrics_per_dataset.keys(): - for k in val_metrics_per_dataset[n].keys(): - metric_update[ - k.split("/")[0] + "-top" + "/" + k.split("/")[1] - ] = val_metrics_per_dataset[n][k] - metric_update["top_selection_performance"] = selection_performance - metric_update["top-selection-epoch"] = metrics["epoch"] - metrics.update(metric_update) - args.top_metric = metric_update - args.top_selection_performance = selection_performance - else: - selection_performance_new = calculate_selection_performance_clotho_audiocaps( - val_metrics_per_dataset - ) - selection_performance_old = args.top_selection_performance - if selection_performance_new > selection_performance_old: - metric_update = {} - for n in val_metrics_per_dataset.keys(): - for k in val_metrics_per_dataset[n].keys(): - metric_update[ - k.split("/")[0] + "-top" + "/" + k.split("/")[1] - ] = val_metrics_per_dataset[n][k] - metric_update["top_selection_performance"] = selection_performance_new - metric_update["top-selection-epoch"] = metrics["epoch"] - metrics.update(metric_update) - args.top_metric = metric_update - args.top_selection_performance = selection_performance_new - else: - metrics.update(args.top_metric) - return metrics diff --git a/audioldm/clap/training/zero_shot.py b/audioldm/clap/training/zero_shot.py deleted file mode 100644 index 28b8fccc1af17fc69002857a7f529ac041c374f2..0000000000000000000000000000000000000000 --- a/audioldm/clap/training/zero_shot.py +++ /dev/null @@ -1,95 +0,0 @@ -# NOTE: This script is currently not supported for CLAP. -import logging -from contextlib import suppress - -import torch -import torch.nn.functional as F -from tqdm import tqdm - -from open_clip import tokenize -from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template - - -def zero_shot_classifier(model, classnames, templates, args): - with torch.no_grad(): - zeroshot_weights = [] - for classname in tqdm(classnames): - texts = [template(classname) for template in templates] # format with class - texts = tokenize(texts).to(args.device) # tokenize - if args.distributed and not args.horovod: - class_embeddings = model.module.encode_text(texts) - else: - class_embeddings = model.encode_text(texts) - class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) - class_embedding /= class_embedding.norm() - zeroshot_weights.append(class_embedding) - zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) - return zeroshot_weights - - -def accuracy(output, target, topk=(1,)): - pred = output.topk(max(topk), 1, True, True)[1].t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - return [ - float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) - for k in topk - ] - - -def run(model, classifier, dataloader, args): - autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress - with torch.no_grad(): - top1, top5, n = 0.0, 0.0, 0.0 - for images, target in tqdm(dataloader, unit_scale=args.batch_size): - images = images.to(args.device) - target = target.to(args.device) - - with autocast(): - # predict - if args.distributed and not args.horovod: - image_features = model.module.encode_image(images) - else: - image_features = model.encode_image(images) - image_features = F.normalize(image_features, dim=-1) - logits = 100.0 * image_features @ classifier - - # measure accuracy - acc1, acc5 = accuracy(logits, target, topk=(1, 5)) - top1 += acc1 - top5 += acc5 - n += images.size(0) - - top1 = top1 / n - top5 = top5 / n - return top1, top5 - - -def zero_shot_eval(model, data, epoch, args): - if "imagenet-val" not in data and "imagenet-v2" not in data: - return {} - if args.zeroshot_frequency == 0: - return {} - if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: - return {} - - logging.info("Starting zero-shot imagenet.") - - logging.info("Building zero-shot classifier") - classifier = zero_shot_classifier( - model, imagenet_classnames, openai_imagenet_template, args - ) - - logging.info("Using classifier") - results = {} - if "imagenet-val" in data: - top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args) - results["imagenet-zeroshot-val-top1"] = top1 - results["imagenet-zeroshot-val-top5"] = top5 - if "imagenet-v2" in data: - top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args) - results["imagenetv2-zeroshot-val-top1"] = top1 - results["imagenetv2-zeroshot-val-top5"] = top5 - - logging.info("Finished zero-shot imagenet.") - - return results diff --git a/audioldm/hifigan/__init__.py b/audioldm/hifigan/__init__.py deleted file mode 100644 index e0ae476fe58c48e998c56234a55b871beba4042d..0000000000000000000000000000000000000000 --- a/audioldm/hifigan/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .models import Generator - - -class AttrDict(dict): - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self diff --git a/audioldm/hifigan/models.py b/audioldm/hifigan/models.py deleted file mode 100644 index c4382cc39de0463f9b7c0f33f037dbc233e7cb36..0000000000000000000000000000000000000000 --- a/audioldm/hifigan/models.py +++ /dev/null @@ -1,174 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Conv1d, ConvTranspose1d -from torch.nn.utils import weight_norm, remove_weight_norm - -LRELU_SLOPE = 0.1 - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -class ResBlock(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): - super(ResBlock, self).__init__() - self.h = h - self.convs1 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - self.convs1.apply(init_weights) - - self.convs2 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - ] - ) - self.convs2.apply(init_weights) - - def forward(self, x): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - xt = c2(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - - -class Generator(torch.nn.Module): - def __init__(self, h): - super(Generator, self).__init__() - self.h = h - self.num_kernels = len(h.resblock_kernel_sizes) - self.num_upsamples = len(h.upsample_rates) - self.conv_pre = weight_norm( - Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) - ) - resblock = ResBlock - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): - self.ups.append( - weight_norm( - ConvTranspose1d( - h.upsample_initial_channel // (2**i), - h.upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = h.upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate( - zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) - ): - self.resblocks.append(resblock(h, ch, k, d)) - - self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) - self.ups.apply(init_weights) - self.conv_post.apply(init_weights) - - def forward(self, x): - x = self.conv_pre(x) - for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - # print("Removing weight norm...") - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) diff --git a/audioldm/hifigan/utilities.py b/audioldm/hifigan/utilities.py deleted file mode 100644 index 47fd39ea0af181772d640feec2413cf631a75702..0000000000000000000000000000000000000000 --- a/audioldm/hifigan/utilities.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -import json - -import torch -import numpy as np - -import audioldm.hifigan as hifigan - -HIFIGAN_16K_64 = { - "resblock": "1", - "num_gpus": 6, - "batch_size": 16, - "learning_rate": 0.0002, - "adam_b1": 0.8, - "adam_b2": 0.99, - "lr_decay": 0.999, - "seed": 1234, - "upsample_rates": [5, 4, 2, 2, 2], - "upsample_kernel_sizes": [16, 16, 8, 4, 4], - "upsample_initial_channel": 1024, - "resblock_kernel_sizes": [3, 7, 11], - "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - "segment_size": 8192, - "num_mels": 64, - "num_freq": 1025, - "n_fft": 1024, - "hop_size": 160, - "win_size": 1024, - "sampling_rate": 16000, - "fmin": 0, - "fmax": 8000, - "fmax_for_loss": None, - "num_workers": 4, - "dist_config": { - "dist_backend": "nccl", - "dist_url": "tcp://localhost:54321", - "world_size": 1, - }, -} - - -def get_available_checkpoint_keys(model, ckpt): - print("==> Attemp to reload from %s" % ckpt) - state_dict = torch.load(ckpt)["state_dict"] - current_state_dict = model.state_dict() - new_state_dict = {} - for k in state_dict.keys(): - if ( - k in current_state_dict.keys() - and current_state_dict[k].size() == state_dict[k].size() - ): - new_state_dict[k] = state_dict[k] - else: - print("==> WARNING: Skipping %s" % k) - print( - "%s out of %s keys are matched" - % (len(new_state_dict.keys()), len(state_dict.keys())) - ) - return new_state_dict - - -def get_param_num(model): - num_param = sum(param.numel() for param in model.parameters()) - return num_param - - -def get_vocoder(config, device): - config = hifigan.AttrDict(HIFIGAN_16K_64) - vocoder = hifigan.Generator(config) - vocoder.eval() - vocoder.remove_weight_norm() - vocoder.to(device) - return vocoder - - -def vocoder_infer(mels, vocoder, lengths=None): - with torch.no_grad(): - wavs = vocoder(mels).squeeze(1) - - wavs = (wavs.cpu().numpy() * 32768).astype("int16") - - if lengths is not None: - wavs = wavs[:, :lengths] - - return wavs diff --git a/audioldm/latent_diffusion/__init__.py b/audioldm/latent_diffusion/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm/latent_diffusion/attention.py b/audioldm/latent_diffusion/attention.py deleted file mode 100644 index 583dd169e7ec9502ee29faeb12689a46494838c0..0000000000000000000000000000000000000000 --- a/audioldm/latent_diffusion/attention.py +++ /dev/null @@ -1,468 +0,0 @@ -from inspect import isfunction -import math -import torch -import torch.nn.functional as F -from torch import nn -from einops import rearrange - -from audioldm.latent_diffusion.util import checkpoint - - -def exists(val): - return val is not None - - -def uniq(arr): - return {el: True for el in arr}.keys() - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -def init_(tensor): - dim = tensor.shape[-1] - std = 1 / math.sqrt(dim) - tensor.uniform_(-std, std) - return tensor - - -# feedforward -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = ( - nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) - if not glu - else GEGLU(dim, inner_dim) - ) - - self.net = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def Normalize(in_channels): - return torch.nn.GroupNorm( - num_groups=32, num_channels=in_channels, eps=1e-6, affine=True - ) - - -class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super().__init__() - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - q, k, v = rearrange( - qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 - ) - k = k.softmax(dim=-1) - context = torch.einsum("bhdn,bhen->bhde", k, v) - out = torch.einsum("bhde,bhdn->bhen", context, q) - out = rearrange( - out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w - ) - return self.to_out(out) - - -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = rearrange(q, "b c h w -> b (h w) c") - k = rearrange(k, "b c h w -> b c (h w)") - w_ = torch.einsum("bij,bjk->bik", q, k) - - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, "b c h w -> b c (h w)") - w_ = rearrange(w_, "b i j -> b j i") - h_ = torch.einsum("bij,bjk->bik", v, w_) - h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) - h_ = self.proj_out(h_) - - return x + h_ - - -class CrossAttention(nn.Module): - """ - ### Cross Attention Layer - This falls-back to self-attention when conditional embeddings are not specified. - """ - - # use_flash_attention: bool = True - use_flash_attention: bool = False - def __init__( - self, - query_dim, - context_dim=None, - heads=8, - dim_head=64, - dropout=0.0, - is_inplace: bool = True, - ): - # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True): - """ - :param d_model: is the input embedding size - :param n_heads: is the number of attention heads - :param d_head: is the size of a attention head - :param d_cond: is the size of the conditional embeddings - :param is_inplace: specifies whether to perform the attention softmax computation inplace to - save memory - """ - super().__init__() - - self.is_inplace = is_inplace - self.n_heads = heads - self.d_head = dim_head - - # Attention scaling factor - self.scale = dim_head**-0.5 - - # The normal self-attention layer - if context_dim is None: - context_dim = query_dim - - # Query, key and value mappings - d_attn = dim_head * heads - self.to_q = nn.Linear(query_dim, d_attn, bias=False) - self.to_k = nn.Linear(context_dim, d_attn, bias=False) - self.to_v = nn.Linear(context_dim, d_attn, bias=False) - - # Final linear layer - self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout)) - - # Setup [flash attention](https://github.com/HazyResearch/flash-attention). - # Flash attention is only used if it's installed - # and `CrossAttention.use_flash_attention` is set to `True`. - try: - # You can install flash attention by cloning their Github repo, - # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention) - # and then running `python setup.py install` - from flash_attn.flash_attention import FlashAttention - - self.flash = FlashAttention() - # Set the scale for scaled dot-product attention. - self.flash.softmax_scale = self.scale - # Set to `None` if it's not installed - except ImportError: - self.flash = None - - def forward(self, x, context=None, mask=None): - """ - :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` - :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` - """ - - # If `cond` is `None` we perform self attention - has_cond = context is not None - if not has_cond: - context = x - - # Get query, key and value vectors - q = self.to_q(x) - k = self.to_k(context) - v = self.to_v(context) - - # Use flash attention if it's available and the head size is less than or equal to `128` - if ( - CrossAttention.use_flash_attention - and self.flash is not None - and not has_cond - and self.d_head <= 128 - ): - return self.flash_attention(q, k, v) - # Otherwise, fallback to normal attention - else: - return self.normal_attention(q, k, v) - - def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): - """ - #### Flash Attention - :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` - :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` - :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` - """ - - # Get batch size and number of elements along sequence axis (`width * height`) - batch_size, seq_len, _ = q.shape - - # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of - # shape `[batch_size, seq_len, 3, n_heads * d_head]` - qkv = torch.stack((q, k, v), dim=2) - # Split the heads - qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head) - - # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to - # fit this size. - if self.d_head <= 32: - pad = 32 - self.d_head - elif self.d_head <= 64: - pad = 64 - self.d_head - elif self.d_head <= 128: - pad = 128 - self.d_head - else: - raise ValueError(f"Head size ${self.d_head} too large for Flash Attention") - - # Pad the heads - if pad: - qkv = torch.cat( - (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1 - ) - - # Compute attention - # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ - # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]` - # TODO here I add the dtype changing - out, _ = self.flash(qkv.type(torch.float16)) - # Truncate the extra head size - out = out[:, :, :, : self.d_head].float() - # Reshape to `[batch_size, seq_len, n_heads * d_head]` - out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head) - - # Map to `[batch_size, height * width, d_model]` with a linear layer - return self.to_out(out) - - def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): - """ - #### Normal Attention - - :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` - :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` - :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` - """ - - # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` - q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32] - k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32] - v = v.view(*v.shape[:2], self.n_heads, -1) - - # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ - attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale - - # Compute softmax - # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ - if self.is_inplace: - half = attn.shape[0] // 2 - attn[half:] = attn[half:].softmax(dim=-1) - attn[:half] = attn[:half].softmax(dim=-1) - else: - attn = attn.softmax(dim=-1) - - # Compute attention output - # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ - # attn: [bs, 20, 64, 1] - # v: [bs, 1, 20, 32] - out = torch.einsum("bhij,bjhd->bihd", attn, v) - # Reshape to `[batch_size, height * width, n_heads * d_head]` - out = out.reshape(*out.shape[:2], -1) - # Map to `[batch_size, height * width, d_model]` with a linear layer - return self.to_out(out) - - -# class CrossAttention(nn.Module): -# def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): -# super().__init__() -# inner_dim = dim_head * heads -# context_dim = default(context_dim, query_dim) - -# self.scale = dim_head ** -0.5 -# self.heads = heads - -# self.to_q = nn.Linear(query_dim, inner_dim, bias=False) -# self.to_k = nn.Linear(context_dim, inner_dim, bias=False) -# self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - -# self.to_out = nn.Sequential( -# nn.Linear(inner_dim, query_dim), -# nn.Dropout(dropout) -# ) - -# def forward(self, x, context=None, mask=None): -# h = self.heads - -# q = self.to_q(x) -# context = default(context, x) -# k = self.to_k(context) -# v = self.to_v(context) - -# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - -# sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - -# if exists(mask): -# mask = rearrange(mask, 'b ... -> b (...)') -# max_neg_value = -torch.finfo(sim.dtype).max -# mask = repeat(mask, 'b j -> (b h) () j', h=h) -# sim.masked_fill_(~mask, max_neg_value) - -# # attention, what we cannot get enough of -# attn = sim.softmax(dim=-1) - -# out = einsum('b i j, b j d -> b i d', attn, v) -# out = rearrange(out, '(b h) n d -> b n (h d)', h=h) -# return self.to_out(out) - - -class BasicTransformerBlock(nn.Module): - def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - ): - super().__init__() - self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention( - query_dim=dim, - context_dim=context_dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def forward(self, x, context=None): - if context is None: - return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) - else: - return checkpoint( - self._forward, (x, context), self.parameters(), self.checkpoint - ) - - def _forward(self, x, context=None): - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. - First, project the input (aka embedding) - and reshape to b, t, d. - Then apply standard transformer action. - Finally, reshape to image - """ - - def __init__( - self, - in_channels, - n_heads, - d_head, - depth=1, - dropout=0.0, - context_dim=None, - no_context=False, - ): - super().__init__() - - if no_context: - context_dim = None - - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) - - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim - ) - for d in range(depth) - ] - ) - - self.proj_out = zero_module( - nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - ) - - def forward(self, x, context=None): - # note: if no context is given, cross-attention defaults to self-attention - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = rearrange(x, "b c h w -> b (h w) c") - for block in self.transformer_blocks: - x = block(x, context=context) - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) - x = self.proj_out(x) - return x + x_in diff --git a/audioldm/latent_diffusion/ddim.py b/audioldm/latent_diffusion/ddim.py deleted file mode 100644 index 57ee8d302c77cb09bd73ef803ef9e715098feafc..0000000000000000000000000000000000000000 --- a/audioldm/latent_diffusion/ddim.py +++ /dev/null @@ -1,377 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm - -from audioldm.latent_diffusion.util import ( - make_ddim_sampling_parameters, - make_ddim_timesteps, - noise_like, - extract_into_tensor, -) -import gradio as gr - -class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - def make_schedule( - self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True - ): - self.ddim_timesteps = make_ddim_timesteps( - ddim_discr_method=ddim_discretize, - num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, - ) - alphas_cumprod = self.model.alphas_cumprod - assert ( - alphas_cumprod.shape[0] == self.ddpm_num_timesteps - ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer("betas", to_torch(self.model.betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer( - "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) - ) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer( - "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", - to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), - ) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( - alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta, - verbose=verbose, - ) - self.register_buffer("ddim_sigmas", ddim_sigmas) - self.register_buffer("ddim_alphas", ddim_alphas) - self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) - self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) - / (1 - self.alphas_cumprod) - * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) - ) - self.register_buffer( - "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps - ) - - @torch.no_grad() - def sample( - self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0.0, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs, - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print( - f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" - ) - else: - if conditioning.shape[0] != batch_size: - print( - f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" - ) - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - samples, intermediates = self.ddim_sampling( - conditioning, - size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, - x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - return samples, intermediates - - @torch.no_grad() - def ddim_sampling( - self, - cond, - shape, - x_T=None, - ddim_use_original_steps=False, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - log_every_t=100, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - ): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = ( - self.ddpm_num_timesteps - if ddim_use_original_steps - else self.ddim_timesteps - ) - elif timesteps is not None and not ddim_use_original_steps: - subset_end = ( - int( - min(timesteps / self.ddim_timesteps.shape[0], 1) - * self.ddim_timesteps.shape[0] - ) - - 1 - ) - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {"x_inter": [img], "pred_x0": [img]} - time_range = ( - reversed(range(0, timesteps)) - if ddim_use_original_steps - else np.flip(timesteps) - ) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - # print(f"Running DDIM Sampling with {total_steps} timesteps") - - # iterator = gr.Progress().tqdm(time_range, desc="DDIM Sampler", total=total_steps) - iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample( - x0, ts - ) # TODO deterministic forward pass? - img = ( - img_orig * mask + (1.0 - mask) * img - ) # In the first sampling step, img is pure gaussian noise - - outs = self.p_sample_ddim( - img, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - img, pred_x0 = outs - if callback: - callback(i) - if img_callback: - img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates["x_inter"].append(img) - intermediates["pred_x0"].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas - - if noise is None: - noise = torch.randn_like(x0) - - return ( - extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 - + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise - ) - - @torch.no_grad() - def decode( - self, - x_latent, - cond, - t_start, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_original_steps=False, - ): - - timesteps = ( - np.arange(self.ddpm_num_timesteps) - if use_original_steps - else self.ddim_timesteps - ) - timesteps = timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - # print(f"Running DDIM Sampling with {total_steps} timesteps") - - # iterator = gr.Progress().tqdm(time_range, desc="Decoding image", total=total_steps) - iterator = tqdm(time_range, desc="Decoding image", total=total_steps) - x_dec = x_latent - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full( - (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long - ) - x_dec, _ = self.p_sample_ddim( - x_dec, - cond, - ts, - index=index, - use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - return x_dec - - @torch.no_grad() - def p_sample_ddim( - self, - x, - c, - t, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - ): - b, *_, device = *x.shape, x.device - - if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - # When unconditional_guidance_scale == 1: only e_t - # When unconditional_guidance_scale == 0: only unconditional - # When unconditional_guidance_scale > 1: add more unconditional guidance - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score( - self.model, e_t, x, t, c, **corrector_kwargs - ) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = ( - self.model.alphas_cumprod_prev - if use_original_steps - else self.ddim_alphas_prev - ) - sqrt_one_minus_alphas = ( - self.model.sqrt_one_minus_alphas_cumprod - if use_original_steps - else self.ddim_sqrt_one_minus_alphas - ) - sigmas = ( - self.model.ddim_sigmas_for_original_num_steps - if use_original_steps - else self.ddim_sigmas - ) - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full( - (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device - ) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise # TODO - return x_prev, pred_x0 diff --git a/audioldm/latent_diffusion/ddpm.py b/audioldm/latent_diffusion/ddpm.py deleted file mode 100644 index ffca031c27d413698adee5a58547b7d0ea4069c3..0000000000000000000000000000000000000000 --- a/audioldm/latent_diffusion/ddpm.py +++ /dev/null @@ -1,441 +0,0 @@ -""" -wild mixture of -https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py -https://github.com/CompVis/taming-transformers --- merci -""" -import sys -import os - -import torch -import torch.nn as nn -import numpy as np -from contextlib import contextmanager -from functools import partial -from tqdm import tqdm - -from audioldm.utils import exists, default, count_params, instantiate_from_config -from audioldm.latent_diffusion.ema import LitEma -from audioldm.latent_diffusion.util import ( - make_beta_schedule, - extract_into_tensor, - noise_like, -) -import soundfile as sf -import os - - -__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -def uniform_on_device(r1, r2, shape, device): - return (r1 - r2) * torch.rand(*shape, device=device) + r2 - - -class DiffusionWrapper(nn.Module): - def __init__(self, diff_model_config, conditioning_key): - super().__init__() - self.diffusion_model = instantiate_from_config(diff_model_config) - self.conditioning_key = conditioning_key - assert self.conditioning_key in [ - None, - "concat", - "crossattn", - "hybrid", - "adm", - "film", - ] - - def forward( - self, x, t, c_concat: list = None, c_crossattn: list = None, c_film: list = None - ): - x = x.contiguous() - t = t.contiguous() - - if self.conditioning_key is None: - out = self.diffusion_model(x, t) - elif self.conditioning_key == "concat": - xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t) - elif self.conditioning_key == "crossattn": - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc) - elif self.conditioning_key == "hybrid": - xc = torch.cat([x] + c_concat, dim=1) - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc) - elif ( - self.conditioning_key == "film" - ): # The condition is assumed to be a global token, which wil pass through a linear layer and added with the time embedding for the FILM - cc = c_film[0].squeeze(1) # only has one token - out = self.diffusion_model(x, t, y=cc) - elif self.conditioning_key == "adm": - cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc) - else: - raise NotImplementedError() - - return out - - -class DDPM(nn.Module): - # classic DDPM with Gaussian diffusion, in image space - def __init__( - self, - unet_config, - timesteps=1000, - beta_schedule="linear", - loss_type="l2", - ckpt_path=None, - ignore_keys=[], - load_only_unet=False, - monitor="val/loss", - use_ema=True, - first_stage_key="image", - latent_t_size=256, - latent_f_size=16, - channels=3, - log_every_t=100, - clip_denoised=True, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - given_betas=None, - original_elbo_weight=0.0, - v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1.0, - conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules - scheduler_config=None, - use_positional_encodings=False, - learn_logvar=False, - logvar_init=0.0, - ): - super().__init__() - assert parameterization in [ - "eps", - "x0", - ], 'currently only supporting "eps" and "x0"' - self.parameterization = parameterization - self.state = None - # print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") - self.cond_stage_model = None - self.clip_denoised = clip_denoised - self.log_every_t = log_every_t - self.first_stage_key = first_stage_key - - self.latent_t_size = latent_t_size - self.latent_f_size = latent_f_size - - self.channels = channels - self.use_positional_encodings = use_positional_encodings - self.model = DiffusionWrapper(unet_config, conditioning_key) - count_params(self.model, verbose=True) - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self.model) - # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - self.use_scheduler = scheduler_config is not None - if self.use_scheduler: - self.scheduler_config = scheduler_config - - self.v_posterior = v_posterior - self.original_elbo_weight = original_elbo_weight - self.l_simple_weight = l_simple_weight - - if monitor is not None: - self.monitor = monitor - - self.register_schedule( - given_betas=given_betas, - beta_schedule=beta_schedule, - timesteps=timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s, - ) - - self.loss_type = loss_type - - self.learn_logvar = learn_logvar - self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) - if self.learn_logvar: - self.logvar = nn.Parameter(self.logvar, requires_grad=True) - else: - self.logvar = nn.Parameter(self.logvar, requires_grad=False) - - self.logger_save_dir = None - self.logger_project = None - self.logger_version = None - self.label_indices_total = None - # To avoid the system cannot find metric value for checkpoint - self.metrics_buffer = { - "val/kullback_leibler_divergence_sigmoid": 15.0, - "val/kullback_leibler_divergence_softmax": 10.0, - "val/psnr": 0.0, - "val/ssim": 0.0, - "val/inception_score_mean": 1.0, - "val/inception_score_std": 0.0, - "val/kernel_inception_distance_mean": 0.0, - "val/kernel_inception_distance_std": 0.0, - "val/frechet_inception_distance": 133.0, - "val/frechet_audio_distance": 32.0, - } - self.initial_learning_rate = None - - def get_log_dir(self): - if ( - self.logger_save_dir is None - and self.logger_project is None - and self.logger_version is None - ): - return os.path.join( - self.logger.save_dir, self.logger._project, self.logger.version - ) - else: - return os.path.join( - self.logger_save_dir, self.logger_project, self.logger_version - ) - - def set_log_dir(self, save_dir, project, version): - self.logger_save_dir = save_dir - self.logger_project = project - self.logger_version = version - - def register_schedule( - self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - ): - if exists(given_betas): - betas = given_betas - else: - betas = make_beta_schedule( - beta_schedule, - timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s, - ) - alphas = 1.0 - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - - (timesteps,) = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - assert ( - alphas_cumprod.shape[0] == self.num_timesteps - ), "alphas have to be defined for each timestep" - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.register_buffer("betas", to_torch(betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) - ) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * ( - 1.0 - alphas_cumprod_prev - ) / (1.0 - alphas_cumprod) + self.v_posterior * betas - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer("posterior_variance", to_torch(posterior_variance)) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer( - "posterior_log_variance_clipped", - to_torch(np.log(np.maximum(posterior_variance, 1e-20))), - ) - self.register_buffer( - "posterior_mean_coef1", - to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), - ) - self.register_buffer( - "posterior_mean_coef2", - to_torch( - (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) - ), - ) - - if self.parameterization == "eps": - lvlb_weights = self.betas**2 / ( - 2 - * self.posterior_variance - * to_torch(alphas) - * (1 - self.alphas_cumprod) - ) - elif self.parameterization == "x0": - lvlb_weights = ( - 0.5 - * np.sqrt(torch.Tensor(alphas_cumprod)) - / (2.0 * 1 - torch.Tensor(alphas_cumprod)) - ) - else: - raise NotImplementedError("mu not supported") - # TODO how to choose this term - lvlb_weights[0] = lvlb_weights[1] - self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) - assert not torch.isnan(self.lvlb_weights).all() - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.model.parameters()) - self.model_ema.copy_to(self.model) - if context is not None: - # print(f"{context}: Switched to EMA weights") - pass - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.model.parameters()) - if context is not None: - # print(f"{context}: Restored training weights") - pass - - def q_mean_variance(self, x_start, t): - """ - Get the distribution q(x_t | x_0). - :param x_start: the [N x C x ...] tensor of noiseless inputs. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :return: A tuple (mean, variance, log_variance), all of x_start's shape. - """ - mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = extract_into_tensor( - self.log_one_minus_alphas_cumprod, t, x_start.shape - ) - return mean, variance, log_variance - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - * noise - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start - + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract_into_tensor( - self.posterior_log_variance_clipped, t, x_t.shape - ) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, x, t, clip_denoised: bool): - model_out = self.model(x, t) - if self.parameterization == "eps": - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": - x_recon = model_out - if clip_denoised: - x_recon.clamp_(-1.0, 1.0) - - model_mean, posterior_variance, posterior_log_variance = self.q_posterior( - x_start=x_recon, x_t=x, t=t - ) - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): - b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance( - x=x, t=t, clip_denoised=clip_denoised - ) - noise = noise_like(x.shape, device, repeat_noise) - # no noise when t == 0 - nonzero_mask = ( - (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() - ) - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def p_sample_loop(self, shape, return_intermediates=False): - device = self.betas.device - b = shape[0] - img = torch.randn(shape, device=device) - intermediates = [img] - for i in tqdm( - reversed(range(0, self.num_timesteps)), - desc="Sampling t", - total=self.num_timesteps, - ): - img = self.p_sample( - img, - torch.full((b,), i, device=device, dtype=torch.long), - clip_denoised=self.clip_denoised, - ) - if i % self.log_every_t == 0 or i == self.num_timesteps - 1: - intermediates.append(img) - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample(self, batch_size=16, return_intermediates=False): - shape = (batch_size, channels, self.latent_t_size, self.latent_f_size) - channels = self.channels - return self.p_sample_loop(shape, return_intermediates=return_intermediates) - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) - * noise - ) - - def forward(self, x, *args, **kwargs): - t = torch.randint( - 0, self.num_timesteps, (x.shape[0],), device=self.device - ).long() - return self.p_losses(x, t, *args, **kwargs) - - def get_input(self, batch, k): - # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch - fbank, log_magnitudes_stft, label_indices, fname, waveform, text = batch - ret = {} - - ret["fbank"] = ( - fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() - ) - ret["stft"] = log_magnitudes_stft.to( - memory_format=torch.contiguous_format - ).float() - # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float() - ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() - ret["text"] = list(text) - ret["fname"] = fname - - return ret[k] diff --git a/audioldm/latent_diffusion/ema.py b/audioldm/latent_diffusion/ema.py deleted file mode 100644 index 192b012186bab3d8a5380bc9b891da8eef0fd9fa..0000000000000000000000000000000000000000 --- a/audioldm/latent_diffusion/ema.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -from torch import nn - -class LitEma(nn.Module): - def __init__(self, model, decay=0.9999, use_num_upates=True): - super().__init__() - if decay < 0.0 or decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.m_name2s_name = {} - self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) - self.register_buffer( - "num_updates", - torch.tensor(0, dtype=torch.int) - if use_num_upates - else torch.tensor(-1, dtype=torch.int), - ) - - for name, p in model.named_parameters(): - if p.requires_grad: - # remove as '.'-character is not allowed in buffers - s_name = name.replace(".", "") - self.m_name2s_name.update({name: s_name}) - self.register_buffer(s_name, p.clone().detach().data) - - self.collected_params = [] - - def forward(self, model): - decay = self.decay - - if self.num_updates >= 0: - self.num_updates += 1 - decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) - - one_minus_decay = 1.0 - decay - - with torch.no_grad(): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - - for key in m_param: - if m_param[key].requires_grad: - sname = self.m_name2s_name[key] - shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_( - one_minus_decay * (shadow_params[sname] - m_param[key]) - ) - else: - assert not key in self.m_name2s_name - - def copy_to(self, model): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - for key in m_param: - if m_param[key].requires_grad: - m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) - else: - assert not key in self.m_name2s_name - - def store(self, parameters): - """ - Save the current parameters for restoring later. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - self.collected_params = [param.clone() for param in parameters] - - def restore(self, parameters): - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. - """ - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) diff --git a/audioldm/latent_diffusion/openaimodel.py b/audioldm/latent_diffusion/openaimodel.py deleted file mode 100644 index 831d7aafb36bba16888e4389153979a6c13639f5..0000000000000000000000000000000000000000 --- a/audioldm/latent_diffusion/openaimodel.py +++ /dev/null @@ -1,1069 +0,0 @@ -from abc import abstractmethod -import math - -import numpy as np -import torch as th -import torch.nn as nn -import torch.nn.functional as F - -from audioldm.latent_diffusion.util import ( - checkpoint, - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) -from audioldm.latent_diffusion.attention import SpatialTransformer - - -# dummy replace -def convert_module_to_f16(x): - pass - - -def convert_module_to_f32(x): - pass - - -## go -class AttentionPool2d(nn.Module): - """ - Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py - """ - - def __init__( - self, - spacial_dim: int, - embed_dim: int, - num_heads_channels: int, - output_dim: int = None, - ): - super().__init__() - self.positional_embedding = nn.Parameter( - th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 - ) - self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) - self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) - self.num_heads = embed_dim // num_heads_channels - self.attention = QKVAttention(self.num_heads) - - def forward(self, x): - b, c, *_spatial = x.shape - x = x.reshape(b, c, -1).contiguous() # NC(HW) - x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) - x = self.qkv_proj(x) - x = self.attention(x) - x = self.c_proj(x) - return x[:, :, 0] - - -class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - - @abstractmethod - def forward(self, x, emb): - """ - Apply the module to `x` given `emb` timestep embeddings. - """ - - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward(self, x, emb, context=None): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer): - x = layer(x, context) - else: - x = layer(x) - return x - - -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd( - dims, self.channels, self.out_channels, 3, padding=padding - ) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class TransposedUpsample(nn.Module): - "Learned 2x upsampling without padding" - - def __init__(self, channels, out_channels=None, ks=5): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - - self.up = nn.ConvTranspose2d( - self.channels, self.out_channels, kernel_size=ks, stride=2 - ) - - def forward(self, x): - return self.up(x) - - -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd( - dims, - self.channels, - self.out_channels, - 3, - stride=stride, - padding=padding, - ) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: if True, use gradient checkpointing on this module. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. - """ - - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - use_checkpoint=False, - up=False, - down=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) - elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - - def _forward(self, x, emb): - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) - if use_new_attention_order: - # split qkv before split heads - self.attention = QKVAttention(self.num_heads) - else: - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) - - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - - def forward(self, x): - return checkpoint( - self._forward, (x,), self.parameters(), True - ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - # return pt_checkpoint(self._forward, x) # pytorch - - def _forward(self, x): - b, c, *spatial = x.shape - x = x.reshape(b, c, -1).contiguous() - qkv = self.qkv(self.norm(x)).contiguous() - h = self.attention(qkv).contiguous() - h = self.proj_out(h).contiguous() - return (x + h).reshape(b, c, *spatial).contiguous() - - -def count_flops_attn(model, _x, y): - """ - A counter for the `thop` package to count the operations in an - attention operation. - Meant to be used like: - macs, params = thop.profile( - model, - inputs=(inputs, timestamps), - custom_ops={QKVAttention: QKVAttention.count_flops}, - ) - """ - b, c, *spatial = y[0].shape - num_spatial = int(np.prod(spatial)) - # We perform two matmuls with the same number of ops. - # The first computes the weight matrix, the second computes - # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial**2) * c - model.total_ops += th.DoubleTensor([matmul_ops]) - - -class QKVAttentionLegacy(nn.Module): - """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = ( - qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1) - ) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length).contiguous() - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class QKVAttention(nn.Module): - """ - A module which performs QKV attention and splits in a different order. - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.chunk(3, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", - (q * scale).view(bs * self.n_heads, ch, length), - (k * scale).view(bs * self.n_heads, ch, length), - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum( - "bts,bcs->bct", - weight, - v.reshape(bs * self.n_heads, ch, length).contiguous(), - ) - return a.reshape(bs, -1, length).contiguous() - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class UNetModel(nn.Module): - """ - The full UNet model with attention and timestep embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - num_classes=None, - extra_film_condition_dim=None, - use_checkpoint=False, - use_fp16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - extra_film_use_concat=False, # If true, concatenate extrafilm condition with time embedding, else addition - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - ): - super().__init__() - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert ( - num_head_channels != -1 - ), "Either num_heads or num_head_channels has to be set" - - if num_head_channels == -1: - assert ( - num_heads != -1 - ), "Either num_heads or num_head_channels has to be set" - - self.image_size = image_size - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.num_classes = num_classes - self.extra_film_condition_dim = extra_film_condition_dim - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.predict_codebook_ids = n_embed is not None - self.extra_film_use_concat = extra_film_use_concat - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - assert not ( - self.num_classes is not None and self.extra_film_condition_dim is not None - ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim." - - if self.num_classes is not None: - self.label_emb = nn.Embedding(num_classes, time_embed_dim) - - self.use_extra_film_by_concat = ( - self.extra_film_condition_dim is not None and self.extra_film_use_concat - ) - self.use_extra_film_by_addition = ( - self.extra_film_condition_dim is not None and not self.extra_film_use_concat - ) - - if self.extra_film_condition_dim is not None: - self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim) - # print("+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " % self.extra_film_condition_dim) - # if(self.use_extra_film_by_concat): - # print("\t By concatenation with time embedding") - # elif(self.use_extra_film_by_concat): - # print("\t By addition with time embedding") - - if use_spatial_transformer and ( - self.use_extra_film_by_concat or self.use_extra_film_by_addition - ): - # print("+ Spatial transformer will only be used as self-attention. Because you have choose to use film as your global condition.") - spatial_transformer_no_context = True - else: - spatial_transformer_no_context = False - - if use_spatial_transformer and not spatial_transformer_no_context: - assert ( - context_dim is not None - ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." - - if context_dim is not None and not spatial_transformer_no_context: - assert ( - use_spatial_transformer - ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." - from omegaconf.listconfig import ListConfig - - if type(context_dim) == ListConfig: - context_dim = list(context_dim) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim - if (not self.use_extra_film_by_concat) - else time_embed_dim * 2, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - dim_head = ( - ch // num_heads - if use_spatial_transformer - else num_head_channels - ) - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - no_context=spatial_transformer_no_context, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim - if (not self.use_extra_film_by_concat) - else time_embed_dim * 2, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim - if (not self.use_extra_film_by_concat) - else time_embed_dim * 2, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - no_context=spatial_transformer_no_context, - ), - ResBlock( - ch, - time_embed_dim - if (not self.use_extra_film_by_concat) - else time_embed_dim * 2, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim - if (not self.use_extra_film_by_concat) - else time_embed_dim * 2, - dropout, - out_channels=model_channels * mult, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - # num_heads = 1 - dim_head = ( - ch // num_heads - if use_spatial_transformer - else num_head_channels - ) - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads_upsample, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - no_context=spatial_transformer_no_context, - ) - ) - if level and i == num_res_blocks: - out_ch = ch - layers.append( - ResBlock( - ch, - time_embed_dim - if (not self.use_extra_film_by_concat) - else time_embed_dim * 2, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - if self.predict_codebook_ids: - self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) - - self.shape_reported = False - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - self.output_blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - self.output_blocks.apply(convert_module_to_f32) - - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param context: conditioning plugged in via crossattn - :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional - :return: an [N x C x ...] Tensor of outputs. - """ - if not self.shape_reported: - # print("The shape of UNet input is", x.size()) - self.shape_reported = True - - assert (y is not None) == ( - self.num_classes is not None or self.extra_film_condition_dim is not None - ), "must specify y if and only if the model is class-conditional or film embedding conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape == (x.shape[0],) - emb = emb + self.label_emb(y) - - if self.use_extra_film_by_addition: - emb = emb + self.film_emb(y) - elif self.use_extra_film_by_concat: - emb = th.cat([emb, self.film_emb(y)], dim=-1) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - h = self.middle_block(h, emb, context) - for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) - - -class EncoderUNetModel(nn.Module): - """ - The half UNet model with attention and timestep embedding. - For usage, see UNet. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - pool="adaptive", - *args, - **kwargs, - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - self.pool = pool - if pool == "adaptive": - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - nn.AdaptiveAvgPool2d((1, 1)), - zero_module(conv_nd(dims, ch, out_channels, 1)), - nn.Flatten(), - ) - elif pool == "attention": - assert num_head_channels != -1 - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - AttentionPool2d( - (image_size // ds), ch, num_head_channels, out_channels - ), - ) - elif pool == "spatial": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - nn.ReLU(), - nn.Linear(2048, self.out_channels), - ) - elif pool == "spatial_v2": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - normalization(2048), - nn.SiLU(), - nn.Linear(2048, self.out_channels), - ) - else: - raise NotImplementedError(f"Unexpected {pool} pooling") - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - - def forward(self, x, timesteps): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :return: an [N x K] Tensor of outputs. - """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - results = [] - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = self.middle_block(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = th.cat(results, axis=-1) - return self.out(h) - else: - h = h.type(x.dtype) - return self.out(h) diff --git a/audioldm/latent_diffusion/util.py b/audioldm/latent_diffusion/util.py deleted file mode 100644 index 8b289f6aa7f22a070870d8a706f944dc8547e936..0000000000000000000000000000000000000000 --- a/audioldm/latent_diffusion/util.py +++ /dev/null @@ -1,295 +0,0 @@ -# adopted from -# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py -# and -# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -# and -# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py -# -# thanks! - - -import os -import math -import torch -import torch.nn as nn -import numpy as np -from einops import repeat - -from audioldm.utils import instantiate_from_config - - -def make_beta_schedule( - schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 -): - if schedule == "linear": - betas = ( - torch.linspace( - linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 - ) - ** 2 - ) - - elif schedule == "cosine": - timesteps = ( - torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s - ) - alphas = timesteps / (1 + cosine_s) * np.pi / 2 - alphas = torch.cos(alphas).pow(2) - alphas = alphas / alphas[0] - betas = 1 - alphas[1:] / alphas[:-1] - betas = np.clip(betas, a_min=0, a_max=0.999) - - elif schedule == "sqrt_linear": - betas = torch.linspace( - linear_start, linear_end, n_timestep, dtype=torch.float64 - ) - elif schedule == "sqrt": - betas = ( - torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) - ** 0.5 - ) - else: - raise ValueError(f"schedule '{schedule}' unknown.") - return betas.numpy() - - -def make_ddim_timesteps( - ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True -): - if ddim_discr_method == "uniform": - c = num_ddpm_timesteps // num_ddim_timesteps - ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) - elif ddim_discr_method == "quad": - ddim_timesteps = ( - (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 - ).astype(int) - else: - raise NotImplementedError( - f'There is no ddim discretization method called "{ddim_discr_method}"' - ) - - # assert ddim_timesteps.shape[0] == num_ddim_timesteps - # add one to get the final alpha values right (the ones from first scale to data during sampling) - steps_out = ddim_timesteps + 1 - if verbose: - print(f"Selected timesteps for ddim sampler: {steps_out}") - return steps_out - - -def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): - # select alphas for computing the variance schedule - alphas = alphacums[ddim_timesteps] - alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) - - # according the the formula provided in https://arxiv.org/abs/2010.02502 - sigmas = eta * np.sqrt( - (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) - ) - if verbose: - print( - f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" - ) - print( - f"For the chosen value of eta, which is {eta}, " - f"this results in the following sigma_t schedule for ddim sampler {sigmas}" - ) - return sigmas, alphas, alphas_prev - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas) - - -def extract_into_tensor(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t).contiguous() - return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() - - -def checkpoint(func, inputs, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. - :param func: the function to evaluate. - :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) - - -class CheckpointFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, run_function, length, *args): - ctx.run_function = run_function - ctx.input_tensors = list(args[:length]) - ctx.input_params = list(args[length:]) - - with torch.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - input_grads = torch.autograd.grad( - output_tensors, - ctx.input_tensors + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return (None, None) + input_grads - - -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) - else: - embedding = repeat(timesteps, "b -> b d", d=dim) - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def scale_module(module, scale): - """ - Scale the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().mul_(scale) - return module - - -def mean_flat(tensor): - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def normalization(channels): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels) - - -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -def avg_pool_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D average pooling module. - """ - if dims == 1: - return nn.AvgPool1d(*args, **kwargs) - elif dims == 2: - return nn.AvgPool2d(*args, **kwargs) - elif dims == 3: - return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -class HybridConditioner(nn.Module): - def __init__(self, c_concat_config, c_crossattn_config): - super().__init__() - self.concat_conditioner = instantiate_from_config(c_concat_config) - self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) - - def forward(self, c_concat, c_crossattn): - c_concat = self.concat_conditioner(c_concat) - c_crossattn = self.crossattn_conditioner(c_crossattn) - return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} - - -def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( - shape[0], *((1,) * (len(shape) - 1)) - ) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() diff --git a/audioldm/ldm.py b/audioldm/ldm.py deleted file mode 100644 index b0392e28404c315e5d8ca5ede571da386f5d4b42..0000000000000000000000000000000000000000 --- a/audioldm/ldm.py +++ /dev/null @@ -1,715 +0,0 @@ -import os - -import torch -import numpy as np -from tqdm import tqdm -from audioldm.utils import default, instantiate_from_config, save_wave -from audioldm.latent_diffusion.ddpm import DDPM -from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution -from audioldm.latent_diffusion.util import noise_like -from audioldm.latent_diffusion.ddim import DDIMSampler -import os - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - -class LatentDiffusion(DDPM): - """main class""" - - def __init__( - self, - device="cuda", - first_stage_config=None, - cond_stage_config=None, - num_timesteps_cond=None, - cond_stage_key="image", - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - scale_by_std=False, - base_learning_rate=None, - *args, - **kwargs, - ): - self.device = device - self.learning_rate = base_learning_rate - self.num_timesteps_cond = default(num_timesteps_cond, 1) - self.scale_by_std = scale_by_std - assert self.num_timesteps_cond <= kwargs["timesteps"] - # for backwards compatibility after implementation of DiffusionWrapper - if conditioning_key is None: - conditioning_key = "concat" if concat_mode else "crossattn" - if cond_stage_config == "__is_unconditional__": - conditioning_key = None - ckpt_path = kwargs.pop("ckpt_path", None) - ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, *args, **kwargs) - self.concat_mode = concat_mode - self.cond_stage_trainable = cond_stage_trainable - self.cond_stage_key = cond_stage_key - self.cond_stage_key_orig = cond_stage_key - try: - self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: - self.num_downs = 0 - if not scale_by_std: - self.scale_factor = scale_factor - else: - self.register_buffer("scale_factor", torch.tensor(scale_factor)) - self.instantiate_first_stage(first_stage_config) - self.instantiate_cond_stage(cond_stage_config) - self.cond_stage_forward = cond_stage_forward - self.clip_denoised = False - - def make_cond_schedule( - self, - ): - self.cond_ids = torch.full( - size=(self.num_timesteps,), - fill_value=self.num_timesteps - 1, - dtype=torch.long, - ) - ids = torch.round( - torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) - ).long() - self.cond_ids[: self.num_timesteps_cond] = ids - - def register_schedule( - self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - ): - super().register_schedule( - given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s - ) - - self.shorten_cond_schedule = self.num_timesteps_cond > 1 - if self.shorten_cond_schedule: - self.make_cond_schedule() - - def instantiate_first_stage(self, config): - model = instantiate_from_config(config) - self.first_stage_model = model.eval() - self.first_stage_model.train = disabled_train - for param in self.first_stage_model.parameters(): - param.requires_grad = False - - def instantiate_cond_stage(self, config): - if not self.cond_stage_trainable: - if config == "__is_first_stage__": - print("Using first stage also as cond stage.") - self.cond_stage_model = self.first_stage_model - elif config == "__is_unconditional__": - print(f"Training {self.__class__.__name__} as an unconditional model.") - self.cond_stage_model = None - # self.be_unconditional = True - else: - model = instantiate_from_config(config) - self.cond_stage_model = model.eval() - self.cond_stage_model.train = disabled_train - for param in self.cond_stage_model.parameters(): - param.requires_grad = False - else: - assert config != "__is_first_stage__" - assert config != "__is_unconditional__" - model = instantiate_from_config(config) - self.cond_stage_model = model - self.cond_stage_model = self.cond_stage_model.to(self.device) - - def get_first_stage_encoding(self, encoder_posterior): - if isinstance(encoder_posterior, DiagonalGaussianDistribution): - z = encoder_posterior.sample() - elif isinstance(encoder_posterior, torch.Tensor): - z = encoder_posterior - else: - raise NotImplementedError( - f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" - ) - return self.scale_factor * z - - def get_learned_conditioning(self, c): - if self.cond_stage_forward is None: - if hasattr(self.cond_stage_model, "encode") and callable( - self.cond_stage_model.encode - ): - c = self.cond_stage_model.encode(c) - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - else: - if len(c) == 1: - c = self.cond_stage_model([c[0], c[0]]) - c = c[0:1] - else: - c = self.cond_stage_model(c) - else: - assert hasattr(self.cond_stage_model, self.cond_stage_forward) - c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) - return c - - @torch.no_grad() - def get_input( - self, - batch, - k, - return_first_stage_encode=True, - return_first_stage_outputs=False, - force_c_encode=False, - cond_key=None, - return_original_cond=False, - bs=None, - ): - x = super().get_input(batch, k) - - if bs is not None: - x = x[:bs] - - x = x.to(self.device) - - if return_first_stage_encode: - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - else: - z = None - - if self.model.conditioning_key is not None: - if cond_key is None: - cond_key = self.cond_stage_key - if cond_key != self.first_stage_key: - if cond_key in ["caption", "coordinates_bbox"]: - xc = batch[cond_key] - elif cond_key == "class_label": - xc = batch - else: - # [bs, 1, 527] - xc = super().get_input(batch, cond_key) - if type(xc) == torch.Tensor: - xc = xc.to(self.device) - else: - xc = x - if not self.cond_stage_trainable or force_c_encode: - if isinstance(xc, dict) or isinstance(xc, list): - c = self.get_learned_conditioning(xc) - else: - c = self.get_learned_conditioning(xc.to(self.device)) - else: - c = xc - - if bs is not None: - c = c[:bs] - - else: - c = None - xc = None - if self.use_positional_encodings: - pos_x, pos_y = self.compute_latent_shifts(batch) - c = {"pos_x": pos_x, "pos_y": pos_y} - out = [z, c] - if return_first_stage_outputs: - xrec = self.decode_first_stage(z) - out.extend([x, xrec]) - if return_original_cond: - out.append(xc) - return out - - @torch.no_grad() - def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, "b h w c -> b c h w").contiguous() - - z = 1.0 / self.scale_factor * z - return self.first_stage_model.decode(z) - - def mel_spectrogram_to_waveform(self, mel): - # Mel: [bs, 1, t-steps, fbins] - if len(mel.size()) == 4: - mel = mel.squeeze(1) - mel = mel.permute(0, 2, 1) - waveform = self.first_stage_model.vocoder(mel) - waveform = waveform.cpu().detach().numpy() - return waveform - - @torch.no_grad() - def encode_first_stage(self, x): - return self.first_stage_model.encode(x) - - def apply_model(self, x_noisy, t, cond, return_ids=False): - - if isinstance(cond, dict): - # hybrid case, cond is exptected to be a dict - pass - else: - if not isinstance(cond, list): - cond = [cond] - if self.model.conditioning_key == "concat": - key = "c_concat" - elif self.model.conditioning_key == "crossattn": - key = "c_crossattn" - else: - key = "c_film" - - cond = {key: cond} - - x_recon = self.model(x_noisy, t, **cond) - - if isinstance(x_recon, tuple) and not return_ids: - return x_recon[0] - else: - return x_recon - - def p_mean_variance( - self, - x, - c, - t, - clip_denoised: bool, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - score_corrector=None, - corrector_kwargs=None, - ): - t_in = t - model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) - - if score_corrector is not None: - assert self.parameterization == "eps" - model_out = score_corrector.modify_score( - self, model_out, x, t, c, **corrector_kwargs - ) - - if return_codebook_ids: - model_out, logits = model_out - - if self.parameterization == "eps": - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": - x_recon = model_out - else: - raise NotImplementedError() - - if clip_denoised: - x_recon.clamp_(-1.0, 1.0) - if quantize_denoised: - x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) - model_mean, posterior_variance, posterior_log_variance = self.q_posterior( - x_start=x_recon, x_t=x, t=t - ) - if return_codebook_ids: - return model_mean, posterior_variance, posterior_log_variance, logits - elif return_x0: - return model_mean, posterior_variance, posterior_log_variance, x_recon - else: - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample( - self, - x, - c, - t, - clip_denoised=False, - repeat_noise=False, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - ): - b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance( - x=x, - c=c, - t=t, - clip_denoised=clip_denoised, - return_codebook_ids=return_codebook_ids, - quantize_denoised=quantize_denoised, - return_x0=return_x0, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - ) - if return_codebook_ids: - raise DeprecationWarning("Support dropped.") - model_mean, _, model_log_variance, logits = outputs - elif return_x0: - model_mean, _, model_log_variance, x0 = outputs - else: - model_mean, _, model_log_variance = outputs - - noise = noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - # no noise when t == 0 - nonzero_mask = ( - (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() - ) - - if return_codebook_ids: - return model_mean + nonzero_mask * ( - 0.5 * model_log_variance - ).exp() * noise, logits.argmax(dim=1) - if return_x0: - return ( - model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, - x0, - ) - else: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def progressive_denoising( - self, - cond, - shape, - verbose=True, - callback=None, - quantize_denoised=False, - img_callback=None, - mask=None, - x0=None, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, - batch_size=None, - x_T=None, - start_T=None, - log_every_t=None, - ): - if not log_every_t: - log_every_t = self.log_every_t - timesteps = self.num_timesteps - if batch_size is not None: - b = batch_size if batch_size is not None else shape[0] - shape = [batch_size] + list(shape) - else: - b = batch_size = shape[0] - if x_T is None: - img = torch.randn(shape, device=self.device) - else: - img = x_T - intermediates = [] - if cond is not None: - if isinstance(cond, dict): - cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) - for key in cond - } - else: - cond = ( - [c[:batch_size] for c in cond] - if isinstance(cond, list) - else cond[:batch_size] - ) - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = ( - tqdm( - reversed(range(0, timesteps)), - desc="Progressive Generation", - total=timesteps, - ) - if verbose - else reversed(range(0, timesteps)) - ) - if type(temperature) == float: - temperature = [temperature] * timesteps - - for i in iterator: - ts = torch.full((b,), i, device=self.device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != "hybrid" - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - - img, x0_partial = self.p_sample( - img, - cond, - ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, - return_x0=True, - temperature=temperature[i], - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - ) - if mask is not None: - assert x0 is not None - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1.0 - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(x0_partial) - if callback: - callback(i) - if img_callback: - img_callback(img, i) - return img, intermediates - - @torch.no_grad() - def p_sample_loop( - self, - cond, - shape, - return_intermediates=False, - x_T=None, - verbose=True, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - start_T=None, - log_every_t=None, - ): - - if not log_every_t: - log_every_t = self.log_every_t - device = self.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - intermediates = [img] - if timesteps is None: - timesteps = self.num_timesteps - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = ( - tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) - if verbose - else reversed(range(0, timesteps)) - ) - - if mask is not None: - assert x0 is not None - assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match - - for i in iterator: - ts = torch.full((b,), i, device=device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != "hybrid" - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - - img = self.p_sample( - img, - cond, - ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, - ) - if mask is not None: - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1.0 - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(img) - if callback: - callback(i) - if img_callback: - img_callback(img, i) - - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample( - self, - cond, - batch_size=16, - return_intermediates=False, - x_T=None, - verbose=True, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - shape=None, - **kwargs, - ): - if shape is None: - shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size) - if cond is not None: - if isinstance(cond, dict): - cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) - for key in cond - } - else: - cond = ( - [c[:batch_size] for c in cond] - if isinstance(cond, list) - else cond[:batch_size] - ) - return self.p_sample_loop( - cond, - shape, - return_intermediates=return_intermediates, - x_T=x_T, - verbose=verbose, - timesteps=timesteps, - quantize_denoised=quantize_denoised, - mask=mask, - x0=x0, - **kwargs, - ) - - @torch.no_grad() - def sample_log( - self, - cond, - batch_size, - ddim, - ddim_steps, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - use_plms=False, - mask=None, - **kwargs, - ): - - if mask is not None: - shape = (self.channels, mask.size()[-2], mask.size()[-1]) - else: - shape = (self.channels, self.latent_t_size, self.latent_f_size) - - intermediate = None - if ddim and not use_plms: - # print("Use ddim sampler") - - ddim_sampler = DDIMSampler(self) - samples, intermediates = ddim_sampler.sample( - ddim_steps, - batch_size, - shape, - cond, - verbose=False, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - mask=mask, - **kwargs, - ) - - else: - # print("Use DDPM sampler") - samples, intermediates = self.sample( - cond=cond, - batch_size=batch_size, - return_intermediates=True, - unconditional_guidance_scale=unconditional_guidance_scale, - mask=mask, - unconditional_conditioning=unconditional_conditioning, - **kwargs, - ) - - return samples, intermediate - - - @torch.no_grad() - def generate_sample( - self, - batchs, - ddim_steps=200, - ddim_eta=1.0, - x_T=None, - n_candidate_gen_per_text=1, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - name="waveform", - use_plms=False, - save=False, - **kwargs, - ): - # Generate n_candidate_gen_per_text times and select the best - # Batch: audio, text, fnames - assert x_T is None - try: - batchs = iter(batchs) - except TypeError: - raise ValueError("The first input argument should be an iterable object") - - if use_plms: - assert ddim_steps is not None - use_ddim = ddim_steps is not None - # waveform_save_path = os.path.join(self.get_log_dir(), name) - # os.makedirs(waveform_save_path, exist_ok=True) - # print("Waveform save path: ", waveform_save_path) - - with self.ema_scope("Generate"): - for batch in batchs: - z, c = self.get_input( - batch, - self.first_stage_key, - return_first_stage_outputs=False, - force_c_encode=True, - return_original_cond=False, - bs=None, - ) - text = super().get_input(batch, "text") - - # Generate multiple samples - batch_size = z.shape[0] * n_candidate_gen_per_text - c = torch.cat([c] * n_candidate_gen_per_text, dim=0) - text = text * n_candidate_gen_per_text - - if unconditional_guidance_scale != 1.0: - unconditional_conditioning = ( - self.cond_stage_model.get_unconditional_condition(batch_size) - ) - - samples, _ = self.sample_log( - cond=c, - batch_size=batch_size, - x_T=x_T, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - use_plms=use_plms, - ) - - mel = self.decode_first_stage(samples) - - waveform = self.mel_spectrogram_to_waveform(mel) - - if(waveform.shape[0] > 1): - similarity = self.cond_stage_model.cos_similarity( - torch.FloatTensor(waveform).squeeze(1), text - ) - - best_index = [] - for i in range(z.shape[0]): - candidates = similarity[i :: z.shape[0]] - max_index = torch.argmax(candidates).item() - best_index.append(i + max_index * z.shape[0]) - - waveform = waveform[best_index] - # print("Similarity between generated audio and text", similarity) - # print("Choose the following indexes:", best_index) - - return waveform diff --git a/audioldm/pipeline.py b/audioldm/pipeline.py deleted file mode 100644 index a45db86865400e28b006dd3eebd873126a856fa0..0000000000000000000000000000000000000000 --- a/audioldm/pipeline.py +++ /dev/null @@ -1,92 +0,0 @@ - - -import os - -import argparse -import yaml -import torch - -from audioldm import LatentDiffusion, seed_everything -from audioldm.utils import default_audioldm_config - - -import time - -def make_batch_for_text_to_audio(text, batchsize=1): - text = [text] * batchsize - if batchsize < 1: - print("Warning: Batchsize must be at least 1. Batchsize is set to .") - fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format - stft = torch.zeros((batchsize, 1024, 512)) # Not used - waveform = torch.zeros((batchsize, 160000)) # Not used - fname = [""] * batchsize # Not used - batch = ( - fbank, - stft, - None, - fname, - waveform, - text, - ) - return batch - - - -def build_model( - ckpt_path=None, - config=None, - model_name="audioldm-s-full" -): - print("Load AudioLDM: %s" % model_name) - - resume_from_checkpoint = "ckpt/%s.ckpt" % model_name - - # if(ckpt_path is None): - # ckpt_path = get_metadata()[model_name]["path"] - - # if(not os.path.exists(ckpt_path)): - # download_checkpoint(model_name) - - if(torch.cuda.is_available()): - device = torch.device("cuda:0") - else: - device = torch.device("cpu") - - if(config is not None): - assert type(config) is str - config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) - else: - config = default_audioldm_config(model_name) - - # Use text as condition instead of using waveform during training - config["model"]["params"]["device"] = device - config["model"]["params"]["cond_stage_key"] = "text" - - # No normalization here - latent_diffusion = LatentDiffusion(**config["model"]["params"]) - - checkpoint = torch.load(resume_from_checkpoint, map_location=device) - latent_diffusion.load_state_dict(checkpoint["state_dict"]) - - latent_diffusion.eval() - latent_diffusion = latent_diffusion.to(device) - - latent_diffusion.cond_stage_model.embed_mode = "text" - return latent_diffusion - -def duration_to_latent_t_size(duration): - return int(duration * 25.6) - -def text_to_audio(latent_diffusion, text, seed=42, duration=10, batchsize=1, guidance_scale=2.5, n_candidate_gen_per_text=3, config=None): - seed_everything(int(seed)) - batch = make_batch_for_text_to_audio(text, batchsize=batchsize) - - latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) - with torch.no_grad(): - waveform = latent_diffusion.generate_sample( - [batch], - unconditional_guidance_scale=guidance_scale, - n_candidate_gen_per_text=n_candidate_gen_per_text, - duration=duration - ) - return waveform diff --git a/audioldm/utils.py b/audioldm/utils.py deleted file mode 100644 index 0468a2973b705af739483c196a91185500d6a8da..0000000000000000000000000000000000000000 --- a/audioldm/utils.py +++ /dev/null @@ -1,174 +0,0 @@ -import importlib - -from inspect import isfunction - -import os -import soundfile as sf - -def seed_everything(seed): - import random, os - import numpy as np - import torch - - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = True - -def save_wave(waveform, savepath, name="outwav"): - if type(name) is not list: - name = [name] * waveform.shape[0] - - for i in range(waveform.shape[0]): - path = os.path.join( - savepath, - "%s_%s.wav" - % ( - os.path.basename(name[i]) - if (not ".wav" in name[i]) - else os.path.basename(name[i]).split(".")[0], - i, - ), - ) - sf.write(path, waveform[i, 0], samplerate=16000) - -def exists(x): - return x is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def count_params(model, verbose=False): - total_params = sum(p.numel() for p in model.parameters()) - if verbose: - print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") - return total_params - - -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - - -def instantiate_from_config(config): - if not "target" in config: - if config == "__is_first_stage__": - return None - elif config == "__is_unconditional__": - return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) - -def default_audioldm_config(model_name="audioldm-s-full"): - basic_config = { - "wave_file_save_path": "./output", - "id": { - "version": "v1", - "name": "default", - "root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml", - }, - "preprocessing": { - "audio": {"sampling_rate": 16000, "max_wav_value": 32768}, - "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024}, - "mel": { - "n_mel_channels": 64, - "mel_fmin": 0, - "mel_fmax": 8000, - "freqm": 0, - "timem": 0, - "blur": False, - "mean": -4.63, - "std": 2.74, - "target_length": 1024, - }, - }, - "model": { - "device": "cuda", - "target": "audioldm.pipline.LatentDiffusion", - "params": { - "base_learning_rate": 5e-06, - "linear_start": 0.0015, - "linear_end": 0.0195, - "num_timesteps_cond": 1, - "log_every_t": 200, - "timesteps": 1000, - "first_stage_key": "fbank", - "cond_stage_key": "waveform", - "latent_t_size": 256, - "latent_f_size": 16, - "channels": 8, - "cond_stage_trainable": True, - "conditioning_key": "film", - "monitor": "val/loss_simple_ema", - "scale_by_std": True, - "unet_config": { - "target": "audioldm.latent_diffusion.openaimodel.UNetModel", - "params": { - "image_size": 64, - "extra_film_condition_dim": 512, - "extra_film_use_concat": True, - "in_channels": 8, - "out_channels": 8, - "model_channels": 128, - "attention_resolutions": [8, 4, 2], - "num_res_blocks": 2, - "channel_mult": [1, 2, 3, 5], - "num_head_channels": 32, - "use_spatial_transformer": True, - }, - }, - "first_stage_config": { - "base_learning_rate": 4.5e-05, - "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL", - "params": { - "monitor": "val/rec_loss", - "image_key": "fbank", - "subband": 1, - "embed_dim": 8, - "time_shuffle": 1, - "ddconfig": { - "double_z": True, - "z_channels": 8, - "resolution": 256, - "downsample_time": False, - "in_channels": 1, - "out_ch": 1, - "ch": 128, - "ch_mult": [1, 2, 4], - "num_res_blocks": 2, - "attn_resolutions": [], - "dropout": 0.0, - }, - }, - }, - "cond_stage_config": { - "target": "audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2", - "params": { - "key": "waveform", - "sampling_rate": 16000, - "embed_mode": "audio", - "unconditional_prob": 0.1, - }, - }, - }, - }, - } - - if("-l-" in model_name): - basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 256 - basic_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = 64 - elif("-m-" in model_name): - basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 192 - basic_config["model"]["params"]["cond_stage_config"]["params"]["amodel"] = "HTSAT-base" # This model use a larger HTAST - - return basic_config \ No newline at end of file diff --git a/audioldm/variational_autoencoder/__init__.py b/audioldm/variational_autoencoder/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audioldm/variational_autoencoder/autoencoder.py b/audioldm/variational_autoencoder/autoencoder.py deleted file mode 100644 index cfbdceba0e1171b052f797885530bacd0f3c73d5..0000000000000000000000000000000000000000 --- a/audioldm/variational_autoencoder/autoencoder.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -from audioldm.latent_diffusion.ema import * -from audioldm.variational_autoencoder.modules import Encoder, Decoder -from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution - -from audioldm.hifigan.utilities import get_vocoder, vocoder_infer - -class AutoencoderKL(nn.Module): - def __init__( - self, - ddconfig=None, - lossconfig=None, - image_key="fbank", - embed_dim=None, - time_shuffle=1, - subband=1, - ckpt_path=None, - reload_from_ckpt=None, - ignore_keys=[], - colorize_nlabels=None, - monitor=None, - base_learning_rate=1e-5, - ): - super().__init__() - - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - - self.subband = int(subband) - - if self.subband > 1: - print("Use subband decomposition %s" % self.subband) - - self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - - self.vocoder = get_vocoder(None, "cpu") - self.embed_dim = embed_dim - - if monitor is not None: - self.monitor = monitor - - self.time_shuffle = time_shuffle - self.reload_from_ckpt = reload_from_ckpt - self.reloaded = False - self.mean, self.std = None, None - - def encode(self, x): - # x = self.time_shuffle_operation(x) - x = self.freq_split_subband(x) - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - dec = self.freq_merge_subband(dec) - return dec - - def decode_to_waveform(self, dec): - dec = dec.squeeze(1).permute(0, 2, 1) - wav_reconstruction = vocoder_infer(dec, self.vocoder) - return wav_reconstruction - - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - - if self.flag_first_run: - print("Latent size: ", z.size()) - self.flag_first_run = False - - dec = self.decode(z) - - return dec, posterior - - def freq_split_subband(self, fbank): - if self.subband == 1 or self.image_key != "stft": - return fbank - - bs, ch, tstep, fbins = fbank.size() - - assert fbank.size(-1) % self.subband == 0 - assert ch == 1 - - return ( - fbank.squeeze(1) - .reshape(bs, tstep, self.subband, fbins // self.subband) - .permute(0, 2, 1, 3) - ) - - def freq_merge_subband(self, subband_fbank): - if self.subband == 1 or self.image_key != "stft": - return subband_fbank - assert subband_fbank.size(1) == self.subband # Channel dimension - bs, sub_ch, tstep, fbins = subband_fbank.size() - return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1) diff --git a/audioldm/variational_autoencoder/distributions.py b/audioldm/variational_autoencoder/distributions.py deleted file mode 100644 index 58eb535e7769f402169ddff77ee45c96ba3650d9..0000000000000000000000000000000000000000 --- a/audioldm/variational_autoencoder/distributions.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -import numpy as np - - -class AbstractDistribution: - def sample(self): - raise NotImplementedError() - - def mode(self): - raise NotImplementedError() - - -class DiracDistribution(AbstractDistribution): - def __init__(self, value): - self.value = value - - def sample(self): - return self.value - - def mode(self): - return self.value - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to( - device=self.parameters.device - ) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to( - device=self.parameters.device - ) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.mean( - torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, - dim=[1, 2, 3], - ) - else: - return 0.5 * torch.mean( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims, - ) - - def mode(self): - return self.mean - - -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 - Compute the KL divergence between two gaussians. - Shapes are automatically broadcasted, so batches can be compared to - scalars, among other use cases. - """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, torch.Tensor): - tensor = obj - break - assert tensor is not None, "at least one argument must be a Tensor" - - # Force variances to be Tensors. Broadcasting helps convert scalars to - # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] - - return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) - ) diff --git a/audioldm/variational_autoencoder/modules.py b/audioldm/variational_autoencoder/modules.py deleted file mode 100644 index 6b2c3dca2d168fb5fbaff5acc4b5a06280a496a7..0000000000000000000000000000000000000000 --- a/audioldm/variational_autoencoder/modules.py +++ /dev/null @@ -1,1064 +0,0 @@ -# pytorch_diffusion + derived encoder decoder -import math -import torch -import torch.nn as nn -import numpy as np -from einops import rearrange - -from audioldm.utils import instantiate_from_config -from audioldm.latent_diffusion.attention import LinearAttention - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm( - num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True - ) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class UpsampleTimeStride4(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=5, stride=1, padding=2 - ) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # Do time downsampling here - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=2, padding=0 - ) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class DownsampleTimeStride4(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # Do time downsampling here - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 - ) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) - return x - - -class ResnetBlock(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - else: - self.nin_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class LinAttnBlock(LinearAttention): - """to match AttnBlock usage""" - - def __init__(self, in_channels): - super().__init__(dim=in_channels, heads=1, dim_head=in_channels) - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w).contiguous() - q = q.permute(0, 2, 1).contiguous() # b,hw,c - k = k.reshape(b, c, h * w).contiguous() # b,c,hw - w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w).contiguous() - w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm( - v, w_ - ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w).contiguous() - - h_ = self.proj_out(h_) - - return x + h_ - - -def make_attn(in_channels, attn_type="vanilla"): - assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" - # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - return AttnBlock(in_channels) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - return LinAttnBlock(in_channels) - - -class Model(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - use_linear_attn=False, - attn_type="vanilla", - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = self.ch * 4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x, t=None, context=None): - # assert x.shape[2] == x.shape[3] == self.resolution - if context is not None: - # assume aligned context, cat along channel axis - x = torch.cat((x, context), dim=1) - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb - ) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - def get_last_layer(self): - return self.conv_out.weight - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type="vanilla", - downsample_time_stride4_levels=[], - **ignore_kwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.downsample_time_stride4_levels = downsample_time_stride4_levels - - if len(self.downsample_time_stride4_levels) > 0: - assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( - "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" - % str(self.num_resolutions) - ) - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - if i_level in self.downsample_time_stride4_levels: - down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) - else: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1, - ) - - def forward(self, x): - # timestep embedding - temb = None - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - downsample_time_stride4_levels=[], - attn_type="vanilla", - **ignorekwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - self.tanh_out = tanh_out - self.downsample_time_stride4_levels = downsample_time_stride4_levels - - if len(self.downsample_time_stride4_levels) > 0: - assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( - "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" - % str(self.num_resolutions) - ) - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - # print("Working with z of shape {} = {} dimensions.".format( - # self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d( - z_channels, block_in, kernel_size=3, stride=1, padding=1 - ) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - if i_level - 1 in self.downsample_time_stride4_levels: - up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) - else: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - if self.tanh_out: - h = torch.tanh(h) - return h - - -class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): - super().__init__() - self.model = nn.ModuleList( - [ - nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock( - in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, - dropout=0.0, - ), - ResnetBlock( - in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, - dropout=0.0, - ), - ResnetBlock( - in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, - dropout=0.0, - ), - nn.Conv2d(2 * in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True), - ] - ) - # end - self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - for i, layer in enumerate(self.model): - if i in [1, 2, 3]: - x = layer(x, None) - else: - x = layer(x) - - h = self.norm_out(x) - h = nonlinearity(h) - x = self.conv_out(h) - return x - - -class UpsampleDecoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - ch, - num_res_blocks, - resolution, - ch_mult=(2, 2), - dropout=0.0, - ): - super().__init__() - # upsampling - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - block_in = in_channels - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.res_blocks = nn.ModuleList() - self.upsample_blocks = nn.ModuleList() - for i_level in range(self.num_resolutions): - res_block = [] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - res_block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - self.res_blocks.append(nn.ModuleList(res_block)) - if i_level != self.num_resolutions - 1: - self.upsample_blocks.append(Upsample(block_in, True)) - curr_res = curr_res * 2 - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - # upsampling - h = x - for k, i_level in enumerate(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.res_blocks[i_level][i_block](h, None) - if i_level != self.num_resolutions - 1: - h = self.upsample_blocks[k](h) - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): - super().__init__() - # residual block, interpolate, residual block - self.factor = factor - self.conv_in = nn.Conv2d( - in_channels, mid_channels, kernel_size=3, stride=1, padding=1 - ) - self.res_block1 = nn.ModuleList( - [ - ResnetBlock( - in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0, - ) - for _ in range(depth) - ] - ) - self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList( - [ - ResnetBlock( - in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0, - ) - for _ in range(depth) - ] - ) - - self.conv_out = nn.Conv2d( - mid_channels, - out_channels, - kernel_size=1, - ) - - def forward(self, x): - x = self.conv_in(x) - for block in self.res_block1: - x = block(x, None) - x = torch.nn.functional.interpolate( - x, - size=( - int(round(x.shape[2] * self.factor)), - int(round(x.shape[3] * self.factor)), - ), - ) - x = self.attn(x).contiguous() - for block in self.res_block2: - x = block(x, None) - x = self.conv_out(x) - return x - - -class MergedRescaleEncoder(nn.Module): - def __init__( - self, - in_channels, - ch, - resolution, - out_ch, - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - ch_mult=(1, 2, 4, 8), - rescale_factor=1.0, - rescale_module_depth=1, - ): - super().__init__() - intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder( - in_channels=in_channels, - num_res_blocks=num_res_blocks, - ch=ch, - ch_mult=ch_mult, - z_channels=intermediate_chn, - double_z=False, - resolution=resolution, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - out_ch=None, - ) - self.rescaler = LatentRescaler( - factor=rescale_factor, - in_channels=intermediate_chn, - mid_channels=intermediate_chn, - out_channels=out_ch, - depth=rescale_module_depth, - ) - - def forward(self, x): - x = self.encoder(x) - x = self.rescaler(x) - return x - - -class MergedRescaleDecoder(nn.Module): - def __init__( - self, - z_channels, - out_ch, - resolution, - num_res_blocks, - attn_resolutions, - ch, - ch_mult=(1, 2, 4, 8), - dropout=0.0, - resamp_with_conv=True, - rescale_factor=1.0, - rescale_module_depth=1, - ): - super().__init__() - tmp_chn = z_channels * ch_mult[-1] - self.decoder = Decoder( - out_ch=out_ch, - z_channels=tmp_chn, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - in_channels=None, - num_res_blocks=num_res_blocks, - ch_mult=ch_mult, - resolution=resolution, - ch=ch, - ) - self.rescaler = LatentRescaler( - factor=rescale_factor, - in_channels=z_channels, - mid_channels=tmp_chn, - out_channels=tmp_chn, - depth=rescale_module_depth, - ) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): - super().__init__() - assert out_size >= in_size - num_blocks = int(np.log2(out_size // in_size)) + 1 - factor_up = 1.0 + (out_size % in_size) - print( - f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" - ) - self.rescaler = LatentRescaler( - factor=factor_up, - in_channels=in_channels, - mid_channels=2 * in_channels, - out_channels=in_channels, - ) - self.decoder = Decoder( - out_ch=out_channels, - resolution=out_size, - z_channels=in_channels, - num_res_blocks=2, - attn_resolutions=[], - in_channels=None, - ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)], - ) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): - super().__init__() - self.with_conv = learned - self.mode = mode - if self.with_conv: - print( - f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" - ) - raise NotImplementedError() - assert in_channels is not None - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=4, stride=2, padding=1 - ) - - def forward(self, x, scale_factor=1.0): - if scale_factor == 1.0: - return x - else: - x = torch.nn.functional.interpolate( - x, mode=self.mode, align_corners=False, scale_factor=scale_factor - ) - return x - - -class FirstStagePostProcessor(nn.Module): - def __init__( - self, - ch_mult: list, - in_channels, - pretrained_model: nn.Module = None, - reshape=False, - n_channels=None, - dropout=0.0, - pretrained_config=None, - ): - super().__init__() - if pretrained_config is None: - assert ( - pretrained_model is not None - ), 'Either "pretrained_model" or "pretrained_config" must not be None' - self.pretrained_model = pretrained_model - else: - assert ( - pretrained_config is not None - ), 'Either "pretrained_model" or "pretrained_config" must not be None' - self.instantiate_pretrained(pretrained_config) - - self.do_reshape = reshape - - if n_channels is None: - n_channels = self.pretrained_model.encoder.ch - - self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) - self.proj = nn.Conv2d( - in_channels, n_channels, kernel_size=3, stride=1, padding=1 - ) - - blocks = [] - downs = [] - ch_in = n_channels - for m in ch_mult: - blocks.append( - ResnetBlock( - in_channels=ch_in, out_channels=m * n_channels, dropout=dropout - ) - ) - ch_in = m * n_channels - downs.append(Downsample(ch_in, with_conv=False)) - - self.model = nn.ModuleList(blocks) - self.downsampler = nn.ModuleList(downs) - - def instantiate_pretrained(self, config): - model = instantiate_from_config(config) - self.pretrained_model = model.eval() - # self.pretrained_model.train = False - for param in self.pretrained_model.parameters(): - param.requires_grad = False - - @torch.no_grad() - def encode_with_pretrained(self, x): - c = self.pretrained_model.encode(x) - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - return c - - def forward(self, x): - z_fs = self.encode_with_pretrained(x) - z = self.proj_norm(z_fs) - z = self.proj(z) - z = nonlinearity(z) - - for submodel, downmodel in zip(self.model, self.downsampler): - z = submodel(z, temb=None) - z = downmodel(z) - - if self.do_reshape: - z = rearrange(z, "b c h w -> b (h w) c") - return z diff --git a/ckpt/audioldm-m-full.ckpt b/ckpt/audioldm-m-full.ckpt deleted file mode 100644 index a9cd836fd57622ef7d8779f3e1fcd7d0b0bd0a29..0000000000000000000000000000000000000000 --- a/ckpt/audioldm-m-full.ckpt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:936914a388905e1fc179c148a41a2b1552dba322ce474160b1cfa0f01ac26f8f -size 4571683377 diff --git a/ckpt/audioldm-m-text-ft.ckpt b/ckpt/audioldm-m-text-ft.ckpt deleted file mode 100644 index f97680915ef7ca756042d736938975dc23e69e39..0000000000000000000000000000000000000000 --- a/ckpt/audioldm-m-text-ft.ckpt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d77d5a61785af82012edb8a72158d52592ac7c76d7f6ed51a048ec2dec8d5eca -size 4571676474 diff --git a/ckpt/audioldm-s-text-ft.ckpt b/ckpt/audioldm-s-text-ft.ckpt deleted file mode 100644 index 171cbed5e1a076e809b0a525e8c88651164e46c8..0000000000000000000000000000000000000000 --- a/ckpt/audioldm-s-text-ft.ckpt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:62075af09973ab50f158b213acc60347a330737c1d827c5dead61af60bfea706 -size 2558980807 diff --git a/ckpt/ldm_trimmed.ckpt b/ckpt/ldm_trimmed.ckpt deleted file mode 100644 index 5925f8aca1f0bf7be45276a956f056ac6c2aa84e..0000000000000000000000000000000000000000 --- a/ckpt/ldm_trimmed.ckpt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8f8d1923410622be823279b61967d27a2df3fd03ddd764afb298e7c20ef8877d -size 2558947469 diff --git a/requirements.txt b/requirements.txt index 93174b5dd3fad9851e3963c0aca028c811675d50..2637b318ea649c1c45532155c716e2449264d81e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,4 @@ git+https://github.com/huggingface/diffusers.git +git+https://github.com/huggingface/transformers.git --extra-index-url https://download.pytorch.org/whl/cu113 -torch -scipy -torchaudio>=0.13.0 -torchvision>=0.14.0 -tqdm -pyyaml -einops -numpy<=1.23.5 -soundfile -librosa -pandas -# transformers -torchlibrosa -transformers -ftfy \ No newline at end of file +torch >= 2.0 \ No newline at end of file