diff --git a/README.md b/README.md
index 49b3f9d9b337749338e72cf5d6782b9562943a8b..a267d537ae6a23859d5640388bd6ccbf04a480f3 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,7 @@ sdk_version: 3.27.0
app_file: app.py
pinned: false
license: bigscience-openrail-m
+duplicated_from: haoheliu/audioldm-text-to-audio-generation
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
index 629eab165af167483b3def8286581e7270b8e01c..113d2d66f0ec1a6f0cdd393ea6763f7f71a483e7 100644
--- a/app.py
+++ b/app.py
@@ -1,197 +1,139 @@
import gradio as gr
-import numpy as np
-from audioldm import text_to_audio, build_model
+import torch
+from diffusers import AudioLDMPipeline
from share_btn import community_icon_html, loading_icon_html, share_js
-model_id="haoheliu/AudioLDM-S-Full"
+from transformers import AutoProcessor, ClapModel
-audioldm = None
-current_model_name = None
-# def predict(input, history=[]):
-# # tokenize the new input sentence
-# new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
+# make Space compatible with CPU duplicates
+if torch.cuda.is_available():
+ device = "cuda"
+ torch_dtype = torch.float16
+else:
+ device = "cpu"
+ torch_dtype = torch.float32
-# # append the new user input tokens to the chat history
-# bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
+# load the diffusers pipeline
+repo_id = "cvssp/audioldm-m-full"
+pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
+pipe.unet = torch.compile(pipe.unet)
-# # generate a response
-# history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
+# CLAP model (only required for automatic scoring)
+clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device)
+processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full")
-# # convert the tokens to text, and then split the responses into lines
-# response = tokenizer.decode(history[0]).split("<|endoftext|>")
-# response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list
-# return response, history
-
-def text2audio(text, duration, guidance_scale, random_seed, n_candidates, model_name="audioldm-m-text-ft"):
- global audioldm, current_model_name
-
- if audioldm is None or model_name != current_model_name:
- audioldm=build_model(model_name=model_name)
- current_model_name = model_name
-
- # print(text, length, guidance_scale)
- waveform = text_to_audio(
- latent_diffusion=audioldm,
- text=text,
- seed=random_seed,
- duration=duration,
+generator = torch.Generator(device)
+
+
+def text2audio(text, negative_prompt, duration, guidance_scale, random_seed, n_candidates):
+ if text is None:
+ raise gr.Error("Please provide a text input.")
+
+ waveforms = pipe(
+ text,
+ audio_length_in_s=duration,
guidance_scale=guidance_scale,
- n_candidate_gen_per_text=int(n_candidates),
- ) # [bs, 1, samples]
- waveform = [
- gr.make_waveform((16000, wave[0]), bg_image="bg.png") for wave in waveform
- ]
- # waveform = [(16000, np.random.randn(16000)), (16000, np.random.randn(16000))]
- if(len(waveform) == 1):
- waveform = waveform[0]
- return waveform
+ negative_prompt=negative_prompt,
+ num_waveforms_per_prompt=n_candidates if n_candidates else 1,
+ generator=generator.manual_seed(int(random_seed)),
+ )["audios"]
+
+ if waveforms.shape[0] > 1:
+ waveform = score_waveforms(text, waveforms)
+ else:
+ waveform = waveforms[0]
+
+ return gr.make_waveform((16000, waveform), bg_image="bg.png")
+
-# iface = gr.Interface(fn=text2audio, inputs=[
-# gr.Textbox(value="A man is speaking in a huge room", max_lines=1),
-# gr.Slider(2.5, 10, value=5, step=2.5),
-# gr.Slider(0, 5, value=2.5, step=0.5),
-# gr.Number(value=42)
-# ], outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")],
-# allow_flagging="never"
-# )
-# iface.launch(share=True)
+def score_waveforms(text, waveforms):
+ inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True)
+ inputs = {key: inputs[key].to(device) for key in inputs}
+ with torch.no_grad():
+ logits_per_text = clap_model(**inputs).logits_per_text # this is the audio-text similarity score
+ probs = logits_per_text.softmax(dim=-1) # we can take the softmax to get the label probabilities
+ most_probable = torch.argmax(probs) # and now select the most likely audio waveform
+ waveform = waveforms[most_probable]
+ return waveform
css = """
a {
- color: inherit;
- text-decoration: underline;
- }
- .gradio-container {
+ color: inherit; text-decoration: underline;
+ } .gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
- }
- .gr-button {
- color: white;
- border-color: #000000;
- background: #000000;
- }
- input[type='range'] {
+ } .gr-button {
+ color: white; border-color: #000000; background: #000000;
+ } input[type='range'] {
accent-color: #000000;
- }
- .dark input[type='range'] {
+ } .dark input[type='range'] {
accent-color: #dfdfdf;
- }
- .container {
- max-width: 730px;
- margin: auto;
- padding-top: 1.5rem;
- }
- #gallery {
- min-height: 22rem;
- margin-bottom: 15px;
- margin-left: auto;
- margin-right: auto;
- border-bottom-right-radius: .5rem !important;
- border-bottom-left-radius: .5rem !important;
- }
- #gallery>div>.h-full {
+ } .container {
+ max-width: 730px; margin: auto; padding-top: 1.5rem;
+ } #gallery {
+ min-height: 22rem; margin-bottom: 15px; margin-left: auto; margin-right: auto; border-bottom-right-radius:
+ .5rem !important; border-bottom-left-radius: .5rem !important;
+ } #gallery>div>.h-full {
min-height: 20rem;
- }
- .details:hover {
+ } .details:hover {
text-decoration: underline;
- }
- .gr-button {
+ } .gr-button {
white-space: nowrap;
- }
- .gr-button:focus {
- border-color: rgb(147 197 253 / var(--tw-border-opacity));
- outline: none;
- box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
- --tw-border-opacity: 1;
- --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
- --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
- --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
- --tw-ring-opacity: .5;
- }
- #advanced-btn {
- font-size: .7rem !important;
- line-height: 19px;
- margin-top: 12px;
- margin-bottom: 12px;
- padding: 2px 8px;
+ } .gr-button:focus {
+ border-color: rgb(147 197 253 / var(--tw-border-opacity)); outline: none; box-shadow:
+ var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); --tw-border-opacity: 1;
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width)
+ var(--tw-ring-offset-color); --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px
+ var(--tw-ring-offset-width)) var(--tw-ring-color); --tw-ring-color: rgb(191 219 254 /
+ var(--tw-ring-opacity)); --tw-ring-opacity: .5;
+ } #advanced-btn {
+ font-size: .7rem !important; line-height: 19px; margin-top: 12px; margin-bottom: 12px; padding: 2px 8px;
border-radius: 14px !important;
- }
- #advanced-options {
+ } #advanced-options {
margin-bottom: 20px;
- }
- .footer {
- margin-bottom: 45px;
- margin-top: 35px;
- text-align: center;
- border-bottom: 1px solid #e5e5e5;
- }
- .footer>p {
- font-size: .8rem;
- display: inline-block;
- padding: 0 10px;
- transform: translateY(10px);
- background: white;
- }
- .dark .footer {
+ } .footer {
+ margin-bottom: 45px; margin-top: 35px; text-align: center; border-bottom: 1px solid #e5e5e5;
+ } .footer>p {
+ font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white;
+ } .dark .footer {
border-color: #303030;
- }
- .dark .footer>p {
+ } .dark .footer>p {
background: #0b0f19;
- }
- .acknowledgments h4{
- margin: 1.25em 0 .25em 0;
- font-weight: bold;
- font-size: 115%;
- }
- #container-advanced-btns{
- display: flex;
- flex-wrap: wrap;
- justify-content: space-between;
- align-items: center;
- }
- .animate-spin {
+ } .acknowledgments h4{
+ margin: 1.25em 0 .25em 0; font-weight: bold; font-size: 115%;
+ } #container-advanced-btns{
+ display: flex; flex-wrap: wrap; justify-content: space-between; align-items: center;
+ } .animate-spin {
animation: spin 1s linear infinite;
- }
- @keyframes spin {
+ } @keyframes spin {
from {
transform: rotate(0deg);
- }
- to {
+ } to {
transform: rotate(360deg);
}
- }
- #share-btn-container {
- display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
- margin-top: 10px;
- margin-left: auto;
- }
- #share-btn {
- all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
- }
- #share-btn * {
+ } #share-btn-container {
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color:
+ #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
+ margin-top: 10px; margin-left: auto;
+ } #share-btn {
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif;
+ margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem
+ !important;right:0;
+ } #share-btn * {
all: unset;
- }
- #share-btn-container div:nth-child(-n+2){
- width: auto !important;
- min-height: 0px !important;
- }
- #share-btn-container .wrap {
+ } #share-btn-container div:nth-child(-n+2){
+ width: auto !important; min-height: 0px !important;
+ } #share-btn-container .wrap {
display: none !important;
- }
- .gr-form{
+ } .gr-form{
flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
- }
- #prompt-container{
+ } #prompt-container{
gap: 0;
- }
- #generated_id{
+ } #generated_id{
min-height: 700px
- }
- #setting_id{
- margin-bottom: 12px;
- text-align: center;
- font-weight: 900;
+ } #setting_id{
+ margin-bottom: 12px; text-align: center; font-weight: 900;
}
"""
iface = gr.Blocks(css=css)
@@ -202,56 +144,72 @@ with iface:
+ [Paper] [Project
+ page] [🧨
+ Diffusers]
"""
)
- gr.HTML("""
-
- AudioLDM: Text-to-Audio Generation with Latent Diffusion Models
-
- For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
-
-
-
-
- """)
+ gr.HTML(
+ """
+ This is the demo for AudioLDM, powered by 🧨 Diffusers. Demo uses the checkpoint audioldm-m-full . For faster inference without waiting in
+ queue, you may duplicate the space and upgrade to a GPU in the settings.
+ """
+ )
+
with gr.Group():
with gr.Box():
- ############# Input
- textbox = gr.Textbox(value="A hammer is hitting a wooden surface", max_lines=1, label="Input your text here. Your text is important for the audio quality. Please ensure it is descriptive by using more adjectives.", elem_id="prompt-in")
+ textbox = gr.Textbox(
+ value="A hammer is hitting a wooden surface",
+ max_lines=1,
+ label="Input text",
+ info="Your text is important for the audio quality. Please ensure it is descriptive by using more adjectives.",
+ elem_id="prompt-in",
+ )
+ negative_textbox = gr.Textbox(
+ value="low quality, average quality",
+ max_lines=1,
+ label="Negative prompt",
+ info="Enter a negative prompt not to guide the audio generation. Selecting appropriate negative prompts can improve the audio quality significantly.",
+ elem_id="prompt-in",
+ )
with gr.Accordion("Click to modify detailed configurations", open=False):
- seed = gr.Number(value=45, label="Change this value (any integer number) will lead to a different generation result.")
- duration = gr.Slider(2.5, 10, value=5, step=2.5, label="Duration (seconds)")
- guidance_scale = gr.Slider(0, 4, value=2.5, step=0.5, label="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)")
- n_candidates = gr.Slider(1, 3, value=3, step=1, label="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation")
- # model_name = gr.Dropdown(
- # ["audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full","audioldm-s-full-v2", "audioldm-s-full", "audioldm-l-full"], value="audioldm-m-full", label="Choose the model to use. audioldm-m-text-ft and audioldm-s-text-ft are recommanded. -s- means small, -m- means medium and -l- means large",
- # )
- ############# Output
- # outputs=gr.Audio(label="Output", type="numpy")
- outputs=gr.Video(label="Output", elem_id="output-video")
-
- # with gr.Group(elem_id="container-advanced-btns"):
- # # advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
- # with gr.Group(elem_id="share-btn-container"):
- # community_icon = gr.HTML(community_icon_html, visible=False)
- # loading_icon = gr.HTML(loading_icon_html, visible=False)
- # share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
- # outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")]
+ seed = gr.Number(
+ value=45,
+ label="Seed",
+ info="Change this value (any integer number) will lead to a different generation result.",
+ )
+ duration = gr.Slider(2.5, 10, value=5, step=2.5, label="Duration (seconds)")
+ guidance_scale = gr.Slider(
+ 0,
+ 4,
+ value=2.5,
+ step=0.5,
+ label="Guidance scale",
+ info="Large => better quality and relevancy to text; Small => better diversity",
+ )
+ n_candidates = gr.Slider(
+ 1,
+ 3,
+ value=3,
+ step=1,
+ label="Number waveforms to generate",
+ info="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
+ )
+
+ outputs = gr.Video(label="Output", elem_id="output-video")
btn = gr.Button("Submit").style(full_width=True)
with gr.Group(elem_id="share-btn-container", visible=False):
@@ -259,51 +217,61 @@ with iface:
loading_icon = gr.HTML(loading_icon_html)
share_button = gr.Button("Share to community", elem_id="share-btn")
- # btn.click(text2audio, inputs=[
- # textbox, duration, guidance_scale, seed, n_candidates, model_name], outputs=[outputs])
- btn.click(text2audio, inputs=[
- textbox, duration, guidance_scale, seed, n_candidates], outputs=[outputs])
-
+ btn.click(
+ text2audio,
+ inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates],
+ outputs=[outputs],
+ )
+
share_button.click(None, [], [], _js=share_js)
- gr.HTML('''
+ gr.HTML(
+ """
- ''')
- gr.Examples([
- ["A hammer is hitting a wooden surface", 5, 2.5, 45, 3, "audioldm-m-full"],
- ["Peaceful and calming ambient music with singing bowl and other instruments.", 5, 2.5, 45, 3, "audioldm-m-full"],
- ["A man is speaking in a small room.", 5, 2.5, 45, 3, "audioldm-m-full"],
- ["A female is speaking followed by footstep sound", 5, 2.5, 45, 3, "audioldm-m-full"],
- ["Wooden table tapping sound followed by water pouring sound.", 5, 2.5, 45, 3, "audioldm-m-full"],
- ],
+ """
+ )
+ gr.Examples(
+ [
+ ["A hammer is hitting a wooden surface", "low quality, average quality", 5, 2.5, 45, 3],
+ ["Peaceful and calming ambient music with singing bowl and other instruments.", "low quality, average quality", 5, 2.5, 45, 3],
+ ["A man is speaking in a small room.", "low quality, average quality", 5, 2.5, 45, 3],
+ ["A female is speaking followed by footstep sound", "low quality, average quality", 5, 2.5, 45, 3],
+ ["Wooden table tapping sound followed by water pouring sound.", "low quality, average quality", 5, 2.5, 45, 3],
+ ],
fn=text2audio,
- # inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name],
- inputs=[textbox, duration, guidance_scale, seed, n_candidates],
+ inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates],
outputs=[outputs],
cache_examples=True,
)
- gr.HTML('''
-
-
Essential Tricks for Enhancing the Quality of Your Generated Audio
-
1. Try to use more adjectives to describe your sound. For example: "A man is speaking clearly and slowly in a large room" is better than "A man is speaking". This can make sure AudioLDM understands what you want.
-
2. Try to use different random seeds, which can affect the generation quality significantly sometimes.
-
3. It's better to use general terms like 'man' or 'woman' instead of specific names for individuals or abstract objects that humans may not be familiar with, such as 'mummy'.
-
- ''')
+ gr.HTML(
+ """
+ Essential Tricks for Enhancing the Quality of Your Generated
+ Audio
1. Try to use more adjectives to describe your sound. For example: "A man is speaking
+ clearly and slowly in a large room" is better than "A man is speaking". This can make sure AudioLDM
+ understands what you want.
2. Try to use different random seeds, which can affect the generation
+ quality significantly sometimes.
3. It's better to use general terms like 'man' or 'woman'
+ instead of specific names for individuals or abstract objects that humans may not be familiar with,
+ such as 'mummy'.
4. Using a negative prompt to not guide the diffusion process can improve the
+ audio quality significantly. Try using negative prompts like 'low quality'.
+ """
+ )
with gr.Accordion("Additional information", open=False):
gr.HTML(
"""
"""
)
# This demo is strictly for research demo purpose only. For commercial use please contact us.
iface.queue(max_size=10).launch(debug=True)
-# iface.launch(debug=True, share=True)
diff --git a/audioldm/__init__.py b/audioldm/__init__.py
deleted file mode 100644
index 2f93cab80ded8e7239bb96eb6e364c3fd4fb46d9..0000000000000000000000000000000000000000
--- a/audioldm/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .ldm import LatentDiffusion
-from .utils import seed_everything
-from .pipeline import *
\ No newline at end of file
diff --git a/audioldm/audio/__init__.py b/audioldm/audio/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/audioldm/audio/audio_processing.py b/audioldm/audio/audio_processing.py
deleted file mode 100644
index 77a4057aa82f226f68474f4c2a19eba84510d663..0000000000000000000000000000000000000000
--- a/audioldm/audio/audio_processing.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import torch
-import numpy as np
-import librosa.util as librosa_util
-from scipy.signal import get_window
-
-
-def window_sumsquare(
- window,
- n_frames,
- hop_length,
- win_length,
- n_fft,
- dtype=np.float32,
- norm=None,
-):
- """
- # from librosa 0.6
- Compute the sum-square envelope of a window function at a given hop length.
-
- This is used to estimate modulation effects induced by windowing
- observations in short-time fourier transforms.
-
- Parameters
- ----------
- window : string, tuple, number, callable, or list-like
- Window specification, as in `get_window`
-
- n_frames : int > 0
- The number of analysis frames
-
- hop_length : int > 0
- The number of samples to advance between frames
-
- win_length : [optional]
- The length of the window function. By default, this matches `n_fft`.
-
- n_fft : int > 0
- The length of each analysis frame.
-
- dtype : np.dtype
- The data type of the output
-
- Returns
- -------
- wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
- The sum-squared envelope of the window function
- """
- if win_length is None:
- win_length = n_fft
-
- n = n_fft + hop_length * (n_frames - 1)
- x = np.zeros(n, dtype=dtype)
-
- # Compute the squared window at the desired length
- win_sq = get_window(window, win_length, fftbins=True)
- win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
- win_sq = librosa_util.pad_center(win_sq, n_fft)
-
- # Fill the envelope
- for i in range(n_frames):
- sample = i * hop_length
- x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
- return x
-
-
-def griffin_lim(magnitudes, stft_fn, n_iters=30):
- """
- PARAMS
- ------
- magnitudes: spectrogram magnitudes
- stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
- """
-
- angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
- angles = angles.astype(np.float32)
- angles = torch.autograd.Variable(torch.from_numpy(angles))
- signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
-
- for i in range(n_iters):
- _, angles = stft_fn.transform(signal)
- signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
- return signal
-
-
-def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
- """
- PARAMS
- ------
- C: compression factor
- """
- return normalize_fun(torch.clamp(x, min=clip_val) * C)
-
-
-def dynamic_range_decompression(x, C=1):
- """
- PARAMS
- ------
- C: compression factor used to compress
- """
- return torch.exp(x) / C
diff --git a/audioldm/audio/stft.py b/audioldm/audio/stft.py
deleted file mode 100644
index 2aa1ac89277734a6676c20a81bf88e21e8ca7aa9..0000000000000000000000000000000000000000
--- a/audioldm/audio/stft.py
+++ /dev/null
@@ -1,180 +0,0 @@
-import torch
-import torch.nn.functional as F
-import numpy as np
-from scipy.signal import get_window
-from librosa.util import pad_center, tiny
-from librosa.filters import mel as librosa_mel_fn
-
-from audioldm.audio.audio_processing import (
- dynamic_range_compression,
- dynamic_range_decompression,
- window_sumsquare,
-)
-
-
-class STFT(torch.nn.Module):
- """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
-
- def __init__(self, filter_length, hop_length, win_length, window="hann"):
- super(STFT, self).__init__()
- self.filter_length = filter_length
- self.hop_length = hop_length
- self.win_length = win_length
- self.window = window
- self.forward_transform = None
- scale = self.filter_length / self.hop_length
- fourier_basis = np.fft.fft(np.eye(self.filter_length))
-
- cutoff = int((self.filter_length / 2 + 1))
- fourier_basis = np.vstack(
- [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
- )
-
- forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
- inverse_basis = torch.FloatTensor(
- np.linalg.pinv(scale * fourier_basis).T[:, None, :]
- )
-
- if window is not None:
- assert filter_length >= win_length
- # get window and zero center pad it to filter_length
- fft_window = get_window(window, win_length, fftbins=True)
- fft_window = pad_center(fft_window, filter_length)
- fft_window = torch.from_numpy(fft_window).float()
-
- # window the bases
- forward_basis *= fft_window
- inverse_basis *= fft_window
-
- self.register_buffer("forward_basis", forward_basis.float())
- self.register_buffer("inverse_basis", inverse_basis.float())
-
- def transform(self, input_data):
- num_batches = input_data.size(0)
- num_samples = input_data.size(1)
-
- self.num_samples = num_samples
-
- # similar to librosa, reflect-pad the input
- input_data = input_data.view(num_batches, 1, num_samples)
- input_data = F.pad(
- input_data.unsqueeze(1),
- (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
- mode="reflect",
- )
- input_data = input_data.squeeze(1)
-
- forward_transform = F.conv1d(
- input_data,
- torch.autograd.Variable(self.forward_basis, requires_grad=False),
- stride=self.hop_length,
- padding=0,
- ).cpu()
-
- cutoff = int((self.filter_length / 2) + 1)
- real_part = forward_transform[:, :cutoff, :]
- imag_part = forward_transform[:, cutoff:, :]
-
- magnitude = torch.sqrt(real_part**2 + imag_part**2)
- phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
-
- return magnitude, phase
-
- def inverse(self, magnitude, phase):
- recombine_magnitude_phase = torch.cat(
- [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
- )
-
- inverse_transform = F.conv_transpose1d(
- recombine_magnitude_phase,
- torch.autograd.Variable(self.inverse_basis, requires_grad=False),
- stride=self.hop_length,
- padding=0,
- )
-
- if self.window is not None:
- window_sum = window_sumsquare(
- self.window,
- magnitude.size(-1),
- hop_length=self.hop_length,
- win_length=self.win_length,
- n_fft=self.filter_length,
- dtype=np.float32,
- )
- # remove modulation effects
- approx_nonzero_indices = torch.from_numpy(
- np.where(window_sum > tiny(window_sum))[0]
- )
- window_sum = torch.autograd.Variable(
- torch.from_numpy(window_sum), requires_grad=False
- )
- window_sum = window_sum
- inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
- approx_nonzero_indices
- ]
-
- # scale by hop ratio
- inverse_transform *= float(self.filter_length) / self.hop_length
-
- inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
- inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
-
- return inverse_transform
-
- def forward(self, input_data):
- self.magnitude, self.phase = self.transform(input_data)
- reconstruction = self.inverse(self.magnitude, self.phase)
- return reconstruction
-
-
-class TacotronSTFT(torch.nn.Module):
- def __init__(
- self,
- filter_length,
- hop_length,
- win_length,
- n_mel_channels,
- sampling_rate,
- mel_fmin,
- mel_fmax,
- ):
- super(TacotronSTFT, self).__init__()
- self.n_mel_channels = n_mel_channels
- self.sampling_rate = sampling_rate
- self.stft_fn = STFT(filter_length, hop_length, win_length)
- mel_basis = librosa_mel_fn(
- sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
- )
- mel_basis = torch.from_numpy(mel_basis).float()
- self.register_buffer("mel_basis", mel_basis)
-
- def spectral_normalize(self, magnitudes, normalize_fun):
- output = dynamic_range_compression(magnitudes, normalize_fun)
- return output
-
- def spectral_de_normalize(self, magnitudes):
- output = dynamic_range_decompression(magnitudes)
- return output
-
- def mel_spectrogram(self, y, normalize_fun=torch.log):
- """Computes mel-spectrograms from a batch of waves
- PARAMS
- ------
- y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
-
- RETURNS
- -------
- mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
- """
- assert torch.min(y.data) >= -1, torch.min(y.data)
- assert torch.max(y.data) <= 1, torch.max(y.data)
-
- magnitudes, phases = self.stft_fn.transform(y)
- magnitudes = magnitudes.data
- mel_output = torch.matmul(self.mel_basis, magnitudes)
- mel_output = self.spectral_normalize(mel_output, normalize_fun)
- energy = torch.norm(magnitudes, dim=1)
-
- log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
-
- return mel_output, log_magnitudes, energy
diff --git a/audioldm/audio/tools.py b/audioldm/audio/tools.py
deleted file mode 100644
index 7aca95cc1f5c120568a210907e9506589899a1c6..0000000000000000000000000000000000000000
--- a/audioldm/audio/tools.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import torch
-import numpy as np
-
-
-def get_mel_from_wav(audio, _stft):
- audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
- audio = torch.autograd.Variable(audio, requires_grad=False)
- melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
- melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
- log_magnitudes_stft = (
- torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
- )
- energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
- return melspec, log_magnitudes_stft, energy
-
-
-# def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
-# mel = torch.stack([mel])
-# mel_decompress = _stft.spectral_de_normalize(mel)
-# mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
-# spec_from_mel_scaling = 1000
-# spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
-# spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
-# spec_from_mel = spec_from_mel * spec_from_mel_scaling
-
-# audio = griffin_lim(
-# torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
-# )
-
-# audio = audio.squeeze()
-# audio = audio.cpu().numpy()
-# audio_path = out_filename
-# write(audio_path, _stft.sampling_rate, audio)
diff --git a/audioldm/clap/__init__.py b/audioldm/clap/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/audioldm/clap/encoders.py b/audioldm/clap/encoders.py
deleted file mode 100644
index 5effd8efd3b933888c586199b5eaa89e632cab03..0000000000000000000000000000000000000000
--- a/audioldm/clap/encoders.py
+++ /dev/null
@@ -1,170 +0,0 @@
-import torch
-import torch.nn as nn
-from audioldm.clap.open_clip import create_model
-from audioldm.clap.training.data import get_audio_features
-import torchaudio
-from transformers import RobertaTokenizer
-import torch.nn.functional as F
-
-
-class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
- def __init__(
- self,
- pretrained_path="",
- key="class",
- sampling_rate=16000,
- embed_mode="audio",
- amodel = "HTSAT-tiny",
- unconditional_prob=0.1,
- random_mute=False,
- max_random_mute_portion=0.5,
- training_mode=True,
- ):
- super().__init__()
-
- self.key = key
- self.device = "cpu"
- self.precision = "fp32"
- self.amodel = amodel
- self.tmodel = "roberta" # the best text encoder in our training
- self.enable_fusion = False # False if you do not want to use the fusion model
- self.fusion_type = "aff_2d"
- self.pretrained = pretrained_path
- self.embed_mode = embed_mode
- self.embed_mode_orig = embed_mode
- self.sampling_rate = sampling_rate
- self.unconditional_prob = unconditional_prob
- self.random_mute = random_mute
- self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
- self.max_random_mute_portion = max_random_mute_portion
- self.training_mode = training_mode
- self.model, self.model_cfg = create_model(
- self.amodel,
- self.tmodel,
- self.pretrained,
- precision=self.precision,
- device=self.device,
- enable_fusion=self.enable_fusion,
- fusion_type=self.fusion_type,
- )
- for p in self.model.parameters():
- p.requires_grad = False
-
- self.model.eval()
-
- def get_unconditional_condition(self, batchsize):
- self.unconditional_token = self.model.get_text_embedding(
- self.tokenizer(["", ""])
- )[0:1]
- return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
-
- def batch_to_list(self, batch):
- ret = []
- for i in range(batch.size(0)):
- ret.append(batch[i])
- return ret
-
- def make_decision(self, probability):
- if float(torch.rand(1)) < probability:
- return True
- else:
- return False
-
- def random_uniform(self, start, end):
- val = torch.rand(1).item()
- return start + (end - start) * val
-
- def _random_mute(self, waveform):
- # waveform: [bs, t-steps]
- t_steps = waveform.size(-1)
- for i in range(waveform.size(0)):
- mute_size = int(
- self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
- )
- mute_start = int(self.random_uniform(0, t_steps - mute_size))
- waveform[i, mute_start : mute_start + mute_size] = 0
- return waveform
-
- def cos_similarity(self, waveform, text):
- # waveform: [bs, t_steps]
- with torch.no_grad():
- self.embed_mode = "audio"
- audio_emb = self(waveform.cuda())
- self.embed_mode = "text"
- text_emb = self(text)
- similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
- return similarity.squeeze()
-
- def forward(self, batch, key=None):
- # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
- # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
- if self.model.training == True and not self.training_mode:
- print(
- "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
- )
- self.model, self.model_cfg = create_model(
- self.amodel,
- self.tmodel,
- self.pretrained,
- precision=self.precision,
- device="cuda",
- enable_fusion=self.enable_fusion,
- fusion_type=self.fusion_type,
- )
- for p in self.model.parameters():
- p.requires_grad = False
- self.model.eval()
-
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
- if self.embed_mode == "audio":
- with torch.no_grad():
- audio_dict_list = []
- assert (
- self.sampling_rate == 16000
- ), "We only support 16000 sampling rate"
- if self.random_mute:
- batch = self._random_mute(batch)
- # batch: [bs, 1, t-samples]
- batch = torchaudio.functional.resample(
- batch, orig_freq=self.sampling_rate, new_freq=48000
- )
- for waveform in self.batch_to_list(batch):
- audio_dict = {}
- audio_dict = get_audio_features(
- audio_dict,
- waveform,
- 480000,
- data_truncating="fusion",
- data_filling="repeatpad",
- audio_cfg=self.model_cfg["audio_cfg"],
- )
- audio_dict_list.append(audio_dict)
- # [bs, 512]
- embed = self.model.get_audio_embedding(audio_dict_list)
- elif self.embed_mode == "text":
- with torch.no_grad():
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
- text_data = self.tokenizer(batch)
- embed = self.model.get_text_embedding(text_data)
-
- embed = embed.unsqueeze(1)
- self.unconditional_token = self.model.get_text_embedding(
- self.tokenizer(["", ""])
- )[0:1]
-
- for i in range(embed.size(0)):
- if self.make_decision(self.unconditional_prob):
- embed[i] = self.unconditional_token
-
- # [bs, 1, 512]
- return embed.detach()
-
- def tokenizer(self, text):
- result = self.tokenize(
- text,
- padding="max_length",
- truncation=True,
- max_length=512,
- return_tensors="pt",
- )
- return {k: v.squeeze(0) for k, v in result.items()}
diff --git a/audioldm/clap/open_clip/__init__.py b/audioldm/clap/open_clip/__init__.py
deleted file mode 100644
index e9f728f2f273be5d5fdbec6c6cc41d737176a8c0..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/__init__.py
+++ /dev/null
@@ -1,25 +0,0 @@
-from .factory import (
- list_models,
- create_model,
- create_model_and_transforms,
- add_model_config,
-)
-from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
-from .model import (
- CLAP,
- CLAPTextCfg,
- CLAPVisionCfg,
- CLAPAudioCfp,
- convert_weights_to_fp16,
- trace_model,
-)
-from .openai import load_openai_model, list_openai_models
-from .pretrained import (
- list_pretrained,
- list_pretrained_tag_models,
- list_pretrained_model_tags,
- get_pretrained_url,
- download_pretrained,
-)
-from .tokenizer import SimpleTokenizer, tokenize
-from .transform import image_transform
diff --git a/audioldm/clap/open_clip/bert.py b/audioldm/clap/open_clip/bert.py
deleted file mode 100644
index a83d96d2a77ed05198efc05837522bc88d2499cc..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/bert.py
+++ /dev/null
@@ -1,40 +0,0 @@
-from transformers import BertTokenizer, BertModel
-
-tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
-model = BertModel.from_pretrained("bert-base-uncased")
-text = "Replace me by any text you'd like."
-
-
-def bert_embeddings(text):
- # text = "Replace me by any text you'd like."
- encoded_input = tokenizer(text, return_tensors="pt")
- output = model(**encoded_input)
- return output
-
-
-from transformers import RobertaTokenizer, RobertaModel
-
-tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
-model = RobertaModel.from_pretrained("roberta-base")
-text = "Replace me by any text you'd like."
-
-
-def Roberta_embeddings(text):
- # text = "Replace me by any text you'd like."
- encoded_input = tokenizer(text, return_tensors="pt")
- output = model(**encoded_input)
- return output
-
-
-from transformers import BartTokenizer, BartModel
-
-tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
-model = BartModel.from_pretrained("facebook/bart-base")
-text = "Replace me by any text you'd like."
-
-
-def bart_embeddings(text):
- # text = "Replace me by any text you'd like."
- encoded_input = tokenizer(text, return_tensors="pt")
- output = model(**encoded_input)
- return output
diff --git a/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz b/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz
deleted file mode 100644
index 36a15856e00a06a9fbed8cdd34d2393fea4a3113..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
-size 1356917
diff --git a/audioldm/clap/open_clip/factory.py b/audioldm/clap/open_clip/factory.py
deleted file mode 100644
index 844f9ca0e12a0ff43ba3e042a3e43530ebe91b8c..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/factory.py
+++ /dev/null
@@ -1,277 +0,0 @@
-import json
-import logging
-import os
-import pathlib
-import re
-from copy import deepcopy
-from pathlib import Path
-
-import torch
-
-from .model import CLAP, convert_weights_to_fp16
-from .openai import load_openai_model
-from .pretrained import get_pretrained_url, download_pretrained
-from .transform import image_transform
-
-_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
-_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
-
-
-def _natural_key(string_):
- return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
-
-
-def _rescan_model_configs():
- global _MODEL_CONFIGS
-
- config_ext = (".json",)
- config_files = []
- for config_path in _MODEL_CONFIG_PATHS:
- if config_path.is_file() and config_path.suffix in config_ext:
- config_files.append(config_path)
- elif config_path.is_dir():
- for ext in config_ext:
- config_files.extend(config_path.glob(f"*{ext}"))
-
- for cf in config_files:
- if os.path.basename(cf)[0] == ".":
- continue # Ignore hidden files
-
- with open(cf, "r") as f:
- model_cfg = json.load(f)
- if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
- _MODEL_CONFIGS[cf.stem] = model_cfg
-
- _MODEL_CONFIGS = {
- k: v
- for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
- }
-
-
-_rescan_model_configs() # initial populate of model config registry
-
-
-def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
- if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
- state_dict = checkpoint["state_dict"]
- else:
- state_dict = checkpoint
- if skip_params:
- if next(iter(state_dict.items()))[0].startswith("module"):
- state_dict = {k[7:]: v for k, v in state_dict.items()}
- # for k in state_dict:
- # if k.startswith('transformer'):
- # v = state_dict.pop(k)
- # state_dict['text_branch.' + k[12:]] = v
- return state_dict
-
-
-def create_model(
- amodel_name: str,
- tmodel_name: str,
- pretrained: str = "",
- precision: str = "fp32",
- device: torch.device = torch.device("cpu"),
- jit: bool = False,
- force_quick_gelu: bool = False,
- openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
- skip_params=True,
- pretrained_audio: str = "",
- pretrained_text: str = "",
- enable_fusion: bool = False,
- fusion_type: str = "None"
- # pretrained_image: bool = False,
-):
- amodel_name = amodel_name.replace(
- "/", "-"
- ) # for callers using old naming with / in ViT names
- pretrained_orig = pretrained
- pretrained = pretrained.lower()
- if pretrained == "openai":
- if amodel_name in _MODEL_CONFIGS:
- logging.info(f"Loading {amodel_name} model config.")
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
- else:
- logging.error(
- f"Model config for {amodel_name} not found; available models {list_models()}."
- )
- raise RuntimeError(f"Model config for {amodel_name} not found.")
-
- logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
- # Hard Code in model name
- model_cfg["text_cfg"]["model_type"] = tmodel_name
- model = load_openai_model(
- "ViT-B-16",
- model_cfg,
- device=device,
- jit=jit,
- cache_dir=openai_model_cache_dir,
- enable_fusion=enable_fusion,
- fusion_type=fusion_type,
- )
- # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
- if precision == "amp" or precision == "fp32":
- model = model.float()
- else:
- if amodel_name in _MODEL_CONFIGS:
- logging.info(f"Loading {amodel_name} model config.")
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
- else:
- logging.error(
- f"Model config for {amodel_name} not found; available models {list_models()}."
- )
- raise RuntimeError(f"Model config for {amodel_name} not found.")
-
- if force_quick_gelu:
- # override for use of QuickGELU on non-OpenAI transformer models
- model_cfg["quick_gelu"] = True
-
- # if pretrained_image:
- # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
- # # pretrained weight loading for timm models set via vision_cfg
- # model_cfg['vision_cfg']['timm_model_pretrained'] = True
- # else:
- # assert False, 'pretrained image towers currently only supported for timm models'
- model_cfg["text_cfg"]["model_type"] = tmodel_name
- model_cfg["enable_fusion"] = enable_fusion
- model_cfg["fusion_type"] = fusion_type
- model = CLAP(**model_cfg)
-
- if pretrained:
- checkpoint_path = ""
- url = get_pretrained_url(amodel_name, pretrained)
- if url:
- checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
- elif os.path.exists(pretrained_orig):
- checkpoint_path = pretrained_orig
- if checkpoint_path:
- logging.info(
- f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
- )
- ckpt = load_state_dict(checkpoint_path, skip_params=True)
- model.load_state_dict(ckpt)
- param_names = [n for n, p in model.named_parameters()]
- # for n in param_names:
- # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
- else:
- logging.warning(
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
- )
- raise RuntimeError(
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
- )
-
- if pretrained_audio:
- if amodel_name.startswith("PANN"):
- if "Cnn14_mAP" in pretrained_audio: # official checkpoint
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
- audio_ckpt = audio_ckpt["model"]
- keys = list(audio_ckpt.keys())
- for key in keys:
- if (
- "spectrogram_extractor" not in key
- and "logmel_extractor" not in key
- ):
- v = audio_ckpt.pop(key)
- audio_ckpt["audio_branch." + key] = v
- elif os.path.basename(pretrained_audio).startswith(
- "PANN"
- ): # checkpoint trained via HTSAT codebase
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
- audio_ckpt = audio_ckpt["state_dict"]
- keys = list(audio_ckpt.keys())
- for key in keys:
- if key.startswith("sed_model"):
- v = audio_ckpt.pop(key)
- audio_ckpt["audio_branch." + key[10:]] = v
- elif os.path.basename(pretrained_audio).startswith(
- "finetuned"
- ): # checkpoint trained via linear probe codebase
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
- else:
- raise ValueError("Unknown audio checkpoint")
- elif amodel_name.startswith("HTSAT"):
- if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
- audio_ckpt = audio_ckpt["state_dict"]
- keys = list(audio_ckpt.keys())
- for key in keys:
- if key.startswith("sed_model") and (
- "spectrogram_extractor" not in key
- and "logmel_extractor" not in key
- ):
- v = audio_ckpt.pop(key)
- audio_ckpt["audio_branch." + key[10:]] = v
- elif os.path.basename(pretrained_audio).startswith(
- "HTSAT"
- ): # checkpoint trained via HTSAT codebase
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
- audio_ckpt = audio_ckpt["state_dict"]
- keys = list(audio_ckpt.keys())
- for key in keys:
- if key.startswith("sed_model"):
- v = audio_ckpt.pop(key)
- audio_ckpt["audio_branch." + key[10:]] = v
- elif os.path.basename(pretrained_audio).startswith(
- "finetuned"
- ): # checkpoint trained via linear probe codebase
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
- else:
- raise ValueError("Unknown audio checkpoint")
- else:
- raise f"this audio encoder pretrained checkpoint is not support"
-
- model.load_state_dict(audio_ckpt, strict=False)
- logging.info(
- f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
- )
- param_names = [n for n, p in model.named_parameters()]
- for n in param_names:
- print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
-
- model.to(device=device)
- if precision == "fp16":
- assert device.type != "cpu"
- convert_weights_to_fp16(model)
-
- if jit:
- model = torch.jit.script(model)
-
- return model, model_cfg
-
-
-def create_model_and_transforms(
- model_name: str,
- pretrained: str = "",
- precision: str = "fp32",
- device: torch.device = torch.device("cpu"),
- jit: bool = False,
- force_quick_gelu: bool = False,
- # pretrained_image: bool = False,
-):
- model = create_model(
- model_name,
- pretrained,
- precision,
- device,
- jit,
- force_quick_gelu=force_quick_gelu,
- # pretrained_image=pretrained_image
- )
- preprocess_train = image_transform(model.visual.image_size, is_train=True)
- preprocess_val = image_transform(model.visual.image_size, is_train=False)
- return model, preprocess_train, preprocess_val
-
-
-def list_models():
- """enumerate available model architectures based on config files"""
- return list(_MODEL_CONFIGS.keys())
-
-
-def add_model_config(path):
- """add model config path or file and update registry"""
- if not isinstance(path, Path):
- path = Path(path)
- _MODEL_CONFIG_PATHS.append(path)
- _rescan_model_configs()
diff --git a/audioldm/clap/open_clip/feature_fusion.py b/audioldm/clap/open_clip/feature_fusion.py
deleted file mode 100644
index dbe4e170e05894c12ebdc36ba1dc1de65e441b89..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/feature_fusion.py
+++ /dev/null
@@ -1,192 +0,0 @@
-"""
-Feature Fusion for Varible-Length Data Processing
-AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
-According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
-"""
-
-import torch
-import torch.nn as nn
-
-
-class DAF(nn.Module):
- """
- 直接相加 DirectAddFuse
- """
-
- def __init__(self):
- super(DAF, self).__init__()
-
- def forward(self, x, residual):
- return x + residual
-
-
-class iAFF(nn.Module):
- """
- 多特征融合 iAFF
- """
-
- def __init__(self, channels=64, r=4, type="2D"):
- super(iAFF, self).__init__()
- inter_channels = int(channels // r)
-
- if type == "1D":
- # 本地注意力
- self.local_att = nn.Sequential(
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm1d(inter_channels),
- nn.ReLU(inplace=True),
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm1d(channels),
- )
-
- # 全局注意力
- self.global_att = nn.Sequential(
- nn.AdaptiveAvgPool1d(1),
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm1d(inter_channels),
- nn.ReLU(inplace=True),
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm1d(channels),
- )
-
- # 第二次本地注意力
- self.local_att2 = nn.Sequential(
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm1d(inter_channels),
- nn.ReLU(inplace=True),
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm1d(channels),
- )
- # 第二次全局注意力
- self.global_att2 = nn.Sequential(
- nn.AdaptiveAvgPool1d(1),
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm1d(inter_channels),
- nn.ReLU(inplace=True),
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm1d(channels),
- )
- elif type == "2D":
- # 本地注意力
- self.local_att = nn.Sequential(
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm2d(inter_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm2d(channels),
- )
-
- # 全局注意力
- self.global_att = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm2d(inter_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm2d(channels),
- )
-
- # 第二次本地注意力
- self.local_att2 = nn.Sequential(
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm2d(inter_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm2d(channels),
- )
- # 第二次全局注意力
- self.global_att2 = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm2d(inter_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm2d(channels),
- )
- else:
- raise f"the type is not supported"
-
- self.sigmoid = nn.Sigmoid()
-
- def forward(self, x, residual):
- flag = False
- xa = x + residual
- if xa.size(0) == 1:
- xa = torch.cat([xa, xa], dim=0)
- flag = True
- xl = self.local_att(xa)
- xg = self.global_att(xa)
- xlg = xl + xg
- wei = self.sigmoid(xlg)
- xi = x * wei + residual * (1 - wei)
-
- xl2 = self.local_att2(xi)
- xg2 = self.global_att(xi)
- xlg2 = xl2 + xg2
- wei2 = self.sigmoid(xlg2)
- xo = x * wei2 + residual * (1 - wei2)
- if flag:
- xo = xo[0].unsqueeze(0)
- return xo
-
-
-class AFF(nn.Module):
- """
- 多特征融合 AFF
- """
-
- def __init__(self, channels=64, r=4, type="2D"):
- super(AFF, self).__init__()
- inter_channels = int(channels // r)
-
- if type == "1D":
- self.local_att = nn.Sequential(
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm1d(inter_channels),
- nn.ReLU(inplace=True),
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm1d(channels),
- )
- self.global_att = nn.Sequential(
- nn.AdaptiveAvgPool1d(1),
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm1d(inter_channels),
- nn.ReLU(inplace=True),
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm1d(channels),
- )
- elif type == "2D":
- self.local_att = nn.Sequential(
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm2d(inter_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm2d(channels),
- )
- self.global_att = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm2d(inter_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
- nn.BatchNorm2d(channels),
- )
- else:
- raise f"the type is not supported."
-
- self.sigmoid = nn.Sigmoid()
-
- def forward(self, x, residual):
- flag = False
- xa = x + residual
- if xa.size(0) == 1:
- xa = torch.cat([xa, xa], dim=0)
- flag = True
- xl = self.local_att(xa)
- xg = self.global_att(xa)
- xlg = xl + xg
- wei = self.sigmoid(xlg)
- xo = 2 * x * wei + 2 * residual * (1 - wei)
- if flag:
- xo = xo[0].unsqueeze(0)
- return xo
diff --git a/audioldm/clap/open_clip/htsat.py b/audioldm/clap/open_clip/htsat.py
deleted file mode 100644
index 3b856c6a43df162116a941f1b5c76e93713b276a..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/htsat.py
+++ /dev/null
@@ -1,1308 +0,0 @@
-# Ke Chen
-# knutchen@ucsd.edu
-# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
-# Some layers designed on the model
-# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
-# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from itertools import repeat
-import collections.abc
-import math
-import warnings
-
-from torch.nn.init import _calculate_fan_in_and_fan_out
-import torch.utils.checkpoint as checkpoint
-
-import random
-
-from torchlibrosa.stft import Spectrogram, LogmelFilterBank
-from torchlibrosa.augmentation import SpecAugmentation
-
-from itertools import repeat
-from .utils import do_mixup, interpolate
-
-from .feature_fusion import iAFF, AFF, DAF
-
-# from PyTorch internals
-def _ntuple(n):
- def parse(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return tuple(repeat(x, n))
-
- return parse
-
-
-to_1tuple = _ntuple(1)
-to_2tuple = _ntuple(2)
-to_3tuple = _ntuple(3)
-to_4tuple = _ntuple(4)
-to_ntuple = _ntuple
-
-
-def drop_path(x, drop_prob: float = 0.0, training: bool = False):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
- 'survival rate' as the argument.
- """
- if drop_prob == 0.0 or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (
- x.ndim - 1
- ) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
- random_tensor.floor_() # binarize
- output = x.div(keep_prob) * random_tensor
- return output
-
-
-class DropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
-
- def __init__(self, drop_prob=None):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
-
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training)
-
-
-class PatchEmbed(nn.Module):
- """2D Image to Patch Embedding"""
-
- def __init__(
- self,
- img_size=224,
- patch_size=16,
- in_chans=3,
- embed_dim=768,
- norm_layer=None,
- flatten=True,
- patch_stride=16,
- enable_fusion=False,
- fusion_type="None",
- ):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patch_stride = to_2tuple(patch_stride)
- self.img_size = img_size
- self.patch_size = patch_size
- self.patch_stride = patch_stride
- self.grid_size = (
- img_size[0] // patch_stride[0],
- img_size[1] // patch_stride[1],
- )
- self.num_patches = self.grid_size[0] * self.grid_size[1]
- self.flatten = flatten
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.enable_fusion = enable_fusion
- self.fusion_type = fusion_type
-
- padding = (
- (patch_size[0] - patch_stride[0]) // 2,
- (patch_size[1] - patch_stride[1]) // 2,
- )
-
- if (self.enable_fusion) and (self.fusion_type == "channel_map"):
- self.proj = nn.Conv2d(
- in_chans * 4,
- embed_dim,
- kernel_size=patch_size,
- stride=patch_stride,
- padding=padding,
- )
- else:
- self.proj = nn.Conv2d(
- in_chans,
- embed_dim,
- kernel_size=patch_size,
- stride=patch_stride,
- padding=padding,
- )
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
-
- if (self.enable_fusion) and (
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
- ):
- self.mel_conv2d = nn.Conv2d(
- in_chans,
- embed_dim,
- kernel_size=(patch_size[0], patch_size[1] * 3),
- stride=(patch_stride[0], patch_stride[1] * 3),
- padding=padding,
- )
- if self.fusion_type == "daf_2d":
- self.fusion_model = DAF()
- elif self.fusion_type == "aff_2d":
- self.fusion_model = AFF(channels=embed_dim, type="2D")
- elif self.fusion_type == "iaff_2d":
- self.fusion_model = iAFF(channels=embed_dim, type="2D")
-
- def forward(self, x, longer_idx=None):
- if (self.enable_fusion) and (
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
- ):
- global_x = x[:, 0:1, :, :]
-
- # global processing
- B, C, H, W = global_x.shape
- assert (
- H == self.img_size[0] and W == self.img_size[1]
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- global_x = self.proj(global_x)
- TW = global_x.size(-1)
- if len(longer_idx) > 0:
- # local processing
- local_x = x[longer_idx, 1:, :, :].contiguous()
- B, C, H, W = local_x.shape
- local_x = local_x.view(B * C, 1, H, W)
- local_x = self.mel_conv2d(local_x)
- local_x = local_x.view(
- B, C, local_x.size(1), local_x.size(2), local_x.size(3)
- )
- local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
- TB, TC, TH, _ = local_x.size()
- if local_x.size(-1) < TW:
- local_x = torch.cat(
- [
- local_x,
- torch.zeros(
- (TB, TC, TH, TW - local_x.size(-1)),
- device=global_x.device,
- ),
- ],
- dim=-1,
- )
- else:
- local_x = local_x[:, :, :, :TW]
-
- global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
- x = global_x
- else:
- B, C, H, W = x.shape
- assert (
- H == self.img_size[0] and W == self.img_size[1]
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x)
-
- if self.flatten:
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
- x = self.norm(x)
- return x
-
-
-class Mlp(nn.Module):
- """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
-
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.0,
- ):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def _no_grad_trunc_normal_(tensor, mean, std, a, b):
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
- def norm_cdf(x):
- # Computes standard normal cumulative distribution function
- return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
-
- if (mean < a - 2 * std) or (mean > b + 2 * std):
- warnings.warn(
- "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
- "The distribution of values may be incorrect.",
- stacklevel=2,
- )
-
- with torch.no_grad():
- # Values are generated by using a truncated uniform distribution and
- # then using the inverse CDF for the normal distribution.
- # Get upper and lower cdf values
- l = norm_cdf((a - mean) / std)
- u = norm_cdf((b - mean) / std)
-
- # Uniformly fill tensor with values from [l, u], then translate to
- # [2l-1, 2u-1].
- tensor.uniform_(2 * l - 1, 2 * u - 1)
-
- # Use inverse cdf transform for normal distribution to get truncated
- # standard normal
- tensor.erfinv_()
-
- # Transform to proper mean, std
- tensor.mul_(std * math.sqrt(2.0))
- tensor.add_(mean)
-
- # Clamp to ensure it's in the proper range
- tensor.clamp_(min=a, max=b)
- return tensor
-
-
-def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
- # type: (Tensor, float, float, float, float) -> Tensor
- r"""Fills the input Tensor with values drawn from a truncated
- normal distribution. The values are effectively drawn from the
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
- with values outside :math:`[a, b]` redrawn until they are within
- the bounds. The method used for generating the random values works
- best when :math:`a \leq \text{mean} \leq b`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- mean: the mean of the normal distribution
- std: the standard deviation of the normal distribution
- a: the minimum cutoff value
- b: the maximum cutoff value
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.trunc_normal_(w)
- """
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
-
-
-def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
- if mode == "fan_in":
- denom = fan_in
- elif mode == "fan_out":
- denom = fan_out
- elif mode == "fan_avg":
- denom = (fan_in + fan_out) / 2
-
- variance = scale / denom
-
- if distribution == "truncated_normal":
- # constant is stddev of standard normal truncated to (-2, 2)
- trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
- elif distribution == "normal":
- tensor.normal_(std=math.sqrt(variance))
- elif distribution == "uniform":
- bound = math.sqrt(3 * variance)
- tensor.uniform_(-bound, bound)
- else:
- raise ValueError(f"invalid distribution {distribution}")
-
-
-def lecun_normal_(tensor):
- variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = (
- x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- )
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(
- B, H // window_size, W // window_size, window_size, window_size, -1
- )
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-
-class WindowAttention(nn.Module):
- r"""Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- """
-
- def __init__(
- self,
- dim,
- window_size,
- num_heads,
- qkv_bias=True,
- qk_scale=None,
- attn_drop=0.0,
- proj_drop=0.0,
- ):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
-
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
- ) # 2*Wh-1 * 2*Ww-1, nH
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = (
- coords_flatten[:, :, None] - coords_flatten[:, None, :]
- ) # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(
- 1, 2, 0
- ).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- trunc_normal_(self.relative_position_bias_table, std=0.02)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv = (
- self.qkv(x)
- .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
- .permute(2, 0, 3, 1, 4)
- )
- q, k, v = (
- qkv[0],
- qkv[1],
- qkv[2],
- ) # make torchscript happy (cannot use tensor as tuple)
-
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
-
- relative_position_bias = self.relative_position_bias_table[
- self.relative_position_index.view(-1)
- ].view(
- self.window_size[0] * self.window_size[1],
- self.window_size[0] * self.window_size[1],
- -1,
- ) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(
- 2, 0, 1
- ).contiguous() # nH, Wh*Ww, Wh*Ww
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
- 1
- ).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x, attn
-
- def extra_repr(self):
- return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
-
-
-# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
-class SwinTransformerBlock(nn.Module):
- r"""Swin Transformer Block.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(
- self,
- dim,
- input_resolution,
- num_heads,
- window_size=7,
- shift_size=0,
- mlp_ratio=4.0,
- qkv_bias=True,
- qk_scale=None,
- drop=0.0,
- attn_drop=0.0,
- drop_path=0.0,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- norm_before_mlp="ln",
- ):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- self.norm_before_mlp = norm_before_mlp
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert (
- 0 <= self.shift_size < self.window_size
- ), "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim,
- window_size=to_2tuple(self.window_size),
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop=attn_drop,
- proj_drop=drop,
- )
-
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- if self.norm_before_mlp == "ln":
- self.norm2 = nn.LayerNorm(dim)
- elif self.norm_before_mlp == "bn":
- self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
- 1, 2
- )
- else:
- raise NotImplementedError
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop,
- )
-
- if self.shift_size > 0:
- # calculate attention mask for SW-MSA
- H, W = self.input_resolution
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (
- slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None),
- )
- w_slices = (
- slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None),
- )
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(
- img_mask, self.window_size
- ) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(
- attn_mask != 0, float(-100.0)
- ).masked_fill(attn_mask == 0, float(0.0))
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def forward(self, x):
- # pdb.set_trace()
- H, W = self.input_resolution
- # print("H: ", H)
- # print("W: ", W)
- # pdb.set_trace()
- B, L, C = x.shape
- # assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = self.norm1(x)
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(
- x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
- )
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(
- shifted_x, self.window_size
- ) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(
- -1, self.window_size * self.window_size, C
- ) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA
- attn_windows, attn = self.attn(
- x_windows, mask=self.attn_mask
- ) # nW*B, window_size*window_size, C
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(
- shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
- )
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
-
- # FFN
- x = shortcut + self.drop_path(x)
- x = x + self.drop_path(self.mlp(self.norm2(x)))
-
- return x, attn
-
- def extra_repr(self):
- return (
- f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
- )
-
-
-class PatchMerging(nn.Module):
- r"""Patch Merging Layer.
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.norm(x)
- x = self.reduction(x)
-
- return x
-
- def extra_repr(self):
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
-
-class BasicLayer(nn.Module):
- """A basic Swin Transformer layer for one stage.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- """
-
- def __init__(
- self,
- dim,
- input_resolution,
- depth,
- num_heads,
- window_size,
- mlp_ratio=4.0,
- qkv_bias=True,
- qk_scale=None,
- drop=0.0,
- attn_drop=0.0,
- drop_path=0.0,
- norm_layer=nn.LayerNorm,
- downsample=None,
- use_checkpoint=False,
- norm_before_mlp="ln",
- ):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList(
- [
- SwinTransformerBlock(
- dim=dim,
- input_resolution=input_resolution,
- num_heads=num_heads,
- window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop,
- attn_drop=attn_drop,
- drop_path=drop_path[i]
- if isinstance(drop_path, list)
- else drop_path,
- norm_layer=norm_layer,
- norm_before_mlp=norm_before_mlp,
- )
- for i in range(depth)
- ]
- )
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(
- input_resolution, dim=dim, norm_layer=norm_layer
- )
- else:
- self.downsample = None
-
- def forward(self, x):
- attns = []
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x)
- else:
- x, attn = blk(x)
- if not self.training:
- attns.append(attn.unsqueeze(0))
- if self.downsample is not None:
- x = self.downsample(x)
- if not self.training:
- attn = torch.cat(attns, dim=0)
- attn = torch.mean(attn, dim=0)
- return x, attn
-
- def extra_repr(self):
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
-
-# The Core of HTSAT
-class HTSAT_Swin_Transformer(nn.Module):
- r"""HTSAT based on the Swin Transformer
- Args:
- spec_size (int | tuple(int)): Input Spectrogram size. Default 256
- patch_size (int | tuple(int)): Patch size. Default: 4
- path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
- in_chans (int): Number of input image channels. Default: 1 (mono)
- num_classes (int): Number of classes for classification head. Default: 527
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 8
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- config (module): The configuration Module from config.py
- """
-
- def __init__(
- self,
- spec_size=256,
- patch_size=4,
- patch_stride=(4, 4),
- in_chans=1,
- num_classes=527,
- embed_dim=96,
- depths=[2, 2, 6, 2],
- num_heads=[4, 8, 16, 32],
- window_size=8,
- mlp_ratio=4.0,
- qkv_bias=True,
- qk_scale=None,
- drop_rate=0.0,
- attn_drop_rate=0.0,
- drop_path_rate=0.1,
- norm_layer=nn.LayerNorm,
- ape=False,
- patch_norm=True,
- use_checkpoint=False,
- norm_before_mlp="ln",
- config=None,
- enable_fusion=False,
- fusion_type="None",
- **kwargs,
- ):
- super(HTSAT_Swin_Transformer, self).__init__()
-
- self.config = config
- self.spec_size = spec_size
- self.patch_stride = patch_stride
- self.patch_size = patch_size
- self.window_size = window_size
- self.embed_dim = embed_dim
- self.depths = depths
- self.ape = ape
- self.in_chans = in_chans
- self.num_classes = num_classes
- self.num_heads = num_heads
- self.num_layers = len(self.depths)
- self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
-
- self.drop_rate = drop_rate
- self.attn_drop_rate = attn_drop_rate
- self.drop_path_rate = drop_path_rate
-
- self.qkv_bias = qkv_bias
- self.qk_scale = None
-
- self.patch_norm = patch_norm
- self.norm_layer = norm_layer if self.patch_norm else None
- self.norm_before_mlp = norm_before_mlp
- self.mlp_ratio = mlp_ratio
-
- self.use_checkpoint = use_checkpoint
-
- self.enable_fusion = enable_fusion
- self.fusion_type = fusion_type
-
- # process mel-spec ; used only once
- self.freq_ratio = self.spec_size // self.config.mel_bins
- window = "hann"
- center = True
- pad_mode = "reflect"
- ref = 1.0
- amin = 1e-10
- top_db = None
- self.interpolate_ratio = 32 # Downsampled ratio
- # Spectrogram extractor
- self.spectrogram_extractor = Spectrogram(
- n_fft=config.window_size,
- hop_length=config.hop_size,
- win_length=config.window_size,
- window=window,
- center=center,
- pad_mode=pad_mode,
- freeze_parameters=True,
- )
- # Logmel feature extractor
- self.logmel_extractor = LogmelFilterBank(
- sr=config.sample_rate,
- n_fft=config.window_size,
- n_mels=config.mel_bins,
- fmin=config.fmin,
- fmax=config.fmax,
- ref=ref,
- amin=amin,
- top_db=top_db,
- freeze_parameters=True,
- )
- # Spec augmenter
- self.spec_augmenter = SpecAugmentation(
- time_drop_width=64,
- time_stripes_num=2,
- freq_drop_width=8,
- freq_stripes_num=2,
- ) # 2 2
- self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
-
- # split spctrogram into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=self.spec_size,
- patch_size=self.patch_size,
- in_chans=self.in_chans,
- embed_dim=self.embed_dim,
- norm_layer=self.norm_layer,
- patch_stride=patch_stride,
- enable_fusion=self.enable_fusion,
- fusion_type=self.fusion_type,
- )
-
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.grid_size
- self.patches_resolution = patches_resolution
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(
- torch.zeros(1, num_patches, self.embed_dim)
- )
- trunc_normal_(self.absolute_pos_embed, std=0.02)
-
- self.pos_drop = nn.Dropout(p=self.drop_rate)
-
- # stochastic depth
- dpr = [
- x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
- ] # stochastic depth decay rule
-
- # build layers
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = BasicLayer(
- dim=int(self.embed_dim * 2**i_layer),
- input_resolution=(
- patches_resolution[0] // (2**i_layer),
- patches_resolution[1] // (2**i_layer),
- ),
- depth=self.depths[i_layer],
- num_heads=self.num_heads[i_layer],
- window_size=self.window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=self.qkv_bias,
- qk_scale=self.qk_scale,
- drop=self.drop_rate,
- attn_drop=self.attn_drop_rate,
- drop_path=dpr[
- sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1])
- ],
- norm_layer=self.norm_layer,
- downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
- use_checkpoint=use_checkpoint,
- norm_before_mlp=self.norm_before_mlp,
- )
- self.layers.append(layer)
-
- self.norm = self.norm_layer(self.num_features)
- self.avgpool = nn.AdaptiveAvgPool1d(1)
- self.maxpool = nn.AdaptiveMaxPool1d(1)
-
- SF = (
- self.spec_size
- // (2 ** (len(self.depths) - 1))
- // self.patch_stride[0]
- // self.freq_ratio
- )
- self.tscam_conv = nn.Conv2d(
- in_channels=self.num_features,
- out_channels=self.num_classes,
- kernel_size=(SF, 3),
- padding=(0, 1),
- )
- self.head = nn.Linear(num_classes, num_classes)
-
- if (self.enable_fusion) and (
- self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
- ):
- self.mel_conv1d = nn.Sequential(
- nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
- nn.BatchNorm1d(64),
- )
- if self.fusion_type == "daf_1d":
- self.fusion_model = DAF()
- elif self.fusion_type == "aff_1d":
- self.fusion_model = AFF(channels=64, type="1D")
- elif self.fusion_type == "iaff_1d":
- self.fusion_model = iAFF(channels=64, type="1D")
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {"absolute_pos_embed"}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {"relative_position_bias_table"}
-
- def forward_features(self, x, longer_idx=None):
- # A deprecated optimization for using a hierarchical output from different blocks
-
- frames_num = x.shape[2]
- x = self.patch_embed(x, longer_idx=longer_idx)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
- for i, layer in enumerate(self.layers):
- x, attn = layer(x)
- # for x
- x = self.norm(x)
- B, N, C = x.shape
- SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
- ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
- x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST)
- B, C, F, T = x.shape
- # group 2D CNN
- c_freq_bin = F // self.freq_ratio
- x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
- x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)
- # get latent_output
- fine_grained_latent_output = torch.mean(x, dim=2)
- fine_grained_latent_output = interpolate(
- fine_grained_latent_output.permute(0, 2, 1).contiguous(),
- 8 * self.patch_stride[1],
- )
-
- latent_output = self.avgpool(torch.flatten(x, 2))
- latent_output = torch.flatten(latent_output, 1)
-
- # display the attention map, if needed
-
- x = self.tscam_conv(x)
- x = torch.flatten(x, 2) # B, C, T
-
- fpx = interpolate(
- torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]
- )
-
- x = self.avgpool(x)
- x = torch.flatten(x, 1)
-
- output_dict = {
- "framewise_output": fpx, # already sigmoided
- "clipwise_output": torch.sigmoid(x),
- "fine_grained_embedding": fine_grained_latent_output,
- "embedding": latent_output,
- }
-
- return output_dict
-
- def crop_wav(self, x, crop_size, spe_pos=None):
- time_steps = x.shape[2]
- tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
- for i in range(len(x)):
- if spe_pos is None:
- crop_pos = random.randint(0, time_steps - crop_size - 1)
- else:
- crop_pos = spe_pos
- tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :]
- return tx
-
- # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
- def reshape_wav2img(self, x):
- B, C, T, F = x.shape
- target_T = int(self.spec_size * self.freq_ratio)
- target_F = self.spec_size // self.freq_ratio
- assert (
- T <= target_T and F <= target_F
- ), "the wav size should less than or equal to the swin input size"
- # to avoid bicubic zero error
- if T < target_T:
- x = nn.functional.interpolate(
- x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
- )
- if F < target_F:
- x = nn.functional.interpolate(
- x, (x.shape[2], target_F), mode="bicubic", align_corners=True
- )
- x = x.permute(0, 1, 3, 2).contiguous()
- x = x.reshape(
- x.shape[0],
- x.shape[1],
- x.shape[2],
- self.freq_ratio,
- x.shape[3] // self.freq_ratio,
- )
- # print(x.shape)
- x = x.permute(0, 1, 3, 2, 4).contiguous()
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
- return x
-
- # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
- def repeat_wat2img(self, x, cur_pos):
- B, C, T, F = x.shape
- target_T = int(self.spec_size * self.freq_ratio)
- target_F = self.spec_size // self.freq_ratio
- assert (
- T <= target_T and F <= target_F
- ), "the wav size should less than or equal to the swin input size"
- # to avoid bicubic zero error
- if T < target_T:
- x = nn.functional.interpolate(
- x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
- )
- if F < target_F:
- x = nn.functional.interpolate(
- x, (x.shape[2], target_F), mode="bicubic", align_corners=True
- )
- x = x.permute(0, 1, 3, 2).contiguous() # B C F T
- x = x[:, :, :, cur_pos : cur_pos + self.spec_size]
- x = x.repeat(repeats=(1, 1, 4, 1))
- return x
-
- def forward(
- self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None
- ): # out_feat_keys: List[str] = None):
-
- if self.enable_fusion and x["longer"].sum() == 0:
- # if no audio is longer than 10s, then randomly select one audio to be longer
- x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
-
- if not self.enable_fusion:
- x = x["waveform"].to(device=device, non_blocking=True)
- x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
- x = x.transpose(1, 3)
- x = self.bn0(x)
- x = x.transpose(1, 3)
- if self.training:
- x = self.spec_augmenter(x)
-
- if self.training and mixup_lambda is not None:
- x = do_mixup(x, mixup_lambda)
-
- x = self.reshape_wav2img(x)
- output_dict = self.forward_features(x)
- else:
- longer_list = x["longer"].to(device=device, non_blocking=True)
- x = x["mel_fusion"].to(device=device, non_blocking=True)
- x = x.transpose(1, 3)
- x = self.bn0(x)
- x = x.transpose(1, 3)
- longer_list_idx = torch.where(longer_list)[0]
- if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
- new_x = x[:, 0:1, :, :].clone().contiguous()
- if len(longer_list_idx) > 0:
- # local processing
- fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
- FB, FC, FT, FF = fusion_x_local.size()
- fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
- fusion_x_local = torch.permute(
- fusion_x_local, (0, 2, 1)
- ).contiguous()
- fusion_x_local = self.mel_conv1d(fusion_x_local)
- fusion_x_local = fusion_x_local.view(
- FB, FC, FF, fusion_x_local.size(-1)
- )
- fusion_x_local = (
- torch.permute(fusion_x_local, (0, 2, 1, 3))
- .contiguous()
- .flatten(2)
- )
- if fusion_x_local.size(-1) < FT:
- fusion_x_local = torch.cat(
- [
- fusion_x_local,
- torch.zeros(
- (FB, FF, FT - fusion_x_local.size(-1)),
- device=device,
- ),
- ],
- dim=-1,
- )
- else:
- fusion_x_local = fusion_x_local[:, :, :FT]
- # 1D fusion
- new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
- new_x[longer_list_idx] = self.fusion_model(
- new_x[longer_list_idx], fusion_x_local
- )
- x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
- else:
- x = new_x
-
- elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
- x = x # no change
-
- if self.training:
- x = self.spec_augmenter(x)
- if self.training and mixup_lambda is not None:
- x = do_mixup(x, mixup_lambda)
-
- x = self.reshape_wav2img(x)
- output_dict = self.forward_features(x, longer_idx=longer_list_idx)
-
- # if infer_mode:
- # # in infer mode. we need to handle different length audio input
- # frame_num = x.shape[2]
- # target_T = int(self.spec_size * self.freq_ratio)
- # repeat_ratio = math.floor(target_T / frame_num)
- # x = x.repeat(repeats=(1,1,repeat_ratio,1))
- # x = self.reshape_wav2img(x)
- # output_dict = self.forward_features(x)
- # else:
- # if x.shape[2] > self.freq_ratio * self.spec_size:
- # if self.training:
- # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
- # x = self.reshape_wav2img(x)
- # output_dict = self.forward_features(x)
- # else:
- # # Change: Hard code here
- # overlap_size = (x.shape[2] - 1) // 4
- # output_dicts = []
- # crop_size = (x.shape[2] - 1) // 2
- # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
- # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
- # tx = self.reshape_wav2img(tx)
- # output_dicts.append(self.forward_features(tx))
- # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
- # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
- # for d in output_dicts:
- # clipwise_output += d["clipwise_output"]
- # framewise_output += d["framewise_output"]
- # clipwise_output = clipwise_output / len(output_dicts)
- # framewise_output = framewise_output / len(output_dicts)
- # output_dict = {
- # 'framewise_output': framewise_output,
- # 'clipwise_output': clipwise_output
- # }
- # else: # this part is typically used, and most easy one
- # x = self.reshape_wav2img(x)
- # output_dict = self.forward_features(x)
- # x = self.head(x)
-
- # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
-
- return output_dict
-
-
-def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"):
- try:
-
- assert audio_cfg.model_name in [
- "tiny",
- "base",
- "large",
- ], "model name for HTS-AT is wrong!"
- if audio_cfg.model_name == "tiny":
- model = HTSAT_Swin_Transformer(
- spec_size=256,
- patch_size=4,
- patch_stride=(4, 4),
- num_classes=audio_cfg.class_num,
- embed_dim=96,
- depths=[2, 2, 6, 2],
- num_heads=[4, 8, 16, 32],
- window_size=8,
- config=audio_cfg,
- enable_fusion=enable_fusion,
- fusion_type=fusion_type,
- )
- elif audio_cfg.model_name == "base":
- model = HTSAT_Swin_Transformer(
- spec_size=256,
- patch_size=4,
- patch_stride=(4, 4),
- num_classes=audio_cfg.class_num,
- embed_dim=128,
- depths=[2, 2, 12, 2],
- num_heads=[4, 8, 16, 32],
- window_size=8,
- config=audio_cfg,
- enable_fusion=enable_fusion,
- fusion_type=fusion_type,
- )
- elif audio_cfg.model_name == "large":
- model = HTSAT_Swin_Transformer(
- spec_size=256,
- patch_size=4,
- patch_stride=(4, 4),
- num_classes=audio_cfg.class_num,
- embed_dim=256,
- depths=[2, 2, 12, 2],
- num_heads=[4, 8, 16, 32],
- window_size=8,
- config=audio_cfg,
- enable_fusion=enable_fusion,
- fusion_type=fusion_type,
- )
-
- return model
- except:
- raise RuntimeError(
- f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
- )
diff --git a/audioldm/clap/open_clip/linear_probe.py b/audioldm/clap/open_clip/linear_probe.py
deleted file mode 100644
index 9d7e23b6b67a53e16d050d675a99d01d7d04d581..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/linear_probe.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import numpy as np
-import torch.nn.functional as F
-from torch import nn
-from .model import MLPLayers
-
-
-class LinearProbe(nn.Module):
- def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None):
- """
- Args:
- model: nn.Module
- mlp: bool, if True, then use the MLP layer as the linear probe module
- freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe
- in_ch: int, the output channel from CLAP model
- out_ch: int, the output channel from linear probe (class_num)
- act: torch.nn.functional, the activation function before the loss function
- """
- super().__init__()
- in_ch = 512
- self.clap_model = model
- self.clap_model.text_branch = None # to save memory
- self.freeze = freeze
- if mlp:
- self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch])
- else:
- self.lp_layer = nn.Linear(in_ch, out_ch)
-
- if self.freeze:
- for param in self.clap_model.parameters():
- param.requires_grad = False
-
- if act == "None":
- self.act = None
- elif act == "relu":
- self.act = nn.ReLU()
- elif act == "elu":
- self.act = nn.ELU()
- elif act == "prelu":
- self.act = nn.PReLU(num_parameters=in_ch)
- elif act == "softmax":
- self.act = nn.Softmax(dim=-1)
- elif act == "sigmoid":
- self.act = nn.Sigmoid()
-
- def forward(self, x, mix_lambda=None, device=None):
- """
- Args:
- x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list
- mix_lambda: torch.tensor [batch], the mixup lambda
- Returns:
- class_prob: torch.tensor [batch, class_num]
-
- """
- # batchnorm cancel grandient
- if self.freeze:
- self.clap_model.eval()
-
- x = self.clap_model.audio_projection(
- self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[
- "embedding"
- ]
- )
- out = self.lp_layer(x)
- if self.act is not None:
- out = self.act(out)
- return out
diff --git a/audioldm/clap/open_clip/loss.py b/audioldm/clap/open_clip/loss.py
deleted file mode 100644
index cc66298a14997da4aa2efc71e37c0a6bcda53fd1..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/loss.py
+++ /dev/null
@@ -1,398 +0,0 @@
-from multiprocessing.sharedctypes import Value
-import torch
-import torch.distributed.nn
-from torch import distributed as dist, nn as nn
-from torch.nn import functional as F
-import numpy as np
-from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
-
-try:
- import horovod.torch as hvd
-except ImportError:
- hvd = None
-
-
-def gather_features(
- audio_features,
- text_features,
- audio_features_mlp=None,
- text_features_mlp=None,
- local_loss=False,
- gather_with_grad=False,
- rank=0,
- world_size=1,
- use_horovod=False,
- mlp_loss=False,
-):
- if use_horovod:
- assert hvd is not None, "Please install horovod"
- if gather_with_grad:
- all_audio_features = hvd.allgather(audio_features)
- all_text_features = hvd.allgather(text_features)
- if mlp_loss:
- all_audio_features_mlp = hvd.allgather(audio_features_mlp)
- all_text_features_mlp = hvd.allgather(text_features_mlp)
- else:
- with torch.no_grad():
- all_audio_features = hvd.allgather(audio_features)
- all_text_features = hvd.allgather(text_features)
- if mlp_loss:
- all_audio_features_mlp = hvd.allgather(audio_features_mlp)
- all_text_features_mlp = hvd.allgather(text_features_mlp)
- if not local_loss:
- # ensure grads for local rank when all_* features don't have a gradient
- gathered_audio_features = list(
- all_audio_features.chunk(world_size, dim=0)
- )
- gathered_text_features = list(
- all_text_features.chunk(world_size, dim=0)
- )
- gathered_audio_features[rank] = audio_features
- gathered_text_features[rank] = text_features
- all_audio_features = torch.cat(gathered_audio_features, dim=0)
- all_text_features = torch.cat(gathered_text_features, dim=0)
- if mlp_loss:
- gathered_audio_features_mlp = list(
- all_audio_features_mlp.chunk(world_size, dim=0)
- )
- gathered_text_features_mlp = list(
- all_text_features_mlp.chunk(world_size, dim=0)
- )
- gathered_audio_features_mlp[rank] = audio_features_mlp
- gathered_text_features_mlp[rank] = text_features_mlp
- all_audio_features_mlp = torch.cat(
- gathered_audio_features_mlp, dim=0
- )
- all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
- else:
- # We gather tensors from all gpus
- if gather_with_grad:
- all_audio_features = torch.cat(
- torch.distributed.nn.all_gather(audio_features), dim=0
- )
- all_text_features = torch.cat(
- torch.distributed.nn.all_gather(text_features), dim=0
- )
- if mlp_loss:
- all_audio_features_mlp = torch.cat(
- torch.distributed.nn.all_gather(audio_features_mlp), dim=0
- )
- all_text_features_mlp = torch.cat(
- torch.distributed.nn.all_gather(text_features_mlp), dim=0
- )
- else:
- gathered_audio_features = [
- torch.zeros_like(audio_features) for _ in range(world_size)
- ]
- gathered_text_features = [
- torch.zeros_like(text_features) for _ in range(world_size)
- ]
- dist.all_gather(gathered_audio_features, audio_features)
- dist.all_gather(gathered_text_features, text_features)
- if mlp_loss:
- gathered_audio_features_mlp = [
- torch.zeros_like(audio_features_mlp) for _ in range(world_size)
- ]
- gathered_text_features_mlp = [
- torch.zeros_like(text_features_mlp) for _ in range(world_size)
- ]
- dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
- dist.all_gather(gathered_text_features_mlp, text_features_mlp)
- if not local_loss:
- # ensure grads for local rank when all_* features don't have a gradient
- gathered_audio_features[rank] = audio_features
- gathered_text_features[rank] = text_features
- if mlp_loss:
- gathered_audio_features_mlp[rank] = audio_features_mlp
- gathered_text_features_mlp[rank] = text_features_mlp
-
- all_audio_features = torch.cat(gathered_audio_features, dim=0)
- all_text_features = torch.cat(gathered_text_features, dim=0)
- if mlp_loss:
- all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
- all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
- if mlp_loss:
- return (
- all_audio_features,
- all_text_features,
- all_audio_features_mlp,
- all_text_features_mlp,
- )
- else:
- return all_audio_features, all_text_features
-
-
-class ClipLoss(nn.Module):
- def __init__(
- self,
- local_loss=False,
- gather_with_grad=False,
- cache_labels=False,
- rank=0,
- world_size=1,
- use_horovod=False,
- mlp_loss=False,
- weight_loss_kappa=0,
- ):
- super().__init__()
- self.local_loss = local_loss
- self.gather_with_grad = gather_with_grad
- self.cache_labels = cache_labels
- self.rank = rank
- self.world_size = world_size
- self.use_horovod = use_horovod
- self.mlp_loss = mlp_loss
- self.weighted_loss = bool(weight_loss_kappa != 0)
- self.weight_loss_kappa = weight_loss_kappa
- # cache state
- self.prev_num_logits = 0
- self.labels = {}
-
- def forward(
- self,
- audio_features,
- text_features,
- logit_scale_a,
- logit_scale_t=None,
- audio_features_mlp=None,
- text_features_mlp=None,
- ):
- device = audio_features.device
- if self.mlp_loss:
- if self.world_size > 1:
- (
- all_audio_features,
- all_text_features,
- all_audio_features_mlp,
- all_text_features_mlp,
- ) = gather_features(
- audio_features=audio_features,
- text_features=text_features,
- audio_features_mlp=audio_features_mlp,
- text_features_mlp=text_features_mlp,
- local_loss=self.local_loss,
- gather_with_grad=self.gather_with_grad,
- rank=self.rank,
- world_size=self.world_size,
- use_horovod=self.use_horovod,
- mlp_loss=self.mlp_loss,
- )
- if self.local_loss:
- a_logits_per_audio = (
- logit_scale_a * audio_features @ all_text_features_mlp.T
- )
- a_logits_per_text = (
- logit_scale_a * text_features_mlp @ all_audio_features.T
- )
- t_logits_per_audio = (
- logit_scale_t * audio_features_mlp @ all_text_features.T
- )
- t_logits_per_text = (
- logit_scale_t * text_features @ all_audio_features_mlp.T
- )
- else:
- a_logits_per_audio = (
- logit_scale_a * all_audio_features @ all_text_features_mlp.T
- )
- a_logits_per_text = a_logits_per_audio.T
- t_logits_per_audio = (
- logit_scale_t * all_audio_features_mlp @ all_text_features.T
- )
- t_logits_per_text = t_logits_per_audio.T
- else:
- a_logits_per_audio = (
- logit_scale_a * audio_features @ text_features_mlp.T
- )
- a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
- t_logits_per_audio = (
- logit_scale_t * audio_features_mlp @ text_features.T
- )
- t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
-
- # calculated ground-truth and cache if enabled
- num_logits = a_logits_per_audio.shape[0]
- if self.prev_num_logits != num_logits or device not in self.labels:
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
- if self.world_size > 1 and self.local_loss:
- labels = labels + num_logits * self.rank
- if self.cache_labels:
- self.labels[device] = labels
- self.prev_num_logits = num_logits
- else:
- labels = self.labels[device]
-
- if not self.weighted_loss:
- total_loss = (
- F.cross_entropy(a_logits_per_audio, labels)
- + F.cross_entropy(a_logits_per_text, labels)
- + F.cross_entropy(t_logits_per_audio, labels)
- + F.cross_entropy(t_logits_per_text, labels)
- ) / 4
- else:
- audio_weight = (audio_features @ audio_features.T).detach()
- audio_weight = (
- torch.exp(
- torch.sum(audio_weight, axis=1)
- / (self.weight_loss_kappa * len(audio_weight))
- )
- ).detach()
- text_weight = (text_features @ text_features.T).detach()
- text_weight = (
- torch.exp(
- torch.sum(text_weight, axis=1)
- / (self.weight_loss_kappa * len(text_features))
- )
- ).detach()
- total_loss = (
- F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
- + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
- + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
- + F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
- ) / 4
- else:
- if self.world_size > 1:
- all_audio_features, all_text_features = gather_features(
- audio_features=audio_features,
- text_features=text_features,
- local_loss=self.local_loss,
- gather_with_grad=self.gather_with_grad,
- rank=self.rank,
- world_size=self.world_size,
- use_horovod=self.use_horovod,
- mlp_loss=self.mlp_loss,
- )
-
- if self.local_loss:
- logits_per_audio = (
- logit_scale_a * audio_features @ all_text_features.T
- )
- logits_per_text = (
- logit_scale_a * text_features @ all_audio_features.T
- )
- else:
- logits_per_audio = (
- logit_scale_a * all_audio_features @ all_text_features.T
- )
- logits_per_text = logits_per_audio.T
- else:
- logits_per_audio = logit_scale_a * audio_features @ text_features.T
- logits_per_text = logit_scale_a * text_features @ audio_features.T
-
- # calculated ground-truth and cache if enabled
- num_logits = logits_per_audio.shape[0]
- if self.prev_num_logits != num_logits or device not in self.labels:
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
- if self.world_size > 1 and self.local_loss:
- labels = labels + num_logits * self.rank
- if self.cache_labels:
- self.labels[device] = labels
- self.prev_num_logits = num_logits
- else:
- labels = self.labels[device]
- if not self.weighted_loss:
- total_loss = (
- F.cross_entropy(logits_per_audio, labels)
- + F.cross_entropy(logits_per_text, labels)
- ) / 2
- else:
- audio_weight = (all_audio_features @ all_audio_features.T).detach()
- audio_weight = (
- torch.exp(
- torch.sum(audio_weight, axis=1)
- / (self.weight_loss_kappa * len(all_audio_features))
- )
- ).detach()
- text_weight = (all_text_features @ all_text_features.T).detach()
- text_weight = (
- torch.exp(
- torch.sum(text_weight, axis=1)
- / (self.weight_loss_kappa * len(all_text_features))
- )
- ).detach()
- total_loss = (
- F.cross_entropy(logits_per_audio, labels, weight=text_weight)
- + F.cross_entropy(logits_per_text, labels, weight=audio_weight)
- ) / 2
- return total_loss
-
-
-def lp_gather_features(pred, target, world_size=1, use_horovod=False):
- if use_horovod:
- assert hvd is not None, "Please install horovod"
- with torch.no_grad():
- all_preds = hvd.allgather(pred)
- all_targets = hvd.allgath(target)
- else:
- gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
- gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
-
- dist.all_gather(gathered_preds, pred)
- dist.all_gather(gathered_targets, target)
- all_preds = torch.cat(gathered_preds, dim=0)
- all_targets = torch.cat(gathered_targets, dim=0)
-
- return all_preds, all_targets
-
-
-def get_map(pred, target):
- pred = torch.sigmoid(pred).numpy()
- target = target.numpy()
- return np.mean(average_precision_score(target, pred, average=None))
-
-
-def get_acc(pred, target):
- pred = torch.argmax(pred, 1).numpy()
- target = torch.argmax(target, 1).numpy()
- return accuracy_score(target, pred)
-
-
-def get_mauc(pred, target):
- pred = torch.sigmoid(pred).numpy()
- target = target.numpy()
- return np.mean(roc_auc_score(target, pred, average=None))
-
-
-class LPMetrics(object):
- def __init__(self, metric_names=["map", "acc", "mauc"]):
- self.metrics = []
- for name in metric_names:
- self.metrics.append(self.get_metric(name))
- self.metric_names = metric_names
-
- def get_metric(self, name):
- if name == "map":
- return get_map
- elif name == "acc":
- return get_acc
- elif name == "mauc":
- return get_mauc
- else:
- raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
-
- def evaluate_mertics(self, pred, target):
- metric_dict = {}
- for i in range(len(self.metric_names)):
- metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
- return metric_dict
-
-
-def calc_celoss(pred, target):
- target = torch.argmax(target, 1).long()
- return nn.CrossEntropyLoss()(pred, target)
-
-
-class LPLoss(nn.Module):
- def __init__(self, loss_name):
- super().__init__()
- if loss_name == "bce":
- self.loss_func = nn.BCEWithLogitsLoss()
- elif loss_name == "ce":
- self.loss_func = calc_celoss
- elif loss_name == "mse":
- self.loss_func = nn.MSELoss()
- else:
- raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
-
- def forward(self, pred, target):
- loss = self.loss_func(pred, target)
- return loss
diff --git a/audioldm/clap/open_clip/model.py b/audioldm/clap/open_clip/model.py
deleted file mode 100644
index f41e6d6d0b0bbecacb90744928a516b75d218214..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model.py
+++ /dev/null
@@ -1,936 +0,0 @@
-""" CLAP Model
-
-Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
-Adapted to the Audio Task.
-"""
-
-from collections import OrderedDict
-from dataclasses import dataclass
-from email.mime import audio
-from typing import Tuple, Union, Callable, Optional
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-from .timm_model import TimmModel
-import logging
-from .utils import freeze_batch_norm_2d
-
-from .pann_model import create_pann_model
-from .htsat import create_htsat_model
-from transformers import BertModel, RobertaModel, BartModel
-from transformers.tokenization_utils_base import BatchEncoding
-
-
-class MLPLayers(nn.Module):
- def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
- super(MLPLayers, self).__init__()
- self.nonlin = nonlin
- self.dropout = dropout
-
- sequence = []
- for u0, u1 in zip(units[:-1], units[1:]):
- sequence.append(nn.Linear(u0, u1))
- sequence.append(self.nonlin)
- sequence.append(nn.Dropout(self.dropout))
- sequence = sequence[:-2]
-
- self.sequential = nn.Sequential(*sequence)
-
- def forward(self, X):
- X = self.sequential(X)
- return X
-
-
-class Bottleneck(nn.Module):
- expansion = 4
-
- def __init__(self, inplanes, planes, stride=1):
- super().__init__()
-
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
- self.bn1 = nn.BatchNorm2d(planes)
-
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(planes)
-
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
-
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
-
- self.relu = nn.ReLU(inplace=True)
- self.downsample = None
- self.stride = stride
-
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
- self.downsample = nn.Sequential(
- OrderedDict(
- [
- ("-1", nn.AvgPool2d(stride)),
- (
- "0",
- nn.Conv2d(
- inplanes,
- planes * self.expansion,
- 1,
- stride=1,
- bias=False,
- ),
- ),
- ("1", nn.BatchNorm2d(planes * self.expansion)),
- ]
- )
- )
-
- def forward(self, x: torch.Tensor):
- identity = x
-
- out = self.relu(self.bn1(self.conv1(x)))
- out = self.relu(self.bn2(self.conv2(out)))
- out = self.avgpool(out)
- out = self.bn3(self.conv3(out))
-
- if self.downsample is not None:
- identity = self.downsample(x)
-
- out += identity
- out = self.relu(out)
- return out
-
-
-class AttentionPool2d(nn.Module):
- def __init__(
- self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
- ):
- super().__init__()
- self.positional_embedding = nn.Parameter(
- torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
- )
- self.k_proj = nn.Linear(embed_dim, embed_dim)
- self.q_proj = nn.Linear(embed_dim, embed_dim)
- self.v_proj = nn.Linear(embed_dim, embed_dim)
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
- self.num_heads = num_heads
-
- def forward(self, x):
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
- 2, 0, 1
- ) # NCHW -> (HW)NC
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
- x, _ = F.multi_head_attention_forward(
- query=x,
- key=x,
- value=x,
- embed_dim_to_check=x.shape[-1],
- num_heads=self.num_heads,
- q_proj_weight=self.q_proj.weight,
- k_proj_weight=self.k_proj.weight,
- v_proj_weight=self.v_proj.weight,
- in_proj_weight=None,
- in_proj_bias=torch.cat(
- [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
- ),
- bias_k=None,
- bias_v=None,
- add_zero_attn=False,
- dropout_p=0,
- out_proj_weight=self.c_proj.weight,
- out_proj_bias=self.c_proj.bias,
- use_separate_proj_weight=True,
- training=self.training,
- need_weights=False,
- )
-
- return x[0]
-
-
-class ModifiedResNet(nn.Module):
- """
- A ResNet class that is similar to torchvision's but contains the following changes:
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- - The final pooling layer is a QKV attention instead of an average pool
- """
-
- def __init__(self, layers, output_dim, heads, image_size=224, width=64):
- super().__init__()
- self.output_dim = output_dim
- self.image_size = image_size
-
- # the 3-layer stem
- self.conv1 = nn.Conv2d(
- 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
- )
- self.bn1 = nn.BatchNorm2d(width // 2)
- self.conv2 = nn.Conv2d(
- width // 2, width // 2, kernel_size=3, padding=1, bias=False
- )
- self.bn2 = nn.BatchNorm2d(width // 2)
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
- self.bn3 = nn.BatchNorm2d(width)
- self.avgpool = nn.AvgPool2d(2)
- self.relu = nn.ReLU(inplace=True)
-
- # residual layers
- self._inplanes = width # this is a *mutable* variable used during construction
- self.layer1 = self._make_layer(width, layers[0])
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
-
- embed_dim = width * 32 # the ResNet feature dimension
- self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
-
- self.init_parameters()
-
- def _make_layer(self, planes, blocks, stride=1):
- layers = [Bottleneck(self._inplanes, planes, stride)]
-
- self._inplanes = planes * Bottleneck.expansion
- for _ in range(1, blocks):
- layers.append(Bottleneck(self._inplanes, planes))
-
- return nn.Sequential(*layers)
-
- def init_parameters(self):
- if self.attnpool is not None:
- std = self.attnpool.c_proj.in_features**-0.5
- nn.init.normal_(self.attnpool.q_proj.weight, std=std)
- nn.init.normal_(self.attnpool.k_proj.weight, std=std)
- nn.init.normal_(self.attnpool.v_proj.weight, std=std)
- nn.init.normal_(self.attnpool.c_proj.weight, std=std)
-
- for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
- for name, param in resnet_block.named_parameters():
- if name.endswith("bn3.weight"):
- nn.init.zeros_(param)
-
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
- assert (
- unlocked_groups == 0
- ), "partial locking not currently supported for this model"
- for param in self.parameters():
- param.requires_grad = False
- if freeze_bn_stats:
- freeze_batch_norm_2d(self)
-
- def stem(self, x):
- for conv, bn in [
- (self.conv1, self.bn1),
- (self.conv2, self.bn2),
- (self.conv3, self.bn3),
- ]:
- x = self.relu(bn(conv(x)))
- x = self.avgpool(x)
- return x
-
- def forward(self, x):
- x = self.stem(x)
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
- x = self.attnpool(x)
-
- return x
-
-
-class LayerNorm(nn.LayerNorm):
- """Subclass torch's LayerNorm to handle fp16."""
-
- def forward(self, x: torch.Tensor):
- orig_type = x.dtype
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
- return x.to(orig_type)
-
-
-class QuickGELU(nn.Module):
- # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
- def forward(self, x: torch.Tensor):
- return x * torch.sigmoid(1.702 * x)
-
-
-class ResidualAttentionBlock(nn.Module):
- def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
- super().__init__()
-
- self.attn = nn.MultiheadAttention(d_model, n_head)
- self.ln_1 = LayerNorm(d_model)
- self.mlp = nn.Sequential(
- OrderedDict(
- [
- ("c_fc", nn.Linear(d_model, d_model * 4)),
- ("gelu", act_layer()),
- ("c_proj", nn.Linear(d_model * 4, d_model)),
- ]
- )
- )
- self.ln_2 = LayerNorm(d_model)
-
- def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
- return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
-
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
- x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
- x = x + self.mlp(self.ln_2(x))
- return x
-
-
-class Transformer(nn.Module):
- def __init__(
- self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
- ):
- super().__init__()
- self.width = width
- self.layers = layers
- self.resblocks = nn.ModuleList(
- [
- ResidualAttentionBlock(width, heads, act_layer=act_layer)
- for _ in range(layers)
- ]
- )
-
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
- for r in self.resblocks:
- x = r(x, attn_mask=attn_mask)
- return x
-
-
-class VisualTransformer(nn.Module):
- def __init__(
- self,
- image_size: int,
- patch_size: int,
- width: int,
- layers: int,
- heads: int,
- output_dim: int,
- act_layer: Callable = nn.GELU,
- ):
- super().__init__()
- self.image_size = image_size
- self.output_dim = output_dim
- self.conv1 = nn.Conv2d(
- in_channels=3,
- out_channels=width,
- kernel_size=patch_size,
- stride=patch_size,
- bias=False,
- )
-
- scale = width**-0.5
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
- self.positional_embedding = nn.Parameter(
- scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
- )
- self.ln_pre = LayerNorm(width)
-
- self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
-
- self.ln_post = LayerNorm(width)
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
-
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
- assert (
- unlocked_groups == 0
- ), "partial locking not currently supported for this model"
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, x: torch.Tensor):
- x = self.conv1(x) # shape = [*, width, grid, grid]
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
- x = torch.cat(
- [
- self.class_embedding.to(x.dtype)
- + torch.zeros(
- x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
- ),
- x,
- ],
- dim=1,
- ) # shape = [*, grid ** 2 + 1, width]
- x = x + self.positional_embedding.to(x.dtype)
- x = self.ln_pre(x)
-
- x = x.permute(1, 0, 2) # NLD -> LND
- x = self.text_branch(x)
- x = x.permute(1, 0, 2) # LND -> NLD
-
- x = self.ln_post(x[:, 0, :])
-
- if self.proj is not None:
- x = x @ self.proj
-
- return x
-
-
-@dataclass
-class CLAPVisionCfg:
- layers: Union[Tuple[int, int, int, int], int] = 12
- width: int = 768
- patch_size: int = 16
- image_size: Union[Tuple[int, int], int] = 224
- timm_model_name: str = (
- None # a valid model name overrides layers, width, patch_size
- )
- timm_model_pretrained: bool = (
- False # use (imagenet) pretrained weights for named model
- )
- timm_pool: str = (
- "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
- )
- timm_proj: str = (
- "linear" # linear projection for timm model output ('linear', 'mlp', '')
- )
-
-
-# Audio Config Class
-@dataclass
-class CLAPAudioCfp:
- model_type: str = "PANN"
- model_name: str = "Cnn14"
- sample_rate: int = 48000
- # Param
- audio_length: int = 1024
- window_size: int = 1024
- hop_size: int = 1024
- fmin: int = 50
- fmax: int = 14000
- class_num: int = 527
- mel_bins: int = 64
- clip_samples: int = 480000
-
-
-@dataclass
-class CLAPTextCfg:
- context_length: int
- vocab_size: int
- width: int
- heads: int
- layers: int
- model_type: str
-
-
-class CLAP(nn.Module):
- def __init__(
- self,
- embed_dim: int,
- audio_cfg: CLAPAudioCfp,
- text_cfg: CLAPTextCfg,
- quick_gelu: bool = False,
- enable_fusion: bool = False,
- fusion_type: str = "None",
- joint_embed_shape: int = 512,
- mlp_act: str = "relu",
- ):
- super().__init__()
- if isinstance(audio_cfg, dict):
- audio_cfg = CLAPAudioCfp(**audio_cfg)
- if isinstance(text_cfg, dict):
- text_cfg = CLAPTextCfg(**text_cfg)
-
- self.audio_cfg = audio_cfg
- self.text_cfg = text_cfg
- self.enable_fusion = enable_fusion
- self.fusion_type = fusion_type
- self.joint_embed_shape = joint_embed_shape
- self.mlp_act = mlp_act
-
- self.context_length = text_cfg.context_length
-
- # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
- # memory efficient in recent PyTorch releases (>= 1.10).
- # NOTE: timm models always use native GELU regardless of quick_gelu flag.
- act_layer = QuickGELU if quick_gelu else nn.GELU
-
- if mlp_act == "relu":
- mlp_act_layer = nn.ReLU()
- elif mlp_act == "gelu":
- mlp_act_layer = nn.GELU()
- else:
- raise NotImplementedError
-
- # audio branch
- # audio branch parameters
- if audio_cfg.model_type == "PANN":
- self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
- elif audio_cfg.model_type == "HTSAT":
- self.audio_branch = create_htsat_model(
- audio_cfg, enable_fusion, fusion_type
- )
- else:
- logging.error(f"Model config for {audio_cfg.model_type} not found")
- raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
-
- # text branch
- # text branch parameters
- if text_cfg.model_type == "transformer":
- self.text_branch = Transformer(
- width=text_cfg.width,
- layers=text_cfg.layers,
- heads=text_cfg.heads,
- act_layer=act_layer,
- )
- self.vocab_size = text_cfg.vocab_size
- self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
- self.positional_embedding = nn.Parameter(
- torch.empty(self.context_length, text_cfg.width)
- )
- self.ln_final = LayerNorm(text_cfg.width)
- self.text_transform = MLPLayers(
- units=[
- self.joint_embed_shape,
- self.joint_embed_shape,
- self.joint_embed_shape,
- ],
- dropout=0.1,
- )
- self.text_projection = nn.Sequential(
- nn.Linear(text_cfg.width, self.joint_embed_shape),
- mlp_act_layer,
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
- )
- elif text_cfg.model_type == "bert":
- self.text_branch = BertModel.from_pretrained("bert-base-uncased")
- self.text_transform = MLPLayers(
- units=[
- self.joint_embed_shape,
- self.joint_embed_shape,
- self.joint_embed_shape,
- ],
- dropout=0.1,
- )
- self.text_projection = nn.Sequential(
- nn.Linear(768, self.joint_embed_shape),
- mlp_act_layer,
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
- )
- elif text_cfg.model_type == "roberta":
- self.text_branch = RobertaModel.from_pretrained("roberta-base")
- self.text_transform = MLPLayers(
- units=[
- self.joint_embed_shape,
- self.joint_embed_shape,
- self.joint_embed_shape,
- ],
- dropout=0.1,
- )
- self.text_projection = nn.Sequential(
- nn.Linear(768, self.joint_embed_shape),
- mlp_act_layer,
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
- )
- elif text_cfg.model_type == "bart":
- self.text_branch = BartModel.from_pretrained("facebook/bart-base")
- self.text_transform = MLPLayers(
- units=[
- self.joint_embed_shape,
- self.joint_embed_shape,
- self.joint_embed_shape,
- ],
- dropout=0.1,
- )
- self.text_projection = nn.Sequential(
- nn.Linear(768, self.joint_embed_shape),
- mlp_act_layer,
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
- )
- else:
- logging.error(f"Model config for {text_cfg.model_type} not found")
- raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
- self.text_branch_type = text_cfg.model_type
- # text branch parameters
-
- # audio branch parameters
- self.audio_transform = MLPLayers(
- units=[
- self.joint_embed_shape,
- self.joint_embed_shape,
- self.joint_embed_shape,
- ],
- dropout=0.1,
- )
-
- # below here is text branch parameters
-
- # ============================================================================================================
- self.audio_projection = nn.Sequential(
- nn.Linear(embed_dim, self.joint_embed_shape),
- mlp_act_layer,
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
- )
-
- self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
- self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
- self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
-
- self.init_text_branch_parameters()
-
- def init_text_branch_parameters(self):
- if self.text_branch_type == "transformer":
- nn.init.normal_(self.token_embedding.weight, std=0.02)
- nn.init.normal_(self.positional_embedding, std=0.01)
- proj_std = (self.text_branch.width**-0.5) * (
- (2 * self.text_branch.layers) ** -0.5
- )
- attn_std = self.text_branch.width**-0.5
- fc_std = (2 * self.text_branch.width) ** -0.5
- for block in self.text_branch.resblocks:
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
- if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
- width = self.text_branch.embeddings.word_embeddings.weight.shape[-1]
- elif self.text_branch_type == "bart":
- width = self.text_branch.shared.weight.shape[-1]
- else:
- width = self.text_branch.width
- nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
- nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
-
- # deprecated
- # if hasattr(self.visual, 'init_parameters'):
- # self.visual.init_parameters()
-
- # if self.text_projection is not None:
- # nn.init.normal_(self.text_projection, std=width**-0.5)
-
- def build_attention_mask(self):
- # lazily create causal attention mask, with full attention between the vision tokens
- # pytorch uses additive attention mask; fill with -inf
- mask = torch.empty(self.context_length, self.context_length)
- mask.fill_(float("-inf"))
- mask.triu_(1) # zero out the lower diagonal
- return mask
-
- def encode_audio(self, audio, device):
- return self.audio_branch(
- audio, mixup_lambda=None, device=device
- ) # mix lambda needs to add
-
- # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
- # tmp = {}
- # for k in x[0].keys():
- # tmp[k] = []
- # for i in range(len(x)):
- # tmp[k].append(x[i][k][:77])
- # for k in x[0].keys():
- # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
- # return tmp
-
- def encode_text(self, text, device):
- if self.text_branch_type == "transformer":
- text = text.to(device=device, non_blocking=True)
- x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
-
- x = x + self.positional_embedding
- x = x.permute(1, 0, 2) # NLD -> LND
- x = self.text_branch(x, attn_mask=self.attn_mask)
- x = x.permute(1, 0, 2) # LND -> NLD
- x = self.ln_final(x)
-
- # x.shape = [batch_size, n_ctx, transformer.width]
- # take features from the eot embedding (eot_token is the highest number in each sequence)
- x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
- elif self.text_branch_type == "bert":
- # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
- # text = BatchEncoding(text)
- x = self.text_branch(
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
- attention_mask=text["attention_mask"].to(
- device=device, non_blocking=True
- ),
- token_type_ids=text["token_type_ids"].to(
- device=device, non_blocking=True
- ),
- )["pooler_output"]
- x = self.text_projection(x)
- elif self.text_branch_type == "roberta":
- x = self.text_branch(
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
- attention_mask=text["attention_mask"].to(
- device=device, non_blocking=True
- ),
- )["pooler_output"]
- x = self.text_projection(x)
- elif self.text_branch_type == "bart":
- x = torch.mean(
- self.text_branch(
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
- attention_mask=text["attention_mask"].to(
- device=device, non_blocking=True
- ),
- )["encoder_last_hidden_state"],
- axis=1,
- )
- x = self.text_projection(x)
- else:
- logging.error(f"Model type {self.text_branch_type} not found")
- raise RuntimeError(f"Model type {self.text_branch_type} not found.")
- return x
-
- def forward(self, audio, text, device=None):
- """Forward audio and text into the CLAP
-
- Parameters
- ----------
- audio: torch.Tensor (batch_size, audio_length)
- the time-domain audio input / the batch of mel_spec and longer list.
- text: torch.Tensor () // need to add
- the text token input
- """
- if device is None:
- if audio is not None:
- device = audio.device
- elif text is not None:
- device = text.device
- if audio is None and text is None:
- # a hack to get the logit scale
- return self.logit_scale_a.exp(), self.logit_scale_t.exp()
- elif audio is None:
- return self.encode_text(text, device=device)
- elif text is None:
- return self.audio_projection(
- self.encode_audio(audio, device=device)["embedding"]
- )
- audio_features = self.audio_projection(
- self.encode_audio(audio, device=device)["embedding"]
- )
- audio_features = F.normalize(audio_features, dim=-1)
-
- text_features = self.encode_text(text, device=device)
- # print("text_features", text_features)
- # print("text_features.shape", text_features.shape)
- # print("text_features.type", type(text_features))
- text_features = F.normalize(text_features, dim=-1)
-
- audio_features_mlp = self.audio_transform(audio_features)
- text_features_mlp = self.text_transform(text_features)
- # Four outputs: audio features (basic & MLP), text features (basic & MLP)
- return (
- audio_features,
- text_features,
- audio_features_mlp,
- text_features_mlp,
- self.logit_scale_a.exp(),
- self.logit_scale_t.exp(),
- )
-
- def get_logit_scale(self):
- return self.logit_scale_a.exp(), self.logit_scale_t.exp()
-
- def get_text_embedding(self, data):
- """Get the text embedding from the model
-
- Parameters
- ----------
- data: torch.Tensor
- a tensor of text embedding
-
- Returns
- ----------
- text_embed: torch.Tensor
- a tensor of text_embeds (N, D)
-
- """
- device = next(self.parameters()).device
- for k in data:
- data[k] = data[k].to(device)
- if(len(data[k].size()) < 2):
- data[k] = data[k].unsqueeze(0)
- text_embeds = self.encode_text(data, device=device)
- text_embeds = F.normalize(text_embeds, dim=-1)
-
- return text_embeds
-
- def get_audio_embedding(self, data):
- """Get the audio embedding from the model
-
- Parameters
- ----------
- data: a list of dict
- the audio input dict list from 'get_audio_feature' method
-
- Returns
- ----------
- audio_embed: torch.Tensor
- a tensor of audio_embeds (N, D)
-
- """
- device = next(self.parameters()).device
- input_dict = {}
- keys = data[0].keys()
- for k in keys:
- input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
- device
- )
-
- audio_embeds = self.audio_projection(
- self.encode_audio(input_dict, device=device)["embedding"]
- )
- audio_embeds = F.normalize(audio_embeds, dim=-1)
-
- return audio_embeds
-
- def audio_infer(self, audio, hopsize=None, device=None):
- """Forward one audio and produce the audio embedding
-
- Parameters
- ----------
- audio: (audio_length)
- the time-domain audio input, notice that it must be only one input
- hopsize: int
- the overlap hopsize as the sliding window
-
- Returns
- ----------
- output_dict: {
- key: [n, (embedding_shape)] if "HTS-AT"
- or
- key: [(embedding_shape)] if "PANN"
- }
- the list of key values of the audio branch
-
- """
-
- assert not self.training, "the inference mode must be run at eval stage"
- output_dict = {}
- # PANN
- if self.audio_cfg.model_type == "PANN":
- audio_input = audio.unsqueeze(dim=0)
- output_dict[key] = self.encode_audio(audio_input, device=device)[
- key
- ].squeeze(dim=0)
- elif self.audio_cfg.model_type == "HTSAT":
- # repeat
- audio_len = len(audio)
- k = self.audio_cfg.clip_samples // audio_len
- if k > 1:
- audio = audio.repeat(k)
- audio_len = len(audio)
-
- if hopsize is None:
- hopsize = min(hopsize, audio_len)
-
- if audio_len > self.audio_cfg.clip_samples:
- audio_input = [
- audio[pos : pos + self.audio_cfg.clip_samples].clone()
- for pos in range(
- 0, audio_len - self.audio_cfg.clip_samples, hopsize
- )
- ]
- audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
- audio_input = torch.stack(audio_input)
- output_dict[key] = self.encode_audio(audio_input, device=device)[key]
- else:
- audio_input = audio.unsqueeze(dim=0)
- output_dict[key] = self.encode_audio(audio_input, device=device)[
- key
- ].squeeze(dim=0)
-
- return output_dict
-
-
-def convert_weights_to_fp16(model: nn.Module):
- """Convert applicable model parameters to fp16"""
-
- def _convert_weights_to_fp16(l):
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
- l.weight.data = l.weight.data.half()
- if l.bias is not None:
- l.bias.data = l.bias.data.half()
-
- if isinstance(l, nn.MultiheadAttention):
- for attr in [
- *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
- "in_proj_bias",
- "bias_k",
- "bias_v",
- ]:
- tensor = getattr(l, attr)
- if tensor is not None:
- tensor.data = tensor.data.half()
-
- for name in ["text_projection", "proj"]:
- if hasattr(l, name):
- attr = getattr(l, name)
- if attr is not None:
- attr.data = attr.data.half()
-
- model.apply(_convert_weights_to_fp16)
-
-
-# Ignore the state dict of the vision part
-def build_model_from_openai_state_dict(
- state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
-):
-
- embed_dim = model_cfg["embed_dim"]
- audio_cfg = model_cfg["audio_cfg"]
- text_cfg = model_cfg["text_cfg"]
- context_length = state_dict["positional_embedding"].shape[0]
- vocab_size = state_dict["token_embedding.weight"].shape[0]
- transformer_width = state_dict["ln_final.weight"].shape[0]
- transformer_heads = transformer_width // 64
- transformer_layers = len(
- set(
- k.split(".")[2]
- for k in state_dict
- if k.startswith(f"transformer.resblocks")
- )
- )
-
- audio_cfg = CLAPAudioCfp(**audio_cfg)
- text_cfg = CLAPTextCfg(**text_cfg)
-
- model = CLAP(
- embed_dim,
- audio_cfg=audio_cfg,
- text_cfg=text_cfg,
- quick_gelu=True, # OpenAI models were trained with QuickGELU
- enable_fusion=enable_fusion,
- fusion_type=fusion_type,
- )
- state_dict["logit_scale_a"] = state_dict["logit_scale"]
- state_dict["logit_scale_t"] = state_dict["logit_scale"]
- pop_keys = list(state_dict.keys())[::]
- # pop the visual branch saved weights
- for key in pop_keys:
- if key.startswith("visual."):
- state_dict.pop(key, None)
-
- for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
- state_dict.pop(key, None)
-
- # not use fp16
- # convert_weights_to_fp16(model)
- model.load_state_dict(state_dict, strict=False)
- return model.eval()
-
-
-def trace_model(model, batch_size=256, device=torch.device("cpu")):
- model.eval()
- audio_length = model.audio_cfg.audio_length
- example_audio = torch.ones((batch_size, audio_length), device=device)
- example_text = torch.zeros(
- (batch_size, model.context_length), dtype=torch.int, device=device
- )
- model = torch.jit.trace_module(
- model,
- inputs=dict(
- forward=(example_audio, example_text),
- encode_text=(example_text,),
- encode_image=(example_audio,),
- ),
- )
- model.audio_cfg.audio_length = audio_length # Question: what does this do?
- return model
diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-base.json b/audioldm/clap/open_clip/model_configs/HTSAT-base.json
deleted file mode 100644
index 6cef625a89daf4431f1c9f72e10bc9640eef2ba8..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/HTSAT-base.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "embed_dim": 1024,
- "audio_cfg": {
- "audio_length": 1024,
- "clip_samples": 480000,
- "mel_bins": 64,
- "sample_rate": 48000,
- "window_size": 1024,
- "hop_size": 480,
- "fmin": 50,
- "fmax": 14000,
- "class_num": 527,
- "model_type": "HTSAT",
- "model_name": "base"
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-large.json b/audioldm/clap/open_clip/model_configs/HTSAT-large.json
deleted file mode 100644
index 699cdb1b16855582606551e4196b24aba2ffd871..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/HTSAT-large.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "embed_dim": 2048,
- "audio_cfg": {
- "audio_length": 1024,
- "clip_samples": 480000,
- "mel_bins": 64,
- "sample_rate": 48000,
- "window_size": 1024,
- "hop_size": 480,
- "fmin": 50,
- "fmax": 14000,
- "class_num": 527,
- "model_type": "HTSAT",
- "model_name": "large"
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json b/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json
deleted file mode 100644
index 73e42990fe8361a0df502e7f93d29f19f58c9ecb..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "embed_dim": 768,
- "audio_cfg": {
- "audio_length": 1024,
- "clip_samples": 480000,
- "mel_bins": 64,
- "sample_rate": 48000,
- "window_size": 1536,
- "hop_size": 480,
- "fmin": 50,
- "fmax": 14000,
- "class_num": 527,
- "model_type": "HTSAT",
- "model_name": "tiny"
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json b/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json
deleted file mode 100644
index a6e7821163d9afa81c27345a1e472475b92af169..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "embed_dim": 768,
- "audio_cfg": {
- "audio_length": 1024,
- "clip_samples": 480000,
- "mel_bins": 64,
- "sample_rate": 48000,
- "window_size": 1024,
- "hop_size": 480,
- "fmin": 50,
- "fmax": 14000,
- "class_num": 527,
- "model_type": "HTSAT",
- "model_name": "tiny"
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-10.json b/audioldm/clap/open_clip/model_configs/PANN-10.json
deleted file mode 100644
index 954ddf62921aed7dde9c37ffffec98a2e96a4ee7..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/PANN-10.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "embed_dim": 1024,
- "audio_cfg": {
- "audio_length": 1024,
- "clip_samples": 480000,
- "mel_bins": 64,
- "sample_rate": 48000,
- "window_size": 1024,
- "hop_size": 480,
- "fmin": 50,
- "fmax": 14000,
- "class_num": 527,
- "model_type": "PANN",
- "model_name": "Cnn10"
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json
deleted file mode 100644
index b7989bc0cd95d0d39049b7524eba508b3e386439..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "embed_dim": 2048,
- "audio_cfg": {
- "audio_length": 1024,
- "clip_samples": 480000,
- "mel_bins": 64,
- "sample_rate": 48000,
- "window_size": 1024,
- "hop_size": 480,
- "fmin": 50,
- "fmax": 18000,
- "class_num": 527,
- "model_type": "PANN",
- "model_name": "Cnn14"
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json
deleted file mode 100644
index 56bdb56bedc304ffa52d8bf5988cea2c1d82d14e..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "embed_dim": 2048,
- "audio_cfg": {
- "audio_length": 1024,
- "clip_samples": 960000,
- "mel_bins": 64,
- "sample_rate": 48000,
- "window_size": 1024,
- "hop_size": 360,
- "fmin": 50,
- "fmax": 8000,
- "class_num": 527,
- "model_type": "PANN",
- "model_name": "Cnn14"
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json b/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json
deleted file mode 100644
index 5756e3bebc97cc985f512cb081930fee4e49bec1..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "embed_dim": 2048,
- "audio_cfg": {
- "audio_length": 1024,
- "clip_samples": 480000,
- "mel_bins": 64,
- "sample_rate": 48000,
- "window_size": 1024,
- "hop_size": 480,
- "fmin": 50,
- "fmax": 14000,
- "class_num": 527,
- "model_type": "PANN",
- "model_name": "Cnn14"
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 4
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json b/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json
deleted file mode 100644
index 5a9e7e208b661619d5e26625e849da1adda8a475..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "embed_dim": 2048,
- "audio_cfg": {
- "audio_length": 1024,
- "clip_samples": 480000,
- "mel_bins": 64,
- "sample_rate": 48000,
- "window_size": 1536,
- "hop_size": 480,
- "fmin": 50,
- "fmax": 14000,
- "class_num": 527,
- "model_type": "PANN",
- "model_name": "Cnn14"
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-14.json b/audioldm/clap/open_clip/model_configs/PANN-14.json
deleted file mode 100644
index 39a5134cde1d8c50f4758377c952ef22f07bab41..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/PANN-14.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "embed_dim": 2048,
- "audio_cfg": {
- "audio_length": 1024,
- "clip_samples": 480000,
- "mel_bins": 64,
- "sample_rate": 48000,
- "window_size": 1024,
- "hop_size": 480,
- "fmin": 50,
- "fmax": 14000,
- "class_num": 527,
- "model_type": "PANN",
- "model_name": "Cnn14"
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-6.json b/audioldm/clap/open_clip/model_configs/PANN-6.json
deleted file mode 100644
index 21ebc344326de260c386ba77e0ad63cf9b04febf..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/PANN-6.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "embed_dim": 512,
- "audio_cfg": {
- "audio_length": 1024,
- "clip_samples": 480000,
- "mel_bins": 64,
- "sample_rate": 48000,
- "window_size": 1024,
- "hop_size": 480,
- "fmin": 50,
- "fmax": 14000,
- "class_num": 527,
- "model_type": "PANN",
- "model_name": "Cnn6"
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json b/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json
deleted file mode 100644
index d0db2c161d13138788c4609d373b023b8454d624..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json
+++ /dev/null
@@ -1,22 +0,0 @@
-{
- "embed_dim": 512,
- "quick_gelu": true,
- "vision_cfg": {
- "image_size": 224,
- "layers": [
- 3,
- 4,
- 23,
- 3
- ],
- "width": 64,
- "patch_size": null
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/RN101.json b/audioldm/clap/open_clip/model_configs/RN101.json
deleted file mode 100644
index b88b4d3acbaa701c614ab0ea65fc88fcfe289c32..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/RN101.json
+++ /dev/null
@@ -1,21 +0,0 @@
-{
- "embed_dim": 512,
- "vision_cfg": {
- "image_size": 224,
- "layers": [
- 3,
- 4,
- 23,
- 3
- ],
- "width": 64,
- "patch_size": null
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json b/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json
deleted file mode 100644
index 8c2f91260cdeb043434dc1e893cce81d4ce7f0d1..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json
+++ /dev/null
@@ -1,22 +0,0 @@
-{
- "embed_dim": 1024,
- "quick_gelu": true,
- "vision_cfg": {
- "image_size": 224,
- "layers": [
- 3,
- 4,
- 6,
- 3
- ],
- "width": 64,
- "patch_size": null
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
diff --git a/audioldm/clap/open_clip/model_configs/RN50.json b/audioldm/clap/open_clip/model_configs/RN50.json
deleted file mode 100644
index 33aa884d54fee0076c33676831e49d5e1ffcb8f2..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/RN50.json
+++ /dev/null
@@ -1,21 +0,0 @@
-{
- "embed_dim": 1024,
- "vision_cfg": {
- "image_size": 224,
- "layers": [
- 3,
- 4,
- 6,
- 3
- ],
- "width": 64,
- "patch_size": null
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/RN50x16.json b/audioldm/clap/open_clip/model_configs/RN50x16.json
deleted file mode 100644
index 3161e1a2c9a839161e652a4d729c2cdc971161db..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/RN50x16.json
+++ /dev/null
@@ -1,21 +0,0 @@
-{
- "embed_dim": 768,
- "vision_cfg": {
- "image_size": 384,
- "layers": [
- 6,
- 8,
- 18,
- 8
- ],
- "width": 96,
- "patch_size": null
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 768,
- "heads": 12,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/RN50x4.json b/audioldm/clap/open_clip/model_configs/RN50x4.json
deleted file mode 100644
index e155237f8ce1026aaaeecc80751eabe6f329f0bb..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/RN50x4.json
+++ /dev/null
@@ -1,21 +0,0 @@
-{
- "embed_dim": 640,
- "vision_cfg": {
- "image_size": 288,
- "layers": [
- 4,
- 6,
- 10,
- 6
- ],
- "width": 80,
- "patch_size": null
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 640,
- "heads": 10,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-16.json b/audioldm/clap/open_clip/model_configs/ViT-B-16.json
deleted file mode 100644
index 395eea77ec3907c0611531aba63459b193e67b9c..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/ViT-B-16.json
+++ /dev/null
@@ -1,16 +0,0 @@
-{
- "embed_dim": 512,
- "vision_cfg": {
- "image_size": 224,
- "layers": 12,
- "width": 768,
- "patch_size": 16
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json b/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json
deleted file mode 100644
index ce6bd923593293ed50dfcfb28b73ca7403bcf3c5..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json
+++ /dev/null
@@ -1,17 +0,0 @@
-{
- "embed_dim": 512,
- "quick_gelu": true,
- "vision_cfg": {
- "image_size": 224,
- "layers": 12,
- "width": 768,
- "patch_size": 32
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-32.json b/audioldm/clap/open_clip/model_configs/ViT-B-32.json
deleted file mode 100644
index 07c8e28eb06fa1813ba932fe4eec668262d1c47f..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/ViT-B-32.json
+++ /dev/null
@@ -1,16 +0,0 @@
-{
- "embed_dim": 512,
- "vision_cfg": {
- "image_size": 224,
- "layers": 12,
- "width": 768,
- "patch_size": 32
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 512,
- "heads": 8,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/ViT-L-14.json b/audioldm/clap/open_clip/model_configs/ViT-L-14.json
deleted file mode 100644
index d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/model_configs/ViT-L-14.json
+++ /dev/null
@@ -1,16 +0,0 @@
-{
- "embed_dim": 768,
- "vision_cfg": {
- "image_size": 224,
- "layers": 24,
- "width": 1024,
- "patch_size": 14
- },
- "text_cfg": {
- "context_length": 77,
- "vocab_size": 49408,
- "width": 768,
- "heads": 12,
- "layers": 12
- }
-}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/openai.py b/audioldm/clap/open_clip/openai.py
deleted file mode 100644
index 3f4eb8b55fe960e1792b3da804b60b3d8f70fe26..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/openai.py
+++ /dev/null
@@ -1,156 +0,0 @@
-""" OpenAI pretrained model functions
-
-Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
-"""
-
-import os
-import warnings
-from typing import Union, List
-
-import torch
-
-from .model import build_model_from_openai_state_dict
-from .pretrained import (
- get_pretrained_url,
- list_pretrained_tag_models,
- download_pretrained,
-)
-
-__all__ = ["list_openai_models", "load_openai_model"]
-
-
-def list_openai_models() -> List[str]:
- """Returns the names of available CLIP models"""
- return list_pretrained_tag_models("openai")
-
-
-def load_openai_model(
- name: str,
- model_cfg,
- device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
- jit=True,
- cache_dir=os.path.expanduser("~/.cache/clip"),
- enable_fusion: bool = False,
- fusion_type: str = "None",
-):
- """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
-
- Parameters
- ----------
- name : str
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
- device : Union[str, torch.device]
- The device to put the loaded model
- jit : bool
- Whether to load the optimized JIT model (default) or more hackable non-JIT model.
-
- Returns
- -------
- model : torch.nn.Module
- The CLAP model
- preprocess : Callable[[PIL.Image], torch.Tensor]
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
- """
- if get_pretrained_url(name, "openai"):
- model_path = download_pretrained(
- get_pretrained_url(name, "openai"), root=cache_dir
- )
- elif os.path.isfile(name):
- model_path = name
- else:
- raise RuntimeError(
- f"Model {name} not found; available models = {list_openai_models()}"
- )
-
- try:
- # loading JIT archive
- model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
- state_dict = None
- except RuntimeError:
- # loading saved state dict
- if jit:
- warnings.warn(
- f"File {model_path} is not a JIT archive. Loading as a state dict instead"
- )
- jit = False
- state_dict = torch.load(model_path, map_location="cpu")
-
- if not jit:
- try:
- model = build_model_from_openai_state_dict(
- state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
- ).to(device)
- except KeyError:
- sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
- model = build_model_from_openai_state_dict(
- sd, model_cfg, enable_fusion, fusion_type
- ).to(device)
-
- if str(device) == "cpu":
- model.float()
- return model
-
- # patch the device names
- device_holder = torch.jit.trace(
- lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
- )
- device_node = [
- n
- for n in device_holder.graph.findAllNodes("prim::Constant")
- if "Device" in repr(n)
- ][-1]
-
- def patch_device(module):
- try:
- graphs = [module.graph] if hasattr(module, "graph") else []
- except RuntimeError:
- graphs = []
-
- if hasattr(module, "forward1"):
- graphs.append(module.forward1.graph)
-
- for graph in graphs:
- for node in graph.findAllNodes("prim::Constant"):
- if "value" in node.attributeNames() and str(node["value"]).startswith(
- "cuda"
- ):
- node.copyAttributes(device_node)
-
- model.apply(patch_device)
- patch_device(model.encode_audio)
- patch_device(model.encode_text)
-
- # patch dtype to float32 on CPU
- if str(device) == "cpu":
- float_holder = torch.jit.trace(
- lambda: torch.ones([]).float(), example_inputs=[]
- )
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
- float_node = float_input.node()
-
- def patch_float(module):
- try:
- graphs = [module.graph] if hasattr(module, "graph") else []
- except RuntimeError:
- graphs = []
-
- if hasattr(module, "forward1"):
- graphs.append(module.forward1.graph)
-
- for graph in graphs:
- for node in graph.findAllNodes("aten::to"):
- inputs = list(node.inputs())
- for i in [
- 1,
- 2,
- ]: # dtype can be the second or third argument to aten::to()
- if inputs[i].node()["value"] == 5:
- inputs[i].node().copyAttributes(float_node)
-
- model.apply(patch_float)
- patch_float(model.encode_audio)
- patch_float(model.encode_text)
- model.float()
-
- model.audio_branch.audio_length = model.audio_cfg.audio_length
- return model
diff --git a/audioldm/clap/open_clip/pann_model.py b/audioldm/clap/open_clip/pann_model.py
deleted file mode 100644
index 874a03fc6eabcfdf3a63c59ca1e05d4f991453c5..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/pann_model.py
+++ /dev/null
@@ -1,703 +0,0 @@
-# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
-# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
-# Some layers are re-designed for CLAP
-import os
-
-os.environ["NUMBA_CACHE_DIR"] = "/tmp/"
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torchlibrosa.stft import Spectrogram, LogmelFilterBank
-from torchlibrosa.augmentation import SpecAugmentation
-
-from .utils import do_mixup, interpolate, pad_framewise_output
-from .feature_fusion import iAFF, AFF, DAF
-
-
-def init_layer(layer):
- """Initialize a Linear or Convolutional layer."""
- nn.init.xavier_uniform_(layer.weight)
-
- if hasattr(layer, "bias"):
- if layer.bias is not None:
- layer.bias.data.fill_(0.0)
-
-def init_bn(bn):
- """Initialize a Batchnorm layer."""
- bn.bias.data.fill_(0.0)
- bn.weight.data.fill_(1.0)
-
-
-class ConvBlock(nn.Module):
- def __init__(self, in_channels, out_channels):
-
- super(ConvBlock, self).__init__()
-
- self.conv1 = nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=(3, 3),
- stride=(1, 1),
- padding=(1, 1),
- bias=False,
- )
-
- self.conv2 = nn.Conv2d(
- in_channels=out_channels,
- out_channels=out_channels,
- kernel_size=(3, 3),
- stride=(1, 1),
- padding=(1, 1),
- bias=False,
- )
-
- self.bn1 = nn.BatchNorm2d(out_channels)
- self.bn2 = nn.BatchNorm2d(out_channels)
-
- self.init_weight()
-
- def init_weight(self):
- init_layer(self.conv1)
- init_layer(self.conv2)
- init_bn(self.bn1)
- init_bn(self.bn2)
-
- def forward(self, input, pool_size=(2, 2), pool_type="avg"):
-
- x = input
- x = F.relu_(self.bn1(self.conv1(x)))
- x = F.relu_(self.bn2(self.conv2(x)))
- if pool_type == "max":
- x = F.max_pool2d(x, kernel_size=pool_size)
- elif pool_type == "avg":
- x = F.avg_pool2d(x, kernel_size=pool_size)
- elif pool_type == "avg+max":
- x1 = F.avg_pool2d(x, kernel_size=pool_size)
- x2 = F.max_pool2d(x, kernel_size=pool_size)
- x = x1 + x2
- else:
- raise Exception("Incorrect argument!")
-
- return x
-
-
-class ConvBlock5x5(nn.Module):
- def __init__(self, in_channels, out_channels):
-
- super(ConvBlock5x5, self).__init__()
-
- self.conv1 = nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=(5, 5),
- stride=(1, 1),
- padding=(2, 2),
- bias=False,
- )
-
- self.bn1 = nn.BatchNorm2d(out_channels)
-
- self.init_weight()
-
- def init_weight(self):
- init_layer(self.conv1)
- init_bn(self.bn1)
-
- def forward(self, input, pool_size=(2, 2), pool_type="avg"):
-
- x = input
- x = F.relu_(self.bn1(self.conv1(x)))
- if pool_type == "max":
- x = F.max_pool2d(x, kernel_size=pool_size)
- elif pool_type == "avg":
- x = F.avg_pool2d(x, kernel_size=pool_size)
- elif pool_type == "avg+max":
- x1 = F.avg_pool2d(x, kernel_size=pool_size)
- x2 = F.max_pool2d(x, kernel_size=pool_size)
- x = x1 + x2
- else:
- raise Exception("Incorrect argument!")
-
- return x
-
-
-class AttBlock(nn.Module):
- def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
- super(AttBlock, self).__init__()
-
- self.activation = activation
- self.temperature = temperature
- self.att = nn.Conv1d(
- in_channels=n_in,
- out_channels=n_out,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=True,
- )
- self.cla = nn.Conv1d(
- in_channels=n_in,
- out_channels=n_out,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=True,
- )
-
- self.bn_att = nn.BatchNorm1d(n_out)
- self.init_weights()
-
- def init_weights(self):
- init_layer(self.att)
- init_layer(self.cla)
- init_bn(self.bn_att)
-
- def forward(self, x):
- # x: (n_samples, n_in, n_time)
- norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
- cla = self.nonlinear_transform(self.cla(x))
- x = torch.sum(norm_att * cla, dim=2)
- return x, norm_att, cla
-
- def nonlinear_transform(self, x):
- if self.activation == "linear":
- return x
- elif self.activation == "sigmoid":
- return torch.sigmoid(x)
-
-
-class Cnn14(nn.Module):
- def __init__(
- self,
- sample_rate,
- window_size,
- hop_size,
- mel_bins,
- fmin,
- fmax,
- classes_num,
- enable_fusion=False,
- fusion_type="None",
- ):
-
- super(Cnn14, self).__init__()
-
- window = "hann"
- center = True
- pad_mode = "reflect"
- ref = 1.0
- amin = 1e-10
- top_db = None
-
- self.enable_fusion = enable_fusion
- self.fusion_type = fusion_type
-
- # Spectrogram extractor
- self.spectrogram_extractor = Spectrogram(
- n_fft=window_size,
- hop_length=hop_size,
- win_length=window_size,
- window=window,
- center=center,
- pad_mode=pad_mode,
- freeze_parameters=True,
- )
-
- # Logmel feature extractor
- self.logmel_extractor = LogmelFilterBank(
- sr=sample_rate,
- n_fft=window_size,
- n_mels=mel_bins,
- fmin=fmin,
- fmax=fmax,
- ref=ref,
- amin=amin,
- top_db=top_db,
- freeze_parameters=True,
- )
-
- # Spec augmenter
- self.spec_augmenter = SpecAugmentation(
- time_drop_width=64,
- time_stripes_num=2,
- freq_drop_width=8,
- freq_stripes_num=2,
- )
-
- self.bn0 = nn.BatchNorm2d(64)
-
- if (self.enable_fusion) and (self.fusion_type == "channel_map"):
- self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
- else:
- self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
- self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
- self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
- self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
- self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
- self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
-
- self.fc1 = nn.Linear(2048, 2048, bias=True)
- self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
-
- if (self.enable_fusion) and (
- self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
- ):
- self.mel_conv1d = nn.Sequential(
- nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
- nn.BatchNorm1d(64), # No Relu
- )
- if self.fusion_type == "daf_1d":
- self.fusion_model = DAF()
- elif self.fusion_type == "aff_1d":
- self.fusion_model = AFF(channels=64, type="1D")
- elif self.fusion_type == "iaff_1d":
- self.fusion_model = iAFF(channels=64, type="1D")
-
- if (self.enable_fusion) and (
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
- ):
- self.mel_conv2d = nn.Sequential(
- nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- )
-
- if self.fusion_type == "daf_2d":
- self.fusion_model = DAF()
- elif self.fusion_type == "aff_2d":
- self.fusion_model = AFF(channels=64, type="2D")
- elif self.fusion_type == "iaff_2d":
- self.fusion_model = iAFF(channels=64, type="2D")
- self.init_weight()
-
- def init_weight(self):
- init_bn(self.bn0)
- init_layer(self.fc1)
- init_layer(self.fc_audioset)
-
- def forward(self, input, mixup_lambda=None, device=None):
- """
- Input: (batch_size, data_length)"""
-
- if self.enable_fusion and input["longer"].sum() == 0:
- # if no audio is longer than 10s, then randomly select one audio to be longer
- input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
-
- if not self.enable_fusion:
- x = self.spectrogram_extractor(
- input["waveform"].to(device=device, non_blocking=True)
- ) # (batch_size, 1, time_steps, freq_bins)
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
-
- x = x.transpose(1, 3)
- x = self.bn0(x)
- x = x.transpose(1, 3)
- else:
- longer_list = input["longer"].to(device=device, non_blocking=True)
- x = input["mel_fusion"].to(device=device, non_blocking=True)
- longer_list_idx = torch.where(longer_list)[0]
- x = x.transpose(1, 3)
- x = self.bn0(x)
- x = x.transpose(1, 3)
- if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
- new_x = x[:, 0:1, :, :].clone().contiguous()
- # local processing
- if len(longer_list_idx) > 0:
- fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
- FB, FC, FT, FF = fusion_x_local.size()
- fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
- fusion_x_local = torch.permute(
- fusion_x_local, (0, 2, 1)
- ).contiguous()
- fusion_x_local = self.mel_conv1d(fusion_x_local)
- fusion_x_local = fusion_x_local.view(
- FB, FC, FF, fusion_x_local.size(-1)
- )
- fusion_x_local = (
- torch.permute(fusion_x_local, (0, 2, 1, 3))
- .contiguous()
- .flatten(2)
- )
- if fusion_x_local.size(-1) < FT:
- fusion_x_local = torch.cat(
- [
- fusion_x_local,
- torch.zeros(
- (FB, FF, FT - fusion_x_local.size(-1)),
- device=device,
- ),
- ],
- dim=-1,
- )
- else:
- fusion_x_local = fusion_x_local[:, :, :FT]
- # 1D fusion
- new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
- new_x[longer_list_idx] = self.fusion_model(
- new_x[longer_list_idx], fusion_x_local
- )
- x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
- else:
- x = new_x
- elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
- x = x # no change
-
- if self.training:
- x = self.spec_augmenter(x)
- # Mixup on spectrogram
- if self.training and mixup_lambda is not None:
- x = do_mixup(x, mixup_lambda)
- if (self.enable_fusion) and (
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
- ):
- global_x = x[:, 0:1, :, :]
-
- # global processing
- B, C, H, W = global_x.shape
- global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg")
- if len(longer_list_idx) > 0:
- local_x = x[longer_list_idx, 1:, :, :].contiguous()
- TH = global_x.size(-2)
- # local processing
- B, C, H, W = local_x.shape
- local_x = local_x.view(B * C, 1, H, W)
- local_x = self.mel_conv2d(local_x)
- local_x = local_x.view(
- B, C, local_x.size(1), local_x.size(2), local_x.size(3)
- )
- local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3)
- TB, TC, _, TW = local_x.size()
- if local_x.size(-2) < TH:
- local_x = torch.cat(
- [
- local_x,
- torch.zeros(
- (TB, TC, TH - local_x.size(-2), TW),
- device=global_x.device,
- ),
- ],
- dim=-2,
- )
- else:
- local_x = local_x[:, :, :TH, :]
-
- global_x[longer_list_idx] = self.fusion_model(
- global_x[longer_list_idx], local_x
- )
- x = global_x
- else:
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
-
- x = F.dropout(x, p=0.2, training=self.training)
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = torch.mean(x, dim=3)
-
- latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
- latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
- latent_x = latent_x1 + latent_x2
- latent_x = latent_x.transpose(1, 2)
- latent_x = F.relu_(self.fc1(latent_x))
- latent_output = interpolate(latent_x, 32)
-
- (x1, _) = torch.max(x, dim=2)
- x2 = torch.mean(x, dim=2)
- x = x1 + x2
- x = F.dropout(x, p=0.5, training=self.training)
- x = F.relu_(self.fc1(x))
- embedding = F.dropout(x, p=0.5, training=self.training)
- clipwise_output = torch.sigmoid(self.fc_audioset(x))
-
- output_dict = {
- "clipwise_output": clipwise_output,
- "embedding": embedding,
- "fine_grained_embedding": latent_output,
- }
- return output_dict
-
-
-class Cnn6(nn.Module):
- def __init__(
- self,
- sample_rate,
- window_size,
- hop_size,
- mel_bins,
- fmin,
- fmax,
- classes_num,
- enable_fusion=False,
- fusion_type="None",
- ):
-
- super(Cnn6, self).__init__()
-
- window = "hann"
- center = True
- pad_mode = "reflect"
- ref = 1.0
- amin = 1e-10
- top_db = None
-
- self.enable_fusion = enable_fusion
- self.fusion_type = fusion_type
-
- # Spectrogram extractor
- self.spectrogram_extractor = Spectrogram(
- n_fft=window_size,
- hop_length=hop_size,
- win_length=window_size,
- window=window,
- center=center,
- pad_mode=pad_mode,
- freeze_parameters=True,
- )
-
- # Logmel feature extractor
- self.logmel_extractor = LogmelFilterBank(
- sr=sample_rate,
- n_fft=window_size,
- n_mels=mel_bins,
- fmin=fmin,
- fmax=fmax,
- ref=ref,
- amin=amin,
- top_db=top_db,
- freeze_parameters=True,
- )
-
- # Spec augmenter
- self.spec_augmenter = SpecAugmentation(
- time_drop_width=64,
- time_stripes_num=2,
- freq_drop_width=8,
- freq_stripes_num=2,
- )
-
- self.bn0 = nn.BatchNorm2d(64)
-
- self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
- self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
- self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
- self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
-
- self.fc1 = nn.Linear(512, 512, bias=True)
- self.fc_audioset = nn.Linear(512, classes_num, bias=True)
-
- self.init_weight()
-
- def init_weight(self):
- init_bn(self.bn0)
- init_layer(self.fc1)
- init_layer(self.fc_audioset)
-
- def forward(self, input, mixup_lambda=None, device=None):
- """
- Input: (batch_size, data_length)"""
-
- x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
-
- x = x.transpose(1, 3)
- x = self.bn0(x)
- x = x.transpose(1, 3)
-
- if self.training:
- x = self.spec_augmenter(x)
-
- # Mixup on spectrogram
- if self.training and mixup_lambda is not None:
- x = do_mixup(x, mixup_lambda)
-
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = torch.mean(x, dim=3)
-
- latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
- latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
- latent_x = latent_x1 + latent_x2
- latent_x = latent_x.transpose(1, 2)
- latent_x = F.relu_(self.fc1(latent_x))
- latent_output = interpolate(latent_x, 16)
-
- (x1, _) = torch.max(x, dim=2)
- x2 = torch.mean(x, dim=2)
- x = x1 + x2
- x = F.dropout(x, p=0.5, training=self.training)
- x = F.relu_(self.fc1(x))
- embedding = F.dropout(x, p=0.5, training=self.training)
- clipwise_output = torch.sigmoid(self.fc_audioset(x))
-
- output_dict = {
- "clipwise_output": clipwise_output,
- "embedding": embedding,
- "fine_grained_embedding": latent_output,
- }
-
- return output_dict
-
-
-class Cnn10(nn.Module):
- def __init__(
- self,
- sample_rate,
- window_size,
- hop_size,
- mel_bins,
- fmin,
- fmax,
- classes_num,
- enable_fusion=False,
- fusion_type="None",
- ):
-
- super(Cnn10, self).__init__()
-
- window = "hann"
- center = True
- pad_mode = "reflect"
- ref = 1.0
- amin = 1e-10
- top_db = None
-
- self.enable_fusion = enable_fusion
- self.fusion_type = fusion_type
-
- # Spectrogram extractor
- self.spectrogram_extractor = Spectrogram(
- n_fft=window_size,
- hop_length=hop_size,
- win_length=window_size,
- window=window,
- center=center,
- pad_mode=pad_mode,
- freeze_parameters=True,
- )
-
- # Logmel feature extractor
- self.logmel_extractor = LogmelFilterBank(
- sr=sample_rate,
- n_fft=window_size,
- n_mels=mel_bins,
- fmin=fmin,
- fmax=fmax,
- ref=ref,
- amin=amin,
- top_db=top_db,
- freeze_parameters=True,
- )
-
- # Spec augmenter
- self.spec_augmenter = SpecAugmentation(
- time_drop_width=64,
- time_stripes_num=2,
- freq_drop_width=8,
- freq_stripes_num=2,
- )
-
- self.bn0 = nn.BatchNorm2d(64)
-
- self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
- self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
- self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
- self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
- self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
-
- self.fc1 = nn.Linear(1024, 1024, bias=True)
- self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
-
- self.init_weight()
-
- def init_weight(self):
- init_bn(self.bn0)
- init_layer(self.fc1)
- init_layer(self.fc_audioset)
-
- def forward(self, input, mixup_lambda=None, device=None):
- """
- Input: (batch_size, data_length)"""
-
- x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
-
- x = x.transpose(1, 3)
- x = self.bn0(x)
- x = x.transpose(1, 3)
-
- if self.training:
- x = self.spec_augmenter(x)
-
- # Mixup on spectrogram
- if self.training and mixup_lambda is not None:
- x = do_mixup(x, mixup_lambda)
-
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
- x = F.dropout(x, p=0.2, training=self.training)
- x = torch.mean(x, dim=3)
-
- latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
- latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
- latent_x = latent_x1 + latent_x2
- latent_x = latent_x.transpose(1, 2)
- latent_x = F.relu_(self.fc1(latent_x))
- latent_output = interpolate(latent_x, 32)
-
- (x1, _) = torch.max(x, dim=2)
- x2 = torch.mean(x, dim=2)
- x = x1 + x2
- x = F.dropout(x, p=0.5, training=self.training)
- x = F.relu_(self.fc1(x))
- embedding = F.dropout(x, p=0.5, training=self.training)
- clipwise_output = torch.sigmoid(self.fc_audioset(x))
-
- output_dict = {
- "clipwise_output": clipwise_output,
- "embedding": embedding,
- "fine_grained_embedding": latent_output,
- }
-
- return output_dict
-
-
-def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"):
- try:
- ModelProto = eval(audio_cfg.model_name)
- model = ModelProto(
- sample_rate=audio_cfg.sample_rate,
- window_size=audio_cfg.window_size,
- hop_size=audio_cfg.hop_size,
- mel_bins=audio_cfg.mel_bins,
- fmin=audio_cfg.fmin,
- fmax=audio_cfg.fmax,
- classes_num=audio_cfg.class_num,
- enable_fusion=enable_fusion,
- fusion_type=fusion_type,
- )
- return model
- except:
- raise RuntimeError(
- f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
- )
diff --git a/audioldm/clap/open_clip/pretrained.py b/audioldm/clap/open_clip/pretrained.py
deleted file mode 100644
index e211d8b5b59320a599e62605f1dee6199f317253..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/pretrained.py
+++ /dev/null
@@ -1,167 +0,0 @@
-import hashlib
-import os
-import urllib
-import warnings
-
-from tqdm import tqdm
-
-_RN50 = dict(
- openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
- cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
-)
-
-_RN50_quickgelu = dict(
- openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
- cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
-)
-
-_RN101 = dict(
- openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
-)
-
-_RN101_quickgelu = dict(
- openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
-)
-
-_RN50x4 = dict(
- openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
-)
-
-_RN50x16 = dict(
- openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
-)
-
-_RN50x64 = dict(
- openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
-)
-
-_VITB32 = dict(
- openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
- laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
- laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
- laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
-)
-
-_VITB32_quickgelu = dict(
- openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
- laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
- laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
- laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
-)
-
-_VITB16 = dict(
- openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
-)
-
-_VITL14 = dict(
- openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
-)
-
-_PRETRAINED = {
- "RN50": _RN50,
- "RN50-quickgelu": _RN50_quickgelu,
- "RN101": _RN101,
- "RN101-quickgelu": _RN101_quickgelu,
- "RN50x4": _RN50x4,
- "RN50x16": _RN50x16,
- "ViT-B-32": _VITB32,
- "ViT-B-32-quickgelu": _VITB32_quickgelu,
- "ViT-B-16": _VITB16,
- "ViT-L-14": _VITL14,
-}
-
-
-def list_pretrained(as_str: bool = False):
- """returns list of pretrained models
- Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
- """
- return [
- ":".join([k, t]) if as_str else (k, t)
- for k in _PRETRAINED.keys()
- for t in _PRETRAINED[k].keys()
- ]
-
-
-def list_pretrained_tag_models(tag: str):
- """return all models having the specified pretrain tag"""
- models = []
- for k in _PRETRAINED.keys():
- if tag in _PRETRAINED[k]:
- models.append(k)
- return models
-
-
-def list_pretrained_model_tags(model: str):
- """return all pretrain tags for the specified model architecture"""
- tags = []
- if model in _PRETRAINED:
- tags.extend(_PRETRAINED[model].keys())
- return tags
-
-
-def get_pretrained_url(model: str, tag: str):
- if model not in _PRETRAINED:
- return ""
- model_pretrained = _PRETRAINED[model]
- if tag not in model_pretrained:
- return ""
- return model_pretrained[tag]
-
-
-def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
- os.makedirs(root, exist_ok=True)
- filename = os.path.basename(url)
-
- if "openaipublic" in url:
- expected_sha256 = url.split("/")[-2]
- else:
- expected_sha256 = ""
-
- download_target = os.path.join(root, filename)
-
- if os.path.exists(download_target) and not os.path.isfile(download_target):
- raise RuntimeError(f"{download_target} exists and is not a regular file")
-
- if os.path.isfile(download_target):
- if expected_sha256:
- if (
- hashlib.sha256(open(download_target, "rb").read()).hexdigest()
- == expected_sha256
- ):
- return download_target
- else:
- warnings.warn(
- f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
- )
- else:
- return download_target
-
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
- with tqdm(
- total=int(source.info().get("Content-Length")),
- ncols=80,
- unit="iB",
- unit_scale=True,
- ) as loop:
- while True:
- buffer = source.read(8192)
- if not buffer:
- break
-
- output.write(buffer)
- loop.update(len(buffer))
-
- if (
- expected_sha256
- and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
- != expected_sha256
- ):
- raise RuntimeError(
- f"Model has been downloaded but the SHA256 checksum does not not match"
- )
-
- return download_target
diff --git a/audioldm/clap/open_clip/timm_model.py b/audioldm/clap/open_clip/timm_model.py
deleted file mode 100644
index c9d1ab4666b5bab5038d44b90c9ddca5087de460..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/timm_model.py
+++ /dev/null
@@ -1,112 +0,0 @@
-""" timm model adapter
-
-Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
-"""
-from collections import OrderedDict
-
-import torch.nn as nn
-
-try:
- import timm
- from timm.models.layers import Mlp, to_2tuple
- from timm.models.layers.attention_pool2d import RotAttentionPool2d
- from timm.models.layers.attention_pool2d import (
- AttentionPool2d as AbsAttentionPool2d,
- )
-except ImportError as e:
- timm = None
-
-from .utils import freeze_batch_norm_2d
-
-
-class TimmModel(nn.Module):
- """timm model adapter
- # FIXME this adapter is a work in progress, may change in ways that break weight compat
- """
-
- def __init__(
- self,
- model_name,
- embed_dim,
- image_size=224,
- pool="avg",
- proj="linear",
- drop=0.0,
- pretrained=False,
- ):
- super().__init__()
- if timm is None:
- raise RuntimeError("Please `pip install timm` to use timm models.")
-
- self.image_size = to_2tuple(image_size)
- self.trunk = timm.create_model(model_name, pretrained=pretrained)
- feat_size = self.trunk.default_cfg.get("pool_size", None)
- feature_ndim = 1 if not feat_size else 2
- if pool in ("abs_attn", "rot_attn"):
- assert feature_ndim == 2
- # if attn pooling used, remove both classifier and default pool
- self.trunk.reset_classifier(0, global_pool="")
- else:
- # reset global pool if pool config set, otherwise leave as network default
- reset_kwargs = dict(global_pool=pool) if pool else {}
- self.trunk.reset_classifier(0, **reset_kwargs)
- prev_chs = self.trunk.num_features
-
- head_layers = OrderedDict()
- if pool == "abs_attn":
- head_layers["pool"] = AbsAttentionPool2d(
- prev_chs, feat_size=feat_size, out_features=embed_dim
- )
- prev_chs = embed_dim
- elif pool == "rot_attn":
- head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
- prev_chs = embed_dim
- else:
- assert proj, "projection layer needed if non-attention pooling is used."
-
- # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
- if proj == "linear":
- head_layers["drop"] = nn.Dropout(drop)
- head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
- elif proj == "mlp":
- head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
-
- self.head = nn.Sequential(head_layers)
-
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
- """lock modules
- Args:
- unlocked_groups (int): leave last n layer groups unlocked (default: 0)
- """
- if not unlocked_groups:
- # lock full model
- for param in self.trunk.parameters():
- param.requires_grad = False
- if freeze_bn_stats:
- freeze_batch_norm_2d(self.trunk)
- else:
- # NOTE: partial freeze requires latest timm (master) branch and is subject to change
- try:
- # FIXME import here until API stable and in an official release
- from timm.models.helpers import group_parameters, group_modules
- except ImportError:
- raise RuntimeError(
- "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
- )
- matcher = self.trunk.group_matcher()
- gparams = group_parameters(self.trunk, matcher)
- max_layer_id = max(gparams.keys())
- max_layer_id = max_layer_id - unlocked_groups
- for group_idx in range(max_layer_id + 1):
- group = gparams[group_idx]
- for param in group:
- self.trunk.get_parameter(param).requires_grad = False
- if freeze_bn_stats:
- gmodules = group_modules(self.trunk, matcher, reverse=True)
- gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
- freeze_batch_norm_2d(self.trunk, gmodules)
-
- def forward(self, x):
- x = self.trunk(x)
- x = self.head(x)
- return x
diff --git a/audioldm/clap/open_clip/tokenizer.py b/audioldm/clap/open_clip/tokenizer.py
deleted file mode 100644
index ee4d28450ec5dd12a79daf38cf3088e9e73c2cd5..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/tokenizer.py
+++ /dev/null
@@ -1,197 +0,0 @@
-""" CLIP tokenizer
-
-Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
-"""
-import gzip
-import html
-import os
-from functools import lru_cache
-from typing import Union, List
-
-import ftfy
-import regex as re
-import torch
-
-
-@lru_cache()
-def default_bpe():
- return os.path.join(
- os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
- )
-
-
-@lru_cache()
-def bytes_to_unicode():
- """
- Returns list of utf-8 byte and a corresponding list of unicode strings.
- The reversible bpe codes work on unicode strings.
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
- This is a signficant percentage of your normal, say, 32K bpe vocab.
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
- And avoids mapping to whitespace/control characters the bpe code barfs on.
- """
- bs = (
- list(range(ord("!"), ord("~") + 1))
- + list(range(ord("¡"), ord("¬") + 1))
- + list(range(ord("®"), ord("ÿ") + 1))
- )
- cs = bs[:]
- n = 0
- for b in range(2**8):
- if b not in bs:
- bs.append(b)
- cs.append(2**8 + n)
- n += 1
- cs = [chr(n) for n in cs]
- return dict(zip(bs, cs))
-
-
-def get_pairs(word):
- """Return set of symbol pairs in a word.
- Word is represented as tuple of symbols (symbols being variable-length strings).
- """
- pairs = set()
- prev_char = word[0]
- for char in word[1:]:
- pairs.add((prev_char, char))
- prev_char = char
- return pairs
-
-
-def basic_clean(text):
- text = ftfy.fix_text(text)
- text = html.unescape(html.unescape(text))
- return text.strip()
-
-
-def whitespace_clean(text):
- text = re.sub(r"\s+", " ", text)
- text = text.strip()
- return text
-
-
-class SimpleTokenizer(object):
- def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
- self.byte_encoder = bytes_to_unicode()
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
- merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
- merges = merges[1 : 49152 - 256 - 2 + 1]
- merges = [tuple(merge.split()) for merge in merges]
- vocab = list(bytes_to_unicode().values())
- vocab = vocab + [v + "" for v in vocab]
- for merge in merges:
- vocab.append("".join(merge))
- if not special_tokens:
- special_tokens = ["", ""]
- else:
- special_tokens = ["", ""] + special_tokens
- vocab.extend(special_tokens)
- self.encoder = dict(zip(vocab, range(len(vocab))))
- self.decoder = {v: k for k, v in self.encoder.items()}
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
- self.cache = {t: t for t in special_tokens}
- special = "|".join(special_tokens)
- self.pat = re.compile(
- special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
- re.IGNORECASE,
- )
-
- self.vocab_size = len(self.encoder)
- self.all_special_ids = [self.encoder[t] for t in special_tokens]
-
- def bpe(self, token):
- if token in self.cache:
- return self.cache[token]
- word = tuple(token[:-1]) + (token[-1] + "",)
- pairs = get_pairs(word)
-
- if not pairs:
- return token + ""
-
- while True:
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
- if bigram not in self.bpe_ranks:
- break
- first, second = bigram
- new_word = []
- i = 0
- while i < len(word):
- try:
- j = word.index(first, i)
- new_word.extend(word[i:j])
- i = j
- except:
- new_word.extend(word[i:])
- break
-
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
- new_word.append(first + second)
- i += 2
- else:
- new_word.append(word[i])
- i += 1
- new_word = tuple(new_word)
- word = new_word
- if len(word) == 1:
- break
- else:
- pairs = get_pairs(word)
- word = " ".join(word)
- self.cache[token] = word
- return word
-
- def encode(self, text):
- bpe_tokens = []
- text = whitespace_clean(basic_clean(text)).lower()
- for token in re.findall(self.pat, text):
- token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
- bpe_tokens.extend(
- self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
- )
- return bpe_tokens
-
- def decode(self, tokens):
- text = "".join([self.decoder[token] for token in tokens])
- text = (
- bytearray([self.byte_decoder[c] for c in text])
- .decode("utf-8", errors="replace")
- .replace("", " ")
- )
- return text
-
-
-_tokenizer = SimpleTokenizer()
-
-
-def tokenize(
- texts: Union[str, List[str]], context_length: int = 77
-) -> torch.LongTensor:
- """
- Returns the tokenized representation of given input string(s)
-
- Parameters
- ----------
- texts : Union[str, List[str]]
- An input string or a list of input strings to tokenize
- context_length : int
- The context length to use; all CLIP models use 77 as the context length
-
- Returns
- -------
- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
- """
- if isinstance(texts, str):
- texts = [texts]
-
- sot_token = _tokenizer.encoder[""]
- eot_token = _tokenizer.encoder[""]
- all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
-
- for i, tokens in enumerate(all_tokens):
- if len(tokens) > context_length:
- tokens = tokens[:context_length] # Truncate
- result[i, : len(tokens)] = torch.tensor(tokens)
-
- return result
diff --git a/audioldm/clap/open_clip/transform.py b/audioldm/clap/open_clip/transform.py
deleted file mode 100644
index 77aaa722c4a5544ac50de6df35d3e922f63b111d..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/transform.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from torchvision.transforms import (
- Normalize,
- Compose,
- RandomResizedCrop,
- InterpolationMode,
- ToTensor,
- Resize,
- CenterCrop,
-)
-
-
-def _convert_to_rgb(image):
- return image.convert("RGB")
-
-
-def image_transform(
- image_size: int,
- is_train: bool,
- mean=(0.48145466, 0.4578275, 0.40821073),
- std=(0.26862954, 0.26130258, 0.27577711),
-):
- normalize = Normalize(mean=mean, std=std)
- if is_train:
- return Compose(
- [
- RandomResizedCrop(
- image_size,
- scale=(0.9, 1.0),
- interpolation=InterpolationMode.BICUBIC,
- ),
- _convert_to_rgb,
- ToTensor(),
- normalize,
- ]
- )
- else:
- return Compose(
- [
- Resize(image_size, interpolation=InterpolationMode.BICUBIC),
- CenterCrop(image_size),
- _convert_to_rgb,
- ToTensor(),
- normalize,
- ]
- )
diff --git a/audioldm/clap/open_clip/utils.py b/audioldm/clap/open_clip/utils.py
deleted file mode 100644
index de59fd2746a13742197ecdeac671d61ece3f79ba..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/utils.py
+++ /dev/null
@@ -1,361 +0,0 @@
-import numpy as np
-import torch
-from torch import nn as nn
-from torchvision.ops.misc import FrozenBatchNorm2d
-import logging
-# import h5py
-from tqdm import tqdm
-import random
-import json
-import os
-import pathlib
-
-# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
-dataset_split = {
- "audiocaps": ["train", "valid", "test"],
- "audioset": ["balanced_train", "unbalanced_train", "eval"],
- "BBCSoundEffects": ["train", "test"],
- "Clotho": ["train", "test", "valid"],
- "free_to_use_sounds": ["train", "test"],
- "paramount_motion": ["train", "test"],
- "sonniss_game_effects": ["train", "test"],
- "wesoundeffects": ["train", "test"],
- "MACS": ["train", "test"],
- "freesound": ["train", "test"],
- "FSD50K": ["train", "test", "valid"],
- "fsd50k_class_label": ["train", "test", "valid"],
- "esc50": ["train", "test"],
- "audiostock": ["train", "test"],
- "freesound_no_overlap_noesc50": ["train", "test"],
- "epidemic_sound_effects": ["train", "test"],
- "VGGSound": ["train", "test"],
- "urbansound8k_class_label": ["train", "test"],
- "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
- "epidemic_sound_effects_t5": ["train", "test"],
- "WavText5K": ["train", "test"],
- "esc50_no_overlap": ["train", "test"],
- "usd8k_no_overlap": ["train", "test"],
- "fsd50k_200_class_label": ["train", "test", "valid"],
-}
-
-
-def freeze_batch_norm_2d(module, module_match={}, name=""):
- """
- Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
- itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
- returned. Otherwise, the module is walked recursively and submodules are converted in place.
-
- Args:
- module (torch.nn.Module): Any PyTorch module.
- module_match (dict): Dictionary of full module names to freeze (all if empty)
- name (str): Full module name (prefix)
-
- Returns:
- torch.nn.Module: Resulting module
-
- Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
- """
- res = module
- is_match = True
- if module_match:
- is_match = name in module_match
- if is_match and isinstance(
- module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
- ):
- res = FrozenBatchNorm2d(module.num_features)
- res.num_features = module.num_features
- res.affine = module.affine
- if module.affine:
- res.weight.data = module.weight.data.clone().detach()
- res.bias.data = module.bias.data.clone().detach()
- res.running_mean.data = module.running_mean.data
- res.running_var.data = module.running_var.data
- res.eps = module.eps
- else:
- for child_name, child in module.named_children():
- full_child_name = ".".join([name, child_name]) if name else child_name
- new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
- if new_child is not child:
- res.add_module(child_name, new_child)
- return res
-
-
-def exist(dataset_name, dataset_type):
- """
- Check if dataset exists
- """
- if dataset_type in dataset_split[dataset_name]:
- return True
- else:
- return False
-
-
-def get_tar_path_from_dataset_name(
- dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None
-):
- """
- Get tar path from dataset name and type
- """
- output = []
- for n in dataset_names:
- if full_dataset is not None and n in full_dataset:
- current_dataset_types = dataset_split[n]
- else:
- current_dataset_types = dataset_types
- for s in current_dataset_types:
- tmp = []
- if islocal:
- sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
- if not os.path.exists(sizefilepath_):
- sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
- else:
- sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
- if not os.path.exists(sizefilepath_):
- continue
- sizes = json.load(open(sizefilepath_, "r"))
- for k in sizes.keys():
- if islocal:
- tmp.append(f"{dataset_path}/{n}/{s}/{k}")
- else:
- tmp.append(
- f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
- )
- if proportion != 1:
- tmp = random.sample(tmp, int(proportion * len(tmp)))
- output.append(tmp)
- return sum(output, [])
-
-
-def get_tar_path_from_txts(txt_path, islocal, proportion=1):
- """
- Get tar path from txt path
- """
- if isinstance(txt_path, (list, tuple)):
- return sum(
- [
- get_tar_path_from_txts(
- txt_path[i], islocal=islocal, proportion=proportion
- )
- for i in range(len(txt_path))
- ],
- [],
- )
- if isinstance(txt_path, str):
- with open(txt_path) as f:
- lines = f.readlines()
- if islocal:
- lines = [
- lines[i]
- .split("\n")[0]
- .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
- for i in range(len(lines))
- ]
- else:
- lines = [
- lines[i].split("\n")[0].replace(".tar", ".tar -")
- for i in range(len(lines))
- ]
- if proportion != 1:
- print("Sampling tars with proportion of {}".format(proportion))
- lines = random.sample(lines, int(proportion * len(lines)))
- return lines
-
-
-def get_mix_lambda(mixup_alpha, batch_size):
- mixup_lambdas = [
- np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
- ]
- return np.array(mixup_lambdas).astype(np.float32)
-
-
-def do_mixup(x, mixup_lambda):
- """
- Args:
- x: (batch_size , ...)
- mixup_lambda: (batch_size,)
- Returns:
- out: (batch_size, ...)
- """
- out = (
- x.transpose(0, -1) * mixup_lambda
- + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
- ).transpose(0, -1)
- return out
-
-
-def interpolate(x, ratio):
- """Interpolate data in time domain. This is used to compensate the
- resolution reduction in downsampling of a CNN.
-
- Args:
- x: (batch_size, time_steps, classes_num)
- ratio: int, ratio to interpolate
- Returns:
- upsampled: (batch_size, time_steps * ratio, classes_num)
- """
- (batch_size, time_steps, classes_num) = x.shape
- upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
- upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
- return upsampled
-
-
-def pad_framewise_output(framewise_output, frames_num):
- """Pad framewise_output to the same length as input frames. The pad value
- is the same as the value of the last frame.
- Args:
- framewise_output: (batch_size, frames_num, classes_num)
- frames_num: int, number of frames to pad
- Outputs:
- output: (batch_size, frames_num, classes_num)
- """
- pad = framewise_output[:, -1:, :].repeat(
- 1, frames_num - framewise_output.shape[1], 1
- )
- """tensor for padding"""
-
- output = torch.cat((framewise_output, pad), dim=1)
- """(batch_size, frames_num, classes_num)"""
-
-
-# def process_ipc(index_path, classes_num, filename):
-# # load data
-# logging.info("Load Data...............")
-# ipc = [[] for _ in range(classes_num)]
-# with h5py.File(index_path, "r") as f:
-# for i in tqdm(range(len(f["target"]))):
-# t_class = np.where(f["target"][i])[0]
-# for t in t_class:
-# ipc[t].append(i)
-# print(ipc)
-# np.save(filename, ipc)
-# logging.info("Load Data Succeed...............")
-
-
-def save_to_dict(s, o_={}):
- sp = s.split(": ")
- o_.update({sp[0]: float(sp[1])})
- return o_
-
-
-def get_data_from_log(txt_path):
- """
- Output dictionary from out.txt log file
- """
- with open(txt_path) as f:
- lines = f.readlines()
- val_data = {}
- train_data = {}
- train_losses = []
- train_losses_epoch = []
- for i in range(len(lines)):
- if "| INFO |" in lines[i]:
- if "Eval Epoch" in lines[i]:
- if "val_loss" in lines[i]:
- # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
- line = lines[i].split("Eval Epoch: ")[-1]
- num_epoch = int(line.split(" ")[0].split(" ")[0])
- d = {
- line.split(" ")[0]
- .split(" ")[1]
- .replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
- }
- for i in range(1, len(line.split(" "))):
- d = save_to_dict(line.split(" ")[i], d)
- val_data[num_epoch] = d
- elif "Train Epoch" in lines[i]:
- num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
- loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
- train_losses.append(loss)
- train_losses_epoch.append(num_epoch)
- for i in range(len(train_losses)):
- train_data[i] = {
- "num_epoch": train_losses_epoch[i],
- "train_loss": train_losses[i],
- }
- return train_data, val_data
-
-
-def save_p(obj, filename):
- import pickle
-
- try:
- from deepdiff import DeepDiff
- except:
- os.system("pip install deepdiff")
- from deepdiff import DeepDiff
- with open(filename, "wb") as file:
- pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
- with open(filename, "rb") as file:
- z = pickle.load(file)
- assert (
- DeepDiff(obj, z, ignore_string_case=True) == {}
- ), "there is something wrong with the saving process"
- return
-
-
-def load_p(filename):
- import pickle
-
- with open(filename, "rb") as file:
- z = pickle.load(file)
- return z
-
-
-def save_json(data, name="data.json"):
- import json
-
- with open(name, "w") as fp:
- json.dump(data, fp)
- return
-
-
-def load_json(name):
- import json
-
- with open(name, "r") as fp:
- data = json.load(fp)
- return data
-
-
-from multiprocessing import Process, Manager
-from multiprocessing import Process, Value, Array
-from ctypes import c_wchar
-
-
-def load_class_label(path):
- # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
- # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
- out = None
- if path is not None:
- if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
- out = load_p(path)
- elif pathlib.Path(path).suffix in [".json", ".txt"]:
- out = load_json(path)
- elif pathlib.Path(path).suffix in [".npy", ".npz"]:
- out = np.load(path)
- elif pathlib.Path(path).suffix in [".csv"]:
- import pandas as pd
-
- out = pd.read_csv(path)
- return out
- # if out is None:
- # return None
- # else:
- # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
- # val = Array('i', out.values(), lock=False)
- # return (key, val)
-
-
-from torch import optim
-
-
-def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
- if optimizer_name.lower() == "adamw":
- optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps)
- elif optimizer_name.lower() == "sgd":
- optimizer = optim.SGD(params, lr=lr, momentum=momentum)
- elif optimizer_name.lower() == "adam":
- optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps)
- else:
- raise ValueError("optimizer name is not correct")
- return optimizer
diff --git a/audioldm/clap/open_clip/version.py b/audioldm/clap/open_clip/version.py
deleted file mode 100644
index 3ced3581bb601ae91b1e1da4b8f4f520855a065e..0000000000000000000000000000000000000000
--- a/audioldm/clap/open_clip/version.py
+++ /dev/null
@@ -1 +0,0 @@
-__version__ = "0.2.1"
diff --git a/audioldm/clap/training/__init__.py b/audioldm/clap/training/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/audioldm/clap/training/audioset_textmap.npy b/audioldm/clap/training/audioset_textmap.npy
deleted file mode 100644
index 3da4c92d3819aaec11e5f576464a9973a6df811b..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/audioset_textmap.npy
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b
-size 84448
diff --git a/audioldm/clap/training/data.py b/audioldm/clap/training/data.py
deleted file mode 100644
index 1d80d598be97d4e04f1b7f3e53a877cfe82ce667..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/data.py
+++ /dev/null
@@ -1,977 +0,0 @@
-import ast
-import json
-import logging
-import math
-import os
-import random
-# import h5py
-from dataclasses import dataclass
-from audioldm.clap.training.params import parse_args
-# import braceexpand
-import numpy as np
-import pandas as pd
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torchvision.datasets as datasets
-import torchvision.transforms
-# import webdataset as wds
-from PIL import Image
-from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
-from torch.utils.data.distributed import DistributedSampler
-from functools import partial
-import soundfile as sf
-import io
-from pathlib import Path
-# import wget
-
-from audioldm.clap.open_clip.utils import (
- get_tar_path_from_dataset_name,
- dataset_split,
-)
-from audioldm.clap.open_clip.utils import load_p, load_class_label
-import copy
-
-try:
- import horovod.torch as hvd
-except ImportError:
- hvd = None
-
-try:
- import torchaudio
-except ImportError:
- torchaudio = None
-
-from audioldm.clap.open_clip import tokenize
-
-
-def tokenizer(text):
- return tokenize(text).squeeze(0)
-
-
-from transformers import RobertaTokenizer
-
-tokenize = RobertaTokenizer.from_pretrained("roberta-base")
-
-
-def tokenizer(text):
- result = tokenize(
- text,
- padding="max_length",
- truncation=True,
- max_length=77,
- return_tensors="pt",
- )
- return {k: v.squeeze(0) for k, v in result.items()}
-
-
-# initizlied the audioset map
-_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy")
-_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True)
-
-
-def int16_to_float32(x):
- return (x / 32767.0).astype(np.float32)
-
-
-def float32_to_int16(x):
- x = np.clip(x, a_min=-1.0, a_max=1.0)
- return (x * 32767.0).astype(np.int16)
-
-
-# For Toy Dataset
-# class ToyDataset(Dataset):
-# def __init__(self, index_path, ipc, config, eval_mode=False):
-# """Toy Dataset for testing the audioset input with text labels
-# Parameters
-# ----------
-# index_path: str
-# the link to the h5 file of each audio
-# idc: str
-# the link to the npy file, the number of samples in each class
-# config: dict
-# the audio cfg file
-# eval_model (bool): to indicate if the dataset is a testing dataset
-# """
-# self.audio_cfg = config["audio_cfg"]
-# self.text_cfg = config["text_cfg"]
-# self.fp = h5py.File(index_path, "r")
-# self.ipc = np.load(ipc, allow_pickle=True)
-# self.total_size = len(self.fp["audio_name"])
-# self.classes_num = self.audio_cfg["class_num"]
-# self.eval_mode = eval_mode
-
-# if not eval_mode:
-# self.generate_queue()
-# else:
-# self.queue = []
-# for i in range(self.total_size):
-# target = self.fp["target"][i]
-# if np.sum(target) > 0:
-# self.queue.append(i)
-# self.total_size = len(self.queue)
-# logging.info("total dataset size: %d" % (self.total_size))
-# logging.info("class num: %d" % (self.classes_num))
-
-# def time_shifting(self, x):
-# frame_num = len(x)
-# shift_len = random.randint(0, frame_num - 1)
-# new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0)
-# return new_sample
-
-# def generate_queue(self):
-# self.queue = []
-# while len(self.queue) < self.total_size:
-# class_set = [*range(self.classes_num)]
-# random.shuffle(class_set)
-# self.queue += [
-# self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set
-# ]
-# self.queue = self.queue[: self.total_size]
-
-# logging.info("queue regenerated:%s" % (self.queue[-5:]))
-
-# def crop_wav(self, x):
-# crop_size = self.audio_cfg["crop_size"]
-# crop_pos = random.randint(0, len(x) - crop_size - 1)
-# return x[crop_pos : crop_pos + crop_size]
-
-# def prompt_text(self, target):
-# events = _AUDIOSET_MAP[np.where(target > 0)]
-# event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1]
-# text = tokenize(event_text)[0]
-# return text
-
-# def __getitem__(self, index):
-# """Load waveform, text, and target of an audio clip
-
-# Parameters
-# ----------
-# index: int
-# the index number
-# Return
-# ------
-# output: dict {
-# "hdf5_path": str,
-# "index_in_hdf5": int,
-# "audio_name": str,
-# "waveform": list (audio_length,),
-# "target": list (class_num, ),
-# "text": torch.tensor (context_length,)
-# }
-# the output dictionary
-# """
-# s_index = self.queue[index]
-
-# audio_name = self.fp["audio_name"][s_index].decode()
-# # Hardcode here CHANGE
-# hdf5_path = (
-# self.fp["hdf5_path"][s_index]
-# .decode()
-# .replace(
-# "../workspace",
-# "/home/la/kechen/Research/ke_zsasp/workspace",
-# )
-# )
-# r_idx = self.fp["index_in_hdf5"][s_index]
-# target = self.fp["target"][s_index].astype(np.float32)
-# text = self.prompt_text(target)
-# with h5py.File(hdf5_path, "r") as f:
-# waveform = int16_to_float32(f["waveform"][r_idx])[
-# : self.audio_cfg["clip_samples"]
-# ]
-# assert (
-# len(waveform) == self.audio_cfg["clip_samples"]
-# ), "The sample length is not match"
-# # Time shift
-# # if (self.config.enable_time_shift) and (not self.eval_mode):
-# # waveform = self.time_shifting(waveform)
-# # # Label Enhance
-# # if (self.config.crop_size is not None) and (not self.eval_mode):
-# # waveform = self.crop_wav(waveform)
-# # # the label enhance rate is fixed 0.5
-# # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5:
-# # kidx = np.where(target)[0]
-# # for k in kidx:
-# # for add_key in self.class_map[k][1]:
-# # target[add_key] = 1.0
-# # if len(self.class_map[k][2]) > 0:
-# # add_key = random.choice(self.class_map[k][2])
-# # target[add_key] = 1.0
-
-# # missing the text input
-# mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :]
-# mel_spec = (
-# torch.cat(
-# [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0
-# )
-# .cpu()
-# .numpy()
-# )
-# longer = random.choice([True, False])
-# if longer == False:
-# mel_spec[1:, :, :] = 0.0
-# data_dict = {
-# "hdf5_path": hdf5_path,
-# "index_in_hdf5": r_idx,
-# "audio_name": audio_name,
-# "waveform": waveform,
-# "class_label": target,
-# "text": text,
-# "longer": longer,
-# "mel_fusion": mel_spec,
-# }
-# return data_dict
-
-# def __len__(self):
-# return self.total_size
-
-
-class CsvDataset(Dataset):
- def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
- logging.debug(f"Loading csv data from {input_filename}.")
- df = pd.read_csv(input_filename, sep=sep)
-
- self.images = df[img_key].tolist()
- self.captions = df[caption_key].tolist()
- self.transforms = transforms
- logging.debug("Done loading data.")
-
- def __len__(self):
- return len(self.captions)
-
- def __getitem__(self, idx):
- images = self.transforms(Image.open(str(self.images[idx])))
- texts = tokenize([str(self.captions[idx])])[0]
- return images, texts
-
-
-@dataclass
-class DataInfo:
- dataloader: DataLoader
- sampler: DistributedSampler
-
-
-def preprocess_txt(text):
- return tokenize([str(text)])[0]
-
-
-def get_dataset_size(shards, sizefilepath_=None, is_local=True):
- if isinstance(shards, list):
- size_list = []
- for s in shards:
- size_list.append(
- get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0]
- )
- else:
- if not is_local:
- for n in dataset_split.keys():
- if n in shards.split("/"):
- break
- for s in dataset_split[n]:
- if s in shards.split("/"):
- break
- sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
- shards_list = list(braceexpand.braceexpand(shards))
- dir_path = os.path.dirname(shards)
- if sizefilepath_ is not None:
- sizes = json.load(open(sizefilepath_, "r"))
- total_size = sum(
- [
- int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))])
- for shard in shards_list
- ]
- )
- else:
- sizes_filename = os.path.join(dir_path, "sizes.json")
- len_filename = os.path.join(dir_path, "__len__")
- if os.path.exists(sizes_filename):
- sizes = json.load(open(sizes_filename, "r"))
- total_size = sum(
- [int(sizes[os.path.basename(shard)]) for shard in shards_list]
- )
- elif os.path.exists(len_filename):
- # FIXME this used to be eval(open(...)) but that seemed rather unsafe
- total_size = ast.literal_eval(open(len_filename, "r").read())
- else:
- raise Exception(
- "Cannot find sizes file for dataset. Please specify the path to the file."
- )
- # total_size = None # num samples undefined
- # some common dataset sizes (at time of authors last download)
- # cc3m-train: 2905954
- # cc12m: 10968539
- # LAION-400m: 407332084
- num_shards = len(shards_list)
- if isinstance(shards, list):
- return sum(size_list), len(shards)
- else:
- return total_size, num_shards
-
-
-def get_imagenet(args, preprocess_fns, split):
- assert split in ["train", "val", "v2"]
- is_train = split == "train"
- preprocess_train, preprocess_val = preprocess_fns
-
- if split == "v2":
- from imagenetv2_pytorch import ImageNetV2Dataset
-
- dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
- else:
- if is_train:
- data_path = args.imagenet_train
- preprocess_fn = preprocess_train
- else:
- data_path = args.imagenet_val
- preprocess_fn = preprocess_val
- assert data_path
-
- dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
-
- if is_train:
- idxs = np.zeros(len(dataset.targets))
- target_array = np.array(dataset.targets)
- k = 50
- for c in range(1000):
- m = target_array == c
- n = len(idxs[m])
- arr = np.zeros(n)
- arr[:k] = 1
- np.random.shuffle(arr)
- idxs[m] = arr
-
- idxs = idxs.astype("int")
- sampler = SubsetRandomSampler(np.where(idxs)[0])
- else:
- sampler = None
-
- dataloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=args.batch_size,
- num_workers=args.workers,
- sampler=sampler,
- )
-
- return DataInfo(dataloader, sampler)
-
-
-def count_samples(dataloader):
- os.environ["WDS_EPOCH"] = "0"
- n_elements, n_batches = 0, 0
- for images, texts in dataloader:
- n_batches += 1
- n_elements += len(images)
- assert len(images) == len(texts)
- return n_elements, n_batches
-
-
-def filter_no_caption(sample):
- return "txt" in sample
-
-
-def log_and_continue(exn):
- """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
- logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
- return True
-
-
-_SHARD_SHUFFLE_SIZE = 2000
-_SHARD_SHUFFLE_INITIAL = 500
-_SAMPLE_SHUFFLE_SIZE = 5000
-_SAMPLE_SHUFFLE_INITIAL = 1000
-
-
-def sample_prop(sizefile, inputs, proportion, is_local=True):
- """
- Sample a proportion of the data.
- """
- file_path_dict = {
- os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0]
- for i in range(len(inputs))
- }
- sampled_filepath_dict = {}
- sampled_size_dict = {}
- if not is_local:
- if os.path.exists("sizes.json"):
- os.remove("sizes.json")
- wget.download(sizefile, "sizes.json")
- sizefile = "sizes.json"
- with open(sizefile, "r", encoding="UTF-8") as f:
- load_dict = json.load(f)
- L = int(len(file_path_dict) * proportion)
- subkeys = random.sample(file_path_dict.keys(), L)
- for k in subkeys:
- sampled_size_dict[k] = load_dict[k]
- sampled_filepath_dict[k] = file_path_dict[k]
- return (
- sum(sampled_size_dict.values()),
- L,
- [os.path.join(v, k) for k, v in sampled_filepath_dict.items()],
- sampled_size_dict,
- )
-
-
-def get_mel(audio_data, audio_cfg):
- # mel shape: (n_mels, T)
- mel = torchaudio.transforms.MelSpectrogram(
- sample_rate=audio_cfg["sample_rate"],
- n_fft=audio_cfg["window_size"],
- win_length=audio_cfg["window_size"],
- hop_length=audio_cfg["hop_size"],
- center=True,
- pad_mode="reflect",
- power=2.0,
- norm=None,
- onesided=True,
- n_mels=64,
- f_min=audio_cfg["fmin"],
- f_max=audio_cfg["fmax"],
- ).to(audio_data.device)
- mel = mel(audio_data)
- # Align to librosa:
- # librosa_melspec = librosa.feature.melspectrogram(
- # waveform,
- # sr=audio_cfg['sample_rate'],
- # n_fft=audio_cfg['window_size'],
- # hop_length=audio_cfg['hop_size'],
- # win_length=audio_cfg['window_size'],
- # center=True,
- # pad_mode="reflect",
- # power=2.0,
- # n_mels=64,
- # norm=None,
- # htk=True,
- # f_min=audio_cfg['fmin'],
- # f_max=audio_cfg['fmax']
- # )
- # we use log mel spectrogram as input
- mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
- return mel.T # (T, n_mels)
-
-
-def get_audio_features(
- sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
-):
- """
- Calculate and add audio features to sample.
- Sample: a dict containing all the data of current sample.
- audio_data: a tensor of shape (T) containing audio data.
- max_len: the maximum length of audio data.
- data_truncating: the method of truncating data.
- data_filling: the method of filling data.
- audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
- """
- with torch.no_grad():
- if len(audio_data) > max_len:
- if data_truncating == "rand_trunc":
- longer = torch.tensor([True])
- elif data_truncating == "fusion":
- # fusion
- mel = get_mel(audio_data, audio_cfg)
- # split to three parts
- chunk_frames = (
- max_len // audio_cfg["hop_size"] + 1
- ) # the +1 related to how the spectrogram is computed
- total_frames = mel.shape[0]
- if chunk_frames == total_frames:
- # there is a corner case where the audio length is
- # larger than max_len but smaller than max_len+hop_size.
- # In this case, we just use the whole audio.
- mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
- sample["mel_fusion"] = mel_fusion
- longer = torch.tensor([False])
- else:
- ranges = np.array_split(
- list(range(0, total_frames - chunk_frames + 1)), 3
- )
- # print('total_frames-chunk_frames:', total_frames-chunk_frames,
- # 'len(audio_data):', len(audio_data),
- # 'chunk_frames:', chunk_frames,
- # 'total_frames:', total_frames)
- if len(ranges[1]) == 0:
- # if the audio is too short, we just use the first chunk
- ranges[1] = [0]
- if len(ranges[2]) == 0:
- # if the audio is too short, we just use the first chunk
- ranges[2] = [0]
- # randomly choose index for each part
- idx_front = np.random.choice(ranges[0])
- idx_middle = np.random.choice(ranges[1])
- idx_back = np.random.choice(ranges[2])
- # select mel
- mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :]
- mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :]
- mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :]
-
- # shrink the mel
- mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])(
- mel[None]
- )[0]
- # logging.info(f"mel_shrink.shape: {mel_shrink.shape}")
-
- # stack
- mel_fusion = torch.stack(
- [mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink],
- dim=0,
- )
- sample["mel_fusion"] = mel_fusion
- longer = torch.tensor([True])
- else:
- raise NotImplementedError(
- f"data_truncating {data_truncating} not implemented"
- )
- # random crop to max_len (for compatibility)
- overflow = len(audio_data) - max_len
- idx = np.random.randint(0, overflow + 1)
- audio_data = audio_data[idx : idx + max_len]
-
- else: # padding if too short
- if len(audio_data) < max_len: # do nothing if equal
- if data_filling == "repeatpad":
- n_repeat = int(max_len / len(audio_data))
- audio_data = audio_data.repeat(n_repeat)
- # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0)
- # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0]
- audio_data = F.pad(
- audio_data,
- (0, max_len - len(audio_data)),
- mode="constant",
- value=0,
- )
- elif data_filling == "pad":
- audio_data = F.pad(
- audio_data,
- (0, max_len - len(audio_data)),
- mode="constant",
- value=0,
- )
- elif data_filling == "repeat":
- n_repeat = int(max_len / len(audio_data))
- audio_data = audio_data.repeat(n_repeat + 1)[:max_len]
- else:
- raise NotImplementedError(
- f"data_filling {data_filling} not implemented"
- )
- if data_truncating == "fusion":
- mel = get_mel(audio_data, audio_cfg)
- mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
- sample["mel_fusion"] = mel_fusion
- longer = torch.tensor([False])
-
- sample["longer"] = longer
- sample["waveform"] = audio_data
-
- return sample
-
-
-def preprocess(
- sample,
- audio_ext,
- text_ext,
- max_len,
- audio_cfg,
- class_index_dict=None,
- data_filling="pad",
- data_truncating="rand_trunc",
- text_augment_selection=None,
-):
- """
- Preprocess a single sample for wdsdataloader.
- """
- audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
- audio_data = int16_to_float32(float32_to_int16(audio_data))
- audio_data = torch.tensor(audio_data).float()
-
- # TODO: (yusong) to be include in the future
- # # if torchaudio not installed, use soundfile to load audio
- # if torchaudio is None:
- # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
- # audio_data = torch.tensor(audio_data).float()
- # else:
- # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py
- # with tempfile.TemporaryDirectory() as dirname:
- # os.makedirs(dirname, exist_ok=True)
- # fname = os.path.join(dirname, f"file.flac")
- # with open(fname, "wb") as stream:
- # stream.write(sample[audio_ext])
- # audio_data, orig_sr = torchaudio.load(fname)
- # audio_data = audio_data[0, :].float()
-
- sample = get_audio_features(
- sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
- )
- del sample[audio_ext]
-
- try:
- json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
- except:
- print("sample[__url__]:", sample["__url__"])
-
- # For selecting augmented text from dataset
- if text_augment_selection is None or text_augment_selection == "none":
- texts = json_dict_raw["text"]
- elif text_augment_selection == "all":
- if "text_augment_all" in json_dict_raw.keys():
- texts = json_dict_raw["text_augment_all"]
- else:
- texts = json_dict_raw["text"]
- elif text_augment_selection == "augment_only":
- if "text_augment_all" in json_dict_raw.keys():
- if json_dict_raw["text_augment_t5"] is None:
- texts = json_dict_raw["text"]
- else:
- texts = json_dict_raw["text_augment_t5"]
- else:
- texts = json_dict_raw["text"]
- else:
- raise NotImplementedError(
- f"text_augment_selection {text_augment_selection} not implemented"
- )
- sample["full_text"] = texts
-
- if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
- texts = random.choice(texts)
- sample["raw_text"] = texts
- sample["text"] = tokenizer(texts) # text shape: [num_token]
- if class_index_dict is not None:
- # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
- # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
- # key, val = class_index_dict
- # key = key[:].split('\n')
- # _dict = {k: v for k, v in zip(key, val)}
- sample["class_label"] = np.zeros(len(class_index_dict.keys()))
- for x in json_dict_raw["tag"]:
- sample["class_label"][class_index_dict[x]] = 1
- sample["class_label"] = torch.tensor(sample["class_label"]).float()
- del sample[text_ext]
- sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext
- sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext
- sample["audio_orig_sr"] = orig_sr
- return sample
-
-
-def collate_fn(batch):
- """
- Collate function for wdsdataloader.
- batch: a list of dict, each dict is a sample
- """
- # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend.
- batch_dict = {}
- for k in batch[0].keys():
- if isinstance(batch[0][k], dict): # dealwith bert tokenizer output
- batch_dict[k] = {}
- for kk in batch[0][k].keys():
- tmp = []
- for i in range(len(batch)):
- tmp.append(batch[i][k][kk])
- batch_dict[k][kk] = torch.vstack(tmp)
- elif isinstance(batch[0][k], torch.Tensor):
- batch_dict[k] = torch.stack([sample[k] for sample in batch])
- elif isinstance(batch[0][k], np.ndarray):
- batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch]))
- else:
- batch_dict[k] = [sample[k] for sample in batch]
- return batch_dict
-
-
-def get_wds_dataset(
- args,
- model_cfg,
- is_train,
- audio_ext="flac",
- text_ext="json",
- max_len=480000,
- proportion=1.0,
- sizefilepath_=None,
- is_local=None,
-):
- """
- Get a dataset for wdsdataloader.
- """
- if is_local is None and (not args.remotedata is None):
- is_local = not args.remotedata
-
- input_shards = args.train_data if is_train else args.val_data
- assert input_shards is not None
-
- if not sizefilepath_ is None:
- sizefilepath = sizefilepath_
- else:
- sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json")
-
- if proportion != 1.0:
- num_samples, num_shards, input_shards, _ = sample_prop(
- sizefilepath, input_shards, proportion, is_local=is_local
- )
- else:
- num_samples, num_shards = get_dataset_size(
- input_shards, sizefilepath_=sizefilepath_, is_local=is_local
- )
-
- if not num_samples:
- if is_train:
- num_samples = args.train_num_samples
- if not num_samples:
- raise RuntimeError(
- "Currently, number of dataset samples must be specified for training dataset. "
- "Please specify via `--train-num-samples` if no dataset length info present."
- )
- else:
- num_samples = (
- args.val_num_samples or 0
- ) # eval will just exhaust the iterator if not specified
-
- pipeline = [wds.SimpleShardList(input_shards)]
- # at this point we have an iterator over all the shards
- # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node
- if is_train or args.parallel_eval:
- pipeline.extend(
- [
- wds.detshuffle(
- bufsize=_SHARD_SHUFFLE_SIZE,
- initial=_SHARD_SHUFFLE_INITIAL,
- seed=args.seed,
- ),
- wds.split_by_node,
- wds.split_by_worker,
- # at this point, we have an iterator over the shards assigned to each worker at each node
- wds.tarfile_to_samples(handler=log_and_continue),
- wds.shuffle(
- bufsize=_SAMPLE_SHUFFLE_SIZE,
- initial=_SAMPLE_SHUFFLE_INITIAL,
- rng=random.Random(args.seed),
- ),
- # wds.repeatedly, # FIXME determine if this is beneficial
- ]
- )
- else:
- pipeline.extend(
- [
- wds.split_by_worker,
- # at this point, we have an iterator over the shards assigned to each worker
- wds.tarfile_to_samples(handler=log_and_continue),
- ]
- )
- pipeline.append(
- wds.map(
- partial(
- preprocess,
- audio_ext=audio_ext,
- text_ext=text_ext,
- max_len=max_len,
- audio_cfg=model_cfg["audio_cfg"],
- class_index_dict=copy.deepcopy(args.class_index_dict),
- data_filling=args.data_filling,
- data_truncating=args.data_truncating,
- text_augment_selection=args.text_augment_selection,
- )
- ),
- )
-
- pipeline.append(
- wds.batched(
- args.batch_size,
- partial=not (is_train or args.parallel_eval),
- collation_fn=collate_fn,
- )
- )
-
- dataset = wds.DataPipeline(*pipeline)
- if is_train or args.parallel_eval:
- # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples.
- # (yusong): See comments below.
- # roll over and repeat a few samples to get same number of full batches on each node
- global_batch_size = args.batch_size * args.world_size
- num_batches = math.ceil(num_samples / global_batch_size)
- num_workers = max(1, args.workers)
- num_worker_batches = math.ceil(
- num_batches / num_workers
- ) # per dataloader worker
- num_batches = num_worker_batches * num_workers
- num_samples = num_batches * global_batch_size
- dataset = dataset.with_epoch(
- num_worker_batches
- ) # each worker is iterating over this
- else:
- # last batches are partial, eval is done on single (master) node
- num_batches = math.ceil(num_samples / args.batch_size)
-
- kwargs = {}
- if args.horovod: # multi-node training on summit
- kwargs["multiprocessing_context"] = "forkserver"
-
- dataloader = wds.WebLoader(
- dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs
- )
-
- # FIXME not clear which approach is better, with_epoch before vs after dataloader?
- # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
- # if is_train:
- # # roll over and repeat a few samples to get same number of full batches on each node
- # global_batch_size = args.batch_size * args.world_size
- # num_batches = math.ceil(num_samples / global_batch_size)
- # num_workers = max(1, args.workers)
- # num_batches = math.ceil(num_batches / num_workers) * num_workers
- # num_samples = num_batches * global_batch_size
- # dataloader = dataloader.with_epoch(num_batches)
- # else:
- # # last batches are partial, eval is done on single (master) node
- # num_batches = math.ceil(num_samples / args.batch_size)
-
- # add meta-data to dataloader instance for convenience
- dataloader.num_batches = num_batches
- dataloader.num_samples = num_samples
-
- return DataInfo(dataloader, None)
-
-
-def wds_batch_list2dict(
- batch,
- keys=[
- "__url__",
- "__key__",
- "waveform",
- "text",
- "raw_text",
- "audio_name",
- "text_name",
- "audio_orig_sr",
- ],
-):
- """
- Return a dictionary of the batch, with keys as the names of the fields.
- """
- assert len(keys) == len(
- batch
- ), "batch must have same number of keys as keys argument"
- return {keys[i]: batch[i] for i in range(len(batch))}
-
-
-def get_csv_dataset(args, preprocess_fn, is_train):
- input_filename = args.train_data if is_train else args.val_data
- assert input_filename
- dataset = CsvDataset(
- input_filename,
- preprocess_fn,
- img_key=args.csv_img_key,
- caption_key=args.csv_caption_key,
- sep=args.csv_separator,
- )
- num_samples = len(dataset)
- sampler = DistributedSampler(dataset) if args.distributed and is_train else None
- shuffle = is_train and sampler is None
-
- dataloader = DataLoader(
- dataset,
- batch_size=args.batch_size,
- shuffle=shuffle,
- num_workers=args.workers,
- pin_memory=True,
- sampler=sampler,
- drop_last=is_train,
- )
- dataloader.num_samples = num_samples
- dataloader.num_batches = len(dataloader)
-
- return DataInfo(dataloader, sampler)
-
-
-def get_toy_dataset(args, model_cfg, is_train):
- index_path = args.train_data if is_train else args.val_data
- ipc_path = args.train_ipc if is_train else args.val_ipc
- assert index_path and ipc_path
- eval_mode = not is_train
- dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode)
-
- num_samples = len(dataset)
- sampler = (
- DistributedSampler(dataset, shuffle=False)
- if args.distributed and is_train
- else None
- )
-
- dataloader = DataLoader(
- dataset,
- batch_size=args.batch_size,
- shuffle=False,
- num_workers=args.workers,
- sampler=sampler,
- drop_last=is_train,
- )
- dataloader.num_samples = num_samples
- dataloader.num_batches = len(dataloader)
-
- return DataInfo(dataloader, sampler)
-
-
-def get_dataset_fn(data_path, dataset_type):
- if dataset_type == "webdataset":
- return get_wds_dataset
- elif dataset_type == "csv":
- return get_csv_dataset
- elif dataset_type == "auto":
- ext = data_path.split(".")[-1]
- if ext in ["csv", "tsv"]:
- return get_csv_dataset
- elif ext in ["tar"]:
- return get_wds_dataset
- else:
- raise ValueError(
- f"Tried to figure out dataset type, but failed for extention {ext}."
- )
- elif dataset_type == "toy":
- return get_toy_dataset
- else:
- raise ValueError(f"Unsupported dataset type: {dataset_type}")
-
-
-def get_data(args, model_cfg):
- data = {}
-
- args.class_index_dict = load_class_label(args.class_label_path)
-
- if args.datasetinfos is None:
- args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
- if args.dataset_type == "webdataset":
- args.train_data = get_tar_path_from_dataset_name(
- args.datasetnames,
- args.datasetinfos,
- islocal=not args.remotedata,
- proportion=args.dataset_proportion,
- dataset_path=args.datasetpath,
- full_dataset=args.full_train_dataset,
- )
-
- if args.full_train_dataset is None:
- args.full_train_dataset = []
- if args.exclude_eval_dataset is None:
- args.exclude_eval_dataset = []
- excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset
-
- val_dataset_names = (
- [n for n in args.datasetnames if n not in excluded_eval_datasets]
- if excluded_eval_datasets
- else args.datasetnames
- )
- args.val_dataset_names = val_dataset_names
- args.val_data = get_tar_path_from_dataset_name(
- val_dataset_names,
- ["valid", "test", "eval"],
- islocal=not args.remotedata,
- proportion=1,
- dataset_path=args.datasetpath,
- full_dataset=None,
- )
-
- if args.train_data:
- data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
- args, model_cfg, is_train=True
- )
-
- if args.val_data:
- data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
- args, model_cfg, is_train=False
- )
-
- return data
diff --git a/audioldm/clap/training/distributed.py b/audioldm/clap/training/distributed.py
deleted file mode 100644
index 2fa61f76c5cc3ab9f6a9643042afa8e1f2e1cb7f..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/distributed.py
+++ /dev/null
@@ -1,150 +0,0 @@
-import os
-
-import torch
-import socket
-
-try:
- import horovod.torch as hvd
-except ImportError:
- hvd = None
-
-
-def is_global_master(args):
- return args.rank == 0
-
-
-def is_local_master(args):
- return args.local_rank == 0
-
-
-def is_master(args, local=False):
- return is_local_master(args) if local else is_global_master(args)
-
-
-def is_using_horovod():
- # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
- # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
- ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
- pmi_vars = ["PMI_RANK", "PMI_SIZE"]
- if all([var in os.environ for var in ompi_vars]) or all(
- [var in os.environ for var in pmi_vars]
- ):
- return True
- else:
- return False
-
-
-def is_using_distributed():
- if "WORLD_SIZE" in os.environ:
- return int(os.environ["WORLD_SIZE"]) > 1
- if "SLURM_NTASKS" in os.environ:
- return int(os.environ["SLURM_NTASKS"]) > 1
- return False
-
-
-def world_info_from_env():
- local_rank = 0
- for v in (
- "SLURM_LOCALID",
- "MPI_LOCALRANKID",
- "OMPI_COMM_WORLD_LOCAL_RANK",
- "LOCAL_RANK",
- ):
- if v in os.environ:
- local_rank = int(os.environ[v])
- break
- global_rank = 0
- for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"):
- if v in os.environ:
- global_rank = int(os.environ[v])
- break
- world_size = 1
- for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"):
- if v in os.environ:
- world_size = int(os.environ[v])
- break
-
- return local_rank, global_rank, world_size
-
-
-def init_distributed_device(args):
- # Distributed training = training on more than one GPU.
- # Works in both single and multi-node scenarios.
- args.distributed = False
- args.world_size = 1
- args.rank = 0 # global rank
- args.local_rank = 0
- if args.horovod:
- assert hvd is not None, "Horovod is not installed"
- hvd.init()
- world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
- world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
- local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
- args.local_rank = local_rank
- args.rank = world_rank
- args.world_size = world_size
- # args.local_rank = int(hvd.local_rank())
- # args.rank = hvd.rank()
- # args.world_size = hvd.size()
- args.distributed = True
- os.environ["LOCAL_RANK"] = str(args.local_rank)
- os.environ["RANK"] = str(args.rank)
- os.environ["WORLD_SIZE"] = str(args.world_size)
- print(
- f"Distributed training: local_rank={args.local_rank}, "
- f"rank={args.rank}, world_size={args.world_size}, "
- f"hostname={socket.gethostname()}, pid={os.getpid()}"
- )
- elif is_using_distributed():
- if "SLURM_PROCID" in os.environ:
- # DDP via SLURM
- args.local_rank, args.rank, args.world_size = world_info_from_env()
- # SLURM var -> torch.distributed vars in case needed
- os.environ["LOCAL_RANK"] = str(args.local_rank)
- os.environ["RANK"] = str(args.rank)
- os.environ["WORLD_SIZE"] = str(args.world_size)
- torch.distributed.init_process_group(
- backend=args.dist_backend,
- init_method=args.dist_url,
- world_size=args.world_size,
- rank=args.rank,
- )
- elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster
- world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
- world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
- local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
- args.local_rank = local_rank
- args.rank = world_rank
- args.world_size = world_size
- torch.distributed.init_process_group(
- backend=args.dist_backend,
- init_method=args.dist_url,
- world_size=args.world_size,
- rank=args.rank,
- )
- else:
- # DDP via torchrun, torch.distributed.launch
- args.local_rank, _, _ = world_info_from_env()
- torch.distributed.init_process_group(
- backend=args.dist_backend, init_method=args.dist_url
- )
- args.world_size = torch.distributed.get_world_size()
- args.rank = torch.distributed.get_rank()
- args.distributed = True
- print(
- f"Distributed training: local_rank={args.local_rank}, "
- f"rank={args.rank}, world_size={args.world_size}, "
- f"hostname={socket.gethostname()}, pid={os.getpid()}"
- )
-
- if torch.cuda.is_available():
- if args.distributed and not args.no_set_device_rank:
- device = "cuda:%d" % args.local_rank
- else:
- device = "cuda:0"
- torch.cuda.set_device(device)
- else:
- device = "cpu"
- args.device = device
- device = torch.device(device)
- return device
diff --git a/audioldm/clap/training/imagenet_zeroshot_data.py b/audioldm/clap/training/imagenet_zeroshot_data.py
deleted file mode 100644
index d32e55328d6799ccb8d61625f43abb80a33d6c17..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/imagenet_zeroshot_data.py
+++ /dev/null
@@ -1,1088 +0,0 @@
-# NOTE: This script is currently not supported for CLAP.
-
-imagenet_classnames = [
- "tench",
- "goldfish",
- "great white shark",
- "tiger shark",
- "hammerhead shark",
- "electric ray",
- "stingray",
- "rooster",
- "hen",
- "ostrich",
- "brambling",
- "goldfinch",
- "house finch",
- "junco",
- "indigo bunting",
- "American robin",
- "bulbul",
- "jay",
- "magpie",
- "chickadee",
- "American dipper",
- "kite (bird of prey)",
- "bald eagle",
- "vulture",
- "great grey owl",
- "fire salamander",
- "smooth newt",
- "newt",
- "spotted salamander",
- "axolotl",
- "American bullfrog",
- "tree frog",
- "tailed frog",
- "loggerhead sea turtle",
- "leatherback sea turtle",
- "mud turtle",
- "terrapin",
- "box turtle",
- "banded gecko",
- "green iguana",
- "Carolina anole",
- "desert grassland whiptail lizard",
- "agama",
- "frilled-necked lizard",
- "alligator lizard",
- "Gila monster",
- "European green lizard",
- "chameleon",
- "Komodo dragon",
- "Nile crocodile",
- "American alligator",
- "triceratops",
- "worm snake",
- "ring-necked snake",
- "eastern hog-nosed snake",
- "smooth green snake",
- "kingsnake",
- "garter snake",
- "water snake",
- "vine snake",
- "night snake",
- "boa constrictor",
- "African rock python",
- "Indian cobra",
- "green mamba",
- "sea snake",
- "Saharan horned viper",
- "eastern diamondback rattlesnake",
- "sidewinder rattlesnake",
- "trilobite",
- "harvestman",
- "scorpion",
- "yellow garden spider",
- "barn spider",
- "European garden spider",
- "southern black widow",
- "tarantula",
- "wolf spider",
- "tick",
- "centipede",
- "black grouse",
- "ptarmigan",
- "ruffed grouse",
- "prairie grouse",
- "peafowl",
- "quail",
- "partridge",
- "african grey parrot",
- "macaw",
- "sulphur-crested cockatoo",
- "lorikeet",
- "coucal",
- "bee eater",
- "hornbill",
- "hummingbird",
- "jacamar",
- "toucan",
- "duck",
- "red-breasted merganser",
- "goose",
- "black swan",
- "tusker",
- "echidna",
- "platypus",
- "wallaby",
- "koala",
- "wombat",
- "jellyfish",
- "sea anemone",
- "brain coral",
- "flatworm",
- "nematode",
- "conch",
- "snail",
- "slug",
- "sea slug",
- "chiton",
- "chambered nautilus",
- "Dungeness crab",
- "rock crab",
- "fiddler crab",
- "red king crab",
- "American lobster",
- "spiny lobster",
- "crayfish",
- "hermit crab",
- "isopod",
- "white stork",
- "black stork",
- "spoonbill",
- "flamingo",
- "little blue heron",
- "great egret",
- "bittern bird",
- "crane bird",
- "limpkin",
- "common gallinule",
- "American coot",
- "bustard",
- "ruddy turnstone",
- "dunlin",
- "common redshank",
- "dowitcher",
- "oystercatcher",
- "pelican",
- "king penguin",
- "albatross",
- "grey whale",
- "killer whale",
- "dugong",
- "sea lion",
- "Chihuahua",
- "Japanese Chin",
- "Maltese",
- "Pekingese",
- "Shih Tzu",
- "King Charles Spaniel",
- "Papillon",
- "toy terrier",
- "Rhodesian Ridgeback",
- "Afghan Hound",
- "Basset Hound",
- "Beagle",
- "Bloodhound",
- "Bluetick Coonhound",
- "Black and Tan Coonhound",
- "Treeing Walker Coonhound",
- "English foxhound",
- "Redbone Coonhound",
- "borzoi",
- "Irish Wolfhound",
- "Italian Greyhound",
- "Whippet",
- "Ibizan Hound",
- "Norwegian Elkhound",
- "Otterhound",
- "Saluki",
- "Scottish Deerhound",
- "Weimaraner",
- "Staffordshire Bull Terrier",
- "American Staffordshire Terrier",
- "Bedlington Terrier",
- "Border Terrier",
- "Kerry Blue Terrier",
- "Irish Terrier",
- "Norfolk Terrier",
- "Norwich Terrier",
- "Yorkshire Terrier",
- "Wire Fox Terrier",
- "Lakeland Terrier",
- "Sealyham Terrier",
- "Airedale Terrier",
- "Cairn Terrier",
- "Australian Terrier",
- "Dandie Dinmont Terrier",
- "Boston Terrier",
- "Miniature Schnauzer",
- "Giant Schnauzer",
- "Standard Schnauzer",
- "Scottish Terrier",
- "Tibetan Terrier",
- "Australian Silky Terrier",
- "Soft-coated Wheaten Terrier",
- "West Highland White Terrier",
- "Lhasa Apso",
- "Flat-Coated Retriever",
- "Curly-coated Retriever",
- "Golden Retriever",
- "Labrador Retriever",
- "Chesapeake Bay Retriever",
- "German Shorthaired Pointer",
- "Vizsla",
- "English Setter",
- "Irish Setter",
- "Gordon Setter",
- "Brittany dog",
- "Clumber Spaniel",
- "English Springer Spaniel",
- "Welsh Springer Spaniel",
- "Cocker Spaniel",
- "Sussex Spaniel",
- "Irish Water Spaniel",
- "Kuvasz",
- "Schipperke",
- "Groenendael dog",
- "Malinois",
- "Briard",
- "Australian Kelpie",
- "Komondor",
- "Old English Sheepdog",
- "Shetland Sheepdog",
- "collie",
- "Border Collie",
- "Bouvier des Flandres dog",
- "Rottweiler",
- "German Shepherd Dog",
- "Dobermann",
- "Miniature Pinscher",
- "Greater Swiss Mountain Dog",
- "Bernese Mountain Dog",
- "Appenzeller Sennenhund",
- "Entlebucher Sennenhund",
- "Boxer",
- "Bullmastiff",
- "Tibetan Mastiff",
- "French Bulldog",
- "Great Dane",
- "St. Bernard",
- "husky",
- "Alaskan Malamute",
- "Siberian Husky",
- "Dalmatian",
- "Affenpinscher",
- "Basenji",
- "pug",
- "Leonberger",
- "Newfoundland dog",
- "Great Pyrenees dog",
- "Samoyed",
- "Pomeranian",
- "Chow Chow",
- "Keeshond",
- "brussels griffon",
- "Pembroke Welsh Corgi",
- "Cardigan Welsh Corgi",
- "Toy Poodle",
- "Miniature Poodle",
- "Standard Poodle",
- "Mexican hairless dog (xoloitzcuintli)",
- "grey wolf",
- "Alaskan tundra wolf",
- "red wolf or maned wolf",
- "coyote",
- "dingo",
- "dhole",
- "African wild dog",
- "hyena",
- "red fox",
- "kit fox",
- "Arctic fox",
- "grey fox",
- "tabby cat",
- "tiger cat",
- "Persian cat",
- "Siamese cat",
- "Egyptian Mau",
- "cougar",
- "lynx",
- "leopard",
- "snow leopard",
- "jaguar",
- "lion",
- "tiger",
- "cheetah",
- "brown bear",
- "American black bear",
- "polar bear",
- "sloth bear",
- "mongoose",
- "meerkat",
- "tiger beetle",
- "ladybug",
- "ground beetle",
- "longhorn beetle",
- "leaf beetle",
- "dung beetle",
- "rhinoceros beetle",
- "weevil",
- "fly",
- "bee",
- "ant",
- "grasshopper",
- "cricket insect",
- "stick insect",
- "cockroach",
- "praying mantis",
- "cicada",
- "leafhopper",
- "lacewing",
- "dragonfly",
- "damselfly",
- "red admiral butterfly",
- "ringlet butterfly",
- "monarch butterfly",
- "small white butterfly",
- "sulphur butterfly",
- "gossamer-winged butterfly",
- "starfish",
- "sea urchin",
- "sea cucumber",
- "cottontail rabbit",
- "hare",
- "Angora rabbit",
- "hamster",
- "porcupine",
- "fox squirrel",
- "marmot",
- "beaver",
- "guinea pig",
- "common sorrel horse",
- "zebra",
- "pig",
- "wild boar",
- "warthog",
- "hippopotamus",
- "ox",
- "water buffalo",
- "bison",
- "ram (adult male sheep)",
- "bighorn sheep",
- "Alpine ibex",
- "hartebeest",
- "impala (antelope)",
- "gazelle",
- "arabian camel",
- "llama",
- "weasel",
- "mink",
- "European polecat",
- "black-footed ferret",
- "otter",
- "skunk",
- "badger",
- "armadillo",
- "three-toed sloth",
- "orangutan",
- "gorilla",
- "chimpanzee",
- "gibbon",
- "siamang",
- "guenon",
- "patas monkey",
- "baboon",
- "macaque",
- "langur",
- "black-and-white colobus",
- "proboscis monkey",
- "marmoset",
- "white-headed capuchin",
- "howler monkey",
- "titi monkey",
- "Geoffroy's spider monkey",
- "common squirrel monkey",
- "ring-tailed lemur",
- "indri",
- "Asian elephant",
- "African bush elephant",
- "red panda",
- "giant panda",
- "snoek fish",
- "eel",
- "silver salmon",
- "rock beauty fish",
- "clownfish",
- "sturgeon",
- "gar fish",
- "lionfish",
- "pufferfish",
- "abacus",
- "abaya",
- "academic gown",
- "accordion",
- "acoustic guitar",
- "aircraft carrier",
- "airliner",
- "airship",
- "altar",
- "ambulance",
- "amphibious vehicle",
- "analog clock",
- "apiary",
- "apron",
- "trash can",
- "assault rifle",
- "backpack",
- "bakery",
- "balance beam",
- "balloon",
- "ballpoint pen",
- "Band-Aid",
- "banjo",
- "baluster / handrail",
- "barbell",
- "barber chair",
- "barbershop",
- "barn",
- "barometer",
- "barrel",
- "wheelbarrow",
- "baseball",
- "basketball",
- "bassinet",
- "bassoon",
- "swimming cap",
- "bath towel",
- "bathtub",
- "station wagon",
- "lighthouse",
- "beaker",
- "military hat (bearskin or shako)",
- "beer bottle",
- "beer glass",
- "bell tower",
- "baby bib",
- "tandem bicycle",
- "bikini",
- "ring binder",
- "binoculars",
- "birdhouse",
- "boathouse",
- "bobsleigh",
- "bolo tie",
- "poke bonnet",
- "bookcase",
- "bookstore",
- "bottle cap",
- "hunting bow",
- "bow tie",
- "brass memorial plaque",
- "bra",
- "breakwater",
- "breastplate",
- "broom",
- "bucket",
- "buckle",
- "bulletproof vest",
- "high-speed train",
- "butcher shop",
- "taxicab",
- "cauldron",
- "candle",
- "cannon",
- "canoe",
- "can opener",
- "cardigan",
- "car mirror",
- "carousel",
- "tool kit",
- "cardboard box / carton",
- "car wheel",
- "automated teller machine",
- "cassette",
- "cassette player",
- "castle",
- "catamaran",
- "CD player",
- "cello",
- "mobile phone",
- "chain",
- "chain-link fence",
- "chain mail",
- "chainsaw",
- "storage chest",
- "chiffonier",
- "bell or wind chime",
- "china cabinet",
- "Christmas stocking",
- "church",
- "movie theater",
- "cleaver",
- "cliff dwelling",
- "cloak",
- "clogs",
- "cocktail shaker",
- "coffee mug",
- "coffeemaker",
- "spiral or coil",
- "combination lock",
- "computer keyboard",
- "candy store",
- "container ship",
- "convertible",
- "corkscrew",
- "cornet",
- "cowboy boot",
- "cowboy hat",
- "cradle",
- "construction crane",
- "crash helmet",
- "crate",
- "infant bed",
- "Crock Pot",
- "croquet ball",
- "crutch",
- "cuirass",
- "dam",
- "desk",
- "desktop computer",
- "rotary dial telephone",
- "diaper",
- "digital clock",
- "digital watch",
- "dining table",
- "dishcloth",
- "dishwasher",
- "disc brake",
- "dock",
- "dog sled",
- "dome",
- "doormat",
- "drilling rig",
- "drum",
- "drumstick",
- "dumbbell",
- "Dutch oven",
- "electric fan",
- "electric guitar",
- "electric locomotive",
- "entertainment center",
- "envelope",
- "espresso machine",
- "face powder",
- "feather boa",
- "filing cabinet",
- "fireboat",
- "fire truck",
- "fire screen",
- "flagpole",
- "flute",
- "folding chair",
- "football helmet",
- "forklift",
- "fountain",
- "fountain pen",
- "four-poster bed",
- "freight car",
- "French horn",
- "frying pan",
- "fur coat",
- "garbage truck",
- "gas mask or respirator",
- "gas pump",
- "goblet",
- "go-kart",
- "golf ball",
- "golf cart",
- "gondola",
- "gong",
- "gown",
- "grand piano",
- "greenhouse",
- "radiator grille",
- "grocery store",
- "guillotine",
- "hair clip",
- "hair spray",
- "half-track",
- "hammer",
- "hamper",
- "hair dryer",
- "hand-held computer",
- "handkerchief",
- "hard disk drive",
- "harmonica",
- "harp",
- "combine harvester",
- "hatchet",
- "holster",
- "home theater",
- "honeycomb",
- "hook",
- "hoop skirt",
- "gymnastic horizontal bar",
- "horse-drawn vehicle",
- "hourglass",
- "iPod",
- "clothes iron",
- "carved pumpkin",
- "jeans",
- "jeep",
- "T-shirt",
- "jigsaw puzzle",
- "rickshaw",
- "joystick",
- "kimono",
- "knee pad",
- "knot",
- "lab coat",
- "ladle",
- "lampshade",
- "laptop computer",
- "lawn mower",
- "lens cap",
- "letter opener",
- "library",
- "lifeboat",
- "lighter",
- "limousine",
- "ocean liner",
- "lipstick",
- "slip-on shoe",
- "lotion",
- "music speaker",
- "loupe magnifying glass",
- "sawmill",
- "magnetic compass",
- "messenger bag",
- "mailbox",
- "tights",
- "one-piece bathing suit",
- "manhole cover",
- "maraca",
- "marimba",
- "mask",
- "matchstick",
- "maypole",
- "maze",
- "measuring cup",
- "medicine cabinet",
- "megalith",
- "microphone",
- "microwave oven",
- "military uniform",
- "milk can",
- "minibus",
- "miniskirt",
- "minivan",
- "missile",
- "mitten",
- "mixing bowl",
- "mobile home",
- "ford model t",
- "modem",
- "monastery",
- "monitor",
- "moped",
- "mortar and pestle",
- "graduation cap",
- "mosque",
- "mosquito net",
- "vespa",
- "mountain bike",
- "tent",
- "computer mouse",
- "mousetrap",
- "moving van",
- "muzzle",
- "metal nail",
- "neck brace",
- "necklace",
- "baby pacifier",
- "notebook computer",
- "obelisk",
- "oboe",
- "ocarina",
- "odometer",
- "oil filter",
- "pipe organ",
- "oscilloscope",
- "overskirt",
- "bullock cart",
- "oxygen mask",
- "product packet / packaging",
- "paddle",
- "paddle wheel",
- "padlock",
- "paintbrush",
- "pajamas",
- "palace",
- "pan flute",
- "paper towel",
- "parachute",
- "parallel bars",
- "park bench",
- "parking meter",
- "railroad car",
- "patio",
- "payphone",
- "pedestal",
- "pencil case",
- "pencil sharpener",
- "perfume",
- "Petri dish",
- "photocopier",
- "plectrum",
- "Pickelhaube",
- "picket fence",
- "pickup truck",
- "pier",
- "piggy bank",
- "pill bottle",
- "pillow",
- "ping-pong ball",
- "pinwheel",
- "pirate ship",
- "drink pitcher",
- "block plane",
- "planetarium",
- "plastic bag",
- "plate rack",
- "farm plow",
- "plunger",
- "Polaroid camera",
- "pole",
- "police van",
- "poncho",
- "pool table",
- "soda bottle",
- "plant pot",
- "potter's wheel",
- "power drill",
- "prayer rug",
- "printer",
- "prison",
- "missile",
- "projector",
- "hockey puck",
- "punching bag",
- "purse",
- "quill",
- "quilt",
- "race car",
- "racket",
- "radiator",
- "radio",
- "radio telescope",
- "rain barrel",
- "recreational vehicle",
- "fishing casting reel",
- "reflex camera",
- "refrigerator",
- "remote control",
- "restaurant",
- "revolver",
- "rifle",
- "rocking chair",
- "rotisserie",
- "eraser",
- "rugby ball",
- "ruler measuring stick",
- "sneaker",
- "safe",
- "safety pin",
- "salt shaker",
- "sandal",
- "sarong",
- "saxophone",
- "scabbard",
- "weighing scale",
- "school bus",
- "schooner",
- "scoreboard",
- "CRT monitor",
- "screw",
- "screwdriver",
- "seat belt",
- "sewing machine",
- "shield",
- "shoe store",
- "shoji screen / room divider",
- "shopping basket",
- "shopping cart",
- "shovel",
- "shower cap",
- "shower curtain",
- "ski",
- "balaclava ski mask",
- "sleeping bag",
- "slide rule",
- "sliding door",
- "slot machine",
- "snorkel",
- "snowmobile",
- "snowplow",
- "soap dispenser",
- "soccer ball",
- "sock",
- "solar thermal collector",
- "sombrero",
- "soup bowl",
- "keyboard space bar",
- "space heater",
- "space shuttle",
- "spatula",
- "motorboat",
- "spider web",
- "spindle",
- "sports car",
- "spotlight",
- "stage",
- "steam locomotive",
- "through arch bridge",
- "steel drum",
- "stethoscope",
- "scarf",
- "stone wall",
- "stopwatch",
- "stove",
- "strainer",
- "tram",
- "stretcher",
- "couch",
- "stupa",
- "submarine",
- "suit",
- "sundial",
- "sunglasses",
- "sunglasses",
- "sunscreen",
- "suspension bridge",
- "mop",
- "sweatshirt",
- "swim trunks / shorts",
- "swing",
- "electrical switch",
- "syringe",
- "table lamp",
- "tank",
- "tape player",
- "teapot",
- "teddy bear",
- "television",
- "tennis ball",
- "thatched roof",
- "front curtain",
- "thimble",
- "threshing machine",
- "throne",
- "tile roof",
- "toaster",
- "tobacco shop",
- "toilet seat",
- "torch",
- "totem pole",
- "tow truck",
- "toy store",
- "tractor",
- "semi-trailer truck",
- "tray",
- "trench coat",
- "tricycle",
- "trimaran",
- "tripod",
- "triumphal arch",
- "trolleybus",
- "trombone",
- "hot tub",
- "turnstile",
- "typewriter keyboard",
- "umbrella",
- "unicycle",
- "upright piano",
- "vacuum cleaner",
- "vase",
- "vaulted or arched ceiling",
- "velvet fabric",
- "vending machine",
- "vestment",
- "viaduct",
- "violin",
- "volleyball",
- "waffle iron",
- "wall clock",
- "wallet",
- "wardrobe",
- "military aircraft",
- "sink",
- "washing machine",
- "water bottle",
- "water jug",
- "water tower",
- "whiskey jug",
- "whistle",
- "hair wig",
- "window screen",
- "window shade",
- "Windsor tie",
- "wine bottle",
- "airplane wing",
- "wok",
- "wooden spoon",
- "wool",
- "split-rail fence",
- "shipwreck",
- "sailboat",
- "yurt",
- "website",
- "comic book",
- "crossword",
- "traffic or street sign",
- "traffic light",
- "dust jacket",
- "menu",
- "plate",
- "guacamole",
- "consomme",
- "hot pot",
- "trifle",
- "ice cream",
- "popsicle",
- "baguette",
- "bagel",
- "pretzel",
- "cheeseburger",
- "hot dog",
- "mashed potatoes",
- "cabbage",
- "broccoli",
- "cauliflower",
- "zucchini",
- "spaghetti squash",
- "acorn squash",
- "butternut squash",
- "cucumber",
- "artichoke",
- "bell pepper",
- "cardoon",
- "mushroom",
- "Granny Smith apple",
- "strawberry",
- "orange",
- "lemon",
- "fig",
- "pineapple",
- "banana",
- "jackfruit",
- "cherimoya (custard apple)",
- "pomegranate",
- "hay",
- "carbonara",
- "chocolate syrup",
- "dough",
- "meatloaf",
- "pizza",
- "pot pie",
- "burrito",
- "red wine",
- "espresso",
- "tea cup",
- "eggnog",
- "mountain",
- "bubble",
- "cliff",
- "coral reef",
- "geyser",
- "lakeshore",
- "promontory",
- "sandbar",
- "beach",
- "valley",
- "volcano",
- "baseball player",
- "bridegroom",
- "scuba diver",
- "rapeseed",
- "daisy",
- "yellow lady's slipper",
- "corn",
- "acorn",
- "rose hip",
- "horse chestnut seed",
- "coral fungus",
- "agaric",
- "gyromitra",
- "stinkhorn mushroom",
- "earth star fungus",
- "hen of the woods mushroom",
- "bolete",
- "corn cob",
- "toilet paper",
-]
-
-
-openai_imagenet_template = [
- lambda c: f"a bad photo of a {c}.",
- lambda c: f"a photo of many {c}.",
- lambda c: f"a sculpture of a {c}.",
- lambda c: f"a photo of the hard to see {c}.",
- lambda c: f"a low resolution photo of the {c}.",
- lambda c: f"a rendering of a {c}.",
- lambda c: f"graffiti of a {c}.",
- lambda c: f"a bad photo of the {c}.",
- lambda c: f"a cropped photo of the {c}.",
- lambda c: f"a tattoo of a {c}.",
- lambda c: f"the embroidered {c}.",
- lambda c: f"a photo of a hard to see {c}.",
- lambda c: f"a bright photo of a {c}.",
- lambda c: f"a photo of a clean {c}.",
- lambda c: f"a photo of a dirty {c}.",
- lambda c: f"a dark photo of the {c}.",
- lambda c: f"a drawing of a {c}.",
- lambda c: f"a photo of my {c}.",
- lambda c: f"the plastic {c}.",
- lambda c: f"a photo of the cool {c}.",
- lambda c: f"a close-up photo of a {c}.",
- lambda c: f"a black and white photo of the {c}.",
- lambda c: f"a painting of the {c}.",
- lambda c: f"a painting of a {c}.",
- lambda c: f"a pixelated photo of the {c}.",
- lambda c: f"a sculpture of the {c}.",
- lambda c: f"a bright photo of the {c}.",
- lambda c: f"a cropped photo of a {c}.",
- lambda c: f"a plastic {c}.",
- lambda c: f"a photo of the dirty {c}.",
- lambda c: f"a jpeg corrupted photo of a {c}.",
- lambda c: f"a blurry photo of the {c}.",
- lambda c: f"a photo of the {c}.",
- lambda c: f"a good photo of the {c}.",
- lambda c: f"a rendering of the {c}.",
- lambda c: f"a {c} in a video game.",
- lambda c: f"a photo of one {c}.",
- lambda c: f"a doodle of a {c}.",
- lambda c: f"a close-up photo of the {c}.",
- lambda c: f"a photo of a {c}.",
- lambda c: f"the origami {c}.",
- lambda c: f"the {c} in a video game.",
- lambda c: f"a sketch of a {c}.",
- lambda c: f"a doodle of the {c}.",
- lambda c: f"a origami {c}.",
- lambda c: f"a low resolution photo of a {c}.",
- lambda c: f"the toy {c}.",
- lambda c: f"a rendition of the {c}.",
- lambda c: f"a photo of the clean {c}.",
- lambda c: f"a photo of a large {c}.",
- lambda c: f"a rendition of a {c}.",
- lambda c: f"a photo of a nice {c}.",
- lambda c: f"a photo of a weird {c}.",
- lambda c: f"a blurry photo of a {c}.",
- lambda c: f"a cartoon {c}.",
- lambda c: f"art of a {c}.",
- lambda c: f"a sketch of the {c}.",
- lambda c: f"a embroidered {c}.",
- lambda c: f"a pixelated photo of a {c}.",
- lambda c: f"itap of the {c}.",
- lambda c: f"a jpeg corrupted photo of the {c}.",
- lambda c: f"a good photo of a {c}.",
- lambda c: f"a plushie {c}.",
- lambda c: f"a photo of the nice {c}.",
- lambda c: f"a photo of the small {c}.",
- lambda c: f"a photo of the weird {c}.",
- lambda c: f"the cartoon {c}.",
- lambda c: f"art of the {c}.",
- lambda c: f"a drawing of the {c}.",
- lambda c: f"a photo of the large {c}.",
- lambda c: f"a black and white photo of a {c}.",
- lambda c: f"the plushie {c}.",
- lambda c: f"a dark photo of a {c}.",
- lambda c: f"itap of a {c}.",
- lambda c: f"graffiti of the {c}.",
- lambda c: f"a toy {c}.",
- lambda c: f"itap of my {c}.",
- lambda c: f"a photo of a cool {c}.",
- lambda c: f"a photo of a small {c}.",
- lambda c: f"a tattoo of the {c}.",
-]
diff --git a/audioldm/clap/training/infer_demo.py b/audioldm/clap/training/infer_demo.py
deleted file mode 100644
index 7d1f4784898dbfeb69affefb6f624711adc8cb42..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/infer_demo.py
+++ /dev/null
@@ -1,105 +0,0 @@
-import sys
-
-import os
-import torch
-import librosa
-from open_clip import create_model
-from training.data import get_audio_features
-from training.data import int16_to_float32, float32_to_int16
-from transformers import RobertaTokenizer
-
-tokenize = RobertaTokenizer.from_pretrained("roberta-base")
-
-
-def tokenizer(text):
- result = tokenize(
- text,
- padding="max_length",
- truncation=True,
- max_length=77,
- return_tensors="pt",
- )
- return {k: v.squeeze(0) for k, v in result.items()}
-
-
-PRETRAINED_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/checkpoints/epoch_top_0_audioset_no_fusion.pt"
-WAVE_48k_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/audio/machine.wav"
-
-
-def infer_text():
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
- precision = "fp32"
- amodel = "HTSAT-tiny" # or 'PANN-14'
- tmodel = "roberta" # the best text encoder in our training
- enable_fusion = False # False if you do not want to use the fusion model
- fusion_type = "aff_2d"
- pretrained = PRETRAINED_PATH
-
- model, model_cfg = create_model(
- amodel,
- tmodel,
- pretrained,
- precision=precision,
- device=device,
- enable_fusion=enable_fusion,
- fusion_type=fusion_type,
- )
- # load the text, can be a list (i.e. batch size)
- text_data = ["I love the contrastive learning", "I love the pretrain model"]
- # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90
- text_data = tokenizer(text_data)
-
- text_embed = model.get_text_embedding(text_data)
- print(text_embed.size())
-
-
-def infer_audio():
-
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
- precision = "fp32"
- amodel = "HTSAT-tiny" # or 'PANN-14'
- tmodel = "roberta" # the best text encoder in our training
- enable_fusion = False # False if you do not want to use the fusion model
- fusion_type = "aff_2d"
- pretrained = PRETRAINED_PATH
-
- model, model_cfg = create_model(
- amodel,
- tmodel,
- pretrained,
- precision=precision,
- device=device,
- enable_fusion=enable_fusion,
- fusion_type=fusion_type,
- )
-
- # load the waveform of the shape (T,), should resample to 48000
- audio_waveform, sr = librosa.load(WAVE_48k_PATH, sr=48000)
- # quantize
- audio_waveform = int16_to_float32(float32_to_int16(audio_waveform))
- audio_waveform = torch.from_numpy(audio_waveform).float()
- audio_dict = {}
-
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
- import ipdb
-
- ipdb.set_trace()
- audio_dict = get_audio_features(
- audio_dict,
- audio_waveform,
- 480000,
- data_truncating="fusion",
- data_filling="repeatpad",
- audio_cfg=model_cfg["audio_cfg"],
- )
- # can send a list to the model, to process many audio tracks in one time (i.e. batch size)
- audio_embed = model.get_audio_embedding([audio_dict])
- print(audio_embed.size())
- import ipdb
-
- ipdb.set_trace()
-
-
-if __name__ == "__main__":
- infer_text()
- infer_audio()
diff --git a/audioldm/clap/training/logger.py b/audioldm/clap/training/logger.py
deleted file mode 100644
index ac4634970fae6aacde2b7b808355dbd50c90ce73..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/logger.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import logging
-
-
-def setup_logging(log_file, level, include_host=False):
- if include_host:
- import socket
-
- hostname = socket.gethostname()
- formatter = logging.Formatter(
- f"%(asctime)s | {hostname} | %(levelname)s | %(message)s",
- datefmt="%Y-%m-%d,%H:%M:%S",
- )
- else:
- formatter = logging.Formatter(
- "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S"
- )
-
- logging.root.setLevel(level)
- loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
- for logger in loggers:
- logger.setLevel(level)
-
- stream_handler = logging.StreamHandler()
- stream_handler.setFormatter(formatter)
- logging.root.addHandler(stream_handler)
-
- if log_file:
- file_handler = logging.FileHandler(filename=log_file)
- file_handler.setFormatter(formatter)
- logging.root.addHandler(file_handler)
diff --git a/audioldm/clap/training/lp_main.py b/audioldm/clap/training/lp_main.py
deleted file mode 100644
index c2d4e8c85aaa3c8e4221963ef56a815cc14f354f..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/lp_main.py
+++ /dev/null
@@ -1,670 +0,0 @@
-from cmath import cos
-from inspect import getargs
-import logging
-import os
-import random
-from datetime import datetime
-import bisect
-import copy
-from sched import scheduler
-import numpy as np
-import torch
-import torch.backends.cudnn as cudnn
-from torch import optim
-from torch.cuda.amp import GradScaler
-import faulthandler
-import pathlib
-import argparse
-import time
-
-try:
- import wandb
-except ImportError:
- wandb = None
-
-try:
- import torch.utils.tensorboard as tensorboard
-except ImportError:
- tensorboard = None
-
-try:
- import horovod.torch as hvd
-except ImportError:
- hvd = None
-
-from open_clip import create_model_and_transforms, trace_model, create_model
-from training.data import get_data
-from training.params import parse_args
-from training.distributed import is_master, init_distributed_device, world_info_from_env
-from training.logger import setup_logging
-from training.scheduler import cosine_lr
-from training.lp_train import train_one_epoch, evaluate
-from open_clip.utils import get_tar_path_from_dataset_name, dataset_split, get_optimizer
-from open_clip.utils import load_p, load_class_label
-from open_clip.linear_probe import LinearProbe
-
-
-def maintain_ckpts(args, startidx, all_idx_len):
- for i in reversed(range(startidx, all_idx_len)):
- if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")):
- os.rename(
- os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
- os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"),
- )
- if os.path.exists(
- os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")
- ):
- os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt"))
- return
-
-
-def update_top_k_performance(
- new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True
-):
- """
- Record the top-k performance of the current epoch.
- current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...}
- """
- if isinstance(new_metrics_inputs, (list, tuple)):
- new_metrics_inputs = np.mean(new_metrics_inputs)
- return update_top_k_performance(
- new_metrics_inputs,
- current_top_k_ckpt_metrics,
- args=args,
- ckpt=ckpt,
- bignumbetter=bignumbetter,
- )
- elif isinstance(new_metrics_inputs, dict):
- new_metrics_inputs = np.mean(list(new_metrics_inputs.values()))
- return update_top_k_performance(
- new_metrics_inputs,
- current_top_k_ckpt_metrics,
- args=args,
- ckpt=ckpt,
- bignumbetter=bignumbetter,
- )
- elif isinstance(new_metrics_inputs, (float, int)):
- update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()}
- sorted_keys = sorted(current_top_k_ckpt_metrics.keys())
- sorted_values = sorted(
- current_top_k_ckpt_metrics.values(), reverse=bignumbetter
- )
- sorted_values_ = copy.deepcopy(sorted_values)
- sorted_values.append(new_metrics_inputs)
- sorted_values = sorted(sorted_values, reverse=bignumbetter)
- sorted_values = sorted_values[:-1]
-
- if sorted_values == sorted_values_:
- return current_top_k_ckpt_metrics, new_metrics_inputs
- else:
- for i in range(len(sorted_keys)):
- if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]:
- current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i]
- update_flag[sorted_keys[i]] = True
- for i in range(len(update_flag)):
- if update_flag[i]:
- maintain_ckpts(args, i, len(sorted_keys))
- torch.save(
- ckpt,
- os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
- )
- break
- return current_top_k_ckpt_metrics, new_metrics_inputs
-
-
-# def updateifNone(a, b):
-# a = b if None else a
-# return a
-
-
-def is_pretrained_params(n):
- return (
- n.startswith("clap_model.transformer")
- or n in ["clap_model.positional_embedding", "clap_model.text_projection"]
- or n.startswith("clap_model.token_embedding")
- or n.startswith("clap_model.ln_final")
- or n.startswith("clap_model.logit_scale_t")
- )
-
-
-def random_seed(seed=42, rank=0):
- torch.manual_seed(seed + rank)
- np.random.seed(seed + rank)
- random.seed(seed + rank)
-
-
-def config_lp_optimizer(model, data, args):
- # set wd-related params to 0 if use adam optimizer
- if args.optimizer == "adam":
- args.wd = 0
- args.wd_pretrained = 0
- args.wd_new = 0
-
- in_clap = lambda n, p: n.startswith("clap_model")
-
- named_parameters = list(model.named_parameters())
-
- optimizer = {}
- scheduler = {}
-
- # freeze text encoder
- text_freeze_parameters = [
- p
- for n, p in named_parameters
- if n.startswith("clap_model.transformer")
- or n in ["clap_model.positional_embedding", "clap_model.text_projection"]
- or n.startswith("clap_model.token_embedding")
- or n.startswith("clap_model.ln_final")
- ]
-
- if args.freeze_text:
- logging.info("Freeze Text!!!!")
- for k in text_freeze_parameters:
- k.requires_grad = False
-
- if not args.lp_freeze:
- exclude = (
- lambda n, p: p.ndim < 2
- or "bn" in n
- or "ln" in n
- or "bias" in n
- or "logit_scale" in n
- )
- include = lambda n, p: not exclude(n, p)
-
- # (yusong): we do not split the learning rate anymore
- # p for n, p in named_parameters if in_clap(n,p) and exclude(n, p) and p.requires_grad
- gain_or_bias_params = [
- p for n, p in named_parameters if exclude(n, p) and p.requires_grad
- ]
- # rest_params = [p for n, p in named_parameters if in_clap(n,p) and include(n, p) and p.requires_grad]
- rest_params = [
- p for n, p in named_parameters if include(n, p) and p.requires_grad
- ]
-
- if args.train_data is None:
- optimizer = None
- scheduler = None
- else:
- total_steps = data["train"].dataloader.num_batches * args.epochs
-
- if args.split_opt:
- for x in ["lr", "beta1", "beta2", "eps", "wd"]:
- for y in ["_new", "_pretrained"]:
- if getattr(args, x + y) is None:
- setattr(args, x + y, getattr(args, x))
-
- gain_or_bias_pretrained_params = [
- p
- for n, p in named_parameters
- if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n)
- ]
- rest_pretrained_params = [
- p
- for n, p in named_parameters
- if (include(n, p) and p.requires_grad) and is_pretrained_params(n)
- ]
- gain_or_bias_new_params = [
- p
- for n, p in named_parameters
- if (exclude(n, p) and p.requires_grad)
- and (not is_pretrained_params(n))
- ]
- rest_new_params = [
- p
- for n, p in named_parameters
- if (include(n, p) and p.requires_grad)
- and (not is_pretrained_params(n))
- ]
-
- pretrained_params_optimizer = get_optimizer(
- [
- {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0},
- {
- "params": rest_pretrained_params,
- "weight_decay": args.wd_pretrained,
- },
- ],
- lr=args.lr_pretrained,
- betas=(args.beta1_pretrained, args.beta2_pretrained),
- eps=args.eps_pretrained,
- momentum=args.momentum_pretrained,
- optimizer_name=args.optimizer,
- )
- pretrained_params_scheduler = cosine_lr(
- pretrained_params_optimizer,
- args.lr_pretrained,
- args.warmup,
- total_steps,
- )
-
- new_params_optimizer = get_optimizer(
- [
- {"params": gain_or_bias_new_params, "weight_decay": 0.0},
- {"params": rest_new_params, "weight_decay": args.wd_new},
- ],
- lr=args.lr_new,
- betas=(args.beta1_new, args.beta2_new),
- eps=args.eps_new,
- momentum=args.momentum_new,
- optimizer_name=args.optimizer,
- )
- new_params_scheduler = cosine_lr(
- new_params_optimizer, args.lr_new, args.warmup, total_steps
- )
-
- optimizer["text"] = pretrained_params_optimizer
- optimizer["audio"] = new_params_optimizer
- scheduler["text"] = pretrained_params_scheduler
- scheduler["audio"] = new_params_scheduler
-
- if args.horovod:
- pretrained_params_optimizer = hvd.DistributedOptimizer(
- pretrained_params_optimizer,
- named_parameters=model.named_parameters(),
- )
- new_params_optimizer = hvd.DistributedOptimizer(
- new_params_optimizer, named_parameters=model.named_parameters()
- )
- hvd.broadcast_parameters(model.state_dict(), root_rank=0)
- hvd.broadcast_optimizer_state(
- pretrained_params_optimizer, root_rank=0
- )
- hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0)
- else:
-
- optimizer["clap"] = get_optimizer(
- [
- {"params": gain_or_bias_params, "weight_decay": 0.0},
- {"params": rest_params, "weight_decay": args.wd},
- ],
- lr=args.lr,
- betas=(args.beta1, args.beta2),
- eps=args.eps,
- momentum=args.momentum,
- optimizer_name=args.optimizer,
- )
- scheduler["clap"] = cosine_lr(
- optimizer["clap"], args.lr, args.warmup, total_steps
- )
-
- if args.horovod:
- optimizer["clap"] = hvd.DistributedOptimizer(
- optimizer["clap"], named_parameters=model.named_parameters()
- )
- hvd.broadcast_parameters(model.state_dict(), root_rank=0)
- hvd.broadcast_optimizer_state(optimizer["clap"], root_rank=0)
-
- # linear probe optimizer
- else:
- lp_params = [
- p for n, p in named_parameters if (not in_clap(n, p)) and p.requires_grad
- ]
- lp_optim = get_optimizer(
- lp_params,
- lr=args.lp_lr,
- betas=(args.beta1, args.beta2),
- eps=args.eps,
- momentum=0.9,
- optimizer_name=args.optimizer,
- )
- optimizer["lp"] = lp_optim
-
- return optimizer, scheduler, text_freeze_parameters
-
-
-def main():
- args = parse_args()
-
- time.sleep(args.sleep)
-
- # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
- args.amodel = args.amodel.replace("/", "-")
- # download sizes.json file
-
- # (yusong): the below two lines are for debug
- # print("setting up faulthandler")
- # faulthandler.register(10)
-
- random.seed(args.seed)
- torch.manual_seed(args.seed)
- torch.cuda.manual_seed(args.seed)
- torch.cuda.manual_seed_all(args.seed)
- np.random.seed(args.seed)
- args.class_index_dict = load_class_label(args.class_label_path)
-
- # get the name of the experiments
- if args.name is None:
- args.name = "-".join(
- [
- datetime.now().strftime("%Y_%m_%d-%H_%M_%S"),
- f"linear_probe" f"model_{args.amodel}",
- f"lr_{args.lr}",
- f"b_{args.batch_size}",
- f"j_{args.workers}",
- f"p_{args.precision}",
- ]
- )
-
- # discover initial world args early so we can log properly
- args.distributed = False
- args.local_rank, args.rank, args.world_size = world_info_from_env()
-
- if args.remotedata and is_master(args):
- for dataset_name in args.datasetnames:
- for split in dataset_split[dataset_name]:
- if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
- os.makedirs(f"./json_files/{dataset_name}/{split}")
- os.system(
- f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
- )
-
- args.log_path = None
- if is_master(args, local=args.log_local):
- log_base_path = os.path.join(args.logs, args.name)
- os.makedirs(log_base_path, exist_ok=True)
- log_filename = f"out-{args.rank}" if args.log_local else "out.log"
- args.log_path = os.path.join(log_base_path, log_filename)
-
- # avoid log dir in same name:
- postfix = 0
- while os.path.exists(args.log_path):
- postfix += 1
- log_base_path_new = log_base_path + "-" + str(postfix)
- os.makedirs(log_base_path_new, exist_ok=True)
- log_filename = f"out-{args.rank}" if args.log_local else "out.log"
- args.log_path = os.path.join(log_base_path_new, log_filename)
- # print(
- # "Error. Experiment already exists. Use --name {} to specify a new experiment."
- # )
- # return -1
-
- # Set logger
- args.log_level = logging.DEBUG if args.debug else logging.INFO
- setup_logging(args.log_path, args.log_level)
-
- # fully initialize distributed device environment
- device = init_distributed_device(args)
-
- args.wandb = "wandb" in args.report_to or "all" in args.report_to
- args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to
- if is_master(args):
- args.tensorboard_path = (
- os.path.join(args.logs, args.name, "tensorboard")
- if args.tensorboard
- else ""
- )
- args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
- for dirname in [args.tensorboard_path, args.checkpoint_path]:
- if dirname:
- os.makedirs(dirname, exist_ok=True)
- else:
- args.tensorboard_path = ""
- args.checkpoint_path = ""
-
- if args.copy_codebase:
- copy_codebase(args)
-
- assert args.precision in ["amp", "fp16", "fp32"]
- if args.precision == "fp16":
- logging.warning(
- "It is recommended to use AMP mixed-precision instead of FP16. "
- "FP16 support needs further verification and tuning, especially for train."
- )
-
- if args.horovod:
- logging.info(
- f"Running in horovod mode with multiple processes / nodes. Device: {args.device}."
- f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
- )
- elif args.distributed:
- logging.info(
- f"Running in distributed mode with multiple processes. Device: {args.device}."
- f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
- )
- else:
- logging.info(f"Running with a single process. Device {args.device}.")
-
- logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}")
-
- # Create CLAP model
- clap_model, clap_model_cfg = create_model(
- args.amodel,
- args.tmodel,
- args.pretrained,
- precision=args.precision,
- device=device,
- jit=args.torchscript,
- force_quick_gelu=args.force_quick_gelu,
- openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
- skip_params=False,
- pretrained_audio=args.pretrained_audio,
- pretrained_text=args.pretrained_text,
- enable_fusion=args.enable_fusion,
- fusion_type=args.fusion_type,
- )
-
- args.lp_out_ch = len(list(args.class_index_dict.keys()))
- # Linear Probe
- logging.info(f"linear probe using mlp: {args.lp_mlp}")
- logging.info(f"linear probe using freeze: {args.lp_freeze}")
- logging.info(f"linear probe act layer: {args.lp_act}")
- logging.info(f"linear probe out ch: {args.lp_out_ch}")
- logging.info(f"linear probe learning rate (if applicable): {args.lp_lr}")
- logging.info(f"linear probe loss func: {args.lp_loss}")
- logging.info(f"linear probe lp_metrics: {args.lp_metrics}")
-
- model = LinearProbe(
- clap_model,
- mlp=args.lp_mlp,
- freeze=args.lp_freeze,
- in_ch=512,
- out_ch=args.lp_out_ch,
- act=args.lp_act,
- ) # in_ch is fixed (i.e., 512)
- model = model.to(device)
-
- if args.horovod:
- with torch.no_grad():
- for param in model.parameters():
- param.set_(param.contiguous())
-
- if args.trace:
- model = trace_model(model, batch_size=args.batch_size, device=device)
-
- if is_master(args):
- logging.info("Linear Probe CLAP Model:")
- logging.info(f"{str(clap_model)}")
- logging.info("Params:")
- params_file = os.path.join(args.logs, args.name, "params.txt")
- with open(params_file, "w") as f:
- for name in sorted(vars(args)):
- val = getattr(args, name)
- logging.info(f" {name}: {val}")
- f.write(f"{name}: {val}\n")
-
- if args.distributed and not args.horovod:
- if args.use_bn_sync:
- model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
- ddp_args = {}
- if args.ddp_static_graph:
- # this doesn't exist in older PyTorch, arg only added if enabled
- ddp_args["static_graph"] = True
- model = torch.nn.parallel.DistributedDataParallel(
- model, device_ids=[device], find_unused_parameters=True, **ddp_args
- )
-
- data = get_data(args, clap_model_cfg)
- assert len(data), "At least one train or eval dataset must be specified."
- if args.trace:
- assert "train" not in data, "Cannot train with traced model"
-
- optimizer, scheduler, text_freeze_parameters = config_lp_optimizer(
- model, data, args
- )
-
- scaler = GradScaler() if args.precision == "amp" else None
-
- # optionally resume from a checkpoint
- start_epoch = 0
- if args.resume is not None:
- if os.path.isfile(args.resume):
- checkpoint = torch.load(args.resume, map_location=device)
- if "epoch" in checkpoint:
- # resuming a train checkpoint w/ epoch and optimizer state
- start_epoch = checkpoint["epoch"]
- sd = checkpoint["state_dict"]
- if not args.distributed and next(iter(sd.items()))[0].startswith(
- "module"
- ):
- sd = {k[len("module.") :]: v for k, v in sd.items()}
- model.load_state_dict(sd)
- if args.split_opt:
- if optimizer is not None:
- for k, o_ in optimizer.items():
- o_.load_state_dict(checkpoint[k + "_" + "optimizer"])
- if optimizer is not None:
- optimizer.load_state_dict(checkpoint["optimizer"])
- if scaler is not None and "scaler" in checkpoint:
- scaler.load_state_dict(checkpoint["scaler"])
- logging.info(
- f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})"
- )
- else:
- # loading a bare (model only) checkpoint for fine-tune or evaluation
- model.load_state_dict(checkpoint)
- logging.info(
- f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})"
- )
- if args.freeze_text:
- print("Freeze Text!!!!")
- for k in text_freeze_parameters:
- k.requires_grad = False
- else:
- logging.info("=> no checkpoint found at '{}'".format(args.resume))
-
- cudnn.benchmark = True
- cudnn.deterministic = False
-
- # determine if this worker should save logs and checkpoints. only do so if it is rank == 0
- args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args)
- writer = None
- if args.save_logs and args.tensorboard:
- assert tensorboard is not None, "Please install tensorboard."
- writer = tensorboard.SummaryWriter(args.tensorboard_path)
-
- if args.wandb and is_master(args):
- assert wandb is not None, "Please install wandb."
- logging.debug("Starting wandb.")
- args.train_sz = data["train"].dataloader.num_samples
- if args.val_data is not None:
- args.val_sz = data["val"].dataloader.num_samples
- # you will have to configure this for your project!
- wandb.init(
- project="clap",
- notes=args.wandb_notes,
- name=args.wandb_notes,
- tags=[],
- config=vars(args),
- )
- if args.debug:
- wandb.watch(model, log="all")
- wandb.save(params_file)
- logging.debug("Finished loading wandb.")
-
- if "train" not in data:
- evaluate(model, data, start_epoch, args, writer)
- return
- elif start_epoch == 0 and "val" in data and not args.no_eval:
- evaluate(model, data, 0, args, writer)
- if args.save_top_performance:
- current_top_k_ckpt_metrics = {
- i: 0 for i in range(args.save_top_performance)
- } # initialize the top-k metric for ckpts to 0
-
- for epoch in range(start_epoch, args.epochs):
- # freeze the text param after (include) args.freeze_text_after, this is -1 by default
- if epoch == args.freeze_text_after:
- print("Text pretrained parameters are freezed since this epoch.")
- for k in text_freeze_parameters:
- k.requires_grad = False
- if is_master(args):
- logging.info(f"Start epoch {epoch}")
-
- train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer)
- completed_epoch = epoch + 1
-
- if (
- any(v in data for v in ("val", "imagenet-val", "imagenet-v2"))
- and not args.no_eval
- ):
- metrics = evaluate(model, data, completed_epoch, args, writer)
- if args.save_top_performance:
- top_k_dataset = args.top_k_checkpoint_select_dataset
- top_k_metric = args.top_k_checkpoint_select_metric
- filtered_metrics = [
- v
- for k, v in metrics.items()
- if top_k_metric in k and top_k_dataset in k
- ] # check all R@10 metrics (all dataset) and use it to update the ckpt
- # Saving checkpoints.
- if args.save_logs:
- opt_dict = {
- k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items()
- }
- checkpoint_dict = {
- "epoch": completed_epoch,
- "name": args.name,
- "state_dict": model.state_dict(),
- }
- checkpoint_dict.update(opt_dict)
- if scaler is not None:
- checkpoint_dict["scaler"] = scaler.state_dict()
-
- if completed_epoch == args.epochs or (
- args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
- ):
- torch.save(
- checkpoint_dict,
- os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
- )
- if args.save_most_recent:
- torch.save(
- checkpoint_dict,
- os.path.join(args.checkpoint_path, f"epoch_latest.pt"),
- )
- if args.save_top_performance and not args.no_eval:
- update_top_k_performance(
- filtered_metrics,
- current_top_k_ckpt_metrics,
- args,
- checkpoint_dict,
- bignumbetter=True,
- )
-
- if args.wandb and is_master(args):
- wandb.finish()
-
-
-def copy_codebase(args):
- from shutil import copytree, ignore_patterns
-
- new_code_path = os.path.join(args.logs, args.name, "code")
- if os.path.exists(new_code_path):
- print(
- f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
- )
- return -1
- print(f"Copying codebase to {new_code_path}")
- current_code_path = os.path.realpath(__file__)
- for _ in range(3):
- current_code_path = os.path.dirname(current_code_path)
- copytree(
- current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb")
- )
- print("Done copying code.")
- return 1
-
-
-if __name__ == "__main__":
- main()
diff --git a/audioldm/clap/training/lp_train.py b/audioldm/clap/training/lp_train.py
deleted file mode 100644
index 24a19bacd0a4b789415cfccbce1f8bc99bc493ed..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/lp_train.py
+++ /dev/null
@@ -1,301 +0,0 @@
-import json
-import logging
-import math
-import os
-import time
-from contextlib import suppress
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-
-try:
- import wandb
-except ImportError:
- wandb = None
-
-from open_clip import LPLoss, LPMetrics, lp_gather_features
-from open_clip.utils import do_mixup, get_mix_lambda
-from .distributed import is_master
-from .zero_shot import zero_shot_eval
-
-
-class AverageMeter(object):
- """Computes and stores the average and current value"""
-
- def __init__(self):
- self.reset()
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
-
-def unwrap_model(model):
- if hasattr(model, "module"):
- return model.module
- else:
- return model
-
-
-def train_one_epoch(
- model,
- data,
- epoch,
- optimizer,
- scaler,
- scheduler,
- args,
- tb_writer=None,
- extra_suffix="",
-):
- device = torch.device(args.device)
- autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
- model.train()
- loss = LPLoss(args.lp_loss)
-
- dataloader, sampler = data["train"].dataloader, data["train"].sampler
- if args.distributed and sampler is not None:
- sampler.set_epoch(epoch)
- num_batches_per_epoch = dataloader.num_batches
- sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
-
- # for toy dataset
- if args.dataset_type == "toy":
- dataloader.dataset.generate_queue()
-
- loss_m = AverageMeter()
- batch_time_m = AverageMeter()
- data_time_m = AverageMeter()
- end = time.time()
-
- for i, batch in enumerate(dataloader):
- step = num_batches_per_epoch * epoch + i
-
- if isinstance(scheduler, dict):
- for s in scheduler.values():
- s(step)
- else:
- scheduler(step)
-
- audio = batch # contains mel_spec, wavform, and longer list
- class_label = batch["class_label"]
- # audio = audio.to(device=device, non_blocking=True)
- class_label = class_label.to(device=device, non_blocking=True)
-
- if args.mixup:
- # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146
- mix_lambda = torch.from_numpy(
- get_mix_lambda(0.5, len(audio["waveform"]))
- ).to(device)
- class_label = do_mixup(class_label, mix_lambda)
- else:
- mix_lambda = None
-
- data_time_m.update(time.time() - end)
- if isinstance(optimizer, dict):
- for o_ in optimizer.values():
- o_.zero_grad()
- else:
- optimizer.zero_grad()
-
- with autocast():
- pred = model(audio, mix_lambda=mix_lambda, device=device)
- total_loss = loss(pred, class_label)
-
- if isinstance(optimizer, dict):
- if scaler is not None:
- scaler.scale(total_loss).backward()
- for o_ in optimizer.values():
- if args.horovod:
- o_.synchronize()
- scaler.unscale_(o_)
- with o_.skip_synchronize():
- scaler.step(o_)
- else:
- scaler.step(o_)
- scaler.update()
- else:
- total_loss.backward()
- for o_ in optimizer.values():
- o_.step()
- else:
- if scaler is not None:
- scaler.scale(total_loss).backward()
- if args.horovod:
- optimizer.synchronize()
- scaler.unscale_(optimizer)
- with optimizer.skip_synchronize():
- scaler.step(optimizer)
- else:
- scaler.step(optimizer)
- scaler.update()
- else:
- total_loss.backward()
- optimizer.step()
-
- # Note: we clamp to 4.6052 = ln(100), as in the original paper.
- with torch.no_grad():
- unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100))
- unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100))
-
- batch_time_m.update(time.time() - end)
- end = time.time()
- batch_count = i + 1
-
- if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
- if isinstance(audio, dict):
- batch_size = len(audio["waveform"])
- else:
- batch_size = len(audio)
- num_samples = batch_count * batch_size * args.world_size
- samples_per_epoch = dataloader.num_samples
- percent_complete = 100.0 * batch_count / num_batches_per_epoch
-
- # NOTE loss is coarsely sampled, just master node and per log update
- loss_m.update(total_loss.item(), batch_size)
- if isinstance(optimizer, dict):
- logging.info(
- f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
- f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
- f"Data (t): {data_time_m.avg:.3f} "
- f"Batch (t): {batch_time_m.avg:.3f} "
- f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}"
- )
- log_data = {
- "loss": loss_m.val,
- "data_time": data_time_m.val,
- "batch_time": batch_time_m.val,
- "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
- }
- else:
- logging.info(
- f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
- f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
- f"Data (t): {data_time_m.avg:.3f} "
- f"Batch (t): {batch_time_m.avg:.3f} "
- f"LR: {optimizer.param_groups[0]['lr']:5f} "
- )
-
- # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
- log_data = {
- "loss": loss_m.val,
- "data_time": data_time_m.val,
- "batch_time": batch_time_m.val,
- "lr": optimizer.param_groups[0]["lr"],
- }
- for name, val in log_data.items():
- name = f"train{extra_suffix}/{name}"
- if tb_writer is not None:
- tb_writer.add_scalar(name, val, step)
- if args.wandb:
- assert wandb is not None, "Please install wandb."
- wandb.log({name: val, "step": step})
-
- # resetting batch / data time meters per log window
- batch_time_m.reset()
- data_time_m.reset()
- # end for
-
-
-def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""):
- metrics = {}
- if not args.parallel_eval:
- if not is_master(args):
- return metrics
- device = torch.device(args.device)
- model.eval()
-
- # CHANGE
- # zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
- # metrics.update(zero_shot_metrics)
- if is_master(args):
- print("Evaluating...")
- metric_names = args.lp_metrics.split(",")
- eval_tool = LPMetrics(metric_names=metric_names)
-
- autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
- if "val" in data and (
- args.val_frequency
- and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
- ):
- if args.parallel_eval:
- dataloader, sampler = data["val"].dataloader, data["val"].sampler
- if args.distributed and sampler is not None:
- sampler.set_epoch(epoch)
- samples_per_val = dataloader.num_samples
- else:
- dataloader = data["val"].dataloader
- num_samples = 0
- samples_per_val = dataloader.num_samples
-
- eval_info = {"pred": [], "target": []}
- with torch.no_grad():
- for i, batch in enumerate(dataloader):
- audio = batch # contains mel_spec, wavform, and longer list
- class_label = batch["class_label"]
-
- # audio = audio.to(device=device, non_blocking=True)
- class_label = class_label.to(device=device, non_blocking=True)
-
- with autocast():
- pred = model(audio, device=device)
- if args.parallel_eval:
- pred, class_label = lp_gather_features(
- pred, class_label, args.world_size, args.horovod
- )
- eval_info["pred"].append(pred)
- eval_info["target"].append(class_label)
-
- num_samples += class_label.shape[0]
-
- if (i % 100) == 0: # and i != 0:
- logging.info(
- f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
- )
-
- if is_master(args):
- eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu()
- eval_info["target"] = torch.cat(eval_info["target"], 0).cpu()
- metric_dict = eval_tool.evaluate_mertics(
- eval_info["pred"], eval_info["target"]
- )
- metrics.update(metric_dict)
- if "epoch" not in metrics.keys():
- metrics.update({"epoch": epoch})
-
- if is_master(args):
- if not metrics:
- return metrics
-
- logging.info(
- f"Eval Epoch: {epoch} "
- + "\n".join(
- ["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics]
- )
- )
- if args.save_logs:
- for name, val in metrics.items():
- if tb_writer is not None:
- tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch)
-
- with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
- f.write(json.dumps(metrics))
- f.write("\n")
-
- if args.wandb:
- assert wandb is not None, "Please install wandb."
- for name, val in metrics.items():
- wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch})
-
- return metrics
- else:
- return metrics
diff --git a/audioldm/clap/training/main.py b/audioldm/clap/training/main.py
deleted file mode 100644
index 3b563a5d001be7adfbe779dee7ad8ac49aadc50d..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/main.py
+++ /dev/null
@@ -1,596 +0,0 @@
-from inspect import getargs
-import logging
-import os
-import random
-from datetime import datetime
-import bisect
-import copy
-import numpy as np
-import torch
-import torch.backends.cudnn as cudnn
-from torch import optim
-from torch.cuda.amp import GradScaler
-import faulthandler
-import pathlib
-
-try:
- import wandb
-except ImportError:
- wandb = None
-
-try:
- import torch.utils.tensorboard as tensorboard
-except ImportError:
- tensorboard = None
-
-try:
- import horovod.torch as hvd
-except ImportError:
- hvd = None
-
-from open_clip import create_model_and_transforms, trace_model, create_model
-from training.data import get_data
-from training.distributed import is_master, init_distributed_device, world_info_from_env
-from training.logger import setup_logging
-from training.params import parse_args
-from training.scheduler import cosine_lr
-from training.train import train_one_epoch, evaluate
-from open_clip.utils import dataset_split, get_optimizer
-
-
-def maintain_ckpts(args, startidx, all_idx_len):
- for i in reversed(range(startidx, all_idx_len)):
- if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")):
- os.rename(
- os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
- os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"),
- )
- if os.path.exists(
- os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")
- ):
- os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt"))
- return
-
-
-def update_top_k_performance(
- new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True
-):
- """
- Record the top-k performance of the current epoch.
- current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...}
- """
- if isinstance(new_metrics_inputs, (list, tuple)):
- new_metrics_inputs = np.mean(new_metrics_inputs)
- return update_top_k_performance(
- new_metrics_inputs,
- current_top_k_ckpt_metrics,
- args=args,
- ckpt=ckpt,
- bignumbetter=bignumbetter,
- )
- elif isinstance(new_metrics_inputs, dict):
- new_metrics_inputs = np.mean(list(new_metrics_inputs.values()))
- return update_top_k_performance(
- new_metrics_inputs,
- current_top_k_ckpt_metrics,
- args=args,
- ckpt=ckpt,
- bignumbetter=bignumbetter,
- )
- elif isinstance(new_metrics_inputs, (float, int)):
- update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()}
- sorted_keys = sorted(current_top_k_ckpt_metrics.keys())
- sorted_values = sorted(
- current_top_k_ckpt_metrics.values(), reverse=bignumbetter
- )
- sorted_values_ = copy.deepcopy(sorted_values)
- sorted_values.append(new_metrics_inputs)
- sorted_values = sorted(sorted_values, reverse=bignumbetter)
- sorted_values = sorted_values[:-1]
-
- if sorted_values == sorted_values_:
- return current_top_k_ckpt_metrics, new_metrics_inputs
- else:
- for i in range(len(sorted_keys)):
- if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]:
- current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i]
- update_flag[sorted_keys[i]] = True
- for i in range(len(update_flag)):
- if update_flag[i]:
- maintain_ckpts(args, i, len(sorted_keys))
- torch.save(
- ckpt,
- os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
- )
- break
- return current_top_k_ckpt_metrics, new_metrics_inputs
-
-
-# def updateifNone(a, b):
-# a = b if None else a
-# return a
-
-
-def is_pretrained_params(n):
- return (
- n.startswith("transformer")
- or n in ["positional_embedding", "text_projection"]
- or n.startswith("token_embedding")
- or n.startswith("ln_final")
- or n.startswith("logit_scale_t")
- )
-
-
-def random_seed(seed=42, rank=0):
- torch.manual_seed(seed + rank)
- np.random.seed(seed + rank)
- random.seed(seed + rank)
-
-
-def main():
- args = parse_args()
- # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
- args.amodel = args.amodel.replace("/", "-")
- # download sizes.json file
-
- # (yusong): the below two lines are for debug
- # print("setting up faulthandler")
- # faulthandler.register(10)
-
- random.seed(args.seed)
- torch.manual_seed(args.seed)
- torch.cuda.manual_seed(args.seed)
- torch.cuda.manual_seed_all(args.seed)
- np.random.seed(args.seed)
- if args.tmodel == "bert" or args.tmodel == "roberta" or args.tmodel == "bart":
- assert (
- args.pretrained == "" or args.pretrained is None
- ), "bert/roberta/bart text encoder does not support pretrained models."
-
- # get the name of the experiments
- if args.name is None:
- args.name = "-".join(
- [
- datetime.now().strftime("%Y_%m_%d-%H_%M_%S"),
- f"model_{args.amodel}",
- f"lr_{args.lr}",
- f"b_{args.batch_size}",
- f"j_{args.workers}",
- f"p_{args.precision}",
- ]
- )
-
- # discover initial world args early so we can log properly
- args.distributed = False
- args.local_rank, args.rank, args.world_size = world_info_from_env()
-
- if args.remotedata and is_master(args):
- for dataset_name in args.datasetnames:
- for split in dataset_split[dataset_name]:
- if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
- os.makedirs(f"./json_files/{dataset_name}/{split}")
- os.system(
- f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
- )
-
- args.log_path = None
- if is_master(args, local=args.log_local):
- log_base_path = os.path.join(args.logs, args.name)
- os.makedirs(log_base_path, exist_ok=True)
- log_filename = f"out-{args.rank}" if args.log_local else "out.log"
- args.log_path = os.path.join(log_base_path, log_filename)
- if os.path.exists(args.log_path):
- print(
- "Error. Experiment already exists. Use --name {} to specify a new experiment."
- )
- return -1
-
- # Set logger
- args.log_level = logging.DEBUG if args.debug else logging.INFO
- setup_logging(args.log_path, args.log_level)
-
- # fully initialize distributed device environment
- device = init_distributed_device(args)
-
- args.wandb = "wandb" in args.report_to or "all" in args.report_to
- args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to
- if is_master(args):
- args.tensorboard_path = (
- os.path.join(args.logs, args.name, "tensorboard")
- if args.tensorboard
- else ""
- )
- args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
- for dirname in [args.tensorboard_path, args.checkpoint_path]:
- if dirname:
- os.makedirs(dirname, exist_ok=True)
- else:
- args.tensorboard_path = ""
- args.checkpoint_path = ""
-
- if args.copy_codebase:
- copy_codebase(args)
-
- assert args.precision in ["amp", "fp16", "fp32"]
- if args.precision == "fp16":
- logging.warning(
- "It is recommended to use AMP mixed-precision instead of FP16. "
- "FP16 support needs further verification and tuning, especially for train."
- )
-
- if args.horovod:
- logging.info(
- f"Running in horovod mode with multiple processes / nodes. Device: {args.device}."
- f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
- )
- elif args.distributed:
- logging.info(
- f"Running in distributed mode with multiple processes. Device: {args.device}."
- f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
- )
- else:
- logging.info(f"Running with a single process. Device {args.device}.")
-
- logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}")
-
- model, model_cfg = create_model(
- args.amodel,
- args.tmodel,
- args.pretrained,
- precision=args.precision,
- device=device,
- jit=args.torchscript,
- force_quick_gelu=args.force_quick_gelu,
- openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
- skip_params=True,
- pretrained_audio=args.pretrained_audio,
- pretrained_text=args.pretrained_text,
- enable_fusion=args.enable_fusion,
- fusion_type=args.fusion_type,
- )
-
- if args.horovod:
- with torch.no_grad():
- for param in model.parameters():
- param.set_(param.contiguous())
-
- if args.trace:
- model = trace_model(model, batch_size=args.batch_size, device=device)
-
- if is_master(args):
- logging.info("Model:")
- logging.info(f"{str(model)}")
- logging.info("Params:")
- params_file = os.path.join(args.logs, args.name, "params.txt")
- with open(params_file, "w") as f:
- for name in sorted(vars(args)):
- val = getattr(args, name)
- logging.info(f" {name}: {val}")
- f.write(f"{name}: {val}\n")
-
- if args.distributed and not args.horovod:
- if args.use_bn_sync:
- model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
- ddp_args = {}
- if args.ddp_static_graph:
- # this doesn't exist in older PyTorch, arg only added if enabled
- ddp_args["static_graph"] = True
- model = torch.nn.parallel.DistributedDataParallel(
- model, device_ids=[device], find_unused_parameters=True, **ddp_args
- )
-
- data = get_data(args, model_cfg)
- assert len(data), "At least one train or eval dataset must be specified."
- if args.trace:
- assert "train" not in data, "Cannot train with traced model"
-
- exclude = (
- lambda n, p: p.ndim < 2
- or "bn" in n
- or "ln" in n
- or "bias" in n
- or "logit_scale" in n
- )
- include = lambda n, p: not exclude(n, p)
-
- named_parameters = list(model.named_parameters())
-
- # freeze text encoder
- text_freeze_parameters = [p for n, p in named_parameters if "text_branch" in n]
-
- if args.freeze_text:
- print("Freeze Text!!!!")
- for k in text_freeze_parameters:
- k.requires_grad = False
-
- gain_or_bias_params = [
- p for n, p in named_parameters if exclude(n, p) and p.requires_grad
- ]
- rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]
-
- # set wd-related params to 0 if use adam optimizer
- if args.optimizer == "adam":
- args.wd = 0
- args.wd_pretrained = 0
- args.wd_new = 0
-
- if args.train_data is None:
- optimizer = None
- scheduler = None
- else:
- total_steps = data["train"].dataloader.num_batches * args.epochs
-
- if args.split_opt:
- for x in ["lr", "beta1", "beta2", "eps", "wd"]:
- for y in ["_new", "_pretrained"]:
- if getattr(args, x + y) is None:
- setattr(args, x + y, getattr(args, x))
-
- gain_or_bias_pretrained_params = [
- p
- for n, p in named_parameters
- if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n)
- ]
- rest_pretrained_params = [
- p
- for n, p in named_parameters
- if (include(n, p) and p.requires_grad) and is_pretrained_params(n)
- ]
- gain_or_bias_new_params = [
- p
- for n, p in named_parameters
- if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n))
- ]
- rest_new_params = [
- p
- for n, p in named_parameters
- if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n))
- ]
- pretrained_params_optimizer = get_optimizer(
- [
- {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0},
- {
- "params": rest_pretrained_params,
- "weight_decay": args.wd_pretrained,
- },
- ],
- lr=args.lr_pretrained,
- betas=(args.beta1_pretrained, args.beta2_pretrained),
- eps=args.eps_pretrained,
- momentum=args.momentum_pretrained,
- optimizer_name=args.optimizer,
- )
- pretrained_params_scheduler = cosine_lr(
- pretrained_params_optimizer,
- args.lr_pretrained,
- args.warmup,
- total_steps,
- )
- new_params_optimizer = get_optimizer(
- [
- {"params": gain_or_bias_new_params, "weight_decay": 0.0},
- {"params": rest_new_params, "weight_decay": args.wd_new},
- ],
- lr=args.lr_new,
- betas=(args.beta1_new, args.beta2_new),
- eps=args.eps_new,
- momentum=args.momentum_new,
- optimizer_name=args.optimizer,
- )
-
- new_params_scheduler = cosine_lr(
- new_params_optimizer, args.lr_new, args.warmup, total_steps
- )
-
- optimizer = {
- "pretrained": pretrained_params_optimizer,
- "new": new_params_optimizer,
- }
- scheduler = {
- "pretrained": pretrained_params_scheduler,
- "new": new_params_scheduler,
- }
-
- if args.horovod:
- pretrained_params_optimizer = hvd.DistributedOptimizer(
- pretrained_params_optimizer,
- named_parameters=model.named_parameters(),
- )
- new_params_optimizer = hvd.DistributedOptimizer(
- new_params_optimizer, named_parameters=model.named_parameters()
- )
- hvd.broadcast_parameters(model.state_dict(), root_rank=0)
- hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0)
- hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0)
- else:
- optimizer = get_optimizer(
- [
- {"params": gain_or_bias_params, "weight_decay": 0.0},
- {"params": rest_params, "weight_decay": args.wd},
- ],
- lr=args.lr,
- betas=(args.beta1, args.beta2),
- eps=args.eps,
- momentum=args.momentum,
- optimizer_name=args.optimizer,
- )
-
- scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)
-
- if args.horovod:
- optimizer = hvd.DistributedOptimizer(
- optimizer, named_parameters=model.named_parameters()
- )
- hvd.broadcast_parameters(model.state_dict(), root_rank=0)
- hvd.broadcast_optimizer_state(optimizer, root_rank=0)
-
- scaler = GradScaler() if args.precision == "amp" else None
-
- # optionally resume from a checkpoint
- start_epoch = 0
- if args.resume is not None:
- if os.path.isfile(args.resume):
- checkpoint = torch.load(args.resume, map_location=device)
- if "epoch" in checkpoint:
- # resuming a train checkpoint w/ epoch and optimizer state
- start_epoch = checkpoint["epoch"]
- sd = checkpoint["state_dict"]
- if not args.distributed and next(iter(sd.items()))[0].startswith(
- "module"
- ):
- sd = {k[len("module.") :]: v for k, v in sd.items()}
- model.load_state_dict(sd)
- if args.split_opt:
- if optimizer is not None:
- for k, o_ in optimizer.items():
- o_.load_state_dict(checkpoint[k + "_" + "optimizer"])
- if optimizer is not None:
- optimizer.load_state_dict(checkpoint["optimizer"])
- if scaler is not None and "scaler" in checkpoint:
- scaler.load_state_dict(checkpoint["scaler"])
- logging.info(
- f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})"
- )
- else:
- # loading a bare (model only) checkpoint for fine-tune or evaluation
- model.load_state_dict(checkpoint)
- logging.info(
- f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})"
- )
- if args.freeze_text:
- print("Freeze Text!!!!")
- for k in text_freeze_parameters:
- k.requires_grad = False
- else:
- logging.info("=> no checkpoint found at '{}'".format(args.resume))
-
- cudnn.benchmark = True
- cudnn.deterministic = False
-
- # determine if this worker should save logs and checkpoints. only do so if it is rank == 0
- args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args)
- writer = None
- if args.save_logs and args.tensorboard:
- assert tensorboard is not None, "Please install tensorboard."
- writer = tensorboard.SummaryWriter(args.tensorboard_path)
-
- if args.wandb and is_master(args):
- assert wandb is not None, "Please install wandb."
- logging.debug("Starting wandb.")
- args.train_sz = data["train"].dataloader.num_samples
- if args.val_data is not None:
- args.val_sz = data["val"].dataloader.num_samples
- # you will have to configure this for your project!
- wandb.init(
- project="clap",
- notes=args.wandb_notes,
- name=args.wandb_notes,
- tags=[],
- config=vars(args),
- )
- if args.debug:
- wandb.watch(model, log="all")
- wandb.save(params_file)
- logging.debug("Finished loading wandb.")
-
- if "train" not in data:
- evaluate(model, data, start_epoch, args, writer)
- return
- elif start_epoch == 0 and "val" in data and not args.no_eval:
- evaluate(model, data, 0, args, writer)
- # print(f'rank {args.rank}, Start First Evaluation')# (yusong): for debug
- if args.save_top_performance:
- current_top_k_ckpt_metrics = {
- i: 0 for i in range(args.save_top_performance)
- } # initialize the top-k metric for ckpts to 0
-
- # print(f'rank {args.rank}, Start Training') # (yusong): for debug
- for epoch in range(start_epoch, args.epochs):
- # freeze the text param after (include) args.freeze_text_after, this is -1 by default
- if epoch == args.freeze_text_after:
- print("Text pretrained parameters are freezed since this epoch.")
- for k in text_freeze_parameters:
- k.requires_grad = False
- if is_master(args):
- logging.info(f"Start epoch {epoch}")
-
- train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer)
- completed_epoch = epoch + 1
-
- if (
- any(v in data for v in ("val", "imagenet-val", "imagenet-v2"))
- and not args.no_eval
- ):
- metrics = evaluate(model, data, completed_epoch, args, writer)
- if args.save_top_performance:
- top_k_dataset = args.top_k_checkpoint_select_dataset
- top_k_metric = args.top_k_checkpoint_select_metric
- filtered_metrics = [
- v
- for k, v in metrics.items()
- if top_k_metric in k and top_k_dataset in k
- ] # check all R@10 metrics (all dataset) and use it to update the ckpt
- # Saving checkpoints.
- if args.save_logs:
- if args.split_opt:
- opt_dict = {
- k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items()
- }
- else:
- opt_dict = {"optimizer": optimizer.state_dict()}
- checkpoint_dict = {
- "epoch": completed_epoch,
- "name": args.name,
- "state_dict": model.state_dict(),
- }
- checkpoint_dict.update(opt_dict)
- if scaler is not None:
- checkpoint_dict["scaler"] = scaler.state_dict()
-
- if completed_epoch == args.epochs or (
- args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
- ):
- torch.save(
- checkpoint_dict,
- os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
- )
- if args.save_most_recent:
- torch.save(
- checkpoint_dict,
- os.path.join(args.checkpoint_path, f"epoch_latest.pt"),
- )
- if args.save_top_performance and not args.no_eval:
- update_top_k_performance(
- filtered_metrics,
- current_top_k_ckpt_metrics,
- args,
- checkpoint_dict,
- bignumbetter=True,
- )
-
- if args.wandb and is_master(args):
- wandb.finish()
-
-
-def copy_codebase(args):
- from shutil import copytree, ignore_patterns
-
- new_code_path = os.path.join(args.logs, args.name, "code")
- if os.path.exists(new_code_path):
- print(
- f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
- )
- return -1
- print(f"Copying codebase to {new_code_path}")
- current_code_path = os.path.realpath(__file__)
- for _ in range(3):
- current_code_path = os.path.dirname(current_code_path)
- copytree(
- current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb")
- )
- print("Done copying code.")
- return 1
-
-
-if __name__ == "__main__":
- main()
diff --git a/audioldm/clap/training/params.py b/audioldm/clap/training/params.py
deleted file mode 100644
index 0cc1a0e2d982e900988cf5a4b24b2e59b093537b..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/params.py
+++ /dev/null
@@ -1,563 +0,0 @@
-import argparse
-
-
-def get_default_params(model_name):
- # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
- model_name = model_name.lower()
- if "vit" in model_name:
- return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
- else:
- return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--train-data",
- type=str,
- default=None,
- help="Path to h5 filewith training data",
- )
- parser.add_argument(
- "--val-data",
- type=str,
- default=None,
- help="Path to h5 file with validation data",
- )
- parser.add_argument(
- "--freeze-text",
- default=False,
- action="store_true",
- help="if you need to freeze the text encoder, make this True",
- )
- parser.add_argument(
- "--freeze-text-after",
- type=int,
- default=-1,
- help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it",
- )
- parser.add_argument(
- "--train-ipc",
- type=str,
- default=None,
- help="Path to npy file of the number of instance per class in training data",
- )
- parser.add_argument(
- "--val-ipc",
- type=str,
- default=None,
- help="Path to npy file of the number of instance per class in validation data",
- )
- parser.add_argument(
- "--train-num-samples",
- type=int,
- default=None,
- help="Number of samples in dataset. Required for webdataset if not available in info file.",
- )
- parser.add_argument(
- "--val-num-samples",
- type=int,
- default=None,
- help="Number of samples in dataset. Useful for webdataset if not available in info file.",
- )
- parser.add_argument(
- "--dataset-type",
- choices=["webdataset", "csv", "auto", "toy"],
- default="auto",
- help="Which type of dataset to process.",
- )
- parser.add_argument(
- "--csv-separator",
- type=str,
- default="\t",
- help="For csv-like datasets, which separator to use.",
- )
- parser.add_argument(
- "--csv-img-key",
- type=str,
- default="filepath",
- help="For csv-like datasets, the name of the key for the image paths.",
- )
- parser.add_argument(
- "--csv-caption-key",
- type=str,
- default="title",
- help="For csv-like datasets, the name of the key for the captions.",
- )
- parser.add_argument(
- "--imagenet-val",
- type=str,
- default=None,
- help="Path to imagenet val set for conducting zero shot evaluation.",
- )
- parser.add_argument(
- "--imagenet-v2",
- type=str,
- default=None,
- help="Path to imagenet v2 for conducting zero shot evaluation.",
- )
- parser.add_argument(
- "--datasetnames",
- nargs="+",
- default=None,
- help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects",
- )
- parser.add_argument(
- "--full-train-dataset",
- nargs="+",
- default=None,
- help="Which dataset will be trained with all the subsets. (train+test)",
- )
- parser.add_argument(
- "--exclude-eval-dataset",
- nargs="+",
- default=None,
- help="Which dataset will be excluded with evaluation",
- )
- parser.add_argument(
- "--datasetinfos",
- nargs="+",
- default=None,
- help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval",
- )
- parser.add_argument(
- "--dataset-proportion",
- type=float,
- default=1.0,
- help="How much proportion of dataset we want to train.",
- )
- parser.add_argument(
- "--remotedata",
- default=False,
- action="store_true",
- help="if the dataset is remote, set this flag",
- )
- parser.add_argument(
- "--class-label-path",
- type=str,
- default=None,
- help="The path of the class label pickle or csv.",
- )
- parser.add_argument(
- "--datasetpath",
- type=str,
- default="/mnt/audio_clip/webdataset_tar",
- help="The path to the dataset",
- )
- parser.add_argument(
- "--logs",
- type=str,
- default="./logs/",
- help="Where to store tensorboard logs. Use None to avoid storing logs.",
- )
- parser.add_argument(
- "--log-local",
- action="store_true",
- default=False,
- help="log files on local master, otherwise global master only.",
- )
- parser.add_argument(
- "--name",
- type=str,
- default=None,
- help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
- )
- parser.add_argument(
- "--workers", type=int, default=1, help="Number of workers per GPU."
- )
- parser.add_argument(
- "--batch-size", type=int, default=64, help="Batch size per GPU."
- )
- parser.add_argument(
- "--epochs", type=int, default=32, help="Number of epochs to train for."
- )
- parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
- parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
- parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
- parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
- parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.")
- parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
-
- parser.add_argument(
- "--split-opt",
- action="store_true",
- default=False,
- help="Use this flag to skip the learning rate decay.",
- )
- parser.add_argument(
- "--lr-pretrained", type=float, default=None, help="Learning rate for text."
- )
- parser.add_argument(
- "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text."
- )
- parser.add_argument(
- "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text."
- )
- parser.add_argument(
- "--eps-pretrained", type=float, default=None, help="Adam epsilon for text."
- )
- parser.add_argument(
- "--wd-pretrained", type=float, default=0.2, help="Weight decay for text."
- )
- parser.add_argument(
- "--momentum-pretrained", type=float, default=0.9, help="Momentum for text."
- )
- parser.add_argument(
- "--lr-new", type=float, default=None, help="Learning rate for audio."
- )
- parser.add_argument(
- "--beta1-new", type=float, default=None, help="Adam beta 1 for audio."
- )
- parser.add_argument(
- "--beta2-new", type=float, default=None, help="Adam beta 2 for audio."
- )
- parser.add_argument(
- "--eps-new", type=float, default=None, help="Adam epsilon for audio."
- )
- parser.add_argument(
- "--wd-new", type=float, default=0.2, help="Weight decay for audio."
- )
- parser.add_argument(
- "--momentum-new", type=float, default=0.9, help="Momentum for audio."
- )
- parser.add_argument(
- "--warmup", type=int, default=10000, help="Number of steps to warmup for."
- )
- parser.add_argument(
- "--use-bn-sync",
- default=False,
- action="store_true",
- help="Whether to use batch norm sync.",
- )
- parser.add_argument(
- "--skip-scheduler",
- action="store_true",
- default=False,
- help="Use this flag to skip the learning rate decay.",
- )
- parser.add_argument(
- "--save-frequency", type=int, default=1, help="How often to save checkpoints."
- )
- parser.add_argument(
- "--save-top-performance",
- type=int,
- default=0,
- help="Save the top x performance weights if the value >0",
- )
- parser.add_argument(
- "--save-most-recent",
- action="store_true",
- default=False,
- help="Always save the most recent model trained to epoch_latest.pt.",
- )
- parser.add_argument(
- "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
- )
- parser.add_argument(
- "--val-frequency",
- type=int,
- default=1,
- help="How often to run evaluation with val data.",
- )
- parser.add_argument(
- "--resume",
- default=None,
- type=str,
- help="path to latest checkpoint (default: none)",
- )
- parser.add_argument(
- "--precision",
- choices=["amp", "fp16", "fp32"],
- default="amp",
- help="Floating point precision.",
- )
- parser.add_argument(
- "--amodel",
- type=str,
- default="RN50",
- help="Name of the audio backbone to use.",
- )
- parser.add_argument(
- "--tmodel",
- type=str,
- default="transformer",
- help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]",
- )
- parser.add_argument(
- "--pretrained-audio",
- default="",
- type=str,
- help="Use a pretrained audio model weights for the audio encoder of CLAP",
- )
- parser.add_argument(
- "--pretrained-text",
- default="",
- type=str,
- help="Use a pretrained text model weights for the text encoder of CLAP",
- )
- parser.add_argument(
- "--pretrained",
- default="",
- type=str,
- help="Use a pretrained CLIP model weights with the specified tag or file path.",
- )
- parser.add_argument(
- "--pretrained-image",
- default=False,
- action="store_true",
- help="Load imagenet pretrained weights for image tower backbone if available.",
- )
- parser.add_argument(
- "--lock-image",
- default=False,
- action="store_true",
- help="Lock full image tower by disabling gradients.",
- )
- parser.add_argument(
- "--lock-image-unlocked-groups",
- type=int,
- default=0,
- help="Leave last n image tower layer groups unlocked.",
- )
- parser.add_argument(
- "--lock-image-freeze-bn-stats",
- default=False,
- action="store_true",
- help="Freeze BatchNorm running stats in image tower for any locked layers.",
- )
- parser.add_argument(
- "--local-loss",
- default=False,
- action="store_true",
- help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)",
- )
- parser.add_argument(
- "--gather-with-grad",
- default=False,
- action="store_true",
- help="enable full distributed gradient for feature gather",
- )
- parser.add_argument(
- "--force-quick-gelu",
- default=False,
- action="store_true",
- help="Force use of QuickGELU activation for non-OpenAI transformer models.",
- )
- parser.add_argument(
- "--torchscript",
- default=False,
- action="store_true",
- help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
- )
- parser.add_argument(
- "--trace",
- default=False,
- action="store_true",
- help="torch.jit.trace the model for inference / eval only",
- )
- # arguments for distributed training
- parser.add_argument(
- "--dist-url",
- default="env://",
- type=str,
- help="url used to set up distributed training",
- )
- parser.add_argument(
- "--dist-backend", default="nccl", type=str, help="distributed backend"
- )
- parser.add_argument(
- "--report-to",
- default="",
- type=str,
- help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']",
- )
- parser.add_argument(
- "--wandb-notes", default="", type=str, help="Notes if logging with wandb"
- )
- parser.add_argument(
- "--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
- )
- parser.add_argument(
- "--debug",
- default=False,
- action="store_true",
- help="If true, more information is logged.",
- )
- parser.add_argument(
- "--copy-codebase",
- default=False,
- action="store_true",
- help="If true, we copy the entire base on the log diretory, and execute from there.",
- )
- parser.add_argument(
- "--horovod",
- default=False,
- action="store_true",
- help="Use horovod for distributed training.",
- )
- parser.add_argument(
- "--ddp-static-graph",
- default=False,
- action="store_true",
- help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
- )
- parser.add_argument(
- "--no-set-device-rank",
- default=False,
- action="store_true",
- help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
- )
- parser.add_argument("--seed", type=int, default=4242, help="Default random seed.")
-
- parser.add_argument(
- "--top-k-checkpoint-select-dataset",
- type=str,
- default="all",
- help="The dataset of selecting top-k checkpoint.",
- )
-
- # @R10, @R@5, @R1, mAP@10
- parser.add_argument(
- "--top-k-checkpoint-select-metric",
- type=str,
- default="_R@10",
- help="The metric for selecting top-k checkpoint.",
- )
- parser.add_argument(
- "--openai-model-cache-dir",
- type=str,
- default="~/.cache/clip",
- help="Directory to download OpenAI models.",
- )
- parser.add_argument(
- "--optimizer",
- type=str,
- default="adamw",
- help="can be AdamW or SGD",
- )
- parser.add_argument(
- "--parallel-eval",
- default=False,
- action="store_true",
- help="Eval in parallel (multi-GPU, multi-node).",
- )
-
- parser.add_argument(
- "--no-eval",
- default=False,
- action="store_true",
- help="Training without evaluation.",
- )
-
- parser.add_argument(
- "--lp-mlp",
- default=False,
- action="store_true",
- help="Linear Probe using MLP layer or not.",
- )
-
- parser.add_argument(
- "--lp-freeze",
- default=False,
- action="store_true",
- help="Linear Probe using Freeze CLAP or not",
- )
-
- parser.add_argument(
- "--lp-act",
- default="None",
- type=str,
- help="Options are ['relu','elu','prelu','softmax','sigmoid']",
- )
-
- parser.add_argument(
- "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe."
- )
-
- parser.add_argument(
- "--lp-metrics",
- type=str,
- default="map,mauc,acc",
- help="Metrics of Linear Probe.",
- )
-
- parser.add_argument(
- "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe"
- )
- parser.add_argument(
- "--kappa",
- type=float,
- default=0,
- help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss",
- )
-
- parser.add_argument(
- "--data-filling",
- type=str,
- default="pad",
- help="type of data filling when the audio length is shorter than the max length."
- "Can be one of the following: repeat, repeatpad, pad",
- )
- parser.add_argument(
- "--data-truncating",
- type=str,
- default="rand_trunc",
- help="type of data truncation when the audio length is longer than the max length."
- "Can be one of the following: rand_trunc, fusion",
- )
-
- parser.add_argument(
- "--clap-mlploss",
- default=False,
- action="store_true",
- help="Using MLP loss for CLAP model or not",
- )
-
- parser.add_argument(
- "--wandb-id",
- type=str,
- default=None,
- help="the id of wandb experiment to restore.",
- )
-
- parser.add_argument(
- "--sleep", type=float, default=0, help="sleep n seconds before start training"
- )
-
- # variable length processing
- parser.add_argument(
- "--enable-fusion",
- default=False,
- action="store_true",
- help="Enable feature funsion for variable-length data",
- )
-
- parser.add_argument(
- "--fusion-type",
- type=str,
- default="None",
- help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']",
- )
-
- parser.add_argument(
- "--mixup",
- default=False,
- action="store_true",
- help="Enable mixup in finetuning training.",
- )
- parser.add_argument(
- "--text-augment-selection",
- type=str,
- default=None,
- help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']",
- )
-
- args = parser.parse_args()
-
- # If some params are not passed, we use the default values based on model name.
- default_params = get_default_params(args.amodel)
- for name, val in default_params.items():
- if getattr(args, name) is None:
- setattr(args, name, val)
-
- return args
diff --git a/audioldm/clap/training/scheduler.py b/audioldm/clap/training/scheduler.py
deleted file mode 100644
index 7151ffbab25a113673b7627027b443b27f22cb0f..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/scheduler.py
+++ /dev/null
@@ -1,24 +0,0 @@
-import numpy as np
-
-
-def assign_learning_rate(optimizer, new_lr):
- for param_group in optimizer.param_groups:
- param_group["lr"] = new_lr
-
-
-def _warmup_lr(base_lr, warmup_length, step):
- return base_lr * (step + 1) / warmup_length
-
-
-def cosine_lr(optimizer, base_lr, warmup_length, steps):
- def _lr_adjuster(step):
- if step < warmup_length:
- lr = _warmup_lr(base_lr, warmup_length, step)
- else:
- e = step - warmup_length
- es = steps - warmup_length
- lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
- assign_learning_rate(optimizer, lr)
- return lr
-
- return _lr_adjuster
diff --git a/audioldm/clap/training/train.py b/audioldm/clap/training/train.py
deleted file mode 100644
index f5759c4679d2ee9c0748444adf66b8453cf09728..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/train.py
+++ /dev/null
@@ -1,838 +0,0 @@
-import json
-import logging
-import math
-import os
-import time
-from contextlib import suppress
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-
-try:
- import wandb
-except ImportError:
- wandb = None
-
-from open_clip import ClipLoss, gather_features
-from .distributed import is_master
-from .zero_shot import zero_shot_eval
-
-
-class AverageMeter(object):
- """Computes and stores the average and current value"""
-
- def __init__(self):
- self.reset()
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
-
-def unwrap_model(model):
- if hasattr(model, "module"):
- return model.module
- else:
- return model
-
-
-def train_one_epoch(
- model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None
-):
- device = torch.device(args.device)
- autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
- model.train()
- loss = ClipLoss(
- local_loss=args.local_loss,
- gather_with_grad=args.gather_with_grad,
- cache_labels=True,
- rank=args.rank,
- world_size=args.world_size,
- use_horovod=args.horovod,
- mlp_loss=args.clap_mlploss,
- weight_loss_kappa=args.kappa,
- )
-
- dataloader, sampler = data["train"].dataloader, data["train"].sampler
- if args.distributed and sampler is not None:
- sampler.set_epoch(epoch)
- num_batches_per_epoch = dataloader.num_batches
- sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
-
- # for toy dataset
- if args.dataset_type == "toy":
- dataloader.dataset.generate_queue()
-
- loss_m = AverageMeter()
- batch_time_m = AverageMeter()
- data_time_m = AverageMeter()
- end = time.time()
-
- for i, batch in enumerate(dataloader):
- # logging.info(f"batch {i} of {num_batches_per_epoch}")
- step = num_batches_per_epoch * epoch + i
- if isinstance(scheduler, dict):
- for s in scheduler.values():
- s(step)
- else:
- scheduler(step)
- audios = batch # contains mel_spec, wavform, and longer list
- texts = batch["text"]
- # audios = audios.to(device=device, non_blocking=True)
- # texts = texts.to(device=device, non_blocking=True)
-
- data_time_m.update(time.time() - end)
- if isinstance(optimizer, dict):
- for o_ in optimizer.values():
- o_.zero_grad()
- else:
- optimizer.zero_grad()
-
- with autocast():
- (
- audio_features,
- text_features,
- audio_features_mlp,
- text_features_mlp,
- logit_scale_a,
- logit_scale_t,
- ) = model(audios, texts, device)
-
- if args.clap_mlploss:
- total_loss = loss(
- audio_features=audio_features,
- text_features=text_features,
- logit_scale_a=logit_scale_a,
- logit_scale_t=logit_scale_t,
- audio_features_mlp=audio_features_mlp,
- text_features_mlp=text_features_mlp,
- )
- else:
- total_loss = loss(
- audio_features=audio_features,
- text_features=text_features,
- logit_scale_a=logit_scale_a,
- )
- if isinstance(optimizer, dict):
- if scaler is not None:
- scaler.scale(total_loss).backward()
- for o_ in optimizer.values():
- if args.horovod:
- o_.synchronize()
- scaler.unscale_(o_)
- with o_.skip_synchronize():
- scaler.step(o_)
- else:
- scaler.step(o_)
- scaler.update()
- else:
- total_loss.backward()
- for o_ in optimizer.values():
- o_.step()
- else:
- if scaler is not None:
- scaler.scale(total_loss).backward()
- if args.horovod:
- optimizer.synchronize()
- scaler.unscale_(optimizer)
- with optimizer.skip_synchronize():
- scaler.step(optimizer)
- else:
- scaler.step(optimizer)
- scaler.update()
- else:
- total_loss.backward()
- optimizer.step()
-
- # Note: we clamp to 4.6052 = ln(100), as in the original paper.
- with torch.no_grad():
- unwrap_model(model).logit_scale_a.clamp_(0, math.log(100))
- if args.clap_mlploss:
- unwrap_model(model).logit_scale_t.clamp_(0, math.log(100))
-
- batch_time_m.update(time.time() - end)
- end = time.time()
- batch_count = i + 1
- if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
- if isinstance(audios, dict):
- batch_size = len(audios["waveform"])
- else:
- batch_size = len(audios)
- num_samples = batch_count * batch_size * args.world_size
- samples_per_epoch = dataloader.num_samples
- percent_complete = 100.0 * batch_count / num_batches_per_epoch
-
- # NOTE loss is coarsely sampled, just master node and per log update
- loss_m.update(total_loss.item(), batch_size)
- logit_scale_scalar_a = logit_scale_a.item()
- logit_scale_scalar_t = logit_scale_t.item()
- if isinstance(optimizer, dict):
- if args.clap_mlploss:
- logging.info(
- f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
- f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
- f"Data (t): {data_time_m.avg:.3f} "
- f"Batch (t): {batch_time_m.avg:.3f} "
- f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} "
- f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
- f"Logit Scale Text: {logit_scale_scalar_t:.3f}"
- )
- log_data = {
- "loss": loss_m.val,
- "data_time": data_time_m.val,
- "batch_time": batch_time_m.val,
- "scale_audio": logit_scale_scalar_a,
- "scale_text": logit_scale_scalar_t,
- "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
- }
- else:
- logging.info(
- f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
- f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
- f"Data (t): {data_time_m.avg:.3f} "
- f"Batch (t): {batch_time_m.avg:.3f} "
- f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} "
- f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
- )
- log_data = {
- "loss": loss_m.val,
- "data_time": data_time_m.val,
- "batch_time": batch_time_m.val,
- "scale_audio": logit_scale_scalar_a,
- "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
- }
-
- else:
- if args.clap_mlploss:
- logging.info(
- f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
- f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
- f"Data (t): {data_time_m.avg:.3f} "
- f"Batch (t): {batch_time_m.avg:.3f} "
- f"LR: {optimizer.param_groups[0]['lr']:5f} "
- f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
- f"Logit Scale Text: {logit_scale_scalar_t:.3f}"
- )
-
- # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
- log_data = {
- "loss": loss_m.val,
- "data_time": data_time_m.val,
- "batch_time": batch_time_m.val,
- "scale_audio": logit_scale_scalar_a,
- "scale_text": logit_scale_scalar_t,
- "lr": optimizer.param_groups[0]["lr"],
- }
- else:
- logging.info(
- f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
- f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
- f"Data (t): {data_time_m.avg:.3f} "
- f"Batch (t): {batch_time_m.avg:.3f} "
- f"LR: {optimizer.param_groups[0]['lr']:5f} "
- f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
- )
-
- # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
- log_data = {
- "loss": loss_m.val,
- "data_time": data_time_m.val,
- "batch_time": batch_time_m.val,
- "scale_audio": logit_scale_scalar_a,
- "lr": optimizer.param_groups[0]["lr"],
- }
- for name, val in log_data.items():
- name = "train/" + name
- if tb_writer is not None:
- tb_writer.add_scalar(name, val, step)
- if args.wandb:
- assert wandb is not None, "Please install wandb."
- wandb.log({name: val, "step": step})
-
- # resetting batch / data time meters per log window
- batch_time_m.reset()
- data_time_m.reset()
- # end for
-
-
-def evaluate(model, data, epoch, args, tb_writer=None):
- metrics = {}
- if not args.parallel_eval:
- if not is_master(args):
- return metrics
- device = torch.device(args.device)
- model.eval()
-
- # CHANGE
- # zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
- # metrics.update(zero_shot_metrics)
- if is_master(args):
- print("Evaluating...")
- autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
- if args.val_dataset_names == ["Clotho", "audiocaps"]:
- # if only clotho and audiocaps are used, then we will use a different evaluation function.
- # This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio.
- if args.parallel_eval:
- # (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps.
- raise NotImplementedError(
- "Parallel evaluation not supported for eval only Clotho and audiocaps."
- )
- val_metrics_per_dataset = evaluate_clotho_audiocaps(
- model, data, epoch, args, autocast, device, tb_writer
- )
- for m in val_metrics_per_dataset.values():
- metrics.update(m)
- if "epoch" not in metrics.keys():
- metrics.update({"epoch": epoch})
- metrics = select_top_metric_clotho_audiocaps(
- metrics, val_metrics_per_dataset, args
- )
- elif "val" in data and (
- args.val_frequency
- and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
- ):
- dataloader = data["val"].dataloader
- num_samples = 0
- samples_per_val = dataloader.num_samples
-
- # FIXME this does not scale past small eval datasets
- # all_audio_features @ all_text_features will blow up memory and compute very quickly
- eval_info = {}
- if args.clap_mlploss:
- eval_info["all"] = {
- "cumulative_loss": 0.0,
- "num_samples": 0,
- "all_audio_features": [],
- "all_text_features": [],
- "all_audio_features_mlp": [],
- "all_text_features_mlp": [],
- } # cumulative_loss = 0.0
- else:
- eval_info["all"] = {
- "cumulative_loss": 0.0,
- "num_samples": 0,
- "all_audio_features": [],
- "all_text_features": [],
- } # cumu
- # all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], []
- with torch.no_grad():
- for i, batch in enumerate(dataloader):
- audios = batch # contains mel_spec, wavform, and longer list
- texts = batch["text"]
- # audios = audios.to(device=device, non_blocking=True)
-
- all_names = list(
- set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
- )
- for name in all_names:
- if name not in eval_info.keys():
- if args.clap_mlploss:
- eval_info[name] = {
- "cumulative_loss": 0.0,
- "num_samples": 0,
- "all_audio_features": [],
- "all_text_features": [],
- "all_audio_features_mlp": [],
- "all_text_features_mlp": [],
- }
- else:
- eval_info[name] = {
- "cumulative_loss": 0.0,
- "num_samples": 0,
- "all_audio_features": [],
- "all_text_features": [],
- }
- with autocast():
- (
- audio_features,
- text_features,
- audio_features_mlp,
- text_features_mlp,
- logit_scale_a,
- logit_scale_t,
- ) = model(audios, texts, device)
-
- if args.parallel_eval:
- # multi-GPU eval
- if args.clap_mlploss:
- (
- audio_features,
- text_features,
- audio_features_mlp,
- text_features_mlp,
- ) = gather_features(
- audio_features=audio_features,
- text_features=text_features,
- audio_features_mlp=audio_features_mlp,
- text_features_mlp=text_features_mlp,
- local_loss=False,
- gather_with_grad=False,
- rank=args.rank,
- world_size=args.world_size,
- use_horovod=args.horovod,
- mlp_loss=args.clap_mlploss,
- )
- else:
- (audio_features, text_features,) = gather_features(
- audio_features=audio_features,
- text_features=text_features,
- local_loss=False,
- gather_with_grad=False,
- rank=args.rank,
- world_size=args.world_size,
- use_horovod=args.horovod,
- mlp_loss=args.clap_mlploss,
- )
-
- if is_master(args):
- num_samples += audio_features.shape[0]
- for n in [*all_names, "all"]:
- if n == "all":
- eval_info[n]["all_audio_features"].append(
- audio_features.cpu()
- )
- eval_info[n]["all_text_features"].append(
- text_features.cpu()
- )
- if args.clap_mlploss:
- eval_info[n]["all_audio_features_mlp"].append(
- audio_features_mlp.cpu()
- )
- eval_info[n]["all_text_features_mlp"].append(
- text_features_mlp.cpu()
- )
- else:
- idx = np.where(
- np.array(
- [
- "-".join(b.split("/")[-3:-1])
- for b in batch["__url__"]
- ]
- )
- == n
- )[0]
- eval_info[n]["all_audio_features"].append(
- audio_features.cpu().index_select(
- 0, torch.tensor(idx).long()
- )
- )
- eval_info[n]["all_text_features"].append(
- text_features.cpu().index_select(
- 0, torch.tensor(idx).long()
- )
- )
- if args.clap_mlploss:
- eval_info[n]["all_audio_features_mlp"].append(
- audio_features_mlp.cpu().index_select(
- 0, torch.tensor(idx).long()
- )
- )
- eval_info[n]["all_text_features_mlp"].append(
- text_features_mlp.cpu().index_select(
- 0, torch.tensor(idx).long()
- )
- )
- # print(f'eval step {i}') # (yusong): for debug
-
- # cumulative_loss += total_loss * batch_size
- # num_samples += batch_size
- if is_master(args) and (i % 100) == 0: # and i != 0:
- logging.info(
- f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
- )
- if is_master(args):
- val_metrics_per_dataset = {}
- for n in eval_info.keys():
- if args.clap_mlploss:
- metrics_single_dataset = get_metrics(
- audio_features=torch.cat(
- eval_info[n]["all_audio_features"]
- ),
- text_features=torch.cat(eval_info[n]["all_text_features"]),
- logit_scale_a=logit_scale_a.cpu(),
- audio_features_mlp=torch.cat(
- eval_info[n]["all_audio_features_mlp"]
- ),
- text_features_mlp=torch.cat(
- eval_info[n]["all_text_features_mlp"]
- ),
- logit_scale_t=logit_scale_t.cpu(),
- mlp_loss=args.clap_mlploss,
- )
- else:
- metrics_single_dataset = get_metrics(
- audio_features=torch.cat(
- eval_info[n]["all_audio_features"]
- ),
- text_features=torch.cat(eval_info[n]["all_text_features"]),
- logit_scale_a=logit_scale_a.cpu(),
- mlp_loss=args.clap_mlploss,
- )
- val_metrics_per_dataset[n] = {
- n + "/" + k: v for k, v in metrics_single_dataset.items()
- }
- metrics.update(val_metrics_per_dataset[n])
- if "epoch" not in metrics.keys():
- metrics.update({"epoch": epoch})
- if is_master(args):
- if not metrics:
- return metrics
-
- logging.info(
- f"Eval Epoch: {epoch} "
- + "\n".join(
- [
- "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()])
- for m in val_metrics_per_dataset.values()
- ]
- )
- )
-
- if args.save_logs:
- for name, val in metrics.items():
- if tb_writer is not None:
- tb_writer.add_scalar(f"val/{name}", val, epoch)
-
- with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
- f.write(json.dumps(metrics))
- f.write("\n")
-
- if args.wandb:
- assert wandb is not None, "Please install wandb."
- for name, val in metrics.items():
- wandb.log({f"val/{name}": val, "epoch": epoch})
-
- return metrics
- else:
- return metrics
-
-
-def get_metrics(
- audio_features,
- text_features,
- logit_scale_a,
- audio_features_mlp=None,
- text_features_mlp=None,
- logit_scale_t=None,
- mlp_loss=False,
-):
- metrics = {}
- if mlp_loss:
- # Set up audio to text & text to audio similary matrice
- a_logits_per_audio = (
- (logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu()
- )
- a_logits_per_text = a_logits_per_audio.t().detach().cpu()
- t_logits_per_audio = (
- (logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu()
- )
- t_logits_per_text = t_logits_per_audio.t().detach().cpu()
-
- labels = torch.arange(audio_features.shape[0]).long()
- # Change the loss from two terms into four terms with 2x2 combined CE loss
- total_loss = (
- F.cross_entropy(a_logits_per_audio, labels)
- + F.cross_entropy(a_logits_per_text, labels)
- + F.cross_entropy(t_logits_per_audio, labels)
- + F.cross_entropy(t_logits_per_text, labels)
- ) / 4
-
- metrics[f"cumulative_loss"] = total_loss.item()
- metrics[f"num_samples"] = audio_features.shape[0]
-
- logits = {
- "audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2,
- "text_to_audio": (a_logits_per_text + t_logits_per_text) / 2,
- }
- ground_truth = torch.arange(len(text_features)).view(-1, 1)
-
- else:
- # print("text_features", text_features)
- # print("text_features.shape", text_features.shape)
- logits_per_audio = (
- (logit_scale_a * audio_features @ text_features.t()).detach().cpu()
- )
- logits_per_text = logits_per_audio.t().detach().cpu()
-
- labels = torch.arange(audio_features.shape[0]).long()
- # Change the loss from two terms into four terms with 2x2 combined CE loss
- total_loss = (
- F.cross_entropy(logits_per_audio, labels)
- + F.cross_entropy(logits_per_text, labels)
- ) / 2
-
- metrics[f"cumulative_loss"] = total_loss.item()
- metrics[f"num_samples"] = audio_features.shape[0]
-
- logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text}
-
- ground_truth = torch.arange(len(text_features)).view(-1, 1)
-
- for name, logit in logits.items():
- ranking = torch.argsort(logit, descending=True)
- preds = torch.where(ranking == ground_truth)[
- 1
- ] # (yusong) this line is slow because it uses single thread
- preds = preds.detach().cpu().numpy()
- metrics[f"{name}_mean_rank"] = preds.mean() + 1
- metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
- for k in [1, 5, 10]:
- metrics[f"{name}_R@{k}"] = np.mean(preds < k)
- # map@10
- metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))
-
- return metrics
-
-
-def evaluate_clotho_audiocaps(
- model, data, epoch, args, autocast, device, tb_writer=None
-):
- """
- Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py.
- 1. for text-to-audio retrieval, do 5 times and average the results
- 2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text
- 3. for map@10 in audio-to-text retrieval:
- 3.1: sort the rank of 5 text
- 3.2: exclude the rank >=10 (0-index)
- 3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks).
- (3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth.
- (3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc.
- """
- # TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now.
- dataloader = data["val"].dataloader
- with torch.no_grad():
- eval_info = {}
- for i, batch in enumerate(dataloader):
- audios = batch # contains mel_spec, wavform, and longer list
-
- # each item in the list has 5 texts
- if args.tmodel == "transformer":
- from open_clip import tokenize
-
- texts = [tokenize(t) for t in batch["full_text"]]
- texts = torch.cat(texts)
- else:
- from .data import tokenizer
-
- texts = [
- tokenizer(t) for t in batch["full_text"]
- ] # 5 texts for each audio
- texts = {
- k: torch.cat([t[k] for t in texts]) for k in texts[0].keys()
- } # 5 x batch
-
- # audios = audios.to(device=device, non_blocking=True)
-
- all_names = list(
- set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
- )
- for name in all_names:
- if name not in eval_info.keys():
- # we will not use mlp outputs even if args.clap_mlploss=True
- eval_info[name] = {
- "cumulative_loss": 0.0,
- "num_samples": 0,
- "all_audio_features": [],
- "all_text_features": [],
- }
- with autocast():
- audio_features = model(audios, None, device)
- text_features = model(None, texts, device)
- audio_features = F.normalize(audio_features, dim=-1)
- text_features = F.normalize(text_features, dim=-1)
-
- all_names = list(
- set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
- )
- for n in all_names:
- idx = np.where(
- np.array(
- ["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]
- )
- == n
- )[0]
- eval_info[n]["all_audio_features"].append(
- audio_features.cpu().index_select(0, torch.tensor(idx).long())
- )
- # (yusong) please double-check. This is for selecting 5 text features at once.
- # because idx is a list of indices in size of num_samples,
- # and text_features is a tensor of size (5*num_samples, dim)
- # so we need to select 5 consecutive indices at once for a single index in idx.
- eval_info[n]["all_text_features"].append(
- text_features.cpu()
- .reshape([-1, 5, text_features.shape[1]])
- .index_select(0, torch.tensor(idx).long())
- .reshape([-1, text_features.shape[1]])
- )
-
- val_metrics_all = {}
-
- for n in eval_info.keys():
- logit_scale_a, logit_scale_t = model(None, None, device)
- logit_scale_a = logit_scale_a.cpu()
-
- audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0)
- text_features = torch.cat(eval_info[n]["all_text_features"], dim=0)
-
- logits_per_audio = (
- (logit_scale_a * audio_features @ text_features.t()).detach().cpu()
- )
- logits_per_text = logits_per_audio.t().detach().cpu()
-
- # logits_per_audio shape: [num_samples, num_samples*5]
- # logits_per_text shape: [num_samples*5, num_samples]
-
- logging.info(
- f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, "
- f"logits_per_text shape: {logits_per_text.shape}"
- )
-
- metrics = {}
- num_samples = audio_features.shape[0]
- metrics[f"num_samples"] = num_samples
-
- # (yusong) the following code is very important, please double-check:
- # logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d]
- # logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :]
- # Those two are retrieving one of the 5 text for each audio.
- labels = torch.arange(audio_features.shape[0]).long()
- audio_to_text_loss = [
- F.cross_entropy(
- logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d],
- labels,
- )
- for d in range(5)
- ]
- text_to_audio_loss = [
- F.cross_entropy(
- logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :],
- labels,
- )
- for d in range(5)
- ]
- total_loss = (np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)) / 2
-
- metrics[f"cumulative_loss"] = total_loss.item()
-
- # text to audio: do 5 times
- pred_text = []
- for d in range(5):
- logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :]
- ground_truth = torch.arange(len(logit)).view(-1, 1)
- ranking = torch.argsort(
- logit, descending=True
- ) # [num_samples, num_samples]
- preds = torch.where(ranking == ground_truth)[1]
- pred_text.append(preds.detach().cpu().numpy())
- pred_text_concat = np.concatenate(pred_text, axis=0) # [5*num_samples]
- metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1
- metrics[f"text_to_audio_median_rank"] = (
- np.floor(np.median(pred_text_concat)) + 1
- )
- for k in [1, 5, 10]:
- metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k)
- # map@10
- metrics[f"text_to_audio_mAP@10"] = np.mean(
- np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0)
- )
-
- # audio to text: take the best result
- # for audio to text map 10, sort and assign descending ground truth.
- # see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103
- # map@10
- map_all = []
- pred_audio_all = []
- for d in range(num_samples):
- # logits_per_audio: [num_samples, num_samples*5]
- logit_single = logits_per_audio[d, :] # [5*num_samples]
- # Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4]
- ranking = torch.argsort(
- logit_single, descending=True
- ) # [5*num_samples]
- # ranking: the index of first match, second match, ...
- ground_truth = torch.arange(d * 5, d * 5 + 5)[None]
- all_pred = torch.where(
- torch.stack([ranking] * 5) == ground_truth.view(-1, 1)
- )[1]
- min_pred = torch.min(all_pred)
- pred_audio_all.append(min_pred.detach().cpu().numpy())
- all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy()
- # /5 because we have 5 text, so it means for the text rank >=10 we count as 0.
- map_single = (
- np.sum(
- (np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1))
- )
- / 5
- )
- map_all.append(map_single)
- metrics[f"audio_to_text_mAP@10"] = np.mean(map_all)
- for k in [1, 5, 10]:
- metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k)
-
- val_metrics_all[n] = {n + "/" + k: v for k, v in metrics.items()}
- return val_metrics_all
-
-
-def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset):
- """
- Calculate performance for Clotho+AudioCaps for model selection.
- """
- selection_performance_all = []
- for n in val_metrics_per_dataset.keys():
- selection_performance = (
- val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"]
- + val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"]
- ) / 2
- selection_performance_all.append(selection_performance)
- return np.mean(selection_performance_all)
-
-
-def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args):
- # val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value
- # metrics: dict, key: metric name, value: metric value
- # Hack: use args to save the top performance
- if not hasattr(args, "top_selection_performance"):
- selection_performance = calculate_selection_performance_clotho_audiocaps(
- val_metrics_per_dataset
- )
- # TODO: write the if and else together
- metric_update = {}
- for n in val_metrics_per_dataset.keys():
- for k in val_metrics_per_dataset[n].keys():
- metric_update[
- k.split("/")[0] + "-top" + "/" + k.split("/")[1]
- ] = val_metrics_per_dataset[n][k]
- metric_update["top_selection_performance"] = selection_performance
- metric_update["top-selection-epoch"] = metrics["epoch"]
- metrics.update(metric_update)
- args.top_metric = metric_update
- args.top_selection_performance = selection_performance
- else:
- selection_performance_new = calculate_selection_performance_clotho_audiocaps(
- val_metrics_per_dataset
- )
- selection_performance_old = args.top_selection_performance
- if selection_performance_new > selection_performance_old:
- metric_update = {}
- for n in val_metrics_per_dataset.keys():
- for k in val_metrics_per_dataset[n].keys():
- metric_update[
- k.split("/")[0] + "-top" + "/" + k.split("/")[1]
- ] = val_metrics_per_dataset[n][k]
- metric_update["top_selection_performance"] = selection_performance_new
- metric_update["top-selection-epoch"] = metrics["epoch"]
- metrics.update(metric_update)
- args.top_metric = metric_update
- args.top_selection_performance = selection_performance_new
- else:
- metrics.update(args.top_metric)
- return metrics
diff --git a/audioldm/clap/training/zero_shot.py b/audioldm/clap/training/zero_shot.py
deleted file mode 100644
index 28b8fccc1af17fc69002857a7f529ac041c374f2..0000000000000000000000000000000000000000
--- a/audioldm/clap/training/zero_shot.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# NOTE: This script is currently not supported for CLAP.
-import logging
-from contextlib import suppress
-
-import torch
-import torch.nn.functional as F
-from tqdm import tqdm
-
-from open_clip import tokenize
-from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
-
-
-def zero_shot_classifier(model, classnames, templates, args):
- with torch.no_grad():
- zeroshot_weights = []
- for classname in tqdm(classnames):
- texts = [template(classname) for template in templates] # format with class
- texts = tokenize(texts).to(args.device) # tokenize
- if args.distributed and not args.horovod:
- class_embeddings = model.module.encode_text(texts)
- else:
- class_embeddings = model.encode_text(texts)
- class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
- class_embedding /= class_embedding.norm()
- zeroshot_weights.append(class_embedding)
- zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device)
- return zeroshot_weights
-
-
-def accuracy(output, target, topk=(1,)):
- pred = output.topk(max(topk), 1, True, True)[1].t()
- correct = pred.eq(target.view(1, -1).expand_as(pred))
- return [
- float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
- for k in topk
- ]
-
-
-def run(model, classifier, dataloader, args):
- autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
- with torch.no_grad():
- top1, top5, n = 0.0, 0.0, 0.0
- for images, target in tqdm(dataloader, unit_scale=args.batch_size):
- images = images.to(args.device)
- target = target.to(args.device)
-
- with autocast():
- # predict
- if args.distributed and not args.horovod:
- image_features = model.module.encode_image(images)
- else:
- image_features = model.encode_image(images)
- image_features = F.normalize(image_features, dim=-1)
- logits = 100.0 * image_features @ classifier
-
- # measure accuracy
- acc1, acc5 = accuracy(logits, target, topk=(1, 5))
- top1 += acc1
- top5 += acc5
- n += images.size(0)
-
- top1 = top1 / n
- top5 = top5 / n
- return top1, top5
-
-
-def zero_shot_eval(model, data, epoch, args):
- if "imagenet-val" not in data and "imagenet-v2" not in data:
- return {}
- if args.zeroshot_frequency == 0:
- return {}
- if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
- return {}
-
- logging.info("Starting zero-shot imagenet.")
-
- logging.info("Building zero-shot classifier")
- classifier = zero_shot_classifier(
- model, imagenet_classnames, openai_imagenet_template, args
- )
-
- logging.info("Using classifier")
- results = {}
- if "imagenet-val" in data:
- top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args)
- results["imagenet-zeroshot-val-top1"] = top1
- results["imagenet-zeroshot-val-top5"] = top5
- if "imagenet-v2" in data:
- top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args)
- results["imagenetv2-zeroshot-val-top1"] = top1
- results["imagenetv2-zeroshot-val-top5"] = top5
-
- logging.info("Finished zero-shot imagenet.")
-
- return results
diff --git a/audioldm/hifigan/__init__.py b/audioldm/hifigan/__init__.py
deleted file mode 100644
index e0ae476fe58c48e998c56234a55b871beba4042d..0000000000000000000000000000000000000000
--- a/audioldm/hifigan/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .models import Generator
-
-
-class AttrDict(dict):
- def __init__(self, *args, **kwargs):
- super(AttrDict, self).__init__(*args, **kwargs)
- self.__dict__ = self
diff --git a/audioldm/hifigan/models.py b/audioldm/hifigan/models.py
deleted file mode 100644
index c4382cc39de0463f9b7c0f33f037dbc233e7cb36..0000000000000000000000000000000000000000
--- a/audioldm/hifigan/models.py
+++ /dev/null
@@ -1,174 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn import Conv1d, ConvTranspose1d
-from torch.nn.utils import weight_norm, remove_weight_norm
-
-LRELU_SLOPE = 0.1
-
-
-def init_weights(m, mean=0.0, std=0.01):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- m.weight.data.normal_(mean, std)
-
-
-def get_padding(kernel_size, dilation=1):
- return int((kernel_size * dilation - dilation) / 2)
-
-
-class ResBlock(torch.nn.Module):
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
- super(ResBlock, self).__init__()
- self.h = h
- self.convs1 = nn.ModuleList(
- [
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation[0],
- padding=get_padding(kernel_size, dilation[0]),
- )
- ),
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation[1],
- padding=get_padding(kernel_size, dilation[1]),
- )
- ),
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation[2],
- padding=get_padding(kernel_size, dilation[2]),
- )
- ),
- ]
- )
- self.convs1.apply(init_weights)
-
- self.convs2 = nn.ModuleList(
- [
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=1,
- padding=get_padding(kernel_size, 1),
- )
- ),
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=1,
- padding=get_padding(kernel_size, 1),
- )
- ),
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=1,
- padding=get_padding(kernel_size, 1),
- )
- ),
- ]
- )
- self.convs2.apply(init_weights)
-
- def forward(self, x):
- for c1, c2 in zip(self.convs1, self.convs2):
- xt = F.leaky_relu(x, LRELU_SLOPE)
- xt = c1(xt)
- xt = F.leaky_relu(xt, LRELU_SLOPE)
- xt = c2(xt)
- x = xt + x
- return x
-
- def remove_weight_norm(self):
- for l in self.convs1:
- remove_weight_norm(l)
- for l in self.convs2:
- remove_weight_norm(l)
-
-
-class Generator(torch.nn.Module):
- def __init__(self, h):
- super(Generator, self).__init__()
- self.h = h
- self.num_kernels = len(h.resblock_kernel_sizes)
- self.num_upsamples = len(h.upsample_rates)
- self.conv_pre = weight_norm(
- Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
- )
- resblock = ResBlock
-
- self.ups = nn.ModuleList()
- for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
- self.ups.append(
- weight_norm(
- ConvTranspose1d(
- h.upsample_initial_channel // (2**i),
- h.upsample_initial_channel // (2 ** (i + 1)),
- k,
- u,
- padding=(k - u) // 2,
- )
- )
- )
-
- self.resblocks = nn.ModuleList()
- for i in range(len(self.ups)):
- ch = h.upsample_initial_channel // (2 ** (i + 1))
- for j, (k, d) in enumerate(
- zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
- ):
- self.resblocks.append(resblock(h, ch, k, d))
-
- self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
- self.ups.apply(init_weights)
- self.conv_post.apply(init_weights)
-
- def forward(self, x):
- x = self.conv_pre(x)
- for i in range(self.num_upsamples):
- x = F.leaky_relu(x, LRELU_SLOPE)
- x = self.ups[i](x)
- xs = None
- for j in range(self.num_kernels):
- if xs is None:
- xs = self.resblocks[i * self.num_kernels + j](x)
- else:
- xs += self.resblocks[i * self.num_kernels + j](x)
- x = xs / self.num_kernels
- x = F.leaky_relu(x)
- x = self.conv_post(x)
- x = torch.tanh(x)
-
- return x
-
- def remove_weight_norm(self):
- # print("Removing weight norm...")
- for l in self.ups:
- remove_weight_norm(l)
- for l in self.resblocks:
- l.remove_weight_norm()
- remove_weight_norm(self.conv_pre)
- remove_weight_norm(self.conv_post)
diff --git a/audioldm/hifigan/utilities.py b/audioldm/hifigan/utilities.py
deleted file mode 100644
index 47fd39ea0af181772d640feec2413cf631a75702..0000000000000000000000000000000000000000
--- a/audioldm/hifigan/utilities.py
+++ /dev/null
@@ -1,85 +0,0 @@
-import os
-import json
-
-import torch
-import numpy as np
-
-import audioldm.hifigan as hifigan
-
-HIFIGAN_16K_64 = {
- "resblock": "1",
- "num_gpus": 6,
- "batch_size": 16,
- "learning_rate": 0.0002,
- "adam_b1": 0.8,
- "adam_b2": 0.99,
- "lr_decay": 0.999,
- "seed": 1234,
- "upsample_rates": [5, 4, 2, 2, 2],
- "upsample_kernel_sizes": [16, 16, 8, 4, 4],
- "upsample_initial_channel": 1024,
- "resblock_kernel_sizes": [3, 7, 11],
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
- "segment_size": 8192,
- "num_mels": 64,
- "num_freq": 1025,
- "n_fft": 1024,
- "hop_size": 160,
- "win_size": 1024,
- "sampling_rate": 16000,
- "fmin": 0,
- "fmax": 8000,
- "fmax_for_loss": None,
- "num_workers": 4,
- "dist_config": {
- "dist_backend": "nccl",
- "dist_url": "tcp://localhost:54321",
- "world_size": 1,
- },
-}
-
-
-def get_available_checkpoint_keys(model, ckpt):
- print("==> Attemp to reload from %s" % ckpt)
- state_dict = torch.load(ckpt)["state_dict"]
- current_state_dict = model.state_dict()
- new_state_dict = {}
- for k in state_dict.keys():
- if (
- k in current_state_dict.keys()
- and current_state_dict[k].size() == state_dict[k].size()
- ):
- new_state_dict[k] = state_dict[k]
- else:
- print("==> WARNING: Skipping %s" % k)
- print(
- "%s out of %s keys are matched"
- % (len(new_state_dict.keys()), len(state_dict.keys()))
- )
- return new_state_dict
-
-
-def get_param_num(model):
- num_param = sum(param.numel() for param in model.parameters())
- return num_param
-
-
-def get_vocoder(config, device):
- config = hifigan.AttrDict(HIFIGAN_16K_64)
- vocoder = hifigan.Generator(config)
- vocoder.eval()
- vocoder.remove_weight_norm()
- vocoder.to(device)
- return vocoder
-
-
-def vocoder_infer(mels, vocoder, lengths=None):
- with torch.no_grad():
- wavs = vocoder(mels).squeeze(1)
-
- wavs = (wavs.cpu().numpy() * 32768).astype("int16")
-
- if lengths is not None:
- wavs = wavs[:, :lengths]
-
- return wavs
diff --git a/audioldm/latent_diffusion/__init__.py b/audioldm/latent_diffusion/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/audioldm/latent_diffusion/attention.py b/audioldm/latent_diffusion/attention.py
deleted file mode 100644
index 583dd169e7ec9502ee29faeb12689a46494838c0..0000000000000000000000000000000000000000
--- a/audioldm/latent_diffusion/attention.py
+++ /dev/null
@@ -1,468 +0,0 @@
-from inspect import isfunction
-import math
-import torch
-import torch.nn.functional as F
-from torch import nn
-from einops import rearrange
-
-from audioldm.latent_diffusion.util import checkpoint
-
-
-def exists(val):
- return val is not None
-
-
-def uniq(arr):
- return {el: True for el in arr}.keys()
-
-
-def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
-
-
-def max_neg_value(t):
- return -torch.finfo(t.dtype).max
-
-
-def init_(tensor):
- dim = tensor.shape[-1]
- std = 1 / math.sqrt(dim)
- tensor.uniform_(-std, std)
- return tensor
-
-
-# feedforward
-class GEGLU(nn.Module):
- def __init__(self, dim_in, dim_out):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
-
- def forward(self, x):
- x, gate = self.proj(x).chunk(2, dim=-1)
- return x * F.gelu(gate)
-
-
-class FeedForward(nn.Module):
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
- super().__init__()
- inner_dim = int(dim * mult)
- dim_out = default(dim_out, dim)
- project_in = (
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
- if not glu
- else GEGLU(dim, inner_dim)
- )
-
- self.net = nn.Sequential(
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
- )
-
- def forward(self, x):
- return self.net(x)
-
-
-def zero_module(module):
- """
- Zero out the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().zero_()
- return module
-
-
-def Normalize(in_channels):
- return torch.nn.GroupNorm(
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
- )
-
-
-class LinearAttention(nn.Module):
- def __init__(self, dim, heads=4, dim_head=32):
- super().__init__()
- self.heads = heads
- hidden_dim = dim_head * heads
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
- self.to_out = nn.Conv2d(hidden_dim, dim, 1)
-
- def forward(self, x):
- b, c, h, w = x.shape
- qkv = self.to_qkv(x)
- q, k, v = rearrange(
- qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
- )
- k = k.softmax(dim=-1)
- context = torch.einsum("bhdn,bhen->bhde", k, v)
- out = torch.einsum("bhde,bhdn->bhen", context, q)
- out = rearrange(
- out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
- )
- return self.to_out(out)
-
-
-class SpatialSelfAttention(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q.shape
- q = rearrange(q, "b c h w -> b (h w) c")
- k = rearrange(k, "b c h w -> b c (h w)")
- w_ = torch.einsum("bij,bjk->bik", q, k)
-
- w_ = w_ * (int(c) ** (-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
-
- # attend to values
- v = rearrange(v, "b c h w -> b c (h w)")
- w_ = rearrange(w_, "b i j -> b j i")
- h_ = torch.einsum("bij,bjk->bik", v, w_)
- h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
- h_ = self.proj_out(h_)
-
- return x + h_
-
-
-class CrossAttention(nn.Module):
- """
- ### Cross Attention Layer
- This falls-back to self-attention when conditional embeddings are not specified.
- """
-
- # use_flash_attention: bool = True
- use_flash_attention: bool = False
- def __init__(
- self,
- query_dim,
- context_dim=None,
- heads=8,
- dim_head=64,
- dropout=0.0,
- is_inplace: bool = True,
- ):
- # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
- """
- :param d_model: is the input embedding size
- :param n_heads: is the number of attention heads
- :param d_head: is the size of a attention head
- :param d_cond: is the size of the conditional embeddings
- :param is_inplace: specifies whether to perform the attention softmax computation inplace to
- save memory
- """
- super().__init__()
-
- self.is_inplace = is_inplace
- self.n_heads = heads
- self.d_head = dim_head
-
- # Attention scaling factor
- self.scale = dim_head**-0.5
-
- # The normal self-attention layer
- if context_dim is None:
- context_dim = query_dim
-
- # Query, key and value mappings
- d_attn = dim_head * heads
- self.to_q = nn.Linear(query_dim, d_attn, bias=False)
- self.to_k = nn.Linear(context_dim, d_attn, bias=False)
- self.to_v = nn.Linear(context_dim, d_attn, bias=False)
-
- # Final linear layer
- self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
-
- # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
- # Flash attention is only used if it's installed
- # and `CrossAttention.use_flash_attention` is set to `True`.
- try:
- # You can install flash attention by cloning their Github repo,
- # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
- # and then running `python setup.py install`
- from flash_attn.flash_attention import FlashAttention
-
- self.flash = FlashAttention()
- # Set the scale for scaled dot-product attention.
- self.flash.softmax_scale = self.scale
- # Set to `None` if it's not installed
- except ImportError:
- self.flash = None
-
- def forward(self, x, context=None, mask=None):
- """
- :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
- :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
- """
-
- # If `cond` is `None` we perform self attention
- has_cond = context is not None
- if not has_cond:
- context = x
-
- # Get query, key and value vectors
- q = self.to_q(x)
- k = self.to_k(context)
- v = self.to_v(context)
-
- # Use flash attention if it's available and the head size is less than or equal to `128`
- if (
- CrossAttention.use_flash_attention
- and self.flash is not None
- and not has_cond
- and self.d_head <= 128
- ):
- return self.flash_attention(q, k, v)
- # Otherwise, fallback to normal attention
- else:
- return self.normal_attention(q, k, v)
-
- def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
- """
- #### Flash Attention
- :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
- :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
- :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
- """
-
- # Get batch size and number of elements along sequence axis (`width * height`)
- batch_size, seq_len, _ = q.shape
-
- # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
- # shape `[batch_size, seq_len, 3, n_heads * d_head]`
- qkv = torch.stack((q, k, v), dim=2)
- # Split the heads
- qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
-
- # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
- # fit this size.
- if self.d_head <= 32:
- pad = 32 - self.d_head
- elif self.d_head <= 64:
- pad = 64 - self.d_head
- elif self.d_head <= 128:
- pad = 128 - self.d_head
- else:
- raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
-
- # Pad the heads
- if pad:
- qkv = torch.cat(
- (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
- )
-
- # Compute attention
- # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
- # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
- # TODO here I add the dtype changing
- out, _ = self.flash(qkv.type(torch.float16))
- # Truncate the extra head size
- out = out[:, :, :, : self.d_head].float()
- # Reshape to `[batch_size, seq_len, n_heads * d_head]`
- out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
-
- # Map to `[batch_size, height * width, d_model]` with a linear layer
- return self.to_out(out)
-
- def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
- """
- #### Normal Attention
-
- :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
- :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
- :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
- """
-
- # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
- q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32]
- k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32]
- v = v.view(*v.shape[:2], self.n_heads, -1)
-
- # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
- attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
-
- # Compute softmax
- # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
- if self.is_inplace:
- half = attn.shape[0] // 2
- attn[half:] = attn[half:].softmax(dim=-1)
- attn[:half] = attn[:half].softmax(dim=-1)
- else:
- attn = attn.softmax(dim=-1)
-
- # Compute attention output
- # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
- # attn: [bs, 20, 64, 1]
- # v: [bs, 1, 20, 32]
- out = torch.einsum("bhij,bjhd->bihd", attn, v)
- # Reshape to `[batch_size, height * width, n_heads * d_head]`
- out = out.reshape(*out.shape[:2], -1)
- # Map to `[batch_size, height * width, d_model]` with a linear layer
- return self.to_out(out)
-
-
-# class CrossAttention(nn.Module):
-# def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
-# super().__init__()
-# inner_dim = dim_head * heads
-# context_dim = default(context_dim, query_dim)
-
-# self.scale = dim_head ** -0.5
-# self.heads = heads
-
-# self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
-# self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
-# self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
-
-# self.to_out = nn.Sequential(
-# nn.Linear(inner_dim, query_dim),
-# nn.Dropout(dropout)
-# )
-
-# def forward(self, x, context=None, mask=None):
-# h = self.heads
-
-# q = self.to_q(x)
-# context = default(context, x)
-# k = self.to_k(context)
-# v = self.to_v(context)
-
-# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
-
-# sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
-
-# if exists(mask):
-# mask = rearrange(mask, 'b ... -> b (...)')
-# max_neg_value = -torch.finfo(sim.dtype).max
-# mask = repeat(mask, 'b j -> (b h) () j', h=h)
-# sim.masked_fill_(~mask, max_neg_value)
-
-# # attention, what we cannot get enough of
-# attn = sim.softmax(dim=-1)
-
-# out = einsum('b i j, b j d -> b i d', attn, v)
-# out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
-# return self.to_out(out)
-
-
-class BasicTransformerBlock(nn.Module):
- def __init__(
- self,
- dim,
- n_heads,
- d_head,
- dropout=0.0,
- context_dim=None,
- gated_ff=True,
- checkpoint=True,
- ):
- super().__init__()
- self.attn1 = CrossAttention(
- query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
- ) # is a self-attention
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
- self.attn2 = CrossAttention(
- query_dim=dim,
- context_dim=context_dim,
- heads=n_heads,
- dim_head=d_head,
- dropout=dropout,
- ) # is self-attn if context is none
- self.norm1 = nn.LayerNorm(dim)
- self.norm2 = nn.LayerNorm(dim)
- self.norm3 = nn.LayerNorm(dim)
- self.checkpoint = checkpoint
-
- def forward(self, x, context=None):
- if context is None:
- return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
- else:
- return checkpoint(
- self._forward, (x, context), self.parameters(), self.checkpoint
- )
-
- def _forward(self, x, context=None):
- x = self.attn1(self.norm1(x)) + x
- x = self.attn2(self.norm2(x), context=context) + x
- x = self.ff(self.norm3(x)) + x
- return x
-
-
-class SpatialTransformer(nn.Module):
- """
- Transformer block for image-like data.
- First, project the input (aka embedding)
- and reshape to b, t, d.
- Then apply standard transformer action.
- Finally, reshape to image
- """
-
- def __init__(
- self,
- in_channels,
- n_heads,
- d_head,
- depth=1,
- dropout=0.0,
- context_dim=None,
- no_context=False,
- ):
- super().__init__()
-
- if no_context:
- context_dim = None
-
- self.in_channels = in_channels
- inner_dim = n_heads * d_head
- self.norm = Normalize(in_channels)
-
- self.proj_in = nn.Conv2d(
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
- )
-
- self.transformer_blocks = nn.ModuleList(
- [
- BasicTransformerBlock(
- inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
- )
- for d in range(depth)
- ]
- )
-
- self.proj_out = zero_module(
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
- )
-
- def forward(self, x, context=None):
- # note: if no context is given, cross-attention defaults to self-attention
- b, c, h, w = x.shape
- x_in = x
- x = self.norm(x)
- x = self.proj_in(x)
- x = rearrange(x, "b c h w -> b (h w) c")
- for block in self.transformer_blocks:
- x = block(x, context=context)
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
- x = self.proj_out(x)
- return x + x_in
diff --git a/audioldm/latent_diffusion/ddim.py b/audioldm/latent_diffusion/ddim.py
deleted file mode 100644
index 57ee8d302c77cb09bd73ef803ef9e715098feafc..0000000000000000000000000000000000000000
--- a/audioldm/latent_diffusion/ddim.py
+++ /dev/null
@@ -1,377 +0,0 @@
-"""SAMPLING ONLY."""
-
-import torch
-import numpy as np
-from tqdm import tqdm
-
-from audioldm.latent_diffusion.util import (
- make_ddim_sampling_parameters,
- make_ddim_timesteps,
- noise_like,
- extract_into_tensor,
-)
-import gradio as gr
-
-class DDIMSampler(object):
- def __init__(self, model, schedule="linear", **kwargs):
- super().__init__()
- self.model = model
- self.ddpm_num_timesteps = model.num_timesteps
- self.schedule = schedule
-
- def register_buffer(self, name, attr):
- if type(attr) == torch.Tensor:
- if attr.device != torch.device("cuda"):
- attr = attr.to(torch.device("cuda"))
- setattr(self, name, attr)
-
- def make_schedule(
- self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
- ):
- self.ddim_timesteps = make_ddim_timesteps(
- ddim_discr_method=ddim_discretize,
- num_ddim_timesteps=ddim_num_steps,
- num_ddpm_timesteps=self.ddpm_num_timesteps,
- verbose=verbose,
- )
- alphas_cumprod = self.model.alphas_cumprod
- assert (
- alphas_cumprod.shape[0] == self.ddpm_num_timesteps
- ), "alphas have to be defined for each timestep"
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
-
- self.register_buffer("betas", to_torch(self.model.betas))
- self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
- self.register_buffer(
- "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
- )
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer(
- "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
- )
- self.register_buffer(
- "sqrt_one_minus_alphas_cumprod",
- to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
- )
- self.register_buffer(
- "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
- )
- self.register_buffer(
- "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
- )
- self.register_buffer(
- "sqrt_recipm1_alphas_cumprod",
- to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
- )
-
- # ddim sampling parameters
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
- alphacums=alphas_cumprod.cpu(),
- ddim_timesteps=self.ddim_timesteps,
- eta=ddim_eta,
- verbose=verbose,
- )
- self.register_buffer("ddim_sigmas", ddim_sigmas)
- self.register_buffer("ddim_alphas", ddim_alphas)
- self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
- self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
- (1 - self.alphas_cumprod_prev)
- / (1 - self.alphas_cumprod)
- * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
- )
- self.register_buffer(
- "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
- )
-
- @torch.no_grad()
- def sample(
- self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.0,
- mask=None,
- x0=None,
- temperature=1.0,
- noise_dropout=0.0,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.0,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs,
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
- if cbs != batch_size:
- print(
- f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
- )
- else:
- if conditioning.shape[0] != batch_size:
- print(
- f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
- )
-
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
- samples, intermediates = self.ddim_sampling(
- conditioning,
- size,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask,
- x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- x_T=x_T,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- return samples, intermediates
-
- @torch.no_grad()
- def ddim_sampling(
- self,
- cond,
- shape,
- x_T=None,
- ddim_use_original_steps=False,
- callback=None,
- timesteps=None,
- quantize_denoised=False,
- mask=None,
- x0=None,
- img_callback=None,
- log_every_t=100,
- temperature=1.0,
- noise_dropout=0.0,
- score_corrector=None,
- corrector_kwargs=None,
- unconditional_guidance_scale=1.0,
- unconditional_conditioning=None,
- ):
- device = self.model.betas.device
- b = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=device)
- else:
- img = x_T
-
- if timesteps is None:
- timesteps = (
- self.ddpm_num_timesteps
- if ddim_use_original_steps
- else self.ddim_timesteps
- )
- elif timesteps is not None and not ddim_use_original_steps:
- subset_end = (
- int(
- min(timesteps / self.ddim_timesteps.shape[0], 1)
- * self.ddim_timesteps.shape[0]
- )
- - 1
- )
- timesteps = self.ddim_timesteps[:subset_end]
-
- intermediates = {"x_inter": [img], "pred_x0": [img]}
- time_range = (
- reversed(range(0, timesteps))
- if ddim_use_original_steps
- else np.flip(timesteps)
- )
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
- # print(f"Running DDIM Sampling with {total_steps} timesteps")
-
- # iterator = gr.Progress().tqdm(time_range, desc="DDIM Sampler", total=total_steps)
- iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
-
- for i, step in enumerate(iterator):
- index = total_steps - i - 1
- ts = torch.full((b,), step, device=device, dtype=torch.long)
- if mask is not None:
- assert x0 is not None
- img_orig = self.model.q_sample(
- x0, ts
- ) # TODO deterministic forward pass?
- img = (
- img_orig * mask + (1.0 - mask) * img
- ) # In the first sampling step, img is pure gaussian noise
-
- outs = self.p_sample_ddim(
- img,
- cond,
- ts,
- index=index,
- use_original_steps=ddim_use_original_steps,
- quantize_denoised=quantize_denoised,
- temperature=temperature,
- noise_dropout=noise_dropout,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- img, pred_x0 = outs
- if callback:
- callback(i)
- if img_callback:
- img_callback(pred_x0, i)
-
- if index % log_every_t == 0 or index == total_steps - 1:
- intermediates["x_inter"].append(img)
- intermediates["pred_x0"].append(pred_x0)
-
- return img, intermediates
-
- @torch.no_grad()
- def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
- # fast, but does not allow for exact reconstruction
- # t serves as an index to gather the correct alphas
- if use_original_steps:
- sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
- sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
- else:
- sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
- sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
-
- if noise is None:
- noise = torch.randn_like(x0)
-
- return (
- extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
- + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
- )
-
- @torch.no_grad()
- def decode(
- self,
- x_latent,
- cond,
- t_start,
- unconditional_guidance_scale=1.0,
- unconditional_conditioning=None,
- use_original_steps=False,
- ):
-
- timesteps = (
- np.arange(self.ddpm_num_timesteps)
- if use_original_steps
- else self.ddim_timesteps
- )
- timesteps = timesteps[:t_start]
-
- time_range = np.flip(timesteps)
- total_steps = timesteps.shape[0]
- # print(f"Running DDIM Sampling with {total_steps} timesteps")
-
- # iterator = gr.Progress().tqdm(time_range, desc="Decoding image", total=total_steps)
- iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
- x_dec = x_latent
-
- for i, step in enumerate(iterator):
- index = total_steps - i - 1
- ts = torch.full(
- (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
- )
- x_dec, _ = self.p_sample_ddim(
- x_dec,
- cond,
- ts,
- index=index,
- use_original_steps=use_original_steps,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- return x_dec
-
- @torch.no_grad()
- def p_sample_ddim(
- self,
- x,
- c,
- t,
- index,
- repeat_noise=False,
- use_original_steps=False,
- quantize_denoised=False,
- temperature=1.0,
- noise_dropout=0.0,
- score_corrector=None,
- corrector_kwargs=None,
- unconditional_guidance_scale=1.0,
- unconditional_conditioning=None,
- ):
- b, *_, device = *x.shape, x.device
-
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
- e_t = self.model.apply_model(x, t, c)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
- c_in = torch.cat([unconditional_conditioning, c])
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- # When unconditional_guidance_scale == 1: only e_t
- # When unconditional_guidance_scale == 0: only unconditional
- # When unconditional_guidance_scale > 1: add more unconditional guidance
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
-
- if score_corrector is not None:
- assert self.model.parameterization == "eps"
- e_t = score_corrector.modify_score(
- self.model, e_t, x, t, c, **corrector_kwargs
- )
-
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = (
- self.model.alphas_cumprod_prev
- if use_original_steps
- else self.ddim_alphas_prev
- )
- sqrt_one_minus_alphas = (
- self.model.sqrt_one_minus_alphas_cumprod
- if use_original_steps
- else self.ddim_sqrt_one_minus_alphas
- )
- sigmas = (
- self.model.ddim_sigmas_for_original_num_steps
- if use_original_steps
- else self.ddim_sigmas
- )
- # select parameters corresponding to the currently considered timestep
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full(
- (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
- )
-
- # current prediction for x_0
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
- # direction pointing to x_t
- dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.0:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise # TODO
- return x_prev, pred_x0
diff --git a/audioldm/latent_diffusion/ddpm.py b/audioldm/latent_diffusion/ddpm.py
deleted file mode 100644
index ffca031c27d413698adee5a58547b7d0ea4069c3..0000000000000000000000000000000000000000
--- a/audioldm/latent_diffusion/ddpm.py
+++ /dev/null
@@ -1,441 +0,0 @@
-"""
-wild mixture of
-https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
-https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
-https://github.com/CompVis/taming-transformers
--- merci
-"""
-import sys
-import os
-
-import torch
-import torch.nn as nn
-import numpy as np
-from contextlib import contextmanager
-from functools import partial
-from tqdm import tqdm
-
-from audioldm.utils import exists, default, count_params, instantiate_from_config
-from audioldm.latent_diffusion.ema import LitEma
-from audioldm.latent_diffusion.util import (
- make_beta_schedule,
- extract_into_tensor,
- noise_like,
-)
-import soundfile as sf
-import os
-
-
-__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
-
-
-def disabled_train(self, mode=True):
- """Overwrite model.train with this function to make sure train/eval mode
- does not change anymore."""
- return self
-
-
-def uniform_on_device(r1, r2, shape, device):
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
-
-
-class DiffusionWrapper(nn.Module):
- def __init__(self, diff_model_config, conditioning_key):
- super().__init__()
- self.diffusion_model = instantiate_from_config(diff_model_config)
- self.conditioning_key = conditioning_key
- assert self.conditioning_key in [
- None,
- "concat",
- "crossattn",
- "hybrid",
- "adm",
- "film",
- ]
-
- def forward(
- self, x, t, c_concat: list = None, c_crossattn: list = None, c_film: list = None
- ):
- x = x.contiguous()
- t = t.contiguous()
-
- if self.conditioning_key is None:
- out = self.diffusion_model(x, t)
- elif self.conditioning_key == "concat":
- xc = torch.cat([x] + c_concat, dim=1)
- out = self.diffusion_model(xc, t)
- elif self.conditioning_key == "crossattn":
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(x, t, context=cc)
- elif self.conditioning_key == "hybrid":
- xc = torch.cat([x] + c_concat, dim=1)
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(xc, t, context=cc)
- elif (
- self.conditioning_key == "film"
- ): # The condition is assumed to be a global token, which wil pass through a linear layer and added with the time embedding for the FILM
- cc = c_film[0].squeeze(1) # only has one token
- out = self.diffusion_model(x, t, y=cc)
- elif self.conditioning_key == "adm":
- cc = c_crossattn[0]
- out = self.diffusion_model(x, t, y=cc)
- else:
- raise NotImplementedError()
-
- return out
-
-
-class DDPM(nn.Module):
- # classic DDPM with Gaussian diffusion, in image space
- def __init__(
- self,
- unet_config,
- timesteps=1000,
- beta_schedule="linear",
- loss_type="l2",
- ckpt_path=None,
- ignore_keys=[],
- load_only_unet=False,
- monitor="val/loss",
- use_ema=True,
- first_stage_key="image",
- latent_t_size=256,
- latent_f_size=16,
- channels=3,
- log_every_t=100,
- clip_denoised=True,
- linear_start=1e-4,
- linear_end=2e-2,
- cosine_s=8e-3,
- given_betas=None,
- original_elbo_weight=0.0,
- v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
- l_simple_weight=1.0,
- conditioning_key=None,
- parameterization="eps", # all assuming fixed variance schedules
- scheduler_config=None,
- use_positional_encodings=False,
- learn_logvar=False,
- logvar_init=0.0,
- ):
- super().__init__()
- assert parameterization in [
- "eps",
- "x0",
- ], 'currently only supporting "eps" and "x0"'
- self.parameterization = parameterization
- self.state = None
- # print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
- self.cond_stage_model = None
- self.clip_denoised = clip_denoised
- self.log_every_t = log_every_t
- self.first_stage_key = first_stage_key
-
- self.latent_t_size = latent_t_size
- self.latent_f_size = latent_f_size
-
- self.channels = channels
- self.use_positional_encodings = use_positional_encodings
- self.model = DiffusionWrapper(unet_config, conditioning_key)
- count_params(self.model, verbose=True)
- self.use_ema = use_ema
- if self.use_ema:
- self.model_ema = LitEma(self.model)
- # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
-
- self.use_scheduler = scheduler_config is not None
- if self.use_scheduler:
- self.scheduler_config = scheduler_config
-
- self.v_posterior = v_posterior
- self.original_elbo_weight = original_elbo_weight
- self.l_simple_weight = l_simple_weight
-
- if monitor is not None:
- self.monitor = monitor
-
- self.register_schedule(
- given_betas=given_betas,
- beta_schedule=beta_schedule,
- timesteps=timesteps,
- linear_start=linear_start,
- linear_end=linear_end,
- cosine_s=cosine_s,
- )
-
- self.loss_type = loss_type
-
- self.learn_logvar = learn_logvar
- self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
- if self.learn_logvar:
- self.logvar = nn.Parameter(self.logvar, requires_grad=True)
- else:
- self.logvar = nn.Parameter(self.logvar, requires_grad=False)
-
- self.logger_save_dir = None
- self.logger_project = None
- self.logger_version = None
- self.label_indices_total = None
- # To avoid the system cannot find metric value for checkpoint
- self.metrics_buffer = {
- "val/kullback_leibler_divergence_sigmoid": 15.0,
- "val/kullback_leibler_divergence_softmax": 10.0,
- "val/psnr": 0.0,
- "val/ssim": 0.0,
- "val/inception_score_mean": 1.0,
- "val/inception_score_std": 0.0,
- "val/kernel_inception_distance_mean": 0.0,
- "val/kernel_inception_distance_std": 0.0,
- "val/frechet_inception_distance": 133.0,
- "val/frechet_audio_distance": 32.0,
- }
- self.initial_learning_rate = None
-
- def get_log_dir(self):
- if (
- self.logger_save_dir is None
- and self.logger_project is None
- and self.logger_version is None
- ):
- return os.path.join(
- self.logger.save_dir, self.logger._project, self.logger.version
- )
- else:
- return os.path.join(
- self.logger_save_dir, self.logger_project, self.logger_version
- )
-
- def set_log_dir(self, save_dir, project, version):
- self.logger_save_dir = save_dir
- self.logger_project = project
- self.logger_version = version
-
- def register_schedule(
- self,
- given_betas=None,
- beta_schedule="linear",
- timesteps=1000,
- linear_start=1e-4,
- linear_end=2e-2,
- cosine_s=8e-3,
- ):
- if exists(given_betas):
- betas = given_betas
- else:
- betas = make_beta_schedule(
- beta_schedule,
- timesteps,
- linear_start=linear_start,
- linear_end=linear_end,
- cosine_s=cosine_s,
- )
- alphas = 1.0 - betas
- alphas_cumprod = np.cumprod(alphas, axis=0)
- alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
-
- (timesteps,) = betas.shape
- self.num_timesteps = int(timesteps)
- self.linear_start = linear_start
- self.linear_end = linear_end
- assert (
- alphas_cumprod.shape[0] == self.num_timesteps
- ), "alphas have to be defined for each timestep"
-
- to_torch = partial(torch.tensor, dtype=torch.float32)
-
- self.register_buffer("betas", to_torch(betas))
- self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
- self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
- self.register_buffer(
- "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
- )
- self.register_buffer(
- "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
- )
- self.register_buffer(
- "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
- )
- self.register_buffer(
- "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
- )
-
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = (1 - self.v_posterior) * betas * (
- 1.0 - alphas_cumprod_prev
- ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.register_buffer("posterior_variance", to_torch(posterior_variance))
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.register_buffer(
- "posterior_log_variance_clipped",
- to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
- )
- self.register_buffer(
- "posterior_mean_coef1",
- to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
- )
- self.register_buffer(
- "posterior_mean_coef2",
- to_torch(
- (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
- ),
- )
-
- if self.parameterization == "eps":
- lvlb_weights = self.betas**2 / (
- 2
- * self.posterior_variance
- * to_torch(alphas)
- * (1 - self.alphas_cumprod)
- )
- elif self.parameterization == "x0":
- lvlb_weights = (
- 0.5
- * np.sqrt(torch.Tensor(alphas_cumprod))
- / (2.0 * 1 - torch.Tensor(alphas_cumprod))
- )
- else:
- raise NotImplementedError("mu not supported")
- # TODO how to choose this term
- lvlb_weights[0] = lvlb_weights[1]
- self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
- assert not torch.isnan(self.lvlb_weights).all()
-
- @contextmanager
- def ema_scope(self, context=None):
- if self.use_ema:
- self.model_ema.store(self.model.parameters())
- self.model_ema.copy_to(self.model)
- if context is not None:
- # print(f"{context}: Switched to EMA weights")
- pass
- try:
- yield None
- finally:
- if self.use_ema:
- self.model_ema.restore(self.model.parameters())
- if context is not None:
- # print(f"{context}: Restored training weights")
- pass
-
- def q_mean_variance(self, x_start, t):
- """
- Get the distribution q(x_t | x_0).
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
- """
- mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
- variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
- log_variance = extract_into_tensor(
- self.log_one_minus_alphas_cumprod, t, x_start.shape
- )
- return mean, variance, log_variance
-
- def predict_start_from_noise(self, x_t, t, noise):
- return (
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
- * noise
- )
-
- def q_posterior(self, x_start, x_t, t):
- posterior_mean = (
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
- + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
- )
- posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
- posterior_log_variance_clipped = extract_into_tensor(
- self.posterior_log_variance_clipped, t, x_t.shape
- )
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
- def p_mean_variance(self, x, t, clip_denoised: bool):
- model_out = self.model(x, t)
- if self.parameterization == "eps":
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
- elif self.parameterization == "x0":
- x_recon = model_out
- if clip_denoised:
- x_recon.clamp_(-1.0, 1.0)
-
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
- x_start=x_recon, x_t=x, t=t
- )
- return model_mean, posterior_variance, posterior_log_variance
-
- @torch.no_grad()
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
- b, *_, device = *x.shape, x.device
- model_mean, _, model_log_variance = self.p_mean_variance(
- x=x, t=t, clip_denoised=clip_denoised
- )
- noise = noise_like(x.shape, device, repeat_noise)
- # no noise when t == 0
- nonzero_mask = (
- (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
- )
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
-
- @torch.no_grad()
- def p_sample_loop(self, shape, return_intermediates=False):
- device = self.betas.device
- b = shape[0]
- img = torch.randn(shape, device=device)
- intermediates = [img]
- for i in tqdm(
- reversed(range(0, self.num_timesteps)),
- desc="Sampling t",
- total=self.num_timesteps,
- ):
- img = self.p_sample(
- img,
- torch.full((b,), i, device=device, dtype=torch.long),
- clip_denoised=self.clip_denoised,
- )
- if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
- intermediates.append(img)
- if return_intermediates:
- return img, intermediates
- return img
-
- @torch.no_grad()
- def sample(self, batch_size=16, return_intermediates=False):
- shape = (batch_size, channels, self.latent_t_size, self.latent_f_size)
- channels = self.channels
- return self.p_sample_loop(shape, return_intermediates=return_intermediates)
-
- def q_sample(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- return (
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
- + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
- * noise
- )
-
- def forward(self, x, *args, **kwargs):
- t = torch.randint(
- 0, self.num_timesteps, (x.shape[0],), device=self.device
- ).long()
- return self.p_losses(x, t, *args, **kwargs)
-
- def get_input(self, batch, k):
- # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch
- fbank, log_magnitudes_stft, label_indices, fname, waveform, text = batch
- ret = {}
-
- ret["fbank"] = (
- fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
- )
- ret["stft"] = log_magnitudes_stft.to(
- memory_format=torch.contiguous_format
- ).float()
- # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float()
- ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
- ret["text"] = list(text)
- ret["fname"] = fname
-
- return ret[k]
diff --git a/audioldm/latent_diffusion/ema.py b/audioldm/latent_diffusion/ema.py
deleted file mode 100644
index 192b012186bab3d8a5380bc9b891da8eef0fd9fa..0000000000000000000000000000000000000000
--- a/audioldm/latent_diffusion/ema.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import torch
-from torch import nn
-
-class LitEma(nn.Module):
- def __init__(self, model, decay=0.9999, use_num_upates=True):
- super().__init__()
- if decay < 0.0 or decay > 1.0:
- raise ValueError("Decay must be between 0 and 1")
-
- self.m_name2s_name = {}
- self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
- self.register_buffer(
- "num_updates",
- torch.tensor(0, dtype=torch.int)
- if use_num_upates
- else torch.tensor(-1, dtype=torch.int),
- )
-
- for name, p in model.named_parameters():
- if p.requires_grad:
- # remove as '.'-character is not allowed in buffers
- s_name = name.replace(".", "")
- self.m_name2s_name.update({name: s_name})
- self.register_buffer(s_name, p.clone().detach().data)
-
- self.collected_params = []
-
- def forward(self, model):
- decay = self.decay
-
- if self.num_updates >= 0:
- self.num_updates += 1
- decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
-
- one_minus_decay = 1.0 - decay
-
- with torch.no_grad():
- m_param = dict(model.named_parameters())
- shadow_params = dict(self.named_buffers())
-
- for key in m_param:
- if m_param[key].requires_grad:
- sname = self.m_name2s_name[key]
- shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
- shadow_params[sname].sub_(
- one_minus_decay * (shadow_params[sname] - m_param[key])
- )
- else:
- assert not key in self.m_name2s_name
-
- def copy_to(self, model):
- m_param = dict(model.named_parameters())
- shadow_params = dict(self.named_buffers())
- for key in m_param:
- if m_param[key].requires_grad:
- m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
- else:
- assert not key in self.m_name2s_name
-
- def store(self, parameters):
- """
- Save the current parameters for restoring later.
- Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- temporarily stored.
- """
- self.collected_params = [param.clone() for param in parameters]
-
- def restore(self, parameters):
- """
- Restore the parameters stored with the `store` method.
- Useful to validate the model with EMA parameters without affecting the
- original optimization process. Store the parameters before the
- `copy_to` method. After validation (or model saving), use this to
- restore the former parameters.
- Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored parameters.
- """
- for c_param, param in zip(self.collected_params, parameters):
- param.data.copy_(c_param.data)
diff --git a/audioldm/latent_diffusion/openaimodel.py b/audioldm/latent_diffusion/openaimodel.py
deleted file mode 100644
index 831d7aafb36bba16888e4389153979a6c13639f5..0000000000000000000000000000000000000000
--- a/audioldm/latent_diffusion/openaimodel.py
+++ /dev/null
@@ -1,1069 +0,0 @@
-from abc import abstractmethod
-import math
-
-import numpy as np
-import torch as th
-import torch.nn as nn
-import torch.nn.functional as F
-
-from audioldm.latent_diffusion.util import (
- checkpoint,
- conv_nd,
- linear,
- avg_pool_nd,
- zero_module,
- normalization,
- timestep_embedding,
-)
-from audioldm.latent_diffusion.attention import SpatialTransformer
-
-
-# dummy replace
-def convert_module_to_f16(x):
- pass
-
-
-def convert_module_to_f32(x):
- pass
-
-
-## go
-class AttentionPool2d(nn.Module):
- """
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
- """
-
- def __init__(
- self,
- spacial_dim: int,
- embed_dim: int,
- num_heads_channels: int,
- output_dim: int = None,
- ):
- super().__init__()
- self.positional_embedding = nn.Parameter(
- th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
- )
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
- self.num_heads = embed_dim // num_heads_channels
- self.attention = QKVAttention(self.num_heads)
-
- def forward(self, x):
- b, c, *_spatial = x.shape
- x = x.reshape(b, c, -1).contiguous() # NC(HW)
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
- x = self.qkv_proj(x)
- x = self.attention(x)
- x = self.c_proj(x)
- return x[:, :, 0]
-
-
-class TimestepBlock(nn.Module):
- """
- Any module where forward() takes timestep embeddings as a second argument.
- """
-
- @abstractmethod
- def forward(self, x, emb):
- """
- Apply the module to `x` given `emb` timestep embeddings.
- """
-
-
-class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
- """
- A sequential module that passes timestep embeddings to the children that
- support it as an extra input.
- """
-
- def forward(self, x, emb, context=None):
- for layer in self:
- if isinstance(layer, TimestepBlock):
- x = layer(x, emb)
- elif isinstance(layer, SpatialTransformer):
- x = layer(x, context)
- else:
- x = layer(x)
- return x
-
-
-class Upsample(nn.Module):
- """
- An upsampling layer with an optional convolution.
- :param channels: channels in the inputs and outputs.
- :param use_conv: a bool determining if a convolution is applied.
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
- upsampling occurs in the inner-two dimensions.
- """
-
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.dims = dims
- if use_conv:
- self.conv = conv_nd(
- dims, self.channels, self.out_channels, 3, padding=padding
- )
-
- def forward(self, x):
- assert x.shape[1] == self.channels
- if self.dims == 3:
- x = F.interpolate(
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
- )
- else:
- x = F.interpolate(x, scale_factor=2, mode="nearest")
- if self.use_conv:
- x = self.conv(x)
- return x
-
-
-class TransposedUpsample(nn.Module):
- "Learned 2x upsampling without padding"
-
- def __init__(self, channels, out_channels=None, ks=5):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
-
- self.up = nn.ConvTranspose2d(
- self.channels, self.out_channels, kernel_size=ks, stride=2
- )
-
- def forward(self, x):
- return self.up(x)
-
-
-class Downsample(nn.Module):
- """
- A downsampling layer with an optional convolution.
- :param channels: channels in the inputs and outputs.
- :param use_conv: a bool determining if a convolution is applied.
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
- downsampling occurs in the inner-two dimensions.
- """
-
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.dims = dims
- stride = 2 if dims != 3 else (1, 2, 2)
- if use_conv:
- self.op = conv_nd(
- dims,
- self.channels,
- self.out_channels,
- 3,
- stride=stride,
- padding=padding,
- )
- else:
- assert self.channels == self.out_channels
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
-
- def forward(self, x):
- assert x.shape[1] == self.channels
- return self.op(x)
-
-
-class ResBlock(TimestepBlock):
- """
- A residual block that can optionally change the number of channels.
- :param channels: the number of input channels.
- :param emb_channels: the number of timestep embedding channels.
- :param dropout: the rate of dropout.
- :param out_channels: if specified, the number of out channels.
- :param use_conv: if True and out_channels is specified, use a spatial
- convolution instead of a smaller 1x1 convolution to change the
- channels in the skip connection.
- :param dims: determines if the signal is 1D, 2D, or 3D.
- :param use_checkpoint: if True, use gradient checkpointing on this module.
- :param up: if True, use this block for upsampling.
- :param down: if True, use this block for downsampling.
- """
-
- def __init__(
- self,
- channels,
- emb_channels,
- dropout,
- out_channels=None,
- use_conv=False,
- use_scale_shift_norm=False,
- dims=2,
- use_checkpoint=False,
- up=False,
- down=False,
- ):
- super().__init__()
- self.channels = channels
- self.emb_channels = emb_channels
- self.dropout = dropout
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.use_checkpoint = use_checkpoint
- self.use_scale_shift_norm = use_scale_shift_norm
-
- self.in_layers = nn.Sequential(
- normalization(channels),
- nn.SiLU(),
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
- )
-
- self.updown = up or down
-
- if up:
- self.h_upd = Upsample(channels, False, dims)
- self.x_upd = Upsample(channels, False, dims)
- elif down:
- self.h_upd = Downsample(channels, False, dims)
- self.x_upd = Downsample(channels, False, dims)
- else:
- self.h_upd = self.x_upd = nn.Identity()
-
- self.emb_layers = nn.Sequential(
- nn.SiLU(),
- linear(
- emb_channels,
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
- ),
- )
- self.out_layers = nn.Sequential(
- normalization(self.out_channels),
- nn.SiLU(),
- nn.Dropout(p=dropout),
- zero_module(
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
- ),
- )
-
- if self.out_channels == channels:
- self.skip_connection = nn.Identity()
- elif use_conv:
- self.skip_connection = conv_nd(
- dims, channels, self.out_channels, 3, padding=1
- )
- else:
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
-
- def forward(self, x, emb):
- """
- Apply the block to a Tensor, conditioned on a timestep embedding.
- :param x: an [N x C x ...] Tensor of features.
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
- :return: an [N x C x ...] Tensor of outputs.
- """
- return checkpoint(
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
- )
-
- def _forward(self, x, emb):
- if self.updown:
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
- h = in_rest(x)
- h = self.h_upd(h)
- x = self.x_upd(x)
- h = in_conv(h)
- else:
- h = self.in_layers(x)
- emb_out = self.emb_layers(emb).type(h.dtype)
- while len(emb_out.shape) < len(h.shape):
- emb_out = emb_out[..., None]
- if self.use_scale_shift_norm:
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
- scale, shift = th.chunk(emb_out, 2, dim=1)
- h = out_norm(h) * (1 + scale) + shift
- h = out_rest(h)
- else:
- h = h + emb_out
- h = self.out_layers(h)
- return self.skip_connection(x) + h
-
-
-class AttentionBlock(nn.Module):
- """
- An attention block that allows spatial positions to attend to each other.
- Originally ported from here, but adapted to the N-d case.
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
- """
-
- def __init__(
- self,
- channels,
- num_heads=1,
- num_head_channels=-1,
- use_checkpoint=False,
- use_new_attention_order=False,
- ):
- super().__init__()
- self.channels = channels
- if num_head_channels == -1:
- self.num_heads = num_heads
- else:
- assert (
- channels % num_head_channels == 0
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
- self.num_heads = channels // num_head_channels
- self.use_checkpoint = use_checkpoint
- self.norm = normalization(channels)
- self.qkv = conv_nd(1, channels, channels * 3, 1)
- if use_new_attention_order:
- # split qkv before split heads
- self.attention = QKVAttention(self.num_heads)
- else:
- # split heads before split qkv
- self.attention = QKVAttentionLegacy(self.num_heads)
-
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
-
- def forward(self, x):
- return checkpoint(
- self._forward, (x,), self.parameters(), True
- ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
- # return pt_checkpoint(self._forward, x) # pytorch
-
- def _forward(self, x):
- b, c, *spatial = x.shape
- x = x.reshape(b, c, -1).contiguous()
- qkv = self.qkv(self.norm(x)).contiguous()
- h = self.attention(qkv).contiguous()
- h = self.proj_out(h).contiguous()
- return (x + h).reshape(b, c, *spatial).contiguous()
-
-
-def count_flops_attn(model, _x, y):
- """
- A counter for the `thop` package to count the operations in an
- attention operation.
- Meant to be used like:
- macs, params = thop.profile(
- model,
- inputs=(inputs, timestamps),
- custom_ops={QKVAttention: QKVAttention.count_flops},
- )
- """
- b, c, *spatial = y[0].shape
- num_spatial = int(np.prod(spatial))
- # We perform two matmuls with the same number of ops.
- # The first computes the weight matrix, the second computes
- # the combination of the value vectors.
- matmul_ops = 2 * b * (num_spatial**2) * c
- model.total_ops += th.DoubleTensor([matmul_ops])
-
-
-class QKVAttentionLegacy(nn.Module):
- """
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
- """
-
- def __init__(self, n_heads):
- super().__init__()
- self.n_heads = n_heads
-
- def forward(self, qkv):
- """
- Apply QKV attention.
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
- :return: an [N x (H * C) x T] tensor after attention.
- """
- bs, width, length = qkv.shape
- assert width % (3 * self.n_heads) == 0
- ch = width // (3 * self.n_heads)
- q, k, v = (
- qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1)
- )
- scale = 1 / math.sqrt(math.sqrt(ch))
- weight = th.einsum(
- "bct,bcs->bts", q * scale, k * scale
- ) # More stable with f16 than dividing afterwards
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
- a = th.einsum("bts,bcs->bct", weight, v)
- return a.reshape(bs, -1, length).contiguous()
-
- @staticmethod
- def count_flops(model, _x, y):
- return count_flops_attn(model, _x, y)
-
-
-class QKVAttention(nn.Module):
- """
- A module which performs QKV attention and splits in a different order.
- """
-
- def __init__(self, n_heads):
- super().__init__()
- self.n_heads = n_heads
-
- def forward(self, qkv):
- """
- Apply QKV attention.
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
- :return: an [N x (H * C) x T] tensor after attention.
- """
- bs, width, length = qkv.shape
- assert width % (3 * self.n_heads) == 0
- ch = width // (3 * self.n_heads)
- q, k, v = qkv.chunk(3, dim=1)
- scale = 1 / math.sqrt(math.sqrt(ch))
- weight = th.einsum(
- "bct,bcs->bts",
- (q * scale).view(bs * self.n_heads, ch, length),
- (k * scale).view(bs * self.n_heads, ch, length),
- ) # More stable with f16 than dividing afterwards
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
- a = th.einsum(
- "bts,bcs->bct",
- weight,
- v.reshape(bs * self.n_heads, ch, length).contiguous(),
- )
- return a.reshape(bs, -1, length).contiguous()
-
- @staticmethod
- def count_flops(model, _x, y):
- return count_flops_attn(model, _x, y)
-
-
-class UNetModel(nn.Module):
- """
- The full UNet model with attention and timestep embedding.
- :param in_channels: channels in the input Tensor.
- :param model_channels: base channel count for the model.
- :param out_channels: channels in the output Tensor.
- :param num_res_blocks: number of residual blocks per downsample.
- :param attention_resolutions: a collection of downsample rates at which
- attention will take place. May be a set, list, or tuple.
- For example, if this contains 4, then at 4x downsampling, attention
- will be used.
- :param dropout: the dropout probability.
- :param channel_mult: channel multiplier for each level of the UNet.
- :param conv_resample: if True, use learned convolutions for upsampling and
- downsampling.
- :param dims: determines if the signal is 1D, 2D, or 3D.
- :param num_classes: if specified (as an int), then this model will be
- class-conditional with `num_classes` classes.
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
- :param num_heads: the number of attention heads in each attention layer.
- :param num_heads_channels: if specified, ignore num_heads and instead use
- a fixed channel width per attention head.
- :param num_heads_upsample: works with num_heads to set a different number
- of heads for upsampling. Deprecated.
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
- :param resblock_updown: use residual blocks for up/downsampling.
- :param use_new_attention_order: use a different attention pattern for potentially
- increased efficiency.
- """
-
- def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- num_classes=None,
- extra_film_condition_dim=None,
- use_checkpoint=False,
- use_fp16=False,
- num_heads=-1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- extra_film_use_concat=False, # If true, concatenate extrafilm condition with time embedding, else addition
- resblock_updown=False,
- use_new_attention_order=False,
- use_spatial_transformer=False, # custom transformer support
- transformer_depth=1, # custom transformer support
- context_dim=None, # custom transformer support
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
- legacy=True,
- ):
- super().__init__()
- if num_heads_upsample == -1:
- num_heads_upsample = num_heads
-
- if num_heads == -1:
- assert (
- num_head_channels != -1
- ), "Either num_heads or num_head_channels has to be set"
-
- if num_head_channels == -1:
- assert (
- num_heads != -1
- ), "Either num_heads or num_head_channels has to be set"
-
- self.image_size = image_size
- self.in_channels = in_channels
- self.model_channels = model_channels
- self.out_channels = out_channels
- self.num_res_blocks = num_res_blocks
- self.attention_resolutions = attention_resolutions
- self.dropout = dropout
- self.channel_mult = channel_mult
- self.conv_resample = conv_resample
- self.num_classes = num_classes
- self.extra_film_condition_dim = extra_film_condition_dim
- self.use_checkpoint = use_checkpoint
- self.dtype = th.float16 if use_fp16 else th.float32
- self.num_heads = num_heads
- self.num_head_channels = num_head_channels
- self.num_heads_upsample = num_heads_upsample
- self.predict_codebook_ids = n_embed is not None
- self.extra_film_use_concat = extra_film_use_concat
- time_embed_dim = model_channels * 4
- self.time_embed = nn.Sequential(
- linear(model_channels, time_embed_dim),
- nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
- )
-
- assert not (
- self.num_classes is not None and self.extra_film_condition_dim is not None
- ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim."
-
- if self.num_classes is not None:
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
-
- self.use_extra_film_by_concat = (
- self.extra_film_condition_dim is not None and self.extra_film_use_concat
- )
- self.use_extra_film_by_addition = (
- self.extra_film_condition_dim is not None and not self.extra_film_use_concat
- )
-
- if self.extra_film_condition_dim is not None:
- self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim)
- # print("+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " % self.extra_film_condition_dim)
- # if(self.use_extra_film_by_concat):
- # print("\t By concatenation with time embedding")
- # elif(self.use_extra_film_by_concat):
- # print("\t By addition with time embedding")
-
- if use_spatial_transformer and (
- self.use_extra_film_by_concat or self.use_extra_film_by_addition
- ):
- # print("+ Spatial transformer will only be used as self-attention. Because you have choose to use film as your global condition.")
- spatial_transformer_no_context = True
- else:
- spatial_transformer_no_context = False
-
- if use_spatial_transformer and not spatial_transformer_no_context:
- assert (
- context_dim is not None
- ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
-
- if context_dim is not None and not spatial_transformer_no_context:
- assert (
- use_spatial_transformer
- ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
- from omegaconf.listconfig import ListConfig
-
- if type(context_dim) == ListConfig:
- context_dim = list(context_dim)
-
- self.input_blocks = nn.ModuleList(
- [
- TimestepEmbedSequential(
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
- )
- ]
- )
- self._feature_size = model_channels
- input_block_chans = [model_channels]
- ch = model_channels
- ds = 1
- for level, mult in enumerate(channel_mult):
- for _ in range(num_res_blocks):
- layers = [
- ResBlock(
- ch,
- time_embed_dim
- if (not self.use_extra_film_by_concat)
- else time_embed_dim * 2,
- dropout,
- out_channels=mult * model_channels,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- ch = mult * model_channels
- if ds in attention_resolutions:
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- dim_head = (
- ch // num_heads
- if use_spatial_transformer
- else num_head_channels
- )
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- )
- if not use_spatial_transformer
- else SpatialTransformer(
- ch,
- num_heads,
- dim_head,
- depth=transformer_depth,
- context_dim=context_dim,
- no_context=spatial_transformer_no_context,
- )
- )
- self.input_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
- input_block_chans.append(ch)
- if level != len(channel_mult) - 1:
- out_ch = ch
- self.input_blocks.append(
- TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim
- if (not self.use_extra_film_by_concat)
- else time_embed_dim * 2,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- down=True,
- )
- if resblock_updown
- else Downsample(
- ch, conv_resample, dims=dims, out_channels=out_ch
- )
- )
- )
- ch = out_ch
- input_block_chans.append(ch)
- ds *= 2
- self._feature_size += ch
-
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- # num_heads = 1
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
- self.middle_block = TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim
- if (not self.use_extra_film_by_concat)
- else time_embed_dim * 2,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- )
- if not use_spatial_transformer
- else SpatialTransformer(
- ch,
- num_heads,
- dim_head,
- depth=transformer_depth,
- context_dim=context_dim,
- no_context=spatial_transformer_no_context,
- ),
- ResBlock(
- ch,
- time_embed_dim
- if (not self.use_extra_film_by_concat)
- else time_embed_dim * 2,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- )
- self._feature_size += ch
-
- self.output_blocks = nn.ModuleList([])
- for level, mult in list(enumerate(channel_mult))[::-1]:
- for i in range(num_res_blocks + 1):
- ich = input_block_chans.pop()
- layers = [
- ResBlock(
- ch + ich,
- time_embed_dim
- if (not self.use_extra_film_by_concat)
- else time_embed_dim * 2,
- dropout,
- out_channels=model_channels * mult,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- ch = model_channels * mult
- if ds in attention_resolutions:
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- # num_heads = 1
- dim_head = (
- ch // num_heads
- if use_spatial_transformer
- else num_head_channels
- )
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads_upsample,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- )
- if not use_spatial_transformer
- else SpatialTransformer(
- ch,
- num_heads,
- dim_head,
- depth=transformer_depth,
- context_dim=context_dim,
- no_context=spatial_transformer_no_context,
- )
- )
- if level and i == num_res_blocks:
- out_ch = ch
- layers.append(
- ResBlock(
- ch,
- time_embed_dim
- if (not self.use_extra_film_by_concat)
- else time_embed_dim * 2,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- up=True,
- )
- if resblock_updown
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
- )
- ds //= 2
- self.output_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
-
- self.out = nn.Sequential(
- normalization(ch),
- nn.SiLU(),
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
- )
- if self.predict_codebook_ids:
- self.id_predictor = nn.Sequential(
- normalization(ch),
- conv_nd(dims, model_channels, n_embed, 1),
- # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
- )
-
- self.shape_reported = False
-
- def convert_to_fp16(self):
- """
- Convert the torso of the model to float16.
- """
- self.input_blocks.apply(convert_module_to_f16)
- self.middle_block.apply(convert_module_to_f16)
- self.output_blocks.apply(convert_module_to_f16)
-
- def convert_to_fp32(self):
- """
- Convert the torso of the model to float32.
- """
- self.input_blocks.apply(convert_module_to_f32)
- self.middle_block.apply(convert_module_to_f32)
- self.output_blocks.apply(convert_module_to_f32)
-
- def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
- """
- Apply the model to an input batch.
- :param x: an [N x C x ...] Tensor of inputs.
- :param timesteps: a 1-D batch of timesteps.
- :param context: conditioning plugged in via crossattn
- :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional
- :return: an [N x C x ...] Tensor of outputs.
- """
- if not self.shape_reported:
- # print("The shape of UNet input is", x.size())
- self.shape_reported = True
-
- assert (y is not None) == (
- self.num_classes is not None or self.extra_film_condition_dim is not None
- ), "must specify y if and only if the model is class-conditional or film embedding conditional"
- hs = []
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
- emb = self.time_embed(t_emb)
-
- if self.num_classes is not None:
- assert y.shape == (x.shape[0],)
- emb = emb + self.label_emb(y)
-
- if self.use_extra_film_by_addition:
- emb = emb + self.film_emb(y)
- elif self.use_extra_film_by_concat:
- emb = th.cat([emb, self.film_emb(y)], dim=-1)
-
- h = x.type(self.dtype)
- for module in self.input_blocks:
- h = module(h, emb, context)
- hs.append(h)
- h = self.middle_block(h, emb, context)
- for module in self.output_blocks:
- h = th.cat([h, hs.pop()], dim=1)
- h = module(h, emb, context)
- h = h.type(x.dtype)
- if self.predict_codebook_ids:
- return self.id_predictor(h)
- else:
- return self.out(h)
-
-
-class EncoderUNetModel(nn.Module):
- """
- The half UNet model with attention and timestep embedding.
- For usage, see UNet.
- """
-
- def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- use_checkpoint=False,
- use_fp16=False,
- num_heads=1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- resblock_updown=False,
- use_new_attention_order=False,
- pool="adaptive",
- *args,
- **kwargs,
- ):
- super().__init__()
-
- if num_heads_upsample == -1:
- num_heads_upsample = num_heads
-
- self.in_channels = in_channels
- self.model_channels = model_channels
- self.out_channels = out_channels
- self.num_res_blocks = num_res_blocks
- self.attention_resolutions = attention_resolutions
- self.dropout = dropout
- self.channel_mult = channel_mult
- self.conv_resample = conv_resample
- self.use_checkpoint = use_checkpoint
- self.dtype = th.float16 if use_fp16 else th.float32
- self.num_heads = num_heads
- self.num_head_channels = num_head_channels
- self.num_heads_upsample = num_heads_upsample
-
- time_embed_dim = model_channels * 4
- self.time_embed = nn.Sequential(
- linear(model_channels, time_embed_dim),
- nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
- )
-
- self.input_blocks = nn.ModuleList(
- [
- TimestepEmbedSequential(
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
- )
- ]
- )
- self._feature_size = model_channels
- input_block_chans = [model_channels]
- ch = model_channels
- ds = 1
- for level, mult in enumerate(channel_mult):
- for _ in range(num_res_blocks):
- layers = [
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=mult * model_channels,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- ch = mult * model_channels
- if ds in attention_resolutions:
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=num_head_channels,
- use_new_attention_order=use_new_attention_order,
- )
- )
- self.input_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
- input_block_chans.append(ch)
- if level != len(channel_mult) - 1:
- out_ch = ch
- self.input_blocks.append(
- TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- down=True,
- )
- if resblock_updown
- else Downsample(
- ch, conv_resample, dims=dims, out_channels=out_ch
- )
- )
- )
- ch = out_ch
- input_block_chans.append(ch)
- ds *= 2
- self._feature_size += ch
-
- self.middle_block = TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=num_head_channels,
- use_new_attention_order=use_new_attention_order,
- ),
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- )
- self._feature_size += ch
- self.pool = pool
- if pool == "adaptive":
- self.out = nn.Sequential(
- normalization(ch),
- nn.SiLU(),
- nn.AdaptiveAvgPool2d((1, 1)),
- zero_module(conv_nd(dims, ch, out_channels, 1)),
- nn.Flatten(),
- )
- elif pool == "attention":
- assert num_head_channels != -1
- self.out = nn.Sequential(
- normalization(ch),
- nn.SiLU(),
- AttentionPool2d(
- (image_size // ds), ch, num_head_channels, out_channels
- ),
- )
- elif pool == "spatial":
- self.out = nn.Sequential(
- nn.Linear(self._feature_size, 2048),
- nn.ReLU(),
- nn.Linear(2048, self.out_channels),
- )
- elif pool == "spatial_v2":
- self.out = nn.Sequential(
- nn.Linear(self._feature_size, 2048),
- normalization(2048),
- nn.SiLU(),
- nn.Linear(2048, self.out_channels),
- )
- else:
- raise NotImplementedError(f"Unexpected {pool} pooling")
-
- def convert_to_fp16(self):
- """
- Convert the torso of the model to float16.
- """
- self.input_blocks.apply(convert_module_to_f16)
- self.middle_block.apply(convert_module_to_f16)
-
- def convert_to_fp32(self):
- """
- Convert the torso of the model to float32.
- """
- self.input_blocks.apply(convert_module_to_f32)
- self.middle_block.apply(convert_module_to_f32)
-
- def forward(self, x, timesteps):
- """
- Apply the model to an input batch.
- :param x: an [N x C x ...] Tensor of inputs.
- :param timesteps: a 1-D batch of timesteps.
- :return: an [N x K] Tensor of outputs.
- """
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
-
- results = []
- h = x.type(self.dtype)
- for module in self.input_blocks:
- h = module(h, emb)
- if self.pool.startswith("spatial"):
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
- h = self.middle_block(h, emb)
- if self.pool.startswith("spatial"):
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
- h = th.cat(results, axis=-1)
- return self.out(h)
- else:
- h = h.type(x.dtype)
- return self.out(h)
diff --git a/audioldm/latent_diffusion/util.py b/audioldm/latent_diffusion/util.py
deleted file mode 100644
index 8b289f6aa7f22a070870d8a706f944dc8547e936..0000000000000000000000000000000000000000
--- a/audioldm/latent_diffusion/util.py
+++ /dev/null
@@ -1,295 +0,0 @@
-# adopted from
-# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
-# and
-# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
-# and
-# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
-#
-# thanks!
-
-
-import os
-import math
-import torch
-import torch.nn as nn
-import numpy as np
-from einops import repeat
-
-from audioldm.utils import instantiate_from_config
-
-
-def make_beta_schedule(
- schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
-):
- if schedule == "linear":
- betas = (
- torch.linspace(
- linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
- )
- ** 2
- )
-
- elif schedule == "cosine":
- timesteps = (
- torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
- )
- alphas = timesteps / (1 + cosine_s) * np.pi / 2
- alphas = torch.cos(alphas).pow(2)
- alphas = alphas / alphas[0]
- betas = 1 - alphas[1:] / alphas[:-1]
- betas = np.clip(betas, a_min=0, a_max=0.999)
-
- elif schedule == "sqrt_linear":
- betas = torch.linspace(
- linear_start, linear_end, n_timestep, dtype=torch.float64
- )
- elif schedule == "sqrt":
- betas = (
- torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
- ** 0.5
- )
- else:
- raise ValueError(f"schedule '{schedule}' unknown.")
- return betas.numpy()
-
-
-def make_ddim_timesteps(
- ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
-):
- if ddim_discr_method == "uniform":
- c = num_ddpm_timesteps // num_ddim_timesteps
- ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
- elif ddim_discr_method == "quad":
- ddim_timesteps = (
- (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
- ).astype(int)
- else:
- raise NotImplementedError(
- f'There is no ddim discretization method called "{ddim_discr_method}"'
- )
-
- # assert ddim_timesteps.shape[0] == num_ddim_timesteps
- # add one to get the final alpha values right (the ones from first scale to data during sampling)
- steps_out = ddim_timesteps + 1
- if verbose:
- print(f"Selected timesteps for ddim sampler: {steps_out}")
- return steps_out
-
-
-def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
- # select alphas for computing the variance schedule
- alphas = alphacums[ddim_timesteps]
- alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
-
- # according the the formula provided in https://arxiv.org/abs/2010.02502
- sigmas = eta * np.sqrt(
- (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
- )
- if verbose:
- print(
- f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
- )
- print(
- f"For the chosen value of eta, which is {eta}, "
- f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
- )
- return sigmas, alphas, alphas_prev
-
-
-def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
- """
- Create a beta schedule that discretizes the given alpha_t_bar function,
- which defines the cumulative product of (1-beta) over time from t = [0,1].
- :param num_diffusion_timesteps: the number of betas to produce.
- :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
- produces the cumulative product of (1-beta) up to that
- part of the diffusion process.
- :param max_beta: the maximum beta to use; use values lower than 1 to
- prevent singularities.
- """
- betas = []
- for i in range(num_diffusion_timesteps):
- t1 = i / num_diffusion_timesteps
- t2 = (i + 1) / num_diffusion_timesteps
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
- return np.array(betas)
-
-
-def extract_into_tensor(a, t, x_shape):
- b, *_ = t.shape
- out = a.gather(-1, t).contiguous()
- return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous()
-
-
-def checkpoint(func, inputs, params, flag):
- """
- Evaluate a function without caching intermediate activations, allowing for
- reduced memory at the expense of extra compute in the backward pass.
- :param func: the function to evaluate.
- :param inputs: the argument sequence to pass to `func`.
- :param params: a sequence of parameters `func` depends on but does not
- explicitly take as arguments.
- :param flag: if False, disable gradient checkpointing.
- """
- if flag:
- args = tuple(inputs) + tuple(params)
- return CheckpointFunction.apply(func, len(inputs), *args)
- else:
- return func(*inputs)
-
-
-class CheckpointFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, run_function, length, *args):
- ctx.run_function = run_function
- ctx.input_tensors = list(args[:length])
- ctx.input_params = list(args[length:])
-
- with torch.no_grad():
- output_tensors = ctx.run_function(*ctx.input_tensors)
- return output_tensors
-
- @staticmethod
- def backward(ctx, *output_grads):
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
- with torch.enable_grad():
- # Fixes a bug where the first op in run_function modifies the
- # Tensor storage in place, which is not allowed for detach()'d
- # Tensors.
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
- output_tensors = ctx.run_function(*shallow_copies)
- input_grads = torch.autograd.grad(
- output_tensors,
- ctx.input_tensors + ctx.input_params,
- output_grads,
- allow_unused=True,
- )
- del ctx.input_tensors
- del ctx.input_params
- del output_tensors
- return (None, None) + input_grads
-
-
-def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
- """
- Create sinusoidal timestep embeddings.
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
- These may be fractional.
- :param dim: the dimension of the output.
- :param max_period: controls the minimum frequency of the embeddings.
- :return: an [N x dim] Tensor of positional embeddings.
- """
- if not repeat_only:
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period)
- * torch.arange(start=0, end=half, dtype=torch.float32)
- / half
- ).to(device=timesteps.device)
- args = timesteps[:, None].float() * freqs[None]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat(
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
- )
- else:
- embedding = repeat(timesteps, "b -> b d", d=dim)
- return embedding
-
-
-def zero_module(module):
- """
- Zero out the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().zero_()
- return module
-
-
-def scale_module(module, scale):
- """
- Scale the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().mul_(scale)
- return module
-
-
-def mean_flat(tensor):
- """
- Take the mean over all non-batch dimensions.
- """
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
-
-
-def normalization(channels):
- """
- Make a standard normalization layer.
- :param channels: number of input channels.
- :return: an nn.Module for normalization.
- """
- return GroupNorm32(32, channels)
-
-
-# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
-class SiLU(nn.Module):
- def forward(self, x):
- return x * torch.sigmoid(x)
-
-
-class GroupNorm32(nn.GroupNorm):
- def forward(self, x):
- return super().forward(x.float()).type(x.dtype)
-
-
-def conv_nd(dims, *args, **kwargs):
- """
- Create a 1D, 2D, or 3D convolution module.
- """
- if dims == 1:
- return nn.Conv1d(*args, **kwargs)
- elif dims == 2:
- return nn.Conv2d(*args, **kwargs)
- elif dims == 3:
- return nn.Conv3d(*args, **kwargs)
- raise ValueError(f"unsupported dimensions: {dims}")
-
-
-def linear(*args, **kwargs):
- """
- Create a linear module.
- """
- return nn.Linear(*args, **kwargs)
-
-
-def avg_pool_nd(dims, *args, **kwargs):
- """
- Create a 1D, 2D, or 3D average pooling module.
- """
- if dims == 1:
- return nn.AvgPool1d(*args, **kwargs)
- elif dims == 2:
- return nn.AvgPool2d(*args, **kwargs)
- elif dims == 3:
- return nn.AvgPool3d(*args, **kwargs)
- raise ValueError(f"unsupported dimensions: {dims}")
-
-
-class HybridConditioner(nn.Module):
- def __init__(self, c_concat_config, c_crossattn_config):
- super().__init__()
- self.concat_conditioner = instantiate_from_config(c_concat_config)
- self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
-
- def forward(self, c_concat, c_crossattn):
- c_concat = self.concat_conditioner(c_concat)
- c_crossattn = self.crossattn_conditioner(c_crossattn)
- return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
-
-
-def noise_like(shape, device, repeat=False):
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
- shape[0], *((1,) * (len(shape) - 1))
- )
- noise = lambda: torch.randn(shape, device=device)
- return repeat_noise() if repeat else noise()
diff --git a/audioldm/ldm.py b/audioldm/ldm.py
deleted file mode 100644
index b0392e28404c315e5d8ca5ede571da386f5d4b42..0000000000000000000000000000000000000000
--- a/audioldm/ldm.py
+++ /dev/null
@@ -1,715 +0,0 @@
-import os
-
-import torch
-import numpy as np
-from tqdm import tqdm
-from audioldm.utils import default, instantiate_from_config, save_wave
-from audioldm.latent_diffusion.ddpm import DDPM
-from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
-from audioldm.latent_diffusion.util import noise_like
-from audioldm.latent_diffusion.ddim import DDIMSampler
-import os
-
-def disabled_train(self, mode=True):
- """Overwrite model.train with this function to make sure train/eval mode
- does not change anymore."""
- return self
-
-class LatentDiffusion(DDPM):
- """main class"""
-
- def __init__(
- self,
- device="cuda",
- first_stage_config=None,
- cond_stage_config=None,
- num_timesteps_cond=None,
- cond_stage_key="image",
- cond_stage_trainable=False,
- concat_mode=True,
- cond_stage_forward=None,
- conditioning_key=None,
- scale_factor=1.0,
- scale_by_std=False,
- base_learning_rate=None,
- *args,
- **kwargs,
- ):
- self.device = device
- self.learning_rate = base_learning_rate
- self.num_timesteps_cond = default(num_timesteps_cond, 1)
- self.scale_by_std = scale_by_std
- assert self.num_timesteps_cond <= kwargs["timesteps"]
- # for backwards compatibility after implementation of DiffusionWrapper
- if conditioning_key is None:
- conditioning_key = "concat" if concat_mode else "crossattn"
- if cond_stage_config == "__is_unconditional__":
- conditioning_key = None
- ckpt_path = kwargs.pop("ckpt_path", None)
- ignore_keys = kwargs.pop("ignore_keys", [])
- super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
- self.concat_mode = concat_mode
- self.cond_stage_trainable = cond_stage_trainable
- self.cond_stage_key = cond_stage_key
- self.cond_stage_key_orig = cond_stage_key
- try:
- self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
- except:
- self.num_downs = 0
- if not scale_by_std:
- self.scale_factor = scale_factor
- else:
- self.register_buffer("scale_factor", torch.tensor(scale_factor))
- self.instantiate_first_stage(first_stage_config)
- self.instantiate_cond_stage(cond_stage_config)
- self.cond_stage_forward = cond_stage_forward
- self.clip_denoised = False
-
- def make_cond_schedule(
- self,
- ):
- self.cond_ids = torch.full(
- size=(self.num_timesteps,),
- fill_value=self.num_timesteps - 1,
- dtype=torch.long,
- )
- ids = torch.round(
- torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
- ).long()
- self.cond_ids[: self.num_timesteps_cond] = ids
-
- def register_schedule(
- self,
- given_betas=None,
- beta_schedule="linear",
- timesteps=1000,
- linear_start=1e-4,
- linear_end=2e-2,
- cosine_s=8e-3,
- ):
- super().register_schedule(
- given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
- )
-
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
- if self.shorten_cond_schedule:
- self.make_cond_schedule()
-
- def instantiate_first_stage(self, config):
- model = instantiate_from_config(config)
- self.first_stage_model = model.eval()
- self.first_stage_model.train = disabled_train
- for param in self.first_stage_model.parameters():
- param.requires_grad = False
-
- def instantiate_cond_stage(self, config):
- if not self.cond_stage_trainable:
- if config == "__is_first_stage__":
- print("Using first stage also as cond stage.")
- self.cond_stage_model = self.first_stage_model
- elif config == "__is_unconditional__":
- print(f"Training {self.__class__.__name__} as an unconditional model.")
- self.cond_stage_model = None
- # self.be_unconditional = True
- else:
- model = instantiate_from_config(config)
- self.cond_stage_model = model.eval()
- self.cond_stage_model.train = disabled_train
- for param in self.cond_stage_model.parameters():
- param.requires_grad = False
- else:
- assert config != "__is_first_stage__"
- assert config != "__is_unconditional__"
- model = instantiate_from_config(config)
- self.cond_stage_model = model
- self.cond_stage_model = self.cond_stage_model.to(self.device)
-
- def get_first_stage_encoding(self, encoder_posterior):
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
- z = encoder_posterior.sample()
- elif isinstance(encoder_posterior, torch.Tensor):
- z = encoder_posterior
- else:
- raise NotImplementedError(
- f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
- )
- return self.scale_factor * z
-
- def get_learned_conditioning(self, c):
- if self.cond_stage_forward is None:
- if hasattr(self.cond_stage_model, "encode") and callable(
- self.cond_stage_model.encode
- ):
- c = self.cond_stage_model.encode(c)
- if isinstance(c, DiagonalGaussianDistribution):
- c = c.mode()
- else:
- if len(c) == 1:
- c = self.cond_stage_model([c[0], c[0]])
- c = c[0:1]
- else:
- c = self.cond_stage_model(c)
- else:
- assert hasattr(self.cond_stage_model, self.cond_stage_forward)
- c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
- return c
-
- @torch.no_grad()
- def get_input(
- self,
- batch,
- k,
- return_first_stage_encode=True,
- return_first_stage_outputs=False,
- force_c_encode=False,
- cond_key=None,
- return_original_cond=False,
- bs=None,
- ):
- x = super().get_input(batch, k)
-
- if bs is not None:
- x = x[:bs]
-
- x = x.to(self.device)
-
- if return_first_stage_encode:
- encoder_posterior = self.encode_first_stage(x)
- z = self.get_first_stage_encoding(encoder_posterior).detach()
- else:
- z = None
-
- if self.model.conditioning_key is not None:
- if cond_key is None:
- cond_key = self.cond_stage_key
- if cond_key != self.first_stage_key:
- if cond_key in ["caption", "coordinates_bbox"]:
- xc = batch[cond_key]
- elif cond_key == "class_label":
- xc = batch
- else:
- # [bs, 1, 527]
- xc = super().get_input(batch, cond_key)
- if type(xc) == torch.Tensor:
- xc = xc.to(self.device)
- else:
- xc = x
- if not self.cond_stage_trainable or force_c_encode:
- if isinstance(xc, dict) or isinstance(xc, list):
- c = self.get_learned_conditioning(xc)
- else:
- c = self.get_learned_conditioning(xc.to(self.device))
- else:
- c = xc
-
- if bs is not None:
- c = c[:bs]
-
- else:
- c = None
- xc = None
- if self.use_positional_encodings:
- pos_x, pos_y = self.compute_latent_shifts(batch)
- c = {"pos_x": pos_x, "pos_y": pos_y}
- out = [z, c]
- if return_first_stage_outputs:
- xrec = self.decode_first_stage(z)
- out.extend([x, xrec])
- if return_original_cond:
- out.append(xc)
- return out
-
- @torch.no_grad()
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
- if predict_cids:
- if z.dim() == 4:
- z = torch.argmax(z.exp(), dim=1).long()
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
- z = rearrange(z, "b h w c -> b c h w").contiguous()
-
- z = 1.0 / self.scale_factor * z
- return self.first_stage_model.decode(z)
-
- def mel_spectrogram_to_waveform(self, mel):
- # Mel: [bs, 1, t-steps, fbins]
- if len(mel.size()) == 4:
- mel = mel.squeeze(1)
- mel = mel.permute(0, 2, 1)
- waveform = self.first_stage_model.vocoder(mel)
- waveform = waveform.cpu().detach().numpy()
- return waveform
-
- @torch.no_grad()
- def encode_first_stage(self, x):
- return self.first_stage_model.encode(x)
-
- def apply_model(self, x_noisy, t, cond, return_ids=False):
-
- if isinstance(cond, dict):
- # hybrid case, cond is exptected to be a dict
- pass
- else:
- if not isinstance(cond, list):
- cond = [cond]
- if self.model.conditioning_key == "concat":
- key = "c_concat"
- elif self.model.conditioning_key == "crossattn":
- key = "c_crossattn"
- else:
- key = "c_film"
-
- cond = {key: cond}
-
- x_recon = self.model(x_noisy, t, **cond)
-
- if isinstance(x_recon, tuple) and not return_ids:
- return x_recon[0]
- else:
- return x_recon
-
- def p_mean_variance(
- self,
- x,
- c,
- t,
- clip_denoised: bool,
- return_codebook_ids=False,
- quantize_denoised=False,
- return_x0=False,
- score_corrector=None,
- corrector_kwargs=None,
- ):
- t_in = t
- model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
-
- if score_corrector is not None:
- assert self.parameterization == "eps"
- model_out = score_corrector.modify_score(
- self, model_out, x, t, c, **corrector_kwargs
- )
-
- if return_codebook_ids:
- model_out, logits = model_out
-
- if self.parameterization == "eps":
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
- elif self.parameterization == "x0":
- x_recon = model_out
- else:
- raise NotImplementedError()
-
- if clip_denoised:
- x_recon.clamp_(-1.0, 1.0)
- if quantize_denoised:
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
- x_start=x_recon, x_t=x, t=t
- )
- if return_codebook_ids:
- return model_mean, posterior_variance, posterior_log_variance, logits
- elif return_x0:
- return model_mean, posterior_variance, posterior_log_variance, x_recon
- else:
- return model_mean, posterior_variance, posterior_log_variance
-
- @torch.no_grad()
- def p_sample(
- self,
- x,
- c,
- t,
- clip_denoised=False,
- repeat_noise=False,
- return_codebook_ids=False,
- quantize_denoised=False,
- return_x0=False,
- temperature=1.0,
- noise_dropout=0.0,
- score_corrector=None,
- corrector_kwargs=None,
- ):
- b, *_, device = *x.shape, x.device
- outputs = self.p_mean_variance(
- x=x,
- c=c,
- t=t,
- clip_denoised=clip_denoised,
- return_codebook_ids=return_codebook_ids,
- quantize_denoised=quantize_denoised,
- return_x0=return_x0,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- )
- if return_codebook_ids:
- raise DeprecationWarning("Support dropped.")
- model_mean, _, model_log_variance, logits = outputs
- elif return_x0:
- model_mean, _, model_log_variance, x0 = outputs
- else:
- model_mean, _, model_log_variance = outputs
-
- noise = noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.0:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- # no noise when t == 0
- nonzero_mask = (
- (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
- )
-
- if return_codebook_ids:
- return model_mean + nonzero_mask * (
- 0.5 * model_log_variance
- ).exp() * noise, logits.argmax(dim=1)
- if return_x0:
- return (
- model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
- x0,
- )
- else:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
-
- @torch.no_grad()
- def progressive_denoising(
- self,
- cond,
- shape,
- verbose=True,
- callback=None,
- quantize_denoised=False,
- img_callback=None,
- mask=None,
- x0=None,
- temperature=1.0,
- noise_dropout=0.0,
- score_corrector=None,
- corrector_kwargs=None,
- batch_size=None,
- x_T=None,
- start_T=None,
- log_every_t=None,
- ):
- if not log_every_t:
- log_every_t = self.log_every_t
- timesteps = self.num_timesteps
- if batch_size is not None:
- b = batch_size if batch_size is not None else shape[0]
- shape = [batch_size] + list(shape)
- else:
- b = batch_size = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=self.device)
- else:
- img = x_T
- intermediates = []
- if cond is not None:
- if isinstance(cond, dict):
- cond = {
- key: cond[key][:batch_size]
- if not isinstance(cond[key], list)
- else list(map(lambda x: x[:batch_size], cond[key]))
- for key in cond
- }
- else:
- cond = (
- [c[:batch_size] for c in cond]
- if isinstance(cond, list)
- else cond[:batch_size]
- )
-
- if start_T is not None:
- timesteps = min(timesteps, start_T)
- iterator = (
- tqdm(
- reversed(range(0, timesteps)),
- desc="Progressive Generation",
- total=timesteps,
- )
- if verbose
- else reversed(range(0, timesteps))
- )
- if type(temperature) == float:
- temperature = [temperature] * timesteps
-
- for i in iterator:
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
- if self.shorten_cond_schedule:
- assert self.model.conditioning_key != "hybrid"
- tc = self.cond_ids[ts].to(cond.device)
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
-
- img, x0_partial = self.p_sample(
- img,
- cond,
- ts,
- clip_denoised=self.clip_denoised,
- quantize_denoised=quantize_denoised,
- return_x0=True,
- temperature=temperature[i],
- noise_dropout=noise_dropout,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- )
- if mask is not None:
- assert x0 is not None
- img_orig = self.q_sample(x0, ts)
- img = img_orig * mask + (1.0 - mask) * img
-
- if i % log_every_t == 0 or i == timesteps - 1:
- intermediates.append(x0_partial)
- if callback:
- callback(i)
- if img_callback:
- img_callback(img, i)
- return img, intermediates
-
- @torch.no_grad()
- def p_sample_loop(
- self,
- cond,
- shape,
- return_intermediates=False,
- x_T=None,
- verbose=True,
- callback=None,
- timesteps=None,
- quantize_denoised=False,
- mask=None,
- x0=None,
- img_callback=None,
- start_T=None,
- log_every_t=None,
- ):
-
- if not log_every_t:
- log_every_t = self.log_every_t
- device = self.betas.device
- b = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=device)
- else:
- img = x_T
-
- intermediates = [img]
- if timesteps is None:
- timesteps = self.num_timesteps
-
- if start_T is not None:
- timesteps = min(timesteps, start_T)
- iterator = (
- tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
- if verbose
- else reversed(range(0, timesteps))
- )
-
- if mask is not None:
- assert x0 is not None
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
-
- for i in iterator:
- ts = torch.full((b,), i, device=device, dtype=torch.long)
- if self.shorten_cond_schedule:
- assert self.model.conditioning_key != "hybrid"
- tc = self.cond_ids[ts].to(cond.device)
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
-
- img = self.p_sample(
- img,
- cond,
- ts,
- clip_denoised=self.clip_denoised,
- quantize_denoised=quantize_denoised,
- )
- if mask is not None:
- img_orig = self.q_sample(x0, ts)
- img = img_orig * mask + (1.0 - mask) * img
-
- if i % log_every_t == 0 or i == timesteps - 1:
- intermediates.append(img)
- if callback:
- callback(i)
- if img_callback:
- img_callback(img, i)
-
- if return_intermediates:
- return img, intermediates
- return img
-
- @torch.no_grad()
- def sample(
- self,
- cond,
- batch_size=16,
- return_intermediates=False,
- x_T=None,
- verbose=True,
- timesteps=None,
- quantize_denoised=False,
- mask=None,
- x0=None,
- shape=None,
- **kwargs,
- ):
- if shape is None:
- shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size)
- if cond is not None:
- if isinstance(cond, dict):
- cond = {
- key: cond[key][:batch_size]
- if not isinstance(cond[key], list)
- else list(map(lambda x: x[:batch_size], cond[key]))
- for key in cond
- }
- else:
- cond = (
- [c[:batch_size] for c in cond]
- if isinstance(cond, list)
- else cond[:batch_size]
- )
- return self.p_sample_loop(
- cond,
- shape,
- return_intermediates=return_intermediates,
- x_T=x_T,
- verbose=verbose,
- timesteps=timesteps,
- quantize_denoised=quantize_denoised,
- mask=mask,
- x0=x0,
- **kwargs,
- )
-
- @torch.no_grad()
- def sample_log(
- self,
- cond,
- batch_size,
- ddim,
- ddim_steps,
- unconditional_guidance_scale=1.0,
- unconditional_conditioning=None,
- use_plms=False,
- mask=None,
- **kwargs,
- ):
-
- if mask is not None:
- shape = (self.channels, mask.size()[-2], mask.size()[-1])
- else:
- shape = (self.channels, self.latent_t_size, self.latent_f_size)
-
- intermediate = None
- if ddim and not use_plms:
- # print("Use ddim sampler")
-
- ddim_sampler = DDIMSampler(self)
- samples, intermediates = ddim_sampler.sample(
- ddim_steps,
- batch_size,
- shape,
- cond,
- verbose=False,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- mask=mask,
- **kwargs,
- )
-
- else:
- # print("Use DDPM sampler")
- samples, intermediates = self.sample(
- cond=cond,
- batch_size=batch_size,
- return_intermediates=True,
- unconditional_guidance_scale=unconditional_guidance_scale,
- mask=mask,
- unconditional_conditioning=unconditional_conditioning,
- **kwargs,
- )
-
- return samples, intermediate
-
-
- @torch.no_grad()
- def generate_sample(
- self,
- batchs,
- ddim_steps=200,
- ddim_eta=1.0,
- x_T=None,
- n_candidate_gen_per_text=1,
- unconditional_guidance_scale=1.0,
- unconditional_conditioning=None,
- name="waveform",
- use_plms=False,
- save=False,
- **kwargs,
- ):
- # Generate n_candidate_gen_per_text times and select the best
- # Batch: audio, text, fnames
- assert x_T is None
- try:
- batchs = iter(batchs)
- except TypeError:
- raise ValueError("The first input argument should be an iterable object")
-
- if use_plms:
- assert ddim_steps is not None
- use_ddim = ddim_steps is not None
- # waveform_save_path = os.path.join(self.get_log_dir(), name)
- # os.makedirs(waveform_save_path, exist_ok=True)
- # print("Waveform save path: ", waveform_save_path)
-
- with self.ema_scope("Generate"):
- for batch in batchs:
- z, c = self.get_input(
- batch,
- self.first_stage_key,
- return_first_stage_outputs=False,
- force_c_encode=True,
- return_original_cond=False,
- bs=None,
- )
- text = super().get_input(batch, "text")
-
- # Generate multiple samples
- batch_size = z.shape[0] * n_candidate_gen_per_text
- c = torch.cat([c] * n_candidate_gen_per_text, dim=0)
- text = text * n_candidate_gen_per_text
-
- if unconditional_guidance_scale != 1.0:
- unconditional_conditioning = (
- self.cond_stage_model.get_unconditional_condition(batch_size)
- )
-
- samples, _ = self.sample_log(
- cond=c,
- batch_size=batch_size,
- x_T=x_T,
- ddim=use_ddim,
- ddim_steps=ddim_steps,
- eta=ddim_eta,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- use_plms=use_plms,
- )
-
- mel = self.decode_first_stage(samples)
-
- waveform = self.mel_spectrogram_to_waveform(mel)
-
- if(waveform.shape[0] > 1):
- similarity = self.cond_stage_model.cos_similarity(
- torch.FloatTensor(waveform).squeeze(1), text
- )
-
- best_index = []
- for i in range(z.shape[0]):
- candidates = similarity[i :: z.shape[0]]
- max_index = torch.argmax(candidates).item()
- best_index.append(i + max_index * z.shape[0])
-
- waveform = waveform[best_index]
- # print("Similarity between generated audio and text", similarity)
- # print("Choose the following indexes:", best_index)
-
- return waveform
diff --git a/audioldm/pipeline.py b/audioldm/pipeline.py
deleted file mode 100644
index a45db86865400e28b006dd3eebd873126a856fa0..0000000000000000000000000000000000000000
--- a/audioldm/pipeline.py
+++ /dev/null
@@ -1,92 +0,0 @@
-
-
-import os
-
-import argparse
-import yaml
-import torch
-
-from audioldm import LatentDiffusion, seed_everything
-from audioldm.utils import default_audioldm_config
-
-
-import time
-
-def make_batch_for_text_to_audio(text, batchsize=1):
- text = [text] * batchsize
- if batchsize < 1:
- print("Warning: Batchsize must be at least 1. Batchsize is set to .")
- fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format
- stft = torch.zeros((batchsize, 1024, 512)) # Not used
- waveform = torch.zeros((batchsize, 160000)) # Not used
- fname = [""] * batchsize # Not used
- batch = (
- fbank,
- stft,
- None,
- fname,
- waveform,
- text,
- )
- return batch
-
-
-
-def build_model(
- ckpt_path=None,
- config=None,
- model_name="audioldm-s-full"
-):
- print("Load AudioLDM: %s" % model_name)
-
- resume_from_checkpoint = "ckpt/%s.ckpt" % model_name
-
- # if(ckpt_path is None):
- # ckpt_path = get_metadata()[model_name]["path"]
-
- # if(not os.path.exists(ckpt_path)):
- # download_checkpoint(model_name)
-
- if(torch.cuda.is_available()):
- device = torch.device("cuda:0")
- else:
- device = torch.device("cpu")
-
- if(config is not None):
- assert type(config) is str
- config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
- else:
- config = default_audioldm_config(model_name)
-
- # Use text as condition instead of using waveform during training
- config["model"]["params"]["device"] = device
- config["model"]["params"]["cond_stage_key"] = "text"
-
- # No normalization here
- latent_diffusion = LatentDiffusion(**config["model"]["params"])
-
- checkpoint = torch.load(resume_from_checkpoint, map_location=device)
- latent_diffusion.load_state_dict(checkpoint["state_dict"])
-
- latent_diffusion.eval()
- latent_diffusion = latent_diffusion.to(device)
-
- latent_diffusion.cond_stage_model.embed_mode = "text"
- return latent_diffusion
-
-def duration_to_latent_t_size(duration):
- return int(duration * 25.6)
-
-def text_to_audio(latent_diffusion, text, seed=42, duration=10, batchsize=1, guidance_scale=2.5, n_candidate_gen_per_text=3, config=None):
- seed_everything(int(seed))
- batch = make_batch_for_text_to_audio(text, batchsize=batchsize)
-
- latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
- with torch.no_grad():
- waveform = latent_diffusion.generate_sample(
- [batch],
- unconditional_guidance_scale=guidance_scale,
- n_candidate_gen_per_text=n_candidate_gen_per_text,
- duration=duration
- )
- return waveform
diff --git a/audioldm/utils.py b/audioldm/utils.py
deleted file mode 100644
index 0468a2973b705af739483c196a91185500d6a8da..0000000000000000000000000000000000000000
--- a/audioldm/utils.py
+++ /dev/null
@@ -1,174 +0,0 @@
-import importlib
-
-from inspect import isfunction
-
-import os
-import soundfile as sf
-
-def seed_everything(seed):
- import random, os
- import numpy as np
- import torch
-
- random.seed(seed)
- os.environ['PYTHONHASHSEED'] = str(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = True
-
-def save_wave(waveform, savepath, name="outwav"):
- if type(name) is not list:
- name = [name] * waveform.shape[0]
-
- for i in range(waveform.shape[0]):
- path = os.path.join(
- savepath,
- "%s_%s.wav"
- % (
- os.path.basename(name[i])
- if (not ".wav" in name[i])
- else os.path.basename(name[i]).split(".")[0],
- i,
- ),
- )
- sf.write(path, waveform[i, 0], samplerate=16000)
-
-def exists(x):
- return x is not None
-
-
-def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
-
-
-def count_params(model, verbose=False):
- total_params = sum(p.numel() for p in model.parameters())
- if verbose:
- print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
- return total_params
-
-
-def get_obj_from_str(string, reload=False):
- module, cls = string.rsplit(".", 1)
- if reload:
- module_imp = importlib.import_module(module)
- importlib.reload(module_imp)
- return getattr(importlib.import_module(module, package=None), cls)
-
-
-def instantiate_from_config(config):
- if not "target" in config:
- if config == "__is_first_stage__":
- return None
- elif config == "__is_unconditional__":
- return None
- raise KeyError("Expected key `target` to instantiate.")
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
-
-def default_audioldm_config(model_name="audioldm-s-full"):
- basic_config = {
- "wave_file_save_path": "./output",
- "id": {
- "version": "v1",
- "name": "default",
- "root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml",
- },
- "preprocessing": {
- "audio": {"sampling_rate": 16000, "max_wav_value": 32768},
- "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024},
- "mel": {
- "n_mel_channels": 64,
- "mel_fmin": 0,
- "mel_fmax": 8000,
- "freqm": 0,
- "timem": 0,
- "blur": False,
- "mean": -4.63,
- "std": 2.74,
- "target_length": 1024,
- },
- },
- "model": {
- "device": "cuda",
- "target": "audioldm.pipline.LatentDiffusion",
- "params": {
- "base_learning_rate": 5e-06,
- "linear_start": 0.0015,
- "linear_end": 0.0195,
- "num_timesteps_cond": 1,
- "log_every_t": 200,
- "timesteps": 1000,
- "first_stage_key": "fbank",
- "cond_stage_key": "waveform",
- "latent_t_size": 256,
- "latent_f_size": 16,
- "channels": 8,
- "cond_stage_trainable": True,
- "conditioning_key": "film",
- "monitor": "val/loss_simple_ema",
- "scale_by_std": True,
- "unet_config": {
- "target": "audioldm.latent_diffusion.openaimodel.UNetModel",
- "params": {
- "image_size": 64,
- "extra_film_condition_dim": 512,
- "extra_film_use_concat": True,
- "in_channels": 8,
- "out_channels": 8,
- "model_channels": 128,
- "attention_resolutions": [8, 4, 2],
- "num_res_blocks": 2,
- "channel_mult": [1, 2, 3, 5],
- "num_head_channels": 32,
- "use_spatial_transformer": True,
- },
- },
- "first_stage_config": {
- "base_learning_rate": 4.5e-05,
- "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL",
- "params": {
- "monitor": "val/rec_loss",
- "image_key": "fbank",
- "subband": 1,
- "embed_dim": 8,
- "time_shuffle": 1,
- "ddconfig": {
- "double_z": True,
- "z_channels": 8,
- "resolution": 256,
- "downsample_time": False,
- "in_channels": 1,
- "out_ch": 1,
- "ch": 128,
- "ch_mult": [1, 2, 4],
- "num_res_blocks": 2,
- "attn_resolutions": [],
- "dropout": 0.0,
- },
- },
- },
- "cond_stage_config": {
- "target": "audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2",
- "params": {
- "key": "waveform",
- "sampling_rate": 16000,
- "embed_mode": "audio",
- "unconditional_prob": 0.1,
- },
- },
- },
- },
- }
-
- if("-l-" in model_name):
- basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 256
- basic_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = 64
- elif("-m-" in model_name):
- basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 192
- basic_config["model"]["params"]["cond_stage_config"]["params"]["amodel"] = "HTSAT-base" # This model use a larger HTAST
-
- return basic_config
\ No newline at end of file
diff --git a/audioldm/variational_autoencoder/__init__.py b/audioldm/variational_autoencoder/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/audioldm/variational_autoencoder/autoencoder.py b/audioldm/variational_autoencoder/autoencoder.py
deleted file mode 100644
index cfbdceba0e1171b052f797885530bacd0f3c73d5..0000000000000000000000000000000000000000
--- a/audioldm/variational_autoencoder/autoencoder.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import torch
-from audioldm.latent_diffusion.ema import *
-from audioldm.variational_autoencoder.modules import Encoder, Decoder
-from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
-
-from audioldm.hifigan.utilities import get_vocoder, vocoder_infer
-
-class AutoencoderKL(nn.Module):
- def __init__(
- self,
- ddconfig=None,
- lossconfig=None,
- image_key="fbank",
- embed_dim=None,
- time_shuffle=1,
- subband=1,
- ckpt_path=None,
- reload_from_ckpt=None,
- ignore_keys=[],
- colorize_nlabels=None,
- monitor=None,
- base_learning_rate=1e-5,
- ):
- super().__init__()
-
- self.encoder = Encoder(**ddconfig)
- self.decoder = Decoder(**ddconfig)
-
- self.subband = int(subband)
-
- if self.subband > 1:
- print("Use subband decomposition %s" % self.subband)
-
- self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
-
- self.vocoder = get_vocoder(None, "cpu")
- self.embed_dim = embed_dim
-
- if monitor is not None:
- self.monitor = monitor
-
- self.time_shuffle = time_shuffle
- self.reload_from_ckpt = reload_from_ckpt
- self.reloaded = False
- self.mean, self.std = None, None
-
- def encode(self, x):
- # x = self.time_shuffle_operation(x)
- x = self.freq_split_subband(x)
- h = self.encoder(x)
- moments = self.quant_conv(h)
- posterior = DiagonalGaussianDistribution(moments)
- return posterior
-
- def decode(self, z):
- z = self.post_quant_conv(z)
- dec = self.decoder(z)
- dec = self.freq_merge_subband(dec)
- return dec
-
- def decode_to_waveform(self, dec):
- dec = dec.squeeze(1).permute(0, 2, 1)
- wav_reconstruction = vocoder_infer(dec, self.vocoder)
- return wav_reconstruction
-
- def forward(self, input, sample_posterior=True):
- posterior = self.encode(input)
- if sample_posterior:
- z = posterior.sample()
- else:
- z = posterior.mode()
-
- if self.flag_first_run:
- print("Latent size: ", z.size())
- self.flag_first_run = False
-
- dec = self.decode(z)
-
- return dec, posterior
-
- def freq_split_subband(self, fbank):
- if self.subband == 1 or self.image_key != "stft":
- return fbank
-
- bs, ch, tstep, fbins = fbank.size()
-
- assert fbank.size(-1) % self.subband == 0
- assert ch == 1
-
- return (
- fbank.squeeze(1)
- .reshape(bs, tstep, self.subband, fbins // self.subband)
- .permute(0, 2, 1, 3)
- )
-
- def freq_merge_subband(self, subband_fbank):
- if self.subband == 1 or self.image_key != "stft":
- return subband_fbank
- assert subband_fbank.size(1) == self.subband # Channel dimension
- bs, sub_ch, tstep, fbins = subband_fbank.size()
- return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)
diff --git a/audioldm/variational_autoencoder/distributions.py b/audioldm/variational_autoencoder/distributions.py
deleted file mode 100644
index 58eb535e7769f402169ddff77ee45c96ba3650d9..0000000000000000000000000000000000000000
--- a/audioldm/variational_autoencoder/distributions.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import torch
-import numpy as np
-
-
-class AbstractDistribution:
- def sample(self):
- raise NotImplementedError()
-
- def mode(self):
- raise NotImplementedError()
-
-
-class DiracDistribution(AbstractDistribution):
- def __init__(self, value):
- self.value = value
-
- def sample(self):
- return self.value
-
- def mode(self):
- return self.value
-
-
-class DiagonalGaussianDistribution(object):
- def __init__(self, parameters, deterministic=False):
- self.parameters = parameters
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
- self.deterministic = deterministic
- self.std = torch.exp(0.5 * self.logvar)
- self.var = torch.exp(self.logvar)
- if self.deterministic:
- self.var = self.std = torch.zeros_like(self.mean).to(
- device=self.parameters.device
- )
-
- def sample(self):
- x = self.mean + self.std * torch.randn(self.mean.shape).to(
- device=self.parameters.device
- )
- return x
-
- def kl(self, other=None):
- if self.deterministic:
- return torch.Tensor([0.0])
- else:
- if other is None:
- return 0.5 * torch.mean(
- torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
- dim=[1, 2, 3],
- )
- else:
- return 0.5 * torch.mean(
- torch.pow(self.mean - other.mean, 2) / other.var
- + self.var / other.var
- - 1.0
- - self.logvar
- + other.logvar,
- dim=[1, 2, 3],
- )
-
- def nll(self, sample, dims=[1, 2, 3]):
- if self.deterministic:
- return torch.Tensor([0.0])
- logtwopi = np.log(2.0 * np.pi)
- return 0.5 * torch.sum(
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
- dim=dims,
- )
-
- def mode(self):
- return self.mean
-
-
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
- Compute the KL divergence between two gaussians.
- Shapes are automatically broadcasted, so batches can be compared to
- scalars, among other use cases.
- """
- tensor = None
- for obj in (mean1, logvar1, mean2, logvar2):
- if isinstance(obj, torch.Tensor):
- tensor = obj
- break
- assert tensor is not None, "at least one argument must be a Tensor"
-
- # Force variances to be Tensors. Broadcasting helps convert scalars to
- # Tensors, but it does not work for torch.exp().
- logvar1, logvar2 = [
- x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
- for x in (logvar1, logvar2)
- ]
-
- return 0.5 * (
- -1.0
- + logvar2
- - logvar1
- + torch.exp(logvar1 - logvar2)
- + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
- )
diff --git a/audioldm/variational_autoencoder/modules.py b/audioldm/variational_autoencoder/modules.py
deleted file mode 100644
index 6b2c3dca2d168fb5fbaff5acc4b5a06280a496a7..0000000000000000000000000000000000000000
--- a/audioldm/variational_autoencoder/modules.py
+++ /dev/null
@@ -1,1064 +0,0 @@
-# pytorch_diffusion + derived encoder decoder
-import math
-import torch
-import torch.nn as nn
-import numpy as np
-from einops import rearrange
-
-from audioldm.utils import instantiate_from_config
-from audioldm.latent_diffusion.attention import LinearAttention
-
-def get_timestep_embedding(timesteps, embedding_dim):
- """
- This matches the implementation in Denoising Diffusion Probabilistic Models:
- From Fairseq.
- Build sinusoidal embeddings.
- This matches the implementation in tensor2tensor, but differs slightly
- from the description in Section 3.5 of "Attention Is All You Need".
- """
- assert len(timesteps.shape) == 1
-
- half_dim = embedding_dim // 2
- emb = math.log(10000) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
- emb = emb.to(device=timesteps.device)
- emb = timesteps.float()[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
- return emb
-
-def nonlinearity(x):
- # swish
- return x * torch.sigmoid(x)
-
-
-def Normalize(in_channels, num_groups=32):
- return torch.nn.GroupNorm(
- num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
- )
-
-
-class Upsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, x):
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
- if self.with_conv:
- x = self.conv(x)
- return x
-
-
-class UpsampleTimeStride4(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=5, stride=1, padding=2
- )
-
- def forward(self, x):
- x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
- if self.with_conv:
- x = self.conv(x)
- return x
-
-
-class Downsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- # Do time downsampling here
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=2, padding=0
- )
-
- def forward(self, x):
- if self.with_conv:
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- else:
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
- return x
-
-
-class DownsampleTimeStride4(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- # Do time downsampling here
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
- )
-
- def forward(self, x):
- if self.with_conv:
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- else:
- x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
- return x
-
-
-class ResnetBlock(nn.Module):
- def __init__(
- self,
- *,
- in_channels,
- out_channels=None,
- conv_shortcut=False,
- dropout,
- temb_channels=512,
- ):
- super().__init__()
- self.in_channels = in_channels
- out_channels = in_channels if out_channels is None else out_channels
- self.out_channels = out_channels
- self.use_conv_shortcut = conv_shortcut
-
- self.norm1 = Normalize(in_channels)
- self.conv1 = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- if temb_channels > 0:
- self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
- self.norm2 = Normalize(out_channels)
- self.dropout = torch.nn.Dropout(dropout)
- self.conv2 = torch.nn.Conv2d(
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- self.conv_shortcut = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- else:
- self.nin_shortcut = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=1, stride=1, padding=0
- )
-
- def forward(self, x, temb):
- h = x
- h = self.norm1(h)
- h = nonlinearity(h)
- h = self.conv1(h)
-
- if temb is not None:
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
-
- h = self.norm2(h)
- h = nonlinearity(h)
- h = self.dropout(h)
- h = self.conv2(h)
-
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
- else:
- x = self.nin_shortcut(x)
-
- return x + h
-
-
-class LinAttnBlock(LinearAttention):
- """to match AttnBlock usage"""
-
- def __init__(self, in_channels):
- super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
-
-
-class AttnBlock(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q.shape
- q = q.reshape(b, c, h * w).contiguous()
- q = q.permute(0, 2, 1).contiguous() # b,hw,c
- k = k.reshape(b, c, h * w).contiguous() # b,c,hw
- w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w_ = w_ * (int(c) ** (-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
-
- # attend to values
- v = v.reshape(b, c, h * w).contiguous()
- w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(
- v, w_
- ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- h_ = h_.reshape(b, c, h, w).contiguous()
-
- h_ = self.proj_out(h_)
-
- return x + h_
-
-
-def make_attn(in_channels, attn_type="vanilla"):
- assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
- # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
- if attn_type == "vanilla":
- return AttnBlock(in_channels)
- elif attn_type == "none":
- return nn.Identity(in_channels)
- else:
- return LinAttnBlock(in_channels)
-
-
-class Model(nn.Module):
- def __init__(
- self,
- *,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- in_channels,
- resolution,
- use_timestep=True,
- use_linear_attn=False,
- attn_type="vanilla",
- ):
- super().__init__()
- if use_linear_attn:
- attn_type = "linear"
- self.ch = ch
- self.temb_ch = self.ch * 4
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
-
- self.use_timestep = use_timestep
- if self.use_timestep:
- # timestep embedding
- self.temb = nn.Module()
- self.temb.dense = nn.ModuleList(
- [
- torch.nn.Linear(self.ch, self.temb_ch),
- torch.nn.Linear(self.temb_ch, self.temb_ch),
- ]
- )
-
- # downsampling
- self.conv_in = torch.nn.Conv2d(
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
- )
-
- curr_res = resolution
- in_ch_mult = (1,) + tuple(ch_mult)
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_in = ch * in_ch_mult[i_level]
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- down = nn.Module()
- down.block = block
- down.attn = attn
- if i_level != self.num_resolutions - 1:
- down.downsample = Downsample(block_in, resamp_with_conv)
- curr_res = curr_res // 2
- self.down.append(down)
-
- # middle
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
-
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch * ch_mult[i_level]
- skip_in = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- if i_block == self.num_res_blocks:
- skip_in = ch * in_ch_mult[i_level]
- block.append(
- ResnetBlock(
- in_channels=block_in + skip_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- up = nn.Module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in, out_ch, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, x, t=None, context=None):
- # assert x.shape[2] == x.shape[3] == self.resolution
- if context is not None:
- # assume aligned context, cat along channel axis
- x = torch.cat((x, context), dim=1)
- if self.use_timestep:
- # timestep embedding
- assert t is not None
- temb = get_timestep_embedding(t, self.ch)
- temb = self.temb.dense[0](temb)
- temb = nonlinearity(temb)
- temb = self.temb.dense[1](temb)
- else:
- temb = None
-
- # downsampling
- hs = [self.conv_in(x)]
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](hs[-1], temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- hs.append(h)
- if i_level != self.num_resolutions - 1:
- hs.append(self.down[i_level].downsample(hs[-1]))
-
- # middle
- h = hs[-1]
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.up[i_level].block[i_block](
- torch.cat([h, hs.pop()], dim=1), temb
- )
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
-
- # end
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
- def get_last_layer(self):
- return self.conv_out.weight
-
-
-class Encoder(nn.Module):
- def __init__(
- self,
- *,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- in_channels,
- resolution,
- z_channels,
- double_z=True,
- use_linear_attn=False,
- attn_type="vanilla",
- downsample_time_stride4_levels=[],
- **ignore_kwargs,
- ):
- super().__init__()
- if use_linear_attn:
- attn_type = "linear"
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- self.downsample_time_stride4_levels = downsample_time_stride4_levels
-
- if len(self.downsample_time_stride4_levels) > 0:
- assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
- "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
- % str(self.num_resolutions)
- )
-
- # downsampling
- self.conv_in = torch.nn.Conv2d(
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
- )
-
- curr_res = resolution
- in_ch_mult = (1,) + tuple(ch_mult)
- self.in_ch_mult = in_ch_mult
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_in = ch * in_ch_mult[i_level]
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- down = nn.Module()
- down.block = block
- down.attn = attn
- if i_level != self.num_resolutions - 1:
- if i_level in self.downsample_time_stride4_levels:
- down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
- else:
- down.downsample = Downsample(block_in, resamp_with_conv)
- curr_res = curr_res // 2
- self.down.append(down)
-
- # middle
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in,
- 2 * z_channels if double_z else z_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- )
-
- def forward(self, x):
- # timestep embedding
- temb = None
- # downsampling
- hs = [self.conv_in(x)]
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](hs[-1], temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- hs.append(h)
- if i_level != self.num_resolutions - 1:
- hs.append(self.down[i_level].downsample(hs[-1]))
-
- # middle
- h = hs[-1]
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # end
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
-
-class Decoder(nn.Module):
- def __init__(
- self,
- *,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- in_channels,
- resolution,
- z_channels,
- give_pre_end=False,
- tanh_out=False,
- use_linear_attn=False,
- downsample_time_stride4_levels=[],
- attn_type="vanilla",
- **ignorekwargs,
- ):
- super().__init__()
- if use_linear_attn:
- attn_type = "linear"
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- self.give_pre_end = give_pre_end
- self.tanh_out = tanh_out
- self.downsample_time_stride4_levels = downsample_time_stride4_levels
-
- if len(self.downsample_time_stride4_levels) > 0:
- assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
- "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
- % str(self.num_resolutions)
- )
-
- # compute in_ch_mult, block_in and curr_res at lowest res
- in_ch_mult = (1,) + tuple(ch_mult)
- block_in = ch * ch_mult[self.num_resolutions - 1]
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
- self.z_shape = (1, z_channels, curr_res, curr_res)
- # print("Working with z of shape {} = {} dimensions.".format(
- # self.z_shape, np.prod(self.z_shape)))
-
- # z to block_in
- self.conv_in = torch.nn.Conv2d(
- z_channels, block_in, kernel_size=3, stride=1, padding=1
- )
-
- # middle
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
-
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- up = nn.Module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- if i_level - 1 in self.downsample_time_stride4_levels:
- up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
- else:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in, out_ch, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, z):
- # assert z.shape[1:] == self.z_shape[1:]
- self.last_z_shape = z.shape
-
- # timestep embedding
- temb = None
-
- # z to block_in
- h = self.conv_in(z)
-
- # middle
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.up[i_level].block[i_block](h, temb)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
-
- # end
- if self.give_pre_end:
- return h
-
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- if self.tanh_out:
- h = torch.tanh(h)
- return h
-
-
-class SimpleDecoder(nn.Module):
- def __init__(self, in_channels, out_channels, *args, **kwargs):
- super().__init__()
- self.model = nn.ModuleList(
- [
- nn.Conv2d(in_channels, in_channels, 1),
- ResnetBlock(
- in_channels=in_channels,
- out_channels=2 * in_channels,
- temb_channels=0,
- dropout=0.0,
- ),
- ResnetBlock(
- in_channels=2 * in_channels,
- out_channels=4 * in_channels,
- temb_channels=0,
- dropout=0.0,
- ),
- ResnetBlock(
- in_channels=4 * in_channels,
- out_channels=2 * in_channels,
- temb_channels=0,
- dropout=0.0,
- ),
- nn.Conv2d(2 * in_channels, in_channels, 1),
- Upsample(in_channels, with_conv=True),
- ]
- )
- # end
- self.norm_out = Normalize(in_channels)
- self.conv_out = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, x):
- for i, layer in enumerate(self.model):
- if i in [1, 2, 3]:
- x = layer(x, None)
- else:
- x = layer(x)
-
- h = self.norm_out(x)
- h = nonlinearity(h)
- x = self.conv_out(h)
- return x
-
-
-class UpsampleDecoder(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- ch,
- num_res_blocks,
- resolution,
- ch_mult=(2, 2),
- dropout=0.0,
- ):
- super().__init__()
- # upsampling
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- block_in = in_channels
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
- self.res_blocks = nn.ModuleList()
- self.upsample_blocks = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- res_block = []
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- res_block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- self.res_blocks.append(nn.ModuleList(res_block))
- if i_level != self.num_resolutions - 1:
- self.upsample_blocks.append(Upsample(block_in, True))
- curr_res = curr_res * 2
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in, out_channels, kernel_size=3, stride=1, padding=1
- )
-
- def forward(self, x):
- # upsampling
- h = x
- for k, i_level in enumerate(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.res_blocks[i_level][i_block](h, None)
- if i_level != self.num_resolutions - 1:
- h = self.upsample_blocks[k](h)
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
-
-class LatentRescaler(nn.Module):
- def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
- super().__init__()
- # residual block, interpolate, residual block
- self.factor = factor
- self.conv_in = nn.Conv2d(
- in_channels, mid_channels, kernel_size=3, stride=1, padding=1
- )
- self.res_block1 = nn.ModuleList(
- [
- ResnetBlock(
- in_channels=mid_channels,
- out_channels=mid_channels,
- temb_channels=0,
- dropout=0.0,
- )
- for _ in range(depth)
- ]
- )
- self.attn = AttnBlock(mid_channels)
- self.res_block2 = nn.ModuleList(
- [
- ResnetBlock(
- in_channels=mid_channels,
- out_channels=mid_channels,
- temb_channels=0,
- dropout=0.0,
- )
- for _ in range(depth)
- ]
- )
-
- self.conv_out = nn.Conv2d(
- mid_channels,
- out_channels,
- kernel_size=1,
- )
-
- def forward(self, x):
- x = self.conv_in(x)
- for block in self.res_block1:
- x = block(x, None)
- x = torch.nn.functional.interpolate(
- x,
- size=(
- int(round(x.shape[2] * self.factor)),
- int(round(x.shape[3] * self.factor)),
- ),
- )
- x = self.attn(x).contiguous()
- for block in self.res_block2:
- x = block(x, None)
- x = self.conv_out(x)
- return x
-
-
-class MergedRescaleEncoder(nn.Module):
- def __init__(
- self,
- in_channels,
- ch,
- resolution,
- out_ch,
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- ch_mult=(1, 2, 4, 8),
- rescale_factor=1.0,
- rescale_module_depth=1,
- ):
- super().__init__()
- intermediate_chn = ch * ch_mult[-1]
- self.encoder = Encoder(
- in_channels=in_channels,
- num_res_blocks=num_res_blocks,
- ch=ch,
- ch_mult=ch_mult,
- z_channels=intermediate_chn,
- double_z=False,
- resolution=resolution,
- attn_resolutions=attn_resolutions,
- dropout=dropout,
- resamp_with_conv=resamp_with_conv,
- out_ch=None,
- )
- self.rescaler = LatentRescaler(
- factor=rescale_factor,
- in_channels=intermediate_chn,
- mid_channels=intermediate_chn,
- out_channels=out_ch,
- depth=rescale_module_depth,
- )
-
- def forward(self, x):
- x = self.encoder(x)
- x = self.rescaler(x)
- return x
-
-
-class MergedRescaleDecoder(nn.Module):
- def __init__(
- self,
- z_channels,
- out_ch,
- resolution,
- num_res_blocks,
- attn_resolutions,
- ch,
- ch_mult=(1, 2, 4, 8),
- dropout=0.0,
- resamp_with_conv=True,
- rescale_factor=1.0,
- rescale_module_depth=1,
- ):
- super().__init__()
- tmp_chn = z_channels * ch_mult[-1]
- self.decoder = Decoder(
- out_ch=out_ch,
- z_channels=tmp_chn,
- attn_resolutions=attn_resolutions,
- dropout=dropout,
- resamp_with_conv=resamp_with_conv,
- in_channels=None,
- num_res_blocks=num_res_blocks,
- ch_mult=ch_mult,
- resolution=resolution,
- ch=ch,
- )
- self.rescaler = LatentRescaler(
- factor=rescale_factor,
- in_channels=z_channels,
- mid_channels=tmp_chn,
- out_channels=tmp_chn,
- depth=rescale_module_depth,
- )
-
- def forward(self, x):
- x = self.rescaler(x)
- x = self.decoder(x)
- return x
-
-
-class Upsampler(nn.Module):
- def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
- super().__init__()
- assert out_size >= in_size
- num_blocks = int(np.log2(out_size // in_size)) + 1
- factor_up = 1.0 + (out_size % in_size)
- print(
- f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
- )
- self.rescaler = LatentRescaler(
- factor=factor_up,
- in_channels=in_channels,
- mid_channels=2 * in_channels,
- out_channels=in_channels,
- )
- self.decoder = Decoder(
- out_ch=out_channels,
- resolution=out_size,
- z_channels=in_channels,
- num_res_blocks=2,
- attn_resolutions=[],
- in_channels=None,
- ch=in_channels,
- ch_mult=[ch_mult for _ in range(num_blocks)],
- )
-
- def forward(self, x):
- x = self.rescaler(x)
- x = self.decoder(x)
- return x
-
-
-class Resize(nn.Module):
- def __init__(self, in_channels=None, learned=False, mode="bilinear"):
- super().__init__()
- self.with_conv = learned
- self.mode = mode
- if self.with_conv:
- print(
- f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
- )
- raise NotImplementedError()
- assert in_channels is not None
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=4, stride=2, padding=1
- )
-
- def forward(self, x, scale_factor=1.0):
- if scale_factor == 1.0:
- return x
- else:
- x = torch.nn.functional.interpolate(
- x, mode=self.mode, align_corners=False, scale_factor=scale_factor
- )
- return x
-
-
-class FirstStagePostProcessor(nn.Module):
- def __init__(
- self,
- ch_mult: list,
- in_channels,
- pretrained_model: nn.Module = None,
- reshape=False,
- n_channels=None,
- dropout=0.0,
- pretrained_config=None,
- ):
- super().__init__()
- if pretrained_config is None:
- assert (
- pretrained_model is not None
- ), 'Either "pretrained_model" or "pretrained_config" must not be None'
- self.pretrained_model = pretrained_model
- else:
- assert (
- pretrained_config is not None
- ), 'Either "pretrained_model" or "pretrained_config" must not be None'
- self.instantiate_pretrained(pretrained_config)
-
- self.do_reshape = reshape
-
- if n_channels is None:
- n_channels = self.pretrained_model.encoder.ch
-
- self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
- self.proj = nn.Conv2d(
- in_channels, n_channels, kernel_size=3, stride=1, padding=1
- )
-
- blocks = []
- downs = []
- ch_in = n_channels
- for m in ch_mult:
- blocks.append(
- ResnetBlock(
- in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
- )
- )
- ch_in = m * n_channels
- downs.append(Downsample(ch_in, with_conv=False))
-
- self.model = nn.ModuleList(blocks)
- self.downsampler = nn.ModuleList(downs)
-
- def instantiate_pretrained(self, config):
- model = instantiate_from_config(config)
- self.pretrained_model = model.eval()
- # self.pretrained_model.train = False
- for param in self.pretrained_model.parameters():
- param.requires_grad = False
-
- @torch.no_grad()
- def encode_with_pretrained(self, x):
- c = self.pretrained_model.encode(x)
- if isinstance(c, DiagonalGaussianDistribution):
- c = c.mode()
- return c
-
- def forward(self, x):
- z_fs = self.encode_with_pretrained(x)
- z = self.proj_norm(z_fs)
- z = self.proj(z)
- z = nonlinearity(z)
-
- for submodel, downmodel in zip(self.model, self.downsampler):
- z = submodel(z, temb=None)
- z = downmodel(z)
-
- if self.do_reshape:
- z = rearrange(z, "b c h w -> b (h w) c")
- return z
diff --git a/ckpt/audioldm-m-full.ckpt b/ckpt/audioldm-m-full.ckpt
deleted file mode 100644
index a9cd836fd57622ef7d8779f3e1fcd7d0b0bd0a29..0000000000000000000000000000000000000000
--- a/ckpt/audioldm-m-full.ckpt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:936914a388905e1fc179c148a41a2b1552dba322ce474160b1cfa0f01ac26f8f
-size 4571683377
diff --git a/ckpt/audioldm-m-text-ft.ckpt b/ckpt/audioldm-m-text-ft.ckpt
deleted file mode 100644
index f97680915ef7ca756042d736938975dc23e69e39..0000000000000000000000000000000000000000
--- a/ckpt/audioldm-m-text-ft.ckpt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:d77d5a61785af82012edb8a72158d52592ac7c76d7f6ed51a048ec2dec8d5eca
-size 4571676474
diff --git a/ckpt/audioldm-s-text-ft.ckpt b/ckpt/audioldm-s-text-ft.ckpt
deleted file mode 100644
index 171cbed5e1a076e809b0a525e8c88651164e46c8..0000000000000000000000000000000000000000
--- a/ckpt/audioldm-s-text-ft.ckpt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:62075af09973ab50f158b213acc60347a330737c1d827c5dead61af60bfea706
-size 2558980807
diff --git a/ckpt/ldm_trimmed.ckpt b/ckpt/ldm_trimmed.ckpt
deleted file mode 100644
index 5925f8aca1f0bf7be45276a956f056ac6c2aa84e..0000000000000000000000000000000000000000
--- a/ckpt/ldm_trimmed.ckpt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:8f8d1923410622be823279b61967d27a2df3fd03ddd764afb298e7c20ef8877d
-size 2558947469
diff --git a/requirements.txt b/requirements.txt
index 93174b5dd3fad9851e3963c0aca028c811675d50..2637b318ea649c1c45532155c716e2449264d81e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,17 +1,4 @@
git+https://github.com/huggingface/diffusers.git
+git+https://github.com/huggingface/transformers.git
--extra-index-url https://download.pytorch.org/whl/cu113
-torch
-scipy
-torchaudio>=0.13.0
-torchvision>=0.14.0
-tqdm
-pyyaml
-einops
-numpy<=1.23.5
-soundfile
-librosa
-pandas
-# transformers
-torchlibrosa
-transformers
-ftfy
\ No newline at end of file
+torch >= 2.0
\ No newline at end of file