3v324v23 commited on
Commit
a84a65c
1 Parent(s): 28cda0c
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README copy.md +107 -0
  3. app.py +199 -0
  4. audiocaps_test_struct.tsv +3 -0
  5. data/audiocaps_test_struct.tsv +0 -0
  6. data/musiccaps_test_16000_struct.tsv +0 -0
  7. infer.sh +20 -0
  8. ldm/__pycache__/util.cpython-38.pyc +0 -0
  9. ldm/__pycache__/util.cpython-39.pyc +0 -0
  10. ldm/data/__pycache__/joinaudiodataset_anylen.cpython-38.pyc +0 -0
  11. ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-38.pyc +0 -0
  12. ldm/data/joinaudiodataset_anylen.py +330 -0
  13. ldm/data/joinaudiodataset_struct_sample_anylen.py +380 -0
  14. ldm/data/tsv_dirs/full_data/V1_new/audiocaps_train_16000.tsv +3 -0
  15. ldm/data/tsv_dirs/full_data/V2/MACS.tsv +3 -0
  16. ldm/data/tsv_dirs/full_data/V2/WavText5K.tsv +3 -0
  17. ldm/data/tsv_dirs/full_data/V2/adobe.tsv +3 -0
  18. ldm/data/tsv_dirs/full_data/V2/audiostock.tsv +3 -0
  19. ldm/data/tsv_dirs/full_data/V2/epidemic_sound.tsv +3 -0
  20. ldm/data/tsv_dirs/full_data/caps_struct/audiocaps_train_16000_struct2.tsv +3 -0
  21. ldm/data/txt_spec_dataset.py +171 -0
  22. ldm/data/video_spec_maa2_dataset.py +837 -0
  23. ldm/lr_scheduler.py +98 -0
  24. ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  25. ldm/models/__pycache__/autoencoder.cpython-39.pyc +0 -0
  26. ldm/models/__pycache__/autoencoder1d.cpython-38.pyc +0 -0
  27. ldm/models/autoencoder.py +503 -0
  28. ldm/models/autoencoder1d.py +517 -0
  29. ldm/models/diffusion/__init__.py +0 -0
  30. ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  31. ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  32. ldm/models/diffusion/__pycache__/cfm1_audio.cpython-38.pyc +0 -0
  33. ldm/models/diffusion/__pycache__/cfm1_audio.cpython-39.pyc +0 -0
  34. ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc +0 -0
  35. ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc +0 -0
  36. ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc +0 -0
  37. ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc +0 -0
  38. ldm/models/diffusion/__pycache__/ddpm_audio.cpython-38.pyc +0 -0
  39. ldm/models/diffusion/__pycache__/ddpm_audio.cpython-39.pyc +0 -0
  40. ldm/models/diffusion/__pycache__/plms.cpython-38.pyc +0 -0
  41. ldm/models/diffusion/__pycache__/plms.cpython-39.pyc +0 -0
  42. ldm/models/diffusion/audioldm.py +818 -0
  43. ldm/models/diffusion/cfm1_audio.py +312 -0
  44. ldm/models/diffusion/cfm1_audio_sampler.py +105 -0
  45. ldm/models/diffusion/classifier.py +267 -0
  46. ldm/models/diffusion/ddim.py +262 -0
  47. ldm/models/diffusion/ddpm.py +1461 -0
  48. ldm/models/diffusion/ddpm_audio.py +865 -0
  49. ldm/models/diffusion/plms.py +236 -0
  50. ldm/models/diffusion/transport/__init__.py +73 -0
