hungchiayu1 commited on
Commit
ffead1e
1 Parent(s): e7af757

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +5 -5
  2. app.py +140 -0
  3. audioldm/__init__.py +8 -0
  4. audioldm/__main__.py +183 -0
  5. audioldm/__pycache__/__init__.cpython-310.pyc +0 -0
  6. audioldm/__pycache__/__init__.cpython-39.pyc +0 -0
  7. audioldm/__pycache__/ldm.cpython-310.pyc +0 -0
  8. audioldm/__pycache__/ldm.cpython-39.pyc +0 -0
  9. audioldm/__pycache__/pipeline.cpython-310.pyc +0 -0
  10. audioldm/__pycache__/pipeline.cpython-39.pyc +0 -0
  11. audioldm/__pycache__/utils.cpython-310.pyc +0 -0
  12. audioldm/__pycache__/utils.cpython-39.pyc +0 -0
  13. audioldm/audio/__init__.py +2 -0
  14. audioldm/audio/__pycache__/__init__.cpython-310.pyc +0 -0
  15. audioldm/audio/__pycache__/__init__.cpython-39.pyc +0 -0
  16. audioldm/audio/__pycache__/audio_processing.cpython-310.pyc +0 -0
  17. audioldm/audio/__pycache__/audio_processing.cpython-39.pyc +0 -0
  18. audioldm/audio/__pycache__/mix.cpython-39.pyc +0 -0
  19. audioldm/audio/__pycache__/stft.cpython-310.pyc +0 -0
  20. audioldm/audio/__pycache__/stft.cpython-39.pyc +0 -0
  21. audioldm/audio/__pycache__/tools.cpython-310.pyc +0 -0
  22. audioldm/audio/__pycache__/tools.cpython-39.pyc +0 -0
  23. audioldm/audio/__pycache__/torch_tools.cpython-39.pyc +0 -0
  24. audioldm/audio/audio_processing.py +100 -0
  25. audioldm/audio/stft.py +186 -0
  26. audioldm/audio/tools.py +85 -0
  27. audioldm/hifigan/__init__.py +7 -0
  28. audioldm/hifigan/__pycache__/__init__.cpython-310.pyc +0 -0
  29. audioldm/hifigan/__pycache__/__init__.cpython-39.pyc +0 -0
  30. audioldm/hifigan/__pycache__/models.cpython-310.pyc +0 -0
  31. audioldm/hifigan/__pycache__/models.cpython-39.pyc +0 -0
  32. audioldm/hifigan/__pycache__/utilities.cpython-310.pyc +0 -0
  33. audioldm/hifigan/__pycache__/utilities.cpython-39.pyc +0 -0
  34. audioldm/hifigan/models.py +174 -0
  35. audioldm/hifigan/utilities.py +86 -0
  36. audioldm/latent_diffusion/__init__.py +0 -0
  37. audioldm/latent_diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  38. audioldm/latent_diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  39. audioldm/latent_diffusion/__pycache__/attention.cpython-310.pyc +0 -0
  40. audioldm/latent_diffusion/__pycache__/attention.cpython-39.pyc +0 -0
  41. audioldm/latent_diffusion/__pycache__/ddim.cpython-310.pyc +0 -0
  42. audioldm/latent_diffusion/__pycache__/ddim.cpython-39.pyc +0 -0
  43. audioldm/latent_diffusion/__pycache__/ddpm.cpython-310.pyc +0 -0
  44. audioldm/latent_diffusion/__pycache__/ddpm.cpython-39.pyc +0 -0
  45. audioldm/latent_diffusion/__pycache__/ema.cpython-310.pyc +0 -0
  46. audioldm/latent_diffusion/__pycache__/ema.cpython-39.pyc +0 -0
  47. audioldm/latent_diffusion/__pycache__/openaimodel.cpython-39.pyc +0 -0
  48. audioldm/latent_diffusion/__pycache__/util.cpython-310.pyc +0 -0
  49. audioldm/latent_diffusion/__pycache__/util.cpython-39.pyc +0 -0
  50. audioldm/latent_diffusion/attention.py +469 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Tango2
