hungchiayu1
commited on
Commit
·
ffead1e
1
Parent(s):
e7af757
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +5 -5
- app.py +140 -0
- audioldm/__init__.py +8 -0
- audioldm/__main__.py +183 -0
- audioldm/__pycache__/__init__.cpython-310.pyc +0 -0
- audioldm/__pycache__/__init__.cpython-39.pyc +0 -0
- audioldm/__pycache__/ldm.cpython-310.pyc +0 -0
- audioldm/__pycache__/ldm.cpython-39.pyc +0 -0
- audioldm/__pycache__/pipeline.cpython-310.pyc +0 -0
- audioldm/__pycache__/pipeline.cpython-39.pyc +0 -0
- audioldm/__pycache__/utils.cpython-310.pyc +0 -0
- audioldm/__pycache__/utils.cpython-39.pyc +0 -0
- audioldm/audio/__init__.py +2 -0
- audioldm/audio/__pycache__/__init__.cpython-310.pyc +0 -0
- audioldm/audio/__pycache__/__init__.cpython-39.pyc +0 -0
- audioldm/audio/__pycache__/audio_processing.cpython-310.pyc +0 -0
- audioldm/audio/__pycache__/audio_processing.cpython-39.pyc +0 -0
- audioldm/audio/__pycache__/mix.cpython-39.pyc +0 -0
- audioldm/audio/__pycache__/stft.cpython-310.pyc +0 -0
- audioldm/audio/__pycache__/stft.cpython-39.pyc +0 -0
- audioldm/audio/__pycache__/tools.cpython-310.pyc +0 -0
- audioldm/audio/__pycache__/tools.cpython-39.pyc +0 -0
- audioldm/audio/__pycache__/torch_tools.cpython-39.pyc +0 -0
- audioldm/audio/audio_processing.py +100 -0
- audioldm/audio/stft.py +186 -0
- audioldm/audio/tools.py +85 -0
- audioldm/hifigan/__init__.py +7 -0
- audioldm/hifigan/__pycache__/__init__.cpython-310.pyc +0 -0
- audioldm/hifigan/__pycache__/__init__.cpython-39.pyc +0 -0
- audioldm/hifigan/__pycache__/models.cpython-310.pyc +0 -0
- audioldm/hifigan/__pycache__/models.cpython-39.pyc +0 -0
- audioldm/hifigan/__pycache__/utilities.cpython-310.pyc +0 -0
- audioldm/hifigan/__pycache__/utilities.cpython-39.pyc +0 -0
- audioldm/hifigan/models.py +174 -0
- audioldm/hifigan/utilities.py +86 -0
- audioldm/latent_diffusion/__init__.py +0 -0
- audioldm/latent_diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/attention.cpython-310.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/attention.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/ddim.cpython-310.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/ddim.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/ddpm.cpython-310.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/ddpm.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/ema.cpython-310.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/ema.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/openaimodel.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/util.cpython-310.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/util.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/attention.py +469 -0
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
+
title: Tango
|
3 |
+
emoji: 🐠
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.28.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import wavio
|
5 |
+
from tqdm import tqdm
|
6 |
+
from huggingface_hub import snapshot_download
|
7 |
+
from models import AudioDiffusion, DDPMScheduler
|
8 |
+
from audioldm.audio.stft import TacotronSTFT
|
9 |
+
from audioldm.variational_autoencoder import AutoencoderKL
|
10 |
+
from gradio import Markdown
|
11 |
+
|
12 |
+
class Tango:
|
13 |
+
def __init__(self, name="declare-lab/tango2", device="cuda:0"):
|
14 |
+
|
15 |
+
path = snapshot_download(repo_id=name)
|
16 |
+
|
17 |
+
vae_config = json.load(open("{}/vae_config.json".format(path)))
|
18 |
+
stft_config = json.load(open("{}/stft_config.json".format(path)))
|
19 |
+
main_config = json.load(open("{}/main_config.json".format(path)))
|
20 |
+
|
21 |
+
self.vae = AutoencoderKL(**vae_config).to(device)
|
22 |
+
self.stft = TacotronSTFT(**stft_config).to(device)
|
23 |
+
self.model = AudioDiffusion(**main_config).to(device)
|
24 |
+
|
25 |
+
vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location=device)
|
26 |
+
stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location=device)
|
27 |
+
main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location=device)
|
28 |
+
|
29 |
+
self.vae.load_state_dict(vae_weights)
|
30 |
+
self.stft.load_state_dict(stft_weights)
|
31 |
+
self.model.load_state_dict(main_weights)
|
32 |
+
|
33 |
+
print ("Successfully loaded checkpoint from:", name)
|
34 |
+
|
35 |
+
self.vae.eval()
|
36 |
+
self.stft.eval()
|
37 |
+
self.model.eval()
|
38 |
+
|
39 |
+
self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler")
|
40 |
+
|
41 |
+
def chunks(self, lst, n):
|
42 |
+
""" Yield successive n-sized chunks from a list. """
|
43 |
+
for i in range(0, len(lst), n):
|
44 |
+
yield lst[i:i + n]
|
45 |
+
|
46 |
+
def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
|
47 |
+
""" Genrate audio for a single prompt string. """
|
48 |
+
with torch.no_grad():
|
49 |
+
latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
|
50 |
+
mel = self.vae.decode_first_stage(latents)
|
51 |
+
wave = self.vae.decode_to_waveform(mel)
|
52 |
+
return wave[0]
|
53 |
+
|
54 |
+
def generate_for_batch(self, prompts, steps=200, guidance=3, samples=1, batch_size=8, disable_progress=True):
|
55 |
+
""" Genrate audio for a list of prompt strings. """
|
56 |
+
outputs = []
|
57 |
+
for k in tqdm(range(0, len(prompts), batch_size)):
|
58 |
+
batch = prompts[k: k+batch_size]
|
59 |
+
with torch.no_grad():
|
60 |
+
latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
|
61 |
+
mel = self.vae.decode_first_stage(latents)
|
62 |
+
wave = self.vae.decode_to_waveform(mel)
|
63 |
+
outputs += [item for item in wave]
|
64 |
+
if samples == 1:
|
65 |
+
return outputs
|
66 |
+
else:
|
67 |
+
return list(self.chunks(outputs, samples))
|
68 |
+
|
69 |
+
# Initialize TANGO
|
70 |
+
if torch.cuda.is_available():
|
71 |
+
tango = Tango()
|
72 |
+
else:
|
73 |
+
tango = Tango(device="cpu")
|
74 |
+
|
75 |
+
def gradio_generate(prompt, steps, guidance):
|
76 |
+
output_wave = tango.generate(prompt, steps, guidance)
|
77 |
+
# output_filename = f"{prompt.replace(' ', '_')}_{steps}_{guidance}"[:250] + ".wav"
|
78 |
+
output_filename = "temp.wav"
|
79 |
+
wavio.write(output_filename, output_wave, rate=16000, sampwidth=2)
|
80 |
+
|
81 |
+
return output_filename
|
82 |
+
|
83 |
+
# description_text = """
|
84 |
+
# <p><a href="https://huggingface.co/spaces/declare-lab/tango/blob/main/app.py?duplicate=true"> <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> For faster inference without waiting in queue, you may duplicate the space and upgrade to a GPU in the settings. <br/><br/>
|
85 |
+
# Generate audio using TANGO by providing a text prompt.
|
86 |
+
# <br/><br/>Limitations: TANGO is trained on the small AudioCaps dataset so it may not generate good audio \
|
87 |
+
# samples related to concepts that it has not seen in training (e.g. singing). For the same reason, TANGO \
|
88 |
+
# is not always able to finely control its generations over textual control prompts. For example, \
|
89 |
+
# the generations from TANGO for prompts Chopping tomatoes on a wooden table and Chopping potatoes \
|
90 |
+
# on a metal table are very similar. \
|
91 |
+
# <br/><br/>We are currently training another version of TANGO on larger datasets to enhance its generalization, \
|
92 |
+
# compositional and controllable generation ability.
|
93 |
+
# <br/><br/>We recommend using a guidance scale of 3. The default number of steps is set to 100. More steps generally lead to better quality of generated audios but will take longer.
|
94 |
+
# <br/><br/>
|
95 |
+
# <h1> ChatGPT-enhanced audio generation</h1>
|
96 |
+
# <br/>
|
97 |
+
# As TANGO consists of an instruction-tuned LLM, it is able to process complex sound descriptions allowing us to provide more detailed instructions to improve the generation quality.
|
98 |
+
# For example, ``A boat is moving on the sea'' vs ``The sound of the water lapping against the hull of the boat or splashing as you move through the waves''. The latter is obtained by prompting ChatGPT to explain the sound generated when a boat moves on the sea.
|
99 |
+
# Using this ChatGPT-generated description of the sound, TANGO provides superior results.
|
100 |
+
# <p/>
|
101 |
+
# """
|
102 |
+
description_text = ""
|
103 |
+
# Gradio input and output components
|
104 |
+
input_text = gr.Textbox(lines=2, label="Prompt")
|
105 |
+
output_audio = gr.Audio(label="Generated Audio", type="filepath")
|
106 |
+
denoising_steps = gr.Slider(minimum=100, maximum=200, value=100, step=1, label="Steps", interactive=True)
|
107 |
+
guidance_scale = gr.Slider(minimum=1, maximum=10, value=3, step=0.1, label="Guidance Scale", interactive=True)
|
108 |
+
|
109 |
+
# Gradio interface
|
110 |
+
gr_interface = gr.Interface(
|
111 |
+
fn=gradio_generate,
|
112 |
+
inputs=[input_text, denoising_steps, guidance_scale],
|
113 |
+
outputs=[output_audio],
|
114 |
+
title="TANGO2: Aligning Diffusion-based Text-to-Audio Generative Models through Direct Preference Optimization",
|
115 |
+
description=description_text,
|
116 |
+
allow_flagging=False,
|
117 |
+
examples=[
|
118 |
+
["A lady is singing a song with a kid"],
|
119 |
+
["The sound of the water lapping against the hull of the boat or splashing as you move through the waves"],
|
120 |
+
["An audience cheering and clapping"],
|
121 |
+
["Rolling thunder with lightning strikes"],
|
122 |
+
["Gentle water stream, birds chirping and sudden gun shot"],
|
123 |
+
["A car engine revving"],
|
124 |
+
["A dog barking"],
|
125 |
+
["A cat meowing"],
|
126 |
+
["Wooden table tapping sound while water pouring"],
|
127 |
+
["Emergency sirens wailing"],
|
128 |
+
["two gunshots followed by birds flying away while chirping"],
|
129 |
+
["Whistling with birds chirping"],
|
130 |
+
["A person snoring"],
|
131 |
+
["Motor vehicles are driving with loud engines and a person whistles"],
|
132 |
+
["People cheering in a stadium while thunder and lightning strikes"],
|
133 |
+
["A helicopter is in flight"],
|
134 |
+
["A dog barking and a man talking and a racing car passes by"],
|
135 |
+
],
|
136 |
+
cache_examples=False, # Turn on to cache.
|
137 |
+
)
|
138 |
+
|
139 |
+
# Launch Gradio app
|
140 |
+
gr_interface.launch()
|
audioldm/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ldm import LatentDiffusion
|
2 |
+
from .utils import seed_everything, save_wave, get_time, get_duration
|
3 |
+
from .pipeline import *
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
audioldm/__main__.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
import os
|
3 |
+
from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
CACHE_DIR = os.getenv(
|
7 |
+
"AUDIOLDM_CACHE_DIR",
|
8 |
+
os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
|
12 |
+
parser.add_argument(
|
13 |
+
"--mode",
|
14 |
+
type=str,
|
15 |
+
required=False,
|
16 |
+
default="generation",
|
17 |
+
help="generation: text-to-audio generation; transfer: style transfer",
|
18 |
+
choices=["generation", "transfer"]
|
19 |
+
)
|
20 |
+
|
21 |
+
parser.add_argument(
|
22 |
+
"-t",
|
23 |
+
"--text",
|
24 |
+
type=str,
|
25 |
+
required=False,
|
26 |
+
default="",
|
27 |
+
help="Text prompt to the model for audio generation",
|
28 |
+
)
|
29 |
+
|
30 |
+
parser.add_argument(
|
31 |
+
"-f",
|
32 |
+
"--file_path",
|
33 |
+
type=str,
|
34 |
+
required=False,
|
35 |
+
default=None,
|
36 |
+
help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio",
|
37 |
+
)
|
38 |
+
|
39 |
+
parser.add_argument(
|
40 |
+
"--transfer_strength",
|
41 |
+
type=float,
|
42 |
+
required=False,
|
43 |
+
default=0.5,
|
44 |
+
help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text",
|
45 |
+
)
|
46 |
+
|
47 |
+
parser.add_argument(
|
48 |
+
"-s",
|
49 |
+
"--save_path",
|
50 |
+
type=str,
|
51 |
+
required=False,
|
52 |
+
help="The path to save model output",
|
53 |
+
default="./output",
|
54 |
+
)
|
55 |
+
|
56 |
+
parser.add_argument(
|
57 |
+
"--model_name",
|
58 |
+
type=str,
|
59 |
+
required=False,
|
60 |
+
help="The checkpoint you gonna use",
|
61 |
+
default="audioldm-s-full",
|
62 |
+
choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2"]
|
63 |
+
)
|
64 |
+
|
65 |
+
parser.add_argument(
|
66 |
+
"-ckpt",
|
67 |
+
"--ckpt_path",
|
68 |
+
type=str,
|
69 |
+
required=False,
|
70 |
+
help="The path to the pretrained .ckpt model",
|
71 |
+
default=None,
|
72 |
+
)
|
73 |
+
|
74 |
+
parser.add_argument(
|
75 |
+
"-b",
|
76 |
+
"--batchsize",
|
77 |
+
type=int,
|
78 |
+
required=False,
|
79 |
+
default=1,
|
80 |
+
help="Generate how many samples at the same time",
|
81 |
+
)
|
82 |
+
|
83 |
+
parser.add_argument(
|
84 |
+
"--ddim_steps",
|
85 |
+
type=int,
|
86 |
+
required=False,
|
87 |
+
default=200,
|
88 |
+
help="The sampling step for DDIM",
|
89 |
+
)
|
90 |
+
|
91 |
+
parser.add_argument(
|
92 |
+
"-gs",
|
93 |
+
"--guidance_scale",
|
94 |
+
type=float,
|
95 |
+
required=False,
|
96 |
+
default=2.5,
|
97 |
+
help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
|
98 |
+
)
|
99 |
+
|
100 |
+
parser.add_argument(
|
101 |
+
"-dur",
|
102 |
+
"--duration",
|
103 |
+
type=float,
|
104 |
+
required=False,
|
105 |
+
default=10.0,
|
106 |
+
help="The duration of the samples",
|
107 |
+
)
|
108 |
+
|
109 |
+
parser.add_argument(
|
110 |
+
"-n",
|
111 |
+
"--n_candidate_gen_per_text",
|
112 |
+
type=int,
|
113 |
+
required=False,
|
114 |
+
default=3,
|
115 |
+
help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
|
116 |
+
)
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
"--seed",
|
120 |
+
type=int,
|
121 |
+
required=False,
|
122 |
+
default=42,
|
123 |
+
help="Change this value (any integer number) will lead to a different generation result.",
|
124 |
+
)
|
125 |
+
|
126 |
+
args = parser.parse_args()
|
127 |
+
|
128 |
+
if(args.ckpt_path is not None):
|
129 |
+
print("Warning: ckpt_path has no effect after version 0.0.20.")
|
130 |
+
|
131 |
+
assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5"
|
132 |
+
|
133 |
+
mode = args.mode
|
134 |
+
if(mode == "generation" and args.file_path is not None):
|
135 |
+
mode = "generation_audio_to_audio"
|
136 |
+
if(len(args.text) > 0):
|
137 |
+
print("Warning: You have specified the --file_path. --text will be ignored")
|
138 |
+
args.text = ""
|
139 |
+
|
140 |
+
save_path = os.path.join(args.save_path, mode)
|
141 |
+
|
142 |
+
if(args.file_path is not None):
|
143 |
+
save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0]))
|
144 |
+
|
145 |
+
text = args.text
|
146 |
+
random_seed = args.seed
|
147 |
+
duration = args.duration
|
148 |
+
guidance_scale = args.guidance_scale
|
149 |
+
n_candidate_gen_per_text = args.n_candidate_gen_per_text
|
150 |
+
|
151 |
+
os.makedirs(save_path, exist_ok=True)
|
152 |
+
audioldm = build_model(model_name=args.model_name)
|
153 |
+
|
154 |
+
if(args.mode == "generation"):
|
155 |
+
waveform = text_to_audio(
|
156 |
+
audioldm,
|
157 |
+
text,
|
158 |
+
args.file_path,
|
159 |
+
random_seed,
|
160 |
+
duration=duration,
|
161 |
+
guidance_scale=guidance_scale,
|
162 |
+
ddim_steps=args.ddim_steps,
|
163 |
+
n_candidate_gen_per_text=n_candidate_gen_per_text,
|
164 |
+
batchsize=args.batchsize,
|
165 |
+
)
|
166 |
+
|
167 |
+
elif(args.mode == "transfer"):
|
168 |
+
assert args.file_path is not None
|
169 |
+
assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path
|
170 |
+
waveform = style_transfer(
|
171 |
+
audioldm,
|
172 |
+
text,
|
173 |
+
args.file_path,
|
174 |
+
args.transfer_strength,
|
175 |
+
random_seed,
|
176 |
+
duration=duration,
|
177 |
+
guidance_scale=guidance_scale,
|
178 |
+
ddim_steps=args.ddim_steps,
|
179 |
+
batchsize=args.batchsize,
|
180 |
+
)
|
181 |
+
waveform = waveform[:,None,:]
|
182 |
+
|
183 |
+
save_wave(waveform, save_path, name="%s_%s" % (get_time(), text))
|
audioldm/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (315 Bytes). View file
|
|
audioldm/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (322 Bytes). View file
|
|
audioldm/__pycache__/ldm.cpython-310.pyc
ADDED
Binary file (16.1 kB). View file
|
|
audioldm/__pycache__/ldm.cpython-39.pyc
ADDED
Binary file (16 kB). View file
|
|
audioldm/__pycache__/pipeline.cpython-310.pyc
ADDED
Binary file (6.63 kB). View file
|
|
audioldm/__pycache__/pipeline.cpython-39.pyc
ADDED
Binary file (6.54 kB). View file
|
|
audioldm/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (8.01 kB). View file
|
|
audioldm/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (7.35 kB). View file
|
|
audioldm/audio/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .tools import wav_to_fbank, read_wav_file
|
2 |
+
from .stft import TacotronSTFT
|
audioldm/audio/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (253 Bytes). View file
|
|
audioldm/audio/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (260 Bytes). View file
|
|
audioldm/audio/__pycache__/audio_processing.cpython-310.pyc
ADDED
Binary file (2.78 kB). View file
|
|
audioldm/audio/__pycache__/audio_processing.cpython-39.pyc
ADDED
Binary file (2.78 kB). View file
|
|
audioldm/audio/__pycache__/mix.cpython-39.pyc
ADDED
Binary file (1.7 kB). View file
|
|
audioldm/audio/__pycache__/stft.cpython-310.pyc
ADDED
Binary file (4.98 kB). View file
|
|
audioldm/audio/__pycache__/stft.cpython-39.pyc
ADDED
Binary file (4.99 kB). View file
|
|
audioldm/audio/__pycache__/tools.cpython-310.pyc
ADDED
Binary file (2.18 kB). View file
|
|
audioldm/audio/__pycache__/tools.cpython-39.pyc
ADDED
Binary file (2.19 kB). View file
|
|
audioldm/audio/__pycache__/torch_tools.cpython-39.pyc
ADDED
Binary file (3.79 kB). View file
|
|
audioldm/audio/audio_processing.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import librosa.util as librosa_util
|
4 |
+
from scipy.signal import get_window
|
5 |
+
|
6 |
+
|
7 |
+
def window_sumsquare(
|
8 |
+
window,
|
9 |
+
n_frames,
|
10 |
+
hop_length,
|
11 |
+
win_length,
|
12 |
+
n_fft,
|
13 |
+
dtype=np.float32,
|
14 |
+
norm=None,
|
15 |
+
):
|
16 |
+
"""
|
17 |
+
# from librosa 0.6
|
18 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
19 |
+
|
20 |
+
This is used to estimate modulation effects induced by windowing
|
21 |
+
observations in short-time fourier transforms.
|
22 |
+
|
23 |
+
Parameters
|
24 |
+
----------
|
25 |
+
window : string, tuple, number, callable, or list-like
|
26 |
+
Window specification, as in `get_window`
|
27 |
+
|
28 |
+
n_frames : int > 0
|
29 |
+
The number of analysis frames
|
30 |
+
|
31 |
+
hop_length : int > 0
|
32 |
+
The number of samples to advance between frames
|
33 |
+
|
34 |
+
win_length : [optional]
|
35 |
+
The length of the window function. By default, this matches `n_fft`.
|
36 |
+
|
37 |
+
n_fft : int > 0
|
38 |
+
The length of each analysis frame.
|
39 |
+
|
40 |
+
dtype : np.dtype
|
41 |
+
The data type of the output
|
42 |
+
|
43 |
+
Returns
|
44 |
+
-------
|
45 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
46 |
+
The sum-squared envelope of the window function
|
47 |
+
"""
|
48 |
+
if win_length is None:
|
49 |
+
win_length = n_fft
|
50 |
+
|
51 |
+
n = n_fft + hop_length * (n_frames - 1)
|
52 |
+
x = np.zeros(n, dtype=dtype)
|
53 |
+
|
54 |
+
# Compute the squared window at the desired length
|
55 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
56 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
|
57 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
58 |
+
|
59 |
+
# Fill the envelope
|
60 |
+
for i in range(n_frames):
|
61 |
+
sample = i * hop_length
|
62 |
+
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
def griffin_lim(magnitudes, stft_fn, n_iters=30):
|
67 |
+
"""
|
68 |
+
PARAMS
|
69 |
+
------
|
70 |
+
magnitudes: spectrogram magnitudes
|
71 |
+
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
|
72 |
+
"""
|
73 |
+
|
74 |
+
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
|
75 |
+
angles = angles.astype(np.float32)
|
76 |
+
angles = torch.autograd.Variable(torch.from_numpy(angles))
|
77 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
78 |
+
|
79 |
+
for i in range(n_iters):
|
80 |
+
_, angles = stft_fn.transform(signal)
|
81 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
82 |
+
return signal
|
83 |
+
|
84 |
+
|
85 |
+
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
|
86 |
+
"""
|
87 |
+
PARAMS
|
88 |
+
------
|
89 |
+
C: compression factor
|
90 |
+
"""
|
91 |
+
return normalize_fun(torch.clamp(x, min=clip_val) * C)
|
92 |
+
|
93 |
+
|
94 |
+
def dynamic_range_decompression(x, C=1):
|
95 |
+
"""
|
96 |
+
PARAMS
|
97 |
+
------
|
98 |
+
C: compression factor used to compress
|
99 |
+
"""
|
100 |
+
return torch.exp(x) / C
|
audioldm/audio/stft.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from scipy.signal import get_window
|
5 |
+
from librosa.util import pad_center, tiny
|
6 |
+
from librosa.filters import mel as librosa_mel_fn
|
7 |
+
|
8 |
+
from audioldm.audio.audio_processing import (
|
9 |
+
dynamic_range_compression,
|
10 |
+
dynamic_range_decompression,
|
11 |
+
window_sumsquare,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
class STFT(torch.nn.Module):
|
16 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
17 |
+
|
18 |
+
def __init__(self, filter_length, hop_length, win_length, window="hann"):
|
19 |
+
super(STFT, self).__init__()
|
20 |
+
self.filter_length = filter_length
|
21 |
+
self.hop_length = hop_length
|
22 |
+
self.win_length = win_length
|
23 |
+
self.window = window
|
24 |
+
self.forward_transform = None
|
25 |
+
scale = self.filter_length / self.hop_length
|
26 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
27 |
+
|
28 |
+
cutoff = int((self.filter_length / 2 + 1))
|
29 |
+
fourier_basis = np.vstack(
|
30 |
+
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
31 |
+
)
|
32 |
+
|
33 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
34 |
+
inverse_basis = torch.FloatTensor(
|
35 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
36 |
+
)
|
37 |
+
|
38 |
+
if window is not None:
|
39 |
+
assert filter_length >= win_length
|
40 |
+
# get window and zero center pad it to filter_length
|
41 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
42 |
+
fft_window = pad_center(fft_window, filter_length)
|
43 |
+
fft_window = torch.from_numpy(fft_window).float()
|
44 |
+
|
45 |
+
# window the bases
|
46 |
+
forward_basis *= fft_window
|
47 |
+
inverse_basis *= fft_window
|
48 |
+
|
49 |
+
self.register_buffer("forward_basis", forward_basis.float())
|
50 |
+
self.register_buffer("inverse_basis", inverse_basis.float())
|
51 |
+
|
52 |
+
def transform(self, input_data):
|
53 |
+
device = self.forward_basis.device
|
54 |
+
input_data = input_data.to(device)
|
55 |
+
|
56 |
+
num_batches = input_data.size(0)
|
57 |
+
num_samples = input_data.size(1)
|
58 |
+
|
59 |
+
self.num_samples = num_samples
|
60 |
+
|
61 |
+
# similar to librosa, reflect-pad the input
|
62 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
63 |
+
input_data = F.pad(
|
64 |
+
input_data.unsqueeze(1),
|
65 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
66 |
+
mode="reflect",
|
67 |
+
)
|
68 |
+
input_data = input_data.squeeze(1)
|
69 |
+
|
70 |
+
forward_transform = F.conv1d(
|
71 |
+
input_data,
|
72 |
+
torch.autograd.Variable(self.forward_basis, requires_grad=False),
|
73 |
+
stride=self.hop_length,
|
74 |
+
padding=0,
|
75 |
+
)#.cpu()
|
76 |
+
|
77 |
+
cutoff = int((self.filter_length / 2) + 1)
|
78 |
+
real_part = forward_transform[:, :cutoff, :]
|
79 |
+
imag_part = forward_transform[:, cutoff:, :]
|
80 |
+
|
81 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
82 |
+
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
|
83 |
+
|
84 |
+
return magnitude, phase
|
85 |
+
|
86 |
+
def inverse(self, magnitude, phase):
|
87 |
+
device = self.forward_basis.device
|
88 |
+
magnitude, phase = magnitude.to(device), phase.to(device)
|
89 |
+
|
90 |
+
recombine_magnitude_phase = torch.cat(
|
91 |
+
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
92 |
+
)
|
93 |
+
|
94 |
+
inverse_transform = F.conv_transpose1d(
|
95 |
+
recombine_magnitude_phase,
|
96 |
+
torch.autograd.Variable(self.inverse_basis, requires_grad=False),
|
97 |
+
stride=self.hop_length,
|
98 |
+
padding=0,
|
99 |
+
)
|
100 |
+
|
101 |
+
if self.window is not None:
|
102 |
+
window_sum = window_sumsquare(
|
103 |
+
self.window,
|
104 |
+
magnitude.size(-1),
|
105 |
+
hop_length=self.hop_length,
|
106 |
+
win_length=self.win_length,
|
107 |
+
n_fft=self.filter_length,
|
108 |
+
dtype=np.float32,
|
109 |
+
)
|
110 |
+
# remove modulation effects
|
111 |
+
approx_nonzero_indices = torch.from_numpy(
|
112 |
+
np.where(window_sum > tiny(window_sum))[0]
|
113 |
+
)
|
114 |
+
window_sum = torch.autograd.Variable(
|
115 |
+
torch.from_numpy(window_sum), requires_grad=False
|
116 |
+
)
|
117 |
+
window_sum = window_sum
|
118 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
119 |
+
approx_nonzero_indices
|
120 |
+
]
|
121 |
+
|
122 |
+
# scale by hop ratio
|
123 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
124 |
+
|
125 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
|
126 |
+
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
|
127 |
+
|
128 |
+
return inverse_transform
|
129 |
+
|
130 |
+
def forward(self, input_data):
|
131 |
+
self.magnitude, self.phase = self.transform(input_data)
|
132 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
133 |
+
return reconstruction
|
134 |
+
|
135 |
+
|
136 |
+
class TacotronSTFT(torch.nn.Module):
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
filter_length,
|
140 |
+
hop_length,
|
141 |
+
win_length,
|
142 |
+
n_mel_channels,
|
143 |
+
sampling_rate,
|
144 |
+
mel_fmin,
|
145 |
+
mel_fmax,
|
146 |
+
):
|
147 |
+
super(TacotronSTFT, self).__init__()
|
148 |
+
self.n_mel_channels = n_mel_channels
|
149 |
+
self.sampling_rate = sampling_rate
|
150 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
151 |
+
mel_basis = librosa_mel_fn(
|
152 |
+
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
|
153 |
+
)
|
154 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
155 |
+
self.register_buffer("mel_basis", mel_basis)
|
156 |
+
|
157 |
+
def spectral_normalize(self, magnitudes, normalize_fun):
|
158 |
+
output = dynamic_range_compression(magnitudes, normalize_fun)
|
159 |
+
return output
|
160 |
+
|
161 |
+
def spectral_de_normalize(self, magnitudes):
|
162 |
+
output = dynamic_range_decompression(magnitudes)
|
163 |
+
return output
|
164 |
+
|
165 |
+
def mel_spectrogram(self, y, normalize_fun=torch.log):
|
166 |
+
"""Computes mel-spectrograms from a batch of waves
|
167 |
+
PARAMS
|
168 |
+
------
|
169 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
170 |
+
|
171 |
+
RETURNS
|
172 |
+
-------
|
173 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
174 |
+
"""
|
175 |
+
assert torch.min(y.data) >= -1, torch.min(y.data)
|
176 |
+
assert torch.max(y.data) <= 1, torch.max(y.data)
|
177 |
+
|
178 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
179 |
+
magnitudes = magnitudes.data
|
180 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
181 |
+
mel_output = self.spectral_normalize(mel_output, normalize_fun)
|
182 |
+
energy = torch.norm(magnitudes, dim=1)
|
183 |
+
|
184 |
+
log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
|
185 |
+
|
186 |
+
return mel_output, log_magnitudes, energy
|
audioldm/audio/tools.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torchaudio
|
4 |
+
|
5 |
+
|
6 |
+
def get_mel_from_wav(audio, _stft):
|
7 |
+
audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
|
8 |
+
audio = torch.autograd.Variable(audio, requires_grad=False)
|
9 |
+
melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
|
10 |
+
melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
|
11 |
+
log_magnitudes_stft = (
|
12 |
+
torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
|
13 |
+
)
|
14 |
+
energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
|
15 |
+
return melspec, log_magnitudes_stft, energy
|
16 |
+
|
17 |
+
|
18 |
+
def _pad_spec(fbank, target_length=1024):
|
19 |
+
n_frames = fbank.shape[0]
|
20 |
+
p = target_length - n_frames
|
21 |
+
# cut and pad
|
22 |
+
if p > 0:
|
23 |
+
m = torch.nn.ZeroPad2d((0, 0, 0, p))
|
24 |
+
fbank = m(fbank)
|
25 |
+
elif p < 0:
|
26 |
+
fbank = fbank[0:target_length, :]
|
27 |
+
|
28 |
+
if fbank.size(-1) % 2 != 0:
|
29 |
+
fbank = fbank[..., :-1]
|
30 |
+
|
31 |
+
return fbank
|
32 |
+
|
33 |
+
|
34 |
+
def pad_wav(waveform, segment_length):
|
35 |
+
waveform_length = waveform.shape[-1]
|
36 |
+
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
|
37 |
+
if segment_length is None or waveform_length == segment_length:
|
38 |
+
return waveform
|
39 |
+
elif waveform_length > segment_length:
|
40 |
+
return waveform[:segment_length]
|
41 |
+
elif waveform_length < segment_length:
|
42 |
+
temp_wav = np.zeros((1, segment_length))
|
43 |
+
temp_wav[:, :waveform_length] = waveform
|
44 |
+
return temp_wav
|
45 |
+
|
46 |
+
def normalize_wav(waveform):
|
47 |
+
waveform = waveform - np.mean(waveform)
|
48 |
+
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
|
49 |
+
return waveform * 0.5
|
50 |
+
|
51 |
+
|
52 |
+
def read_wav_file(filename, segment_length):
|
53 |
+
# waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
|
54 |
+
waveform, sr = torchaudio.load(filename) # Faster!!!
|
55 |
+
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
|
56 |
+
waveform = waveform.numpy()[0, ...]
|
57 |
+
waveform = normalize_wav(waveform)
|
58 |
+
waveform = waveform[None, ...]
|
59 |
+
waveform = pad_wav(waveform, segment_length)
|
60 |
+
|
61 |
+
waveform = waveform / np.max(np.abs(waveform))
|
62 |
+
waveform = 0.5 * waveform
|
63 |
+
|
64 |
+
return waveform
|
65 |
+
|
66 |
+
|
67 |
+
def wav_to_fbank(filename, target_length=1024, fn_STFT=None):
|
68 |
+
assert fn_STFT is not None
|
69 |
+
|
70 |
+
# mixup
|
71 |
+
waveform = read_wav_file(filename, target_length * 160) # hop size is 160
|
72 |
+
|
73 |
+
waveform = waveform[0, ...]
|
74 |
+
waveform = torch.FloatTensor(waveform)
|
75 |
+
|
76 |
+
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
|
77 |
+
|
78 |
+
fbank = torch.FloatTensor(fbank.T)
|
79 |
+
log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)
|
80 |
+
|
81 |
+
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
|
82 |
+
log_magnitudes_stft, target_length
|
83 |
+
)
|
84 |
+
|
85 |
+
return fbank, log_magnitudes_stft, waveform
|
audioldm/hifigan/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import Generator
|
2 |
+
|
3 |
+
|
4 |
+
class AttrDict(dict):
|
5 |
+
def __init__(self, *args, **kwargs):
|
6 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
7 |
+
self.__dict__ = self
|
audioldm/hifigan/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (569 Bytes). View file
|
|
audioldm/hifigan/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (574 Bytes). View file
|
|
audioldm/hifigan/__pycache__/models.cpython-310.pyc
ADDED
Binary file (3.73 kB). View file
|
|
audioldm/hifigan/__pycache__/models.cpython-39.pyc
ADDED
Binary file (3.73 kB). View file
|
|
audioldm/hifigan/__pycache__/utilities.cpython-310.pyc
ADDED
Binary file (2.48 kB). View file
|
|
audioldm/hifigan/__pycache__/utilities.cpython-39.pyc
ADDED
Binary file (2.37 kB). View file
|
|
audioldm/hifigan/models.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
6 |
+
|
7 |
+
LRELU_SLOPE = 0.1
|
8 |
+
|
9 |
+
|
10 |
+
def init_weights(m, mean=0.0, std=0.01):
|
11 |
+
classname = m.__class__.__name__
|
12 |
+
if classname.find("Conv") != -1:
|
13 |
+
m.weight.data.normal_(mean, std)
|
14 |
+
|
15 |
+
|
16 |
+
def get_padding(kernel_size, dilation=1):
|
17 |
+
return int((kernel_size * dilation - dilation) / 2)
|
18 |
+
|
19 |
+
|
20 |
+
class ResBlock(torch.nn.Module):
|
21 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
22 |
+
super(ResBlock, self).__init__()
|
23 |
+
self.h = h
|
24 |
+
self.convs1 = nn.ModuleList(
|
25 |
+
[
|
26 |
+
weight_norm(
|
27 |
+
Conv1d(
|
28 |
+
channels,
|
29 |
+
channels,
|
30 |
+
kernel_size,
|
31 |
+
1,
|
32 |
+
dilation=dilation[0],
|
33 |
+
padding=get_padding(kernel_size, dilation[0]),
|
34 |
+
)
|
35 |
+
),
|
36 |
+
weight_norm(
|
37 |
+
Conv1d(
|
38 |
+
channels,
|
39 |
+
channels,
|
40 |
+
kernel_size,
|
41 |
+
1,
|
42 |
+
dilation=dilation[1],
|
43 |
+
padding=get_padding(kernel_size, dilation[1]),
|
44 |
+
)
|
45 |
+
),
|
46 |
+
weight_norm(
|
47 |
+
Conv1d(
|
48 |
+
channels,
|
49 |
+
channels,
|
50 |
+
kernel_size,
|
51 |
+
1,
|
52 |
+
dilation=dilation[2],
|
53 |
+
padding=get_padding(kernel_size, dilation[2]),
|
54 |
+
)
|
55 |
+
),
|
56 |
+
]
|
57 |
+
)
|
58 |
+
self.convs1.apply(init_weights)
|
59 |
+
|
60 |
+
self.convs2 = nn.ModuleList(
|
61 |
+
[
|
62 |
+
weight_norm(
|
63 |
+
Conv1d(
|
64 |
+
channels,
|
65 |
+
channels,
|
66 |
+
kernel_size,
|
67 |
+
1,
|
68 |
+
dilation=1,
|
69 |
+
padding=get_padding(kernel_size, 1),
|
70 |
+
)
|
71 |
+
),
|
72 |
+
weight_norm(
|
73 |
+
Conv1d(
|
74 |
+
channels,
|
75 |
+
channels,
|
76 |
+
kernel_size,
|
77 |
+
1,
|
78 |
+
dilation=1,
|
79 |
+
padding=get_padding(kernel_size, 1),
|
80 |
+
)
|
81 |
+
),
|
82 |
+
weight_norm(
|
83 |
+
Conv1d(
|
84 |
+
channels,
|
85 |
+
channels,
|
86 |
+
kernel_size,
|
87 |
+
1,
|
88 |
+
dilation=1,
|
89 |
+
padding=get_padding(kernel_size, 1),
|
90 |
+
)
|
91 |
+
),
|
92 |
+
]
|
93 |
+
)
|
94 |
+
self.convs2.apply(init_weights)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
98 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
99 |
+
xt = c1(xt)
|
100 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
101 |
+
xt = c2(xt)
|
102 |
+
x = xt + x
|
103 |
+
return x
|
104 |
+
|
105 |
+
def remove_weight_norm(self):
|
106 |
+
for l in self.convs1:
|
107 |
+
remove_weight_norm(l)
|
108 |
+
for l in self.convs2:
|
109 |
+
remove_weight_norm(l)
|
110 |
+
|
111 |
+
|
112 |
+
class Generator(torch.nn.Module):
|
113 |
+
def __init__(self, h):
|
114 |
+
super(Generator, self).__init__()
|
115 |
+
self.h = h
|
116 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
117 |
+
self.num_upsamples = len(h.upsample_rates)
|
118 |
+
self.conv_pre = weight_norm(
|
119 |
+
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
120 |
+
)
|
121 |
+
resblock = ResBlock
|
122 |
+
|
123 |
+
self.ups = nn.ModuleList()
|
124 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
125 |
+
self.ups.append(
|
126 |
+
weight_norm(
|
127 |
+
ConvTranspose1d(
|
128 |
+
h.upsample_initial_channel // (2**i),
|
129 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
130 |
+
k,
|
131 |
+
u,
|
132 |
+
padding=(k - u) // 2,
|
133 |
+
)
|
134 |
+
)
|
135 |
+
)
|
136 |
+
|
137 |
+
self.resblocks = nn.ModuleList()
|
138 |
+
for i in range(len(self.ups)):
|
139 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
140 |
+
for j, (k, d) in enumerate(
|
141 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
142 |
+
):
|
143 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
144 |
+
|
145 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
146 |
+
self.ups.apply(init_weights)
|
147 |
+
self.conv_post.apply(init_weights)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
x = self.conv_pre(x)
|
151 |
+
for i in range(self.num_upsamples):
|
152 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
153 |
+
x = self.ups[i](x)
|
154 |
+
xs = None
|
155 |
+
for j in range(self.num_kernels):
|
156 |
+
if xs is None:
|
157 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
158 |
+
else:
|
159 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
160 |
+
x = xs / self.num_kernels
|
161 |
+
x = F.leaky_relu(x)
|
162 |
+
x = self.conv_post(x)
|
163 |
+
x = torch.tanh(x)
|
164 |
+
|
165 |
+
return x
|
166 |
+
|
167 |
+
def remove_weight_norm(self):
|
168 |
+
# print("Removing weight norm...")
|
169 |
+
for l in self.ups:
|
170 |
+
remove_weight_norm(l)
|
171 |
+
for l in self.resblocks:
|
172 |
+
l.remove_weight_norm()
|
173 |
+
remove_weight_norm(self.conv_pre)
|
174 |
+
remove_weight_norm(self.conv_post)
|
audioldm/hifigan/utilities.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import audioldm.hifigan as hifigan
|
8 |
+
|
9 |
+
HIFIGAN_16K_64 = {
|
10 |
+
"resblock": "1",
|
11 |
+
"num_gpus": 6,
|
12 |
+
"batch_size": 16,
|
13 |
+
"learning_rate": 0.0002,
|
14 |
+
"adam_b1": 0.8,
|
15 |
+
"adam_b2": 0.99,
|
16 |
+
"lr_decay": 0.999,
|
17 |
+
"seed": 1234,
|
18 |
+
"upsample_rates": [5, 4, 2, 2, 2],
|
19 |
+
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
20 |
+
"upsample_initial_channel": 1024,
|
21 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
22 |
+
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
23 |
+
"segment_size": 8192,
|
24 |
+
"num_mels": 64,
|
25 |
+
"num_freq": 1025,
|
26 |
+
"n_fft": 1024,
|
27 |
+
"hop_size": 160,
|
28 |
+
"win_size": 1024,
|
29 |
+
"sampling_rate": 16000,
|
30 |
+
"fmin": 0,
|
31 |
+
"fmax": 8000,
|
32 |
+
"fmax_for_loss": None,
|
33 |
+
"num_workers": 4,
|
34 |
+
"dist_config": {
|
35 |
+
"dist_backend": "nccl",
|
36 |
+
"dist_url": "tcp://localhost:54321",
|
37 |
+
"world_size": 1,
|
38 |
+
},
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def get_available_checkpoint_keys(model, ckpt):
|
43 |
+
print("==> Attemp to reload from %s" % ckpt)
|
44 |
+
state_dict = torch.load(ckpt)["state_dict"]
|
45 |
+
current_state_dict = model.state_dict()
|
46 |
+
new_state_dict = {}
|
47 |
+
for k in state_dict.keys():
|
48 |
+
if (
|
49 |
+
k in current_state_dict.keys()
|
50 |
+
and current_state_dict[k].size() == state_dict[k].size()
|
51 |
+
):
|
52 |
+
new_state_dict[k] = state_dict[k]
|
53 |
+
else:
|
54 |
+
print("==> WARNING: Skipping %s" % k)
|
55 |
+
print(
|
56 |
+
"%s out of %s keys are matched"
|
57 |
+
% (len(new_state_dict.keys()), len(state_dict.keys()))
|
58 |
+
)
|
59 |
+
return new_state_dict
|
60 |
+
|
61 |
+
|
62 |
+
def get_param_num(model):
|
63 |
+
num_param = sum(param.numel() for param in model.parameters())
|
64 |
+
return num_param
|
65 |
+
|
66 |
+
|
67 |
+
def get_vocoder(config, device):
|
68 |
+
config = hifigan.AttrDict(HIFIGAN_16K_64)
|
69 |
+
vocoder = hifigan.Generator(config)
|
70 |
+
vocoder.eval()
|
71 |
+
vocoder.remove_weight_norm()
|
72 |
+
vocoder.to(device)
|
73 |
+
return vocoder
|
74 |
+
|
75 |
+
|
76 |
+
def vocoder_infer(mels, vocoder, lengths=None):
|
77 |
+
vocoder.eval()
|
78 |
+
with torch.no_grad():
|
79 |
+
wavs = vocoder(mels).squeeze(1)
|
80 |
+
|
81 |
+
wavs = (wavs.cpu().numpy() * 32768).astype("int16")
|
82 |
+
|
83 |
+
if lengths is not None:
|
84 |
+
wavs = wavs[:, :lengths]
|
85 |
+
|
86 |
+
return wavs
|
audioldm/latent_diffusion/__init__.py
ADDED
File without changes
|
audioldm/latent_diffusion/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (157 Bytes). View file
|
|
audioldm/latent_diffusion/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (164 Bytes). View file
|
|
audioldm/latent_diffusion/__pycache__/attention.cpython-310.pyc
ADDED
Binary file (11.4 kB). View file
|
|
audioldm/latent_diffusion/__pycache__/attention.cpython-39.pyc
ADDED
Binary file (11.4 kB). View file
|
|
audioldm/latent_diffusion/__pycache__/ddim.cpython-310.pyc
ADDED
Binary file (7.2 kB). View file
|
|
audioldm/latent_diffusion/__pycache__/ddim.cpython-39.pyc
ADDED
Binary file (7.11 kB). View file
|
|
audioldm/latent_diffusion/__pycache__/ddpm.cpython-310.pyc
ADDED
Binary file (11.1 kB). View file
|
|
audioldm/latent_diffusion/__pycache__/ddpm.cpython-39.pyc
ADDED
Binary file (11 kB). View file
|
|
audioldm/latent_diffusion/__pycache__/ema.cpython-310.pyc
ADDED
Binary file (3.01 kB). View file
|
|
audioldm/latent_diffusion/__pycache__/ema.cpython-39.pyc
ADDED
Binary file (3 kB). View file
|
|
audioldm/latent_diffusion/__pycache__/openaimodel.cpython-39.pyc
ADDED
Binary file (23.7 kB). View file
|
|
audioldm/latent_diffusion/__pycache__/util.cpython-310.pyc
ADDED
Binary file (9.53 kB). View file
|
|
audioldm/latent_diffusion/__pycache__/util.cpython-39.pyc
ADDED
Binary file (9.6 kB). View file
|
|
audioldm/latent_diffusion/attention.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inspect import isfunction
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from audioldm.latent_diffusion.util import checkpoint
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def uniq(arr):
|
16 |
+
return {el: True for el in arr}.keys()
|
17 |
+
|
18 |
+
|
19 |
+
def default(val, d):
|
20 |
+
if exists(val):
|
21 |
+
return val
|
22 |
+
return d() if isfunction(d) else d
|
23 |
+
|
24 |
+
|
25 |
+
def max_neg_value(t):
|
26 |
+
return -torch.finfo(t.dtype).max
|
27 |
+
|
28 |
+
|
29 |
+
def init_(tensor):
|
30 |
+
dim = tensor.shape[-1]
|
31 |
+
std = 1 / math.sqrt(dim)
|
32 |
+
tensor.uniform_(-std, std)
|
33 |
+
return tensor
|
34 |
+
|
35 |
+
|
36 |
+
# feedforward
|
37 |
+
class GEGLU(nn.Module):
|
38 |
+
def __init__(self, dim_in, dim_out):
|
39 |
+
super().__init__()
|
40 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
44 |
+
return x * F.gelu(gate)
|
45 |
+
|
46 |
+
|
47 |
+
class FeedForward(nn.Module):
|
48 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
49 |
+
super().__init__()
|
50 |
+
inner_dim = int(dim * mult)
|
51 |
+
dim_out = default(dim_out, dim)
|
52 |
+
project_in = (
|
53 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
54 |
+
if not glu
|
55 |
+
else GEGLU(dim, inner_dim)
|
56 |
+
)
|
57 |
+
|
58 |
+
self.net = nn.Sequential(
|
59 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
60 |
+
)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
return self.net(x)
|
64 |
+
|
65 |
+
|
66 |
+
def zero_module(module):
|
67 |
+
"""
|
68 |
+
Zero out the parameters of a module and return it.
|
69 |
+
"""
|
70 |
+
for p in module.parameters():
|
71 |
+
p.detach().zero_()
|
72 |
+
return module
|
73 |
+
|
74 |
+
|
75 |
+
def Normalize(in_channels):
|
76 |
+
return torch.nn.GroupNorm(
|
77 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
class LinearAttention(nn.Module):
|
82 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
83 |
+
super().__init__()
|
84 |
+
self.heads = heads
|
85 |
+
hidden_dim = dim_head * heads
|
86 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
87 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
b, c, h, w = x.shape
|
91 |
+
qkv = self.to_qkv(x)
|
92 |
+
q, k, v = rearrange(
|
93 |
+
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
94 |
+
)
|
95 |
+
k = k.softmax(dim=-1)
|
96 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
97 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
98 |
+
out = rearrange(
|
99 |
+
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
100 |
+
)
|
101 |
+
return self.to_out(out)
|
102 |
+
|
103 |
+
|
104 |
+
class SpatialSelfAttention(nn.Module):
|
105 |
+
def __init__(self, in_channels):
|
106 |
+
super().__init__()
|
107 |
+
self.in_channels = in_channels
|
108 |
+
|
109 |
+
self.norm = Normalize(in_channels)
|
110 |
+
self.q = torch.nn.Conv2d(
|
111 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
112 |
+
)
|
113 |
+
self.k = torch.nn.Conv2d(
|
114 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
115 |
+
)
|
116 |
+
self.v = torch.nn.Conv2d(
|
117 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
118 |
+
)
|
119 |
+
self.proj_out = torch.nn.Conv2d(
|
120 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
121 |
+
)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
h_ = x
|
125 |
+
h_ = self.norm(h_)
|
126 |
+
q = self.q(h_)
|
127 |
+
k = self.k(h_)
|
128 |
+
v = self.v(h_)
|
129 |
+
|
130 |
+
# compute attention
|
131 |
+
b, c, h, w = q.shape
|
132 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
133 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
134 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
135 |
+
|
136 |
+
w_ = w_ * (int(c) ** (-0.5))
|
137 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
138 |
+
|
139 |
+
# attend to values
|
140 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
141 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
142 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
143 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
144 |
+
h_ = self.proj_out(h_)
|
145 |
+
|
146 |
+
return x + h_
|
147 |
+
|
148 |
+
|
149 |
+
class CrossAttention(nn.Module):
|
150 |
+
"""
|
151 |
+
### Cross Attention Layer
|
152 |
+
This falls-back to self-attention when conditional embeddings are not specified.
|
153 |
+
"""
|
154 |
+
|
155 |
+
# use_flash_attention: bool = True
|
156 |
+
use_flash_attention: bool = False
|
157 |
+
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
query_dim,
|
161 |
+
context_dim=None,
|
162 |
+
heads=8,
|
163 |
+
dim_head=64,
|
164 |
+
dropout=0.0,
|
165 |
+
is_inplace: bool = True,
|
166 |
+
):
|
167 |
+
# def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
|
168 |
+
"""
|
169 |
+
:param d_model: is the input embedding size
|
170 |
+
:param n_heads: is the number of attention heads
|
171 |
+
:param d_head: is the size of a attention head
|
172 |
+
:param d_cond: is the size of the conditional embeddings
|
173 |
+
:param is_inplace: specifies whether to perform the attention softmax computation inplace to
|
174 |
+
save memory
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
|
178 |
+
self.is_inplace = is_inplace
|
179 |
+
self.n_heads = heads
|
180 |
+
self.d_head = dim_head
|
181 |
+
|
182 |
+
# Attention scaling factor
|
183 |
+
self.scale = dim_head**-0.5
|
184 |
+
|
185 |
+
# The normal self-attention layer
|
186 |
+
if context_dim is None:
|
187 |
+
context_dim = query_dim
|
188 |
+
|
189 |
+
# Query, key and value mappings
|
190 |
+
d_attn = dim_head * heads
|
191 |
+
self.to_q = nn.Linear(query_dim, d_attn, bias=False)
|
192 |
+
self.to_k = nn.Linear(context_dim, d_attn, bias=False)
|
193 |
+
self.to_v = nn.Linear(context_dim, d_attn, bias=False)
|
194 |
+
|
195 |
+
# Final linear layer
|
196 |
+
self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
|
197 |
+
|
198 |
+
# Setup [flash attention](https://github.com/HazyResearch/flash-attention).
|
199 |
+
# Flash attention is only used if it's installed
|
200 |
+
# and `CrossAttention.use_flash_attention` is set to `True`.
|
201 |
+
try:
|
202 |
+
# You can install flash attention by cloning their Github repo,
|
203 |
+
# [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
|
204 |
+
# and then running `python setup.py install`
|
205 |
+
from flash_attn.flash_attention import FlashAttention
|
206 |
+
|
207 |
+
self.flash = FlashAttention()
|
208 |
+
# Set the scale for scaled dot-product attention.
|
209 |
+
self.flash.softmax_scale = self.scale
|
210 |
+
# Set to `None` if it's not installed
|
211 |
+
except ImportError:
|
212 |
+
self.flash = None
|
213 |
+
|
214 |
+
def forward(self, x, context=None, mask=None):
|
215 |
+
"""
|
216 |
+
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
|
217 |
+
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
218 |
+
"""
|
219 |
+
|
220 |
+
# If `cond` is `None` we perform self attention
|
221 |
+
has_cond = context is not None
|
222 |
+
if not has_cond:
|
223 |
+
context = x
|
224 |
+
|
225 |
+
# Get query, key and value vectors
|
226 |
+
q = self.to_q(x)
|
227 |
+
k = self.to_k(context)
|
228 |
+
v = self.to_v(context)
|
229 |
+
|
230 |
+
# Use flash attention if it's available and the head size is less than or equal to `128`
|
231 |
+
if (
|
232 |
+
CrossAttention.use_flash_attention
|
233 |
+
and self.flash is not None
|
234 |
+
and not has_cond
|
235 |
+
and self.d_head <= 128
|
236 |
+
):
|
237 |
+
return self.flash_attention(q, k, v)
|
238 |
+
# Otherwise, fallback to normal attention
|
239 |
+
else:
|
240 |
+
return self.normal_attention(q, k, v)
|
241 |
+
|
242 |
+
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
243 |
+
"""
|
244 |
+
#### Flash Attention
|
245 |
+
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
246 |
+
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
247 |
+
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
248 |
+
"""
|
249 |
+
|
250 |
+
# Get batch size and number of elements along sequence axis (`width * height`)
|
251 |
+
batch_size, seq_len, _ = q.shape
|
252 |
+
|
253 |
+
# Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
|
254 |
+
# shape `[batch_size, seq_len, 3, n_heads * d_head]`
|
255 |
+
qkv = torch.stack((q, k, v), dim=2)
|
256 |
+
# Split the heads
|
257 |
+
qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
|
258 |
+
|
259 |
+
# Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
|
260 |
+
# fit this size.
|
261 |
+
if self.d_head <= 32:
|
262 |
+
pad = 32 - self.d_head
|
263 |
+
elif self.d_head <= 64:
|
264 |
+
pad = 64 - self.d_head
|
265 |
+
elif self.d_head <= 128:
|
266 |
+
pad = 128 - self.d_head
|
267 |
+
else:
|
268 |
+
raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
|
269 |
+
|
270 |
+
# Pad the heads
|
271 |
+
if pad:
|
272 |
+
qkv = torch.cat(
|
273 |
+
(qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
|
274 |
+
)
|
275 |
+
|
276 |
+
# Compute attention
|
277 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
|
278 |
+
# This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
|
279 |
+
# TODO here I add the dtype changing
|
280 |
+
out, _ = self.flash(qkv.type(torch.float16))
|
281 |
+
# Truncate the extra head size
|
282 |
+
out = out[:, :, :, : self.d_head].float()
|
283 |
+
# Reshape to `[batch_size, seq_len, n_heads * d_head]`
|
284 |
+
out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
|
285 |
+
|
286 |
+
# Map to `[batch_size, height * width, d_model]` with a linear layer
|
287 |
+
return self.to_out(out)
|
288 |
+
|
289 |
+
def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
290 |
+
"""
|
291 |
+
#### Normal Attention
|
292 |
+
|
293 |
+
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
294 |
+
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
295 |
+
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
296 |
+
"""
|
297 |
+
|
298 |
+
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
|
299 |
+
q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32]
|
300 |
+
k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32]
|
301 |
+
v = v.view(*v.shape[:2], self.n_heads, -1)
|
302 |
+
|
303 |
+
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
|
304 |
+
attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
|
305 |
+
|
306 |
+
# Compute softmax
|
307 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
|
308 |
+
if self.is_inplace:
|
309 |
+
half = attn.shape[0] // 2
|
310 |
+
attn[half:] = attn[half:].softmax(dim=-1)
|
311 |
+
attn[:half] = attn[:half].softmax(dim=-1)
|
312 |
+
else:
|
313 |
+
attn = attn.softmax(dim=-1)
|
314 |
+
|
315 |
+
# Compute attention output
|
316 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
|
317 |
+
# attn: [bs, 20, 64, 1]
|
318 |
+
# v: [bs, 1, 20, 32]
|
319 |
+
out = torch.einsum("bhij,bjhd->bihd", attn, v)
|
320 |
+
# Reshape to `[batch_size, height * width, n_heads * d_head]`
|
321 |
+
out = out.reshape(*out.shape[:2], -1)
|
322 |
+
# Map to `[batch_size, height * width, d_model]` with a linear layer
|
323 |
+
return self.to_out(out)
|
324 |
+
|
325 |
+
|
326 |
+
# class CrossAttention(nn.Module):
|
327 |
+
# def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
328 |
+
# super().__init__()
|
329 |
+
# inner_dim = dim_head * heads
|
330 |
+
# context_dim = default(context_dim, query_dim)
|
331 |
+
|
332 |
+
# self.scale = dim_head ** -0.5
|
333 |
+
# self.heads = heads
|
334 |
+
|
335 |
+
# self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
336 |
+
# self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
337 |
+
# self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
338 |
+
|
339 |
+
# self.to_out = nn.Sequential(
|
340 |
+
# nn.Linear(inner_dim, query_dim),
|
341 |
+
# nn.Dropout(dropout)
|
342 |
+
# )
|
343 |
+
|
344 |
+
# def forward(self, x, context=None, mask=None):
|
345 |
+
# h = self.heads
|
346 |
+
|
347 |
+
# q = self.to_q(x)
|
348 |
+
# context = default(context, x)
|
349 |
+
# k = self.to_k(context)
|
350 |
+
# v = self.to_v(context)
|
351 |
+
|
352 |
+
# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
353 |
+
|
354 |
+
# sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
355 |
+
|
356 |
+
# if exists(mask):
|
357 |
+
# mask = rearrange(mask, 'b ... -> b (...)')
|
358 |
+
# max_neg_value = -torch.finfo(sim.dtype).max
|
359 |
+
# mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
360 |
+
# sim.masked_fill_(~mask, max_neg_value)
|
361 |
+
|
362 |
+
# # attention, what we cannot get enough of
|
363 |
+
# attn = sim.softmax(dim=-1)
|
364 |
+
|
365 |
+
# out = einsum('b i j, b j d -> b i d', attn, v)
|
366 |
+
# out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
367 |
+
# return self.to_out(out)
|
368 |
+
|
369 |
+
|
370 |
+
class BasicTransformerBlock(nn.Module):
|
371 |
+
def __init__(
|
372 |
+
self,
|
373 |
+
dim,
|
374 |
+
n_heads,
|
375 |
+
d_head,
|
376 |
+
dropout=0.0,
|
377 |
+
context_dim=None,
|
378 |
+
gated_ff=True,
|
379 |
+
checkpoint=True,
|
380 |
+
):
|
381 |
+
super().__init__()
|
382 |
+
self.attn1 = CrossAttention(
|
383 |
+
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
384 |
+
) # is a self-attention
|
385 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
386 |
+
self.attn2 = CrossAttention(
|
387 |
+
query_dim=dim,
|
388 |
+
context_dim=context_dim,
|
389 |
+
heads=n_heads,
|
390 |
+
dim_head=d_head,
|
391 |
+
dropout=dropout,
|
392 |
+
) # is self-attn if context is none
|
393 |
+
self.norm1 = nn.LayerNorm(dim)
|
394 |
+
self.norm2 = nn.LayerNorm(dim)
|
395 |
+
self.norm3 = nn.LayerNorm(dim)
|
396 |
+
self.checkpoint = checkpoint
|
397 |
+
|
398 |
+
def forward(self, x, context=None):
|
399 |
+
if context is None:
|
400 |
+
return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
|
401 |
+
else:
|
402 |
+
return checkpoint(
|
403 |
+
self._forward, (x, context), self.parameters(), self.checkpoint
|
404 |
+
)
|
405 |
+
|
406 |
+
def _forward(self, x, context=None):
|
407 |
+
x = self.attn1(self.norm1(x)) + x
|
408 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
409 |
+
x = self.ff(self.norm3(x)) + x
|
410 |
+
return x
|
411 |
+
|
412 |
+
|
413 |
+
class SpatialTransformer(nn.Module):
|
414 |
+
"""
|
415 |
+
Transformer block for image-like data.
|
416 |
+
First, project the input (aka embedding)
|
417 |
+
and reshape to b, t, d.
|
418 |
+
Then apply standard transformer action.
|
419 |
+
Finally, reshape to image
|
420 |
+
"""
|
421 |
+
|
422 |
+
def __init__(
|
423 |
+
self,
|
424 |
+
in_channels,
|
425 |
+
n_heads,
|
426 |
+
d_head,
|
427 |
+
depth=1,
|
428 |
+
dropout=0.0,
|
429 |
+
context_dim=None,
|
430 |
+
no_context=False,
|
431 |
+
):
|
432 |
+
super().__init__()
|
433 |
+
|
434 |
+
if no_context:
|
435 |
+
context_dim = None
|
436 |
+
|
437 |
+
self.in_channels = in_channels
|
438 |
+
inner_dim = n_heads * d_head
|
439 |
+
self.norm = Normalize(in_channels)
|
440 |
+
|
441 |
+
self.proj_in = nn.Conv2d(
|
442 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
443 |
+
)
|
444 |
+
|
445 |
+
self.transformer_blocks = nn.ModuleList(
|
446 |
+
[
|
447 |
+
BasicTransformerBlock(
|
448 |
+
inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
|
449 |
+
)
|
450 |
+
for d in range(depth)
|
451 |
+
]
|
452 |
+
)
|
453 |
+
|
454 |
+
self.proj_out = zero_module(
|
455 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
456 |
+
)
|
457 |
+
|
458 |
+
def forward(self, x, context=None):
|
459 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
460 |
+
b, c, h, w = x.shape
|
461 |
+
x_in = x
|
462 |
+
x = self.norm(x)
|
463 |
+
x = self.proj_in(x)
|
464 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
465 |
+
for block in self.transformer_blocks:
|
466 |
+
x = block(x, context=context)
|
467 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
468 |
+
x = self.proj_out(x)
|
469 |
+
return x + x_in
|