.gitattributes CHANGED
@@ -32,4 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *.tsv filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
README copy.md ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Make-An-Audio 3: Transforming Text into Audio via Flow-based Large Diffusion Transformers
2
+
3
+ PyTorch Implementation of [Lumina-t2x](https://arxiv.org/abs/2405.05945)
4
+
5
+ We will provide our implementation and pretrained models as open source in this repository recently.
6
+
7
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2305.18474)
8
+ [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/AIGC-Audio/Lumina-Audio)
9
+ [![GitHub Stars](https://img.shields.io/github/stars/Text-to-Audio/Make-An-Audio-3?style=social)](https://github.com/Text-to-Audio/Make-An-Audio-3)
10
+
11
+ ## Use pretrained model
12
+ We provide our implementation and pretrained models as open source in this repository.
13
+
14
+ Visit our [demo page](https://make-an-audio-2.github.io/) for audio samples.
15
+ ## Quick Started
16
+ ### Pretrained Models
17
+ Simply download the weights from [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/Alpha-VLLM/Lumina-T2Music).
18
+ - Text Encoder: [FLAN-T5-Large](https://huggingface.co/google/flan-t5-large)
19
+ - VAE: Make-An-Audio 2, finetuned from [Make an Audio](https://github.com/Text-to-Audio/Make-An-Audio)
20
+ - Decoder: [Vocoder](https://github.com/NVIDIA/BigVGAN)
21
+ - `Music` Checkpoints: [huggingface](https://huggingface.co/Alpha-VLLM/Lumina-T2Music), `Audio` Checkpoints: [huggingface]()
22
+
23
+ ### Generate audio/music from text
24
+ ```
25
+ python3 scripts/txt2audio_for_2cap_flow.py
26
+ --outdir output_dir -r checkpoints_last.ckpt -b configs/txt2audio-cfm1-cfg-LargeDiT3.yaml --scale 3.0
27
+ --vocoder-ckpt useful_ckpts/bigvnat --test-dataset audiocaps
28
+ ```
29
+
30
+ ### Generate audio/music from audiocaps or musiccaps test dataset
31
+ - remember to relatively change `config["test_dataset]`
32
+ ```
33
+ python3 scripts/txt2audio_for_2cap_flow.py
34
+ --outdir output_dir -r checkpoints_last.ckpt -b configs/txt2audio-cfm1-cfg-LargeDiT3.yaml --scale 3.0
35
+ --vocoder-ckpt useful_ckpts/bigvnat --test-dataset testset
36
+ ```
37
+
38
+ ### Generate audio/music from video
39
+ ```
40
+ python3 scripts/video2audio_flow.py
41
+ --outdir output_dir -r checkpoints_last.ckpt -b configs/txt2audio-cfm1-cfg-LargeDiT3.yaml --scale 3.0
42
+ --vocoder-ckpt useful_ckpts/bigvnat --test-dataset vggsound
43
+ ```
44
+
45
+ ## Train
46
+ ### Data preparation
47
+ - We can't provide the dataset download link for copyright issues. We provide the process code to generate melspec, count audio duration and generate structured caption.
48
+ - Before training, we need to construct the dataset information into a tsv file, which includes name (id for each audio), dataset (which dataset the audio belongs to), audio_path (the path of .wav file),caption (the caption of the audio) ,mel_path (the processed melspec file path of each audio), duration (the duration of the audio). We provide a tsv file of audiocaps test set: audiocaps_test_struct.tsv as a sample.
49
+ - We provide a tsv file of the audiocaps test set: ./audiocaps_test_16000_struct.tsv as a sample.
50
+
51
+ ### Generate the melspec file of audio
52
+ Assume you have already got a tsv file to link each caption to its audio_path, which mean the tsv_file have "name","audio_path","dataset" and "caption" columns in it.
53
+ To get the melspec of audio, run the following command, which will save mels in ./processed
54
+ ```
55
+ python preprocess/mel_spec.py --tsv_path tmp.tsv --num_gpus 1 --max_duration 10
56
+ ```
57
+
58
+ ### Count audio duration
59
+ To count the duration of the audio and save duration information in tsv file, run the following command:
60
+ ```
61
+ python preprocess/add_duration.py --tsv_path tmp.tsv
62
+ ```
63
+
64
+ ### Generated structure caption from the original natural language caption
65
+ Firstly you need to get an authorization token in openai(https://openai.com/blog/openai-api), here is a tutorial(https://www.maisieai.com/help/how-to-get-an-openai-api-key-for-chatgpt). Then replace your key of variable openai_key in preprocess/n2s_by_openai.py. Run the following command to add structed caption, the tsv file with structured caption will be saved into {tsv_file_name}_struct.tsv:
66
+ ```
67
+ python preprocess/n2s_by_openai.py --tsv_path tmp.tsv
68
+ ```
69
+
70
+ ### Place Tsv files
71
+ After generated structure caption, put the tsv with structed caption to ./data/main_spec_dir . And put tsv files without structured caption to ./data/no_struct_dir
72
+
73
+ Modify the config data.params.main_spec_dir and data.params.main_spec_dir.other_spec_dir_path respectively in config file configs/text2audio-ConcatDiT-ae1dnat_Skl20d2_struct2MLPanylen.yaml .
74
+
75
+ ## Train variational autoencoder
76
+ Assume we have processed several datasets, and save the .tsv files in tsv_dir/*.tsv . Replace data.params.spec_dir_path with tsv_dir in the config file. Then we can train VAE with the following command. If you don't have 8 gpus in your machine, you can replace --gpus 0,1,...,gpu_nums
77
+ ```
78
+ python main.py --base configs/research/autoencoder/autoencoder1d_kl20_natbig_r1_down2_disc2.yaml -t --gpus 0,1,2,3,4,5,6,7
79
+ ```
80
+
81
+ ## Train latent diffsuion
82
+ After trainning VAE, replace model.params.first_stage_config.params.ckpt_path with your trained VAE checkpoint path in the config file.
83
+ Run the following command to train Diffusion model
84
+ ```
85
+ python main.py --base configs/research/text2audio/text2audio-ConcatDiT-ae1dnat_Skl20d2_freezeFlananylen_drop.yaml -t --gpus 0,1,2,3,4,5,6,7
86
+ ```
87
+
88
+ ## Evaluation
89
+ Please refer to [Make-An-Audio](https://github.com/Text-to-Audio/Make-An-Audio?tab=readme-ov-file#evaluation)
90
+
91
+
92
+ ## Acknowledgements
93
+ This implementation uses parts of the code from the following Github repos:
94
+ [Make-An-Audio](https://github.com/Text-to-Audio/Make-An-Audio),
95
+ [AudioLCM](https://github.com/Text-to-Audio/AudioLCM),
96
+ [CLAP](https://github.com/LAION-AI/CLAP),
97
+ as described in our code.
98
+
99
+
100
+
101
+ ## Citations ##
102
+ If you find this code useful in your research, please consider citing:
103
+ ```bibtex
104
+ ```
105
+
106
+ # Disclaimer ##
107
+ Any organization or individual is prohibited from using any technology mentioned in this paper to generate someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws.
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import argparse, os, sys, glob
3
+ import pathlib
4
+ directory = pathlib.Path(os.getcwd())
5
+ print(directory)
6
+ sys.path.append(str(directory))
7
+ import torch
8
+ import numpy as np
9
+ from omegaconf import OmegaConf
10
+ from ldm.util import instantiate_from_config
11
+ from ldm.models.diffusion.ddim import DDIMSampler
12
+ from ldm.models.diffusion.plms import PLMSSampler
13
+ import pandas as pd
14
+ from tqdm import tqdm
15
+ import preprocess.n2s_by_openai as n2s
16
+ from vocoder.bigvgan.models import VocoderBigVGAN
17
+ import soundfile
18
+ import torchaudio, math
19
+ import gradio
20
+ import gradio as gr
21
+
22
+ def load_model_from_config(config, ckpt = None, verbose=True):
23
+ model = instantiate_from_config(config.model)
24
+ if ckpt:
25
+ print(f"Loading model from {ckpt}")
26
+ pl_sd = torch.load(ckpt, map_location="cpu")
27
+ sd = pl_sd["state_dict"]
28
+
29
+ m, u = model.load_state_dict(sd, strict=False)
30
+ if len(m) > 0 and verbose:
31
+ print("missing keys:")
32
+ print(m)
33
+ if len(u) > 0 and verbose:
34
+ print("unexpected keys:")
35
+ print(u)
36
+ else:
37
+ print(f"Note chat no ckpt is loaded !!!")
38
+
39
+ model.cuda()
40
+ model.eval()
41
+ return model
42
+
43
+
44
+ class GenSamples:
45
+ def __init__(self,opt, model,outpath,config, vocoder = None,save_mel = True,save_wav = True) -> None:
46
+ self.opt = opt
47
+ self.model = model
48
+ self.outpath = outpath
49
+ if save_wav:
50
+ assert vocoder is not None
51
+ self.vocoder = vocoder
52
+ self.save_mel = save_mel
53
+ self.save_wav = save_wav
54
+ self.channel_dim = self.model.channels
55
+ self.config = config
56
+
57
+ def gen_test_sample(self,prompt, mel_name = None,wav_name = None, gt=None, video=None):# prompt is {'ori_caption':’xxx‘,'struct_caption':'xxx'}
58
+ uc = None
59
+ record_dicts = []
60
+ if self.opt['scale'] != 1.0:
61
+ try: # audiocaps
62
+ uc = self.model.get_learned_conditioning({'ori_caption': "",'struct_caption': ""})
63
+ except: # audioset
64
+ uc = self.model.get_learned_conditioning(prompt['ori_caption'])
65
+ for n in range(self.opt['n_iter']):
66
+ try: # audiocaps
67
+ c = self.model.get_learned_conditioning(prompt) # shape:[1,77,1280],即还没有变成句子embedding,仍是每个单词的embedding
68
+ except: # audioset
69
+ c = self.model.get_learned_conditioning(prompt['ori_caption'])
70
+
71
+ if self.channel_dim>0:
72
+ shape = [self.channel_dim, self.opt['H'], self.opt['W']] # (z_dim, 80//2^x, 848//2^x)
73
+ else:
74
+ shape = [1, self.opt['H'], self.opt['W']]
75
+
76
+ x0 = torch.randn(shape, device=self.model.device)
77
+
78
+ if self.opt['scale'] == 1: # w/o cfg
79
+ sample, _ = self.model.sample(c, 1, timesteps=self.opt['ddim_steps'], x_latent=x0)
80
+ else: # cfg
81
+ sample, _ = self.model.sample_cfg(c, self.opt['scale'], uc, 1, timesteps=self.opt['ddim_steps'], x_latent=x0)
82
+ x_samples_ddim = self.model.decode_first_stage(sample)
83
+
84
+ for idx,spec in enumerate(x_samples_ddim):
85
+ spec = spec.squeeze(0).cpu().numpy()
86
+ print(spec[0])
87
+ record_dict = {'caption':prompt['ori_caption'][0]}
88
+ if self.save_mel:
89
+ mel_path = os.path.join(self.outpath,mel_name+f'_{idx}.npy')
90
+ np.save(mel_path,spec)
91
+ record_dict['mel_path'] = mel_path
92
+ if self.save_wav:
93
+ wav = self.vocoder.vocode(spec)
94
+ wav_path = os.path.join(self.outpath,wav_name+f'_{idx}.wav')
95
+ soundfile.write(wav_path, wav, self.opt['sample_rate'])
96
+ record_dict['audio_path'] = wav_path
97
+ record_dicts.append(record_dict)
98
+
99
+ return record_dicts
100
+
101
+ @spaces.GPU(enable_queue=True)
102
+ def infer(ori_prompt, ddim_steps, scale, seed):
103
+ # np.random.seed(seed)
104
+ # torch.manual_seed(seed)
105
+ prompt = dict(ori_caption=ori_prompt,struct_caption=f'<{ori_prompt}& all>')
106
+
107
+ opt = {
108
+ 'sample_rate': 16000,
109
+ 'outdir': 'outputs/txt2music-samples',
110
+ 'ddim_steps': ddim_steps,
111
+ 'n_iter': 1,
112
+ 'H': 20,
113
+ 'W': 312,
114
+ 'scale': scale,
115
+ 'resume': 'useful_ckpts/music_generation/119.ckpt',
116
+ 'base': 'configs/txt2music-cfm1-cfg-LargeDiT3.yaml',
117
+ 'vocoder_ckpt': 'useful_ckpts/bigvnat',
118
+ }
119
+
120
+ config = OmegaConf.load(opt['base'])
121
+ model = load_model_from_config(config, opt['resume'])
122
+
123
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
124
+ model = model.to(device)
125
+ os.makedirs(opt['outdir'], exist_ok=True)
126
+ vocoder = VocoderBigVGAN(opt['vocoder_ckpt'],device)
127
+ generator = GenSamples(opt, model,opt['outdir'],config, vocoder,save_mel=False,save_wav=True)
128
+
129
+ with torch.no_grad():
130
+ with model.ema_scope():
131
+ wav_name = f'{prompt["ori_caption"].strip().replace(" ", "-")}'
132
+ generator.gen_test_sample(prompt,wav_name=wav_name)
133
+
134
+ file_path = os.path.join(opt['outdir'],wav_name+'_0.wav')
135
+ print(f"Your samples are ready and waiting four you here: \n{file_path} \nEnjoy.")
136
+ return file_path
137
+
138
+ def my_inference_function(text_prompt, ddim_steps, scale, seed):
139
+ file_path = infer(text_prompt, ddim_steps, scale, seed)
140
+ return file_path
141
+
142
+
143
+ with gr.Blocks() as demo:
144
+ with gr.Row():
145
+ gr.Markdown("## Make-An-Audio 3: Transforming Text into Audio via Flow-based Large Diffusion Transformers")
146
+
147
+ with gr.Row():
148
+ with gr.Column():
149
+ prompt = gr.Textbox(label="Prompt: Input your text here. ")
150
+ run_button = gr.Button()
151
+
152
+ with gr.Accordion("Advanced options", open=False):
153
+ ddim_steps = gr.Slider(label="ddim_steps", minimum=1,
154
+ maximum=50, value=25, step=1)
155
+ scale = gr.Slider(
156
+ label="Guidance Scale:(Large => more relevant to text but the quality may drop)", minimum=0.1, maximum=8.0, value=3.0, step=0.1
157
+ )
158
+ seed = gr.Slider(
159
+ label="Seed:Change this value (any integer number) will lead to a different generation result.",
160
+ minimum=0,
161
+ maximum=2147483647,
162
+ step=1,
163
+ value=44,
164
+ )
165
+
166
+ with gr.Column():
167
+ outaudio = gr.Audio()
168
+
169
+ run_button.click(fn=my_inference_function, inputs=[
170
+ prompt, ddim_steps, scale, seed], outputs=[outaudio])
171
+ with gr.Row():
172
+ with gr.Column():
173
+ gr.Examples(
174
+ examples = [['An amateur recording features a steel drum playing in a higher register',25,5,55],
175
+ ['An instrumental song with a caribbean feel, happy mood, and featuring steel pan music, programmed percussion, and bass',25,5,55],
176
+ ['This musical piece features a playful and emotionally melodic male vocal accompanied by piano',25,5,55],
177
+ ['A eerie yet calming experimental electronic track featuring haunting synthesizer strings and pads',25,5,55],
178
+ ['A slow tempo pop instrumental piece featuring only acoustic guitar with fingerstyle and percussive strumming techniques',25,5,55]],
179
+ inputs = [prompt, ddim_steps, scale, seed],
180
+ outputs = [outaudio]
181
+ )
182
+ with gr.Column():
183
+ pass
184
+
185
+ demo.launch()
186
+
187
+
188
+ # gradio_interface = gradio.Interface(
189
+ # fn = my_inference_function,
190
+ # inputs = "text",
191
+ # outputs = "audio"
192
+ # )
193
+ # gradio_interface.launch()
194
+ # text_prompt = 'An amateur recording features a steel drum playing in a higher register'
195
+ # # text_prompt = 'A slow tempo pop instrumental piece featuring only acoustic guitar with fingerstyle and percussive strumming techniques'
196
+ # ddim_steps=25
197
+ # scale=5.0
198
+ # seed=55
199
+ # my_inference_function(text_prompt, ddim_steps, scale, seed)
audiocaps_test_struct.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36d5f93b134ee6ed8c7e75adffca2e0a378fb683e67836abd78b50153659858b
3
+ size 1306277
data/audiocaps_test_struct.tsv CHANGED
The diff for this file is too large to render. See raw diff
 
data/musiccaps_test_16000_struct.tsv CHANGED
The diff for this file is too large to render. See raw diff
 
infer.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # music prompt genneration
2
+ python3 scripts/txt2audio_for_2cap_flow.py \
3
+ --outdir output_dir_text -r useful_ckpts/music_generation/119.ckpt -b configs/txt2music-cfm1-cfg-LargeDiT3.yaml --scale 3.0 \
4
+ --vocoder-ckpt useful_ckpts/bigvnat
5
+
6
+ # music test dataset genneration
7
+ python3 scripts/txt2audio_for_2cap_flow.py \
8
+ --outdir results/music/dataset -r useful_ckpts/music_generation/119.ckpt -b configs/txt2music-cfm1-cfg-LargeDiT3.yaml --scale 3.0 \
9
+ --vocoder-ckpt useful_ckpts/bigvnat --test-dataset testset
10
+
11
+ # audio prompt genneration
12
+ python3 scripts/txt2audio_for_2cap_flow.py \
13
+ --prompt 'A train running on a railroad track followed by a vehicle door closing and a man talking in the distance while a train horn honks and railroad crossing warning signals ring' \
14
+ --outdir results/auido/text -r useful_ckpts/audio_generation/324.ckpt -b configs/txt2audio-cfm1-cfg-LargeDiT3.yaml --scale 3.0 \
15
+ --vocoder-ckpt useful_ckpts/bigvnat
16
+
17
+ # audio test dataset genneration
18
+ python3 scripts/txt2audio_for_2cap_flow.py \
19
+ --outdir results/auido/dataset -r useful_ckpts/audio_generation/324.ckpt -b configs/txt2audio-cfm1-cfg-LargeDiT3.yaml --scale 3.0 \
20
+ --vocoder-ckpt useful_ckpts/bigvnat --test-dataset testset
ldm/__pycache__/util.cpython-38.pyc ADDED
Binary file (5.1 kB). View file
 
ldm/__pycache__/util.cpython-39.pyc ADDED
Binary file (5.16 kB). View file
 
ldm/data/__pycache__/joinaudiodataset_anylen.cpython-38.pyc ADDED
Binary file (12.1 kB). View file
 
ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-38.pyc ADDED
Binary file (11.6 kB). View file
 
ldm/data/joinaudiodataset_anylen.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data.sampler import Sampler
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ import torch.distributed
9
+ from typing import TypeVar, Optional, Iterator,List
10
+ import logging
11
+ import pandas as pd
12
+ import glob
13
+ import torch.distributed as dist
14
+ logger = logging.getLogger(f'main.{__name__}')
15
+
16
+ sys.path.insert(0, '.') # nopep8
17
+
18
+ class JoinManifestSpecs(torch.utils.data.Dataset):
19
+ def __init__(self, split, spec_dir_path, mel_num=80,spec_crop_len=1248,mode='pad',pad_value=-5,drop=0,**kwargs):
20
+ super().__init__()
21
+ self.split = split
22
+ self.max_batch_len = spec_crop_len
23
+ self.min_batch_len = 64
24
+ self.mel_num = mel_num
25
+ self.min_factor = 4
26
+ self.drop = drop
27
+ self.pad_value = pad_value
28
+ assert mode in ['pad','tile']
29
+ self.collate_mode = mode
30
+ # print(f"################# self.collate_mode {self.collate_mode} ##################")
31
+
32
+ manifest_files = []
33
+ for dir_path in spec_dir_path.split(','):
34
+ manifest_files += glob.glob(f'{dir_path}/*.tsv')
35
+ df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
36
+ df = pd.concat(df_list,ignore_index=True)
37
+
38
+ if split == 'train':
39
+ self.dataset = df.iloc[100:]
40
+ elif split == 'valid' or split == 'val':
41
+ self.dataset = df.iloc[:100]
42
+ elif split == 'test':
43
+ df = self.add_name_num(df)
44
+ self.dataset = df
45
+ else:
46
+ raise ValueError(f'Unknown split {split}')
47
+ self.dataset.reset_index(inplace=True)
48
+ print('dataset len:', len(self.dataset))
49
+
50
+ def add_name_num(self,df):
51
+ """each file may have different caption, we add num to filename to identify each audio-caption pair"""
52
+ name_count_dict = {}
53
+ change = []
54
+ for t in df.itertuples():
55
+ name = getattr(t,'name')
56
+ if name in name_count_dict:
57
+ name_count_dict[name] += 1
58
+ else:
59
+ name_count_dict[name] = 0
60
+ change.append((t[0],name_count_dict[name]))
61
+ for t in change:
62
+ df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
63
+ return df
64
+
65
+ def ordered_indices(self):
66
+ index2dur = self.dataset[['duration']]
67
+ index2dur = index2dur.sort_values(by='duration')
68
+ return list(index2dur.index)
69
+
70
+ def __getitem__(self, idx):
71
+ item = {}
72
+ data = self.dataset.iloc[idx]
73
+ try:
74
+ spec = np.load(data['mel_path']) # mel spec [80, 624]
75
+ except:
76
+ mel_path = data['mel_path']
77
+ print(f'corrupted:{mel_path}')
78
+ spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
79
+
80
+
81
+ item['image'] = spec
82
+ p = np.random.uniform(0,1)
83
+ if p > self.drop:
84
+ item["caption"] = data['caption']
85
+ else:
86
+ item["caption"] = ""
87
+ if self.split == 'test':
88
+ item['f_name'] = data['name']
89
+ # item['f_name'] = data['mel_path']
90
+ return item
91
+
92
+ def collater(self,inputs):
93
+ to_dict = {}
94
+ for l in inputs:
95
+ for k,v in l.items():
96
+ if k in to_dict:
97
+ to_dict[k].append(v)
98
+ else:
99
+ to_dict[k] = [v]
100
+ if self.collate_mode == 'pad':
101
+ to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
102
+ elif self.collate_mode == 'tile':
103
+ to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
104
+ else:
105
+ raise NotImplementedError
106
+
107
+ return to_dict
108
+
109
+ def __len__(self):
110
+ return len(self.dataset)
111
+
112
+
113
+ class JoinSpecsTrain(JoinManifestSpecs):
114
+ def __init__(self, specs_dataset_cfg):
115
+ super().__init__('train', **specs_dataset_cfg)
116
+
117
+ class JoinSpecsValidation(JoinManifestSpecs):
118
+ def __init__(self, specs_dataset_cfg):
119
+ super().__init__('valid', **specs_dataset_cfg)
120
+
121
+ class JoinSpecsTest(JoinManifestSpecs):
122
+ def __init__(self, specs_dataset_cfg):
123
+ super().__init__('test', **specs_dataset_cfg)
124
+
125
+ class JoinSpecsDebug(JoinManifestSpecs):
126
+ def __init__(self, specs_dataset_cfg):
127
+ super().__init__('valid', **specs_dataset_cfg)
128
+ self.dataset = self.dataset.iloc[:37]
129
+
130
+ class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到一个batch中以避免过长的pad
131
+ def __init__(self, indices ,batch_size, num_replicas: Optional[int] = None,
132
+ rank: Optional[int] = None, shuffle: bool = True,
133
+ seed: int = 0, drop_last: bool = False) -> None:
134
+ if num_replicas is None:
135
+ if not dist.is_initialized():
136
+ # raise RuntimeError("Requires distributed package to be available")
137
+ print("Not in distributed mode")
138
+ num_replicas = 1
139
+ else:
140
+ num_replicas = dist.get_world_size()
141
+ if rank is None:
142
+ if not dist.is_initialized():
143
+ # raise RuntimeError("Requires distributed package to be available")
144
+ rank = 0
145
+ else:
146
+ rank = dist.get_rank()
147
+ if rank >= num_replicas or rank < 0:
148
+ raise ValueError(
149
+ "Invalid rank {}, rank should be in the interval"
150
+ " [0, {}]".format(rank, num_replicas - 1))
151
+ self.indices = indices
152
+ self.num_replicas = num_replicas
153
+ self.rank = rank
154
+ self.epoch = 0
155
+ self.drop_last = drop_last
156
+ self.batch_size = batch_size
157
+ self.batches = self.build_batches()
158
+ print(f"rank: {self.rank}, batches_num {len(self.batches)}")
159
+ # If the dataset length is evenly divisible by replicas, then there
160
+ # is no need to drop any data, since the dataset will be split equally.
161
+ if self.drop_last and len(self.batches) % self.num_replicas != 0:
162
+ self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
163
+ if len(self.batches) > self.num_replicas:
164
+ self.batches = self.batches[self.rank::self.num_replicas]
165
+ else: # may happen in sanity checking
166
+ self.batches = [self.batches[0]]
167
+ print(f"after split batches_num {len(self.batches)}")
168
+ self.shuffle = shuffle
169
+ if self.shuffle:
170
+ self.batches = np.random.permutation(self.batches)
171
+ self.seed = seed
172
+
173
+ def set_epoch(self,epoch):
174
+ self.epoch = epoch
175
+ if self.shuffle:
176
+ np.random.seed(self.seed+self.epoch)
177
+ self.batches = np.random.permutation(self.batches)
178
+
179
+ def build_batches(self):
180
+ batches,batch = [],[]
181
+ for index in self.indices:
182
+ batch.append(index)
183
+ if len(batch) == self.batch_size:
184
+ batches.append(batch)
185
+ batch = []
186
+ if not self.drop_last and len(batch) > 0:
187
+ batches.append(batch)
188
+ return batches
189
+
190
+ def __iter__(self) -> Iterator[List[int]]:
191
+ for batch in self.batches:
192
+ yield batch
193
+
194
+ def __len__(self) -> int:
195
+ return len(self.batches)
196
+
197
+ def set_epoch(self, epoch: int) -> None:
198
+ r"""
199
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
200
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
201
+ sampler will yield the same ordering.
202
+
203
+ Args:
204
+ epoch (int): Epoch number.
205
+ """
206
+ self.epoch = epoch
207
+
208
+
209
+ def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
210
+ if len(values[0].shape) == 1:
211
+ return collate_1d(values, pad_idx, left_pad, shift_right,min_len, max_len,min_factor, shift_id)
212
+ else:
213
+ return collate_2d(values, pad_idx, left_pad, shift_right,min_len,max_len,min_factor)
214
+
215
+ def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False,min_len=None, max_len=None,min_factor=None, shift_id=1):
216
+ """Convert a list of 1d tensors into a padded 2d tensor."""
217
+ size = max(v.size(0) for v in values)
218
+ if max_len:
219
+ size = min(size,max_len)
220
+ if min_len:
221
+ size = max(size,min_len)
222
+ if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
223
+ size += (min_factor - size % min_factor)
224
+ res = values[0].new(len(values), size).fill_(pad_idx)
225
+
226
+ def copy_tensor(src, dst):
227
+ assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
228
+ if shift_right:
229
+ dst[1:] = src[:-1]
230
+ dst[0] = shift_id
231
+ else:
232
+ dst.copy_(src)
233
+
234
+ for i, v in enumerate(values):
235
+ copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
236
+ return res
237
+
238
+
239
+ def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, min_len=None,max_len=None,min_factor=None):
240
+ """Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension.
241
+ values[0] shape: (melbins,mel_length)
242
+ """
243
+ size = max(v.shape[1] for v in values) # if max_len is None else max_len
244
+ if max_len:
245
+ size = min(size,max_len)
246
+ if min_len:
247
+ size = max(size,min_len)
248
+ if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
249
+ size += (min_factor - size % min_factor)
250
+
251
+ if isinstance(values,np.ndarray):
252
+ values = torch.FloatTensor(values)
253
+ if isinstance(values,list):
254
+ values = [torch.FloatTensor(v) for v in values]
255
+ res = torch.ones(len(values), values[0].shape[0],size).to(dtype=torch.float32)*pad_idx
256
+
257
+ def copy_tensor(src, dst):
258
+ assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
259
+ if shift_right:
260
+ dst[1:] = src[:-1]
261
+ else:
262
+ dst.copy_(src)
263
+
264
+ for i, v in enumerate(values):
265
+ copy_tensor(v[:,:size], res[i][:,size - v.shape[1]:] if left_pad else res[i][:,:v.shape[1]])
266
+ return res
267
+
268
+
269
+ def collate_1d_or_2d_tile(values, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
270
+ if len(values[0].shape) == 1:
271
+ return collate_1d_tile(values, shift_right,min_len, max_len,min_factor, shift_id)
272
+ else:
273
+ return collate_2d_tile(values, shift_right,min_len,max_len,min_factor)
274
+
275
+ def collate_1d_tile(values, shift_right=False,min_len=None, max_len=None,min_factor=None,shift_id=1):
276
+ """Convert a list of 1d tensors into a padded 2d tensor."""
277
+ size = max(v.size(0) for v in values)
278
+ if max_len:
279
+ size = min(size,max_len)
280
+ if min_len:
281
+ size = max(size,min_len)
282
+ if min_factor and (size%min_factor!=0):# size must be the multiple of min_factor
283
+ size += (min_factor - size % min_factor)
284
+ res = values[0].new(len(values), size)
285
+
286
+ def copy_tensor(src, dst):
287
+ assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
288
+ if shift_right:
289
+ dst[1:] = src[:-1]
290
+ dst[0] = shift_id
291
+ else:
292
+ dst.copy_(src)
293
+
294
+ for i, v in enumerate(values):
295
+ n_repeat = math.ceil((size + 1) / v.shape[0])
296
+ v = torch.tile(v,dims=(1,n_repeat))[:size]
297
+ copy_tensor(v, res[i])
298
+
299
+ return res
300
+
301
+
302
+ def collate_2d_tile(values, shift_right=False, min_len=None,max_len=None,min_factor=None):
303
+ """Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension. """
304
+ size = max(v.shape[1] for v in values) # if max_len is None else max_len
305
+ if max_len:
306
+ size = min(size,max_len)
307
+ if min_len:
308
+ size = max(size,min_len)
309
+ if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
310
+ size += (min_factor - size % min_factor)
311
+
312
+ if isinstance(values,np.ndarray):
313
+ values = torch.FloatTensor(values)
314
+ if isinstance(values,list):
315
+ values = [torch.FloatTensor(v) for v in values]
316
+ res = torch.zeros(len(values), values[0].shape[0],size).to(dtype=torch.float32)
317
+
318
+ def copy_tensor(src, dst):
319
+ assert dst.numel() == src.numel()
320
+ if shift_right:
321
+ dst[1:] = src[:-1]
322
+ else:
323
+ dst.copy_(src)
324
+
325
+ for i, v in enumerate(values):
326
+ n_repeat = math.ceil((size + 1) / v.shape[1])
327
+ v = torch.tile(v,dims=(1,n_repeat))[:,:size]
328
+ copy_tensor(v, res[i])
329
+
330
+ return res
ldm/data/joinaudiodataset_struct_sample_anylen.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import torch
5
+ from typing import TypeVar, Optional, Iterator
6
+ import logging
7
+ import pandas as pd
8
+ from ldm.data.joinaudiodataset_anylen import *
9
+ import glob
10
+ logger = logging.getLogger(f'main.{__name__}')
11
+
12
+ sys.path.insert(0, '.') # nopep8
13
+
14
+ class JoinManifestSpecs(torch.utils.data.Dataset):
15
+ def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=80,mode='pad', spec_crop_len=1248,pad_value=-5,drop=0,**kwargs):
16
+ super().__init__()
17
+ self.split = split
18
+ self.max_batch_len = spec_crop_len
19
+ self.min_batch_len = 64
20
+ self.min_factor = 4
21
+ self.mel_num = mel_num
22
+ self.drop = drop
23
+ self.pad_value = pad_value
24
+ assert mode in ['pad','tile']
25
+ self.collate_mode = mode
26
+ manifest_files = []
27
+ for dir_path in main_spec_dir_path.split(','):
28
+ manifest_files += glob.glob(f'{dir_path}/*.tsv')
29
+ df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
30
+ self.df_main = pd.concat(df_list,ignore_index=True)
31
+
32
+ # manifest_files = []
33
+ # for dir_path in other_spec_dir_path.split(','):
34
+ # manifest_files += glob.glob(f'{dir_path}/*.tsv')
35
+ # df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
36
+ # self.df_other = pd.concat(df_list,ignore_index=True)
37
+ # self.df_other.reset_index(inplace=True)
38
+
39
+ if split == 'train':
40
+ self.dataset = self.df_main.iloc[100:]
41
+ elif split == 'valid' or split == 'val':
42
+ self.dataset = self.df_main.iloc[:100]
43
+ elif split == 'test':
44
+ self.df_main = self.add_name_num(self.df_main)
45
+ self.dataset = self.df_main
46
+ else:
47
+ raise ValueError(f'Unknown split {split}')
48
+ self.dataset.reset_index(inplace=True)
49
+ print('dataset len:', len(self.dataset),"drop_rate",self.drop)
50
+
51
+ def add_name_num(self,df):
52
+ """each file may have different caption, we add num to filename to identify each audio-caption pair"""
53
+ name_count_dict = {}
54
+ change = []
55
+ for t in df.itertuples():
56
+ name = getattr(t,'name')
57
+ if name in name_count_dict:
58
+ name_count_dict[name] += 1
59
+ else:
60
+ name_count_dict[name] = 0
61
+ change.append((t[0],name_count_dict[name]))
62
+ for t in change:
63
+ df.loc[t[0],'name'] = str(df.loc[t[0],'name']) + f'_{t[1]}'
64
+ return df
65
+
66
+ def ordered_indices(self):
67
+ index2dur = self.dataset[['duration']].sort_values(by='duration')
68
+ # index2dur_other = self.df_other[['duration']].sort_values(by='duration')
69
+ # other_indices = list(index2dur_other.index)
70
+ offset = len(self.dataset)
71
+ # other_indices = [x + offset for x in other_indices]
72
+ return list(index2dur.index) # ,other_indices
73
+
74
+ def collater(self,inputs):
75
+ to_dict = {}
76
+ for l in inputs:
77
+ for k,v in l.items():
78
+ if k in to_dict:
79
+ to_dict[k].append(v)
80
+ else:
81
+ to_dict[k] = [v]
82
+
83
+ if self.collate_mode == 'pad':
84
+ to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
85
+ elif self.collate_mode == 'tile':
86
+ to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
87
+ else:
88
+ raise NotImplementedError
89
+ to_dict['caption'] = {'ori_caption':[c['ori_caption'] for c in to_dict['caption']],
90
+ 'struct_caption':[c['struct_caption'] for c in to_dict['caption']]}
91
+
92
+ return to_dict
93
+
94
+ def __getitem__(self, idx):
95
+ # if idx < len(self.dataset):
96
+ data = self.dataset.iloc[idx]
97
+ p = np.random.uniform(0,1)
98
+ if p > self.drop:
99
+ ori_caption = data['ori_cap']
100
+ struct_caption = data['caption']
101
+ else:
102
+ ori_caption = ""
103
+ struct_caption = ""
104
+ # else:
105
+ # data = self.df_other.iloc[idx-len(self.dataset)]
106
+ # p = np.random.uniform(0,1)
107
+ # if p > self.drop:
108
+ # ori_caption = data['caption']
109
+ # struct_caption = f'<{ori_caption}& all>'
110
+ # else:
111
+ # ori_caption = ""
112
+ # struct_caption = ""
113
+ item = {}
114
+ try:
115
+ if not os.path.exists(data['mel_path']):
116
+ mel_path = data['mel_path'].replace('/apdcephfs', '/apdcephfs_intern')
117
+ else:
118
+ mel_path = data['mel_path']
119
+ spec = np.load(mel_path) # mel spec [80, T]
120
+ if spec.shape[1] > self.max_batch_len:
121
+ spec = spec[:, :self.max_batch_len]
122
+ except:
123
+ mel_path = data['mel_path']
124
+ print(f'corrupted:{mel_path}')
125
+ spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
126
+
127
+ item['image'] = spec
128
+ item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
129
+ if self.split == 'test':
130
+ item['f_name'] = data['name']
131
+ return item
132
+
133
+ def __len__(self):
134
+ return len(self.dataset) # + len(self.df_other)
135
+
136
+
137
+ class JoinSpecsTrain(JoinManifestSpecs):
138
+ def __init__(self, specs_dataset_cfg):
139
+ super().__init__('train', **specs_dataset_cfg)
140
+
141
+ class JoinSpecsValidation(JoinManifestSpecs):
142
+ def __init__(self, specs_dataset_cfg):
143
+ super().__init__('valid', **specs_dataset_cfg)
144
+
145
+ class JoinSpecsTest(JoinManifestSpecs):
146
+ def __init__(self, specs_dataset_cfg):
147
+ super().__init__('test', **specs_dataset_cfg)
148
+
149
+
150
+ class TestManifest(torch.utils.data.Dataset):
151
+ def __init__(self, manifest, mel_num=80, mode='pad', spec_crop_len=1248, pad_value=-5, **kwargs):
152
+ super().__init__()
153
+ self.max_batch_len = spec_crop_len
154
+ self.min_batch_len = 64
155
+ self.min_factor = 4
156
+ self.mel_num = mel_num
157
+
158
+ self.pad_value = pad_value
159
+ assert mode in ['pad', 'tile']
160
+ self.collate_mode = mode
161
+
162
+ df_list = pd.read_csv(manifest, sep='\t')
163
+ self.df_main = pd.concat([df_list], ignore_index=True)
164
+ self.df_main = self.add_name_num(self.df_main)
165
+ self.dataset = self.df_main
166
+ self.dataset.reset_index(inplace=True)
167
+ print('dataset len:', len(self.dataset))
168
+
169
+ def add_name_num(self, df):
170
+ """each file may have different caption, we add num to filename to identify each audio-caption pair"""
171
+ name_count_dict = {}
172
+ change = []
173
+ for t in df.itertuples():
174
+ name = getattr(t, 'name')
175
+ if name in name_count_dict:
176
+ name_count_dict[name] += 1
177
+ else:
178
+ name_count_dict[name] = 0
179
+ change.append((t[0], name_count_dict[name]))
180
+ for t in change:
181
+ df.loc[t[0], 'name'] = str(df.loc[t[0], 'name']) + f'_{t[1]}'
182
+ return df
183
+
184
+ def ordered_indices(self):
185
+ index2dur = self.dataset[['duration']].sort_values(by='duration')
186
+ return list(index2dur.index) # ,other_indices
187
+
188
+ def collater(self, inputs):
189
+ to_dict = {}
190
+ for l in inputs:
191
+ for k, v in l.items():
192
+ if k in to_dict:
193
+ to_dict[k].append(v)
194
+ else:
195
+ to_dict[k] = [v]
196
+
197
+ if self.collate_mode == 'pad':
198
+ to_dict['image'] = collate_1d_or_2d(to_dict['image'], pad_idx=self.pad_value, min_len=self.min_batch_len,
199
+ max_len=self.max_batch_len, min_factor=self.min_factor)
200
+ elif self.collate_mode == 'tile':
201
+ to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'], min_len=self.min_batch_len,
202
+ max_len=self.max_batch_len, min_factor=self.min_factor)
203
+ else:
204
+ raise NotImplementedError
205
+ to_dict['caption'] = {'ori_caption': [c['ori_caption'] for c in to_dict['caption']],
206
+ 'struct_caption': [c['struct_caption'] for c in to_dict['caption']]}
207
+
208
+ return to_dict
209
+
210
+ def __getitem__(self, idx):
211
+ # if idx < len(self.dataset):
212
+ data = self.dataset.iloc[idx]
213
+ ori_caption = data['ori_cap']
214
+ struct_caption = data['caption']
215
+ item = {}
216
+ try:
217
+ if not os.path.exists(data['mel_path']):
218
+ mel_path = data['mel_path'].replace('/apdcephfs', '/apdcephfs_intern')
219
+ else:
220
+ mel_path = data['mel_path']
221
+ spec = np.load(mel_path) # mel spec [80, T]
222
+
223
+ if spec.shape[1] > self.max_batch_len:
224
+ spec = spec[:, :self.max_batch_len]
225
+ except:
226
+ mel_path = data['mel_path']
227
+ print(f'corrupted:{mel_path}')
228
+ spec = np.ones((self.mel_num, self.min_batch_len)).astype(np.float32) * self.pad_value
229
+
230
+ item['image'] = spec
231
+ item["caption"] = {"ori_caption": ori_caption, "struct_caption": struct_caption}
232
+ item['f_name'] = data['name']
233
+ return item
234
+
235
+ def __len__(self):
236
+ return len(self.dataset) # + len(self.df_other)
237
+
238
+
239
+
240
+ class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到一个batch中以避免过长的pad
241
+ def __init__(self, main_indices,batch_size, num_replicas: Optional[int] = None,
242
+ rank: Optional[int] = None, shuffle: bool = True,
243
+ seed: int = 0, drop_last: bool = False) -> None:
244
+ if num_replicas is None:
245
+ if not dist.is_initialized():
246
+ # raise RuntimeError("Requires distributed package to be available")
247
+ print("Not in distributed mode")
248
+ num_replicas = 1
249
+ else:
250
+ num_replicas = dist.get_world_size()
251
+ if rank is None:
252
+ if not dist.is_initialized():
253
+ # raise RuntimeError("Requires distributed package to be available")
254
+ rank = 0
255
+ else:
256
+ rank = dist.get_rank()
257
+ if rank >= num_replicas or rank < 0:
258
+ raise ValueError(
259
+ "Invalid rank {}, rank should be in the interval"
260
+ " [0, {}]".format(rank, num_replicas - 1))
261
+ self.main_indices = main_indices
262
+ # self.other_indices = other_indices
263
+ # self.max_index = max(self.other_indices)
264
+ self.num_replicas = num_replicas
265
+ self.rank = rank
266
+ self.epoch = 0
267
+ self.drop_last = drop_last
268
+ self.batch_size = batch_size
269
+ self.shuffle = shuffle
270
+ self.batches = self.build_batches()
271
+ self.seed = seed
272
+
273
+ def set_epoch(self,epoch):
274
+ # print("!!!!!!!!!!!set epoch is called!!!!!!!!!!!!!!")
275
+ self.epoch = epoch
276
+ if self.shuffle:
277
+ np.random.seed(self.seed+self.epoch)
278
+ self.batches = self.build_batches()
279
+
280
+ def build_batches(self):
281
+ batches,batch = [],[]
282
+ for index in self.main_indices:
283
+ batch.append(index)
284
+ if len(batch) == self.batch_size:
285
+ batches.append(batch)
286
+ batch = []
287
+ if not self.drop_last and len(batch) > 0:
288
+ batches.append(batch)
289
+ # selected_others = np.random.choice(len(self.other_indices),len(batches),replace=False)
290
+ # for index in selected_others:
291
+ # if index + self.batch_size > len(self.other_indices):
292
+ # index = len(self.other_indices) - self.batch_size
293
+ # batch = [self.other_indices[index + i] for i in range(self.batch_size)]
294
+ # batches.append(batch)
295
+ self.batches = batches
296
+ if self.shuffle:
297
+ self.batches = np.random.permutation(self.batches)
298
+ if self.rank == 0:
299
+ print(f"rank: {self.rank}, batches_num {len(self.batches)}")
300
+
301
+ if self.drop_last and len(self.batches) % self.num_replicas != 0:
302
+ self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
303
+ if len(self.batches) >= self.num_replicas:
304
+ self.batches = self.batches[self.rank::self.num_replicas]
305
+ else: # may happen in sanity checking
306
+ self.batches = [self.batches[0]]
307
+ if self.rank == 0:
308
+ print(f"after split batches_num {len(self.batches)}")
309
+
310
+ return self.batches
311
+
312
+ def __iter__(self) -> Iterator[List[int]]:
313
+ print(f"len(self.batches):{len(self.batches)}")
314
+ for batch in self.batches:
315
+ yield batch
316
+
317
+ def __len__(self) -> int:
318
+ return len(self.batches)
319
+
320
+
321
+ class JoinManifestSpecs_Caption(JoinManifestSpecs):
322
+ def collater(self, inputs):
323
+ to_dict = {}
324
+ for l in inputs:
325
+ for k, v in l.items():
326
+ if k in to_dict:
327
+ to_dict[k].append(v)
328
+ else:
329
+ to_dict[k] = [v]
330
+
331
+ if self.collate_mode == 'pad':
332
+ to_dict['image'] = collate_1d_or_2d(to_dict['image'], pad_idx=self.pad_value, min_len=self.min_batch_len,
333
+ max_len=self.max_batch_len, min_factor=self.min_factor)
334
+ elif self.collate_mode == 'tile':
335
+ to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'], min_len=self.min_batch_len,
336
+ max_len=self.max_batch_len, min_factor=self.min_factor)
337
+ else:
338
+ raise NotImplementedError
339
+
340
+ return to_dict
341
+
342
+ def __getitem__(self, idx):
343
+ # if idx < len(self.dataset):
344
+ data = self.dataset.iloc[idx]
345
+ p = np.random.uniform(0, 1)
346
+ if p > self.drop:
347
+ caption = data['ori_cap']
348
+ else:
349
+ caption = ""
350
+ item = {}
351
+ try:
352
+ if not os.path.exists(data['mel_path']):
353
+ mel_path = data['mel_path'].replace('/apdcephfs', '/apdcephfs_intern')
354
+ else:
355
+ mel_path = data['mel_path']
356
+ spec = np.load(mel_path) # mel spec [80, T]
357
+ if spec.shape[1] > self.max_batch_len:
358
+ spec = spec[:, :self.max_batch_len]
359
+ except:
360
+ mel_path = data['mel_path']
361
+ print(f'corrupted:{mel_path}')
362
+ spec = np.ones((self.mel_num, self.min_batch_len)).astype(np.float32) * self.pad_value
363
+
364
+ item['image'] = spec
365
+ item["caption"] = caption
366
+ if self.split == 'test':
367
+ item['f_name'] = data['name']
368
+ return item
369
+
370
+ class JoinSpecsTrain_Caption(JoinManifestSpecs_Caption):
371
+ def __init__(self, specs_dataset_cfg):
372
+ super().__init__('train', **specs_dataset_cfg)
373
+
374
+ class JoinSpecsValidation_Caption(JoinManifestSpecs_Caption):
375
+ def __init__(self, specs_dataset_cfg):
376
+ super().__init__('valid', **specs_dataset_cfg)
377
+
378
+ class JoinSpecsTest_Caption(JoinManifestSpecs_Caption):
379
+ def __init__(self, specs_dataset_cfg):
380
+ super().__init__('test', **specs_dataset_cfg)
ldm/data/tsv_dirs/full_data/V1_new/audiocaps_train_16000.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a34eeaf905d408e7faab9424f1742df3c1eb89e763c91ba355058b61e86c60b8
3
+ size 8042145
ldm/data/tsv_dirs/full_data/V2/MACS.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7e993db5676570b42daf04a7836ad0cfdbef4d04b8a73f56a5828f864ee37f6
3
+ size 6019546
ldm/data/tsv_dirs/full_data/V2/WavText5K.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:617bc20b11d6206e8735153a850b16449c484f52286dee4d7f67ed4f26bfb221
3
+ size 1145878
ldm/data/tsv_dirs/full_data/V2/adobe.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da973ea2f5e2440a832c40a022e33ef03aad24fbf2da7943ba5a77d43a7100d4
3
+ size 2138832
ldm/data/tsv_dirs/full_data/V2/audiostock.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cafe0c81c72b3fa1574f98fa293e4036f69f1c4b8d8cd9cb369087076482e63a
3
+ size 2028510
ldm/data/tsv_dirs/full_data/V2/epidemic_sound.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc67e42c9defa98edfc2c6b23c731fafa4a22307fddfd1fb95ccfc00d0168951
3
+ size 15062608
ldm/data/tsv_dirs/full_data/caps_struct/audiocaps_train_16000_struct2.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:565a506454c19ddd694cfb4b5c47a13f98e7966bce5617a7bbecec50c418257b
3
+ size 10208584
ldm/data/txt_spec_dataset.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import pickle
4
+ import sys
5
+
6
+ import numpy as np
7
+ import torch
8
+ import random
9
+ import math
10
+ import librosa
11
+ import pandas as pd
12
+ from pathlib import Path
13
+ class audio_spec_join_Dataset(torch.utils.data.Dataset):
14
+ # Only Load audio dataset: for training Stage1: Audio Npy Dataset
15
+ def __init__(self, split, dataset_name, spec_crop_len, drop=0.0):
16
+ super().__init__()
17
+
18
+ if split == "train":
19
+ self.split = "Train"
20
+
21
+ elif split == "valid" or split == 'test':
22
+ self.split = "Test"
23
+
24
+ # Default params:
25
+ self.min_duration = 2
26
+ self.spec_crop_len = spec_crop_len
27
+ self.drop = drop
28
+
29
+ print("Use Drop: {}".format(self.drop))
30
+
31
+ self.init_text2audio(dataset_name)
32
+
33
+ print('Split: {} Total Sample Num: {}'.format(split, len(self.dataset)))
34
+
35
+ if os.path.exists('/apdcephfs_intern/share_1316500/nlphuang/data/video_to_audio/vggsound/cavp/empty_vid.npz'):
36
+ self.root = '/apdcephfs_intern'
37
+ else:
38
+ self.root = '/apdcephfs'
39
+
40
+
41
+ def init_text2audio(self, dataset):
42
+
43
+ with open(dataset) as f:
44
+ reader = csv.DictReader(
45
+ f,
46
+ delimiter="\t",
47
+ quotechar=None,
48
+ doublequote=False,
49
+ lineterminator="\n",
50
+ quoting=csv.QUOTE_NONE,
51
+ )
52
+ samples = [dict(e) for e in reader]
53
+
54
+ if self.split == 'Test':
55
+ samples = samples[:100]
56
+
57
+ self.dataset = samples
58
+ print('text2audio dataset len:', len(self.dataset))
59
+
60
+ def __len__(self):
61
+ return len(self.dataset)
62
+
63
+ def load_feat(self, spec_path):
64
+ try:
65
+ spec_raw = np.load(spec_path) # mel spec [80, T]
66
+ except:
67
+ print(f'corrupted mel:{spec_path}', flush=True)
68
+ spec_raw = np.zeros((80, self.spec_crop_len), dtype=np.float32) # [C, T]
69
+
70
+ spec_len = self.spec_crop_len
71
+ if spec_raw.shape[1] < spec_len:
72
+ spec_raw = np.tile(spec_raw, math.ceil(spec_len / spec_raw.shape[1]))
73
+ spec_raw = spec_raw[:, :int(spec_len)]
74
+
75
+ return spec_raw
76
+
77
+
78
+ def __getitem__(self, idx):
79
+ data_dict = {}
80
+ data = self.dataset[idx]
81
+
82
+ p = np.random.uniform(0, 1)
83
+ if p > self.drop:
84
+ caption = {"ori_caption": data['ori_cap'], "struct_caption": data['caption']}
85
+ else:
86
+ caption = {"ori_caption": "", "struct_caption": ""}
87
+
88
+ mel_path = data['mel_path'].replace('/apdcephfs', '/apdcephfs_intern') if self.root == '/apdcephfs_intern' else data['mel_path']
89
+ spec = self.load_feat(mel_path)
90
+
91
+ data_dict['caption'] = caption
92
+ data_dict['image'] = spec # (80, 624)
93
+
94
+ return data_dict
95
+
96
+
97
+ class spec_join_Dataset_Train(audio_spec_join_Dataset):
98
+ def __init__(self, dataset_cfg):
99
+ super().__init__(split='train', **dataset_cfg)
100
+
101
+ class spec_join_Dataset_Valid(audio_spec_join_Dataset):
102
+ def __init__(self, dataset_cfg):
103
+ super().__init__(split='valid', **dataset_cfg)
104
+
105
+ class spec_join_Dataset_Test(audio_spec_join_Dataset):
106
+ def __init__(self, dataset_cfg):
107
+ super().__init__(split='test', **dataset_cfg)
108
+
109
+
110
+
111
+ class audio_spec_join_audioset_Dataset(audio_spec_join_Dataset):
112
+
113
+ # def __init__(self, split, dataset_name, root, spec_crop_len, drop=0.0):
114
+ # super().__init__(split, dataset_name, spec_crop_len, drop)
115
+ #
116
+ # self.data_dir = root
117
+ # MANIFEST_COLUMNS = ["name", "dataset", "ori_cap", "audio_path", "mel_path", "duration"]
118
+ # manifest = {c: [] for c in MANIFEST_COLUMNS}
119
+ # skip = 0
120
+ # if self.split != 'Train': return
121
+ # from preprocess.generate_manifest import save_df_to_tsv
122
+ # from tqdm import tqdm
123
+ # for idx in tqdm(range(len(self.dataset))):
124
+ # item = self.dataset[idx]
125
+ # mel_path = f'{self.data_dir}/{Path(item["name"])}_mel.npy'
126
+ # try:
127
+ # _ = np.load(mel_path)
128
+ # except:
129
+ # skip += 1
130
+ # continue
131
+ #
132
+ # manifest["name"].append(item['name'])
133
+ # manifest["dataset"].append("audioset")
134
+ # manifest["ori_cap"].append(item['ori_cap'])
135
+ # manifest["duration"].append(item['audio_path'])
136
+ # manifest["audio_path"].append(item['duration'])
137
+ # manifest["mel_path"].append(mel_path)
138
+ #
139
+ # print(f"Writing manifest to {dataset_name.replace('audioset.tsv', 'audioset_new.tsv')}..., skip: {skip}")
140
+ # save_df_to_tsv(pd.DataFrame.from_dict(manifest), f"{dataset_name.replace('audioset.tsv', 'audioset_new.tsv')}")
141
+
142
+
143
+ def __getitem__(self, idx):
144
+ data_dict = {}
145
+ data = self.dataset[idx]
146
+
147
+ p = np.random.uniform(0, 1)
148
+ if p > self.drop:
149
+ caption = data['ori_cap']
150
+ else:
151
+ caption = ""
152
+ spec = self.load_feat(data['mel_path'])
153
+
154
+ data_dict['caption'] = caption
155
+ data_dict['image'] = spec # (80, 624)
156
+
157
+ return data_dict
158
+
159
+
160
+
161
+ class spec_join_Dataset_audioset_Train(audio_spec_join_audioset_Dataset):
162
+ def __init__(self, dataset_cfg):
163
+ super().__init__(split='train', **dataset_cfg)
164
+
165
+ class spec_join_Dataset_audioset_Valid(audio_spec_join_audioset_Dataset):
166
+ def __init__(self, dataset_cfg):
167
+ super().__init__(split='valid', **dataset_cfg)
168
+
169
+ class spec_join_Dataset_audioset_Test(audio_spec_join_audioset_Dataset):
170
+ def __init__(self, dataset_cfg):
171
+ super().__init__(split='test', **dataset_cfg)
ldm/data/video_spec_maa2_dataset.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import pickle
4
+ import sys
5
+
6
+ import numpy as np
7
+ import torch
8
+ import random
9
+ import math
10
+ import librosa
11
+
12
+ class audio_video_spec_fullset_Dataset(torch.utils.data.Dataset):
13
+ # Only Load audio dataset: for training Stage1: Audio Npy Dataset
14
+ def __init__(self, split, dataset1, feat_type='clip', transforms=None, sr=22050, duration=10, truncate=220000, fps=21.5, drop=0.0, fix_frames=False, hop_len=256):
15
+ super().__init__()
16
+
17
+ if split == "train":
18
+ self.split = "Train"
19
+
20
+ elif split == "valid" or split == 'test':
21
+ self.split = "Test"
22
+
23
+ # Default params:
24
+ self.min_duration = 2
25
+ self.sr = sr # 22050
26
+ self.duration = duration # 10
27
+ self.truncate = truncate # 220000
28
+ self.fps = fps
29
+ self.fix_frames = fix_frames
30
+ self.hop_len = hop_len
31
+ self.drop = drop
32
+ print("Fix Frames: {}".format(self.fix_frames))
33
+ print("Use Drop: {}".format(self.drop))
34
+
35
+ # Dataset1: (VGGSound)
36
+ assert dataset1.dataset_name == "VGGSound"
37
+
38
+ # spec_dir: spectrogram path
39
+ # feat_dir: CAVP feature path
40
+ # video_dir: video path
41
+
42
+ dataset1_spec_dir = os.path.join(dataset1.data_dir, "mel_maa2", "npy")
43
+ dataset1_feat_dir = os.path.join(dataset1.data_dir, "cavp")
44
+ dataset1_video_dir = os.path.join(dataset1.video_dir, "tmp_vid")
45
+
46
+ split_txt_path = dataset1.split_txt_path
47
+ with open(os.path.join(split_txt_path, '{}.txt'.format(self.split)), "r") as f:
48
+ data_list1 = f.readlines()
49
+ data_list1 = list(map(lambda x: x.strip(), data_list1))
50
+
51
+ spec_list1 = list(map(lambda x: os.path.join(dataset1_spec_dir, x) + "_mel.npy", data_list1)) # spec
52
+ feat_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, x) + ".npz", data_list1)) # feat
53
+ video_list1 = list(map(lambda x: os.path.join(dataset1_video_dir, x) + "_new_fps_21.5_truncate_0_10.0.mp4", data_list1)) # video
54
+
55
+
56
+ # Merge Data:
57
+ self.data_list = data_list1 if self.split != "Test" else data_list1[:200]
58
+ self.spec_list = spec_list1 if self.split != "Test" else spec_list1[:200]
59
+ self.feat_list = feat_list1 if self.split != "Test" else feat_list1[:200]
60
+ self.video_list = video_list1 if self.split != "Test" else video_list1[:200]
61
+
62
+ assert len(self.data_list) == len(self.spec_list) == len(self.feat_list) == len(self.video_list)
63
+
64
+ shuffle_idx = np.random.permutation(np.arange(len(self.data_list)))
65
+ self.data_list = [self.data_list[i] for i in shuffle_idx]
66
+ self.spec_list = [self.spec_list[i] for i in shuffle_idx]
67
+ self.feat_list = [self.feat_list[i] for i in shuffle_idx]
68
+ self.video_list = [self.video_list[i] for i in shuffle_idx]
69
+
70
+ print('Split: {} Sample Num: {}'.format(split, len(self.data_list)))
71
+
72
+
73
+
74
+ def __len__(self):
75
+ return len(self.data_list)
76
+
77
+
78
+ def load_spec_and_feat(self, spec_path, video_feat_path):
79
+ """Load audio spec and video feat"""
80
+ try:
81
+ spec_raw = np.load(spec_path).astype(np.float32) # channel: 1
82
+ except:
83
+ print(f"corrupted mel: {spec_path}", flush=True)
84
+ spec_raw = np.zeros((80, 625), dtype=np.float32) # [C, T]
85
+
86
+ p = np.random.uniform(0,1)
87
+ if p > self.drop:
88
+ try:
89
+ video_feat = np.load(video_feat_path)['feat'].astype(np.float32)
90
+ except:
91
+ print(f"corrupted video: {video_feat_path}", flush=True)
92
+ video_feat = np.load(os.path.join(os.path.dirname(video_feat_path), 'empty_vid.npz'))['feat'].astype(np.float32)
93
+ else:
94
+ video_feat = np.load(os.path.join(os.path.dirname(video_feat_path), 'empty_vid.npz'))['feat'].astype(np.float32)
95
+
96
+ spec_len = self.sr * self.duration / self.hop_len
97
+ if spec_raw.shape[1] < spec_len:
98
+ spec_raw = np.tile(spec_raw, math.ceil(spec_len / spec_raw.shape[1]))
99
+ spec_raw = spec_raw[:, :int(spec_len)]
100
+
101
+ feat_len = self.fps * self.duration
102
+ if video_feat.shape[0] < feat_len:
103
+ video_feat = np.tile(video_feat, (math.ceil(feat_len / video_feat.shape[0]), 1))
104
+ video_feat = video_feat[:int(feat_len)]
105
+ return spec_raw, video_feat
106
+
107
+
108
+ def mix_audio_and_feat(self, spec1=None, spec2=None, video_feat1=None, video_feat2=None, video_info_dict={}, mode='single'):
109
+ """ Return Mix Spec and Mix video feat"""
110
+ if mode == "single":
111
+ # spec1:
112
+ if not self.fix_frames:
113
+ start_idx = random.randint(0, self.sr * self.duration - self.truncate - 1) # audio start
114
+ else:
115
+ start_idx = 0
116
+
117
+ start_frame = int(self.fps * start_idx / self.sr)
118
+ truncate_frame = int(self.fps * self.truncate / self.sr)
119
+
120
+ # Spec Start & Truncate:
121
+ spec_start = int(start_idx / self.hop_len)
122
+ spec_truncate = int(self.truncate / self.hop_len)
123
+
124
+ spec1 = spec1[:, spec_start : spec_start + spec_truncate]
125
+ video_feat1 = video_feat1[start_frame: start_frame + truncate_frame]
126
+
127
+ # info_dict:
128
+ video_info_dict['video_time1'] = str(start_frame) + '_' + str(start_frame+truncate_frame) # Start frame, end frame
129
+ video_info_dict['video_time2'] = ""
130
+ return spec1, video_feat1, video_info_dict
131
+
132
+ elif mode == "concat":
133
+ total_spec_len = int(self.truncate / self.hop_len)
134
+ # Random Trucate len:
135
+ spec1_truncate_len = random.randint(self.min_duration * self.sr // self.hop_len, total_spec_len - self.min_duration * self.sr // self.hop_len - 1)
136
+ spec2_truncate_len = total_spec_len - spec1_truncate_len
137
+
138
+ # Sample spec clip:
139
+ spec_start1 = random.randint(0, total_spec_len - spec1_truncate_len - 1)
140
+ spec_start2 = random.randint(0, total_spec_len - spec2_truncate_len - 1)
141
+ spec_end1, spec_end2 = spec_start1 + spec1_truncate_len, spec_start2 + spec2_truncate_len
142
+
143
+ # concat spec:
144
+ spec1, spec2 = spec1[:, spec_start1 : spec_end1], spec2[:, spec_start2 : spec_end2]
145
+ concat_audio_spec = np.concatenate([spec1, spec2], axis=1)
146
+
147
+ # Concat Video Feat:
148
+ start1_frame, truncate1_frame = int(self.fps * spec_start1 * self.hop_len / self.sr), int(self.fps * spec1_truncate_len * self.hop_len / self.sr)
149
+ start2_frame, truncate2_frame = int(self.fps * spec_start2 * self.hop_len / self.sr), int(self.fps * self.truncate / self.sr) - truncate1_frame
150
+ video_feat1, video_feat2 = video_feat1[start1_frame : start1_frame + truncate1_frame], video_feat2[start2_frame : start2_frame + truncate2_frame]
151
+ concat_video_feat = np.concatenate([video_feat1, video_feat2])
152
+
153
+ video_info_dict['video_time1'] = str(start1_frame) + '_' + str(start1_frame+truncate1_frame) # Start frame, end frame
154
+ video_info_dict['video_time2'] = str(start2_frame) + '_' + str(start2_frame+truncate2_frame)
155
+ return concat_audio_spec, concat_video_feat, video_info_dict
156
+
157
+
158
+
159
+ def __getitem__(self, idx):
160
+
161
+ audio_name1 = self.data_list[idx]
162
+ spec_npy_path1 = self.spec_list[idx]
163
+ video_feat_path1 = self.feat_list[idx]
164
+ video_path1 = self.video_list[idx]
165
+
166
+ # select other video:
167
+ flag = False
168
+ if random.uniform(0, 1) < 0.5:
169
+ flag = True
170
+ random_idx = idx
171
+ while random_idx == idx:
172
+ random_idx = random.randint(0, len(self.data_list)-1)
173
+ audio_name2 = self.data_list[random_idx]
174
+ spec_npy_path2 = self.spec_list[random_idx]
175
+ video_feat_path2 = self.feat_list[random_idx]
176
+ video_path2 = self.video_list[random_idx]
177
+
178
+ # Load the Spec and Feat:
179
+ spec1, video_feat1 = self.load_spec_and_feat(spec_npy_path1, video_feat_path1)
180
+
181
+ if flag:
182
+ spec2, video_feat2 = self.load_spec_and_feat(spec_npy_path2, video_feat_path2)
183
+ video_info_dict = {'audio_name1':audio_name1, 'audio_name2': audio_name2, 'video_path1': video_path1, 'video_path2': video_path2}
184
+ mix_spec, mix_video_feat, mix_info = self.mix_audio_and_feat(spec1, spec2, video_feat1, video_feat2, video_info_dict, mode='concat')
185
+ else:
186
+ video_info_dict = {'audio_name1':audio_name1, 'audio_name2': "", 'video_path1': video_path1, 'video_path2': ""}
187
+ mix_spec, mix_video_feat, mix_info = self.mix_audio_and_feat(spec1=spec1, video_feat1=video_feat1, video_info_dict=video_info_dict, mode='single')
188
+
189
+ # print("mix spec shape:", mix_spec.shape)
190
+ # print("mix video feat:", mix_video_feat.shape)
191
+ data_dict = {}
192
+ # data_dict['mix_spec'] = mix_spec[None].repeat(3, axis=0) # TODO:要把这里改掉,否则无法适应maa的autoencoder
193
+ data_dict['mix_spec'] = mix_spec # (80, 512)
194
+ data_dict['mix_video_feat'] = mix_video_feat # (32, 512)
195
+ data_dict['mix_info_dict'] = mix_info
196
+
197
+ return data_dict
198
+
199
+
200
+
201
+ class audio_video_spec_fullset_Dataset_Train(audio_video_spec_fullset_Dataset):
202
+ def __init__(self, dataset_cfg):
203
+ super().__init__(split='train', **dataset_cfg)
204
+
205
+ class audio_video_spec_fullset_Dataset_Valid(audio_video_spec_fullset_Dataset):
206
+ def __init__(self, dataset_cfg):
207
+ super().__init__(split='valid', **dataset_cfg)
208
+
209
+ class audio_video_spec_fullset_Dataset_Test(audio_video_spec_fullset_Dataset):
210
+ def __init__(self, dataset_cfg):
211
+ super().__init__(split='test', **dataset_cfg)
212
+
213
+
214
+
215
+ class audio_video_spec_fullset_Dataset_inpaint(audio_video_spec_fullset_Dataset):
216
+
217
+ def __getitem__(self, idx):
218
+
219
+ audio_name1 = self.data_list[idx]
220
+ spec_npy_path1 = self.spec_list[idx]
221
+ video_feat_path1 = self.feat_list[idx]
222
+ video_path1 = self.video_list[idx]
223
+
224
+ # Load the Spec and Feat:
225
+ spec1, video_feat1 = self.load_spec_and_feat(spec_npy_path1, video_feat_path1)
226
+
227
+ video_info_dict = {'audio_name1': audio_name1, 'audio_name2': "", 'video_path1': video_path1, 'video_path2': ""}
228
+ mix_spec, mix_masked_spec, mix_video_feat, mix_info = self.mix_audio_and_feat(spec1=spec1, video_feat1=video_feat1, video_info_dict=video_info_dict)
229
+
230
+ # print("mix spec shape:", mix_spec.shape)
231
+ # print("mix video feat:", mix_video_feat.shape)
232
+ data_dict = {}
233
+ # data_dict['mix_spec'] = mix_spec[None].repeat(3, axis=0) # TODO:要把这里改掉,否则无法适应maa的autoencoder
234
+ data_dict['mix_spec'] = mix_spec # (80, 512)
235
+ data_dict['hybrid_feat'] = {'mix_video_feat': mix_video_feat, 'mix_spec': mix_masked_spec} # (32, 512)
236
+ data_dict['mix_info_dict'] = mix_info
237
+
238
+ return data_dict
239
+
240
+ def mix_audio_and_feat(self, spec1=None, video_feat1=None, video_info_dict={}):
241
+ """ Return Mix Spec and Mix video feat"""
242
+
243
+ # spec1:
244
+ if not self.fix_frames:
245
+ start_idx = random.randint(0, self.sr * self.duration - self.truncate - 1) # audio start
246
+ else:
247
+ start_idx = 0
248
+
249
+ start_frame = int(self.fps * start_idx / self.sr)
250
+ truncate_frame = int(self.fps * self.truncate / self.sr)
251
+
252
+ # Spec Start & Truncate:
253
+ spec_start = int(start_idx / self.hop_len)
254
+ spec_truncate = int(self.truncate / self.hop_len)
255
+
256
+ spec1 = spec1[:, spec_start: spec_start + spec_truncate]
257
+ video_feat1 = video_feat1[start_frame: start_frame + truncate_frame]
258
+
259
+ # Start masking frames:
260
+ masked_spec = random.randint(1, int(spec_truncate * 0.5 // 16)) * 16 # 16帧的倍数,最多mask 50%
261
+ masked_truncate = int(masked_spec * self.hop_len)
262
+ masked_frame = int(self.fps * masked_truncate / self.sr)
263
+
264
+ start_masked_idx = random.randint(0, self.truncate - masked_truncate - 1)
265
+ start_masked_frame = int(self.fps * start_masked_idx / self.sr)
266
+ start_masked_spec = int(start_masked_idx / self.hop_len)
267
+
268
+ masked_spec1 = np.zeros((80, spec_truncate)).astype(np.float32)
269
+ masked_spec1[:] = spec1[:]
270
+ masked_spec1[:, start_masked_spec:start_masked_spec+masked_spec] = np.zeros((80, masked_spec))
271
+ video_feat1[start_masked_frame:start_masked_frame+masked_frame, :] = np.zeros((masked_frame, 512))
272
+ # info_dict:
273
+ video_info_dict['video_time1'] = str(start_frame) + '_' + str(start_frame + truncate_frame) # Start frame, end frame
274
+ video_info_dict['video_time2'] = ""
275
+ return spec1, masked_spec1, video_feat1, video_info_dict
276
+
277
+
278
+
279
+ class audio_video_spec_fullset_Dataset_inpaint_Train(audio_video_spec_fullset_Dataset_inpaint):
280
+ def __init__(self, dataset_cfg):
281
+ super().__init__(split='train', **dataset_cfg)
282
+
283
+ class audio_video_spec_fullset_Dataset_inpaint_Valid(audio_video_spec_fullset_Dataset_inpaint):
284
+ def __init__(self, dataset_cfg):
285
+ super().__init__(split='valid', **dataset_cfg)
286
+
287
+ class audio_video_spec_fullset_Dataset_inpaint_Test(audio_video_spec_fullset_Dataset_inpaint):
288
+ def __init__(self, dataset_cfg):
289
+ super().__init__(split='test', **dataset_cfg)
290
+
291
+
292
+
293
+ class audio_Dataset(torch.utils.data.Dataset):
294
+ # Only Load audio dataset: for training Stage1: Audio Npy Dataset
295
+ def __init__(self, split, dataset1, sr=22050, duration=10, truncate=220000, debug_num=False, fix_frames=False, hop_len=256):
296
+ super().__init__()
297
+
298
+ if split == "train":
299
+ self.split = "Train"
300
+
301
+ elif split == "valid" or split == 'test':
302
+ self.split = "Test"
303
+
304
+ # Default params:
305
+ self.min_duration = 2
306
+ self.sr = sr # 22050
307
+ self.duration = duration # 10
308
+ self.truncate = truncate # 220000
309
+ self.fix_frames = fix_frames
310
+ self.hop_len = hop_len
311
+ print("Fix Frames: {}".format(self.fix_frames))
312
+
313
+
314
+ # Dataset1: (VGGSound)
315
+ assert dataset1.dataset_name == "VGGSound"
316
+
317
+ # spec_dir: spectrogram path
318
+
319
+ # dataset1_spec_dir = os.path.join(dataset1.data_dir, "codec")
320
+ dataset1_wav_dir = os.path.join(dataset1.wav_dir, "wav")
321
+
322
+ split_txt_path = dataset1.split_txt_path
323
+ with open(os.path.join(split_txt_path, '{}.txt'.format(self.split)), "r") as f:
324
+ data_list1 = f.readlines()
325
+ data_list1 = list(map(lambda x: x.strip(), data_list1))
326
+ wav_list1 = list(map(lambda x: os.path.join(dataset1_wav_dir, x) + ".wav", data_list1)) # feat
327
+
328
+ # Merge Data:
329
+ self.data_list = data_list1
330
+ self.wav_list = wav_list1
331
+
332
+ assert len(self.data_list) == len(self.wav_list)
333
+
334
+ shuffle_idx = np.random.permutation(np.arange(len(self.data_list)))
335
+ self.data_list = [self.data_list[i] for i in shuffle_idx]
336
+ self.wav_list = [self.wav_list[i] for i in shuffle_idx]
337
+
338
+ if debug_num:
339
+ self.data_list = self.data_list[:debug_num]
340
+ self.wav_list = self.wav_list[:debug_num]
341
+
342
+ print('Split: {} Sample Num: {}'.format(split, len(self.data_list)))
343
+
344
+
345
+ def __len__(self):
346
+ return len(self.data_list)
347
+
348
+
349
+ def load_spec_and_feat(self, wav_path):
350
+ """Load audio spec and video feat"""
351
+ try:
352
+ wav_raw, sr = librosa.load(wav_path, sr=self.sr) # channel: 1
353
+ except:
354
+ print(f"corrupted wav: {wav_path}", flush=True)
355
+ wav_raw = np.zeros((160000,), dtype=np.float32) # [T]
356
+
357
+ wav_len = self.sr * self.duration
358
+ if wav_raw.shape[0] < wav_len:
359
+ wav_raw = np.tile(wav_raw, math.ceil(wav_len / wav_raw.shape[0]))
360
+ wav_raw = wav_raw[:int(wav_len)]
361
+
362
+ return wav_raw
363
+
364
+
365
+ def mix_audio_and_feat(self, wav_raw1=None, video_info_dict={}, mode='single'):
366
+ """ Return Mix Spec and Mix video feat"""
367
+ if mode == "single":
368
+ # spec1:
369
+ if not self.fix_frames:
370
+ start_idx = random.randint(0, self.sr * self.duration - self.truncate - 1) # audio start
371
+ else:
372
+ start_idx = 0
373
+
374
+ wav_start = start_idx
375
+ wav_truncate = self.truncate
376
+ wav_raw1 = wav_raw1[wav_start: wav_start + wav_truncate]
377
+
378
+ return wav_raw1, video_info_dict
379
+
380
+ elif mode == "concat":
381
+ total_spec_len = int(self.truncate / self.hop_len)
382
+ # Random Trucate len:
383
+ spec1_truncate_len = random.randint(self.min_duration * self.sr // self.hop_len, total_spec_len - self.min_duration * self.sr // self.hop_len - 1)
384
+ spec2_truncate_len = total_spec_len - spec1_truncate_len
385
+
386
+ # Sample spec clip:
387
+ spec_start1 = random.randint(0, total_spec_len - spec1_truncate_len - 1)
388
+ spec_start2 = random.randint(0, total_spec_len - spec2_truncate_len - 1)
389
+ spec_end1, spec_end2 = spec_start1 + spec1_truncate_len, spec_start2 + spec2_truncate_len
390
+
391
+ # concat spec:
392
+ return video_info_dict
393
+
394
+
395
+ def __getitem__(self, idx):
396
+
397
+ audio_name1 = self.data_list[idx]
398
+ wav_path1 = self.wav_list[idx]
399
+ # select other video:
400
+ flag = False
401
+ if random.uniform(0, 1) < -1:
402
+ flag = True
403
+ random_idx = idx
404
+ while random_idx == idx:
405
+ random_idx = random.randint(0, len(self.data_list)-1)
406
+ audio_name2 = self.data_list[random_idx]
407
+ spec_npy_path2 = self.spec_list[random_idx]
408
+ wav_path2 = self.wav_list[random_idx]
409
+
410
+ # Load the Spec and Feat:
411
+ wav_raw1 = self.load_spec_and_feat(wav_path1)
412
+
413
+ if flag:
414
+ spec2, video_feat2 = self.load_spec_and_feat(spec_npy_path2, wav_path2)
415
+ video_info_dict = {'audio_name1':audio_name1, 'audio_name2': audio_name2}
416
+ mix_spec, mix_video_feat, mix_info = self.mix_audio_and_feat(video_info_dict, mode='concat')
417
+ else:
418
+ video_info_dict = {'audio_name1':audio_name1, 'audio_name2': ""}
419
+ mix_wav, mix_info = self.mix_audio_and_feat(wav_raw1=wav_raw1, video_info_dict=video_info_dict, mode='single')
420
+
421
+ data_dict = {}
422
+ data_dict['mix_wav'] = mix_wav # (131072,)
423
+ data_dict['mix_info_dict'] = mix_info
424
+
425
+ return data_dict
426
+
427
+
428
+ class audio_Dataset_Train(audio_Dataset):
429
+ def __init__(self, dataset_cfg):
430
+ super().__init__(split='train', **dataset_cfg)
431
+
432
+ class audio_Dataset_Test(audio_Dataset):
433
+ def __init__(self, dataset_cfg):
434
+ super().__init__(split='test', **dataset_cfg)
435
+
436
+ class audio_Dataset_Valid(audio_Dataset):
437
+ def __init__(self, dataset_cfg):
438
+ super().__init__(split='valid', **dataset_cfg)
439
+
440
+
441
+
442
+ class video_codec_Dataset(torch.utils.data.Dataset):
443
+ # Only Load audio dataset: for training Stage1: Audio Npy Dataset
444
+ def __init__(self, split, dataset1, sr=22050, duration=10, truncate=220000, fps=21.5, debug_num=False, fix_frames=False, hop_len=256):
445
+ super().__init__()
446
+
447
+ if split == "train":
448
+ self.split = "Train"
449
+
450
+ elif split == "valid" or split == 'test':
451
+ self.split = "Test"
452
+
453
+ # Default params:
454
+ self.min_duration = 2
455
+ self.fps = fps
456
+ self.sr = sr # 22050
457
+ self.duration = duration # 10
458
+ self.truncate = truncate # 220000
459
+ self.fix_frames = fix_frames
460
+ self.hop_len = hop_len
461
+ print("Fix Frames: {}".format(self.fix_frames))
462
+
463
+
464
+ # Dataset1: (VGGSound)
465
+ assert dataset1.dataset_name == "VGGSound"
466
+
467
+ # spec_dir: spectrogram path
468
+
469
+ # dataset1_spec_dir = os.path.join(dataset1.data_dir, "codec")
470
+ dataset1_feat_dir = os.path.join(dataset1.data_dir, "cavp")
471
+ dataset1_wav_dir = os.path.join(dataset1.wav_dir, "wav")
472
+
473
+ split_txt_path = dataset1.split_txt_path
474
+ with open(os.path.join(split_txt_path, '{}.txt'.format(self.split)), "r") as f:
475
+ data_list1 = f.readlines()
476
+ data_list1 = list(map(lambda x: x.strip(), data_list1))
477
+ wav_list1 = list(map(lambda x: os.path.join(dataset1_wav_dir, x) + ".wav", data_list1)) # feat
478
+ feat_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, x) + ".npz", data_list1)) # feat
479
+
480
+ # Merge Data:
481
+ self.data_list = data_list1
482
+ self.wav_list = wav_list1
483
+ self.feat_list = feat_list1
484
+
485
+ assert len(self.data_list) == len(self.wav_list)
486
+
487
+ shuffle_idx = np.random.permutation(np.arange(len(self.data_list)))
488
+ self.data_list = [self.data_list[i] for i in shuffle_idx]
489
+ self.wav_list = [self.wav_list[i] for i in shuffle_idx]
490
+ self.feat_list = [self.feat_list[i] for i in shuffle_idx]
491
+
492
+ if debug_num:
493
+ self.data_list = self.data_list[:debug_num]
494
+ self.wav_list = self.wav_list[:debug_num]
495
+ self.feat_list = self.feat_list[:debug_num]
496
+
497
+ print('Split: {} Sample Num: {}'.format(split, len(self.data_list)))
498
+
499
+
500
+ def __len__(self):
501
+ return len(self.data_list)
502
+
503
+
504
+ def load_spec_and_feat(self, wav_path, video_feat_path):
505
+ """Load audio spec and video feat"""
506
+ try:
507
+ wav_raw, sr = librosa.load(wav_path, sr=self.sr) # channel: 1
508
+ except:
509
+ print(f"corrupted wav: {wav_path}", flush=True)
510
+ wav_raw = np.zeros((160000,), dtype=np.float32) # [T]
511
+
512
+ try:
513
+ video_feat = np.load(video_feat_path)['feat'].astype(np.float32)
514
+ except:
515
+ print(f"corrupted video: {video_feat_path}", flush=True)
516
+ video_feat = np.load(os.path.join(os.path.dirname(video_feat_path), 'empty_vid.npz'))['feat'].astype(np.float32)
517
+
518
+ wav_len = self.sr * self.duration
519
+ if wav_raw.shape[0] < wav_len:
520
+ wav_raw = np.tile(wav_raw, math.ceil(wav_len / wav_raw.shape[0]))
521
+ wav_raw = wav_raw[:int(wav_len)]
522
+
523
+ feat_len = self.fps * self.duration
524
+ if video_feat.shape[0] < feat_len:
525
+ video_feat = np.tile(video_feat, (math.ceil(feat_len / video_feat.shape[0]), 1))
526
+ video_feat = video_feat[:int(feat_len)]
527
+
528
+ return wav_raw, video_feat
529
+
530
+
531
+ def mix_audio_and_feat(self, wav_raw1=None, video_feat1=None, video_info_dict={}, mode='single'):
532
+ """ Return Mix Spec and Mix video feat"""
533
+ if mode == "single":
534
+ # spec1:
535
+ if not self.fix_frames:
536
+ start_idx = random.randint(0, self.sr * self.duration - self.truncate - 1) # audio start
537
+ else:
538
+ start_idx = 0
539
+
540
+ wav_start = start_idx
541
+ wav_truncate = self.truncate
542
+ wav_raw1 = wav_raw1[wav_start: wav_start + wav_truncate]
543
+
544
+ start_frame = int(self.fps * start_idx / self.sr)
545
+ truncate_frame = int(self.fps * self.truncate / self.sr)
546
+ video_feat1 = video_feat1[start_frame: start_frame + truncate_frame]
547
+
548
+ # info_dict:
549
+ video_info_dict['video_time1'] = str(start_frame) + '_' + str(start_frame+truncate_frame) # Start frame, end frame
550
+ video_info_dict['video_time2'] = ""
551
+
552
+ return wav_raw1, video_feat1, video_info_dict
553
+
554
+ elif mode == "concat":
555
+ total_spec_len = int(self.truncate / self.hop_len)
556
+ # Random Trucate len:
557
+ spec1_truncate_len = random.randint(self.min_duration * self.sr // self.hop_len, total_spec_len - self.min_duration * self.sr // self.hop_len - 1)
558
+ spec2_truncate_len = total_spec_len - spec1_truncate_len
559
+
560
+ # Sample spec clip:
561
+ spec_start1 = random.randint(0, total_spec_len - spec1_truncate_len - 1)
562
+ spec_start2 = random.randint(0, total_spec_len - spec2_truncate_len - 1)
563
+ spec_end1, spec_end2 = spec_start1 + spec1_truncate_len, spec_start2 + spec2_truncate_len
564
+
565
+ # concat spec:
566
+ return video_info_dict
567
+
568
+
569
+ def __getitem__(self, idx):
570
+
571
+ audio_name1 = self.data_list[idx]
572
+ wav_path1 = self.wav_list[idx]
573
+ video_feat_path1 = self.feat_list[idx]
574
+ # select other video:
575
+ flag = False
576
+ if random.uniform(0, 1) < -1:
577
+ flag = True
578
+ random_idx = idx
579
+ while random_idx == idx:
580
+ random_idx = random.randint(0, len(self.data_list)-1)
581
+ audio_name2 = self.data_list[random_idx]
582
+ wav_path2 = self.wav_list[random_idx]
583
+ video_feat_path2 = self.feat_list[random_idx]
584
+
585
+ # Load the Spec and Feat:
586
+ wav_raw1, video_feat1 = self.load_spec_and_feat(wav_path1, video_feat_path1)
587
+
588
+ if flag:
589
+ wav_raw2, video_feat2 = self.load_spec_and_feat(wav_path2, video_feat_path2)
590
+ video_info_dict = {'audio_name1':audio_name1, 'audio_name2': audio_name2}
591
+ mix_spec, mix_video_feat, mix_info = self.mix_audio_and_feat(video_info_dict, mode='concat')
592
+ else:
593
+ video_info_dict = {'audio_name1':audio_name1, 'audio_name2': ""}
594
+ mix_wav, mix_video_feat, mix_info = self.mix_audio_and_feat(wav_raw1=wav_raw1, video_feat1=video_feat1, video_info_dict=video_info_dict, mode='single')
595
+
596
+ data_dict = {}
597
+ data_dict['mix_wav'] = mix_wav # (131072,)
598
+ data_dict['mix_video_feat'] = mix_video_feat # (32, 512)
599
+ data_dict['mix_info_dict'] = mix_info
600
+
601
+ return data_dict
602
+
603
+
604
+ class video_codec_Dataset_Train(video_codec_Dataset):
605
+ def __init__(self, dataset_cfg):
606
+ super().__init__(split='train', **dataset_cfg)
607
+
608
+ class video_codec_Dataset_Test(video_codec_Dataset):
609
+ def __init__(self, dataset_cfg):
610
+ super().__init__(split='test', **dataset_cfg)
611
+
612
+ class video_codec_Dataset_Valid(video_codec_Dataset):
613
+ def __init__(self, dataset_cfg):
614
+ super().__init__(split='valid', **dataset_cfg)
615
+
616
+
617
+ class audio_video_spec_fullset_Audioset_Dataset(torch.utils.data.Dataset):
618
+ # Only Load audio dataset: for training Stage1: Audio Npy Dataset
619
+ def __init__(self, split, dataset1, dataset2, sr=22050, duration=10, truncate=220000,
620
+ fps=21.5, drop=0.0, fix_frames=False, hop_len=256):
621
+ super().__init__()
622
+
623
+ if split == "train":
624
+ self.split = "Train"
625
+
626
+ elif split == "valid" or split == 'test':
627
+ self.split = "Test"
628
+
629
+ # Default params:
630
+ self.min_duration = 2
631
+ self.sr = sr # 22050
632
+ self.duration = duration # 10
633
+ self.truncate = truncate # 220000
634
+ self.fps = fps
635
+ self.fix_frames = fix_frames
636
+ self.hop_len = hop_len
637
+ self.drop = drop
638
+ print("Fix Frames: {}".format(self.fix_frames))
639
+ print("Use Drop: {}".format(self.drop))
640
+
641
+ # Dataset1: (VGGSound)
642
+ assert dataset1.dataset_name == "VGGSound"
643
+ assert dataset2.dataset_name == "Audioset"
644
+
645
+ # spec_dir: spectrogram path
646
+ # feat_dir: CAVP feature path
647
+ # video_dir: video path
648
+
649
+ dataset1_spec_dir = os.path.join(dataset1.data_dir, "mel_maa2", "npy")
650
+ dataset1_feat_dir = os.path.join(dataset1.data_dir, "cavp")
651
+ split_txt_path = dataset1.split_txt_path
652
+ with open(os.path.join(split_txt_path, '{}.txt'.format(self.split)), "r") as f:
653
+ data_list1 = f.readlines()
654
+ data_list1 = list(map(lambda x: x.strip(), data_list1))
655
+
656
+ spec_list1 = list(map(lambda x: os.path.join(dataset1_spec_dir, x) + "_mel.npy", data_list1)) # spec
657
+ feat_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, x) + ".npz", data_list1)) # feat
658
+
659
+ if split == "train":
660
+ dataset2_spec_dir = os.path.join(dataset2.data_dir, "mel")
661
+ dataset2_feat_dir = os.path.join(dataset2.data_dir, "cavp_renamed")
662
+ split_txt_path = dataset2.split_txt_path
663
+ with open(os.path.join(split_txt_path, '{}.txt'.format(self.split)), "r") as f:
664
+ data_list2 = f.readlines()
665
+ data_list2 = list(map(lambda x: x.strip(), data_list2))
666
+
667
+ spec_list2 = list(map(lambda x: os.path.join(dataset2_spec_dir, f'Y{x}') + "_mel.npy", data_list2)) # spec
668
+ feat_list2 = list(map(lambda x: os.path.join(dataset2_feat_dir, x) + ".npz", data_list2)) # feat
669
+
670
+ data_list1 += data_list2
671
+ spec_list1 += spec_list2
672
+ feat_list1 += feat_list2
673
+
674
+ # Merge Data:
675
+ self.data_list = data_list1 if self.split != "Test" else data_list1[:200]
676
+ self.spec_list = spec_list1 if self.split != "Test" else spec_list1[:200]
677
+ self.feat_list = feat_list1 if self.split != "Test" else feat_list1[:200]
678
+
679
+ assert len(self.data_list) == len(self.spec_list) == len(self.feat_list)
680
+
681
+ shuffle_idx = np.random.permutation(np.arange(len(self.data_list)))
682
+ self.data_list = [self.data_list[i] for i in shuffle_idx]
683
+ self.spec_list = [self.spec_list[i] for i in shuffle_idx]
684
+ self.feat_list = [self.feat_list[i] for i in shuffle_idx]
685
+
686
+ print('Split: {} Sample Num: {}'.format(split, len(self.data_list)))
687
+
688
+ # self.check(self.spec_list)
689
+
690
+ def __len__(self):
691
+ return len(self.data_list)
692
+
693
+ def check(self, feat_list):
694
+ from tqdm import tqdm
695
+ for spec_path in tqdm(feat_list):
696
+ mel = np.load(spec_path).astype(np.float32)
697
+ if mel.shape[0] != 80:
698
+ import ipdb
699
+ ipdb.set_trace()
700
+
701
+
702
+
703
+ def load_spec_and_feat(self, spec_path, video_feat_path):
704
+ """Load audio spec and video feat"""
705
+ spec_raw = np.load(spec_path).astype(np.float32) # channel: 1
706
+ if spec_raw.shape[0] != 80:
707
+ print(f"corrupted mel: {spec_path}", flush=True)
708
+ spec_raw = np.zeros((80, 625), dtype=np.float32) # [C, T]
709
+
710
+ p = np.random.uniform(0, 1)
711
+ if p > self.drop:
712
+ try:
713
+ video_feat = np.load(video_feat_path)['feat'].astype(np.float32)
714
+ except:
715
+ print(f"corrupted video: {video_feat_path}", flush=True)
716
+ video_feat = np.load(os.path.join(os.path.dirname(video_feat_path), 'empty_vid.npz'))['feat'].astype(np.float32)
717
+ else:
718
+ video_feat = np.load(os.path.join(os.path.dirname(video_feat_path), 'empty_vid.npz'))['feat'].astype(np.float32)
719
+
720
+ spec_len = self.sr * self.duration / self.hop_len
721
+ if spec_raw.shape[1] < spec_len:
722
+ spec_raw = np.tile(spec_raw, math.ceil(spec_len / spec_raw.shape[1]))
723
+ spec_raw = spec_raw[:, :int(spec_len)]
724
+
725
+ feat_len = self.fps * self.duration
726
+ if video_feat.shape[0] < feat_len:
727
+ video_feat = np.tile(video_feat, (math.ceil(feat_len / video_feat.shape[0]), 1))
728
+ video_feat = video_feat[:int(feat_len)]
729
+ return spec_raw, video_feat
730
+
731
+ def mix_audio_and_feat(self, spec1=None, spec2=None, video_feat1=None, video_feat2=None, video_info_dict={},
732
+ mode='single'):
733
+ """ Return Mix Spec and Mix video feat"""
734
+ if mode == "single":
735
+ # spec1:
736
+ if not self.fix_frames:
737
+ start_idx = random.randint(0, self.sr * self.duration - self.truncate - 1) # audio start
738
+ else:
739
+ start_idx = 0
740
+
741
+ start_frame = int(self.fps * start_idx / self.sr)
742
+ truncate_frame = int(self.fps * self.truncate / self.sr)
743
+
744
+ # Spec Start & Truncate:
745
+ spec_start = int(start_idx / self.hop_len)
746
+ spec_truncate = int(self.truncate / self.hop_len)
747
+
748
+ spec1 = spec1[:, spec_start: spec_start + spec_truncate]
749
+ video_feat1 = video_feat1[start_frame: start_frame + truncate_frame]
750
+
751
+ # info_dict:
752
+ video_info_dict['video_time1'] = str(start_frame) + '_' + str(
753
+ start_frame + truncate_frame) # Start frame, end frame
754
+ video_info_dict['video_time2'] = ""
755
+ return spec1, video_feat1, video_info_dict
756
+
757
+ elif mode == "concat":
758
+ total_spec_len = int(self.truncate / self.hop_len)
759
+ # Random Trucate len:
760
+ spec1_truncate_len = random.randint(self.min_duration * self.sr // self.hop_len,
761
+ total_spec_len - self.min_duration * self.sr // self.hop_len - 1)
762
+ spec2_truncate_len = total_spec_len - spec1_truncate_len
763
+
764
+ # Sample spec clip:
765
+ spec_start1 = random.randint(0, total_spec_len - spec1_truncate_len - 1)
766
+ spec_start2 = random.randint(0, total_spec_len - spec2_truncate_len - 1)
767
+ spec_end1, spec_end2 = spec_start1 + spec1_truncate_len, spec_start2 + spec2_truncate_len
768
+
769
+ # concat spec:
770
+ spec1, spec2 = spec1[:, spec_start1: spec_end1], spec2[:, spec_start2: spec_end2]
771
+ concat_audio_spec = np.concatenate([spec1, spec2], axis=1)
772
+
773
+ # Concat Video Feat:
774
+ start1_frame, truncate1_frame = int(self.fps * spec_start1 * self.hop_len / self.sr), int(
775
+ self.fps * spec1_truncate_len * self.hop_len / self.sr)
776
+ start2_frame, truncate2_frame = int(self.fps * spec_start2 * self.hop_len / self.sr), int(
777
+ self.fps * self.truncate / self.sr) - truncate1_frame
778
+ video_feat1, video_feat2 = video_feat1[start1_frame: start1_frame + truncate1_frame], video_feat2[
779
+ start2_frame: start2_frame + truncate2_frame]
780
+ concat_video_feat = np.concatenate([video_feat1, video_feat2])
781
+
782
+ video_info_dict['video_time1'] = str(start1_frame) + '_' + str(
783
+ start1_frame + truncate1_frame) # Start frame, end frame
784
+ video_info_dict['video_time2'] = str(start2_frame) + '_' + str(start2_frame + truncate2_frame)
785
+ return concat_audio_spec, concat_video_feat, video_info_dict
786
+
787
+ def __getitem__(self, idx):
788
+
789
+ audio_name1 = self.data_list[idx]
790
+ spec_npy_path1 = self.spec_list[idx]
791
+ video_feat_path1 = self.feat_list[idx]
792
+
793
+ # select other video:
794
+ flag = False
795
+ if random.uniform(0, 1) < -1:
796
+ flag = True
797
+ random_idx = idx
798
+ while random_idx == idx:
799
+ random_idx = random.randint(0, len(self.data_list) - 1)
800
+ audio_name2 = self.data_list[random_idx]
801
+ spec_npy_path2 = self.spec_list[random_idx]
802
+ video_feat_path2 = self.feat_list[random_idx]
803
+
804
+ # Load the Spec and Feat:
805
+ spec1, video_feat1 = self.load_spec_and_feat(spec_npy_path1, video_feat_path1)
806
+
807
+ if flag:
808
+ spec2, video_feat2 = self.load_spec_and_feat(spec_npy_path2, video_feat_path2)
809
+ video_info_dict = {'audio_name1': audio_name1, 'audio_name2': audio_name2}
810
+ mix_spec, mix_video_feat, mix_info = self.mix_audio_and_feat(spec1, spec2, video_feat1, video_feat2, video_info_dict, mode='concat')
811
+ else:
812
+ video_info_dict = {'audio_name1': audio_name1, 'audio_name2': ""}
813
+ mix_spec, mix_video_feat, mix_info = self.mix_audio_and_feat(spec1=spec1, video_feat1=video_feat1, video_info_dict=video_info_dict, mode='single')
814
+
815
+ # print("mix spec shape:", mix_spec.shape)
816
+ # print("mix video feat:", mix_video_feat.shape)
817
+ data_dict = {}
818
+ data_dict['mix_spec'] = mix_spec # (80, 512)
819
+ data_dict['mix_video_feat'] = mix_video_feat # (32, 512)
820
+ data_dict['mix_info_dict'] = mix_info
821
+
822
+ return data_dict
823
+
824
+
825
+ class audio_video_spec_fullset_Audioset_Train(audio_video_spec_fullset_Audioset_Dataset):
826
+ def __init__(self, dataset_cfg):
827
+ super().__init__(split='train', **dataset_cfg)
828
+
829
+
830
+ class audio_video_spec_fullset_Audioset_Valid(audio_video_spec_fullset_Audioset_Dataset):
831
+ def __init__(self, dataset_cfg):
832
+ super().__init__(split='valid', **dataset_cfg)
833
+
834
+
835
+ class audio_video_spec_fullset_Audioset_Test(audio_video_spec_fullset_Audioset_Dataset):
836
+ def __init__(self, dataset_cfg):
837
+ super().__init__(split='test', **dataset_cfg)
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.cycle_lengths = cycle_lengths
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+ def schedule(self, n, **kwargs):
60
+ cycle = self.find_in_interval(n)
61
+ n = n - self.cum_cycles[cycle]
62
+ if self.verbosity_interval > 0:
63
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
+ f"current cycle {cycle}")
65
+ if n < self.lr_warm_up_steps[cycle]:
66
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
+ self.last_f = f
68
+ return f
69
+ else:
70
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
+ t = min(t, 1.0)
72
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
+ 1 + np.cos(t * np.pi))
74
+ self.last_f = f
75
+ return f
76
+
77
+ def __call__(self, n, **kwargs):
78
+ return self.schedule(n, **kwargs)
79
+
80
+
81
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
+ f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ return f
94
+ else:
95
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96
+ self.last_f = f
97
+ return f
98
+
ldm/models/__pycache__/autoencoder.cpython-38.pyc ADDED
Binary file (15.5 kB). View file
 
ldm/models/__pycache__/autoencoder.cpython-39.pyc ADDED
Binary file (15.5 kB). View file
 
ldm/models/__pycache__/autoencoder1d.cpython-38.pyc ADDED
Binary file (13.4 kB). View file
 
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ import torch.nn.functional as F
5
+ from contextlib import contextmanager
6
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
+ from packaging import version
8
+ import numpy as np
9
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
10
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+ from ldm.util import instantiate_from_config
13
+
14
+ class VQModel(pl.LightningModule):
15
+ def __init__(self,
16
+ ddconfig,
17
+ lossconfig,
18
+ n_embed,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ batch_resize_range=None,
26
+ scheduler_config=None,
27
+ lr_g_factor=1.0,
28
+ remap=None,
29
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
+ use_ema=False
31
+ ):
32
+ super().__init__()
33
+ self.embed_dim = embed_dim
34
+ self.n_embed = n_embed
35
+ self.image_key = image_key
36
+ self.encoder = Encoder(**ddconfig)
37
+ self.decoder = Decoder(**ddconfig)
38
+ self.loss = instantiate_from_config(lossconfig)
39
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
+ remap=remap,
41
+ sane_index_shape=sane_index_shape)
42
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
+ if colorize_nlabels is not None:
45
+ assert type(colorize_nlabels)==int
46
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
+ if monitor is not None:
48
+ self.monitor = monitor
49
+ self.batch_resize_range = batch_resize_range
50
+ if self.batch_resize_range is not None:
51
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
+
53
+ self.use_ema = use_ema
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self)
56
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
+
58
+ if ckpt_path is not None:
59
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
+ self.scheduler_config = scheduler_config
61
+ self.lr_g_factor = lr_g_factor
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def init_from_ckpt(self, path, ignore_keys=list()):
79
+ sd = torch.load(path, map_location="cpu")["state_dict"]
80
+ keys = list(sd.keys())
81
+ for k in keys:
82
+ for ik in ignore_keys:
83
+ if k.startswith(ik):
84
+ print("Deleting key {} from state_dict.".format(k))
85
+ del sd[k]
86
+ missing, unexpected = self.load_state_dict(sd, strict=False)
87
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
+ if len(missing) > 0:
89
+ print(f"Missing Keys: {missing}")
90
+ print(f"Unexpected Keys: {unexpected}")
91
+
92
+ def on_train_batch_end(self, *args, **kwargs):
93
+ if self.use_ema:
94
+ self.model_ema(self)
95
+
96
+ def encode(self, x):
97
+ h = self.encoder(x)
98
+ h = self.quant_conv(h)
99
+ quant, emb_loss, info = self.quantize(h)
100
+ return quant, emb_loss, info
101
+
102
+ def encode_to_prequant(self, x):
103
+ h = self.encoder(x)
104
+ h = self.quant_conv(h)
105
+ return h
106
+
107
+ def decode(self, quant):
108
+ quant = self.post_quant_conv(quant)
109
+ dec = self.decoder(quant)
110
+ return dec
111
+
112
+ def decode_code(self, code_b):
113
+ quant_b = self.quantize.embed_code(code_b)
114
+ dec = self.decode(quant_b)
115
+ return dec
116
+
117
+ def forward(self, input, return_pred_indices=False):
118
+ quant, diff, (_,_,ind) = self.encode(input)
119
+ dec = self.decode(quant)
120
+ if return_pred_indices:
121
+ return dec, diff, ind
122
+ return dec, diff
123
+
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
+ if self.batch_resize_range is not None:
130
+ lower_size = self.batch_resize_range[0]
131
+ upper_size = self.batch_resize_range[1]
132
+ if self.global_step <= 4:
133
+ # do the first few batches with max size to avoid later oom
134
+ new_resize = upper_size
135
+ else:
136
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
+ if new_resize != x.shape[2]:
138
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
139
+ x = x.detach()
140
+ return x
141
+
142
+ def training_step(self, batch, batch_idx, optimizer_idx):
143
+ # https://github.com/pytorch/pytorch/issues/37142
144
+ # try not to fool the heuristics
145
+ x = self.get_input(batch, self.image_key)
146
+ xrec, qloss, ind = self(x, return_pred_indices=True)
147
+
148
+ if optimizer_idx == 0:
149
+ # autoencode
150
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
+ last_layer=self.get_last_layer(), split="train",
152
+ predicted_indices=ind)
153
+
154
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
+ return aeloss
156
+
157
+ if optimizer_idx == 1:
158
+ # discriminator
159
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
+ last_layer=self.get_last_layer(), split="train")
161
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
+ return discloss
163
+
164
+ def validation_step(self, batch, batch_idx):
165
+ log_dict = self._validation_step(batch, batch_idx)
166
+ with self.ema_scope():
167
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
+ return log_dict
169
+
170
+ def _validation_step(self, batch, batch_idx, suffix=""):
171
+ x = self.get_input(batch, self.image_key)
172
+ xrec, qloss, ind = self(x, return_pred_indices=True)
173
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
+ self.global_step,
175
+ last_layer=self.get_last_layer(),
176
+ split="val"+suffix,
177
+ predicted_indices=ind
178
+ )
179
+
180
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
+ self.global_step,
182
+ last_layer=self.get_last_layer(),
183
+ split="val"+suffix,
184
+ predicted_indices=ind
185
+ )
186
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
+ self.log(f"val{suffix}/rec_loss", rec_loss,
188
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
+ self.log(f"val{suffix}/aeloss", aeloss,
190
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
+ del log_dict_ae[f"val{suffix}/rec_loss"]
193
+ self.log_dict(log_dict_ae)
194
+ self.log_dict(log_dict_disc)
195
+ return self.log_dict
196
+
197
+ def test_step(self, batch, batch_idx):
198
+ x = self.get_input(batch, self.image_key)
199
+ xrec, qloss, ind = self(x, return_pred_indices=True)
200
+ reconstructions = (xrec + 1)/2 # to mel scale
201
+ test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
202
+ savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
203
+ if not os.path.exists(savedir):
204
+ os.makedirs(savedir)
205
+
206
+ file_names = batch['f_name']
207
+ # print(f"reconstructions.shape:{reconstructions.shape}",file_names)
208
+ reconstructions = reconstructions.cpu().numpy().squeeze(1) # squuze channel dim
209
+ for b in range(reconstructions.shape[0]):
210
+ vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
211
+ v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
212
+ save_img_path = os.path.join(savedir,f'{v_n}_sample_{num}.npy')
213
+ np.save(save_img_path,reconstructions[b])
214
+
215
+ return None
216
+
217
+ def configure_optimizers(self):
218
+ lr_d = self.learning_rate
219
+ lr_g = self.lr_g_factor*self.learning_rate
220
+ print("lr_d", lr_d)
221
+ print("lr_g", lr_g)
222
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
223
+ list(self.decoder.parameters())+
224
+ list(self.quantize.parameters())+
225
+ list(self.quant_conv.parameters())+
226
+ list(self.post_quant_conv.parameters()),
227
+ lr=lr_g, betas=(0.5, 0.9))
228
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
229
+ lr=lr_d, betas=(0.5, 0.9))
230
+
231
+ if self.scheduler_config is not None:
232
+ scheduler = instantiate_from_config(self.scheduler_config)
233
+
234
+ print("Setting up LambdaLR scheduler...")
235
+ scheduler = [
236
+ {
237
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
238
+ 'interval': 'step',
239
+ 'frequency': 1
240
+ },
241
+ {
242
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
243
+ 'interval': 'step',
244
+ 'frequency': 1
245
+ },
246
+ ]
247
+ return [opt_ae, opt_disc], scheduler
248
+ return [opt_ae, opt_disc], []
249
+
250
+ def get_last_layer(self):
251
+ return self.decoder.conv_out.weight
252
+
253
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
254
+ log = dict()
255
+ x = self.get_input(batch, self.image_key)
256
+ x = x.to(self.device)
257
+ if only_inputs:
258
+ log["inputs"] = x
259
+ return log
260
+ xrec, _ = self(x)
261
+ if x.shape[1] > 3:
262
+ # colorize with random projection
263
+ assert xrec.shape[1] > 3
264
+ x = self.to_rgb(x)
265
+ xrec = self.to_rgb(xrec)
266
+ log["inputs"] = x
267
+ log["reconstructions"] = xrec
268
+ if plot_ema:
269
+ with self.ema_scope():
270
+ xrec_ema, _ = self(x)
271
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
272
+ log["reconstructions_ema"] = xrec_ema
273
+ return log
274
+
275
+ def to_rgb(self, x):
276
+ assert self.image_key == "segmentation"
277
+ if not hasattr(self, "colorize"):
278
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
279
+ x = F.conv2d(x, weight=self.colorize)
280
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
281
+ return x
282
+
283
+
284
+ class VQModelInterface(VQModel):
285
+ def __init__(self, embed_dim, *args, **kwargs):
286
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
287
+ self.embed_dim = embed_dim
288
+
289
+ def encode(self, x):# VQModel的quantize写在encoder里,VQModelInterface则将其写在decoder里
290
+ h = self.encoder(x)
291
+ h = self.quant_conv(h)
292
+ return h
293
+
294
+ def decode(self, h, force_not_quantize=False):
295
+ # also go through quantization layer
296
+ if not force_not_quantize:
297
+ quant, emb_loss, info = self.quantize(h)
298
+ else:
299
+ quant = h
300
+ quant = self.post_quant_conv(quant)
301
+ dec = self.decoder(quant)
302
+ return dec
303
+
304
+
305
+ class AutoencoderKL(pl.LightningModule):
306
+ def __init__(self,
307
+ ddconfig,
308
+ lossconfig,
309
+ embed_dim,
310
+ ckpt_path=None,
311
+ ignore_keys=[],
312
+ image_key="image",
313
+ colorize_nlabels=None,
314
+ monitor=None,
315
+ ):
316
+ super().__init__()
317
+ self.to_1d = False
318
+ print(f"to_1d is {self.to_1d} in AUTOENCODER")
319
+ self.image_key = image_key
320
+ self.encoder = Encoder(**ddconfig)
321
+ self.decoder = Decoder(**ddconfig)
322
+ self.loss = instantiate_from_config(lossconfig)
323
+ assert ddconfig["double_z"]
324
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
325
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
326
+ self.embed_dim = embed_dim
327
+ if colorize_nlabels is not None:
328
+ assert type(colorize_nlabels)==int
329
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
330
+ if monitor is not None:
331
+ self.monitor = monitor
332
+ if ckpt_path is not None:
333
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
334
+ # self.automatic_optimization = False # hjw for debug
335
+
336
+ def init_from_ckpt(self, path, ignore_keys=list()):
337
+ sd = torch.load(path, map_location="cpu")["state_dict"]
338
+ keys = list(sd.keys())
339
+ for k in keys:
340
+ for ik in ignore_keys:
341
+ if k.startswith(ik):
342
+ print("Deleting key {} from state_dict.".format(k))
343
+ del sd[k]
344
+ self.load_state_dict(sd, strict=False)
345
+ print(f"Restored from {path}")
346
+
347
+ def encode(self, x):
348
+ if self.to_1d and len(x.shape)==3:
349
+ x = x.unsqueeze(1)
350
+ h = self.encoder(x)
351
+ moments = self.quant_conv(h)
352
+ if self.to_1d:
353
+ b,c,h,w = moments.shape
354
+ moments = moments.reshape(b,c*h,w)
355
+ posterior = DiagonalGaussianDistribution(moments)
356
+ return posterior
357
+
358
+ def decode(self, z):
359
+ if self.to_1d:
360
+ b,c_h,w = z.shape
361
+ c = self.post_quant_conv.in_channels
362
+ z = z.reshape(b,c,-1,w)
363
+ z = self.post_quant_conv(z)
364
+ dec = self.decoder(z)
365
+ return dec
366
+
367
+ def forward(self, input, sample_posterior=True):
368
+ posterior = self.encode(input)
369
+ if sample_posterior:
370
+ z = posterior.sample()
371
+ else:
372
+ z = posterior.mode()
373
+ dec = self.decode(z)
374
+ return dec, posterior
375
+
376
+ def get_input(self, batch, k):
377
+ x = batch[k]
378
+ if len(x.shape) == 3:
379
+ x = x[..., None]
380
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
381
+ return x
382
+
383
+ def training_step(self, batch, batch_idx, optimizer_idx):
384
+ inputs = self.get_input(batch, self.image_key)
385
+ reconstructions, posterior = self(inputs)
386
+
387
+ if optimizer_idx == 0:
388
+ # train encoder+decoder+logvar
389
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
390
+ last_layer=self.get_last_layer(), split="train")
391
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
392
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
393
+ # print(optimizer_idx,log_dict_ae)
394
+ return aeloss
395
+
396
+ if optimizer_idx == 1:
397
+ # train the discriminator
398
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
399
+ last_layer=self.get_last_layer(), split="train")
400
+
401
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
402
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
403
+ # print(optimizer_idx,log_dict_disc)
404
+ return discloss
405
+
406
+ def validation_step(self, batch, batch_idx):
407
+ inputs = self.get_input(batch, self.image_key)
408
+ reconstructions, posterior = self(inputs)
409
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
410
+ last_layer=self.get_last_layer(), split="val")
411
+
412
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
413
+ last_layer=self.get_last_layer(), split="val")
414
+
415
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
416
+ self.log_dict(log_dict_ae)
417
+ self.log_dict(log_dict_disc)
418
+ return self.log_dict
419
+
420
+ def test_step(self, batch, batch_idx):
421
+ inputs = self.get_input(batch, self.image_key)# inputs shape:(b,mel_len,T)
422
+ reconstructions, posterior = self(inputs)# reconstructions:(b,mel_len,T)
423
+ mse_loss = torch.nn.functional.mse_loss(reconstructions,inputs)
424
+ self.log('test/mse_loss',mse_loss)
425
+
426
+ test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
427
+ savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
428
+ if batch_idx == 0:
429
+ print(f"save_path is: {savedir}")
430
+ if not os.path.exists(savedir):
431
+ os.makedirs(savedir)
432
+ print(f"save_path is: {savedir}")
433
+
434
+ file_names = batch['f_name']
435
+ # print(f"reconstructions.shape:{reconstructions.shape}",file_names)
436
+ # reconstructions = (reconstructions + 1)/2 # to mel scale
437
+ reconstructions = reconstructions.cpu().numpy().squeeze(1) # squeeze channel dim
438
+ for b in range(reconstructions.shape[0]):
439
+ vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
440
+ v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
441
+ save_img_path = os.path.join(savedir, f'{v_n}.npy') # f'{v_n}_sample_{num}.npy' f'{v_n}.npy'
442
+ np.save(save_img_path,reconstructions[b])
443
+
444
+ return None
445
+
446
+ def configure_optimizers(self):
447
+ lr = self.learning_rate
448
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
449
+ list(self.decoder.parameters())+
450
+ list(self.quant_conv.parameters())+
451
+ list(self.post_quant_conv.parameters()),
452
+ lr=lr, betas=(0.5, 0.9))
453
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
454
+ lr=lr, betas=(0.5, 0.9))
455
+ return [opt_ae, opt_disc], []
456
+
457
+ def get_last_layer(self):
458
+ return self.decoder.conv_out.weight
459
+
460
+ @torch.no_grad()
461
+ def log_images(self, batch, only_inputs=False,save_dir = 'mel_result_ae13_26_debug/fake_class', **kwargs): # 在main.py的on_validation_batch_end中调用
462
+ log = dict()
463
+ x = self.get_input(batch, self.image_key)
464
+ x = x.to(self.device)
465
+ if not only_inputs:
466
+ xrec, posterior = self(x)
467
+ if x.shape[1] > 3:
468
+ # colorize with random projection
469
+ assert xrec.shape[1] > 3
470
+ x = self.to_rgb(x)
471
+ xrec = self.to_rgb(xrec)
472
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
473
+ log["reconstructions"] = xrec
474
+ log["inputs"] = x
475
+ return log
476
+
477
+ def to_rgb(self, x):
478
+ assert self.image_key == "segmentation"
479
+ if not hasattr(self, "colorize"):
480
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
481
+ x = F.conv2d(x, weight=self.colorize)
482
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
483
+ return x
484
+
485
+
486
+ class IdentityFirstStage(torch.nn.Module):
487
+ def __init__(self, *args, vq_interface=False, **kwargs):
488
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
489
+ super().__init__()
490
+
491
+ def encode(self, x, *args, **kwargs):
492
+ return x
493
+
494
+ def decode(self, x, *args, **kwargs):
495
+ return x
496
+
497
+ def quantize(self, x, *args, **kwargs):
498
+ if self.vq_interface:
499
+ return x, None, [None, None, None]
500
+ return x
501
+
502
+ def forward(self, x, *args, **kwargs):
503
+ return x
ldm/models/autoencoder1d.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 与autoencoder.py的区别在于,autoencoder.py是(B,1,80,T) ->(B,C,80/8,T/8),现在vae要变成(B,80,T) -> (B,80/downsample_c,T/downsample_t)
3
+ """
4
+
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ import pytorch_lightning as pl
9
+ import torch.nn.functional as F
10
+ from contextlib import contextmanager
11
+ from packaging import version
12
+ import numpy as np
13
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
14
+ from torch.optim.lr_scheduler import LambdaLR
15
+ from ldm.util import instantiate_from_config
16
+
17
+
18
+ class AutoencoderKL(pl.LightningModule):
19
+ def __init__(self,
20
+ embed_dim,
21
+ ddconfig,
22
+ lossconfig,
23
+ ckpt_path=None,
24
+ ignore_keys=[],
25
+ image_key="image",
26
+ monitor=None,
27
+ ):
28
+ super().__init__()
29
+ self.image_key = image_key
30
+ self.encoder = Encoder1D(**ddconfig)
31
+ self.decoder = Decoder1D(**ddconfig)
32
+ self.loss = instantiate_from_config(lossconfig)
33
+ assert ddconfig["double_z"]
34
+ self.quant_conv = torch.nn.Conv1d(2*ddconfig["z_channels"], 2*embed_dim, 1)
35
+ self.post_quant_conv = torch.nn.Conv1d(embed_dim, ddconfig["z_channels"], 1)
36
+ self.embed_dim = embed_dim
37
+ if monitor is not None:
38
+ self.monitor = monitor
39
+ if ckpt_path is not None:
40
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
41
+
42
+ def init_from_ckpt(self, path, ignore_keys=list()):
43
+ sd = torch.load(path, map_location="cpu")["state_dict"]
44
+ keys = list(sd.keys())
45
+ for k in keys:
46
+ for ik in ignore_keys:
47
+ if k.startswith(ik):
48
+ print("Deleting key {} from state_dict.".format(k))
49
+ del sd[k]
50
+ self.load_state_dict(sd, strict=False)
51
+ print(f"AutoencoderKL Restored from {path} Done")
52
+
53
+ def encode(self, x):
54
+ h = self.encoder(x)
55
+ moments = self.quant_conv(h)
56
+ posterior = DiagonalGaussianDistribution(moments)
57
+ return posterior
58
+
59
+ def decode(self, z):
60
+ z = self.post_quant_conv(z)
61
+ dec = self.decoder(z)
62
+ return dec
63
+
64
+ def forward(self, input, sample_posterior=True):
65
+ posterior = self.encode(input)
66
+ if sample_posterior:
67
+ z = posterior.sample()
68
+ else:
69
+ z = posterior.mode()
70
+ dec = self.decode(z)
71
+ return dec, posterior
72
+
73
+ def get_input(self, batch, k):
74
+ x = batch[k]
75
+ assert len(x.shape) == 3
76
+ x = x.to(memory_format=torch.contiguous_format).float()
77
+ return x
78
+
79
+ def training_step(self, batch, batch_idx, optimizer_idx):
80
+ inputs = self.get_input(batch, self.image_key)
81
+ # print(inputs.shape)
82
+ reconstructions, posterior = self(inputs)
83
+
84
+ if optimizer_idx == 0:
85
+ # train encoder+decoder+logvar
86
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
87
+ last_layer=self.get_last_layer(), split="train")
88
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
89
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
90
+ return aeloss
91
+
92
+ if optimizer_idx == 1:
93
+ # train the discriminator
94
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
95
+ last_layer=self.get_last_layer(), split="train")
96
+
97
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
98
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
99
+ return discloss
100
+
101
+ def validation_step(self, batch, batch_idx):
102
+ inputs = self.get_input(batch, self.image_key)
103
+ reconstructions, posterior = self(inputs)
104
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
105
+ last_layer=self.get_last_layer(), split="val")
106
+
107
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
108
+ last_layer=self.get_last_layer(), split="val")
109
+
110
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
111
+ self.log_dict(log_dict_ae)
112
+ self.log_dict(log_dict_disc)
113
+ return self.log_dict
114
+
115
+ def test_step(self, batch, batch_idx):
116
+ inputs = self.get_input(batch, self.image_key)# inputs shape:(b,mel_len,T)
117
+ reconstructions, posterior = self(inputs)# reconstructions:(b,mel_len,T)
118
+ mse_loss = torch.nn.functional.mse_loss(reconstructions,inputs)
119
+ self.log('test/mse_loss',mse_loss)
120
+
121
+ test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
122
+ savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
123
+ if batch_idx == 0:
124
+ print(f"save_path is: {savedir}")
125
+ if not os.path.exists(savedir):
126
+ os.makedirs(savedir)
127
+ print(f"save_path is: {savedir}")
128
+
129
+ file_names = batch['f_name']
130
+ # print(f"reconstructions.shape:{reconstructions.shape}",file_names)
131
+ # reconstructions = (reconstructions + 1)/2 # to mel scale
132
+ reconstructions = reconstructions.cpu().numpy() # squuze channel dim
133
+ for b in range(reconstructions.shape[0]):
134
+ vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
135
+ v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
136
+ save_img_path = os.path.join(savedir, f'{v_n}.npy') # f'{v_n}_sample_{num}.npy' f'{v_n}.npy'
137
+ np.save(save_img_path,reconstructions[b])
138
+
139
+ return None
140
+
141
+ def configure_optimizers(self):
142
+ lr = self.learning_rate
143
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
144
+ list(self.decoder.parameters())+
145
+ list(self.quant_conv.parameters())+
146
+ list(self.post_quant_conv.parameters()),
147
+ lr=lr, betas=(0.5, 0.9))
148
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
149
+ lr=lr, betas=(0.5, 0.9))
150
+ return [opt_ae, opt_disc], []
151
+
152
+ def get_last_layer(self):
153
+ return self.decoder.conv_out.weight
154
+
155
+ @torch.no_grad()
156
+ def log_images(self, batch, only_inputs=False, **kwargs):
157
+ log = dict()
158
+ x = self.get_input(batch, self.image_key)
159
+ x = x.to(self.device)
160
+
161
+ if not only_inputs:
162
+ xrec, posterior = self(x)
163
+ log["samples"] = self.decode(torch.randn_like(posterior.sample())).unsqueeze(1) # (b,1,H,W)
164
+ log["reconstructions"] = xrec.unsqueeze(1)
165
+ log["inputs"] = x.unsqueeze(1)
166
+ return log
167
+
168
+
169
+ def Normalize(in_channels, num_groups=32):
170
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
171
+
172
+ def nonlinearity(x):
173
+ # swish
174
+ return x*torch.sigmoid(x)
175
+
176
+ class ResnetBlock1D(nn.Module):
177
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
178
+ dropout, temb_channels=512,kernel_size = 3):
179
+ super().__init__()
180
+ self.in_channels = in_channels
181
+ out_channels = in_channels if out_channels is None else out_channels
182
+ self.out_channels = out_channels
183
+ self.use_conv_shortcut = conv_shortcut
184
+
185
+ self.norm1 = Normalize(in_channels)
186
+ self.conv1 = torch.nn.Conv1d(in_channels,
187
+ out_channels,
188
+ kernel_size=kernel_size,
189
+ stride=1,
190
+ padding=kernel_size//2)
191
+ if temb_channels > 0:
192
+ self.temb_proj = torch.nn.Linear(temb_channels,
193
+ out_channels)
194
+ self.norm2 = Normalize(out_channels)
195
+ self.dropout = torch.nn.Dropout(dropout)
196
+ self.conv2 = torch.nn.Conv1d(out_channels,
197
+ out_channels,
198
+ kernel_size=kernel_size,
199
+ stride=1,
200
+ padding=kernel_size//2)
201
+ if self.in_channels != self.out_channels:
202
+ if self.use_conv_shortcut:
203
+ self.conv_shortcut = torch.nn.Conv1d(in_channels,
204
+ out_channels,
205
+ kernel_size=kernel_size,
206
+ stride=1,
207
+ padding=kernel_size//2)
208
+ else:
209
+ self.nin_shortcut = torch.nn.Conv1d(in_channels,
210
+ out_channels,
211
+ kernel_size=1,
212
+ stride=1,
213
+ padding=0)
214
+
215
+ def forward(self, x, temb):
216
+ h = x
217
+ h = self.norm1(h)
218
+ h = nonlinearity(h)
219
+ h = self.conv1(h)
220
+
221
+ if temb is not None:
222
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
223
+
224
+ h = self.norm2(h)
225
+ h = nonlinearity(h)
226
+ h = self.dropout(h)
227
+ h = self.conv2(h)
228
+
229
+ if self.in_channels != self.out_channels:
230
+ if self.use_conv_shortcut:
231
+ x = self.conv_shortcut(x)
232
+ else:
233
+ x = self.nin_shortcut(x)
234
+
235
+ return x+h
236
+
237
+ class AttnBlock1D(nn.Module):
238
+ def __init__(self, in_channels):
239
+ super().__init__()
240
+ self.in_channels = in_channels
241
+
242
+ self.norm = Normalize(in_channels)
243
+ self.q = torch.nn.Conv1d(in_channels,
244
+ in_channels,
245
+ kernel_size=1)
246
+ self.k = torch.nn.Conv1d(in_channels,
247
+ in_channels,
248
+ kernel_size=1)
249
+ self.v = torch.nn.Conv1d(in_channels,
250
+ in_channels,
251
+ kernel_size=1)
252
+ self.proj_out = torch.nn.Conv1d(in_channels,
253
+ in_channels,
254
+ kernel_size=1)
255
+
256
+
257
+ def forward(self, x):
258
+ h_ = x
259
+ h_ = self.norm(h_)
260
+ q = self.q(h_)
261
+ k = self.k(h_)
262
+ v = self.v(h_)
263
+
264
+ # compute attention
265
+ b,t,c = q.shape
266
+ q = q.permute(0,2,1) # b,t,c
267
+ w_ = torch.bmm(q,k) # b,t,t w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
268
+ # if still 2d attn (q:b,hw,c ,k:b,c,hw -> w_:b,hw,hw)
269
+ w_ = w_ * (int(t)**(-0.5))
270
+ w_ = torch.nn.functional.softmax(w_, dim=2)
271
+
272
+ # attend to values
273
+ w_ = w_.permute(0,2,1) # b,t,t (first t of k, second of q)
274
+ h_ = torch.bmm(v,w_) # b,c,t (t of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
275
+
276
+ h_ = self.proj_out(h_)
277
+
278
+ return x+h_
279
+
280
+ class Upsample1D(nn.Module):
281
+ def __init__(self, in_channels, with_conv):
282
+ super().__init__()
283
+ self.with_conv = with_conv
284
+ if self.with_conv:
285
+ self.conv = torch.nn.Conv1d(in_channels,
286
+ in_channels,
287
+ kernel_size=3,
288
+ stride=1,
289
+ padding=1)
290
+
291
+ def forward(self, x):
292
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") # support 3D tensor(B,C,T)
293
+ if self.with_conv:
294
+ x = self.conv(x)
295
+ return x
296
+
297
+
298
+ class Downsample1D(nn.Module):
299
+ def __init__(self, in_channels, with_conv):
300
+ super().__init__()
301
+ self.with_conv = with_conv
302
+ if self.with_conv:
303
+ # no asymmetric padding in torch conv, must do it ourselves
304
+ self.conv = torch.nn.Conv1d(in_channels,
305
+ in_channels,
306
+ kernel_size=3,
307
+ stride=2,
308
+ padding=0)
309
+
310
+ def forward(self, x):
311
+ if self.with_conv:
312
+ pad = (0,1)
313
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
314
+ x = self.conv(x)
315
+ else:
316
+ x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
317
+ return x
318
+
319
+ class Encoder1D(nn.Module):
320
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
321
+ attn_layers = [],down_layers = [], dropout=0.0, resamp_with_conv=True, in_channels,
322
+ z_channels, double_z=True,kernel_size=3, **ignore_kwargs):
323
+ """ out_ch is only used in decoder,not used here
324
+ """
325
+ super().__init__()
326
+ self.ch = ch
327
+ self.temb_ch = 0
328
+ self.num_layers = len(ch_mult)
329
+ self.num_res_blocks = num_res_blocks
330
+ self.in_channels = in_channels
331
+ print(f"downsample rates is {2**len(down_layers)}")
332
+ self.down_layers = down_layers
333
+ self.attn_layers = attn_layers
334
+ self.conv_in = torch.nn.Conv1d(in_channels,
335
+ self.ch,
336
+ kernel_size=kernel_size,
337
+ stride=1,
338
+ padding=kernel_size//2)
339
+
340
+ in_ch_mult = (1,)+tuple(ch_mult)
341
+ self.in_ch_mult = in_ch_mult
342
+ # downsampling
343
+ self.down = nn.ModuleList()
344
+ for i_level in range(self.num_layers):
345
+ block = nn.ModuleList()
346
+ attn = nn.ModuleList()
347
+ block_in = ch*in_ch_mult[i_level]
348
+ block_out = ch*ch_mult[i_level]
349
+ for i_block in range(self.num_res_blocks):
350
+ block.append(ResnetBlock1D(in_channels=block_in,
351
+ out_channels=block_out,
352
+ temb_channels=self.temb_ch,
353
+ dropout=dropout,
354
+ kernel_size=kernel_size))
355
+ block_in = block_out
356
+ if i_level in attn_layers:
357
+ # print(f"add attn in layer:{i_level}")
358
+ attn.append(AttnBlock1D(block_in))
359
+ down = nn.Module()
360
+ down.block = block
361
+ down.attn = attn
362
+ if i_level in down_layers:
363
+ down.downsample = Downsample1D(block_in, resamp_with_conv)
364
+ self.down.append(down)
365
+
366
+ # middle
367
+ self.mid = nn.Module()
368
+ self.mid.block_1 = ResnetBlock1D(in_channels=block_in,
369
+ out_channels=block_in,
370
+ temb_channels=self.temb_ch,
371
+ dropout=dropout,
372
+ kernel_size=kernel_size)
373
+ self.mid.attn_1 = AttnBlock1D(block_in)
374
+ self.mid.block_2 = ResnetBlock1D(in_channels=block_in,
375
+ out_channels=block_in,
376
+ temb_channels=self.temb_ch,
377
+ dropout=dropout,
378
+ kernel_size=kernel_size)
379
+
380
+ # end
381
+ self.norm_out = Normalize(block_in)# GroupNorm
382
+ self.conv_out = torch.nn.Conv1d(block_in,
383
+ 2*z_channels if double_z else z_channels,
384
+ kernel_size=kernel_size,
385
+ stride=1,
386
+ padding=kernel_size//2)
387
+
388
+ def forward(self, x):
389
+ # timestep embedding
390
+ temb = None
391
+
392
+ # downsampling
393
+ hs = [self.conv_in(x)]
394
+ for i_level in range(self.num_layers):
395
+ for i_block in range(self.num_res_blocks):
396
+ h = self.down[i_level].block[i_block](hs[-1], temb)
397
+ if len(self.down[i_level].attn) > 0:
398
+ h = self.down[i_level].attn[i_block](h)
399
+ hs.append(h)
400
+ if i_level in self.down_layers:
401
+ hs.append(self.down[i_level].downsample(hs[-1]))
402
+
403
+ # middle
404
+ h = hs[-1]
405
+ h = self.mid.block_1(h, temb)
406
+ h = self.mid.attn_1(h)
407
+ h = self.mid.block_2(h, temb)
408
+
409
+ # end
410
+ h = self.norm_out(h)
411
+ h = nonlinearity(h)
412
+ h = self.conv_out(h)
413
+ return h
414
+
415
+ class Decoder1D(nn.Module):
416
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
417
+ attn_layers = [],down_layers = [], dropout=0.0,kernel_size=3, resamp_with_conv=True, in_channels,
418
+ z_channels, give_pre_end=False, tanh_out=False, **ignorekwargs):
419
+ super().__init__()
420
+ self.ch = ch
421
+ self.temb_ch = 0
422
+ self.num_layers = len(ch_mult)
423
+ self.num_res_blocks = num_res_blocks
424
+ self.in_channels = in_channels
425
+ self.give_pre_end = give_pre_end
426
+ self.tanh_out = tanh_out
427
+ self.down_layers = [i+1 for i in down_layers] # each downlayer add one
428
+ print(f"upsample rates is {2**len(down_layers)}")
429
+
430
+ # compute in_ch_mult, block_in and curr_res at lowest res
431
+ in_ch_mult = (1,)+tuple(ch_mult)
432
+ block_in = ch*ch_mult[self.num_layers-1]
433
+
434
+
435
+ # z to block_in
436
+ self.conv_in = torch.nn.Conv1d(z_channels,
437
+ block_in,
438
+ kernel_size=kernel_size,
439
+ stride=1,
440
+ padding=kernel_size//2)
441
+
442
+ # middle
443
+ self.mid = nn.Module()
444
+ self.mid.block_1 = ResnetBlock1D(in_channels=block_in,
445
+ out_channels=block_in,
446
+ temb_channels=self.temb_ch,
447
+ dropout=dropout)
448
+ self.mid.attn_1 = AttnBlock1D(block_in)
449
+ self.mid.block_2 = ResnetBlock1D(in_channels=block_in,
450
+ out_channels=block_in,
451
+ temb_channels=self.temb_ch,
452
+ dropout=dropout)
453
+
454
+ # upsampling
455
+ self.up = nn.ModuleList()
456
+ for i_level in reversed(range(self.num_layers)):
457
+ block = nn.ModuleList()
458
+ attn = nn.ModuleList()
459
+ block_out = ch*ch_mult[i_level]
460
+ for i_block in range(self.num_res_blocks+1):
461
+ block.append(ResnetBlock1D(in_channels=block_in,
462
+ out_channels=block_out,
463
+ temb_channels=self.temb_ch,
464
+ dropout=dropout))
465
+ block_in = block_out
466
+ if i_level in attn_layers:
467
+ # print(f"add attn in layer:{i_level}")
468
+ attn.append(AttnBlock1D(block_in))
469
+ up = nn.Module()
470
+ up.block = block
471
+ up.attn = attn
472
+ if i_level in self.down_layers:
473
+ up.upsample = Upsample1D(block_in, resamp_with_conv)
474
+ self.up.insert(0, up) # prepend to get consistent order
475
+
476
+ # end
477
+ self.norm_out = Normalize(block_in)
478
+ self.conv_out = torch.nn.Conv1d(block_in,
479
+ out_ch,
480
+ kernel_size=kernel_size,
481
+ stride=1,
482
+ padding=kernel_size//2)
483
+
484
+ def forward(self, z):
485
+ #assert z.shape[1:] == self.z_shape[1:]
486
+ self.last_z_shape = z.shape
487
+
488
+ # timestep embedding
489
+ temb = None
490
+
491
+ # z to block_in
492
+ h = self.conv_in(z)
493
+
494
+ # middle
495
+ h = self.mid.block_1(h, temb)
496
+ h = self.mid.attn_1(h)
497
+ h = self.mid.block_2(h, temb)
498
+
499
+ # upsampling
500
+ for i_level in reversed(range(self.num_layers)):
501
+ for i_block in range(self.num_res_blocks+1):
502
+ h = self.up[i_level].block[i_block](h, temb)
503
+ if len(self.up[i_level].attn) > 0:
504
+ h = self.up[i_level].attn[i_block](h)
505
+ if i_level in self.down_layers:
506
+ h = self.up[i_level].upsample(h)
507
+
508
+ # end
509
+ if self.give_pre_end:
510
+ return h
511
+
512
+ h = self.norm_out(h)
513
+ h = nonlinearity(h)
514
+ h = self.conv_out(h)
515
+ if self.tanh_out:
516
+ h = torch.tanh(h)
517
+ return h
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (177 Bytes). View file
 
ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (177 Bytes). View file
 
ldm/models/diffusion/__pycache__/cfm1_audio.cpython-38.pyc ADDED
Binary file (11 kB). View file
 
ldm/models/diffusion/__pycache__/cfm1_audio.cpython-39.pyc ADDED
Binary file (11 kB). View file
 
ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc ADDED
Binary file (7.62 kB). View file
 
ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc ADDED
Binary file (7.56 kB). View file
 
ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc ADDED
Binary file (44.4 kB). View file
 
ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc ADDED
Binary file (44.3 kB). View file
 
ldm/models/diffusion/__pycache__/ddpm_audio.cpython-38.pyc ADDED
Binary file (25.9 kB). View file
 
ldm/models/diffusion/__pycache__/ddpm_audio.cpython-39.pyc ADDED
Binary file (25.9 kB). View file
 
ldm/models/diffusion/__pycache__/plms.cpython-38.pyc ADDED
Binary file (7.38 kB). View file
 
ldm/models/diffusion/__pycache__/plms.cpython-39.pyc ADDED
Binary file (7.31 kB). View file
 
ldm/models/diffusion/audioldm.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from audioldm.utils import default, instantiate_from_config, save_wave
7
+ from audioldm.latent_diffusion.ddpm import DDPM
8
+ from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
9
+ from audioldm.latent_diffusion.util import noise_like
10
+ from audioldm.latent_diffusion.ddim import DDIMSampler
11
+ import os
12
+
13
+
14
+ def disabled_train(self, mode=True):
15
+ """Overwrite model.train with this function to make sure train/eval mode
16
+ does not change anymore."""
17
+ return self
18
+
19
+
20
+ class LatentDiffusion(DDPM):
21
+ """main class"""
22
+
23
+ def __init__(
24
+ self,
25
+ device="cuda",
26
+ first_stage_config=None,
27
+ cond_stage_config=None,
28
+ num_timesteps_cond=None,
29
+ cond_stage_key="image",
30
+ cond_stage_trainable=False,
31
+ concat_mode=True,
32
+ cond_stage_forward=None,
33
+ conditioning_key=None,
34
+ scale_factor=1.0,
35
+ scale_by_std=False,
36
+ base_learning_rate=None,
37
+ *args,
38
+ **kwargs,
39
+ ):
40
+ self.device = device
41
+ self.learning_rate = base_learning_rate
42
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
43
+ self.scale_by_std = scale_by_std
44
+ assert self.num_timesteps_cond <= kwargs["timesteps"]
45
+ # for backwards compatibility after implementation of DiffusionWrapper
46
+ if conditioning_key is None:
47
+ conditioning_key = "concat" if concat_mode else "crossattn"
48
+ if cond_stage_config == "__is_unconditional__":
49
+ conditioning_key = None
50
+ ckpt_path = kwargs.pop("ckpt_path", None)
51
+ ignore_keys = kwargs.pop("ignore_keys", [])
52
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
53
+ self.concat_mode = concat_mode
54
+ self.cond_stage_trainable = cond_stage_trainable
55
+ self.cond_stage_key = cond_stage_key
56
+ self.cond_stage_key_orig = cond_stage_key
57
+ try:
58
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
59
+ except:
60
+ self.num_downs = 0
61
+ if not scale_by_std:
62
+ self.scale_factor = scale_factor
63
+ else:
64
+ self.register_buffer("scale_factor", torch.tensor(scale_factor))
65
+ self.instantiate_first_stage(first_stage_config)
66
+ self.instantiate_cond_stage(cond_stage_config)
67
+ self.cond_stage_forward = cond_stage_forward
68
+ self.clip_denoised = False
69
+
70
+ def make_cond_schedule(
71
+ self,
72
+ ):
73
+ self.cond_ids = torch.full(
74
+ size=(self.num_timesteps,),
75
+ fill_value=self.num_timesteps - 1,
76
+ dtype=torch.long,
77
+ )
78
+ ids = torch.round(
79
+ torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
80
+ ).long()
81
+ self.cond_ids[: self.num_timesteps_cond] = ids
82
+
83
+ def register_schedule(
84
+ self,
85
+ given_betas=None,
86
+ beta_schedule="linear",
87
+ timesteps=1000,
88
+ linear_start=1e-4,
89
+ linear_end=2e-2,
90
+ cosine_s=8e-3,
91
+ ):
92
+ super().register_schedule(
93
+ given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
94
+ )
95
+
96
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
97
+ if self.shorten_cond_schedule:
98
+ self.make_cond_schedule()
99
+
100
+ def instantiate_first_stage(self, config):
101
+ model = instantiate_from_config(config)
102
+ self.first_stage_model = model.eval()
103
+ self.first_stage_model.train = disabled_train
104
+ for param in self.first_stage_model.parameters():
105
+ param.requires_grad = False
106
+
107
+ def instantiate_cond_stage(self, config):
108
+ if not self.cond_stage_trainable:
109
+ if config == "__is_first_stage__":
110
+ print("Using first stage also as cond stage.")
111
+ self.cond_stage_model = self.first_stage_model
112
+ elif config == "__is_unconditional__":
113
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
114
+ self.cond_stage_model = None
115
+ # self.be_unconditional = True
116
+ else:
117
+ model = instantiate_from_config(config)
118
+ self.cond_stage_model = model.eval()
119
+ self.cond_stage_model.train = disabled_train
120
+ for param in self.cond_stage_model.parameters():
121
+ param.requires_grad = False
122
+ else:
123
+ assert config != "__is_first_stage__"
124
+ assert config != "__is_unconditional__"
125
+ model = instantiate_from_config(config)
126
+ self.cond_stage_model = model
127
+ self.cond_stage_model = self.cond_stage_model.to(self.device)
128
+
129
+ def get_first_stage_encoding(self, encoder_posterior):
130
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
131
+ z = encoder_posterior.sample()
132
+ elif isinstance(encoder_posterior, torch.Tensor):
133
+ z = encoder_posterior
134
+ else:
135
+ raise NotImplementedError(
136
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
137
+ )
138
+ return self.scale_factor * z
139
+
140
+ def get_learned_conditioning(self, c):
141
+ if self.cond_stage_forward is None:
142
+ if hasattr(self.cond_stage_model, "encode") and callable(
143
+ self.cond_stage_model.encode
144
+ ):
145
+ c = self.cond_stage_model.encode(c)
146
+ if isinstance(c, DiagonalGaussianDistribution):
147
+ c = c.mode()
148
+ else:
149
+ # Text input is list
150
+ if type(c) == list and len(c) == 1:
151
+ c = self.cond_stage_model([c[0], c[0]])
152
+ c = c[0:1] # (2,1,512) -> (1,1,512)
153
+ else:
154
+ c = self.cond_stage_model(c)
155
+ else:
156
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
157
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
158
+ return c
159
+
160
+ @torch.no_grad()
161
+ def get_input(
162
+ self,
163
+ batch,
164
+ k,
165
+ return_first_stage_encode=True,
166
+ return_first_stage_outputs=False,
167
+ force_c_encode=False,
168
+ cond_key=None,
169
+ return_original_cond=False,
170
+ bs=None,
171
+ ):
172
+ x = super().get_input(batch, k)# shape(b,1,T=1024,melbins=64)
173
+
174
+ if bs is not None:
175
+ x = x[:bs]
176
+
177
+ x = x.to(self.device)
178
+
179
+ if return_first_stage_encode:
180
+ encoder_posterior = self.encode_first_stage(x)
181
+ z = self.get_first_stage_encoding(encoder_posterior).detach()# z:(b,8,256,16) 长压缩4倍,宽压缩4倍,dim增到8倍,基本没做压缩嘛
182
+ else:
183
+ z = None
184
+
185
+ if self.model.conditioning_key is not None:
186
+ if cond_key is None:
187
+ cond_key = self.cond_stage_key
188
+ if cond_key != self.first_stage_key:
189
+ if cond_key in ["caption", "coordinates_bbox"]:
190
+ xc = batch[cond_key]
191
+ elif cond_key == "class_label":
192
+ xc = batch
193
+ else:
194
+ # [bs, 1, 527]
195
+ xc = super().get_input(batch, cond_key)
196
+ if type(xc) == torch.Tensor:
197
+ xc = xc.to(self.device)
198
+ else:
199
+ xc = x
200
+ if not self.cond_stage_trainable or force_c_encode:
201
+ if isinstance(xc, dict) or isinstance(xc, list):
202
+ c = self.get_learned_conditioning(xc)
203
+ else:
204
+ c = self.get_learned_conditioning(xc.to(self.device))
205
+ else:
206
+ c = xc
207
+
208
+ if bs is not None:
209
+ c = c[:bs]
210
+
211
+ else:
212
+ c = None
213
+ xc = None
214
+ if self.use_positional_encodings:
215
+ pos_x, pos_y = self.compute_latent_shifts(batch)
216
+ c = {"pos_x": pos_x, "pos_y": pos_y}
217
+ out = [z, c]# z:(b,8,256,16)
218
+ if return_first_stage_outputs:
219
+ xrec = self.decode_first_stage(z)
220
+ out.extend([x, xrec])
221
+ if return_original_cond:
222
+ out.append(xc)
223
+ return out
224
+
225
+ @torch.no_grad()
226
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
227
+ if predict_cids:
228
+ if z.dim() == 4:
229
+ z = torch.argmax(z.exp(), dim=1).long()
230
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
231
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
232
+
233
+ z = 1.0 / self.scale_factor * z
234
+ return self.first_stage_model.decode(z)
235
+
236
+ def mel_spectrogram_to_waveform(self, mel):
237
+ # Mel: [bs, 1, t-steps, fbins]
238
+ if len(mel.size()) == 4:
239
+ mel = mel.squeeze(1)
240
+ mel = mel.permute(0, 2, 1)
241
+ waveform = self.first_stage_model.vocoder(mel)
242
+ waveform = waveform.cpu().detach().numpy()
243
+ return waveform
244
+
245
+ @torch.no_grad()
246
+ def encode_first_stage(self, x):
247
+ return self.first_stage_model.encode(x)
248
+
249
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
250
+
251
+ if isinstance(cond, dict):
252
+ # hybrid case, cond is exptected to be a dict
253
+ pass
254
+ else:
255
+ if not isinstance(cond, list):
256
+ cond = [cond]
257
+ if self.model.conditioning_key == "concat":
258
+ key = "c_concat"
259
+ elif self.model.conditioning_key == "crossattn":
260
+ key = "c_crossattn"
261
+ else:
262
+ key = "c_film"
263
+
264
+ cond = {key: cond}
265
+
266
+ x_recon = self.model(x_noisy, t, **cond)
267
+
268
+ if isinstance(x_recon, tuple) and not return_ids:
269
+ return x_recon[0]
270
+ else:
271
+ return x_recon
272
+
273
+ def p_mean_variance(
274
+ self,
275
+ x,
276
+ c,
277
+ t,
278
+ clip_denoised: bool,
279
+ return_codebook_ids=False,
280
+ quantize_denoised=False,
281
+ return_x0=False,
282
+ score_corrector=None,
283
+ corrector_kwargs=None,
284
+ ):
285
+ t_in = t
286
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
287
+
288
+ if score_corrector is not None:
289
+ assert self.parameterization == "eps"
290
+ model_out = score_corrector.modify_score(
291
+ self, model_out, x, t, c, **corrector_kwargs
292
+ )
293
+
294
+ if return_codebook_ids:
295
+ model_out, logits = model_out
296
+
297
+ if self.parameterization == "eps":
298
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
299
+ elif self.parameterization == "x0":
300
+ x_recon = model_out
301
+ else:
302
+ raise NotImplementedError()
303
+
304
+ if clip_denoised:
305
+ x_recon.clamp_(-1.0, 1.0)
306
+ if quantize_denoised:
307
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
308
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
309
+ x_start=x_recon, x_t=x, t=t
310
+ )
311
+ if return_codebook_ids:
312
+ return model_mean, posterior_variance, posterior_log_variance, logits
313
+ elif return_x0:
314
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
315
+ else:
316
+ return model_mean, posterior_variance, posterior_log_variance
317
+
318
+ @torch.no_grad()
319
+ def p_sample(
320
+ self,
321
+ x,
322
+ c,
323
+ t,
324
+ clip_denoised=False,
325
+ repeat_noise=False,
326
+ return_codebook_ids=False,
327
+ quantize_denoised=False,
328
+ return_x0=False,
329
+ temperature=1.0,
330
+ noise_dropout=0.0,
331
+ score_corrector=None,
332
+ corrector_kwargs=None,
333
+ ):
334
+ b, *_, device = *x.shape, x.device
335
+ outputs = self.p_mean_variance(
336
+ x=x,
337
+ c=c,
338
+ t=t,
339
+ clip_denoised=clip_denoised,
340
+ return_codebook_ids=return_codebook_ids,
341
+ quantize_denoised=quantize_denoised,
342
+ return_x0=return_x0,
343
+ score_corrector=score_corrector,
344
+ corrector_kwargs=corrector_kwargs,
345
+ )
346
+ if return_codebook_ids:
347
+ raise DeprecationWarning("Support dropped.")
348
+ model_mean, _, model_log_variance, logits = outputs
349
+ elif return_x0:
350
+ model_mean, _, model_log_variance, x0 = outputs
351
+ else:
352
+ model_mean, _, model_log_variance = outputs
353
+
354
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
355
+ if noise_dropout > 0.0:
356
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
357
+ # no noise when t == 0
358
+ nonzero_mask = (
359
+ (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
360
+ )
361
+
362
+ if return_codebook_ids:
363
+ return model_mean + nonzero_mask * (
364
+ 0.5 * model_log_variance
365
+ ).exp() * noise, logits.argmax(dim=1)
366
+ if return_x0:
367
+ return (
368
+ model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
369
+ x0,
370
+ )
371
+ else:
372
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
373
+
374
+ @torch.no_grad()
375
+ def progressive_denoising(
376
+ self,
377
+ cond,
378
+ shape,
379
+ verbose=True,
380
+ callback=None,
381
+ quantize_denoised=False,
382
+ img_callback=None,
383
+ mask=None,
384
+ x0=None,
385
+ temperature=1.0,
386
+ noise_dropout=0.0,
387
+ score_corrector=None,
388
+ corrector_kwargs=None,
389
+ batch_size=None,
390
+ x_T=None,
391
+ start_T=None,
392
+ log_every_t=None,
393
+ ):
394
+ if not log_every_t:
395
+ log_every_t = self.log_every_t
396
+ timesteps = self.num_timesteps
397
+ if batch_size is not None:
398
+ b = batch_size if batch_size is not None else shape[0]
399
+ shape = [batch_size] + list(shape)
400
+ else:
401
+ b = batch_size = shape[0]
402
+ if x_T is None:
403
+ img = torch.randn(shape, device=self.device)
404
+ else:
405
+ img = x_T
406
+ intermediates = []
407
+ if cond is not None:
408
+ if isinstance(cond, dict):
409
+ cond = {
410
+ key: cond[key][:batch_size]
411
+ if not isinstance(cond[key], list)
412
+ else list(map(lambda x: x[:batch_size], cond[key]))
413
+ for key in cond
414
+ }
415
+ else:
416
+ cond = (
417
+ [c[:batch_size] for c in cond]
418
+ if isinstance(cond, list)
419
+ else cond[:batch_size]
420
+ )
421
+
422
+ if start_T is not None:
423
+ timesteps = min(timesteps, start_T)
424
+ iterator = (
425
+ tqdm(
426
+ reversed(range(0, timesteps)),
427
+ desc="Progressive Generation",
428
+ total=timesteps,
429
+ )
430
+ if verbose
431
+ else reversed(range(0, timesteps))
432
+ )
433
+ if type(temperature) == float:
434
+ temperature = [temperature] * timesteps
435
+
436
+ for i in iterator:
437
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
438
+ if self.shorten_cond_schedule:
439
+ assert self.model.conditioning_key != "hybrid"
440
+ tc = self.cond_ids[ts].to(cond.device)
441
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
442
+
443
+ img, x0_partial = self.p_sample(
444
+ img,
445
+ cond,
446
+ ts,
447
+ clip_denoised=self.clip_denoised,
448
+ quantize_denoised=quantize_denoised,
449
+ return_x0=True,
450
+ temperature=temperature[i],
451
+ noise_dropout=noise_dropout,
452
+ score_corrector=score_corrector,
453
+ corrector_kwargs=corrector_kwargs,
454
+ )
455
+ if mask is not None:
456
+ assert x0 is not None
457
+ img_orig = self.q_sample(x0, ts)
458
+ img = img_orig * mask + (1.0 - mask) * img
459
+
460
+ if i % log_every_t == 0 or i == timesteps - 1:
461
+ intermediates.append(x0_partial)
462
+ if callback:
463
+ callback(i)
464
+ if img_callback:
465
+ img_callback(img, i)
466
+ return img, intermediates
467
+
468
+ @torch.no_grad()
469
+ def p_sample_loop(
470
+ self,
471
+ cond,
472
+ shape,
473
+ return_intermediates=False,
474
+ x_T=None,
475
+ verbose=True,
476
+ callback=None,
477
+ timesteps=None,
478
+ quantize_denoised=False,
479
+ mask=None,
480
+ x0=None,
481
+ img_callback=None,
482
+ start_T=None,
483
+ log_every_t=None,
484
+ ):
485
+
486
+ if not log_every_t:
487
+ log_every_t = self.log_every_t
488
+ device = self.betas.device
489
+ b = shape[0]
490
+ if x_T is None:
491
+ img = torch.randn(shape, device=device)
492
+ else:
493
+ img = x_T
494
+
495
+ intermediates = [img]
496
+ if timesteps is None:
497
+ timesteps = self.num_timesteps
498
+
499
+ if start_T is not None:
500
+ timesteps = min(timesteps, start_T)
501
+ iterator = (
502
+ tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
503
+ if verbose
504
+ else reversed(range(0, timesteps))
505
+ )
506
+
507
+ if mask is not None:
508
+ assert x0 is not None
509
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
510
+
511
+ for i in iterator:
512
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
513
+ if self.shorten_cond_schedule:
514
+ assert self.model.conditioning_key != "hybrid"
515
+ tc = self.cond_ids[ts].to(cond.device)
516
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
517
+
518
+ img = self.p_sample(
519
+ img,
520
+ cond,
521
+ ts,
522
+ clip_denoised=self.clip_denoised,
523
+ quantize_denoised=quantize_denoised,
524
+ )
525
+ if mask is not None:
526
+ img_orig = self.q_sample(x0, ts)
527
+ img = img_orig * mask + (1.0 - mask) * img
528
+
529
+ if i % log_every_t == 0 or i == timesteps - 1:
530
+ intermediates.append(img)
531
+ if callback:
532
+ callback(i)
533
+ if img_callback:
534
+ img_callback(img, i)
535
+
536
+ if return_intermediates:
537
+ return img, intermediates
538
+ return img
539
+
540
+ @torch.no_grad()
541
+ def sample(
542
+ self,
543
+ cond,
544
+ batch_size=16,
545
+ return_intermediates=False,
546
+ x_T=None,
547
+ verbose=True,
548
+ timesteps=None,
549
+ quantize_denoised=False,
550
+ mask=None,
551
+ x0=None,
552
+ shape=None,
553
+ **kwargs,
554
+ ):
555
+ if shape is None:
556
+ shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size)
557
+ if cond is not None:
558
+ if isinstance(cond, dict):
559
+ cond = {
560
+ key: cond[key][:batch_size]
561
+ if not isinstance(cond[key], list)
562
+ else list(map(lambda x: x[:batch_size], cond[key]))
563
+ for key in cond
564
+ }
565
+ else:
566
+ cond = (
567
+ [c[:batch_size] for c in cond]
568
+ if isinstance(cond, list)
569
+ else cond[:batch_size]
570
+ )
571
+ return self.p_sample_loop(
572
+ cond,
573
+ shape,
574
+ return_intermediates=return_intermediates,
575
+ x_T=x_T,
576
+ verbose=verbose,
577
+ timesteps=timesteps,
578
+ quantize_denoised=quantize_denoised,
579
+ mask=mask,
580
+ x0=x0,
581
+ **kwargs,
582
+ )
583
+
584
+ @torch.no_grad()
585
+ def sample_log(
586
+ self,
587
+ cond,
588
+ batch_size,
589
+ ddim,
590
+ ddim_steps,
591
+ unconditional_guidance_scale=1.0,
592
+ unconditional_conditioning=None,
593
+ use_plms=False,
594
+ mask=None,
595
+ **kwargs,
596
+ ):
597
+
598
+ if mask is not None:
599
+ shape = (self.channels, mask.size()[-2], mask.size()[-1])
600
+ else:
601
+ shape = (self.channels, self.latent_t_size, self.latent_f_size)
602
+
603
+ intermediate = None
604
+ if ddim and not use_plms:
605
+ # print("Use ddim sampler")
606
+
607
+ ddim_sampler = DDIMSampler(self)
608
+ samples, intermediates = ddim_sampler.sample(
609
+ ddim_steps,
610
+ batch_size,
611
+ shape,
612
+ cond,
613
+ verbose=False,
614
+ unconditional_guidance_scale=unconditional_guidance_scale,
615
+ unconditional_conditioning=unconditional_conditioning,
616
+ mask=mask,
617
+ **kwargs,
618
+ )
619
+
620
+ else:
621
+ # print("Use DDPM sampler")
622
+ samples, intermediates = self.sample(
623
+ cond=cond,
624
+ batch_size=batch_size,
625
+ return_intermediates=True,
626
+ unconditional_guidance_scale=unconditional_guidance_scale,
627
+ mask=mask,
628
+ unconditional_conditioning=unconditional_conditioning,
629
+ **kwargs,
630
+ )
631
+
632
+ return samples, intermediate
633
+
634
+ @torch.no_grad()
635
+ def generate_sample(
636
+ self,
637
+ batchs,
638
+ ddim_steps=200,
639
+ ddim_eta=1.0,
640
+ x_T=None,
641
+ n_candidate_gen_per_text=1,
642
+ unconditional_guidance_scale=1.0,
643
+ unconditional_conditioning=None,
644
+ name="waveform",
645
+ use_plms=False,
646
+ save=False,
647
+ **kwargs,
648
+ ):
649
+ # Generate n_candidate_gen_per_text times and select the best
650
+ # Batch: audio, text, fnames
651
+ assert x_T is None
652
+ try:
653
+ batchs = iter(batchs)
654
+ except TypeError:
655
+ raise ValueError("The first input argument should be an iterable object")
656
+
657
+ if use_plms:
658
+ assert ddim_steps is not None
659
+ use_ddim = ddim_steps is not None
660
+ # waveform_save_path = os.path.join(self.get_log_dir(), name)
661
+ # os.makedirs(waveform_save_path, exist_ok=True)
662
+ # print("Waveform save path: ", waveform_save_path)
663
+
664
+ with self.ema_scope("Generate"):
665
+ for batch in batchs:
666
+ z, c = self.get_input(
667
+ batch,
668
+ self.first_stage_key,
669
+ cond_key=self.cond_stage_key,
670
+ return_first_stage_outputs=False,
671
+ force_c_encode=True,
672
+ return_original_cond=False,
673
+ bs=None,
674
+ )
675
+ text = super().get_input(batch, "text")
676
+
677
+ # Generate multiple samples
678
+ batch_size = z.shape[0] * n_candidate_gen_per_text
679
+ c = torch.cat([c] * n_candidate_gen_per_text, dim=0)
680
+ text = text * n_candidate_gen_per_text
681
+
682
+ if unconditional_guidance_scale != 1.0:
683
+ unconditional_conditioning = (
684
+ self.cond_stage_model.get_unconditional_condition(batch_size)
685
+ )
686
+
687
+ samples, _ = self.sample_log(
688
+ cond=c,
689
+ batch_size=batch_size,
690
+ x_T=x_T,
691
+ ddim=use_ddim,
692
+ ddim_steps=ddim_steps,
693
+ eta=ddim_eta,
694
+ unconditional_guidance_scale=unconditional_guidance_scale,
695
+ unconditional_conditioning=unconditional_conditioning,
696
+ use_plms=use_plms,
697
+ )
698
+
699
+ if(torch.max(torch.abs(samples)) > 1e2):
700
+ samples = torch.clip(samples, min=-10, max=10)
701
+
702
+ mel = self.decode_first_stage(samples)
703
+
704
+ waveform = self.mel_spectrogram_to_waveform(mel)
705
+
706
+ if waveform.shape[0] > 1:
707
+ similarity = self.cond_stage_model.cos_similarity(
708
+ torch.FloatTensor(waveform).squeeze(1), text
709
+ )
710
+
711
+ best_index = []
712
+ for i in range(z.shape[0]):
713
+ candidates = similarity[i :: z.shape[0]]
714
+ max_index = torch.argmax(candidates).item()
715
+ best_index.append(i + max_index * z.shape[0])
716
+
717
+ waveform = waveform[best_index]
718
+ # print("Similarity between generated audio and text", similarity)
719
+ # print("Choose the following indexes:", best_index)
720
+
721
+ return waveform
722
+
723
+ @torch.no_grad()
724
+ def generate_sample_masked(
725
+ self,
726
+ batchs,
727
+ ddim_steps=200,
728
+ ddim_eta=1.0,
729
+ x_T=None,
730
+ n_candidate_gen_per_text=1,
731
+ unconditional_guidance_scale=1.0,
732
+ unconditional_conditioning=None,
733
+ name="waveform",
734
+ use_plms=False,
735
+ time_mask_ratio_start_and_end=(0.25, 0.75),
736
+ freq_mask_ratio_start_and_end=(0.75, 1.0),
737
+ save=False,
738
+ **kwargs,
739
+ ):
740
+ # Generate n_candidate_gen_per_text times and select the best
741
+ # Batch: audio, text, fnames
742
+ assert x_T is None
743
+ try:
744
+ batchs = iter(batchs)
745
+ except TypeError:
746
+ raise ValueError("The first input argument should be an iterable object")
747
+
748
+ if use_plms:
749
+ assert ddim_steps is not None
750
+ use_ddim = ddim_steps is not None
751
+ # waveform_save_path = os.path.join(self.get_log_dir(), name)
752
+ # os.makedirs(waveform_save_path, exist_ok=True)
753
+ # print("Waveform save path: ", waveform_save_path)
754
+
755
+ with self.ema_scope("Generate"):
756
+ for batch in batchs:
757
+ z, c = self.get_input(
758
+ batch,
759
+ self.first_stage_key,
760
+ cond_key=self.cond_stage_key,
761
+ return_first_stage_outputs=False,
762
+ force_c_encode=True,
763
+ return_original_cond=False,
764
+ bs=None,
765
+ )
766
+ text = super().get_input(batch, "text")
767
+
768
+ # Generate multiple samples
769
+ batch_size = z.shape[0] * n_candidate_gen_per_text
770
+
771
+ _, h, w = z.shape[0], z.shape[2], z.shape[3]
772
+
773
+ mask = torch.ones(batch_size, h, w).to(self.device)
774
+
775
+ mask[:, int(h * time_mask_ratio_start_and_end[0]) : int(h * time_mask_ratio_start_and_end[1]), :] = 0
776
+ mask[:, :, int(w * freq_mask_ratio_start_and_end[0]) : int(w * freq_mask_ratio_start_and_end[1])] = 0
777
+ mask = mask[:, None, ...]
778
+
779
+ c = torch.cat([c] * n_candidate_gen_per_text, dim=0)
780
+ text = text * n_candidate_gen_per_text
781
+
782
+ if unconditional_guidance_scale != 1.0:
783
+ unconditional_conditioning = (
784
+ self.cond_stage_model.get_unconditional_condition(batch_size)
785
+ )
786
+
787
+ samples, _ = self.sample_log(
788
+ cond=c,
789
+ batch_size=batch_size,
790
+ x_T=x_T,
791
+ ddim=use_ddim,
792
+ ddim_steps=ddim_steps,
793
+ eta=ddim_eta,
794
+ unconditional_guidance_scale=unconditional_guidance_scale,
795
+ unconditional_conditioning=unconditional_conditioning,
796
+ use_plms=use_plms, mask=mask, x0=torch.cat([z] * n_candidate_gen_per_text)
797
+ )
798
+
799
+ mel = self.decode_first_stage(samples)
800
+
801
+ waveform = self.mel_spectrogram_to_waveform(mel)
802
+
803
+ if waveform.shape[0] > 1:
804
+ similarity = self.cond_stage_model.cos_similarity(
805
+ torch.FloatTensor(waveform).squeeze(1), text
806
+ )
807
+
808
+ best_index = []
809
+ for i in range(z.shape[0]):
810
+ candidates = similarity[i :: z.shape[0]]
811
+ max_index = torch.argmax(candidates).item()
812
+ best_index.append(i + max_index * z.shape[0])
813
+
814
+ waveform = waveform[best_index]
815
+ # print("Similarity between generated audio and text", similarity)
816
+ # print("Choose the following indexes:", best_index)
817
+
818
+ return waveform
ldm/models/diffusion/cfm1_audio.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pytorch_memlab import LineProfiler,profile
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from einops import rearrange, repeat
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+ from tqdm import tqdm
12
+
13
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps
14
+ from torchvision.utils import make_grid
15
+ try:
16
+ from pytorch_lightning.utilities.distributed import rank_zero_only
17
+ except:
18
+ from pytorch_lightning.utilities import rank_zero_only # torch2
19
+ from torchdyn.core import NeuralODE
20
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
21
+ from ldm.models.diffusion.ddpm_audio import LatentDiffusion_audio, disabled_train
22
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
23
+ from omegaconf import ListConfig
24
+
25
+ __conditioning_keys__ = {'concat': 'c_concat',
26
+ 'crossattn': 'c_crossattn',
27
+ 'adm': 'y'}
28
+
29
+
30
+ class CFM(LatentDiffusion_audio):
31
+
32
+ def __init__(self, **kwargs):
33
+
34
+ super(CFM, self).__init__(**kwargs)
35
+ self.sigma_min = 1e-4
36
+
37
+ def p_losses(self, x_start, cond, t, noise=None):
38
+ x1 = x_start
39
+ x0 = default(noise, lambda: torch.randn_like(x_start))
40
+ ut = x1 - (1 - self.sigma_min) * x0 # 和ut的梯度没关系
41
+ t_unsqueeze = t.unsqueeze(1).unsqueeze(1).float() / self.num_timesteps
42
+ x_noisy = t_unsqueeze * x1 + (1. - (1 - self.sigma_min) * t_unsqueeze) * x0
43
+
44
+ model_output = self.apply_model(x_noisy, t, cond)
45
+
46
+ loss_dict = {}
47
+ prefix = 'train' if self.training else 'val'
48
+ target = ut
49
+
50
+ mean_dims = list(range(1,len(target.shape)))
51
+ loss_simple = self.get_loss(model_output, target, mean=False).mean(dim=mean_dims)
52
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
53
+
54
+ loss = loss_simple
55
+ loss = self.l_simple_weight * loss.mean()
56
+ loss_dict.update({f'{prefix}/loss': loss})
57
+
58
+ return loss, loss_dict
59
+
60
+ @torch.no_grad()
61
+ def sample(self, cond, batch_size=16, timesteps=None, shape=None, x_latent=None, t_start=None, **kwargs):
62
+ if shape is None:
63
+ if self.channels > 0:
64
+ shape = (batch_size, self.channels, self.mel_dim, self.mel_length)
65
+ else:
66
+ shape = (batch_size, self.mel_dim, self.mel_length)
67
+ if cond is not None:
68
+ if isinstance(cond, dict):
69
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
70
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
71
+ else:
72
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
73
+
74
+ neural_ode = NeuralODE(self.ode_wrapper(cond), solver='euler', sensitivity="adjoint", atol=1e-4, rtol=1e-4)
75
+ t_span = torch.linspace(0, 1, 25 if timesteps is None else timesteps)
76
+ if t_start is not None:
77
+ t_span = t_span[t_start:]
78
+
79
+ x0 = torch.randn(shape, device=self.device) if x_latent is None else x_latent
80
+ eval_points, traj = neural_ode(x0, t_span)
81
+
82
+ return traj[-1], traj
83
+
84
+ def ode_wrapper(self, cond):
85
+ # self.estimator receives x, mask, mu, t, spk as arguments
86
+ return Wrapper(self, cond)
87
+
88
+ @torch.no_grad()
89
+ def sample_cfg(self, cond, unconditional_guidance_scale, unconditional_conditioning, batch_size=16, timesteps=None, shape=None, x_latent=None, t_start=None, **kwargs):
90
+ if shape is None:
91
+ if self.channels > 0:
92
+ shape = (batch_size, self.channels, self.mel_dim, self.mel_length)
93
+ else:
94
+ shape = (batch_size, self.mel_dim, self.mel_length)
95
+ if cond is not None:
96
+ if isinstance(cond, dict):
97
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
98
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
99
+ else:
100
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
101
+
102
+ neural_ode = NeuralODE(self.ode_wrapper_cfg(cond, unconditional_guidance_scale, unconditional_conditioning), solver='euler', sensitivity="adjoint", atol=1e-4, rtol=1e-4)
103
+ t_span = torch.linspace(0, 1, 25 if timesteps is None else timesteps)
104
+
105
+ if t_start is not None:
106
+ t_span = t_span[t_start:]
107
+
108
+ x0 = torch.randn(shape, device=self.device) if x_latent is None else x_latent
109
+ eval_points, traj = neural_ode(x0, t_span)
110
+
111
+ return traj[-1], traj
112
+
113
+ def ode_wrapper_cfg(self, cond, unconditional_guidance_scale, unconditional_conditioning):
114
+ # self.estimator receives x, mask, mu, t, spk as arguments
115
+ return Wrapper_cfg(self, cond, unconditional_guidance_scale, unconditional_conditioning)
116
+
117
+
118
+ @torch.no_grad()
119
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
120
+ # fast, but does not allow for exact reconstruction
121
+ # t serves as an index to gather the correct alphas
122
+ # if use_original_steps:
123
+ # sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
124
+ # sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
125
+ # else:
126
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
127
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
128
+ if noise is None:
129
+ noise = torch.randn_like(x0)
130
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
131
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
132
+
133
+
134
+ class Wrapper(nn.Module):
135
+ def __init__(self, net, cond):
136
+ super(Wrapper, self).__init__()
137
+ self.net = net
138
+ self.cond = cond
139
+
140
+ def forward(self, t, x, args):
141
+ t = torch.tensor([t * 1000] * x.shape[0], device=t.device).long()
142
+ return self.net.apply_model(x, t, self.cond)
143
+
144
+
145
+ class Wrapper_cfg(nn.Module):
146
+
147
+ def __init__(self, net, cond, unconditional_guidance_scale, unconditional_conditioning):
148
+ super(Wrapper_cfg, self).__init__()
149
+ self.net = net
150
+ self.cond = cond
151
+ self.unconditional_conditioning = unconditional_conditioning
152
+ self.unconditional_guidance_scale = unconditional_guidance_scale
153
+
154
+ def forward(self, t, x, args):
155
+ x_in = torch.cat([x] * 2)
156
+ t = torch.tensor([t * 1000] * x.shape[0], device=t.device).long()
157
+ t_in = torch.cat([t] * 2)
158
+ c_in = torch.cat([self.unconditional_conditioning, self.cond]) # c/uc shape [b,seq_len=77,dim=1024],c_in shape [b*2,seq_len,dim]
159
+ e_t_uncond, e_t = self.net.apply_model(x_in, t_in, c_in).chunk(2)
160
+ e_t = e_t_uncond + self.unconditional_guidance_scale * (e_t - e_t_uncond)
161
+ return e_t
162
+
163
+
164
+ class CFM_inpaint(CFM):
165
+
166
+ @torch.no_grad()
167
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
168
+ cond_key=None, return_original_cond=False, bs=None):
169
+ x = batch[k]
170
+ if self.channels > 0: # use 4d input
171
+ if len(x.shape) == 3:
172
+ x = x[..., None]
173
+ x = rearrange(x, 'b h w c -> b c h w')
174
+ x = x.to(memory_format=torch.contiguous_format).float()
175
+
176
+ if bs is not None:
177
+ x = x[:bs]
178
+ x = x.to(self.device)
179
+ encoder_posterior = self.encode_first_stage(x)
180
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
181
+
182
+ if self.model.conditioning_key is not None:
183
+ if cond_key is None:
184
+ cond_key = self.cond_stage_key
185
+ if cond_key != self.first_stage_key:
186
+ if cond_key in ['caption', 'coordinates_bbox', 'hybrid_feat']:
187
+ xc = batch[cond_key]
188
+ elif cond_key == 'class_label':
189
+ xc = batch
190
+ else:
191
+ xc = super().get_input(batch, cond_key).to(self.device)
192
+ else:
193
+ xc = x
194
+ ##### Testing #######
195
+ spec = xc['mix_spec'].to(self.device)
196
+ encoder_posterior = self.encode_first_stage(spec)
197
+ z_spec = self.get_first_stage_encoding(encoder_posterior).detach()
198
+ c = {"mix_spec": z_spec, "mix_video_feat": xc['mix_video_feat']}
199
+ ##### Testing #######
200
+ if bs is not None:
201
+ c = {"mix_spec": c["mix_spec"][:bs], "mix_video_feat": c['mix_video_feat'][:bs]}
202
+ # Testing #
203
+ if cond_key == 'masked_image':
204
+ mask = super().get_input(batch, "mask")
205
+ cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # [B, 1, 10, 106]
206
+ c = torch.cat((c, cc), dim=1) # [B, 5, 10, 106]
207
+ # Testing #
208
+ if self.use_positional_encodings:
209
+ pos_x, pos_y = self.compute_latent_shifts(batch)
210
+ ckey = __conditioning_keys__[self.model.conditioning_key]
211
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
212
+
213
+ else:
214
+ c = None
215
+ xc = None
216
+ if self.use_positional_encodings:
217
+ pos_x, pos_y = self.compute_latent_shifts(batch)
218
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
219
+ out = [z, c]
220
+ if return_first_stage_outputs:
221
+ xrec = self.decode_first_stage(z)
222
+ out.extend([x, xrec])
223
+ if return_original_cond:
224
+ out.append(xc)
225
+ return out
226
+
227
+
228
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
229
+
230
+ if isinstance(cond, dict):
231
+ # hybrid case, cond is exptected to be a dict
232
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
233
+ cond = {key: cond}
234
+ else:
235
+ if not isinstance(cond, list):
236
+ cond = [cond]
237
+ if self.model.conditioning_key == "concat":
238
+ key = "c_concat"
239
+ elif self.model.conditioning_key == "crossattn" or self.model.conditioning_key == "hybrid_inpaint":
240
+ key = "c_crossattn"
241
+ else:
242
+ key = "c_film"
243
+ cond = {key: cond}
244
+
245
+
246
+ x_recon = self.model(x_noisy, t, **cond)
247
+
248
+ if isinstance(x_recon, tuple) and not return_ids:
249
+ return x_recon[0]
250
+ else:
251
+ return x_recon
252
+
253
+
254
+
255
+ @torch.no_grad()
256
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
257
+ quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=True,
258
+ plot_diffusion_rows=True, **kwargs):
259
+
260
+ log = dict()
261
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
262
+ return_first_stage_outputs=True,
263
+ force_c_encode=True,
264
+ return_original_cond=True,
265
+ bs=N) # z is latent,c is condition embedding, xc is condition(caption) list
266
+ N = min(x.shape[0], N)
267
+ n_row = min(x.shape[0], n_row)
268
+ log["inputs"] = x if len(x.shape)==4 else x.unsqueeze(1)
269
+ log["reconstruction"] = xrec if len(xrec.shape)==4 else xrec.unsqueeze(1)
270
+ if self.model.conditioning_key is not None:
271
+ if hasattr(self.cond_stage_model, "decode") and self.cond_stage_key != "masked_image":
272
+ xc = self.cond_stage_model.decode(c)
273
+ log["conditioning"] = xc
274
+ elif self.cond_stage_key == "masked_image":
275
+ log["mask"] = c[:, -1, :, :][:, None, :, :]
276
+ xc = self.cond_stage_model.decode(c[:, :self.cond_stage_model.embed_dim, :, :])
277
+ log["conditioning"] = xc
278
+ elif self.cond_stage_key in ["caption"]:
279
+ pass
280
+ # xc = log_txt_as_img((256, 256), batch["caption"])
281
+ # log["conditioning"] = xc
282
+ elif self.cond_stage_key == 'class_label':
283
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
284
+ log['conditioning'] = xc
285
+ elif isimage(xc):
286
+ log["conditioning"] = xc
287
+
288
+ if plot_diffusion_rows:
289
+ # get diffusion row
290
+ diffusion_row = list()
291
+ z_start = z[:n_row]
292
+ for t in range(self.num_timesteps):
293
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
294
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
295
+ t = t.to(self.device).long()
296
+ noise = torch.randn_like(z_start)
297
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
298
+ diffusion_row.append(self.decode_first_stage(z_noisy))
299
+ if len(diffusion_row[0].shape) == 3:
300
+ diffusion_row = [x.unsqueeze(1) for x in diffusion_row]
301
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
302
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
303
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
304
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
305
+ log["diffusion_row"] = diffusion_grid
306
+
307
+ if return_keys:
308
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
309
+ return log
310
+ else:
311
+ return {key: log[key] for key in return_keys}
312
+ return log
ldm/models/diffusion/cfm1_audio_sampler.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pytorch_memlab import LineProfiler,profile
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from einops import rearrange, repeat
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+ from tqdm import tqdm
12
+
13
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps
14
+ from torchvision.utils import make_grid
15
+ try:
16
+ from pytorch_lightning.utilities.distributed import rank_zero_only
17
+ except:
18
+ from pytorch_lightning.utilities import rank_zero_only # torch2
19
+ from torchdyn.core import NeuralODE
20
+ from ldm.models.diffusion.cfm_audio import Wrapper, Wrapper_cfg
21
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
22
+ from omegaconf import ListConfig
23
+
24
+ from ldm.util import log_txt_as_img, exists, default
25
+
26
+ class CFMSampler(object):
27
+
28
+ def __init__(self, model, num_timesteps, schedule="linear", **kwargs):
29
+ super().__init__()
30
+ self.model = model
31
+ self.ddpm_num_timesteps = model.num_timesteps
32
+ self.num_timesteps = num_timesteps
33
+ self.schedule = schedule
34
+
35
+ def register_buffer(self, name, attr):
36
+ if type(attr) == torch.Tensor:
37
+ if attr.device != torch.device("cuda"):
38
+ attr = attr.to(torch.device("cuda"))
39
+ setattr(self, name, attr)
40
+
41
+ def stochastic_encode(self, x_start, t, noise=None):
42
+ x1 = x_start
43
+ x0 = default(noise, lambda: torch.randn_like(x_start))
44
+ t_unsqueeze = 1 - t.unsqueeze(1).unsqueeze(1).float() / self.num_timesteps
45
+ x_noisy = t_unsqueeze * x1 + (1. - (1 - self.model.sigma_min) * t_unsqueeze) * x0
46
+ return x_noisy
47
+
48
+ @torch.no_grad()
49
+ def sample(self, cond, batch_size=16, timesteps=None, shape=None, x_latent=None, t_start=None, **kwargs):
50
+ if shape is None:
51
+ if self.model.channels > 0:
52
+ shape = (batch_size, self.model.channels, self.model.mel_dim, self.model.mel_length)
53
+ else:
54
+ shape = (batch_size, self.model.mel_dim, self.model.mel_length)
55
+ # if cond is not None:
56
+ # if isinstance(cond, dict):
57
+ # cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
58
+ # list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
59
+ # else:
60
+ # cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
61
+
62
+
63
+ neural_ode = NeuralODE(self.ode_wrapper(cond), solver='euler', sensitivity="adjoint", atol=1e-4, rtol=1e-4)
64
+ t_span = torch.linspace(0, 1, 25 if timesteps is None else timesteps)
65
+ if t_start is not None:
66
+ t_span = t_span[t_start:]
67
+
68
+ x0 = torch.randn(shape, device=self.model.device) if x_latent is None else x_latent
69
+ eval_points, traj = neural_ode(x0, t_span)
70
+
71
+ return traj[-1], traj
72
+
73
+ def ode_wrapper(self, cond):
74
+ # self.estimator receives x, mask, mu, t, spk as arguments
75
+ return Wrapper(self.model, cond)
76
+
77
+ @torch.no_grad()
78
+ def sample_cfg(self, cond, unconditional_guidance_scale, unconditional_conditioning, batch_size=16, timesteps=None, shape=None, x_latent=None, t_start=None, **kwargs):
79
+ if shape is None:
80
+ if self.model.channels > 0:
81
+ shape = (batch_size, self.model.channels, self.model.mel_dim, self.model.mel_length)
82
+ else:
83
+ shape = (batch_size, self.model.mel_dim, self.model.mel_length)
84
+ # if cond is not None:
85
+ # if isinstance(cond, dict):
86
+ # cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
87
+ # list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
88
+ # else:
89
+ # cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
90
+
91
+ neural_ode = NeuralODE(self.ode_wrapper_cfg(cond, unconditional_guidance_scale, unconditional_conditioning), solver='euler', sensitivity="adjoint", atol=1e-4, rtol=1e-4)
92
+ t_span = torch.linspace(0, 1, 25 if timesteps is None else timesteps)
93
+
94
+ if t_start is not None:
95
+ t_span = t_span[t_start:]
96
+
97
+ x0 = torch.randn(shape, device=self.model.device) if x_latent is None else x_latent
98
+ eval_points, traj = neural_ode(x0, t_span)
99
+
100
+ return traj[-1], traj
101
+
102
+ def ode_wrapper_cfg(self, cond, unconditional_guidance_scale, unconditional_conditioning):
103
+ # self.estimator receives x, mask, mu, t, spk as arguments
104
+ return Wrapper_cfg(self.model, cond, unconditional_guidance_scale, unconditional_conditioning)
105
+
ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from omegaconf import OmegaConf
5
+ from torch.nn import functional as F
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from copy import deepcopy
9
+ from einops import rearrange
10
+ from glob import glob
11
+ from natsort import natsorted
12
+
13
+ from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
+ from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
+
16
+ __models__ = {
17
+ 'class_label': EncoderUNetModel,
18
+ 'segmentation': UNetModel
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ class NoisyLatentImageClassifier(pl.LightningModule):
29
+
30
+ def __init__(self,
31
+ diffusion_path,
32
+ num_classes,
33
+ ckpt_path=None,
34
+ pool='attention',
35
+ label_key=None,
36
+ diffusion_ckpt_path=None,
37
+ scheduler_config=None,
38
+ weight_decay=1.e-2,
39
+ log_steps=10,
40
+ monitor='val/loss',
41
+ *args,
42
+ **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.num_classes = num_classes
45
+ # get latest config of diffusion model
46
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
48
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
+ self.load_diffusion()
50
+
51
+ self.monitor = monitor
52
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
+ self.log_steps = log_steps
55
+
56
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
+ else self.diffusion_model.cond_stage_key
58
+
59
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
+
61
+ if self.label_key not in __models__:
62
+ raise NotImplementedError()
63
+
64
+ self.load_classifier(ckpt_path, pool)
65
+
66
+ self.scheduler_config = scheduler_config
67
+ self.use_scheduler = self.scheduler_config is not None
68
+ self.weight_decay = weight_decay
69
+
70
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
+ sd = torch.load(path, map_location="cpu")
72
+ if "state_dict" in list(sd.keys()):
73
+ sd = sd["state_dict"]
74
+ keys = list(sd.keys())
75
+ for k in keys:
76
+ for ik in ignore_keys:
77
+ if k.startswith(ik):
78
+ print("Deleting key {} from state_dict.".format(k))
79
+ del sd[k]
80
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
+ sd, strict=False)
82
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
+ if len(missing) > 0:
84
+ print(f"Missing Keys: {missing}")
85
+ if len(unexpected) > 0:
86
+ print(f"Unexpected Keys: {unexpected}")
87
+
88
+ def load_diffusion(self):
89
+ model = instantiate_from_config(self.diffusion_config)
90
+ self.diffusion_model = model.eval()
91
+ self.diffusion_model.train = disabled_train
92
+ for param in self.diffusion_model.parameters():
93
+ param.requires_grad = False
94
+
95
+ def load_classifier(self, ckpt_path, pool):
96
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
+ model_config.out_channels = self.num_classes
99
+ if self.label_key == 'class_label':
100
+ model_config.pool = pool
101
+
102
+ self.model = __models__[self.label_key](**model_config)
103
+ if ckpt_path is not None:
104
+ print('#####################################################################')
105
+ print(f'load from ckpt "{ckpt_path}"')
106
+ print('#####################################################################')
107
+ self.init_from_ckpt(ckpt_path)
108
+
109
+ @torch.no_grad()
110
+ def get_x_noisy(self, x, t, noise=None):
111
+ noise = default(noise, lambda: torch.randn_like(x))
112
+ continuous_sqrt_alpha_cumprod = None
113
+ if self.diffusion_model.use_continuous_noise:
114
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
+ # todo: make sure t+1 is correct here
116
+
117
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
+
120
+ def forward(self, x_noisy, t, *args, **kwargs):
121
+ return self.model(x_noisy, t)
122
+
123
+ @torch.no_grad()
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = rearrange(x, 'b h w c -> b c h w')
129
+ x = x.to(memory_format=torch.contiguous_format).float()
130
+ return x
131
+
132
+ @torch.no_grad()
133
+ def get_conditioning(self, batch, k=None):
134
+ if k is None:
135
+ k = self.label_key
136
+ assert k is not None, 'Needs to provide label key'
137
+
138
+ targets = batch[k].to(self.device)
139
+
140
+ if self.label_key == 'segmentation':
141
+ targets = rearrange(targets, 'b h w c -> b c h w')
142
+ for down in range(self.numd):
143
+ h, w = targets.shape[-2:]
144
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
+
146
+ # targets = rearrange(targets,'b c h w -> b h w c')
147
+
148
+ return targets
149
+
150
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
151
+ _, top_ks = torch.topk(logits, k, dim=1)
152
+ if reduction == "mean":
153
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
+ elif reduction == "none":
155
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
156
+
157
+ def on_train_epoch_start(self):
158
+ # save some memory
159
+ self.diffusion_model.model.to('cpu')
160
+
161
+ @torch.no_grad()
162
+ def write_logs(self, loss, logits, targets):
163
+ log_prefix = 'train' if self.training else 'val'
164
+ log = {}
165
+ log[f"{log_prefix}/loss"] = loss.mean()
166
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
+ logits, targets, k=1, reduction="mean"
168
+ )
169
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
+ logits, targets, k=5, reduction="mean"
171
+ )
172
+
173
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
+ lr = self.optimizers().param_groups[0]['lr']
177
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
+
179
+ def shared_step(self, batch, t=None):
180
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
+ targets = self.get_conditioning(batch)
182
+ if targets.dim() == 4:
183
+ targets = targets.argmax(dim=1)
184
+ if t is None:
185
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
+ else:
187
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
+ x_noisy = self.get_x_noisy(x, t)
189
+ logits = self(x_noisy, t)
190
+
191
+ loss = F.cross_entropy(logits, targets, reduction='none')
192
+
193
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
+
195
+ loss = loss.mean()
196
+ return loss, logits, x_noisy, targets
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ loss, *_ = self.shared_step(batch)
200
+ return loss
201
+
202
+ def reset_noise_accs(self):
203
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
+
206
+ def on_validation_start(self):
207
+ self.reset_noise_accs()
208
+
209
+ @torch.no_grad()
210
+ def validation_step(self, batch, batch_idx):
211
+ loss, *_ = self.shared_step(batch)
212
+
213
+ for t in self.noisy_acc:
214
+ _, logits, _, targets = self.shared_step(batch, t)
215
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
+
218
+ return loss
219
+
220
+ def configure_optimizers(self):
221
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
+
223
+ if self.use_scheduler:
224
+ scheduler = instantiate_from_config(self.scheduler_config)
225
+
226
+ print("Setting up LambdaLR scheduler...")
227
+ scheduler = [
228
+ {
229
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
+ 'interval': 'step',
231
+ 'frequency': 1
232
+ }]
233
+ return [optimizer], scheduler
234
+
235
+ return optimizer
236
+
237
+ @torch.no_grad()
238
+ def log_images(self, batch, N=8, *args, **kwargs):
239
+ log = dict()
240
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
+ log['inputs'] = x
242
+
243
+ y = self.get_conditioning(batch)
244
+
245
+ if self.label_key == 'class_label':
246
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
+ log['labels'] = y
248
+
249
+ if ismap(y):
250
+ log['labels'] = self.diffusion_model.to_rgb(y)
251
+
252
+ for step in range(self.log_steps):
253
+ current_time = step * self.log_time_interval
254
+
255
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
+
257
+ log[f'inputs@t{current_time}'] = x_noisy
258
+
259
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
+ pred = rearrange(pred, 'b h w c -> b c h w')
261
+
262
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
+
264
+ for key in log:
265
+ log[key] = log[key][:N]
266
+
267
+ return log
ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
9
+ extract_into_tensor
10
+
11
+
12
+ class DDIMSampler(object):
13
+ def __init__(self, model, schedule="linear", **kwargs):
14
+ super().__init__()
15
+ self.model = model
16
+ self.ddpm_num_timesteps = model.num_timesteps
17
+ self.schedule = schedule
18
+
19
+ def register_buffer(self, name, attr):
20
+ if type(attr) == torch.Tensor:
21
+ if attr.device != torch.device("cuda"):
22
+ attr = attr.to(torch.device("cuda"))
23
+ setattr(self, name, attr)
24
+
25
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
27
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28
+ alphas_cumprod = self.model.alphas_cumprod
29
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
31
+
32
+ self.register_buffer('betas', to_torch(self.model.betas))
33
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
34
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
35
+
36
+ # calculations for diffusion q(x_t | x_{t-1}) and others
37
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
39
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
42
+
43
+ # ddim sampling parameters
44
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45
+ ddim_timesteps=self.ddim_timesteps,
46
+ eta=ddim_eta,verbose=verbose)
47
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
48
+ self.register_buffer('ddim_alphas', ddim_alphas)
49
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
50
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
51
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
52
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
53
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
54
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
55
+
56
+ @torch.no_grad()
57
+ def sample(self,
58
+ S,
59
+ batch_size,
60
+ shape,
61
+ conditioning=None,
62
+ callback=None,
63
+ normals_sequence=None,
64
+ img_callback=None,
65
+ quantize_x0=False,
66
+ eta=0.,
67
+ mask=None,
68
+ x0=None,
69
+ temperature=1.,
70
+ noise_dropout=0.,
71
+ score_corrector=None,
72
+ corrector_kwargs=None,
73
+ verbose=True,
74
+ x_T=None,
75
+ log_every_t=100,
76
+ unconditional_guidance_scale=1.,
77
+ unconditional_conditioning=None,
78
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
79
+ **kwargs
80
+ ):
81
+ if conditioning is not None:
82
+ if isinstance(conditioning, dict):
83
+ ctmp = conditioning[list(conditioning.keys())[0]]
84
+ while isinstance(ctmp, list): ctmp = ctmp[0]
85
+ cbs = ctmp.shape[0]
86
+ if cbs != batch_size:
87
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
88
+ else:
89
+ if conditioning.shape[0] != batch_size:
90
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
91
+
92
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
93
+ # sampling
94
+ if len(shape)==3:
95
+ C, H, W = shape
96
+ size = (batch_size, C, H, W)
97
+ else:
98
+ C, T = shape
99
+ size = (batch_size, C, T)
100
+ # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
101
+
102
+ samples, intermediates = self.ddim_sampling(conditioning, size,
103
+ callback=callback,
104
+ img_callback=img_callback,
105
+ quantize_denoised=quantize_x0,
106
+ mask=mask, x0=x0,
107
+ ddim_use_original_steps=False,
108
+ noise_dropout=noise_dropout,
109
+ temperature=temperature,
110
+ score_corrector=score_corrector,
111
+ corrector_kwargs=corrector_kwargs,
112
+ x_T=x_T,
113
+ log_every_t=log_every_t,
114
+ unconditional_guidance_scale=unconditional_guidance_scale,
115
+ unconditional_conditioning=unconditional_conditioning,
116
+ )
117
+ return samples, intermediates
118
+
119
+ @torch.no_grad()
120
+ def ddim_sampling(self, cond, shape,
121
+ x_T=None, ddim_use_original_steps=False,
122
+ callback=None, timesteps=None, quantize_denoised=False,
123
+ mask=None, x0=None, img_callback=None, log_every_t=100,
124
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
125
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
126
+ device = self.model.betas.device
127
+ b = shape[0]
128
+ if x_T is None:
129
+ img = torch.randn(shape, device=device)
130
+ else:
131
+ img = x_T
132
+
133
+ if timesteps is None:
134
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
135
+ elif timesteps is not None and not ddim_use_original_steps:
136
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
137
+ timesteps = self.ddim_timesteps[:subset_end]
138
+
139
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
140
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
141
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
142
+
143
+ # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
144
+
145
+ for i, step in enumerate(time_range):
146
+ index = total_steps - i - 1
147
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
148
+
149
+ if mask is not None:
150
+ assert x0 is not None
151
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
152
+ img = img_orig * mask + (1. - mask) * img
153
+
154
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
155
+ quantize_denoised=quantize_denoised, temperature=temperature,
156
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
157
+ corrector_kwargs=corrector_kwargs,
158
+ unconditional_guidance_scale=unconditional_guidance_scale,
159
+ unconditional_conditioning=unconditional_conditioning)
160
+ img, pred_x0 = outs
161
+ if callback: callback(i)
162
+ if img_callback: img_callback(pred_x0, i)
163
+
164
+ if index % log_every_t == 0 or index == total_steps - 1:
165
+ intermediates['x_inter'].append(img)
166
+ intermediates['pred_x0'].append(pred_x0)
167
+
168
+ return img, intermediates
169
+
170
+ @torch.no_grad()
171
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
172
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
173
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
174
+ b, *_, device = *x.shape, x.device
175
+
176
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
177
+ e_t = self.model.apply_model(x, t, c)
178
+ else:
179
+ x_in = torch.cat([x] * 2)
180
+ t_in = torch.cat([t] * 2)
181
+ if isinstance(c, dict):
182
+ assert isinstance(unconditional_conditioning, dict)
183
+ c_in = dict()
184
+ for k in c:
185
+ if isinstance(c[k], list):
186
+ c_in[k] = [torch.cat([
187
+ unconditional_conditioning[k][i],
188
+ c[k][i]]) for i in range(len(c[k]))]
189
+ else:
190
+ c_in[k] = torch.cat([
191
+ unconditional_conditioning[k],
192
+ c[k]])
193
+ elif isinstance(c, list):
194
+ c_in = list()
195
+ assert isinstance(unconditional_conditioning, list)
196
+ for i in range(len(c)):
197
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
198
+ else:
199
+ c_in = torch.cat([unconditional_conditioning, c])# c/uc shape [b,seq_len=77,dim=1024],c_in shape [b*2,seq_len,dim]
200
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
201
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
202
+
203
+ if score_corrector is not None:
204
+ assert self.model.parameterization == "eps"
205
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
206
+
207
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
208
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
209
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
210
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
211
+ # select parameters corresponding to the currently considered timestep
212
+ full_shape = (b,) + tuple([1 for dim in range(len(x.shape)-1)])
213
+ a_t = torch.full(full_shape, alphas[index], device=device)
214
+ a_prev = torch.full(full_shape, alphas_prev[index], device=device)
215
+ sigma_t = torch.full(full_shape, sigmas[index], device=device)
216
+ sqrt_one_minus_at = torch.full(full_shape, sqrt_one_minus_alphas[index],device=device)
217
+
218
+ # current prediction for x_0
219
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
220
+ if quantize_denoised:
221
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
222
+ # direction pointing to x_t
223
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
224
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
225
+ if noise_dropout > 0.:
226
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
227
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
228
+ return x_prev, pred_x0
229
+
230
+ @torch.no_grad()
231
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
232
+ # fast, but does not allow for exact reconstruction
233
+ # t serves as an index to gather the correct alphas
234
+ if use_original_steps:
235
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
236
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
237
+ else:
238
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
239
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
240
+
241
+ if noise is None:
242
+ noise = torch.randn_like(x0)
243
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
244
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
245
+
246
+ @torch.no_grad()
247
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
248
+ use_original_steps=False):
249
+
250
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
251
+ timesteps = timesteps[:t_start]
252
+
253
+ time_range = np.flip(timesteps)
254
+ total_steps = timesteps.shape[0]
255
+ x_dec = x_latent
256
+ for i, step in enumerate(time_range):
257
+ index = total_steps - i - 1
258
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
259
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
260
+ unconditional_guidance_scale=unconditional_guidance_scale,
261
+ unconditional_conditioning=unconditional_conditioning)
262
+ return x_dec
ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+ import pytorch_lightning as pl
12
+ from torch.optim.lr_scheduler import LambdaLR
13
+ from einops import rearrange, repeat
14
+ from contextlib import contextmanager
15
+ from functools import partial
16
+ from tqdm import tqdm
17
+ from torchvision.utils import make_grid
18
+ try:
19
+ from pytorch_lightning.utilities.distributed import rank_zero_only
20
+ except:
21
+ from pytorch_lightning.utilities import rank_zero_only # torch2
22
+
23
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
24
+ from ldm.modules.ema import LitEma
25
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
26
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
27
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
28
+ from ldm.models.diffusion.ddim import DDIMSampler
29
+
30
+
31
+ __conditioning_keys__ = {'concat': 'c_concat',
32
+ 'crossattn': 'c_crossattn',
33
+ 'adm': 'y'}
34
+
35
+
36
+ def disabled_train(self, mode=True):
37
+ """Overwrite model.train with this function to make sure train/eval mode
38
+ does not change anymore."""
39
+ return self
40
+
41
+
42
+ def uniform_on_device(r1, r2, shape, device):
43
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
44
+
45
+
46
+ class DDPM(pl.LightningModule):
47
+ # classic DDPM with Gaussian diffusion, in image space
48
+ def __init__(self,
49
+ unet_config,
50
+ timesteps=1000,
51
+ beta_schedule="linear",
52
+ loss_type="l2",
53
+ ckpt_path=None,
54
+ ignore_keys=[],
55
+ load_only_unet=False,
56
+ monitor="val/loss",
57
+ use_ema=True,
58
+ first_stage_key="image",
59
+ image_size=256,
60
+ channels=3,
61
+ log_every_t=100,
62
+ clip_denoised=True,
63
+ linear_start=1e-4,
64
+ linear_end=2e-2,
65
+ cosine_s=8e-3,
66
+ given_betas=None,
67
+ original_elbo_weight=0.,
68
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
69
+ l_simple_weight=1.,
70
+ conditioning_key=None,
71
+ parameterization="eps", # all config files uses "eps"
72
+ scheduler_config=None,
73
+ use_positional_encodings=False,
74
+ learn_logvar=False,
75
+ logvar_init=0.,
76
+ ):
77
+ super().__init__()
78
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
79
+ self.parameterization = parameterization
80
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
81
+ self.cond_stage_model = None
82
+ self.clip_denoised = clip_denoised
83
+ self.log_every_t = log_every_t
84
+ self.first_stage_key = first_stage_key
85
+ self.image_size = image_size # try conv?
86
+ self.channels = channels
87
+ self.use_positional_encodings = use_positional_encodings
88
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
89
+ count_params(self.model, verbose=True)
90
+ self.use_ema = use_ema
91
+ if self.use_ema:
92
+ self.model_ema = LitEma(self.model)
93
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
94
+
95
+ self.use_scheduler = scheduler_config is not None
96
+ if self.use_scheduler:
97
+ self.scheduler_config = scheduler_config
98
+
99
+ self.v_posterior = v_posterior
100
+ self.original_elbo_weight = original_elbo_weight
101
+ self.l_simple_weight = l_simple_weight
102
+
103
+ if monitor is not None:
104
+ self.monitor = monitor
105
+ if ckpt_path is not None:
106
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
107
+
108
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
109
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
110
+
111
+ self.loss_type = loss_type
112
+
113
+ self.learn_logvar = learn_logvar
114
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
115
+ if self.learn_logvar:
116
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
117
+
118
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
119
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
120
+ if exists(given_betas):
121
+ betas = given_betas
122
+ else:
123
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
124
+ cosine_s=cosine_s)
125
+ alphas = 1. - betas
126
+ alphas_cumprod = np.cumprod(alphas, axis=0)
127
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
128
+
129
+ timesteps, = betas.shape
130
+ self.num_timesteps = int(timesteps)
131
+ self.linear_start = linear_start
132
+ self.linear_end = linear_end
133
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
134
+
135
+ to_torch = partial(torch.tensor, dtype=torch.float32)
136
+
137
+ self.register_buffer('betas', to_torch(betas))
138
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
139
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
140
+
141
+ # calculations for diffusion q(x_t | x_{t-1}) and others
142
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
143
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
144
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
145
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
146
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
147
+
148
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
149
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
150
+ 1. - alphas_cumprod) + self.v_posterior * betas
151
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
152
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
153
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
154
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
155
+ self.register_buffer('posterior_mean_coef1', to_torch(
156
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
157
+ self.register_buffer('posterior_mean_coef2', to_torch(
158
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
159
+
160
+ if self.parameterization == "eps":
161
+ lvlb_weights = self.betas ** 2 / (
162
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
163
+ elif self.parameterization == "x0":
164
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
165
+ else:
166
+ raise NotImplementedError("mu not supported")
167
+ # TODO how to choose this term
168
+ lvlb_weights[0] = lvlb_weights[1]
169
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
170
+ assert not torch.isnan(self.lvlb_weights).all()
171
+
172
+ @contextmanager
173
+ def ema_scope(self, context=None):
174
+ if self.use_ema:
175
+ self.model_ema.store(self.model.parameters())
176
+ self.model_ema.copy_to(self.model)
177
+ if context is not None:
178
+ print(f"{context}: Switched to EMA weights")
179
+ try:
180
+ yield None
181
+ finally:
182
+ if self.use_ema:
183
+ self.model_ema.restore(self.model.parameters())
184
+ if context is not None:
185
+ print(f"{context}: Restored training weights")
186
+
187
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
188
+ sd = torch.load(path, map_location="cpu")
189
+ if "state_dict" in list(sd.keys()):
190
+ sd = sd["state_dict"]
191
+ keys = list(sd.keys())
192
+ for k in keys:
193
+ for ik in ignore_keys:
194
+ if k.startswith(ik):
195
+ print("Deleting key {} from state_dict.".format(k))
196
+ del sd[k]
197
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
198
+ sd, strict=False)
199
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
200
+ if len(missing) > 0:
201
+ print(f"Missing Keys: {missing}")
202
+ if len(unexpected) > 0:
203
+ print(f"Unexpected Keys: {unexpected}")
204
+
205
+ def q_mean_variance(self, x_start, t):
206
+ """
207
+ Get the distribution q(x_t | x_0).
208
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
209
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
210
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
211
+ """
212
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
213
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
214
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
215
+ return mean, variance, log_variance
216
+
217
+ def predict_start_from_noise(self, x_t, t, noise):
218
+ return (
219
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
220
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
221
+ )
222
+
223
+ def q_posterior(self, x_start, x_t, t):
224
+ posterior_mean = (
225
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
226
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
227
+ )
228
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
229
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
230
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
231
+
232
+ def p_mean_variance(self, x, t, clip_denoised: bool):
233
+ model_out = self.model(x, t)
234
+ if self.parameterization == "eps":
235
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
236
+ elif self.parameterization == "x0":
237
+ x_recon = model_out
238
+ if clip_denoised:
239
+ x_recon.clamp_(-1., 1.)
240
+
241
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
242
+ return model_mean, posterior_variance, posterior_log_variance
243
+
244
+ @torch.no_grad()
245
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
246
+ b, *_, device = *x.shape, x.device
247
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
248
+ noise = noise_like(x.shape, device, repeat_noise)
249
+ # no noise when t == 0
250
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
251
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
252
+
253
+ @torch.no_grad()
254
+ def p_sample_loop(self, shape, return_intermediates=False):
255
+ device = self.betas.device
256
+ b = shape[0]
257
+ img = torch.randn(shape, device=device)
258
+ intermediates = [img]
259
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
260
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
261
+ clip_denoised=self.clip_denoised)
262
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
263
+ intermediates.append(img)
264
+ if return_intermediates:
265
+ return img, intermediates
266
+ return img
267
+
268
+ @torch.no_grad()
269
+ def sample(self, batch_size=16, return_intermediates=False):
270
+ image_size = self.image_size
271
+ channels = self.channels
272
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
273
+ return_intermediates=return_intermediates)
274
+
275
+ def q_sample(self, x_start, t, noise=None):
276
+ noise = default(noise, lambda: torch.randn_like(x_start))
277
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
278
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
279
+
280
+ def get_loss(self, pred, target, mean=True):
281
+ if self.loss_type == 'l1':
282
+ loss = (target - pred).abs()
283
+ if mean:
284
+ loss = loss.mean()
285
+ elif self.loss_type == 'l2':
286
+ if mean:
287
+ loss = torch.nn.functional.mse_loss(target, pred)
288
+ else:
289
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
290
+ else:
291
+ raise NotImplementedError("unknown loss type '{loss_type}'")
292
+
293
+ return loss
294
+
295
+ def p_losses(self, x_start, t, noise=None):
296
+ noise = default(noise, lambda: torch.randn_like(x_start))
297
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
298
+ model_out = self.model(x_noisy, t)
299
+
300
+ loss_dict = {}
301
+ if self.parameterization == "eps":
302
+ target = noise
303
+ elif self.parameterization == "x0":
304
+ target = x_start
305
+ else:
306
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
307
+
308
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
309
+
310
+ log_prefix = 'train' if self.training else 'val'
311
+
312
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
313
+ loss_simple = loss.mean() * self.l_simple_weight
314
+
315
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
316
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
317
+
318
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
319
+
320
+ loss_dict.update({f'{log_prefix}/loss': loss})
321
+
322
+ return loss, loss_dict
323
+
324
+ def forward(self, x, *args, **kwargs):
325
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
326
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
327
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
328
+ return self.p_losses(x, t, *args, **kwargs)
329
+
330
+ def get_input(self, batch, k):
331
+ x = batch[k]
332
+ if self.channels > 0:# use 4d input
333
+ if len(x.shape) == 3:
334
+ x = x[..., None]
335
+ x = rearrange(x, 'b h w c -> b c h w')
336
+ x = x.to(memory_format=torch.contiguous_format).float()
337
+ return x
338
+
339
+ def shared_step(self, batch):
340
+ x = self.get_input(batch, self.first_stage_key)
341
+ loss, loss_dict = self(x)
342
+ return loss, loss_dict
343
+
344
+ def training_step(self, batch, batch_idx):
345
+ loss, loss_dict = self.shared_step(batch)
346
+
347
+ self.log_dict(loss_dict, prog_bar=True,
348
+ logger=True, on_step=True, on_epoch=True)
349
+
350
+ self.log('epoch', float(self.trainer.current_epoch))
351
+ self.log("global_step", self.global_step,
352
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
353
+
354
+ if self.use_scheduler:
355
+ lr = self.optimizers().param_groups[0]['lr']
356
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
357
+
358
+ return loss
359
+
360
+ @torch.no_grad()
361
+ def validation_step(self, batch, batch_idx):
362
+ _, loss_dict_no_ema = self.shared_step(batch)
363
+ with self.ema_scope():
364
+ _, loss_dict_ema = self.shared_step(batch)
365
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
366
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True,sync_dist=True)
367
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True,sync_dist=True)
368
+
369
+ def on_train_batch_end(self, *args, **kwargs):
370
+ if self.use_ema:
371
+ self.model_ema(self.model)
372
+
373
+ def _get_rows_from_list(self, samples):
374
+ n_imgs_per_row = len(samples)
375
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
376
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
377
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
378
+ return denoise_grid
379
+
380
+ @torch.no_grad()
381
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
382
+ log = dict()
383
+ x = self.get_input(batch, self.first_stage_key)
384
+ N = min(x.shape[0], N)
385
+ n_row = min(x.shape[0], n_row)
386
+ x = x.to(self.device)[:N]
387
+ log["inputs"] = x
388
+
389
+ # get diffusion row
390
+ diffusion_row = list()
391
+ x_start = x[:n_row]
392
+
393
+ for t in range(self.num_timesteps):
394
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
395
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
396
+ t = t.to(self.device).long()
397
+ noise = torch.randn_like(x_start)
398
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
399
+ diffusion_row.append(x_noisy)
400
+
401
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
402
+
403
+ if sample:
404
+ # get denoise row
405
+ with self.ema_scope("Plotting"):
406
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
407
+
408
+ log["samples"] = samples
409
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
410
+
411
+ if return_keys:
412
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
413
+ return log
414
+ else:
415
+ return {key: log[key] for key in return_keys}
416
+ return log
417
+
418
+ def configure_optimizers(self):
419
+ lr = self.learning_rate
420
+ params = list(self.model.parameters())
421
+ if self.learn_logvar:
422
+ params = params + [self.logvar]
423
+ opt = torch.optim.AdamW(params, lr=lr)
424
+ return opt
425
+
426
+
427
+ class LatentDiffusion(DDPM):
428
+ """main class"""
429
+ def __init__(self,
430
+ first_stage_config,
431
+ cond_stage_config,
432
+ num_timesteps_cond=None,
433
+ cond_stage_key="image",# 'caption' for txt2image, 'masked_image' for inpainting
434
+ cond_stage_trainable=False,
435
+ concat_mode=True,# true for inpainting
436
+ cond_stage_forward=None,
437
+ conditioning_key=None, # 'crossattn' for txt2image, None for inpainting
438
+ scale_factor=1.0,
439
+ scale_by_std=False,
440
+ *args, **kwargs):
441
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
442
+ self.scale_by_std = scale_by_std
443
+ assert self.num_timesteps_cond <= kwargs['timesteps']
444
+ # for backwards compatibility after implementation of DiffusionWrapper
445
+ if conditioning_key is None:
446
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
447
+ if cond_stage_config == '__is_unconditional__':
448
+ conditioning_key = None
449
+ ckpt_path = kwargs.pop("ckpt_path", None)
450
+ ignore_keys = kwargs.pop("ignore_keys", [])
451
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
452
+ self.concat_mode = concat_mode
453
+ self.cond_stage_trainable = cond_stage_trainable
454
+ self.cond_stage_key = cond_stage_key
455
+ try:
456
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
457
+ except:
458
+ self.num_downs = 0
459
+ if not scale_by_std:
460
+ self.scale_factor = scale_factor
461
+ else:
462
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
463
+ self.instantiate_first_stage(first_stage_config)
464
+ self.instantiate_cond_stage(cond_stage_config)
465
+ self.cond_stage_forward = cond_stage_forward
466
+ self.clip_denoised = False
467
+ self.bbox_tokenizer = None
468
+
469
+ self.restarted_from_ckpt = False
470
+ if ckpt_path is not None:
471
+ self.init_from_ckpt(ckpt_path, ignore_keys)
472
+ self.restarted_from_ckpt = True
473
+
474
+ def make_cond_schedule(self, ):
475
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
476
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
477
+ self.cond_ids[:self.num_timesteps_cond] = ids
478
+
479
+ @rank_zero_only
480
+ @torch.no_grad()
481
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
482
+ # only for very first batch
483
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
484
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
485
+ # set rescale weight to 1./std of encodings
486
+ print("### USING STD-RESCALING ###")
487
+ x = super().get_input(batch, self.first_stage_key)
488
+ x = x.to(self.device)
489
+ encoder_posterior = self.encode_first_stage(x)
490
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
491
+ del self.scale_factor
492
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
493
+ print(f"setting self.scale_factor to {self.scale_factor}")
494
+ print("### USING STD-RESCALING ###")
495
+
496
+ def register_schedule(self,
497
+ given_betas=None, beta_schedule="linear", timesteps=1000,
498
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
499
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
500
+
501
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
502
+ if self.shorten_cond_schedule:
503
+ self.make_cond_schedule()
504
+
505
+ def instantiate_first_stage(self, config):
506
+ model = instantiate_from_config(config)
507
+ self.first_stage_model = model.eval()
508
+ self.first_stage_model.train = disabled_train
509
+ for param in self.first_stage_model.parameters():
510
+ param.requires_grad = False
511
+
512
+ def instantiate_cond_stage(self, config):
513
+ if not self.cond_stage_trainable:
514
+ if config == "__is_first_stage__":# inpaint
515
+ print("Using first stage also as cond stage.")
516
+ self.cond_stage_model = self.first_stage_model
517
+ elif config == "__is_unconditional__":
518
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
519
+ self.cond_stage_model = None
520
+ # self.be_unconditional = True
521
+ else:
522
+ model = instantiate_from_config(config)
523
+ self.cond_stage_model = model.eval()
524
+ self.cond_stage_model.train = disabled_train
525
+ for param in self.cond_stage_model.parameters():
526
+ param.requires_grad = False
527
+ else:
528
+ assert config != '__is_first_stage__'
529
+ assert config != '__is_unconditional__'
530
+ model = instantiate_from_config(config)
531
+ self.cond_stage_model = model
532
+
533
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
534
+ denoise_row = []
535
+ for zd in tqdm(samples, desc=desc):
536
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
537
+ force_not_quantize=force_no_decoder_quantization))
538
+ n_imgs_per_row = len(denoise_row)
539
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
540
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
541
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
542
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
543
+ return denoise_grid
544
+
545
+ def get_first_stage_encoding(self, encoder_posterior):
546
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
547
+ z = encoder_posterior.sample()
548
+ elif isinstance(encoder_posterior, torch.Tensor):
549
+ z = encoder_posterior
550
+ else:
551
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
552
+ return self.scale_factor * z
553
+
554
+ def get_learned_conditioning(self, c):
555
+ if self.cond_stage_forward is None:
556
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
557
+ c = self.cond_stage_model.encode(c)
558
+ if isinstance(c, DiagonalGaussianDistribution):
559
+ c = c.mode()
560
+ else:
561
+ c = self.cond_stage_model(c)
562
+ else:
563
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
564
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
565
+ return c
566
+
567
+ def meshgrid(self, h, w):
568
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
569
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
570
+
571
+ arr = torch.cat([y, x], dim=-1)
572
+ return arr
573
+
574
+ def delta_border(self, h, w):
575
+ """
576
+ :param h: height
577
+ :param w: width
578
+ :return: normalized distance to image border,
579
+ wtith min distance = 0 at border and max dist = 0.5 at image center
580
+ """
581
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
582
+ arr = self.meshgrid(h, w) / lower_right_corner
583
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
584
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
585
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
586
+ return edge_dist
587
+
588
+ def get_weighting(self, h, w, Ly, Lx, device):
589
+ weighting = self.delta_border(h, w)
590
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
591
+ self.split_input_params["clip_max_weight"], )
592
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
593
+
594
+ if self.split_input_params["tie_braker"]:
595
+ L_weighting = self.delta_border(Ly, Lx)
596
+ L_weighting = torch.clip(L_weighting,
597
+ self.split_input_params["clip_min_tie_weight"],
598
+ self.split_input_params["clip_max_tie_weight"])
599
+
600
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
601
+ weighting = weighting * L_weighting
602
+ return weighting
603
+
604
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
605
+ """
606
+ :param x: img of size (bs, c, h, w)
607
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
608
+ """
609
+ bs, nc, h, w = x.shape
610
+
611
+ # number of crops in image
612
+ Ly = (h - kernel_size[0]) // stride[0] + 1
613
+ Lx = (w - kernel_size[1]) // stride[1] + 1
614
+
615
+ if uf == 1 and df == 1:
616
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
617
+ unfold = torch.nn.Unfold(**fold_params)
618
+
619
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
620
+
621
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
622
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
623
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
624
+
625
+ elif uf > 1 and df == 1:
626
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
627
+ unfold = torch.nn.Unfold(**fold_params)
628
+
629
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
630
+ dilation=1, padding=0,
631
+ stride=(stride[0] * uf, stride[1] * uf))
632
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
633
+
634
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
635
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
636
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
637
+
638
+ elif df > 1 and uf == 1:
639
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
640
+ unfold = torch.nn.Unfold(**fold_params)
641
+
642
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
643
+ dilation=1, padding=0,
644
+ stride=(stride[0] // df, stride[1] // df))
645
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
646
+
647
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
648
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
649
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
650
+
651
+ else:
652
+ raise NotImplementedError
653
+
654
+ return fold, unfold, normalization, weighting
655
+
656
+ @torch.no_grad()
657
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
658
+ cond_key=None, return_original_cond=False, bs=None):
659
+ x = super().get_input(batch, k)
660
+ if bs is not None:
661
+ x = x[:bs]
662
+ x = x.to(self.device)
663
+ encoder_posterior = self.encode_first_stage(x)
664
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
665
+
666
+ if self.model.conditioning_key is not None:
667
+ if cond_key is None:
668
+ cond_key = self.cond_stage_key
669
+ if cond_key != self.first_stage_key:# cond_key is not image. for inapint it's masked_img
670
+ if cond_key in ['caption', 'coordinates_bbox']:
671
+ xc = batch[cond_key]
672
+ elif cond_key == 'class_label':
673
+ xc = batch
674
+ else:
675
+ xc = super().get_input(batch, cond_key).to(self.device)
676
+ else:
677
+ xc = x
678
+ if not self.cond_stage_trainable or force_c_encode:
679
+ if isinstance(xc, dict) or isinstance(xc, list):
680
+ # import pudb; pudb.set_trace()
681
+ c = self.get_learned_conditioning(xc)
682
+ else:
683
+ c = self.get_learned_conditioning(xc.to(self.device))
684
+ else:
685
+ c = xc
686
+ if bs is not None:
687
+ c = c[:bs]
688
+
689
+ if self.use_positional_encodings:
690
+ pos_x, pos_y = self.compute_latent_shifts(batch)
691
+ ckey = __conditioning_keys__[self.model.conditioning_key]
692
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
693
+
694
+ else:
695
+ c = None
696
+ xc = None
697
+ if self.use_positional_encodings:
698
+ pos_x, pos_y = self.compute_latent_shifts(batch)
699
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
700
+ out = [z, c]
701
+ if return_first_stage_outputs:
702
+ xrec = self.decode_first_stage(z)
703
+ out.extend([x, xrec])
704
+ if return_original_cond:
705
+ out.append(xc)
706
+ return out
707
+
708
+ @torch.no_grad()
709
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
710
+ if predict_cids:
711
+ if z.dim() == 4:
712
+ z = torch.argmax(z.exp(), dim=1).long()
713
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
714
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
715
+
716
+ z = 1. / self.scale_factor * z
717
+
718
+ if hasattr(self, "split_input_params"):
719
+ if self.split_input_params["patch_distributed_vq"]:
720
+ ks = self.split_input_params["ks"] # eg. (128, 128)
721
+ stride = self.split_input_params["stride"] # eg. (64, 64)
722
+ uf = self.split_input_params["vqf"]
723
+ bs, nc, h, w = z.shape
724
+ if ks[0] > h or ks[1] > w:
725
+ ks = (min(ks[0], h), min(ks[1], w))
726
+ print("reducing Kernel")
727
+
728
+ if stride[0] > h or stride[1] > w:
729
+ stride = (min(stride[0], h), min(stride[1], w))
730
+ print("reducing stride")
731
+
732
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
733
+
734
+ z = unfold(z) # (bn, nc * prod(**ks), L)
735
+ # 1. Reshape to img shape
736
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
737
+
738
+ # 2. apply model loop over last dim
739
+ if isinstance(self.first_stage_model, VQModelInterface):
740
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
741
+ force_not_quantize=predict_cids or force_not_quantize)
742
+ for i in range(z.shape[-1])]
743
+ else:
744
+
745
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
746
+ for i in range(z.shape[-1])]
747
+
748
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
749
+ o = o * weighting
750
+ # Reverse 1. reshape to img shape
751
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
752
+ # stitch crops together
753
+ decoded = fold(o)
754
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
755
+ return decoded
756
+ else:
757
+ if isinstance(self.first_stage_model, VQModelInterface):
758
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
759
+ else:
760
+ return self.first_stage_model.decode(z)
761
+
762
+ else:
763
+ if isinstance(self.first_stage_model, VQModelInterface):
764
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
765
+ else:
766
+ return self.first_stage_model.decode(z)
767
+
768
+ # same as above but without decorator
769
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
770
+ if predict_cids:
771
+ if z.dim() == 4:
772
+ z = torch.argmax(z.exp(), dim=1).long()
773
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
774
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
775
+
776
+ z = 1. / self.scale_factor * z
777
+
778
+ if hasattr(self, "split_input_params"):
779
+ if self.split_input_params["patch_distributed_vq"]:
780
+ ks = self.split_input_params["ks"] # eg. (128, 128)
781
+ stride = self.split_input_params["stride"] # eg. (64, 64)
782
+ uf = self.split_input_params["vqf"]
783
+ bs, nc, h, w = z.shape
784
+ if ks[0] > h or ks[1] > w:
785
+ ks = (min(ks[0], h), min(ks[1], w))
786
+ print("reducing Kernel")
787
+
788
+ if stride[0] > h or stride[1] > w:
789
+ stride = (min(stride[0], h), min(stride[1], w))
790
+ print("reducing stride")
791
+
792
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
793
+
794
+ z = unfold(z) # (bn, nc * prod(**ks), L)
795
+ # 1. Reshape to img shape
796
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
797
+
798
+ # 2. apply model loop over last dim
799
+ if isinstance(self.first_stage_model, VQModelInterface):
800
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
801
+ force_not_quantize=predict_cids or force_not_quantize)
802
+ for i in range(z.shape[-1])]
803
+ else:
804
+
805
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
806
+ for i in range(z.shape[-1])]
807
+
808
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
809
+ o = o * weighting
810
+ # Reverse 1. reshape to img shape
811
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
812
+ # stitch crops together
813
+ decoded = fold(o)
814
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
815
+ return decoded
816
+ else:
817
+ if isinstance(self.first_stage_model, VQModelInterface):
818
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
819
+ else:
820
+ return self.first_stage_model.decode(z)
821
+
822
+ else:
823
+ if isinstance(self.first_stage_model, VQModelInterface):
824
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
825
+ else:
826
+ return self.first_stage_model.decode(z)
827
+
828
+ @torch.no_grad()
829
+ def encode_first_stage(self, x):
830
+ if hasattr(self, "split_input_params"):
831
+ if self.split_input_params["patch_distributed_vq"]:
832
+ ks = self.split_input_params["ks"] # eg. (128, 128)
833
+ stride = self.split_input_params["stride"] # eg. (64, 64)
834
+ df = self.split_input_params["vqf"]
835
+ self.split_input_params['original_image_size'] = x.shape[-2:]
836
+ bs, nc, h, w = x.shape
837
+ if ks[0] > h or ks[1] > w:
838
+ ks = (min(ks[0], h), min(ks[1], w))
839
+ print("reducing Kernel")
840
+
841
+ if stride[0] > h or stride[1] > w:
842
+ stride = (min(stride[0], h), min(stride[1], w))
843
+ print("reducing stride")
844
+
845
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
846
+ z = unfold(x) # (bn, nc * prod(**ks), L)
847
+ # Reshape to img shape
848
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
849
+
850
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
851
+ for i in range(z.shape[-1])]
852
+
853
+ o = torch.stack(output_list, axis=-1)
854
+ o = o * weighting
855
+
856
+ # Reverse reshape to img shape
857
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
858
+ # stitch crops together
859
+ decoded = fold(o)
860
+ decoded = decoded / normalization
861
+ return decoded
862
+
863
+ else:
864
+ return self.first_stage_model.encode(x)
865
+ else:
866
+ return self.first_stage_model.encode(x)
867
+
868
+ def shared_step(self, batch, **kwargs):
869
+ x, c = self.get_input(batch, self.first_stage_key)
870
+ loss = self(x, c)
871
+ return loss
872
+
873
+ def forward(self, x, c, *args, **kwargs):
874
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
875
+ if self.model.conditioning_key is not None:
876
+ assert c is not None
877
+ if self.cond_stage_trainable:# true when use text
878
+ c = self.get_learned_conditioning(c) # c: string list -> [B, T, Context_dim]
879
+ if self.shorten_cond_schedule: # TODO: drop this option
880
+ tc = self.cond_ids[t].to(self.device)
881
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
882
+ return self.p_losses(x, c, t, *args, **kwargs)
883
+
884
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
885
+ def rescale_bbox(bbox):
886
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
887
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
888
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
889
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
890
+ return x0, y0, w, h
891
+
892
+ return [rescale_bbox(b) for b in bboxes]
893
+
894
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
895
+
896
+ if isinstance(cond, dict):
897
+ # hybrid case, cond is exptected to be a dict
898
+ pass
899
+ else:
900
+ if not isinstance(cond, list):
901
+ cond = [cond]
902
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
903
+ cond = {key: cond}
904
+
905
+ if hasattr(self, "split_input_params"):
906
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
907
+ assert not return_ids
908
+ ks = self.split_input_params["ks"] # eg. (128, 128)
909
+ stride = self.split_input_params["stride"] # eg. (64, 64)
910
+
911
+ h, w = x_noisy.shape[-2:]
912
+
913
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
914
+
915
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
916
+ # Reshape to img shape
917
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
918
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
919
+
920
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
921
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
922
+ c_key = next(iter(cond.keys())) # get key
923
+ c = next(iter(cond.values())) # get value
924
+ assert (len(c) == 1) # todo extend to list with more than one elem
925
+ c = c[0] # get element
926
+
927
+ c = unfold(c)
928
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
929
+
930
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
931
+
932
+ elif self.cond_stage_key == 'coordinates_bbox':
933
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
934
+
935
+ # assuming padding of unfold is always 0 and its dilation is always 1
936
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
937
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
938
+ # as we are operating on latents, we need the factor from the original image size to the
939
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
940
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
941
+ rescale_latent = 2 ** (num_downs)
942
+
943
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
944
+ # need to rescale the tl patch coordinates to be in between (0,1)
945
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
946
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
947
+ for patch_nr in range(z.shape[-1])]
948
+
949
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
950
+ patch_limits = [(x_tl, y_tl,
951
+ rescale_latent * ks[0] / full_img_w,
952
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
953
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
954
+
955
+ # tokenize crop coordinates for the bounding boxes of the respective patches
956
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
957
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
958
+ print(patch_limits_tknzd[0].shape)
959
+ # cut tknzd crop position from conditioning
960
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
961
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
962
+ print(cut_cond.shape)
963
+
964
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
965
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
966
+ print(adapted_cond.shape)
967
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
968
+ print(adapted_cond.shape)
969
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
970
+ print(adapted_cond.shape)
971
+
972
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
973
+
974
+ else:
975
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
976
+
977
+ # apply model by loop over crops
978
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
979
+ assert not isinstance(output_list[0],
980
+ tuple) # todo cant deal with multiple model outputs check this never happens
981
+
982
+ o = torch.stack(output_list, axis=-1)
983
+ o = o * weighting
984
+ # Reverse reshape to img shape
985
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
986
+ # stitch crops together
987
+ x_recon = fold(o) / normalization
988
+
989
+ else:
990
+ x_recon = self.model(x_noisy, t, **cond)
991
+
992
+ if isinstance(x_recon, tuple) and not return_ids:
993
+ return x_recon[0]
994
+ else:
995
+ return x_recon
996
+
997
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
998
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
999
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1000
+
1001
+ def _prior_bpd(self, x_start):
1002
+ """
1003
+ Get the prior KL term for the variational lower-bound, measured in
1004
+ bits-per-dim.
1005
+ This term can't be optimized, as it only depends on the encoder.
1006
+ :param x_start: the [N x C x ...] tensor of inputs.
1007
+ :return: a batch of [N] KL values (in bits), one per batch element.
1008
+ """
1009
+ batch_size = x_start.shape[0]
1010
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1011
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1012
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1013
+ return mean_flat(kl_prior) / np.log(2.0)
1014
+
1015
+ def p_losses(self, x_start, cond, t, noise=None):
1016
+ noise = default(noise, lambda: torch.randn_like(x_start))
1017
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1018
+ model_output = self.apply_model(x_noisy, t, cond)
1019
+
1020
+ loss_dict = {}
1021
+ prefix = 'train' if self.training else 'val'
1022
+
1023
+ if self.parameterization == "x0":
1024
+ target = x_start
1025
+ elif self.parameterization == "eps":
1026
+ target = noise
1027
+ else:
1028
+ raise NotImplementedError()
1029
+
1030
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1031
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1032
+
1033
+ logvar_t = self.logvar[t].to(self.device)
1034
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1035
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1036
+ if self.learn_logvar:
1037
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1038
+ loss_dict.update({'logvar': self.logvar.data.mean()})
1039
+
1040
+ loss = self.l_simple_weight * loss.mean()
1041
+
1042
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1043
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1044
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1045
+ loss += (self.original_elbo_weight * loss_vlb)
1046
+ loss_dict.update({f'{prefix}/loss': loss})
1047
+
1048
+ return loss, loss_dict
1049
+
1050
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1051
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
1052
+ t_in = t
1053
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1054
+
1055
+ if score_corrector is not None:
1056
+ assert self.parameterization == "eps"
1057
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1058
+
1059
+ if return_codebook_ids:
1060
+ model_out, logits = model_out
1061
+
1062
+ if self.parameterization == "eps":
1063
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1064
+ elif self.parameterization == "x0":
1065
+ x_recon = model_out
1066
+ else:
1067
+ raise NotImplementedError()
1068
+
1069
+ if clip_denoised:
1070
+ x_recon.clamp_(-1., 1.)
1071
+ if quantize_denoised:
1072
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1073
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1074
+ if return_codebook_ids:
1075
+ return model_mean, posterior_variance, posterior_log_variance, logits
1076
+ elif return_x0:
1077
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1078
+ else:
1079
+ return model_mean, posterior_variance, posterior_log_variance
1080
+
1081
+ @torch.no_grad()
1082
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
1083
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1084
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1085
+ b, *_, device = *x.shape, x.device
1086
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1087
+ return_codebook_ids=return_codebook_ids,
1088
+ quantize_denoised=quantize_denoised,
1089
+ return_x0=return_x0,
1090
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1091
+ if return_codebook_ids:
1092
+ raise DeprecationWarning("Support dropped.")
1093
+ model_mean, _, model_log_variance, logits = outputs
1094
+ elif return_x0:
1095
+ model_mean, _, model_log_variance, x0 = outputs
1096
+ else:
1097
+ model_mean, _, model_log_variance = outputs
1098
+
1099
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1100
+ if noise_dropout > 0.:
1101
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1102
+ # no noise when t == 0
1103
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1104
+
1105
+ if return_codebook_ids:
1106
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1107
+ if return_x0:
1108
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1109
+ else:
1110
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1111
+
1112
+ @torch.no_grad()
1113
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1114
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1115
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1116
+ log_every_t=None):
1117
+ if not log_every_t:
1118
+ log_every_t = self.log_every_t # 100
1119
+ timesteps = self.num_timesteps
1120
+ if batch_size is not None:
1121
+ b = batch_size if batch_size is not None else shape[0]
1122
+ shape = [batch_size] + list(shape)
1123
+ else:
1124
+ b = batch_size = shape[0]
1125
+ if x_T is None:
1126
+ img = torch.randn(shape, device=self.device)
1127
+ else:
1128
+ img = x_T
1129
+ intermediates = []
1130
+ if cond is not None:
1131
+ if isinstance(cond, dict):
1132
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1133
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1134
+ else:
1135
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1136
+
1137
+ if start_T is not None:
1138
+ timesteps = min(timesteps, start_T)
1139
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1140
+ total=timesteps) if verbose else reversed(
1141
+ range(0, timesteps))
1142
+ if type(temperature) == float:
1143
+ temperature = [temperature] * timesteps
1144
+
1145
+ for i in iterator:
1146
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1147
+ if self.shorten_cond_schedule:
1148
+ assert self.model.conditioning_key != 'hybrid'
1149
+ tc = self.cond_ids[ts].to(cond.device)
1150
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1151
+
1152
+ img, x0_partial = self.p_sample(img, cond, ts,
1153
+ clip_denoised=self.clip_denoised,
1154
+ quantize_denoised=quantize_denoised, return_x0=True,
1155
+ temperature=temperature[i], noise_dropout=noise_dropout,
1156
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1157
+ if mask is not None:
1158
+ assert x0 is not None
1159
+ img_orig = self.q_sample(x0, ts)
1160
+ img = img_orig * mask + (1. - mask) * img
1161
+
1162
+ if i % log_every_t == 0 or i == timesteps - 1:
1163
+ intermediates.append(x0_partial)
1164
+ if callback: callback(i)
1165
+ if img_callback: img_callback(img, i)
1166
+ return img, intermediates
1167
+
1168
+ @torch.no_grad()
1169
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1170
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1171
+ mask=None, x0=None, img_callback=None, start_T=None,
1172
+ log_every_t=None):
1173
+
1174
+ if not log_every_t:
1175
+ log_every_t = self.log_every_t
1176
+ device = self.betas.device
1177
+ b = shape[0]
1178
+ if x_T is None:
1179
+ img = torch.randn(shape, device=device)
1180
+ else:
1181
+ img = x_T
1182
+
1183
+ intermediates = [img]
1184
+ if timesteps is None:
1185
+ timesteps = self.num_timesteps
1186
+
1187
+ if start_T is not None:
1188
+ timesteps = min(timesteps, start_T)
1189
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1190
+ range(0, timesteps))
1191
+
1192
+ if mask is not None:
1193
+ assert x0 is not None
1194
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1195
+
1196
+ for i in iterator:
1197
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1198
+ if self.shorten_cond_schedule:
1199
+ assert self.model.conditioning_key != 'hybrid'
1200
+ tc = self.cond_ids[ts].to(cond.device)
1201
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1202
+
1203
+ img = self.p_sample(img, cond, ts,
1204
+ clip_denoised=self.clip_denoised,
1205
+ quantize_denoised=quantize_denoised)
1206
+ if mask is not None:
1207
+ img_orig = self.q_sample(x0, ts)
1208
+ img = img_orig * mask + (1. - mask) * img
1209
+
1210
+ if i % log_every_t == 0 or i == timesteps - 1:
1211
+ intermediates.append(img)
1212
+ if callback: callback(i)
1213
+ if img_callback: img_callback(img, i)
1214
+
1215
+ if return_intermediates:
1216
+ return img, intermediates
1217
+ return img
1218
+
1219
+ @torch.no_grad()
1220
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1221
+ verbose=True, timesteps=None, quantize_denoised=False,
1222
+ mask=None, x0=None, shape=None,**kwargs):
1223
+ if shape is None:
1224
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1225
+ if cond is not None:
1226
+ if isinstance(cond, dict):
1227
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1228
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1229
+ else:
1230
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1231
+ return self.p_sample_loop(cond,
1232
+ shape,
1233
+ return_intermediates=return_intermediates, x_T=x_T,
1234
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1235
+ mask=mask, x0=x0)
1236
+
1237
+ @torch.no_grad()
1238
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1239
+
1240
+ if ddim:
1241
+ ddim_sampler = DDIMSampler(self)
1242
+ shape = (self.channels, self.image_size, self.image_size)
1243
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1244
+ shape,cond,verbose=False,**kwargs)
1245
+
1246
+ else:
1247
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1248
+ return_intermediates=True,**kwargs)
1249
+
1250
+ return samples, intermediates
1251
+
1252
+
1253
+ @torch.no_grad()
1254
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1255
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1256
+ plot_diffusion_rows=True, **kwargs):
1257
+
1258
+ use_ddim = ddim_steps is not None
1259
+
1260
+ log = dict()
1261
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1262
+ return_first_stage_outputs=True,
1263
+ force_c_encode=True,
1264
+ return_original_cond=True,
1265
+ bs=N)
1266
+ N = min(x.shape[0], N)
1267
+ n_row = min(x.shape[0], n_row)
1268
+ log["inputs"] = x
1269
+ log["reconstruction"] = xrec
1270
+ if self.model.conditioning_key is not None:
1271
+ if hasattr(self.cond_stage_model, "decode"):
1272
+ xc = self.cond_stage_model.decode(c)
1273
+ log["conditioning"] = xc
1274
+ elif self.cond_stage_key in ["caption"]:
1275
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1276
+ log["conditioning"] = xc
1277
+ elif self.cond_stage_key == 'class_label':
1278
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1279
+ log['conditioning'] = xc
1280
+ elif isimage(xc):
1281
+ log["conditioning"] = xc
1282
+ if ismap(xc):
1283
+ log["original_conditioning"] = self.to_rgb(xc)
1284
+
1285
+ if plot_diffusion_rows:
1286
+ # get diffusion row
1287
+ diffusion_row = list()
1288
+ z_start = z[:n_row]
1289
+ for t in range(self.num_timesteps):
1290
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1291
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1292
+ t = t.to(self.device).long()
1293
+ noise = torch.randn_like(z_start)
1294
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1295
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1296
+
1297
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1298
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1299
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1300
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1301
+ log["diffusion_row"] = diffusion_grid
1302
+
1303
+ if sample:
1304
+ # get denoise row
1305
+ with self.ema_scope("Plotting"):
1306
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1307
+ ddim_steps=ddim_steps,eta=ddim_eta)
1308
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1309
+ x_samples = self.decode_first_stage(samples)
1310
+ log["samples"] = x_samples
1311
+ if plot_denoise_rows:
1312
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1313
+ log["denoise_row"] = denoise_grid
1314
+
1315
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1316
+ self.first_stage_model, IdentityFirstStage):
1317
+ # also display when quantizing x0 while sampling
1318
+ with self.ema_scope("Plotting Quantized Denoised"):
1319
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1320
+ ddim_steps=ddim_steps,eta=ddim_eta,
1321
+ quantize_denoised=True)
1322
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1323
+ # quantize_denoised=True)
1324
+ x_samples = self.decode_first_stage(samples.to(self.device))
1325
+ log["samples_x0_quantized"] = x_samples
1326
+
1327
+ if inpaint:
1328
+ # make a simple center square
1329
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1330
+ mask = torch.ones(N, h, w).to(self.device)
1331
+ # zeros will be filled in
1332
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1333
+ mask = mask[:, None, ...]
1334
+ with self.ema_scope("Plotting Inpaint"):
1335
+
1336
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1337
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1338
+ x_samples = self.decode_first_stage(samples.to(self.device))
1339
+ log["samples_inpainting"] = x_samples
1340
+ log["mask"] = mask
1341
+
1342
+ # outpaint
1343
+ with self.ema_scope("Plotting Outpaint"):
1344
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1345
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1346
+ x_samples = self.decode_first_stage(samples.to(self.device))
1347
+ log["samples_outpainting"] = x_samples
1348
+
1349
+ if plot_progressive_rows:
1350
+ with self.ema_scope("Plotting Progressives"):
1351
+ img, progressives = self.progressive_denoising(c,
1352
+ shape=(self.channels, self.image_size, self.image_size),
1353
+ batch_size=N)
1354
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1355
+ log["progressive_row"] = prog_row
1356
+
1357
+ if return_keys:
1358
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1359
+ return log
1360
+ else:
1361
+ return {key: log[key] for key in return_keys}
1362
+ return log
1363
+
1364
+ def configure_optimizers(self):
1365
+ lr = self.learning_rate
1366
+ params = list(self.model.parameters())
1367
+ if self.cond_stage_trainable:
1368
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1369
+ params = params + list(self.cond_stage_model.parameters())
1370
+ if self.learn_logvar:
1371
+ print('Diffusion model optimizing logvar')
1372
+ params.append(self.logvar)
1373
+ opt = torch.optim.AdamW(params, lr=lr)
1374
+ if self.use_scheduler:
1375
+ assert 'target' in self.scheduler_config
1376
+ scheduler = instantiate_from_config(self.scheduler_config)
1377
+
1378
+ print("Setting up LambdaLR scheduler...")
1379
+ scheduler = [
1380
+ {
1381
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1382
+ 'interval': 'step',
1383
+ 'frequency': 1
1384
+ }]
1385
+ return [opt], scheduler
1386
+ return opt
1387
+
1388
+ @torch.no_grad()
1389
+ def to_rgb(self, x):
1390
+ x = x.float()
1391
+ if not hasattr(self, "colorize"):
1392
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1393
+ x = nn.functional.conv2d(x, weight=self.colorize)
1394
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1395
+ return x
1396
+
1397
+
1398
+
1399
+ class DiffusionWrapper(pl.LightningModule):
1400
+ def __init__(self, diff_model_config, conditioning_key):
1401
+ super().__init__()
1402
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1403
+ self.conditioning_key = conditioning_key # 'crossattn' for txt2image, concat for inpainting
1404
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'film', 'hybrid_inpaint']
1405
+
1406
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None,c_film: list = None):
1407
+ x = x.contiguous()
1408
+ t = t.contiguous()
1409
+ """param x: tensor with shape:[B,C,mel_len,T]"""
1410
+ if self.conditioning_key is None:
1411
+ out = self.diffusion_model(x, t)
1412
+ elif self.conditioning_key == 'concat':
1413
+ xc = torch.cat([x] + c_concat, dim=1)# channel dim,x shape (b,3,64,64) c_concat shape(b,4,64,64)
1414
+ out = self.diffusion_model(xc, t)
1415
+ elif self.conditioning_key == 'crossattn':
1416
+ if isinstance(c_crossattn,list):
1417
+ cc = torch.cat(c_crossattn, 1)# [b,seq_len,dim]
1418
+ else:
1419
+ cc = c_crossattn
1420
+ out = self.diffusion_model(x, t, context=cc)
1421
+ elif self.conditioning_key == 'hybrid':# not implemented in the LatentDiffusion
1422
+ xc = torch.cat([x] + c_concat, dim=1)
1423
+ cc = torch.cat(c_crossattn, 1)
1424
+ out = self.diffusion_model(xc, t, context=cc)
1425
+ elif self.conditioning_key == 'hybrid_inpaint': # special
1426
+ cc = c_crossattn
1427
+ out = self.diffusion_model(x, t, context=cc)
1428
+ elif self.conditioning_key == "film": # The condition is assumed to be a global token, which wil pass through a linear layer and added with the time embedding for the FILM
1429
+ cc = c_film[0].squeeze(1).contiguous() # only has one token, shape (b,context_dim)
1430
+ out = self.diffusion_model(x, t, y=cc)
1431
+ elif self.conditioning_key == 'adm':
1432
+ cc = c_crossattn[0]
1433
+ out = self.diffusion_model(x, t, y=cc)
1434
+ else:
1435
+ raise NotImplementedError()
1436
+
1437
+ return out
1438
+
1439
+
1440
+ class Layout2ImgDiffusion(LatentDiffusion):
1441
+ # TODO: move all layout-specific hacks to this class
1442
+ def __init__(self, cond_stage_key, *args, **kwargs):
1443
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1444
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1445
+
1446
+ def log_images(self, batch, N=8, *args, **kwargs):
1447
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1448
+
1449
+ key = 'train' if self.training else 'validation'
1450
+ dset = self.trainer.datamodule.datasets[key]
1451
+ mapper = dset.conditional_builders[self.cond_stage_key]
1452
+
1453
+ bbox_imgs = []
1454
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1455
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
1456
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1457
+ bbox_imgs.append(bboximg)
1458
+
1459
+ cond_img = torch.stack(bbox_imgs, dim=0)
1460
+ logs['bbox_image'] = cond_img
1461
+ return logs
ldm/models/diffusion/ddpm_audio.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pytorch_memlab import LineProfiler,profile
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from einops import rearrange, repeat
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+ from tqdm import tqdm
12
+ from torchvision.utils import make_grid
13
+ try:
14
+ from pytorch_lightning.utilities.distributed import rank_zero_only
15
+ except:
16
+ from pytorch_lightning.utilities import rank_zero_only # torch2
17
+
18
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
19
+ from ldm.modules.ema import LitEma
20
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
21
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
22
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
23
+ from ldm.models.diffusion.ddim import DDIMSampler
24
+ from ldm.models.diffusion.ddpm import DDPM, disabled_train
25
+ from omegaconf import ListConfig
26
+
27
+ __conditioning_keys__ = {'concat': 'c_concat',
28
+ 'crossattn': 'c_crossattn',
29
+ 'adm': 'y'}
30
+
31
+
32
+ class LatentDiffusion_audio(DDPM):
33
+ """main class"""
34
+ def __init__(self,
35
+ first_stage_config,
36
+ cond_stage_config,
37
+ num_timesteps_cond=None,
38
+ mel_dim=80,
39
+ mel_length=848,
40
+ cond_stage_key="image",
41
+ cond_stage_trainable=False,
42
+ concat_mode=True,
43
+ cond_stage_forward=None,
44
+ conditioning_key=None,
45
+ scale_factor=1.0,
46
+ scale_by_std=False,
47
+ *args, **kwargs):
48
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
49
+ self.scale_by_std = scale_by_std
50
+ assert self.num_timesteps_cond <= kwargs['timesteps']
51
+ # for backwards compatibility after implementation of DiffusionWrapper
52
+ if conditioning_key is None:
53
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
54
+ if cond_stage_config == '__is_unconditional__':
55
+ conditioning_key = None
56
+ ckpt_path = kwargs.pop("ckpt_path", None)
57
+ ignore_keys = kwargs.pop("ignore_keys", [])
58
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
59
+ self.concat_mode = concat_mode
60
+ self.mel_dim = mel_dim
61
+ self.mel_length = mel_length
62
+ self.cond_stage_trainable = cond_stage_trainable
63
+ self.cond_stage_key = cond_stage_key
64
+ try:
65
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
66
+ except:
67
+ self.num_downs = 0
68
+ if not scale_by_std:
69
+ self.scale_factor = scale_factor
70
+ else:
71
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
72
+ self.instantiate_first_stage(first_stage_config)
73
+ self.instantiate_cond_stage(cond_stage_config)
74
+ self.cond_stage_forward = cond_stage_forward
75
+ self.clip_denoised = False
76
+ self.bbox_tokenizer = None
77
+
78
+ self.restarted_from_ckpt = False
79
+ if ckpt_path is not None:
80
+ self.init_from_ckpt(ckpt_path, ignore_keys)
81
+ self.restarted_from_ckpt = True
82
+
83
+ def make_cond_schedule(self, ):
84
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
85
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
86
+ self.cond_ids[:self.num_timesteps_cond] = ids
87
+
88
+ @rank_zero_only
89
+ @torch.no_grad()
90
+ def on_train_batch_start(self, batch, batch_idx):
91
+ # only for very first batch
92
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
93
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
94
+ # set rescale weight to 1./std of encodings
95
+ print("### USING STD-RESCALING ###")
96
+ x = super().get_input(batch, self.first_stage_key)
97
+ x = x.to(self.device)
98
+ encoder_posterior = self.encode_first_stage(x)
99
+ z = self.get_first_stage_encoding(encoder_posterior).detach()# get latent
100
+ del self.scale_factor
101
+ self.register_buffer('scale_factor', 1. / z.flatten().std())# 1/latent.std, get_first_stage_encoding returns self.scale_factor * latent
102
+ print(f"setting self.scale_factor to {self.scale_factor}")
103
+ print("### USING STD-RESCALING ###")
104
+
105
+ # def on_train_epoch_start(self):
106
+ # print("!!!!!!!!!!!!!!!!!!!!!!!!!!on_train_epoch_strat",self.trainer.train_dataloader.batch_sampler,hasattr(self.trainer.train_dataloader.batch_sampler,'set_epoch'))
107
+ # if hasattr(self.trainer.train_dataloader.batch_sampler,'set_epoch'):
108
+ # self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch)
109
+ # return super().on_train_epoch_start()
110
+
111
+
112
+ def register_schedule(self,
113
+ given_betas=None, beta_schedule="linear", timesteps=1000,
114
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
115
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
116
+
117
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
118
+ if self.shorten_cond_schedule:
119
+ self.make_cond_schedule()
120
+
121
+ def instantiate_first_stage(self, config):
122
+ model = instantiate_from_config(config)
123
+ self.first_stage_model = model.eval()
124
+ self.first_stage_model.train = disabled_train
125
+ for param in self.first_stage_model.parameters():
126
+ param.requires_grad = False
127
+
128
+ def instantiate_cond_stage(self, config):
129
+ if not self.cond_stage_trainable:
130
+ if config == "__is_first_stage__":
131
+ print("Using first stage also as cond stage.")
132
+ self.cond_stage_model = self.first_stage_model
133
+ elif config == "__is_unconditional__":
134
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
135
+ self.cond_stage_model = None
136
+ else:
137
+ model = instantiate_from_config(config)
138
+ self.cond_stage_model = model.eval()
139
+ self.cond_stage_model.train = disabled_train
140
+ for param in self.cond_stage_model.parameters():
141
+ param.requires_grad = False
142
+ else:
143
+ assert config != '__is_first_stage__'
144
+ assert config != '__is_unconditional__'
145
+ model = instantiate_from_config(config)
146
+ self.cond_stage_model = model
147
+
148
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
149
+ denoise_row = []
150
+ for zd in tqdm(samples, desc=desc):
151
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
152
+ force_not_quantize=force_no_decoder_quantization))
153
+ n_imgs_per_row = len(denoise_row)
154
+ if len(denoise_row[0].shape) == 3:
155
+ denoise_row = [x.unsqueeze(1) for x in denoise_row]
156
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
157
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
158
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
159
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
160
+ return denoise_grid
161
+
162
+ def get_first_stage_encoding(self, encoder_posterior):
163
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
164
+ z = encoder_posterior.sample()
165
+ elif isinstance(encoder_posterior, torch.Tensor):
166
+ z = encoder_posterior
167
+ else:
168
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
169
+ return self.scale_factor * z
170
+
171
+ #@profile
172
+ def get_learned_conditioning(self, c):
173
+ if self.cond_stage_forward is None:
174
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
175
+ c = self.cond_stage_model.encode(c)
176
+ if isinstance(c, DiagonalGaussianDistribution):
177
+ c = c.mode()
178
+ else:
179
+ c = self.cond_stage_model(c)
180
+ else:
181
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
182
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
183
+ return c
184
+
185
+
186
+ @torch.no_grad()
187
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
188
+ if null_label is not None:
189
+ xc = null_label
190
+ if isinstance(xc, ListConfig):
191
+ xc = list(xc)
192
+ if isinstance(xc, dict) or isinstance(xc, list):
193
+ c = self.get_learned_conditioning(xc)
194
+ else:
195
+ if hasattr(xc, "to"):
196
+ xc = xc.to(self.device)
197
+ c = self.get_learned_conditioning(xc)
198
+ else:
199
+ if self.cond_stage_key in ["class_label", "cls"]:
200
+ xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
201
+ return self.get_learned_conditioning(xc)
202
+ else:
203
+ raise NotImplementedError("todo")
204
+ if isinstance(c, list): # in case the encoder gives us a list
205
+ for i in range(len(c)):
206
+ c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
207
+ else:
208
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
209
+ return c
210
+
211
+ def meshgrid(self, h, w):
212
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
213
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
214
+
215
+ arr = torch.cat([y, x], dim=-1)
216
+ return arr
217
+
218
+ def delta_border(self, h, w):
219
+ """
220
+ :param h: height
221
+ :param w: width
222
+ :return: normalized distance to image border,
223
+ wtith min distance = 0 at border and max dist = 0.5 at image center
224
+ """
225
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
226
+ arr = self.meshgrid(h, w) / lower_right_corner
227
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
228
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
229
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
230
+ return edge_dist
231
+
232
+ def get_weighting(self, h, w, Ly, Lx, device):
233
+ weighting = self.delta_border(h, w)
234
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
235
+ self.split_input_params["clip_max_weight"], )
236
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
237
+
238
+ if self.split_input_params["tie_braker"]:
239
+ L_weighting = self.delta_border(Ly, Lx)
240
+ L_weighting = torch.clip(L_weighting,
241
+ self.split_input_params["clip_min_tie_weight"],
242
+ self.split_input_params["clip_max_tie_weight"])
243
+
244
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
245
+ weighting = weighting * L_weighting
246
+ return weighting
247
+
248
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
249
+ """
250
+ :param x: img of size (bs, c, h, w)
251
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
252
+ """
253
+ bs, nc, h, w = x.shape
254
+
255
+ # number of crops in image
256
+ Ly = (h - kernel_size[0]) // stride[0] + 1
257
+ Lx = (w - kernel_size[1]) // stride[1] + 1
258
+
259
+ if uf == 1 and df == 1:
260
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
261
+ unfold = torch.nn.Unfold(**fold_params)
262
+
263
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
264
+
265
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
266
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
267
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
268
+
269
+ elif uf > 1 and df == 1:
270
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
271
+ unfold = torch.nn.Unfold(**fold_params)
272
+
273
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
274
+ dilation=1, padding=0,
275
+ stride=(stride[0] * uf, stride[1] * uf))
276
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
277
+
278
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
279
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
280
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
281
+
282
+ elif df > 1 and uf == 1:
283
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
284
+ unfold = torch.nn.Unfold(**fold_params)
285
+
286
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
287
+ dilation=1, padding=0,
288
+ stride=(stride[0] // df, stride[1] // df))
289
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
290
+
291
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
292
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
293
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
294
+
295
+ else:
296
+ raise NotImplementedError
297
+
298
+ return fold, unfold, normalization, weighting
299
+
300
+ @torch.no_grad()
301
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
302
+ cond_key=None, return_original_cond=False, bs=None):
303
+ x = super().get_input(batch, k)
304
+ if bs is not None:
305
+ x = x[:bs]
306
+ x = x.to(self.device)
307
+ encoder_posterior = self.encode_first_stage(x)
308
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
309
+
310
+ if self.model.conditioning_key is not None:
311
+ if cond_key is None:
312
+ cond_key = self.cond_stage_key
313
+ if cond_key != self.first_stage_key:
314
+ if cond_key in ['caption', 'coordinates_bbox', 'hybrid_feat']:
315
+ xc = batch[cond_key]
316
+ elif cond_key == 'class_label':
317
+ xc = batch
318
+ else:
319
+ xc = super().get_input(batch, cond_key).to(self.device)
320
+ else:
321
+ xc = x
322
+ if not self.cond_stage_trainable or force_c_encode: # False
323
+ if isinstance(xc, dict) or isinstance(xc, list):
324
+ # import pudb; pudb.set_trace()
325
+ c = self.get_learned_conditioning(xc)
326
+ else:
327
+ c = self.get_learned_conditioning(xc.to(self.device))
328
+ else:
329
+ c = xc
330
+ if bs is not None:
331
+ c = c[:bs]
332
+ # Testing #
333
+ if cond_key == 'masked_image':
334
+ mask = super().get_input(batch, "mask")
335
+ cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # [B, 1, 10, 106]
336
+ c = torch.cat((c, cc), dim=1) # [B, 5, 10, 106]
337
+ # Testing #
338
+ if self.use_positional_encodings:
339
+ pos_x, pos_y = self.compute_latent_shifts(batch)
340
+ ckey = __conditioning_keys__[self.model.conditioning_key]
341
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
342
+
343
+ else:
344
+ c = None
345
+ xc = None
346
+ if self.use_positional_encodings:
347
+ pos_x, pos_y = self.compute_latent_shifts(batch)
348
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
349
+ out = [z, c]
350
+ if return_first_stage_outputs:
351
+ xrec = self.decode_first_stage(z)
352
+ out.extend([x, xrec])
353
+ if return_original_cond:
354
+ out.append(xc)
355
+ return out
356
+
357
+ @torch.no_grad()
358
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
359
+ if predict_cids:
360
+ if z.dim() == 4:
361
+ z = torch.argmax(z.exp(), dim=1).long()
362
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
363
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
364
+
365
+ z = 1. / self.scale_factor * z
366
+
367
+
368
+ if isinstance(self.first_stage_model, VQModelInterface):
369
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
370
+ else:
371
+ return self.first_stage_model.decode(z)
372
+
373
+ # same as above but without decorator
374
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
375
+ if predict_cids:
376
+ if z.dim() == 4:
377
+ z = torch.argmax(z.exp(), dim=1).long()
378
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
379
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
380
+
381
+ z = 1. / self.scale_factor * z
382
+
383
+
384
+ if isinstance(self.first_stage_model, VQModelInterface):
385
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
386
+ else:
387
+ return self.first_stage_model.decode(z)
388
+
389
+ @torch.no_grad()
390
+ def encode_first_stage(self, x):
391
+ return self.first_stage_model.encode(x)
392
+
393
+ def shared_step(self, batch, **kwargs):
394
+ x, c = self.get_input(batch, self.first_stage_key)
395
+ loss = self(x, c)
396
+ return loss
397
+
398
+ def test_step(self,batch,batch_idx):
399
+ cond = batch[self.cond_stage_key] # * self.test_repeat
400
+ cond = self.get_learned_conditioning(cond) # c: string -> [B, T, Context_dim]
401
+ batch_size = len(cond)
402
+ enc_emb = self.sample(cond,batch_size,timesteps=self.num_timesteps)# shape = [batch_size,self.channels,self.mel_dim,self.mel_length]
403
+ xrec = self.decode_first_stage(enc_emb)
404
+ # reconstructions = (xrec + 1)/2 # to mel scale
405
+ # test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
406
+ # savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
407
+ # if not os.path.exists(savedir):
408
+ # os.makedirs(savedir)
409
+
410
+ # file_names = batch['f_name']
411
+ # nfiles = len(file_names)
412
+ # reconstructions = reconstructions.cpu().numpy().squeeze(1) # squuze channel dim
413
+ # for k in range(reconstructions.shape[0]):
414
+ # b,repeat = k % nfiles, k // nfiles
415
+ # vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
416
+ # v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
417
+ # save_img_path = os.path.join(savedir,f'{v_n}_sample_{num}_{repeat}.npy')# the num_th caption, the repeat_th repitition
418
+ # np.save(save_img_path,reconstructions[b])
419
+ return None
420
+
421
+ def forward(self, x, c, *args, **kwargs):
422
+ '''
423
+ video to audio:
424
+ x (latent): [B, 256 (time), 20] c (video feat): [B, 32 (time), 512]
425
+ '''
426
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() # [B]
427
+ if self.model.conditioning_key is not None:
428
+ assert c is not None
429
+ if self.cond_stage_trainable:
430
+ c = self.get_learned_conditioning(c) # c: string -> [B, T, Context_dim]
431
+ if self.shorten_cond_schedule: # TODO: drop this option
432
+ tc = self.cond_ids[t].to(self.device)
433
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
434
+ return self.p_losses(x, c, t, *args, **kwargs)
435
+
436
+
437
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
438
+
439
+ if isinstance(cond, dict):
440
+ # hybrid case, cond is exptected to be a dict
441
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
442
+ cond = {key: cond}
443
+ else:
444
+ if not isinstance(cond, list):
445
+ cond = [cond]
446
+ if self.model.conditioning_key == "concat":
447
+ key = "c_concat"
448
+ elif self.model.conditioning_key == "crossattn":
449
+ key = "c_crossattn"
450
+ else:
451
+ key = "c_film"
452
+ cond = {key: cond}
453
+
454
+
455
+ x_recon = self.model(x_noisy, t, **cond)
456
+
457
+ if isinstance(x_recon, tuple) and not return_ids:
458
+ return x_recon[0]
459
+ else:
460
+ return x_recon
461
+
462
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
463
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
464
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
465
+
466
+ def _prior_bpd(self, x_start):
467
+ """
468
+ Get the prior KL term for the variational lower-bound, measured in
469
+ bits-per-dim.
470
+ This term can't be optimized, as it only depends on the encoder.
471
+ :param x_start: the [N x C x ...] tensor of inputs.
472
+ :return: a batch of [N] KL values (in bits), one per batch element.
473
+ """
474
+ batch_size = x_start.shape[0]
475
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
476
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
477
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
478
+ return mean_flat(kl_prior) / np.log(2.0)
479
+
480
+ def p_losses(self, x_start, cond, t, noise=None):
481
+ noise = default(noise, lambda: torch.randn_like(x_start))
482
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
483
+ model_output = self.apply_model(x_noisy, t, cond)
484
+
485
+ loss_dict = {}
486
+ prefix = 'train' if self.training else 'val'
487
+
488
+ if self.parameterization == "x0":
489
+ target = x_start
490
+ elif self.parameterization == "eps":
491
+ target = noise
492
+ else:
493
+ raise NotImplementedError()
494
+
495
+ mean_dims = list(range(1,len(target.shape)))
496
+ loss_simple = self.get_loss(model_output, target, mean=False).mean(dim=mean_dims)
497
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
498
+
499
+ logvar_t = self.logvar[t.to(self.logvar.device)].to(self.device)
500
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
501
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
502
+ if self.learn_logvar:
503
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
504
+ loss_dict.update({'logvar': self.logvar.data.mean()})
505
+
506
+ loss = self.l_simple_weight * loss.mean()
507
+
508
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=mean_dims)
509
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
510
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
511
+ loss += (self.original_elbo_weight * loss_vlb)
512
+ loss_dict.update({f'{prefix}/loss': loss})
513
+
514
+ return loss, loss_dict
515
+
516
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
517
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
518
+ t_in = t
519
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
520
+
521
+ if score_corrector is not None:
522
+ assert self.parameterization == "eps"
523
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
524
+
525
+ if return_codebook_ids:
526
+ model_out, logits = model_out
527
+
528
+ if self.parameterization == "eps":
529
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
530
+ elif self.parameterization == "x0":
531
+ x_recon = model_out
532
+ else:
533
+ raise NotImplementedError()
534
+
535
+ if clip_denoised:
536
+ x_recon.clamp_(-1., 1.)
537
+ if quantize_denoised:
538
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
539
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
540
+ if return_codebook_ids:
541
+ return model_mean, posterior_variance, posterior_log_variance, logits
542
+ elif return_x0:
543
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
544
+ else:
545
+ return model_mean, posterior_variance, posterior_log_variance
546
+
547
+ @torch.no_grad()
548
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
549
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
550
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
551
+ b, *_, device = *x.shape, x.device
552
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
553
+ return_codebook_ids=return_codebook_ids,
554
+ quantize_denoised=quantize_denoised,
555
+ return_x0=return_x0,
556
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
557
+ if return_codebook_ids:
558
+ raise DeprecationWarning("Support dropped.")
559
+ model_mean, _, model_log_variance, logits = outputs
560
+ elif return_x0:
561
+ model_mean, _, model_log_variance, x0 = outputs
562
+ else:
563
+ model_mean, _, model_log_variance = outputs
564
+
565
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
566
+ if noise_dropout > 0.:
567
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
568
+ # no noise when t == 0
569
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
570
+
571
+ if return_codebook_ids:
572
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
573
+ if return_x0:
574
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
575
+ else:
576
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
577
+
578
+ @torch.no_grad()
579
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
580
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
581
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
582
+ log_every_t=None):
583
+ if not log_every_t:
584
+ log_every_t = self.log_every_t
585
+ timesteps = self.num_timesteps
586
+ if batch_size is not None:
587
+ b = batch_size if batch_size is not None else shape[0]
588
+ shape = [batch_size] + list(shape)
589
+ else:
590
+ b = batch_size = shape[0]
591
+ if x_T is None:
592
+ img = torch.randn(shape, device=self.device)
593
+ else:
594
+ img = x_T
595
+ intermediates = []
596
+ if cond is not None:
597
+ if isinstance(cond, dict):
598
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
599
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
600
+ else:
601
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
602
+
603
+ if start_T is not None:
604
+ timesteps = min(timesteps, start_T)
605
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
606
+ total=timesteps) if verbose else reversed(
607
+ range(0, timesteps))
608
+ if type(temperature) == float:
609
+ temperature = [temperature] * timesteps
610
+
611
+ for i in iterator:
612
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
613
+ if self.shorten_cond_schedule:
614
+ assert self.model.conditioning_key != 'hybrid'
615
+ tc = self.cond_ids[ts].to(cond.device)
616
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
617
+
618
+
619
+ img, x0_partial = self.p_sample(img, cond, ts,
620
+ clip_denoised=self.clip_denoised,
621
+ quantize_denoised=quantize_denoised, return_x0=True,
622
+ temperature=temperature[i], noise_dropout=noise_dropout,
623
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
624
+ if mask is not None:
625
+ assert x0 is not None
626
+ img_orig = self.q_sample(x0, ts)
627
+ img = img_orig * mask + (1. - mask) * img
628
+
629
+ if i % log_every_t == 0 or i == timesteps - 1:
630
+ intermediates.append(x0_partial)
631
+ if callback: callback(i)
632
+ if img_callback: img_callback(img, i)
633
+ return img, intermediates
634
+
635
+ @torch.no_grad()
636
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
637
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
638
+ mask=None, x0=None, img_callback=None, start_T=None,
639
+ log_every_t=None):
640
+
641
+ if not log_every_t:
642
+ log_every_t = self.log_every_t
643
+ device = self.betas.device
644
+ b = shape[0]
645
+ if x_T is None:
646
+ img = torch.randn(shape, device=device)
647
+ else:
648
+ img = x_T
649
+
650
+ intermediates = [img]
651
+ if timesteps is None:
652
+ timesteps = self.num_timesteps
653
+
654
+ if start_T is not None:
655
+ timesteps = min(timesteps, start_T)
656
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
657
+ range(0, timesteps))
658
+
659
+ if mask is not None:
660
+ assert x0 is not None
661
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
662
+
663
+ for i in iterator:
664
+ ts = torch.full((b,), i, device=device, dtype=torch.long) # num
665
+ if self.shorten_cond_schedule: # False
666
+ assert self.model.conditioning_key != 'hybrid'
667
+ tc = self.cond_ids[ts].to(cond.device)
668
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
669
+
670
+ img = self.p_sample(img, cond, ts,
671
+ clip_denoised=self.clip_denoised, # False
672
+ quantize_denoised=quantize_denoised) # False
673
+ if mask is not None: # False
674
+ img_orig = self.q_sample(x0, ts)
675
+ img = img_orig * mask + (1. - mask) * img
676
+
677
+ if i % log_every_t == 0 or i == timesteps - 1:
678
+ intermediates.append(img)
679
+ if callback: callback(i)
680
+ if img_callback: img_callback(img, i)
681
+
682
+ if return_intermediates:
683
+ return img, intermediates
684
+ return img
685
+
686
+ @torch.no_grad()
687
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
688
+ verbose=True, timesteps=None, quantize_denoised=False,
689
+ mask=None, x0=None, shape=None,**kwargs):
690
+ if shape is None:
691
+ if self.channels > 0:
692
+ shape = (batch_size, self.channels, self.mel_dim, self.mel_length)
693
+ else:
694
+ shape = (batch_size, self.mel_dim, self.mel_length)
695
+ if cond is not None:
696
+ if isinstance(cond, dict):
697
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
698
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
699
+ else:
700
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
701
+ return self.p_sample_loop(cond,
702
+ shape,
703
+ return_intermediates=return_intermediates, x_T=x_T,
704
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
705
+ mask=mask, x0=x0)
706
+
707
+ @torch.no_grad()
708
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
709
+
710
+ if ddim:
711
+ ddim_sampler = DDIMSampler(self)
712
+ shape = (self.channels, self.mel_dim, self.mel_length) if self.channels > 0 else (self.mel_dim, self.mel_length)
713
+ samples, intermediates = ddim_sampler.sample(ddim_steps,batch_size,
714
+ shape,cond,verbose=False,**kwargs)
715
+
716
+ else:
717
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
718
+ return_intermediates=True,**kwargs)
719
+
720
+ return samples, intermediates
721
+
722
+
723
+ @torch.no_grad()
724
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
725
+ quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=True,
726
+ plot_diffusion_rows=True, **kwargs):
727
+
728
+ use_ddim = ddim_steps is not None
729
+
730
+ log = dict()
731
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
732
+ return_first_stage_outputs=True,
733
+ force_c_encode=True,
734
+ return_original_cond=True,
735
+ bs=N) # z is latent,c is condition embedding, xc is condition(caption) list
736
+ N = min(x.shape[0], N)
737
+ n_row = min(x.shape[0], n_row)
738
+ log["inputs"] = x if len(x.shape)==4 else x.unsqueeze(1)
739
+ log["reconstruction"] = xrec if len(xrec.shape)==4 else xrec.unsqueeze(1)
740
+ if self.model.conditioning_key is not None:
741
+ if hasattr(self.cond_stage_model, "decode") and self.cond_stage_key != "masked_image":
742
+ xc = self.cond_stage_model.decode(c)
743
+ log["conditioning"] = xc
744
+ elif self.cond_stage_key == "masked_image":
745
+ log["mask"] = c[:, -1, :, :][:, None, :, :]
746
+ xc = self.cond_stage_model.decode(c[:, :self.cond_stage_model.embed_dim, :, :])
747
+ log["conditioning"] = xc
748
+ elif self.cond_stage_key in ["caption"]:
749
+ pass
750
+ # xc = log_txt_as_img((256, 256), batch["caption"])
751
+ # log["conditioning"] = xc
752
+ elif self.cond_stage_key == 'class_label':
753
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
754
+ log['conditioning'] = xc
755
+ elif isimage(xc):
756
+ log["conditioning"] = xc
757
+
758
+ if plot_diffusion_rows:
759
+ # get diffusion row
760
+ diffusion_row = list()
761
+ z_start = z[:n_row]
762
+ for t in range(self.num_timesteps):
763
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
764
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
765
+ t = t.to(self.device).long()
766
+ noise = torch.randn_like(z_start)
767
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
768
+ diffusion_row.append(self.decode_first_stage(z_noisy))
769
+ if len(diffusion_row[0].shape) == 3:
770
+ diffusion_row = [x.unsqueeze(1) for x in diffusion_row]
771
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
772
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
773
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
774
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
775
+ log["diffusion_row"] = diffusion_grid
776
+
777
+ if sample:
778
+ # get denoise row
779
+ with self.ema_scope("Plotting"):
780
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
781
+ ddim_steps=ddim_steps,eta=ddim_eta)
782
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
783
+ x_samples = self.decode_first_stage(samples)
784
+ log["samples"] = x_samples if len(x_samples.shape)==4 else x_samples.unsqueeze(1)
785
+ if plot_denoise_rows:
786
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
787
+ log["denoise_row"] = denoise_grid
788
+
789
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
790
+ self.first_stage_model, IdentityFirstStage):
791
+ # also display when quantizing x0 while sampling
792
+ with self.ema_scope("Plotting Quantized Denoised"):
793
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
794
+ ddim_steps=ddim_steps,eta=ddim_eta,
795
+ quantize_denoised=True)
796
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
797
+ # quantize_denoised=True)
798
+ x_samples = self.decode_first_stage(samples.to(self.device))
799
+ log["samples_x0_quantized"] = x_samples if len(x_samples.shape)==4 else x_samples.unsqueeze(1)
800
+
801
+ if inpaint:
802
+ # make a simple center square
803
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
804
+ mask = torch.ones(N, h, w).to(self.device)
805
+ # zeros will be filled in
806
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
807
+ mask = mask[:, None, ...]
808
+ with self.ema_scope("Plotting Inpaint"):
809
+
810
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
811
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
812
+ x_samples = self.decode_first_stage(samples.to(self.device))
813
+ log["samples_inpainting"] = x_samples
814
+ log["mask_inpainting"] = mask
815
+
816
+ # outpaint
817
+ mask = 1 - mask
818
+ with self.ema_scope("Plotting Outpaint"):
819
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
820
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
821
+ x_samples = self.decode_first_stage(samples.to(self.device))
822
+ log["samples_outpainting"] = x_samples
823
+ log["mask_outpainting"] = mask
824
+
825
+ if plot_progressive_rows:
826
+ with self.ema_scope("Plotting Progressives"):
827
+ shape = (self.channels, self.mel_dim, self.mel_length) if self.channels > 0 else (self.mel_dim, self.mel_length)
828
+ img, progressives = self.progressive_denoising(c,
829
+ shape=shape,
830
+ batch_size=N)
831
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
832
+ log["progressive_row"] = prog_row
833
+
834
+ if return_keys:
835
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
836
+ return log
837
+ else:
838
+ return {key: log[key] for key in return_keys}
839
+ return log
840
+
841
+ def configure_optimizers(self):
842
+ lr = self.learning_rate
843
+ params = list(self.model.parameters())
844
+ if self.cond_stage_trainable:
845
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
846
+ params = params + list(self.cond_stage_model.parameters())
847
+ if self.learn_logvar:
848
+ print('Diffusion model optimizing logvar')
849
+ params.append(self.logvar)
850
+ opt = torch.optim.AdamW(params, lr=lr)
851
+ if self.use_scheduler:
852
+ assert 'target' in self.scheduler_config
853
+ scheduler = instantiate_from_config(self.scheduler_config)
854
+
855
+ print("Setting up LambdaLR scheduler...")
856
+ scheduler = [
857
+ {
858
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
859
+ 'interval': 'step',
860
+ 'frequency': 1
861
+ }]
862
+ return [opt], scheduler
863
+ return opt
864
+
865
+
ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class PLMSSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ if ddim_eta != 0:
26
+ raise ValueError('ddim_eta must be 0 for PLMS')
27
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29
+ alphas_cumprod = self.model.alphas_cumprod
30
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32
+
33
+ self.register_buffer('betas', to_torch(self.model.betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43
+
44
+ # ddim sampling parameters
45
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46
+ ddim_timesteps=self.ddim_timesteps,
47
+ eta=ddim_eta,verbose=verbose)
48
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
49
+ self.register_buffer('ddim_alphas', ddim_alphas)
50
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56
+
57
+ @torch.no_grad()
58
+ def sample(self,
59
+ S,
60
+ batch_size,
61
+ shape,
62
+ conditioning=None,
63
+ callback=None,
64
+ normals_sequence=None,
65
+ img_callback=None,
66
+ quantize_x0=False,
67
+ eta=0.,
68
+ mask=None,
69
+ x0=None,
70
+ temperature=1.,
71
+ noise_dropout=0.,
72
+ score_corrector=None,
73
+ corrector_kwargs=None,
74
+ verbose=True,
75
+ x_T=None,
76
+ log_every_t=100,
77
+ unconditional_guidance_scale=1.,
78
+ unconditional_conditioning=None,
79
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80
+ **kwargs
81
+ ):
82
+ if conditioning is not None:
83
+ if isinstance(conditioning, dict):
84
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+ else:
88
+ if conditioning.shape[0] != batch_size:
89
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90
+
91
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92
+ # sampling
93
+ C, H, W = shape
94
+ size = (batch_size, C, H, W)
95
+ print(f'Data shape for PLMS sampling is {size}')
96
+
97
+ samples, intermediates = self.plms_sampling(conditioning, size,
98
+ callback=callback,
99
+ img_callback=img_callback,
100
+ quantize_denoised=quantize_x0,
101
+ mask=mask, x0=x0,
102
+ ddim_use_original_steps=False,
103
+ noise_dropout=noise_dropout,
104
+ temperature=temperature,
105
+ score_corrector=score_corrector,
106
+ corrector_kwargs=corrector_kwargs,
107
+ x_T=x_T,
108
+ log_every_t=log_every_t,
109
+ unconditional_guidance_scale=unconditional_guidance_scale,
110
+ unconditional_conditioning=unconditional_conditioning,
111
+ )
112
+ return samples, intermediates
113
+
114
+ @torch.no_grad()
115
+ def plms_sampling(self, cond, shape,
116
+ x_T=None, ddim_use_original_steps=False,
117
+ callback=None, timesteps=None, quantize_denoised=False,
118
+ mask=None, x0=None, img_callback=None, log_every_t=100,
119
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
120
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
121
+ device = self.model.betas.device
122
+ b = shape[0]
123
+ if x_T is None:
124
+ img = torch.randn(shape, device=device)
125
+ else:
126
+ img = x_T
127
+
128
+ if timesteps is None:
129
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
130
+ elif timesteps is not None and not ddim_use_original_steps:
131
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
132
+ timesteps = self.ddim_timesteps[:subset_end]
133
+
134
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
135
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
136
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
137
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
138
+
139
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
140
+ old_eps = []
141
+
142
+ for i, step in enumerate(iterator):
143
+ index = total_steps - i - 1
144
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
145
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
146
+
147
+ if mask is not None:
148
+ assert x0 is not None
149
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
150
+ img = img_orig * mask + (1. - mask) * img
151
+
152
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
153
+ quantize_denoised=quantize_denoised, temperature=temperature,
154
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
155
+ corrector_kwargs=corrector_kwargs,
156
+ unconditional_guidance_scale=unconditional_guidance_scale,
157
+ unconditional_conditioning=unconditional_conditioning,
158
+ old_eps=old_eps, t_next=ts_next)
159
+ img, pred_x0, e_t = outs
160
+ old_eps.append(e_t)
161
+ if len(old_eps) >= 4:
162
+ old_eps.pop(0)
163
+ if callback: callback(i)
164
+ if img_callback: img_callback(pred_x0, i)
165
+
166
+ if index % log_every_t == 0 or index == total_steps - 1:
167
+ intermediates['x_inter'].append(img)
168
+ intermediates['pred_x0'].append(pred_x0)
169
+
170
+ return img, intermediates
171
+
172
+ @torch.no_grad()
173
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
174
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
175
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
176
+ b, *_, device = *x.shape, x.device
177
+
178
+ def get_model_output(x, t):
179
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
180
+ e_t = self.model.apply_model(x, t, c)
181
+ else:
182
+ x_in = torch.cat([x] * 2)
183
+ t_in = torch.cat([t] * 2)
184
+ c_in = torch.cat([unconditional_conditioning, c])
185
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
186
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
187
+
188
+ if score_corrector is not None:
189
+ assert self.model.parameterization == "eps"
190
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
191
+
192
+ return e_t
193
+
194
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
195
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
196
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
197
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
198
+
199
+ def get_x_prev_and_pred_x0(e_t, index):
200
+ # select parameters corresponding to the currently considered timestep
201
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205
+
206
+ # current prediction for x_0
207
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208
+ if quantize_denoised:
209
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210
+ # direction pointing to x_t
211
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
213
+ if noise_dropout > 0.:
214
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216
+ return x_prev, pred_x0
217
+
218
+ e_t = get_model_output(x, t)
219
+ if len(old_eps) == 0:
220
+ # Pseudo Improved Euler (2nd order)
221
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
222
+ e_t_next = get_model_output(x_prev, t_next)
223
+ e_t_prime = (e_t + e_t_next) / 2
224
+ elif len(old_eps) == 1:
225
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
226
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
227
+ elif len(old_eps) == 2:
228
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
229
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
230
+ elif len(old_eps) >= 3:
231
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
232
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
233
+
234
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
235
+
236
+ return x_prev, pred_x0, e_t
ldm/models/diffusion/transport/__init__.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .transport import Transport, ModelType, WeightType, PathType, SNRType, Sampler
2
+
3
+
4
+ def create_transport(
5
+ path_type='Linear',
6
+ prediction="velocity",
7
+ loss_weight=None,
8
+ train_eps=None,
9
+ sample_eps=None,
10
+ snr_type="uniform"
11
+ ):
12
+ """function for creating Transport object
13
+ **Note**: model prediction defaults to velocity
14
+ Args:
15
+ - path_type: type of path to use; default to linear
16
+ - learn_score: set model prediction to score
17
+ - learn_noise: set model prediction to noise
18
+ - velocity_weighted: weight loss by velocity weight
19
+ - likelihood_weighted: weight loss by likelihood weight
20
+ - train_eps: small epsilon for avoiding instability during training
21
+ - sample_eps: small epsilon for avoiding instability during sampling
22
+ """
23
+
24
+ if prediction == "noise":
25
+ model_type = ModelType.NOISE
26
+ elif prediction == "score":
27
+ model_type = ModelType.SCORE
28
+ else:
29
+ model_type = ModelType.VELOCITY
30
+
31
+ if loss_weight == "velocity":
32
+ loss_type = WeightType.VELOCITY
33
+ elif loss_weight == "likelihood":
34
+ loss_type = WeightType.LIKELIHOOD
35
+ else:
36
+ loss_type = WeightType.NONE
37
+
38
+ if snr_type == "lognorm":
39
+ snr_type = SNRType.LOGNORM
40
+ elif snr_type == "uniform":
41
+ snr_type = SNRType.UNIFORM
42
+ else:
43
+ raise ValueError(f"Invalid snr type {snr_type}")
44
+
45
+ path_choice = {
46
+ "Linear": PathType.LINEAR,
47
+ "GVP": PathType.GVP,
48
+ "VP": PathType.VP,
49
+ }
50
+
51
+ path_type = path_choice[path_type]
52
+
53
+ if (path_type in [PathType.VP]):
54
+ train_eps = 1e-5 if train_eps is None else train_eps
55
+ sample_eps = 1e-3 if train_eps is None else sample_eps
56
+ elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
57
+ train_eps = 1e-3 if train_eps is None else train_eps
58
+ sample_eps = 1e-3 if train_eps is None else sample_eps
59
+ else: # velocity & [GVP, LINEAR] is stable everywhere
60
+ train_eps = 0
61
+ sample_eps = 0
62
+
63
+ # create flow state
64
+ state = Transport(
65
+ model_type=model_type,
66
+ path_type=path_type,
67
+ loss_type=loss_type,
68
+ train_eps=train_eps,
69
+ sample_eps=sample_eps,
70
+ snr_type=snr_type
71
+ )
72
+
73
+ return state