3
- emoji: 👀
4
- colorFrom: blue
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.26.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Tango
3
+ emoji: 🐠
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.28.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import torch
4
+ import wavio
5
+ from tqdm import tqdm
6
+ from huggingface_hub import snapshot_download
7
+ from models import AudioDiffusion, DDPMScheduler
8
+ from audioldm.audio.stft import TacotronSTFT
9
+ from audioldm.variational_autoencoder import AutoencoderKL
10
+ from gradio import Markdown
11
+
12
+ class Tango:
13
+ def __init__(self, name="declare-lab/tango2", device="cuda:0"):
14
+
15
+ path = snapshot_download(repo_id=name)
16
+
17
+ vae_config = json.load(open("{}/vae_config.json".format(path)))
18
+ stft_config = json.load(open("{}/stft_config.json".format(path)))
19
+ main_config = json.load(open("{}/main_config.json".format(path)))
20
+
21
+ self.vae = AutoencoderKL(**vae_config).to(device)
22
+ self.stft = TacotronSTFT(**stft_config).to(device)
23
+ self.model = AudioDiffusion(**main_config).to(device)
24
+
25
+ vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location=device)
26
+ stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location=device)
27
+ main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location=device)
28
+
29
+ self.vae.load_state_dict(vae_weights)
30
+ self.stft.load_state_dict(stft_weights)
31
+ self.model.load_state_dict(main_weights)
32
+
33
+ print ("Successfully loaded checkpoint from:", name)
34
+
35
+ self.vae.eval()
36
+ self.stft.eval()
37
+ self.model.eval()
38
+
39
+ self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler")
40
+
41
+ def chunks(self, lst, n):
42
+ """ Yield successive n-sized chunks from a list. """
43
+ for i in range(0, len(lst), n):
44
+ yield lst[i:i + n]
45
+
46
+ def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
47
+ """ Genrate audio for a single prompt string. """
48
+ with torch.no_grad():
49
+ latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
50
+ mel = self.vae.decode_first_stage(latents)
51
+ wave = self.vae.decode_to_waveform(mel)
52
+ return wave[0]
53
+
54
+ def generate_for_batch(self, prompts, steps=200, guidance=3, samples=1, batch_size=8, disable_progress=True):
55
+ """ Genrate audio for a list of prompt strings. """
56
+ outputs = []
57
+ for k in tqdm(range(0, len(prompts), batch_size)):
58
+ batch = prompts[k: k+batch_size]
59
+ with torch.no_grad():
60
+ latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
61
+ mel = self.vae.decode_first_stage(latents)
62
+ wave = self.vae.decode_to_waveform(mel)
63
+ outputs += [item for item in wave]
64
+ if samples == 1:
65
+ return outputs
66
+ else:
67
+ return list(self.chunks(outputs, samples))
68
+
69
+ # Initialize TANGO
70
+ if torch.cuda.is_available():
71
+ tango = Tango()
72
+ else:
73
+ tango = Tango(device="cpu")
74
+
75
+ def gradio_generate(prompt, steps, guidance):
76
+ output_wave = tango.generate(prompt, steps, guidance)
77
+ # output_filename = f"{prompt.replace(' ', '_')}_{steps}_{guidance}"[:250] + ".wav"
78
+ output_filename = "temp.wav"
79
+ wavio.write(output_filename, output_wave, rate=16000, sampwidth=2)
80
+
81
+ return output_filename
82
+
83
+ # description_text = """
84
+ # <p><a href="https://huggingface.co/spaces/declare-lab/tango/blob/main/app.py?duplicate=true"> <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> For faster inference without waiting in queue, you may duplicate the space and upgrade to a GPU in the settings. <br/><br/>
85
+ # Generate audio using TANGO by providing a text prompt.
86
+ # <br/><br/>Limitations: TANGO is trained on the small AudioCaps dataset so it may not generate good audio \
87
+ # samples related to concepts that it has not seen in training (e.g. singing). For the same reason, TANGO \
88
+ # is not always able to finely control its generations over textual control prompts. For example, \
89
+ # the generations from TANGO for prompts Chopping tomatoes on a wooden table and Chopping potatoes \
90
+ # on a metal table are very similar. \
91
+ # <br/><br/>We are currently training another version of TANGO on larger datasets to enhance its generalization, \
92
+ # compositional and controllable generation ability.
93
+ # <br/><br/>We recommend using a guidance scale of 3. The default number of steps is set to 100. More steps generally lead to better quality of generated audios but will take longer.
94
+ # <br/><br/>
95
+ # <h1> ChatGPT-enhanced audio generation</h1>
96
+ # <br/>
97
+ # As TANGO consists of an instruction-tuned LLM, it is able to process complex sound descriptions allowing us to provide more detailed instructions to improve the generation quality.
98
+ # For example, ``A boat is moving on the sea'' vs ``The sound of the water lapping against the hull of the boat or splashing as you move through the waves''. The latter is obtained by prompting ChatGPT to explain the sound generated when a boat moves on the sea.
99
+ # Using this ChatGPT-generated description of the sound, TANGO provides superior results.
100
+ # <p/>
101
+ # """
102
+ description_text = ""
103
+ # Gradio input and output components
104
+ input_text = gr.Textbox(lines=2, label="Prompt")
105
+ output_audio = gr.Audio(label="Generated Audio", type="filepath")
106
+ denoising_steps = gr.Slider(minimum=100, maximum=200, value=100, step=1, label="Steps", interactive=True)
107
+ guidance_scale = gr.Slider(minimum=1, maximum=10, value=3, step=0.1, label="Guidance Scale", interactive=True)
108
+
109
+ # Gradio interface
110
+ gr_interface = gr.Interface(
111
+ fn=gradio_generate,
112
+ inputs=[input_text, denoising_steps, guidance_scale],
113
+ outputs=[output_audio],
114
+ title="TANGO2: Aligning Diffusion-based Text-to-Audio Generative Models through Direct Preference Optimization",
115
+ description=description_text,
116
+ allow_flagging=False,
117
+ examples=[
118
+ ["A lady is singing a song with a kid"],
119
+ ["The sound of the water lapping against the hull of the boat or splashing as you move through the waves"],
120
+ ["An audience cheering and clapping"],
121
+ ["Rolling thunder with lightning strikes"],
122
+ ["Gentle water stream, birds chirping and sudden gun shot"],
123
+ ["A car engine revving"],
124
+ ["A dog barking"],
125
+ ["A cat meowing"],
126
+ ["Wooden table tapping sound while water pouring"],
127
+ ["Emergency sirens wailing"],
128
+ ["two gunshots followed by birds flying away while chirping"],
129
+ ["Whistling with birds chirping"],
130
+ ["A person snoring"],
131
+ ["Motor vehicles are driving with loud engines and a person whistles"],
132
+ ["People cheering in a stadium while thunder and lightning strikes"],
133
+ ["A helicopter is in flight"],
134
+ ["A dog barking and a man talking and a racing car passes by"],
135
+ ],
136
+ cache_examples=False, # Turn on to cache.
137
+ )
138
+
139
+ # Launch Gradio app
140
+ gr_interface.launch()
audioldm/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .ldm import LatentDiffusion
2
+ from .utils import seed_everything, save_wave, get_time, get_duration
3
+ from .pipeline import *
4
+
5
+
6
+
7
+
8
+
audioldm/__main__.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ import os
3
+ from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration
4
+ import argparse
5
+
6
+ CACHE_DIR = os.getenv(
7
+ "AUDIOLDM_CACHE_DIR",
8
+ os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
9
+
10
+ parser = argparse.ArgumentParser()
11
+
12
+ parser.add_argument(
13
+ "--mode",
14
+ type=str,
15
+ required=False,
16
+ default="generation",
17
+ help="generation: text-to-audio generation; transfer: style transfer",
18
+ choices=["generation", "transfer"]
19
+ )
20
+
21
+ parser.add_argument(
22
+ "-t",
23
+ "--text",
24
+ type=str,
25
+ required=False,
26
+ default="",
27
+ help="Text prompt to the model for audio generation",
28
+ )
29
+
30
+ parser.add_argument(
31
+ "-f",
32
+ "--file_path",
33
+ type=str,
34
+ required=False,
35
+ default=None,
36
+ help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio",
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--transfer_strength",
41
+ type=float,
42
+ required=False,
43
+ default=0.5,
44
+ help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text",
45
+ )
46
+
47
+ parser.add_argument(
48
+ "-s",
49
+ "--save_path",
50
+ type=str,
51
+ required=False,
52
+ help="The path to save model output",
53
+ default="./output",
54
+ )
55
+
56
+ parser.add_argument(
57
+ "--model_name",
58
+ type=str,
59
+ required=False,
60
+ help="The checkpoint you gonna use",
61
+ default="audioldm-s-full",
62
+ choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2"]
63
+ )
64
+
65
+ parser.add_argument(
66
+ "-ckpt",
67
+ "--ckpt_path",
68
+ type=str,
69
+ required=False,
70
+ help="The path to the pretrained .ckpt model",
71
+ default=None,
72
+ )
73
+
74
+ parser.add_argument(
75
+ "-b",
76
+ "--batchsize",
77
+ type=int,
78
+ required=False,
79
+ default=1,
80
+ help="Generate how many samples at the same time",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--ddim_steps",
85
+ type=int,
86
+ required=False,
87
+ default=200,
88
+ help="The sampling step for DDIM",
89
+ )
90
+
91
+ parser.add_argument(
92
+ "-gs",
93
+ "--guidance_scale",
94
+ type=float,
95
+ required=False,
96
+ default=2.5,
97
+ help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
98
+ )
99
+
100
+ parser.add_argument(
101
+ "-dur",
102
+ "--duration",
103
+ type=float,
104
+ required=False,
105
+ default=10.0,
106
+ help="The duration of the samples",
107
+ )
108
+
109
+ parser.add_argument(
110
+ "-n",
111
+ "--n_candidate_gen_per_text",
112
+ type=int,
113
+ required=False,
114
+ default=3,
115
+ help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--seed",
120
+ type=int,
121
+ required=False,
122
+ default=42,
123
+ help="Change this value (any integer number) will lead to a different generation result.",
124
+ )
125
+
126
+ args = parser.parse_args()
127
+
128
+ if(args.ckpt_path is not None):
129
+ print("Warning: ckpt_path has no effect after version 0.0.20.")
130
+
131
+ assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5"
132
+
133
+ mode = args.mode
134
+ if(mode == "generation" and args.file_path is not None):
135
+ mode = "generation_audio_to_audio"
136
+ if(len(args.text) > 0):
137
+ print("Warning: You have specified the --file_path. --text will be ignored")
138
+ args.text = ""
139
+
140
+ save_path = os.path.join(args.save_path, mode)
141
+
142
+ if(args.file_path is not None):
143
+ save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0]))
144
+
145
+ text = args.text
146
+ random_seed = args.seed
147
+ duration = args.duration
148
+ guidance_scale = args.guidance_scale
149
+ n_candidate_gen_per_text = args.n_candidate_gen_per_text
150
+
151
+ os.makedirs(save_path, exist_ok=True)
152
+ audioldm = build_model(model_name=args.model_name)
153
+
154
+ if(args.mode == "generation"):
155
+ waveform = text_to_audio(
156
+ audioldm,
157
+ text,
158
+ args.file_path,
159
+ random_seed,
160
+ duration=duration,
161
+ guidance_scale=guidance_scale,
162
+ ddim_steps=args.ddim_steps,
163
+ n_candidate_gen_per_text=n_candidate_gen_per_text,
164
+ batchsize=args.batchsize,
165
+ )
166
+
167
+ elif(args.mode == "transfer"):
168
+ assert args.file_path is not None
169
+ assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path
170
+ waveform = style_transfer(
171
+ audioldm,
172
+ text,
173
+ args.file_path,
174
+ args.transfer_strength,
175
+ random_seed,
176
+ duration=duration,
177
+ guidance_scale=guidance_scale,
178
+ ddim_steps=args.ddim_steps,
179
+ batchsize=args.batchsize,
180
+ )
181
+ waveform = waveform[:,None,:]
182
+
183
+ save_wave(waveform, save_path, name="%s_%s" % (get_time(), text))
audioldm/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (315 Bytes). View file
 
audioldm/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (322 Bytes). View file
 
audioldm/__pycache__/ldm.cpython-310.pyc ADDED
Binary file (16.1 kB). View file
 
audioldm/__pycache__/ldm.cpython-39.pyc ADDED
Binary file (16 kB). View file
 
audioldm/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (6.63 kB). View file
 
audioldm/__pycache__/pipeline.cpython-39.pyc ADDED
Binary file (6.54 kB). View file
 
audioldm/__pycache__/utils.cpython-310.pyc ADDED
Binary file (8.01 kB). View file
 
audioldm/__pycache__/utils.cpython-39.pyc ADDED
Binary file (7.35 kB). View file
 
audioldm/audio/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .tools import wav_to_fbank, read_wav_file
2
+ from .stft import TacotronSTFT
audioldm/audio/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (253 Bytes). View file
 
audioldm/audio/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (260 Bytes). View file
 
audioldm/audio/__pycache__/audio_processing.cpython-310.pyc ADDED
Binary file (2.78 kB). View file
 
audioldm/audio/__pycache__/audio_processing.cpython-39.pyc ADDED
Binary file (2.78 kB). View file
 
audioldm/audio/__pycache__/mix.cpython-39.pyc ADDED
Binary file (1.7 kB). View file
 
audioldm/audio/__pycache__/stft.cpython-310.pyc ADDED
Binary file (4.98 kB). View file
 
audioldm/audio/__pycache__/stft.cpython-39.pyc ADDED
Binary file (4.99 kB). View file
 
audioldm/audio/__pycache__/tools.cpython-310.pyc ADDED
Binary file (2.18 kB). View file
 
audioldm/audio/__pycache__/tools.cpython-39.pyc ADDED
Binary file (2.19 kB). View file
 
audioldm/audio/__pycache__/torch_tools.cpython-39.pyc ADDED
Binary file (3.79 kB). View file
 
audioldm/audio/audio_processing.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import librosa.util as librosa_util
4
+ from scipy.signal import get_window
5
+
6
+
7
+ def window_sumsquare(
8
+ window,
9
+ n_frames,
10
+ hop_length,
11
+ win_length,
12
+ n_fft,
13
+ dtype=np.float32,
14
+ norm=None,
15
+ ):
16
+ """
17
+ # from librosa 0.6
18
+ Compute the sum-square envelope of a window function at a given hop length.
19
+
20
+ This is used to estimate modulation effects induced by windowing
21
+ observations in short-time fourier transforms.
22
+
23
+ Parameters
24
+ ----------
25
+ window : string, tuple, number, callable, or list-like
26
+ Window specification, as in `get_window`
27
+
28
+ n_frames : int > 0
29
+ The number of analysis frames
30
+
31
+ hop_length : int > 0
32
+ The number of samples to advance between frames
33
+
34
+ win_length : [optional]
35
+ The length of the window function. By default, this matches `n_fft`.
36
+
37
+ n_fft : int > 0
38
+ The length of each analysis frame.
39
+
40
+ dtype : np.dtype
41
+ The data type of the output
42
+
43
+ Returns
44
+ -------
45
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46
+ The sum-squared envelope of the window function
47
+ """
48
+ if win_length is None:
49
+ win_length = n_fft
50
+
51
+ n = n_fft + hop_length * (n_frames - 1)
52
+ x = np.zeros(n, dtype=dtype)
53
+
54
+ # Compute the squared window at the desired length
55
+ win_sq = get_window(window, win_length, fftbins=True)
56
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
58
+
59
+ # Fill the envelope
60
+ for i in range(n_frames):
61
+ sample = i * hop_length
62
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63
+ return x
64
+
65
+
66
+ def griffin_lim(magnitudes, stft_fn, n_iters=30):
67
+ """
68
+ PARAMS
69
+ ------
70
+ magnitudes: spectrogram magnitudes
71
+ stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72
+ """
73
+
74
+ angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75
+ angles = angles.astype(np.float32)
76
+ angles = torch.autograd.Variable(torch.from_numpy(angles))
77
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78
+
79
+ for i in range(n_iters):
80
+ _, angles = stft_fn.transform(signal)
81
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82
+ return signal
83
+
84
+
85
+ def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
86
+ """
87
+ PARAMS
88
+ ------
89
+ C: compression factor
90
+ """
91
+ return normalize_fun(torch.clamp(x, min=clip_val) * C)
92
+
93
+
94
+ def dynamic_range_decompression(x, C=1):
95
+ """
96
+ PARAMS
97
+ ------
98
+ C: compression factor used to compress
99
+ """
100
+ return torch.exp(x) / C
audioldm/audio/stft.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from scipy.signal import get_window
5
+ from librosa.util import pad_center, tiny
6
+ from librosa.filters import mel as librosa_mel_fn
7
+
8
+ from audioldm.audio.audio_processing import (
9
+ dynamic_range_compression,
10
+ dynamic_range_decompression,
11
+ window_sumsquare,
12
+ )
13
+
14
+
15
+ class STFT(torch.nn.Module):
16
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17
+
18
+ def __init__(self, filter_length, hop_length, win_length, window="hann"):
19
+ super(STFT, self).__init__()
20
+ self.filter_length = filter_length
21
+ self.hop_length = hop_length
22
+ self.win_length = win_length
23
+ self.window = window
24
+ self.forward_transform = None
25
+ scale = self.filter_length / self.hop_length
26
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
27
+
28
+ cutoff = int((self.filter_length / 2 + 1))
29
+ fourier_basis = np.vstack(
30
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31
+ )
32
+
33
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34
+ inverse_basis = torch.FloatTensor(
35
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36
+ )
37
+
38
+ if window is not None:
39
+ assert filter_length >= win_length
40
+ # get window and zero center pad it to filter_length
41
+ fft_window = get_window(window, win_length, fftbins=True)
42
+ fft_window = pad_center(fft_window, filter_length)
43
+ fft_window = torch.from_numpy(fft_window).float()
44
+
45
+ # window the bases
46
+ forward_basis *= fft_window
47
+ inverse_basis *= fft_window
48
+
49
+ self.register_buffer("forward_basis", forward_basis.float())
50
+ self.register_buffer("inverse_basis", inverse_basis.float())
51
+
52
+ def transform(self, input_data):
53
+ device = self.forward_basis.device
54
+ input_data = input_data.to(device)
55
+
56
+ num_batches = input_data.size(0)
57
+ num_samples = input_data.size(1)
58
+
59
+ self.num_samples = num_samples
60
+
61
+ # similar to librosa, reflect-pad the input
62
+ input_data = input_data.view(num_batches, 1, num_samples)
63
+ input_data = F.pad(
64
+ input_data.unsqueeze(1),
65
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
66
+ mode="reflect",
67
+ )
68
+ input_data = input_data.squeeze(1)
69
+
70
+ forward_transform = F.conv1d(
71
+ input_data,
72
+ torch.autograd.Variable(self.forward_basis, requires_grad=False),
73
+ stride=self.hop_length,
74
+ padding=0,
75
+ )#.cpu()
76
+
77
+ cutoff = int((self.filter_length / 2) + 1)
78
+ real_part = forward_transform[:, :cutoff, :]
79
+ imag_part = forward_transform[:, cutoff:, :]
80
+
81
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
82
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
83
+
84
+ return magnitude, phase
85
+
86
+ def inverse(self, magnitude, phase):
87
+ device = self.forward_basis.device
88
+ magnitude, phase = magnitude.to(device), phase.to(device)
89
+
90
+ recombine_magnitude_phase = torch.cat(
91
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
92
+ )
93
+
94
+ inverse_transform = F.conv_transpose1d(
95
+ recombine_magnitude_phase,
96
+ torch.autograd.Variable(self.inverse_basis, requires_grad=False),
97
+ stride=self.hop_length,
98
+ padding=0,
99
+ )
100
+
101
+ if self.window is not None:
102
+ window_sum = window_sumsquare(
103
+ self.window,
104
+ magnitude.size(-1),
105
+ hop_length=self.hop_length,
106
+ win_length=self.win_length,
107
+ n_fft=self.filter_length,
108
+ dtype=np.float32,
109
+ )
110
+ # remove modulation effects
111
+ approx_nonzero_indices = torch.from_numpy(
112
+ np.where(window_sum > tiny(window_sum))[0]
113
+ )
114
+ window_sum = torch.autograd.Variable(
115
+ torch.from_numpy(window_sum), requires_grad=False
116
+ )
117
+ window_sum = window_sum
118
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
119
+ approx_nonzero_indices
120
+ ]
121
+
122
+ # scale by hop ratio
123
+ inverse_transform *= float(self.filter_length) / self.hop_length
124
+
125
+ inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
126
+ inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
127
+
128
+ return inverse_transform
129
+
130
+ def forward(self, input_data):
131
+ self.magnitude, self.phase = self.transform(input_data)
132
+ reconstruction = self.inverse(self.magnitude, self.phase)
133
+ return reconstruction
134
+
135
+
136
+ class TacotronSTFT(torch.nn.Module):
137
+ def __init__(
138
+ self,
139
+ filter_length,
140
+ hop_length,
141
+ win_length,
142
+ n_mel_channels,
143
+ sampling_rate,
144
+ mel_fmin,
145
+ mel_fmax,
146
+ ):
147
+ super(TacotronSTFT, self).__init__()
148
+ self.n_mel_channels = n_mel_channels
149
+ self.sampling_rate = sampling_rate
150
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
151
+ mel_basis = librosa_mel_fn(
152
+ sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
153
+ )
154
+ mel_basis = torch.from_numpy(mel_basis).float()
155
+ self.register_buffer("mel_basis", mel_basis)
156
+
157
+ def spectral_normalize(self, magnitudes, normalize_fun):
158
+ output = dynamic_range_compression(magnitudes, normalize_fun)
159
+ return output
160
+
161
+ def spectral_de_normalize(self, magnitudes):
162
+ output = dynamic_range_decompression(magnitudes)
163
+ return output
164
+
165
+ def mel_spectrogram(self, y, normalize_fun=torch.log):
166
+ """Computes mel-spectrograms from a batch of waves
167
+ PARAMS
168
+ ------
169
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
170
+
171
+ RETURNS
172
+ -------
173
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
174
+ """
175
+ assert torch.min(y.data) >= -1, torch.min(y.data)
176
+ assert torch.max(y.data) <= 1, torch.max(y.data)
177
+
178
+ magnitudes, phases = self.stft_fn.transform(y)
179
+ magnitudes = magnitudes.data
180
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
181
+ mel_output = self.spectral_normalize(mel_output, normalize_fun)
182
+ energy = torch.norm(magnitudes, dim=1)
183
+
184
+ log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
185
+
186
+ return mel_output, log_magnitudes, energy
audioldm/audio/tools.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torchaudio
4
+
5
+
6
+ def get_mel_from_wav(audio, _stft):
7
+ audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
8
+ audio = torch.autograd.Variable(audio, requires_grad=False)
9
+ melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
10
+ melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
11
+ log_magnitudes_stft = (
12
+ torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
13
+ )
14
+ energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
15
+ return melspec, log_magnitudes_stft, energy
16
+
17
+
18
+ def _pad_spec(fbank, target_length=1024):
19
+ n_frames = fbank.shape[0]
20
+ p = target_length - n_frames
21
+ # cut and pad
22
+ if p > 0:
23
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
24
+ fbank = m(fbank)
25
+ elif p < 0:
26
+ fbank = fbank[0:target_length, :]
27
+
28
+ if fbank.size(-1) % 2 != 0:
29
+ fbank = fbank[..., :-1]
30
+
31
+ return fbank
32
+
33
+
34
+ def pad_wav(waveform, segment_length):
35
+ waveform_length = waveform.shape[-1]
36
+ assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
37
+ if segment_length is None or waveform_length == segment_length:
38
+ return waveform
39
+ elif waveform_length > segment_length:
40
+ return waveform[:segment_length]
41
+ elif waveform_length < segment_length:
42
+ temp_wav = np.zeros((1, segment_length))
43
+ temp_wav[:, :waveform_length] = waveform
44
+ return temp_wav
45
+
46
+ def normalize_wav(waveform):
47
+ waveform = waveform - np.mean(waveform)
48
+ waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
49
+ return waveform * 0.5
50
+
51
+
52
+ def read_wav_file(filename, segment_length):
53
+ # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
54
+ waveform, sr = torchaudio.load(filename) # Faster!!!
55
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
56
+ waveform = waveform.numpy()[0, ...]
57
+ waveform = normalize_wav(waveform)
58
+ waveform = waveform[None, ...]
59
+ waveform = pad_wav(waveform, segment_length)
60
+
61
+ waveform = waveform / np.max(np.abs(waveform))
62
+ waveform = 0.5 * waveform
63
+
64
+ return waveform
65
+
66
+
67
+ def wav_to_fbank(filename, target_length=1024, fn_STFT=None):
68
+ assert fn_STFT is not None
69
+
70
+ # mixup
71
+ waveform = read_wav_file(filename, target_length * 160) # hop size is 160
72
+
73
+ waveform = waveform[0, ...]
74
+ waveform = torch.FloatTensor(waveform)
75
+
76
+ fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
77
+
78
+ fbank = torch.FloatTensor(fbank.T)
79
+ log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)
80
+
81
+ fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
82
+ log_magnitudes_stft, target_length
83
+ )
84
+
85
+ return fbank, log_magnitudes_stft, waveform
audioldm/hifigan/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .models import Generator
2
+
3
+
4
+ class AttrDict(dict):
5
+ def __init__(self, *args, **kwargs):
6
+ super(AttrDict, self).__init__(*args, **kwargs)
7
+ self.__dict__ = self
audioldm/hifigan/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (569 Bytes). View file
 
audioldm/hifigan/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (574 Bytes). View file
 
audioldm/hifigan/__pycache__/models.cpython-310.pyc ADDED
Binary file (3.73 kB). View file
 
audioldm/hifigan/__pycache__/models.cpython-39.pyc ADDED
Binary file (3.73 kB). View file
 
audioldm/hifigan/__pycache__/utilities.cpython-310.pyc ADDED
Binary file (2.48 kB). View file
 
audioldm/hifigan/__pycache__/utilities.cpython-39.pyc ADDED
Binary file (2.37 kB). View file
 
audioldm/hifigan/models.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+ LRELU_SLOPE = 0.1
8
+
9
+
10
+ def init_weights(m, mean=0.0, std=0.01):
11
+ classname = m.__class__.__name__
12
+ if classname.find("Conv") != -1:
13
+ m.weight.data.normal_(mean, std)
14
+
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size * dilation - dilation) / 2)
18
+
19
+
20
+ class ResBlock(torch.nn.Module):
21
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22
+ super(ResBlock, self).__init__()
23
+ self.h = h
24
+ self.convs1 = nn.ModuleList(
25
+ [
26
+ weight_norm(
27
+ Conv1d(
28
+ channels,
29
+ channels,
30
+ kernel_size,
31
+ 1,
32
+ dilation=dilation[0],
33
+ padding=get_padding(kernel_size, dilation[0]),
34
+ )
35
+ ),
36
+ weight_norm(
37
+ Conv1d(
38
+ channels,
39
+ channels,
40
+ kernel_size,
41
+ 1,
42
+ dilation=dilation[1],
43
+ padding=get_padding(kernel_size, dilation[1]),
44
+ )
45
+ ),
46
+ weight_norm(
47
+ Conv1d(
48
+ channels,
49
+ channels,
50
+ kernel_size,
51
+ 1,
52
+ dilation=dilation[2],
53
+ padding=get_padding(kernel_size, dilation[2]),
54
+ )
55
+ ),
56
+ ]
57
+ )
58
+ self.convs1.apply(init_weights)
59
+
60
+ self.convs2 = nn.ModuleList(
61
+ [
62
+ weight_norm(
63
+ Conv1d(
64
+ channels,
65
+ channels,
66
+ kernel_size,
67
+ 1,
68
+ dilation=1,
69
+ padding=get_padding(kernel_size, 1),
70
+ )
71
+ ),
72
+ weight_norm(
73
+ Conv1d(
74
+ channels,
75
+ channels,
76
+ kernel_size,
77
+ 1,
78
+ dilation=1,
79
+ padding=get_padding(kernel_size, 1),
80
+ )
81
+ ),
82
+ weight_norm(
83
+ Conv1d(
84
+ channels,
85
+ channels,
86
+ kernel_size,
87
+ 1,
88
+ dilation=1,
89
+ padding=get_padding(kernel_size, 1),
90
+ )
91
+ ),
92
+ ]
93
+ )
94
+ self.convs2.apply(init_weights)
95
+
96
+ def forward(self, x):
97
+ for c1, c2 in zip(self.convs1, self.convs2):
98
+ xt = F.leaky_relu(x, LRELU_SLOPE)
99
+ xt = c1(xt)
100
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
101
+ xt = c2(xt)
102
+ x = xt + x
103
+ return x
104
+
105
+ def remove_weight_norm(self):
106
+ for l in self.convs1:
107
+ remove_weight_norm(l)
108
+ for l in self.convs2:
109
+ remove_weight_norm(l)
110
+
111
+
112
+ class Generator(torch.nn.Module):
113
+ def __init__(self, h):
114
+ super(Generator, self).__init__()
115
+ self.h = h
116
+ self.num_kernels = len(h.resblock_kernel_sizes)
117
+ self.num_upsamples = len(h.upsample_rates)
118
+ self.conv_pre = weight_norm(
119
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
120
+ )
121
+ resblock = ResBlock
122
+
123
+ self.ups = nn.ModuleList()
124
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
125
+ self.ups.append(
126
+ weight_norm(
127
+ ConvTranspose1d(
128
+ h.upsample_initial_channel // (2**i),
129
+ h.upsample_initial_channel // (2 ** (i + 1)),
130
+ k,
131
+ u,
132
+ padding=(k - u) // 2,
133
+ )
134
+ )
135
+ )
136
+
137
+ self.resblocks = nn.ModuleList()
138
+ for i in range(len(self.ups)):
139
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
140
+ for j, (k, d) in enumerate(
141
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
142
+ ):
143
+ self.resblocks.append(resblock(h, ch, k, d))
144
+
145
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
146
+ self.ups.apply(init_weights)
147
+ self.conv_post.apply(init_weights)
148
+
149
+ def forward(self, x):
150
+ x = self.conv_pre(x)
151
+ for i in range(self.num_upsamples):
152
+ x = F.leaky_relu(x, LRELU_SLOPE)
153
+ x = self.ups[i](x)
154
+ xs = None
155
+ for j in range(self.num_kernels):
156
+ if xs is None:
157
+ xs = self.resblocks[i * self.num_kernels + j](x)
158
+ else:
159
+ xs += self.resblocks[i * self.num_kernels + j](x)
160
+ x = xs / self.num_kernels
161
+ x = F.leaky_relu(x)
162
+ x = self.conv_post(x)
163
+ x = torch.tanh(x)
164
+
165
+ return x
166
+
167
+ def remove_weight_norm(self):
168
+ # print("Removing weight norm...")
169
+ for l in self.ups:
170
+ remove_weight_norm(l)
171
+ for l in self.resblocks:
172
+ l.remove_weight_norm()
173
+ remove_weight_norm(self.conv_pre)
174
+ remove_weight_norm(self.conv_post)
audioldm/hifigan/utilities.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import torch
5
+ import numpy as np
6
+
7
+ import audioldm.hifigan as hifigan
8
+
9
+ HIFIGAN_16K_64 = {
10
+ "resblock": "1",
11
+ "num_gpus": 6,
12
+ "batch_size": 16,
13
+ "learning_rate": 0.0002,
14
+ "adam_b1": 0.8,
15
+ "adam_b2": 0.99,
16
+ "lr_decay": 0.999,
17
+ "seed": 1234,
18
+ "upsample_rates": [5, 4, 2, 2, 2],
19
+ "upsample_kernel_sizes": [16, 16, 8, 4, 4],
20
+ "upsample_initial_channel": 1024,
21
+ "resblock_kernel_sizes": [3, 7, 11],
22
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
23
+ "segment_size": 8192,
24
+ "num_mels": 64,
25
+ "num_freq": 1025,
26
+ "n_fft": 1024,
27
+ "hop_size": 160,
28
+ "win_size": 1024,
29
+ "sampling_rate": 16000,
30
+ "fmin": 0,
31
+ "fmax": 8000,
32
+ "fmax_for_loss": None,
33
+ "num_workers": 4,
34
+ "dist_config": {
35
+ "dist_backend": "nccl",
36
+ "dist_url": "tcp://localhost:54321",
37
+ "world_size": 1,
38
+ },
39
+ }
40
+
41
+
42
+ def get_available_checkpoint_keys(model, ckpt):
43
+ print("==> Attemp to reload from %s" % ckpt)
44
+ state_dict = torch.load(ckpt)["state_dict"]
45
+ current_state_dict = model.state_dict()
46
+ new_state_dict = {}
47
+ for k in state_dict.keys():
48
+ if (
49
+ k in current_state_dict.keys()
50
+ and current_state_dict[k].size() == state_dict[k].size()
51
+ ):
52
+ new_state_dict[k] = state_dict[k]
53
+ else:
54
+ print("==> WARNING: Skipping %s" % k)
55
+ print(
56
+ "%s out of %s keys are matched"
57
+ % (len(new_state_dict.keys()), len(state_dict.keys()))
58
+ )
59
+ return new_state_dict
60
+
61
+
62
+ def get_param_num(model):
63
+ num_param = sum(param.numel() for param in model.parameters())
64
+ return num_param
65
+
66
+
67
+ def get_vocoder(config, device):
68
+ config = hifigan.AttrDict(HIFIGAN_16K_64)
69
+ vocoder = hifigan.Generator(config)
70
+ vocoder.eval()
71
+ vocoder.remove_weight_norm()
72
+ vocoder.to(device)
73
+ return vocoder
74
+
75
+
76
+ def vocoder_infer(mels, vocoder, lengths=None):
77
+ vocoder.eval()
78
+ with torch.no_grad():
79
+ wavs = vocoder(mels).squeeze(1)
80
+
81
+ wavs = (wavs.cpu().numpy() * 32768).astype("int16")
82
+
83
+ if lengths is not None:
84
+ wavs = wavs[:, :lengths]
85
+
86
+ return wavs
audioldm/latent_diffusion/__init__.py ADDED
File without changes
audioldm/latent_diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (157 Bytes). View file
 
audioldm/latent_diffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (164 Bytes). View file
 
audioldm/latent_diffusion/__pycache__/attention.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
audioldm/latent_diffusion/__pycache__/attention.cpython-39.pyc ADDED
Binary file (11.4 kB). View file
 
audioldm/latent_diffusion/__pycache__/ddim.cpython-310.pyc ADDED
Binary file (7.2 kB). View file
 
audioldm/latent_diffusion/__pycache__/ddim.cpython-39.pyc ADDED
Binary file (7.11 kB). View file
 
audioldm/latent_diffusion/__pycache__/ddpm.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
audioldm/latent_diffusion/__pycache__/ddpm.cpython-39.pyc ADDED
Binary file (11 kB). View file
 
audioldm/latent_diffusion/__pycache__/ema.cpython-310.pyc ADDED
Binary file (3.01 kB). View file
 
audioldm/latent_diffusion/__pycache__/ema.cpython-39.pyc ADDED
Binary file (3 kB). View file
 
audioldm/latent_diffusion/__pycache__/openaimodel.cpython-39.pyc ADDED
Binary file (23.7 kB). View file
 
audioldm/latent_diffusion/__pycache__/util.cpython-310.pyc ADDED
Binary file (9.53 kB). View file
 
audioldm/latent_diffusion/__pycache__/util.cpython-39.pyc ADDED
Binary file (9.6 kB). View file
 
audioldm/latent_diffusion/attention.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from einops import rearrange
7
+
8
+ from audioldm.latent_diffusion.util import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return {el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = (
53
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
54
+ if not glu
55
+ else GEGLU(dim, inner_dim)
56
+ )
57
+
58
+ self.net = nn.Sequential(
59
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
60
+ )
61
+
62
+ def forward(self, x):
63
+ return self.net(x)
64
+
65
+
66
+ def zero_module(module):
67
+ """
68
+ Zero out the parameters of a module and return it.
69
+ """
70
+ for p in module.parameters():
71
+ p.detach().zero_()
72
+ return module
73
+
74
+
75
+ def Normalize(in_channels):
76
+ return torch.nn.GroupNorm(
77
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
78
+ )
79
+
80
+
81
+ class LinearAttention(nn.Module):
82
+ def __init__(self, dim, heads=4, dim_head=32):
83
+ super().__init__()
84
+ self.heads = heads
85
+ hidden_dim = dim_head * heads
86
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
87
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
88
+
89
+ def forward(self, x):
90
+ b, c, h, w = x.shape
91
+ qkv = self.to_qkv(x)
92
+ q, k, v = rearrange(
93
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
94
+ )
95
+ k = k.softmax(dim=-1)
96
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
97
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
98
+ out = rearrange(
99
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
100
+ )
101
+ return self.to_out(out)
102
+
103
+
104
+ class SpatialSelfAttention(nn.Module):
105
+ def __init__(self, in_channels):
106
+ super().__init__()
107
+ self.in_channels = in_channels
108
+
109
+ self.norm = Normalize(in_channels)
110
+ self.q = torch.nn.Conv2d(
111
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
112
+ )
113
+ self.k = torch.nn.Conv2d(
114
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
115
+ )
116
+ self.v = torch.nn.Conv2d(
117
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
118
+ )
119
+ self.proj_out = torch.nn.Conv2d(
120
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
121
+ )
122
+
123
+ def forward(self, x):
124
+ h_ = x
125
+ h_ = self.norm(h_)
126
+ q = self.q(h_)
127
+ k = self.k(h_)
128
+ v = self.v(h_)
129
+
130
+ # compute attention
131
+ b, c, h, w = q.shape
132
+ q = rearrange(q, "b c h w -> b (h w) c")
133
+ k = rearrange(k, "b c h w -> b c (h w)")
134
+ w_ = torch.einsum("bij,bjk->bik", q, k)
135
+
136
+ w_ = w_ * (int(c) ** (-0.5))
137
+ w_ = torch.nn.functional.softmax(w_, dim=2)
138
+
139
+ # attend to values
140
+ v = rearrange(v, "b c h w -> b c (h w)")
141
+ w_ = rearrange(w_, "b i j -> b j i")
142
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
143
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
144
+ h_ = self.proj_out(h_)
145
+
146
+ return x + h_
147
+
148
+
149
+ class CrossAttention(nn.Module):
150
+ """
151
+ ### Cross Attention Layer
152
+ This falls-back to self-attention when conditional embeddings are not specified.
153
+ """
154
+
155
+ # use_flash_attention: bool = True
156
+ use_flash_attention: bool = False
157
+
158
+ def __init__(
159
+ self,
160
+ query_dim,
161
+ context_dim=None,
162
+ heads=8,
163
+ dim_head=64,
164
+ dropout=0.0,
165
+ is_inplace: bool = True,
166
+ ):
167
+ # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
168
+ """
169
+ :param d_model: is the input embedding size
170
+ :param n_heads: is the number of attention heads
171
+ :param d_head: is the size of a attention head
172
+ :param d_cond: is the size of the conditional embeddings
173
+ :param is_inplace: specifies whether to perform the attention softmax computation inplace to
174
+ save memory
175
+ """
176
+ super().__init__()
177
+
178
+ self.is_inplace = is_inplace
179
+ self.n_heads = heads
180
+ self.d_head = dim_head
181
+
182
+ # Attention scaling factor
183
+ self.scale = dim_head**-0.5
184
+
185
+ # The normal self-attention layer
186
+ if context_dim is None:
187
+ context_dim = query_dim
188
+
189
+ # Query, key and value mappings
190
+ d_attn = dim_head * heads
191
+ self.to_q = nn.Linear(query_dim, d_attn, bias=False)
192
+ self.to_k = nn.Linear(context_dim, d_attn, bias=False)
193
+ self.to_v = nn.Linear(context_dim, d_attn, bias=False)
194
+
195
+ # Final linear layer
196
+ self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
197
+
198
+ # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
199
+ # Flash attention is only used if it's installed
200
+ # and `CrossAttention.use_flash_attention` is set to `True`.
201
+ try:
202
+ # You can install flash attention by cloning their Github repo,
203
+ # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
204
+ # and then running `python setup.py install`
205
+ from flash_attn.flash_attention import FlashAttention
206
+
207
+ self.flash = FlashAttention()
208
+ # Set the scale for scaled dot-product attention.
209
+ self.flash.softmax_scale = self.scale
210
+ # Set to `None` if it's not installed
211
+ except ImportError:
212
+ self.flash = None
213
+
214
+ def forward(self, x, context=None, mask=None):
215
+ """
216
+ :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
217
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
218
+ """
219
+
220
+ # If `cond` is `None` we perform self attention
221
+ has_cond = context is not None
222
+ if not has_cond:
223
+ context = x
224
+
225
+ # Get query, key and value vectors
226
+ q = self.to_q(x)
227
+ k = self.to_k(context)
228
+ v = self.to_v(context)
229
+
230
+ # Use flash attention if it's available and the head size is less than or equal to `128`
231
+ if (
232
+ CrossAttention.use_flash_attention
233
+ and self.flash is not None
234
+ and not has_cond
235
+ and self.d_head <= 128
236
+ ):
237
+ return self.flash_attention(q, k, v)
238
+ # Otherwise, fallback to normal attention
239
+ else:
240
+ return self.normal_attention(q, k, v)
241
+
242
+ def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
243
+ """
244
+ #### Flash Attention
245
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
246
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
247
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
248
+ """
249
+
250
+ # Get batch size and number of elements along sequence axis (`width * height`)
251
+ batch_size, seq_len, _ = q.shape
252
+
253
+ # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
254
+ # shape `[batch_size, seq_len, 3, n_heads * d_head]`
255
+ qkv = torch.stack((q, k, v), dim=2)
256
+ # Split the heads
257
+ qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
258
+
259
+ # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
260
+ # fit this size.
261
+ if self.d_head <= 32:
262
+ pad = 32 - self.d_head
263
+ elif self.d_head <= 64:
264
+ pad = 64 - self.d_head
265
+ elif self.d_head <= 128:
266
+ pad = 128 - self.d_head
267
+ else:
268
+ raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
269
+
270
+ # Pad the heads
271
+ if pad:
272
+ qkv = torch.cat(
273
+ (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
274
+ )
275
+
276
+ # Compute attention
277
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
278
+ # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
279
+ # TODO here I add the dtype changing
280
+ out, _ = self.flash(qkv.type(torch.float16))
281
+ # Truncate the extra head size
282
+ out = out[:, :, :, : self.d_head].float()
283
+ # Reshape to `[batch_size, seq_len, n_heads * d_head]`
284
+ out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
285
+
286
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
287
+ return self.to_out(out)
288
+
289
+ def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
290
+ """
291
+ #### Normal Attention
292
+
293
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
294
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
295
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
296
+ """
297
+
298
+ # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
299
+ q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32]
300
+ k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32]
301
+ v = v.view(*v.shape[:2], self.n_heads, -1)
302
+
303
+ # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
304
+ attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
305
+
306
+ # Compute softmax
307
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
308
+ if self.is_inplace:
309
+ half = attn.shape[0] // 2
310
+ attn[half:] = attn[half:].softmax(dim=-1)
311
+ attn[:half] = attn[:half].softmax(dim=-1)
312
+ else:
313
+ attn = attn.softmax(dim=-1)
314
+
315
+ # Compute attention output
316
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
317
+ # attn: [bs, 20, 64, 1]
318
+ # v: [bs, 1, 20, 32]
319
+ out = torch.einsum("bhij,bjhd->bihd", attn, v)
320
+ # Reshape to `[batch_size, height * width, n_heads * d_head]`
321
+ out = out.reshape(*out.shape[:2], -1)
322
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
323
+ return self.to_out(out)
324
+
325
+
326
+ # class CrossAttention(nn.Module):
327
+ # def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
328
+ # super().__init__()
329
+ # inner_dim = dim_head * heads
330
+ # context_dim = default(context_dim, query_dim)
331
+
332
+ # self.scale = dim_head ** -0.5
333
+ # self.heads = heads
334
+
335
+ # self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
336
+ # self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
337
+ # self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
338
+
339
+ # self.to_out = nn.Sequential(
340
+ # nn.Linear(inner_dim, query_dim),
341
+ # nn.Dropout(dropout)
342
+ # )
343
+
344
+ # def forward(self, x, context=None, mask=None):
345
+ # h = self.heads
346
+
347
+ # q = self.to_q(x)
348
+ # context = default(context, x)
349
+ # k = self.to_k(context)
350
+ # v = self.to_v(context)
351
+
352
+ # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
353
+
354
+ # sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
355
+
356
+ # if exists(mask):
357
+ # mask = rearrange(mask, 'b ... -> b (...)')
358
+ # max_neg_value = -torch.finfo(sim.dtype).max
359
+ # mask = repeat(mask, 'b j -> (b h) () j', h=h)
360
+ # sim.masked_fill_(~mask, max_neg_value)
361
+
362
+ # # attention, what we cannot get enough of
363
+ # attn = sim.softmax(dim=-1)
364
+
365
+ # out = einsum('b i j, b j d -> b i d', attn, v)
366
+ # out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
367
+ # return self.to_out(out)
368
+
369
+
370
+ class BasicTransformerBlock(nn.Module):
371
+ def __init__(
372
+ self,
373
+ dim,
374
+ n_heads,
375
+ d_head,
376
+ dropout=0.0,
377
+ context_dim=None,
378
+ gated_ff=True,
379
+ checkpoint=True,
380
+ ):
381
+ super().__init__()
382
+ self.attn1 = CrossAttention(
383
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
384
+ ) # is a self-attention
385
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
386
+ self.attn2 = CrossAttention(
387
+ query_dim=dim,
388
+ context_dim=context_dim,
389
+ heads=n_heads,
390
+ dim_head=d_head,
391
+ dropout=dropout,
392
+ ) # is self-attn if context is none
393
+ self.norm1 = nn.LayerNorm(dim)
394
+ self.norm2 = nn.LayerNorm(dim)
395
+ self.norm3 = nn.LayerNorm(dim)
396
+ self.checkpoint = checkpoint
397
+
398
+ def forward(self, x, context=None):
399
+ if context is None:
400
+ return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
401
+ else:
402
+ return checkpoint(
403
+ self._forward, (x, context), self.parameters(), self.checkpoint
404
+ )
405
+
406
+ def _forward(self, x, context=None):
407
+ x = self.attn1(self.norm1(x)) + x
408
+ x = self.attn2(self.norm2(x), context=context) + x
409
+ x = self.ff(self.norm3(x)) + x
410
+ return x
411
+
412
+
413
+ class SpatialTransformer(nn.Module):
414
+ """
415
+ Transformer block for image-like data.
416
+ First, project the input (aka embedding)
417
+ and reshape to b, t, d.
418
+ Then apply standard transformer action.
419
+ Finally, reshape to image
420
+ """
421
+
422
+ def __init__(
423
+ self,
424
+ in_channels,
425
+ n_heads,
426
+ d_head,
427
+ depth=1,
428
+ dropout=0.0,
429
+ context_dim=None,
430
+ no_context=False,
431
+ ):
432
+ super().__init__()
433
+
434
+ if no_context:
435
+ context_dim = None
436
+
437
+ self.in_channels = in_channels
438
+ inner_dim = n_heads * d_head
439
+ self.norm = Normalize(in_channels)
440
+
441
+ self.proj_in = nn.Conv2d(
442
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
443
+ )
444
+
445
+ self.transformer_blocks = nn.ModuleList(
446
+ [
447
+ BasicTransformerBlock(
448
+ inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
449
+ )
450
+ for d in range(depth)
451
+ ]
452
+ )
453
+
454
+ self.proj_out = zero_module(
455
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
456
+ )
457
+
458
+ def forward(self, x, context=None):
459
+ # note: if no context is given, cross-attention defaults to self-attention
460
+ b, c, h, w = x.shape
461
+ x_in = x
462
+ x = self.norm(x)
463
+ x = self.proj_in(x)
464
+ x = rearrange(x, "b c h w -> b (h w) c")
465
+ for block in self.transformer_blocks:
466
+ x = block(x, context=context)
467
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
468
+ x = self.proj_out(x)
469
+ return x + x_in