yjhuangcd
commited on
Commit
·
9965bf6
1
Parent(s):
9708aee
First commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +136 -3
- compute_std.py +54 -0
- datasets/README.md +20 -0
- datasets/all_midi.csv +0 -0
- datasets/chunk_midi.py +72 -0
- datasets/filter_class.py +38 -0
- datasets/piano_roll_all.py +139 -0
- datasets/select_midi.py +74 -0
- diff_collage/README.md +3 -0
- diff_collage/__init__.py +5 -0
- diff_collage/avg_circle.py +64 -0
- diff_collage/avg_long.py +40 -0
- diff_collage/condind_circle.py +190 -0
- diff_collage/condind_long.py +147 -0
- diff_collage/generic_sampler.py +113 -0
- diff_collage/loss_helper.py +41 -0
- diff_collage/w_img.py +79 -0
- diff_collage/w_loss.py +433 -0
- environment.yml +282 -0
- guided_diffusion/__init__.py +3 -0
- guided_diffusion/condition_functions.py +174 -0
- guided_diffusion/dist_util.py +104 -0
- guided_diffusion/dit.py +983 -0
- guided_diffusion/embed_datasets.py +161 -0
- guided_diffusion/fp16_util.py +237 -0
- guided_diffusion/gaussian_diffusion.py +1400 -0
- guided_diffusion/logger.py +521 -0
- guided_diffusion/losses.py +77 -0
- guided_diffusion/midi_util.py +291 -0
- guided_diffusion/nn.py +170 -0
- guided_diffusion/pr_datasets_all.py +183 -0
- guided_diffusion/resample.py +154 -0
- guided_diffusion/respace.py +128 -0
- guided_diffusion/script_util.py +531 -0
- guided_diffusion/train_util.py +475 -0
- guided_diffusion/unet.py +906 -0
- load_utils.py +31 -0
- music_evaluation/README.md +22 -0
- music_evaluation/convert_to_wav.py +42 -0
- music_evaluation/demo.ipynb +0 -0
- music_evaluation/fad.py +38 -0
- music_evaluation/figaro/chord_recognition.py +247 -0
- music_evaluation/figaro/constants.py +47 -0
- music_evaluation/figaro/evaluate.py +268 -0
- music_evaluation/figaro/input_representation.py +655 -0
- music_evaluation/figaro/vocab.py +166 -0
- music_evaluation/mgeval/__init__.py +0 -0
- music_evaluation/mgeval/__init__.pyc +0 -0
- music_evaluation/mgeval/core.py +644 -0
- music_evaluation/mgeval/core.pyc +0 -0
README.md
CHANGED
@@ -1,3 +1,136 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Symbolic Music Generation with Non-Differentiable Rule Guided Diffusion
|
2 |
+
|
3 |
+
This is the codebase for the paper: [Symbolic Music Generation with Non-Differentiable Rule Guided Diffusion](https://arxiv.org/abs/2402.14285).
|
4 |
+
|
5 |
+
We introduced a symbolic music generator with non-differentiable rule guided diffusion models, drawing inspiration from stochastic control. For music demos, please visit our [project website](https://scg-rule-guided-music.github.io/).
|
6 |
+
|
7 |
+
<img align="center" src="rule_guided_music_gen.png" width="750">
|
8 |
+
|
9 |
+
## Set up the environment
|
10 |
+
|
11 |
+
- Put the pretrained VAE checkpoint under `taming-transformers/checkpoints`
|
12 |
+
- Create conda virtual environment via: `conda env create -f environment.yml`
|
13 |
+
- Activating virtual env: `conda activate guided`
|
14 |
+
|
15 |
+
## Download Pretrained Checkpoints
|
16 |
+
- Pretrained VAE checkpoint under `trained_models/VAE`: put it under `taming-transformers/checkpoints/all_onset/epoch_14.ckpt`.
|
17 |
+
- Pretrained Diffusion model checkpoint under `trained_models/diffusion`: put it under `loggings/checkpoints/ema_0.9999_1200000.pt`.
|
18 |
+
- Pretrained classifiers for each rule under `trained_models/classifier`: put them under `loggings/classifier/`.
|
19 |
+
|
20 |
+
## Rule Guided Generation
|
21 |
+
All the configs of the rule guidance are stored in `scripts/configs/`. `cond_demo` contains the configs that we used to generate the demos for composer co-creation. `cond_table` contains the configs that we used to create the table in the paper. `edit` contains the configs of editing an existing excerpt.
|
22 |
+
For instance, if you want to guide diffusion models on all of the rules simultaneously, and use both SCG and classifier guidance, you can use this config `scripts/configs/cond_table/all/scg_classifier_all.yml`. The results will save in this directory: `loggings/cond_table/all/scg_classifier_all`.
|
23 |
+
|
24 |
+
The config file contains the following fields:
|
25 |
+
| Field | Description |
|
26 |
+
|-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
27 |
+
| `target_rules` | Specify the desired attributes here. New rules can be added by writing rule programs in `music_rule_guidance/music_rules.py` and updating `rule_maps`. |
|
28 |
+
| `guidance` | Hyper-parameters for guidance, including classifier config for classifier guidance, and when to start or stop using guidance. |
|
29 |
+
| `scg` | Hyper-parameters for stochastic control guidance (SCG, ours). |
|
30 |
+
| `sampling` | Hyper-parameters for diffusion model sampling. Options include using DDIM or sampling longer sequences with `diff_collage`. |
|
31 |
+
|
32 |
+
To run the rule-guided sampling code, you can use the following script:
|
33 |
+
```
|
34 |
+
python sample_rule.py
|
35 |
+
--config_path <config_file>
|
36 |
+
--batch_size 4
|
37 |
+
--num_samples 20
|
38 |
+
--data_dir <data_dir>
|
39 |
+
--model DiTRotary_XL_8
|
40 |
+
--model_path loggings/checkpoints/ema_0.9999_1200000.pt
|
41 |
+
--image_size 128 16
|
42 |
+
--in_channels 4
|
43 |
+
--scale_factor 1.2465
|
44 |
+
--class_cond True
|
45 |
+
--num_classes 3
|
46 |
+
--class_label 1
|
47 |
+
```
|
48 |
+
The meaning of each hyper-parameter is listed as follows:
|
49 |
+
| Hyper-parameter | Description |
|
50 |
+
|-------------------|-------------------------------------------------------------------------------------------------------------------|
|
51 |
+
| `config_path` | Path of the configuration file. For example: `scripts/configs/cond_demo/demo1.yml`. |
|
52 |
+
| `batch_size` | Batch size for generation. Default: 4. |
|
53 |
+
| `num_samples` | How many samples to generate in total. Default: 20. |
|
54 |
+
| `data_dir` | Optional: directory to store data. Used to extract rule label from existing music excerpts. Do not need if target rule labels are given (just leave as default value in this case). |
|
55 |
+
| `model` | Model backbone for diffusion model. Default: DiTRotary_XL_8. |
|
56 |
+
| `model_path` | Path of the pretrained diffusion model. |
|
57 |
+
| `image_size` | Size of the generated piano roll in latent space (for 10.24s, the size is 128x16). |
|
58 |
+
| `in_channels` | Number of channels for the latent space of pretrained VAE model. Default: 4. |
|
59 |
+
| `scale_factor` | 1 / std of the latents. You can use `compute_std.py` to compute it for a pretrained VAE. By default: 1.2465 (computed for the VAE checkpoint that we provided). |
|
60 |
+
| `class_cond` | Whether to condition on music genre (datasets: maestro, muscore and pop) for generation. Default: True. |
|
61 |
+
| `num_classes` | Number of classes (datasets). We trained on 3 datasets. |
|
62 |
+
| `class_label` | 0 for Maestro (classical performance), 1 for Muscore (classical sheet music), 2 for Pop. |
|
63 |
+
|
64 |
+
|
65 |
+
To guide on new rules in addition to what we considered (pitch histogram, note density and chord progression).
|
66 |
+
You can add the rule function to `music_rule_guidance/music_rules.py`, and add it to `FUNC_DICT` in `rule_maps.py`.
|
67 |
+
In addition, you need to pick a loss function for the newly added rule and add it to `LOSS_DICT` in `rule_maps.py`.
|
68 |
+
Then you can use the key in `FUNC_DICT` for `target_rules` in the config file.
|
69 |
+
|
70 |
+
|
71 |
+
This framework also supports editing existing excerpt:
|
72 |
+
```
|
73 |
+
python scripts/edit.py
|
74 |
+
--config_path scripts/configs/edit_table/nd_500_num16.yml
|
75 |
+
--batch_size 2
|
76 |
+
--num_samples 20
|
77 |
+
--data_dir <data_dir>
|
78 |
+
--model DiTRotary_XL_8
|
79 |
+
--model_path loggings/checkpoints/ema_0.9999_1200000.pt
|
80 |
+
--image_size 128 16
|
81 |
+
--in_channels 4
|
82 |
+
--scale_factor 1.2465
|
83 |
+
--class_cond True
|
84 |
+
--num_classes 3
|
85 |
+
--class_label 2
|
86 |
+
```
|
87 |
+
|
88 |
+
|
89 |
+
## Train diffusion model for music generation
|
90 |
+
To train a diffusion model for symbolic music generation, using the following script.
|
91 |
+
```
|
92 |
+
mpiexec -n 8 python scripts/train_dit.py
|
93 |
+
--dir <loggings/save_dir>
|
94 |
+
--data_dir <datasets/data_dir>
|
95 |
+
--model DiTRotary_XL_8
|
96 |
+
--image_size 128 16
|
97 |
+
--in_channels 4
|
98 |
+
--batch_size 32
|
99 |
+
--encode_rep 4
|
100 |
+
--shift_size 4
|
101 |
+
--pr_image_size 2560
|
102 |
+
--microbatch_encode -1
|
103 |
+
--class_cond True
|
104 |
+
--num_classes 3
|
105 |
+
--scale_factor <scale_factor>
|
106 |
+
--fs 100
|
107 |
+
--save_interval 10000
|
108 |
+
--resume <dir to the last saved model ckpt>
|
109 |
+
```
|
110 |
+
The meaning of each hyper-parameter is listed as follows:
|
111 |
+
| Hyper-parameter | Description |
|
112 |
+
|-------------------|-------------------------------------------------------------------------------------------------------------------|
|
113 |
+
| `mpiexec -n 8` | Multi-GPU training, using 8 GPUs. |
|
114 |
+
| `embed_model_name`| VAE config, default is `kl/f8-all-onset`. |
|
115 |
+
| `embed_model_ckpt`| Directory of the VAE checkpoint. |
|
116 |
+
| `dir` | Directory to save diffusion checkpoints and generated samples. |
|
117 |
+
| `data_dir` | Where you store your piano roll data. |
|
118 |
+
| `model` | Diffusion model name (config), e.g., `DiTRotary_XL_8`: a DiT XL model with 1D patch_size=8 (seq_len=256). |
|
119 |
+
| `image_size` | Latent space size (for 10.24s, the size is 128x16). |
|
120 |
+
| `in_channels` | Latent space channel (default is 4). |
|
121 |
+
| `batch_size` | Batch size on each GPU. Effective batch size is batch_size * num_GPUs. Aim for an effective batch size of 256. |
|
122 |
+
| `encoder_rep` | How many excerpts to create from a long sequence. Batch_size needs to be greater or equal to encoder_rep. Default: 4. |
|
123 |
+
| `shift_size` | Time shift between successive music excerpts from a long sequence. Default: 4. |
|
124 |
+
| `pr_image_size` | Length of a long sequence, need to be compatible with `encoder_rep` and `shift_size`. For example, for `encoder_rep=4` and `shift_size=4`, the excerpts created from a long sequence are 1-8, 5-12, 9-16 and 13-20. Therefore `pr_image_size=20x128=2560`. | |
|
125 |
+
| `class_cond` | Train with class conditioning (score(x,y), y is the class). |
|
126 |
+
| `num_classes` | Number of classes in your conditioning. E.g., 3 - 0 for maestro, 1 for muscore, 2 for pop. |
|
127 |
+
| `scale_factor` | 1 / std of the latents. You can use `compute_std.py` to compute it for a pretrained VAE. |
|
128 |
+
| `fs` | Time resolution is 1 / fs. |
|
129 |
+
| `save_interval` | Frequency of saving checkpoints, e.g., every 10k steps. |
|
130 |
+
|
131 |
+
## References
|
132 |
+
This repository is based on [openai/guided-diffusion](https://github.com/openai/guided-diffusion), with modifications for data representation, guidance algorithm and architecture improvements.
|
133 |
+
- The VAE architecture is modified upon [taming-transformers](https://github.com/CompVis/taming-transformers).
|
134 |
+
- The DiT architecture is modified upon [DiT](https://github.com/facebookresearch/DiT).
|
135 |
+
- Music evaluation code is adapted from [mgeval](https://github.com/RichardYang40148/mgeval) and [figaro](https://github.com/dvruette/figaro).
|
136 |
+
- MIDI to piano roll representation is adapted from [pretty_midi](https://github.com/craffel/pretty-midi).
|
compute_std.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from load_utils import load_model
|
4 |
+
from guided_diffusion import dist_util
|
5 |
+
from guided_diffusion.gaussian_diffusion import _encode, _decode
|
6 |
+
from guided_diffusion.pr_datasets_all import load_data
|
7 |
+
from tqdm import tqdm
|
8 |
+
from guided_diffusion.midi_util import visualize_full_piano_roll, save_piano_roll_midi
|
9 |
+
from music_rule_guidance import music_rules
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import warnings
|
12 |
+
warnings.filterwarnings("ignore")
|
13 |
+
plt.rcParams["figure.figsize"] = (20,3)
|
14 |
+
plt.rcParams['figure.dpi'] = 300
|
15 |
+
plt.rcParams['savefig.dpi'] = 300
|
16 |
+
|
17 |
+
|
18 |
+
MODEL_NAME = 'kl/f8-all-onset'
|
19 |
+
MODEL_CKPT = 'taming-transformers/checkpoints/all_onset/epoch_14.ckpt'
|
20 |
+
|
21 |
+
TOTAL_BATCH = 256
|
22 |
+
|
23 |
+
|
24 |
+
def main():
|
25 |
+
|
26 |
+
data = load_data(
|
27 |
+
data_dir='datasets/all-len-40-gap-16-no-empty_train.csv',
|
28 |
+
batch_size=32,
|
29 |
+
class_cond=True,
|
30 |
+
image_size=1024,
|
31 |
+
deterministic=False,
|
32 |
+
fs=100,
|
33 |
+
)
|
34 |
+
embed_model = load_model(MODEL_NAME, MODEL_CKPT)
|
35 |
+
del embed_model.loss
|
36 |
+
embed_model.to(dist_util.dev())
|
37 |
+
embed_model.eval()
|
38 |
+
|
39 |
+
z_list = []
|
40 |
+
with torch.no_grad():
|
41 |
+
for _ in tqdm(range(TOTAL_BATCH)):
|
42 |
+
batch, cond = next(data)
|
43 |
+
batch = batch.to(dist_util.dev())
|
44 |
+
enc = _encode(batch, embed_model, scale_factor=1.)
|
45 |
+
z_list.append(enc.cpu())
|
46 |
+
latents = torch.concat(z_list, dim=0)
|
47 |
+
scale_factor = 1. / latents.flatten().std().item()
|
48 |
+
print(f"scale_factor: {scale_factor}")
|
49 |
+
print("done")
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
main()
|
datasets/README.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Creating data representation for symbolic music
|
2 |
+
|
3 |
+
This directory contains instructions and scripts for creating training dataset.
|
4 |
+
Notice that you do not need to prepare dataset if you want to generate music and have target rule labels in mind.
|
5 |
+
You will need to prepare dataset if you want to train a model, or you want to extract rule labels from existing music excerpts.
|
6 |
+
|
7 |
+
We train our diffusion model on three datasets: [Maestro](https://magenta.tensorflow.org/datasets/maestro#v300) (classical piano performance), Muscore (crawled from the Muscore website), and Pop ([Pop1k7](https://drive.google.com/file/d/1qw_tVUntblIg4lW16vbpjLXVndkVtgDe/view) and [Pop909](https://github.com/music-x-lab/POP909-Dataset)).
|
8 |
+
You can download the data and put the midi files into the corresponding folder: `maestro`, `muscore` and `pop`.
|
9 |
+
|
10 |
+
Then run `piano_roll_all.py` to create piano roll excerpts from the dataset.
|
11 |
+
|
12 |
+
The above script creates piano roll excerpts of 1.28s.
|
13 |
+
To create music of 10.24s for training, return to the main folder and run `rearrange_pr_data.py` to concat shorter piano rolls to longer ones.
|
14 |
+
The processed data will be saved in `datasets/all-len-40-gap-16-no-empty` by default, and along with the data, there will be two csv files:
|
15 |
+
`all-len-40-gap-16-no-empty_train.csv` and `all-len-40-gap-16-no-empty_test.csv` that list the filenames.
|
16 |
+
|
17 |
+
If you want to extract rule label from the piano rolls and condition on a specific dataset, you need to create csv file for each dataset using:
|
18 |
+
```
|
19 |
+
python filter_class.py --file_path all-len-40-gap-16-no-empty_test.csv --class_label <class label>
|
20 |
+
```
|
datasets/all_midi.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
datasets/chunk_midi.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pretty_midi
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
def chunk_midi(input_path, output_dir, chunk_length=10.24):
|
6 |
+
# Ensure the output directory exists
|
7 |
+
if not os.path.exists(output_dir):
|
8 |
+
os.makedirs(output_dir)
|
9 |
+
|
10 |
+
for midi_file_name in os.listdir(input_path):
|
11 |
+
if not (midi_file_name.endswith('.midi') or midi_file_name.endswith('.mid')):
|
12 |
+
continue # Skip non-midi files
|
13 |
+
|
14 |
+
full_path = os.path.join(input_path, midi_file_name)
|
15 |
+
try:
|
16 |
+
midi_data = pretty_midi.PrettyMIDI(full_path)
|
17 |
+
except Exception as e:
|
18 |
+
print(f"Error processing {midi_file_name}: {e}")
|
19 |
+
continue # Skip to the next file if an error occurs
|
20 |
+
|
21 |
+
end_time = midi_data.get_end_time() # Get end time directly with pretty_midi
|
22 |
+
num_chunks = int(end_time // chunk_length) + (1 if end_time % chunk_length > 0 else 0)
|
23 |
+
|
24 |
+
base_name, file_extension = os.path.splitext(midi_file_name)
|
25 |
+
|
26 |
+
for i in range(num_chunks):
|
27 |
+
start_time = i * chunk_length
|
28 |
+
segment_end_time = min((i + 1) * chunk_length, midi_data.get_end_time())
|
29 |
+
|
30 |
+
# Create a new MIDI object for each chunk
|
31 |
+
chunk_midi_data = pretty_midi.PrettyMIDI()
|
32 |
+
|
33 |
+
# Merge non-drum instruments into a single instrument
|
34 |
+
merged_instrument = pretty_midi.Instrument(program=0, is_drum=False)
|
35 |
+
|
36 |
+
for instrument in midi_data.instruments:
|
37 |
+
if not instrument.is_drum:
|
38 |
+
for note in instrument.notes:
|
39 |
+
if start_time <= note.start < segment_end_time:
|
40 |
+
# Shift the note start and end times to start at 0
|
41 |
+
new_note = pretty_midi.Note(
|
42 |
+
velocity=note.velocity,
|
43 |
+
pitch=note.pitch,
|
44 |
+
start=note.start - start_time,
|
45 |
+
end=note.end - start_time
|
46 |
+
)
|
47 |
+
merged_instrument.notes.append(new_note)
|
48 |
+
else:
|
49 |
+
# If it's a drum instrument, just adjust the note times and append it
|
50 |
+
new_drum_instrument = pretty_midi.Instrument(program=instrument.program, is_drum=True, name=instrument.name)
|
51 |
+
new_drum_instrument.notes = [note for note in instrument.notes if start_time <= note.start < segment_end_time]
|
52 |
+
for note in new_drum_instrument.notes:
|
53 |
+
note.start -= start_time
|
54 |
+
note.end -= start_time
|
55 |
+
chunk_midi_data.instruments.append(new_drum_instrument)
|
56 |
+
|
57 |
+
# Add the merged instrument to the MIDI object
|
58 |
+
chunk_midi_data.instruments.append(merged_instrument)
|
59 |
+
|
60 |
+
# Save the chunk with the same extension as the original file
|
61 |
+
new_midi_name = "{}_{}{}".format(base_name, i, file_extension)
|
62 |
+
chunk_midi_data.write(os.path.join(output_dir, new_midi_name))
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
parser = argparse.ArgumentParser(description="Chunk MIDI files into specified lengths.")
|
67 |
+
parser.add_argument("--input_path", type=str, help="Path to the directory containing the MIDI files to chunk.")
|
68 |
+
parser.add_argument("--output_dir", type=str, help="Path to the directory where the chunked MIDI files will be saved.")
|
69 |
+
parser.add_argument("--chunk_length", type=float, default=10.24, help="length to chunk the midi file to (s).")
|
70 |
+
args = parser.parse_args()
|
71 |
+
|
72 |
+
chunk_midi(args.input_path, args.output_dir, chunk_length=args.chunk_length)
|
datasets/filter_class.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
def filter_and_save_csv(file_path, class_label):
|
5 |
+
"""
|
6 |
+
Filters a CSV file to keep only the rows where the 'classes' column equals the specified class label.
|
7 |
+
Saves the filtered DataFrame to a new CSV file.
|
8 |
+
|
9 |
+
:param file_path: Path to the original CSV file.
|
10 |
+
:param class_label: The class label to filter by.
|
11 |
+
"""
|
12 |
+
# Read the CSV file
|
13 |
+
df = pd.read_csv(file_path)
|
14 |
+
|
15 |
+
# Filter out rows where 'classes' equals the specified class_label
|
16 |
+
filtered_df = df[df['classes'] == class_label]
|
17 |
+
|
18 |
+
# Save the filtered DataFrame to a new CSV file
|
19 |
+
# The new file name is the original file name with '_cls_<class_label>' appended before the file extension
|
20 |
+
new_file_path = file_path.replace('.csv', f'_cls_{class_label}.csv')
|
21 |
+
filtered_df.to_csv(new_file_path, index=False)
|
22 |
+
|
23 |
+
print(f"Filtered CSV saved as: {new_file_path}")
|
24 |
+
|
25 |
+
def main():
|
26 |
+
# Set up the argument parser
|
27 |
+
parser = argparse.ArgumentParser(description="Filter a CSV file by class and save to a new file.")
|
28 |
+
parser.add_argument("--file_path", type=str, help="Path to the original CSV file")
|
29 |
+
parser.add_argument("--class_label", type=int, help="The class label to filter by")
|
30 |
+
|
31 |
+
# Parse arguments
|
32 |
+
args = parser.parse_args()
|
33 |
+
|
34 |
+
# Call the function with the provided arguments
|
35 |
+
filter_and_save_csv(args.file_path, args.class_label)
|
36 |
+
|
37 |
+
if __name__ == "__main__":
|
38 |
+
main()
|
datasets/piano_roll_all.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
from torchvision import transforms, utils
|
5 |
+
import pretty_midi
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
import math
|
10 |
+
from music_rule_guidance.music_rules import MAX_PIANO, MIN_PIANO
|
11 |
+
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
plt.rcParams["figure.figsize"] = (6,3)
|
14 |
+
plt.rcParams['figure.dpi'] = 300
|
15 |
+
plt.rcParams['savefig.dpi'] = 300
|
16 |
+
|
17 |
+
CC_SUSTAIN_PEDAL = 64
|
18 |
+
|
19 |
+
|
20 |
+
def split_csv(csv_path='merged_midi.csv'):
|
21 |
+
# separate training validation testing files
|
22 |
+
df = pd.read_csv(csv_path)
|
23 |
+
save_name = csv_path[:csv_path.rfind('.csv')]
|
24 |
+
for split in ['train', 'validation', 'test']:
|
25 |
+
path = os.path.join(save_name, split + '.csv')
|
26 |
+
df_sub = df[df.split == split]
|
27 |
+
df_sub.to_csv(path, index=False)
|
28 |
+
return
|
29 |
+
|
30 |
+
|
31 |
+
def quantize_pedal(value, num_bins=8):
|
32 |
+
"""Quantize an integer value from 0 to 127 into 8 bins and return the center value of the bin."""
|
33 |
+
if value < 0 or value > 127:
|
34 |
+
raise ValueError("Value should be between 0 and 127")
|
35 |
+
# Determine bin size
|
36 |
+
bin_size = 128 // num_bins # 16
|
37 |
+
# Quantize the value
|
38 |
+
bin_index = value // bin_size
|
39 |
+
bin_center = bin_size * bin_index + bin_size // 2
|
40 |
+
# Handle edge case for the last bin
|
41 |
+
if bin_center > 127:
|
42 |
+
bin_center = 127
|
43 |
+
return bin_center
|
44 |
+
|
45 |
+
|
46 |
+
def get_full_piano_roll(midi_data, fs, show=False):
|
47 |
+
# do not process sustain pedal
|
48 |
+
piano_roll, onset_roll = midi_data.get_piano_roll(fs=fs, pedal_threshold=None, onset=True)
|
49 |
+
# save pedal roll explicitly
|
50 |
+
pedal_roll = np.zeros_like(piano_roll)
|
51 |
+
# process pedal
|
52 |
+
for instru in midi_data.instruments:
|
53 |
+
pedal_changes = [_e for _e in instru.control_changes if _e.number == CC_SUSTAIN_PEDAL]
|
54 |
+
for cc in pedal_changes:
|
55 |
+
time_now = int(cc.time * fs)
|
56 |
+
if time_now < pedal_roll.shape[-1]:
|
57 |
+
# need to distinguish control_change 0 and background 0, with quantize 0-16 will be 8
|
58 |
+
# in muscore files, 0 immediately followed by 127, need to shift by one column
|
59 |
+
if pedal_roll[MIN_PIANO, time_now] != 0. and abs(pedal_roll[MIN_PIANO, time_now] - cc.value) > 64:
|
60 |
+
# use shift 2 here to prevent missing change when using interpolation augmentation
|
61 |
+
pedal_roll[MIN_PIANO:MAX_PIANO + 1, min(time_now + 2, pedal_roll.shape[-1] - 1)] = quantize_pedal(cc.value)
|
62 |
+
else:
|
63 |
+
pedal_roll[MIN_PIANO:MAX_PIANO + 1, time_now] = quantize_pedal(cc.value)
|
64 |
+
full_roll = np.concatenate((piano_roll[None], onset_roll[None], pedal_roll[None]), axis=0)
|
65 |
+
if show:
|
66 |
+
plt.imshow(piano_roll[::-1, :1024], vmin=0, vmax=127)
|
67 |
+
plt.show()
|
68 |
+
plt.imshow(pedal_roll[::-1, :1024], vmin=0, vmax=127)
|
69 |
+
plt.show()
|
70 |
+
return full_roll
|
71 |
+
|
72 |
+
|
73 |
+
def preprocess_midi(target='merged', csv_path='merged_midi.csv', fs=100., image_size=128, overlap=False, show=False):
|
74 |
+
# get piano roll from midi file
|
75 |
+
df = pd.read_csv(csv_path)
|
76 |
+
total_pieces = len(df)
|
77 |
+
if not os.path.exists(target):
|
78 |
+
os.makedirs(target)
|
79 |
+
for split in ['train', 'test']:
|
80 |
+
path = os.path.join(target, split)
|
81 |
+
if not os.path.exists(path):
|
82 |
+
os.makedirs(path)
|
83 |
+
for i in tqdm(range(total_pieces)):
|
84 |
+
midi_filename = df.midi_filename[i]
|
85 |
+
split = df.split[i]
|
86 |
+
dataset = df.dataset[i]
|
87 |
+
path = os.path.join(target, split)
|
88 |
+
midi_data = pretty_midi.PrettyMIDI(os.path.join(dataset, midi_filename))
|
89 |
+
full_roll = get_full_piano_roll(midi_data, fs=fs, show=show)
|
90 |
+
for j in range(0, full_roll.shape[-1], image_size):
|
91 |
+
if j + image_size <= full_roll.shape[-1]:
|
92 |
+
full_roll_excerpt = full_roll[:, :, j:j + image_size]
|
93 |
+
else:
|
94 |
+
full_roll_excerpt = np.zeros((3, full_roll.shape[1], image_size)) # 2x128ximage_size
|
95 |
+
full_roll_excerpt[:, :, : full_roll.shape[-1] - j] = full_roll[:, :, j:]
|
96 |
+
empty_roll = math.isclose(full_roll_excerpt.max(), 0.)
|
97 |
+
if not empty_roll:
|
98 |
+
# Find the last '/' in the string
|
99 |
+
last_slash_index = midi_filename.rfind('/')
|
100 |
+
# Find the '.npy' in the string
|
101 |
+
dot_mid_index = midi_filename.rfind('.mid')
|
102 |
+
# Extract the substring between last '/' and '.mid'
|
103 |
+
save_name = midi_filename[last_slash_index + 1:dot_mid_index]
|
104 |
+
full_roll_excerpt = full_roll_excerpt.astype(np.uint8)
|
105 |
+
np.save(os.path.join(path, save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
|
106 |
+
# save with dataset name for VAE duplicate file names
|
107 |
+
# np.save(os.path.join(path, dataset + '_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
|
108 |
+
if overlap:
|
109 |
+
for j in range(image_size//2, full_roll.shape[-1], image_size): # overlap with image_size//2
|
110 |
+
if j + image_size <= full_roll.shape[-1]:
|
111 |
+
full_roll_excerpt = full_roll[:, :, j:j + image_size]
|
112 |
+
else:
|
113 |
+
full_roll_excerpt = np.zeros((3, full_roll.shape[1], image_size))
|
114 |
+
full_roll_excerpt[:, :, : full_roll.shape[-1] - j] = full_roll[:, :, j:]
|
115 |
+
empty_roll = math.isclose(full_roll_excerpt.max(), 0.)
|
116 |
+
if not empty_roll:
|
117 |
+
last_slash_index = midi_filename.rfind('/')
|
118 |
+
dot_mid_index = midi_filename.rfind('.mid')
|
119 |
+
save_name = midi_filename[last_slash_index + 1:dot_mid_index]
|
120 |
+
full_roll_excerpt = full_roll_excerpt.astype(np.uint8)
|
121 |
+
np.save(os.path.join(path, 'shift_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
|
122 |
+
# save with dataset name for VAE duplicate file names
|
123 |
+
# np.save(os.path.join(path, dataset + '_' + 'shift_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
|
124 |
+
return
|
125 |
+
|
126 |
+
|
127 |
+
def main():
|
128 |
+
# create fs=100 1.28s datasets without overlap (can be rearranged)
|
129 |
+
preprocess_midi(target='all-128-fs100', csv_path='all_midi.csv', fs=100, image_size=128, overlap=False, show=False)
|
130 |
+
# create fs=100 2.56s datasets with overlap (used for vae training), when load in, need to select 1.28s from 2.56s
|
131 |
+
# preprocess_midi(target='all-256-overlap-fs100', csv_path='all_midi.csv', fs=100, image_size=256, overlap=True,
|
132 |
+
# show=False)
|
133 |
+
# create fs=12.5 (0.08s) for pixel space diffusion model, rearrangement with length 2
|
134 |
+
# preprocess_midi(target='all-128-fs12.5', csv_path='all_midi.csv', fs=12.5, image_size=128, overlap=False,
|
135 |
+
# show=False)
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
main()
|
datasets/select_midi.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pretty_midi
|
3 |
+
import argparse
|
4 |
+
import random
|
5 |
+
|
6 |
+
def select_midi(input_path, output_dir, select_length=10.24):
|
7 |
+
# Ensure the output directory exists
|
8 |
+
if not os.path.exists(output_dir):
|
9 |
+
os.makedirs(output_dir)
|
10 |
+
|
11 |
+
for midi_file_name in os.listdir(input_path):
|
12 |
+
if not (midi_file_name.endswith('.midi') or midi_file_name.endswith('.mid')):
|
13 |
+
continue # Skip non-midi files
|
14 |
+
|
15 |
+
full_path = os.path.join(input_path, midi_file_name)
|
16 |
+
try:
|
17 |
+
midi_data = pretty_midi.PrettyMIDI(full_path)
|
18 |
+
except Exception as e:
|
19 |
+
print(f"Error processing {midi_file_name}: {e}")
|
20 |
+
continue # Skip to the next file if an error occurs
|
21 |
+
|
22 |
+
end_time = midi_data.get_end_time() # Get end time directly with pretty_midi
|
23 |
+
|
24 |
+
if select_length > end_time:
|
25 |
+
print("Segment length is longer than the MIDI file duration.")
|
26 |
+
continue
|
27 |
+
|
28 |
+
start_time = random.uniform(0, end_time - select_length)
|
29 |
+
segment_end_time = start_time + select_length
|
30 |
+
|
31 |
+
# Create a new MIDI object for each chunk
|
32 |
+
chunk_midi_data = pretty_midi.PrettyMIDI()
|
33 |
+
|
34 |
+
# Merge non-drum instruments into a single instrument
|
35 |
+
merged_instrument = pretty_midi.Instrument(program=0, is_drum=False)
|
36 |
+
|
37 |
+
for instrument in midi_data.instruments:
|
38 |
+
if not instrument.is_drum:
|
39 |
+
for note in instrument.notes:
|
40 |
+
if start_time <= note.start < segment_end_time:
|
41 |
+
# Shift the note start and end times to start at 0
|
42 |
+
new_note = pretty_midi.Note(
|
43 |
+
velocity=note.velocity,
|
44 |
+
pitch=note.pitch,
|
45 |
+
start=note.start - start_time,
|
46 |
+
end=note.end - start_time
|
47 |
+
)
|
48 |
+
merged_instrument.notes.append(new_note)
|
49 |
+
else:
|
50 |
+
# If it's a drum instrument, just adjust the note times and append it
|
51 |
+
new_drum_instrument = pretty_midi.Instrument(program=instrument.program, is_drum=True,
|
52 |
+
name=instrument.name)
|
53 |
+
new_drum_instrument.notes = [note for note in instrument.notes if
|
54 |
+
start_time <= note.start < segment_end_time]
|
55 |
+
for note in new_drum_instrument.notes:
|
56 |
+
note.start -= start_time
|
57 |
+
note.end -= start_time
|
58 |
+
chunk_midi_data.instruments.append(new_drum_instrument)
|
59 |
+
|
60 |
+
# Add the merged instrument to the MIDI object
|
61 |
+
chunk_midi_data.instruments.append(merged_instrument)
|
62 |
+
|
63 |
+
# Save the chunk with the same name as the original file
|
64 |
+
chunk_midi_data.write(os.path.join(output_dir, midi_file_name))
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
parser = argparse.ArgumentParser(description="Chunk MIDI files into specified lengths.")
|
69 |
+
parser.add_argument("--input_path", type=str, help="Path to the directory containing the MIDI files to chunk.")
|
70 |
+
parser.add_argument("--output_dir", type=str, help="Path to the directory where the chunked MIDI files will be saved.")
|
71 |
+
parser.add_argument("--select_length", type=float, default=10.24, help="length to chunk the midi file to (s).")
|
72 |
+
args = parser.parse_args()
|
73 |
+
|
74 |
+
select_midi(args.input_path, args.output_dir, select_length=args.select_length)
|
diff_collage/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Diff Collage
|
2 |
+
|
3 |
+
This is an implementation of the [DiffCollage](https://arxiv.org/abs/2303.17076) paper. We use DiffCollage to generate long sequence following certain structure (e.g. loop).
|
diff_collage/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .w_loss import *
|
2 |
+
from .generic_sampler import *
|
3 |
+
from .condind_long import CondIndSimple
|
4 |
+
from .condind_circle import CondIndCircle
|
5 |
+
from .avg_long import AvgLong
|
diff_collage/avg_circle.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch as th
|
3 |
+
from einops import rearrange
|
4 |
+
|
5 |
+
from .generic_sampler import SimpleWork
|
6 |
+
from .w_img import split_wimg, avg_merge_wimg
|
7 |
+
|
8 |
+
class AvgCircle(SimpleWork):
|
9 |
+
def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32):
|
10 |
+
c, h, w = shape
|
11 |
+
self.base_img_w = w
|
12 |
+
self.overlap_size = overlap_size
|
13 |
+
self.num_img = num_img
|
14 |
+
final_img_w = w * num_img - self.overlap_size * num_img
|
15 |
+
super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn))
|
16 |
+
|
17 |
+
def get_eps_t_fn(self, eps_scalar_t_fn):
|
18 |
+
def eps_t_fn(long_x, scalar_t, enable_grad=False):
|
19 |
+
shift = np.random.randint(self.base_img_w)
|
20 |
+
long_x = th.cat(
|
21 |
+
[
|
22 |
+
long_x[:,:,:,shift:],
|
23 |
+
long_x[:,:,:,:shift]
|
24 |
+
],
|
25 |
+
dim=-1
|
26 |
+
)
|
27 |
+
|
28 |
+
x = th.cat(
|
29 |
+
[
|
30 |
+
long_x,
|
31 |
+
long_x[:,:,:,:self.overlap_size]
|
32 |
+
],
|
33 |
+
dim=-1,
|
34 |
+
)
|
35 |
+
xs, _overlap = split_wimg(x, self.num_img, rtn_overlap=True)
|
36 |
+
assert _overlap == self.overlap_size
|
37 |
+
full_eps = eps_scalar_t_fn(xs, scalar_t, enable_grad) # #((b,n), c, h, w)
|
38 |
+
|
39 |
+
eps = avg_merge_wimg(full_eps, self.overlap_size, n=self.num_img)
|
40 |
+
eps = th.cat(
|
41 |
+
[
|
42 |
+
(eps[:,:,:,:self.overlap_size] + eps[:,:,:,-self.overlap_size:])/2.0,
|
43 |
+
eps[:,:,:,self.overlap_size:-self.overlap_size]
|
44 |
+
],
|
45 |
+
dim=-1
|
46 |
+
)
|
47 |
+
assert eps.shape == long_x.shape
|
48 |
+
return th.cat(
|
49 |
+
[
|
50 |
+
eps[:,:,:,-shift:],
|
51 |
+
eps[:,:,:,:-shift],
|
52 |
+
],
|
53 |
+
dim=-1
|
54 |
+
)
|
55 |
+
# return eps
|
56 |
+
|
57 |
+
return eps_t_fn
|
58 |
+
|
59 |
+
def x0_fn(self, xt, scalar_t, enable_grad=False):
|
60 |
+
cur_eps = self.eps_scalar_t_fn(xt, scalar_t, enable_grad)
|
61 |
+
x0 = xt - scalar_t * cur_eps
|
62 |
+
return x0, {}, {
|
63 |
+
"x0": x0.cpu()
|
64 |
+
}
|
diff_collage/avg_long.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as th
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
from .generic_sampler import SimpleWork
|
5 |
+
from .w_img import split_wimg, avg_merge_wimg
|
6 |
+
|
7 |
+
class AvgLong(SimpleWork):
|
8 |
+
def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32):
|
9 |
+
c, h, w = shape
|
10 |
+
assert overlap_size == w // 2
|
11 |
+
self.overlap_size = overlap_size
|
12 |
+
self.num_img = num_img
|
13 |
+
final_img_w = w * num_img - self.overlap_size * (num_img - 1)
|
14 |
+
super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn))
|
15 |
+
|
16 |
+
def loss(self, x):
|
17 |
+
x1, x2 = x[:-1], x[1:]
|
18 |
+
return th.sum(
|
19 |
+
(th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
|
20 |
+
dim=(1, 2, 3),
|
21 |
+
)
|
22 |
+
|
23 |
+
def get_eps_t_fn(self, eps_scalar_t_fn):
|
24 |
+
def eps_t_fn(long_x, scalar_t, y=None):
|
25 |
+
xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
|
26 |
+
if y is not None:
|
27 |
+
y = y.repeat_interleave(self.num_img)
|
28 |
+
scalar_t = scalar_t.repeat_interleave(self.num_img)
|
29 |
+
full_eps = eps_scalar_t_fn(xs, scalar_t, y=y) #((b,n), c, h, w)
|
30 |
+
full_eps = rearrange(
|
31 |
+
full_eps,
|
32 |
+
"(b n) c h w -> n b c h w", n = self.num_img
|
33 |
+
)
|
34 |
+
|
35 |
+
whole_eps = rearrange(
|
36 |
+
full_eps,
|
37 |
+
"n b c h w -> (b n) c h w"
|
38 |
+
)
|
39 |
+
return avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
|
40 |
+
return eps_t_fn
|
diff_collage/condind_circle.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as th
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
from .generic_sampler import SimpleWork
|
5 |
+
from .w_img import split_wimg, avg_merge_wimg
|
6 |
+
|
7 |
+
class CondIndCircle(SimpleWork):
|
8 |
+
def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32):
|
9 |
+
c, h, w = shape
|
10 |
+
assert overlap_size == w // 2
|
11 |
+
self.overlap_size = overlap_size
|
12 |
+
self.num_img = num_img
|
13 |
+
final_img_w = w * num_img - self.overlap_size * num_img
|
14 |
+
super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn))
|
15 |
+
|
16 |
+
def circle_split(self, in_x):
|
17 |
+
long_x = th.cat(
|
18 |
+
[
|
19 |
+
in_x,
|
20 |
+
in_x[:,:,:,:self.overlap_size],
|
21 |
+
],
|
22 |
+
dim=-1
|
23 |
+
)
|
24 |
+
xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
|
25 |
+
return xs
|
26 |
+
|
27 |
+
def circle_merge(self, xs, overlap_size=None):
|
28 |
+
if overlap_size is None:
|
29 |
+
overlap_size = self.overlap_size
|
30 |
+
long_xs = avg_merge_wimg(xs, overlap_size, n=self.num_img, is_avg=True)
|
31 |
+
return th.cat(
|
32 |
+
[
|
33 |
+
(
|
34 |
+
long_xs[:,:,:,:overlap_size] + long_xs[:,:,:,-overlap_size:]
|
35 |
+
) / 2.0,
|
36 |
+
long_xs[:,:,:,overlap_size:-overlap_size]
|
37 |
+
],
|
38 |
+
dim=-1
|
39 |
+
)
|
40 |
+
|
41 |
+
def get_eps_t_fn(self, eps_scalar_t_fn):
|
42 |
+
def eps_t_fn(in_x, scalar_t, y=None):
|
43 |
+
long_x = th.cat(
|
44 |
+
[
|
45 |
+
in_x,
|
46 |
+
in_x[:,:,:,:self.overlap_size],
|
47 |
+
],
|
48 |
+
dim=-1
|
49 |
+
)
|
50 |
+
xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
|
51 |
+
if y is not None:
|
52 |
+
y = y.repeat_interleave(self.num_img)
|
53 |
+
scalar_t = scalar_t.repeat_interleave(self.num_img)
|
54 |
+
full_eps = eps_scalar_t_fn(xs, scalar_t, y=y) #((b,n), c, h, w)
|
55 |
+
full_eps = rearrange(
|
56 |
+
full_eps,
|
57 |
+
"(b n) c h w -> n b c h w", n = self.num_img
|
58 |
+
)
|
59 |
+
|
60 |
+
# calculate half eps
|
61 |
+
half_eps = eps_scalar_t_fn(xs[:,:,:,-self.overlap_size:], scalar_t, y=y) #((b,n), c, h, w//2)
|
62 |
+
half_eps = rearrange(
|
63 |
+
half_eps,
|
64 |
+
"(b n) c h w -> n b c h w", n = self.num_img
|
65 |
+
)
|
66 |
+
|
67 |
+
half_eps[-1]=0
|
68 |
+
|
69 |
+
full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps
|
70 |
+
whole_eps = rearrange(
|
71 |
+
full_eps,
|
72 |
+
"n b c h w -> (b n) c h w"
|
73 |
+
)
|
74 |
+
long_eps = avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
|
75 |
+
return th.cat(
|
76 |
+
[
|
77 |
+
(
|
78 |
+
long_eps[:,:,:,:self.overlap_size] + long_eps[:,:,:,-self.overlap_size:]
|
79 |
+
) / 2.0,
|
80 |
+
long_eps[:,:,:,self.overlap_size:-self.overlap_size]
|
81 |
+
],
|
82 |
+
dim=-1
|
83 |
+
)
|
84 |
+
return eps_t_fn
|
85 |
+
|
86 |
+
|
87 |
+
class CondIndCircleSR(SimpleWork):
|
88 |
+
def __init__(self, shape, eps_scalar_t_fn, num_img, low_res, overlap_size=32):
|
89 |
+
c, h, w = shape
|
90 |
+
assert overlap_size == w // 2
|
91 |
+
self.overlap_size = overlap_size
|
92 |
+
self.low_overlap_size = low_res.shape[-2] // 2
|
93 |
+
self.num_img = num_img
|
94 |
+
final_img_w = w * num_img - self.overlap_size * num_img
|
95 |
+
assert low_res.shape[-1] == self.low_overlap_size * num_img
|
96 |
+
|
97 |
+
self.square_fn = self.get_square_sr_fn(eps_scalar_t_fn, low_res)
|
98 |
+
self.half_fn = self.get_half_sr_fn(eps_scalar_t_fn, low_res)
|
99 |
+
|
100 |
+
super().__init__((c, h, final_img_w), self.get_eps_t_fn())
|
101 |
+
|
102 |
+
def circle_split(self, in_x, overlap_size=None):
|
103 |
+
if overlap_size is None:
|
104 |
+
overlap_size = self.overlap_size
|
105 |
+
long_x = th.cat(
|
106 |
+
[
|
107 |
+
in_x,
|
108 |
+
in_x[:,:,:,:overlap_size],
|
109 |
+
],
|
110 |
+
dim=-1
|
111 |
+
)
|
112 |
+
xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
|
113 |
+
return xs
|
114 |
+
|
115 |
+
def circle_merge(self, xs, overlap_size=None):
|
116 |
+
if overlap_size is None:
|
117 |
+
overlap_size = self.overlap_size
|
118 |
+
long_xs = avg_merge_wimg(xs, overlap_size, n=self.num_img, is_avg=True)
|
119 |
+
return th.cat(
|
120 |
+
[
|
121 |
+
(
|
122 |
+
long_xs[:,:,:,:overlap_size] + long_xs[:,:,:,-overlap_size:]
|
123 |
+
) / 2.0,
|
124 |
+
long_xs[:,:,:,overlap_size:-overlap_size]
|
125 |
+
],
|
126 |
+
dim=-1
|
127 |
+
)
|
128 |
+
|
129 |
+
def get_square_sr_fn(self, eps_fn, low_res):
|
130 |
+
low_res = self.circle_split(low_res, self.low_overlap_size)
|
131 |
+
def _fn(_x, _t, enable_grad):
|
132 |
+
context = th.enable_grad if enable_grad else th.no_grad
|
133 |
+
with context():
|
134 |
+
vec_t = th.ones(_x.shape[0]).cuda() * _t
|
135 |
+
rtn = eps_fn(_x, vec_t, low_res)
|
136 |
+
rtn = rearrange(
|
137 |
+
rtn,
|
138 |
+
"(b n) c h w -> n b c h w", n = self.num_img
|
139 |
+
)
|
140 |
+
return rtn
|
141 |
+
return _fn
|
142 |
+
|
143 |
+
def get_half_sr_fn(self, eps_fn, low_res):
|
144 |
+
low_res = self.circle_split(low_res, self.low_overlap_size)
|
145 |
+
def _fn(_x, _t, enable_grad):
|
146 |
+
context = th.enable_grad if enable_grad else th.no_grad
|
147 |
+
with context():
|
148 |
+
vec_t = th.ones(_x.shape[0]).cuda() * _t
|
149 |
+
half_eps = eps_fn(_x[:,:,:,-self.overlap_size:], vec_t, low_res[:,:,:,-self.low_overlap_size:])
|
150 |
+
half_eps = rearrange(
|
151 |
+
half_eps,
|
152 |
+
"(b n) c h w -> n b c h w", n = self.num_img
|
153 |
+
)
|
154 |
+
|
155 |
+
half_eps[-1]=0
|
156 |
+
return half_eps
|
157 |
+
return _fn
|
158 |
+
|
159 |
+
def get_eps_t_fn(self):
|
160 |
+
def eps_t_fn(in_x, scalar_t, enable_grad=False):
|
161 |
+
long_x = th.cat(
|
162 |
+
[
|
163 |
+
in_x,
|
164 |
+
in_x[:,:,:,:self.overlap_size],
|
165 |
+
],
|
166 |
+
dim=-1
|
167 |
+
)
|
168 |
+
xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
|
169 |
+
|
170 |
+
# full eps
|
171 |
+
full_eps = self.square_fn(xs, scalar_t, enable_grad)
|
172 |
+
# calculate half eps
|
173 |
+
half_eps = self.half_fn(xs, scalar_t, enable_grad)
|
174 |
+
|
175 |
+
full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps
|
176 |
+
whole_eps = rearrange(
|
177 |
+
full_eps,
|
178 |
+
"n b c h w -> (b n) c h w"
|
179 |
+
)
|
180 |
+
long_eps = avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
|
181 |
+
return th.cat(
|
182 |
+
[
|
183 |
+
(
|
184 |
+
long_eps[:,:,:,:self.overlap_size] + long_eps[:,:,:,-self.overlap_size:]
|
185 |
+
) / 2.0,
|
186 |
+
long_eps[:,:,:,self.overlap_size:-self.overlap_size]
|
187 |
+
],
|
188 |
+
dim=-1
|
189 |
+
)
|
190 |
+
return eps_t_fn
|
diff_collage/condind_long.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch as th
|
3 |
+
from einops import rearrange
|
4 |
+
|
5 |
+
from .generic_sampler import SimpleWork
|
6 |
+
from .w_img import split_wimg, avg_merge_wimg
|
7 |
+
|
8 |
+
class CondIndSimple(SimpleWork):
|
9 |
+
def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32):
|
10 |
+
c, h, w = shape
|
11 |
+
assert overlap_size == w // 2
|
12 |
+
self.overlap_size = overlap_size
|
13 |
+
self.num_img = num_img
|
14 |
+
final_img_w = w * num_img - self.overlap_size * (num_img - 1)
|
15 |
+
super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn))
|
16 |
+
|
17 |
+
def loss(self, x):
|
18 |
+
x1, x2 = x[:-1], x[1:]
|
19 |
+
return th.sum(
|
20 |
+
(th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
|
21 |
+
dim=(1, 2, 3),
|
22 |
+
)
|
23 |
+
|
24 |
+
def get_eps_t_fn(self, eps_scalar_t_fn):
|
25 |
+
def eps_t_fn(long_x, scalar_t, y=None):
|
26 |
+
xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
|
27 |
+
if y is not None:
|
28 |
+
y = y.repeat_interleave(self.num_img)
|
29 |
+
scalar_t = scalar_t.repeat_interleave(self.num_img)
|
30 |
+
full_eps = eps_scalar_t_fn(xs, scalar_t, y=y) #((b,n), c, h, w)
|
31 |
+
full_eps = rearrange(
|
32 |
+
full_eps,
|
33 |
+
"(b n) c h w -> n b c h w", n = self.num_img
|
34 |
+
)
|
35 |
+
|
36 |
+
# calculate half eps
|
37 |
+
half_eps = eps_scalar_t_fn(xs[:,:,:,-self.overlap_size:], scalar_t, y=y) #((b,n), c, h, w//2)
|
38 |
+
half_eps = rearrange(
|
39 |
+
half_eps,
|
40 |
+
"(b n) c h w -> n b c h w", n = self.num_img
|
41 |
+
)
|
42 |
+
|
43 |
+
half_eps[-1]=0
|
44 |
+
|
45 |
+
full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps
|
46 |
+
whole_eps = rearrange(
|
47 |
+
full_eps,
|
48 |
+
"n b c h w -> (b n) c h w"
|
49 |
+
)
|
50 |
+
return avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
|
51 |
+
return eps_t_fn
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
class CondIndSR(SimpleWork):
|
56 |
+
def __init__(self, shape, eps_scalar_t_fn, num_img, low_res, overlap_size=128):
|
57 |
+
c, h, w = shape
|
58 |
+
assert overlap_size == w // 2
|
59 |
+
self.overlap_size = overlap_size
|
60 |
+
self.low_overlap_size = low_res.shape[-2] // 2
|
61 |
+
self.num_img = num_img
|
62 |
+
final_img_w = w * num_img - self.overlap_size * (num_img - 1)
|
63 |
+
assert low_res.shape[-1] == self.low_overlap_size * (num_img + 1)
|
64 |
+
|
65 |
+
self.square_fn = self.get_square_sr_fn(eps_scalar_t_fn, low_res)
|
66 |
+
self.half_fn = self.get_half_sr_fn(eps_scalar_t_fn, low_res)
|
67 |
+
|
68 |
+
super().__init__((c, h, final_img_w), self.get_eps_t_fn())
|
69 |
+
|
70 |
+
def get_square_sr_fn(self, eps_fn, low_res):
|
71 |
+
low_res = split_wimg(low_res, self.num_img, False)
|
72 |
+
def _fn(_x, _t, enable_grad):
|
73 |
+
context = th.enable_grad if enable_grad else th.no_grad
|
74 |
+
with context():
|
75 |
+
vec_t = th.ones(_x.shape[0]).cuda() * _t
|
76 |
+
rtn = eps_fn(_x, vec_t, low_res)
|
77 |
+
rtn = rearrange(
|
78 |
+
rtn,
|
79 |
+
"(b n) c h w -> n b c h w", n = self.num_img
|
80 |
+
)
|
81 |
+
return rtn
|
82 |
+
return _fn
|
83 |
+
|
84 |
+
def get_half_sr_fn(self, eps_fn, low_res):
|
85 |
+
low_res = split_wimg(low_res, self.num_img, False)
|
86 |
+
def _fn(_x, _t, enable_grad):
|
87 |
+
context = th.enable_grad if enable_grad else th.no_grad
|
88 |
+
with context():
|
89 |
+
vec_t = th.ones(_x.shape[0]).cuda() * _t
|
90 |
+
half_eps = eps_fn(_x[:,:,:,-self.overlap_size:], vec_t, low_res[:,:,:,-self.low_overlap_size:])
|
91 |
+
half_eps = rearrange(
|
92 |
+
half_eps,
|
93 |
+
"(b n) c h w -> n b c h w", n = self.num_img
|
94 |
+
)
|
95 |
+
|
96 |
+
half_eps[-1]=0
|
97 |
+
return half_eps
|
98 |
+
return _fn
|
99 |
+
|
100 |
+
def get_eps_t_fn(self):
|
101 |
+
def eps_t_fn(in_x, scalar_t, enable_grad=False):
|
102 |
+
xs = split_wimg(in_x, self.num_img, rtn_overlap=False)
|
103 |
+
|
104 |
+
# full eps
|
105 |
+
full_eps = self.square_fn(xs, scalar_t, enable_grad)
|
106 |
+
# calculate half eps
|
107 |
+
half_eps = self.half_fn(xs, scalar_t, enable_grad)
|
108 |
+
|
109 |
+
full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps
|
110 |
+
whole_eps = rearrange(
|
111 |
+
full_eps,
|
112 |
+
"n b c h w -> (b n) c h w"
|
113 |
+
)
|
114 |
+
out_eps = avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
|
115 |
+
return out_eps
|
116 |
+
return eps_t_fn
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
# class CondIndLong(SimpleWork):
|
122 |
+
# def __init__(self, shape, eps_scalar_t_fn, overlap_size=32):
|
123 |
+
# super().__init__(shape, eps_scalar_t_fn)
|
124 |
+
# self.overlap_size = overlap_size
|
125 |
+
|
126 |
+
# def loss(self, x):
|
127 |
+
# x1, x2 = x[:-1], x[1:]
|
128 |
+
# return th.sum(
|
129 |
+
# (th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
|
130 |
+
# dim=(1, 2, 3),
|
131 |
+
# )
|
132 |
+
|
133 |
+
# def generate_xT(self, n):
|
134 |
+
# white_noise = th.randn((n , *self.shape)).cuda()
|
135 |
+
# return self.noise(white_noise, None) * 80.0
|
136 |
+
|
137 |
+
# def noise(self, xt, scalar_t):
|
138 |
+
# del scalar_t
|
139 |
+
# noise = th.randn_like(xt)
|
140 |
+
# b, _, _, w = xt.shape
|
141 |
+
# final_img_w = w * b - self.overlap_size * (b - 1)
|
142 |
+
# noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w]
|
143 |
+
# noise = split_wimg(noise, b, rtn_overlap=False)
|
144 |
+
# return noise
|
145 |
+
|
146 |
+
# def merge(self, xs):
|
147 |
+
# return avg_merge_wimg(xs, self.overlap_size)
|
diff_collage/generic_sampler.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import torch as th
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
"generic_sampler",
|
10 |
+
"SimpleWork",
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def batch_mul(a, b): # pylint: disable=invalid-name
|
15 |
+
return th.einsum("a...,a...->a...", a, b)
|
16 |
+
|
17 |
+
class SimpleWork:
|
18 |
+
def __init__(self, shape, eps_scalar_t_fn):
|
19 |
+
self.shape = shape
|
20 |
+
self.eps_scalar_t_fn = eps_scalar_t_fn
|
21 |
+
|
22 |
+
def generate_xT(self, n):
|
23 |
+
return 80.0 * th.randn((n , *self.shape)).cuda()
|
24 |
+
|
25 |
+
def x0_fn(self, xt, scalar_t, y=None):
|
26 |
+
cur_eps = self.eps_scalar_t_fn(xt, scalar_t, y=y)
|
27 |
+
x0 = xt - scalar_t * cur_eps
|
28 |
+
x0 = th.clip(x0, -1,1)
|
29 |
+
return x0, {}, {"x0": x0.cpu()}
|
30 |
+
|
31 |
+
def noise(self, xt, scalar_t):
|
32 |
+
del scalar_t
|
33 |
+
return th.randn_like(xt)
|
34 |
+
|
35 |
+
def rev_ts(self, n_step, ts_order):
|
36 |
+
_rev_ts = th.pow(
|
37 |
+
th.linspace(
|
38 |
+
np.power(80.0, 1.0 / ts_order),
|
39 |
+
np.power(1e-3, 1.0 / ts_order),
|
40 |
+
n_step + 1
|
41 |
+
),
|
42 |
+
ts_order
|
43 |
+
)
|
44 |
+
return _rev_ts.cuda()
|
45 |
+
|
46 |
+
def generic_sampler( # pylint: disable=too-many-locals
|
47 |
+
x,
|
48 |
+
rev_ts,
|
49 |
+
noise_fn,
|
50 |
+
x0_pred_fn,
|
51 |
+
xt_lgv_fn=None,
|
52 |
+
s_churn = 0.0,
|
53 |
+
before_step_fn=None,
|
54 |
+
end_fn=None, # to do???
|
55 |
+
is_tqdm=True,
|
56 |
+
is_traj=True,
|
57 |
+
):
|
58 |
+
measure_loss = defaultdict(list)
|
59 |
+
traj = defaultdict(list)
|
60 |
+
if callable(x):
|
61 |
+
x = x()
|
62 |
+
if traj:
|
63 |
+
traj["xt"].append(x.cpu())
|
64 |
+
|
65 |
+
s_t_min = 0.05
|
66 |
+
s_t_max = 50.0
|
67 |
+
s_noise = 1.003
|
68 |
+
eta = min(s_churn / len(rev_ts), math.sqrt(2.0) - 1)
|
69 |
+
|
70 |
+
loop = zip(rev_ts[:-1], rev_ts[1:])
|
71 |
+
if is_tqdm:
|
72 |
+
loop = tqdm(loop)
|
73 |
+
|
74 |
+
running_x = x
|
75 |
+
for cur_t, next_t in loop:
|
76 |
+
# cur_x = traj["xt"][-1].clone().to("cuda")
|
77 |
+
cur_x = running_x
|
78 |
+
if cur_t < s_t_max and cur_t > s_t_min:
|
79 |
+
hat_cur_t = cur_t + eta * cur_t
|
80 |
+
cur_noise = noise_fn(cur_x, cur_t)
|
81 |
+
cur_x = cur_x + s_noise * cur_noise * th.sqrt(hat_cur_t ** 2 - cur_t ** 2)
|
82 |
+
cur_t = hat_cur_t
|
83 |
+
|
84 |
+
if before_step_fn is not None:
|
85 |
+
# TODO: may change the callabck
|
86 |
+
cur_x = before_step_fn(cur_x, cur_t)
|
87 |
+
|
88 |
+
x0, loss_info, traj_info = x0_pred_fn(cur_x, cur_t)
|
89 |
+
epsilon_1 = (cur_x - x0) / cur_t
|
90 |
+
|
91 |
+
xt_next = x0 + next_t * epsilon_1
|
92 |
+
|
93 |
+
x0, loss_info, traj_info = x0_pred_fn(xt_next, next_t)
|
94 |
+
epsilon_2 = (xt_next - x0) / next_t
|
95 |
+
|
96 |
+
xt_next = cur_x + (next_t - cur_t) * (epsilon_1 + epsilon_2) / 2
|
97 |
+
|
98 |
+
running_x = xt_next
|
99 |
+
|
100 |
+
if is_traj:
|
101 |
+
for key, value in loss_info.items():
|
102 |
+
measure_loss[key].append(value)
|
103 |
+
|
104 |
+
for key, value in traj_info.items():
|
105 |
+
traj[key].append(value)
|
106 |
+
traj["xt"].append(running_x.to("cpu").detach())
|
107 |
+
|
108 |
+
if xt_lgv_fn:
|
109 |
+
raise RuntimeError("Not implemented")
|
110 |
+
|
111 |
+
if is_traj:
|
112 |
+
return traj, measure_loss
|
113 |
+
return running_x
|
diff_collage/loss_helper.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as th
|
2 |
+
from .generic_sampler import batch_mul
|
3 |
+
|
4 |
+
def get_x0_grad_pred_fn(raw_net_model, cond_loss_fn, weight_fn, x0_update, thres_t):
|
5 |
+
def fn(xt, scalar_t):
|
6 |
+
xt = xt.requires_grad_(True)
|
7 |
+
x0_pred = raw_net_model(xt, scalar_t)
|
8 |
+
|
9 |
+
loss_info = {
|
10 |
+
"raw_x0": cond_loss_fn(x0_pred.detach()).cpu(),
|
11 |
+
}
|
12 |
+
traj_info = {
|
13 |
+
"t": scalar_t,
|
14 |
+
}
|
15 |
+
if scalar_t < thres_t:
|
16 |
+
x0_cor = x0_pred.detach()
|
17 |
+
else:
|
18 |
+
pred_loss = cond_loss_fn(x0_pred)
|
19 |
+
grad_term = th.autograd.grad(pred_loss.sum(), xt)[0]
|
20 |
+
weights = weight_fn(x0_pred, grad_term, cond_loss_fn)
|
21 |
+
x0_cor = (x0_pred - batch_mul(weights, grad_term)).detach()
|
22 |
+
loss_info["weight"] = weights.detach().cpu()
|
23 |
+
traj_info["grad"] = grad_term.detach().cpu()
|
24 |
+
|
25 |
+
if x0_update:
|
26 |
+
x0 = x0_update(x0_cor, scalar_t)
|
27 |
+
else:
|
28 |
+
x0 = x0_cor
|
29 |
+
|
30 |
+
loss_info["cor_x0"] = cond_loss_fn(x0_cor.detach()).cpu()
|
31 |
+
loss_info["x0"] = cond_loss_fn(x0.detach()).cpu()
|
32 |
+
traj_info.update({
|
33 |
+
"raw_x0": x0_pred.detach().cpu(),
|
34 |
+
"cor_x0": x0_cor.detach().cpu(),
|
35 |
+
"x0": x0.detach().cpu(),
|
36 |
+
}
|
37 |
+
)
|
38 |
+
return x0_cor, loss_info, traj_info
|
39 |
+
|
40 |
+
|
41 |
+
return fn
|
diff_collage/w_img.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as th
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"split_wimg",
|
6 |
+
]
|
7 |
+
|
8 |
+
def split_wimg(wimg, n_img, rtn_overlap=True):
|
9 |
+
if wimg.ndim == 3:
|
10 |
+
wimg = wimg[None]
|
11 |
+
_, _, h, w = wimg.shape
|
12 |
+
base_len = 128 # todo: hard code 128 here (the length of the latents)
|
13 |
+
overlap_size = (n_img * base_len - w) // (n_img - 1)
|
14 |
+
assert n_img * base_len - overlap_size * (n_img - 1) == w
|
15 |
+
|
16 |
+
img = th.nn.functional.unfold(wimg, kernel_size=(h, base_len), stride=base_len - overlap_size) #(B, block, n_img)
|
17 |
+
img = rearrange(
|
18 |
+
img,
|
19 |
+
"b (c h w) n -> (b n) c h w", h=h, w=base_len
|
20 |
+
)
|
21 |
+
|
22 |
+
if rtn_overlap:
|
23 |
+
return img , overlap_size
|
24 |
+
return img
|
25 |
+
|
26 |
+
def avg_merge_wimg(imgs, overlap_size, n=None, is_avg=True):
|
27 |
+
b, _, h, w = imgs.shape
|
28 |
+
if n == None:
|
29 |
+
n = b
|
30 |
+
unfold_img = rearrange(
|
31 |
+
imgs,
|
32 |
+
"(b n) c h w -> b (c h w) n", n = n
|
33 |
+
)
|
34 |
+
img = th.nn.functional.fold(
|
35 |
+
unfold_img,
|
36 |
+
(h, n * w - (n-1) * overlap_size),
|
37 |
+
kernel_size = (h, w),
|
38 |
+
stride = w - overlap_size
|
39 |
+
)
|
40 |
+
if is_avg:
|
41 |
+
counter = th.nn.functional.fold(
|
42 |
+
th.ones_like(unfold_img),
|
43 |
+
(h, n * w - (n-1) * overlap_size),
|
44 |
+
kernel_size = (h, w),
|
45 |
+
stride = w - overlap_size
|
46 |
+
)
|
47 |
+
return img / counter
|
48 |
+
return img
|
49 |
+
|
50 |
+
# legacy code use naive implementation
|
51 |
+
|
52 |
+
def split_wimg_legacy(himg, n_img, rtn_overlap=True):
|
53 |
+
if himg.ndim == 3:
|
54 |
+
himg = himg[None]
|
55 |
+
_, _, h, w = himg.shape
|
56 |
+
overlap_size = (n_img * h - w) // (n_img - 1)
|
57 |
+
assert n_img * h - overlap_size * (n_img - 1) == w
|
58 |
+
himg = himg[0]
|
59 |
+
rtn_img = [himg[:, :, :h]]
|
60 |
+
for i in range(n_img - 1):
|
61 |
+
rtn_img.append(himg[:, :, (h - overlap_size) * (i + 1) : h + (h - overlap_size) * (i + 1)])
|
62 |
+
if rtn_overlap:
|
63 |
+
return th.stack(rtn_img), overlap_size
|
64 |
+
return th.stack(rtn_img)
|
65 |
+
|
66 |
+
def avg_merge_wimg_legacy(imgs, overlap_size):
|
67 |
+
_, _, _, w = imgs.shape
|
68 |
+
rtn_img = [imgs[0]]
|
69 |
+
for cur_img in imgs[1:]:
|
70 |
+
rtn_img.append(cur_img[:, :, overlap_size:])
|
71 |
+
first_img = th.cat(rtn_img, dim=-1)
|
72 |
+
|
73 |
+
rtn_img = []
|
74 |
+
for cur_img in imgs[:-1]:
|
75 |
+
rtn_img.append(cur_img[:, :, : w - overlap_size])
|
76 |
+
rtn_img.append(imgs[-1])
|
77 |
+
second_img = th.cat(rtn_img, dim=-1)
|
78 |
+
|
79 |
+
return (first_img + second_img) / 2.0
|
diff_collage/w_loss.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch as th
|
3 |
+
from einops import rearrange
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from .generic_sampler import batch_mul
|
7 |
+
|
8 |
+
|
9 |
+
def split_wimg(himg, n_img, rtn_overlap=True):
|
10 |
+
if himg.ndim == 3:
|
11 |
+
himg = himg[None]
|
12 |
+
_, _, h, w = himg.shape
|
13 |
+
overlap_size = (n_img * h - w) // (n_img - 1)
|
14 |
+
assert n_img * h - overlap_size * (n_img - 1) == w
|
15 |
+
himg = himg[0]
|
16 |
+
rtn_img = [himg[:, :, :h]]
|
17 |
+
for i in range(n_img - 1):
|
18 |
+
rtn_img.append(himg[:, :, (h - overlap_size) * (i + 1) : h + (h - overlap_size) * (i + 1)])
|
19 |
+
if rtn_overlap:
|
20 |
+
return th.stack(rtn_img), overlap_size
|
21 |
+
return th.stack(rtn_img)
|
22 |
+
|
23 |
+
|
24 |
+
def merge_wimg(imgs, overlap_size):
|
25 |
+
_, _, _, w = imgs.shape
|
26 |
+
rtn_img = [imgs[0]]
|
27 |
+
for cur_img in imgs[1:]:
|
28 |
+
rtn_img.append(cur_img[:, :, overlap_size:])
|
29 |
+
first_img = th.cat(rtn_img, dim=-1)
|
30 |
+
|
31 |
+
rtn_img = []
|
32 |
+
for cur_img in imgs[:-1]:
|
33 |
+
rtn_img.append(cur_img[:, :, : w - overlap_size])
|
34 |
+
rtn_img.append(imgs[-1])
|
35 |
+
second_img = th.cat(rtn_img, dim=-1)
|
36 |
+
|
37 |
+
return (first_img + second_img) / 2.0
|
38 |
+
|
39 |
+
|
40 |
+
def get_x0_pred_fn(raw_net_model, cond_loss_fn, weight_fn, x0_fn, thres_t, init_fn=None):
|
41 |
+
def fn(xt, scalar_t):
|
42 |
+
if init_fn is not None:
|
43 |
+
xt = init_fn(xt, scalar_t)
|
44 |
+
xt = xt.requires_grad_(True)
|
45 |
+
x0_pred = raw_net_model(xt, scalar_t)
|
46 |
+
|
47 |
+
loss_info = {
|
48 |
+
"raw_x0": cond_loss_fn(x0_pred.detach()).cpu(),
|
49 |
+
}
|
50 |
+
traj_info = {
|
51 |
+
"t": scalar_t,
|
52 |
+
}
|
53 |
+
if scalar_t < thres_t:
|
54 |
+
x0_cor = x0_pred.detach()
|
55 |
+
else:
|
56 |
+
pred_loss = cond_loss_fn(x0_pred)
|
57 |
+
grad_term = th.autograd.grad(pred_loss.sum(), xt)[0]
|
58 |
+
weights = weight_fn(x0_pred, grad_term, cond_loss_fn)
|
59 |
+
x0_cor = (x0_pred - batch_mul(weights, grad_term)).detach()
|
60 |
+
loss_info["weight"] = weights.detach().cpu()
|
61 |
+
traj_info["grad"] = grad_term.detach().cpu()
|
62 |
+
|
63 |
+
if x0_fn:
|
64 |
+
x0 = x0_fn(x0_cor, scalar_t)
|
65 |
+
else:
|
66 |
+
x0 = x0_cor
|
67 |
+
|
68 |
+
loss_info["cor_x0"] = cond_loss_fn(x0_cor.detach()).cpu()
|
69 |
+
loss_info["x0"] = cond_loss_fn(x0.detach()).cpu()
|
70 |
+
traj_info.update({
|
71 |
+
"raw_x0": x0_pred.detach().cpu(),
|
72 |
+
"cor_x0": x0_cor.detach().cpu(),
|
73 |
+
"x0": x0.detach().cpu(),
|
74 |
+
}
|
75 |
+
)
|
76 |
+
return x0_cor, loss_info, traj_info
|
77 |
+
|
78 |
+
return fn
|
79 |
+
|
80 |
+
|
81 |
+
def simple_noise(cur_t, xt):
|
82 |
+
del cur_t
|
83 |
+
return th.randn_like(xt)
|
84 |
+
|
85 |
+
|
86 |
+
def get_fix_weight_fn(fix_weight):
|
87 |
+
def weight_fn(xs, grads, *args):
|
88 |
+
del grads, args
|
89 |
+
return th.ones(xs.shape[0]).to(xs) * fix_weight
|
90 |
+
|
91 |
+
return weight_fn
|
92 |
+
|
93 |
+
|
94 |
+
class SeqWorker:
|
95 |
+
def __init__(self, overlap_size=10, src_img=None):
|
96 |
+
self.overlap_size = overlap_size
|
97 |
+
self.src_img = src_img
|
98 |
+
|
99 |
+
def loss(self, x):
|
100 |
+
return th.sum(
|
101 |
+
(th.abs(self.src_img[:, :, :, -self.overlap_size :] - x[:, :, :, : self.overlap_size]))
|
102 |
+
** 2,
|
103 |
+
dim=(1, 2, 3),
|
104 |
+
)
|
105 |
+
|
106 |
+
def x0_replace(self, x0):
|
107 |
+
rtn_x0 = x0.clone()
|
108 |
+
rtn_x0[:, :, :, : self.overlap_size] = self.src_img[:, :, :, -self.overlap_size :]
|
109 |
+
return x0
|
110 |
+
|
111 |
+
def optimal_weight_fn(self, x0, grads, *args, ratio=1.0):
|
112 |
+
del args
|
113 |
+
overlap_size = self.overlap_size
|
114 |
+
# argmin_{w} (delta_pixel - w * delta_pixel)^2
|
115 |
+
delta_pixel = x0[:, :, :, :overlap_size] - self.src_img[:, :, :, -overlap_size:]
|
116 |
+
delta_grads = grads[:, :, :, :overlap_size]
|
117 |
+
num = th.sum(delta_pixel * delta_grads).item()
|
118 |
+
denum = th.sum(delta_grads * delta_grads).item()
|
119 |
+
_optimal_weight = num / denum
|
120 |
+
if math.isnan(_optimal_weight):
|
121 |
+
print(denum)
|
122 |
+
raise RuntimeError("nan for weights")
|
123 |
+
|
124 |
+
return ratio * _optimal_weight * th.ones(x0.shape[0]).to(x0)
|
125 |
+
|
126 |
+
|
127 |
+
class CircleWorker:
|
128 |
+
def __init__(self, overlap_size=10, adam_num_iter=100):
|
129 |
+
self.overlap_size = overlap_size
|
130 |
+
self.adam_num_iter = adam_num_iter
|
131 |
+
|
132 |
+
|
133 |
+
def get_match_patch(self, x):
|
134 |
+
tail = x[:, :, :, -self.overlap_size :]
|
135 |
+
head = x[:, :, :, : self.overlap_size]
|
136 |
+
tail = th.roll(tail, 1, 0)
|
137 |
+
return tail, head
|
138 |
+
|
139 |
+
def loss(self, x):
|
140 |
+
tail, head = self.get_match_patch(x)
|
141 |
+
return th.sum(
|
142 |
+
(tail - head)**2,
|
143 |
+
dim=(1, 2, 3),
|
144 |
+
)
|
145 |
+
|
146 |
+
def split_noise(self, cur_t, xt):
|
147 |
+
noise = simple_noise(cur_t, xt)
|
148 |
+
b, _, _, w = xt.shape
|
149 |
+
final_img_w = w * b - self.overlap_size * b
|
150 |
+
noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w]
|
151 |
+
noise = th.cat([noise, noise[:,:,:, :self.overlap_size]], dim=-1)
|
152 |
+
noise, _ = split_wimg(noise, b)
|
153 |
+
return noise
|
154 |
+
|
155 |
+
def merge_circle_image(self, xt):
|
156 |
+
merged_long_img = merge_wimg(xt, self.overlap_size)
|
157 |
+
return th.cat(
|
158 |
+
[
|
159 |
+
(merged_long_img[:,:,:self.overlap_size] + merged_long_img[:,:,-self.overlap_size:]) / 2.0,
|
160 |
+
merged_long_img[:,:,self.overlap_size:-self.overlap_size],
|
161 |
+
],
|
162 |
+
dim=-1
|
163 |
+
)
|
164 |
+
|
165 |
+
def split_circle_image(self, merged_long_img, n):
|
166 |
+
imgs,_ = split_wimg(
|
167 |
+
th.cat(
|
168 |
+
[
|
169 |
+
merged_long_img,
|
170 |
+
merged_long_img[:,:,:self.overlap_size],
|
171 |
+
],
|
172 |
+
dim = -1,
|
173 |
+
),
|
174 |
+
n
|
175 |
+
)
|
176 |
+
return imgs
|
177 |
+
|
178 |
+
|
179 |
+
def optimal_weight_fn(self, xs, grads, *args):
|
180 |
+
del args
|
181 |
+
# argmin_{w} (delta_pixel - w * delta_pixel)^2
|
182 |
+
tail, head = self.get_match_patch(xs)
|
183 |
+
delta_pixel = tail - head
|
184 |
+
tail, head = self.get_match_patch(grads)
|
185 |
+
delta_grads = tail - head
|
186 |
+
|
187 |
+
num = th.sum(delta_pixel * delta_grads).item()
|
188 |
+
denum = th.sum(delta_grads * delta_grads).item()
|
189 |
+
_optimal_weight = num / denum
|
190 |
+
return _optimal_weight * th.ones(xs.shape[0]).to(xs)
|
191 |
+
|
192 |
+
def adam_grad_weight(self, x0, grad_term, cond_loss_fn):
|
193 |
+
init_weight = self.optimal_weight_fn(x0, grad_term)
|
194 |
+
grad_term = grad_term.detach()
|
195 |
+
x0 = x0.detach()
|
196 |
+
with th.enable_grad():
|
197 |
+
weights = init_weight.requires_grad_()
|
198 |
+
optimizer = th.optim.Adam(
|
199 |
+
[
|
200 |
+
weights,
|
201 |
+
],
|
202 |
+
lr=1e-2,
|
203 |
+
)
|
204 |
+
|
205 |
+
def _loss(w):
|
206 |
+
cor_x0 = x0 - batch_mul(w, grad_term)
|
207 |
+
return cond_loss_fn(cor_x0).sum()
|
208 |
+
|
209 |
+
for _ in range(self.adam_num_iter):
|
210 |
+
optimizer.zero_grad()
|
211 |
+
_cur_loss = _loss(weights)
|
212 |
+
_cur_loss.backward()
|
213 |
+
optimizer.step()
|
214 |
+
return weights
|
215 |
+
|
216 |
+
# TODO:
|
217 |
+
def x0_replace(self, x0, sclar_t, thres_t):
|
218 |
+
if sclar_t > thres_t:
|
219 |
+
merge_x0 = merge_wimg(x0, self.overlap_size)
|
220 |
+
return split_wimg(merge_x0, x0.shape[0])[0]
|
221 |
+
else:
|
222 |
+
return x0
|
223 |
+
|
224 |
+
|
225 |
+
class ParaWorker:
|
226 |
+
def __init__(self, overlap_size=10, adam_num_iter=100):
|
227 |
+
self.overlap_size = overlap_size
|
228 |
+
self.adam_num_iter = adam_num_iter
|
229 |
+
|
230 |
+
def loss(self, x):
|
231 |
+
x1, x2 = x[:-1], x[1:]
|
232 |
+
return th.sum(
|
233 |
+
(th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
|
234 |
+
dim=(1, 2, 3),
|
235 |
+
)
|
236 |
+
|
237 |
+
def split_noise(self, xt, cur_t):
|
238 |
+
noise = simple_noise(cur_t, xt)
|
239 |
+
b, _, _, w = xt.shape
|
240 |
+
final_img_w = w * b - self.overlap_size * (b - 1)
|
241 |
+
noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w]
|
242 |
+
noise, _ = split_wimg(noise, b)
|
243 |
+
return noise
|
244 |
+
|
245 |
+
def optimal_weight_fn(self, xs, grads, *args):
|
246 |
+
del args
|
247 |
+
overlap_size = self.overlap_size
|
248 |
+
# argmin_{w} (delta_pixel - w * delta_pixel)^2
|
249 |
+
delta_pixel = xs[:-1, :, :, -overlap_size:] - xs[1:, :, :, :overlap_size]
|
250 |
+
delta_grads = grads[:-1, :, :, -overlap_size:] - grads[1:, :, :, :overlap_size]
|
251 |
+
num = th.sum(delta_pixel * delta_grads).item()
|
252 |
+
denum = th.sum(delta_grads * delta_grads).item()
|
253 |
+
_optimal_weight = num / denum
|
254 |
+
return _optimal_weight * th.ones(xs.shape[0]).to(xs)
|
255 |
+
|
256 |
+
def adam_grad_weight(self, x0, grad_term, cond_loss_fn):
|
257 |
+
init_weight = self.optimal_weight_fn(x0, grad_term)
|
258 |
+
grad_term = grad_term.detach()
|
259 |
+
x0 = x0.detach()
|
260 |
+
with th.enable_grad():
|
261 |
+
weights = init_weight.requires_grad_()
|
262 |
+
optimizer = th.optim.Adam(
|
263 |
+
[
|
264 |
+
weights,
|
265 |
+
],
|
266 |
+
lr=1e-2,
|
267 |
+
)
|
268 |
+
|
269 |
+
def _loss(w):
|
270 |
+
cor_x0 = x0 - batch_mul(w, grad_term)
|
271 |
+
return cond_loss_fn(cor_x0).sum()
|
272 |
+
|
273 |
+
for _ in range(self.adam_num_iter):
|
274 |
+
optimizer.zero_grad()
|
275 |
+
_cur_loss = _loss(weights)
|
276 |
+
_cur_loss.backward()
|
277 |
+
optimizer.step()
|
278 |
+
return weights
|
279 |
+
|
280 |
+
def x0_replace(self, x0, sclar_t, thres_t):
|
281 |
+
if sclar_t > thres_t:
|
282 |
+
merge_x0 = merge_wimg(x0, self.overlap_size)
|
283 |
+
return split_wimg(merge_x0, x0.shape[0])[0]
|
284 |
+
else:
|
285 |
+
return x0
|
286 |
+
|
287 |
+
class ParaWorkerC(ParaWorker):
|
288 |
+
def __init__(self, src_img, mask_img, inpaint_w = 1.0, overlap_size=10, adam_num_iter=100):
|
289 |
+
self.src_img = src_img
|
290 |
+
self.inpaint_w = inpaint_w
|
291 |
+
self.mask_img = mask_img # 1 indicate masked given pixels
|
292 |
+
super().__init__(overlap_size, adam_num_iter)
|
293 |
+
|
294 |
+
def loss(self, x):
|
295 |
+
if x.shape[0] == 1:
|
296 |
+
return th.sum(
|
297 |
+
th.sum(
|
298 |
+
th.square(self.src_img[:,:,:,:x.shape[-1]] - x), dim=(0,1)
|
299 |
+
) * self.mask_img[:,:x.shape[-1]]
|
300 |
+
)
|
301 |
+
else:
|
302 |
+
consistent_loss = super().loss(x)
|
303 |
+
# merge image
|
304 |
+
merge_x = merge_wimg(x, self.overlap_size)
|
305 |
+
|
306 |
+
inpating_loss = th.sum(
|
307 |
+
th.sum(
|
308 |
+
th.square(self.src_img[:,:,:,:merge_x.shape[-1]] - merge_x), dim=(0,1)
|
309 |
+
) * self.mask_img[:,:merge_x.shape[-1]]
|
310 |
+
)
|
311 |
+
|
312 |
+
return consistent_loss + inpating_loss / (x.shape[-1] - 1)
|
313 |
+
|
314 |
+
def x0_replace(self, x0, sclar_t, thres_t):
|
315 |
+
if sclar_t > thres_t:
|
316 |
+
merge_x = merge_wimg(x0, self.overlap_size)
|
317 |
+
src_img = self.src_img[:,:,:,:merge_x.shape[-1]]
|
318 |
+
mask_img = self.mask_img[:,:merge_x.shape[-1]]
|
319 |
+
merge_x = th.where(mask_img[None,None], src_img, merge_x)
|
320 |
+
return split_wimg(merge_x, x0.shape[0])[0]
|
321 |
+
else:
|
322 |
+
return x0
|
323 |
+
|
324 |
+
|
325 |
+
class SplitMergeOp:
|
326 |
+
def __init__(self, avg_overlap=32):
|
327 |
+
self.avg_overlap = avg_overlap
|
328 |
+
self.cur_overlap_int = None
|
329 |
+
|
330 |
+
def sample(self, n):
|
331 |
+
# lower_coef = 3 / 4.0
|
332 |
+
_lower_bound = self.avg_overlap - 6
|
333 |
+
base_overlap = np.ones(n) * _lower_bound
|
334 |
+
|
335 |
+
total_ball = (self.avg_overlap - _lower_bound) * n
|
336 |
+
random_number = np.random.randint(0, total_ball - n, n-1)
|
337 |
+
random_number = np.sort(random_number)
|
338 |
+
balls = np.append(random_number, total_ball - n) - np.insert(random_number, 0, 0) + np.ones(n) + base_overlap
|
339 |
+
|
340 |
+
assert np.sum(balls) == n * self.avg_overlap
|
341 |
+
|
342 |
+
# TODO: FIXME
|
343 |
+
balls = np.ones(n) * self.avg_overlap
|
344 |
+
|
345 |
+
return balls.astype(np.int)
|
346 |
+
|
347 |
+
def reset(self, n):
|
348 |
+
self.cur_overlap_int = self.sample(n)
|
349 |
+
|
350 |
+
def split(self, img, n, img_w=64):
|
351 |
+
assert img.ndim == 3
|
352 |
+
# assert img.shape[-1] > (n-1) * self.avg_overlap
|
353 |
+
assert (n-1) == self.cur_overlap_int.shape[0]
|
354 |
+
|
355 |
+
assert (n-1) * self.avg_overlap + img.shape[-1] == n * img_w
|
356 |
+
|
357 |
+
cur_idx = 0
|
358 |
+
imgs = []
|
359 |
+
for cur_overlap in self.cur_overlap_int:
|
360 |
+
imgs.append(img[:,:,cur_idx:cur_idx + img_w])
|
361 |
+
cur_idx = cur_idx + img_w - cur_overlap
|
362 |
+
imgs.append(img[:,:,cur_idx:])
|
363 |
+
return th.stack(imgs)
|
364 |
+
|
365 |
+
def merge(self, imgs):
|
366 |
+
b = imgs.shape[0]
|
367 |
+
img_size = imgs.shape[-1]
|
368 |
+
assert b - 1 == self.cur_overlap_int.shape[0]
|
369 |
+
img_width = b * imgs.shape[-1] - np.sum(self.cur_overlap_int)
|
370 |
+
wimg = th.zeros((3, imgs.shape[-2], img_width)).to(imgs)
|
371 |
+
ncnt = th.zeros(img_width).to(imgs)
|
372 |
+
cur_idx = 0
|
373 |
+
for i_th, cur_img in enumerate(imgs):
|
374 |
+
wimg[:,:,cur_idx:cur_idx + img_size] += cur_img
|
375 |
+
ncnt[cur_idx:cur_idx + img_size] += 1.0
|
376 |
+
if i_th < b -1:
|
377 |
+
cur_idx = cur_idx + img_size - self.cur_overlap_int[i_th]
|
378 |
+
return wimg / ncnt[None,None,:]
|
379 |
+
|
380 |
+
|
381 |
+
class ParaWorkerFix:
|
382 |
+
def __init__(self, overlap_size=10, adam_num_iter=100):
|
383 |
+
self.overlap_size = overlap_size
|
384 |
+
self.adam_num_iter = adam_num_iter
|
385 |
+
self.op = SplitMergeOp(overlap_size)
|
386 |
+
|
387 |
+
def loss(self, x):
|
388 |
+
avg_x = self.op.split(
|
389 |
+
self.op.merge(x), x.shape[0], x.shape[-1]
|
390 |
+
)
|
391 |
+
return th.sum(
|
392 |
+
(x - avg_x) ** 2,
|
393 |
+
dim=(1, 2, 3),
|
394 |
+
)
|
395 |
+
|
396 |
+
def split_noise(self, cur_t, xt):
|
397 |
+
noise = simple_noise(cur_t, xt)
|
398 |
+
b, _, _, w = xt.shape
|
399 |
+
final_img_w = w * b - self.overlap_size * (b - 1)
|
400 |
+
noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w][0]
|
401 |
+
noise = self.op.split(noise, b, w)
|
402 |
+
return noise
|
403 |
+
|
404 |
+
def adam_grad_weight(self, x0, grad_term, cond_loss_fn):
|
405 |
+
init_weight = th.ones(x0.shape[0]).to(x0)
|
406 |
+
grad_term = grad_term.detach()
|
407 |
+
x0 = x0.detach()
|
408 |
+
with th.enable_grad():
|
409 |
+
weights = init_weight.requires_grad_()
|
410 |
+
optimizer = th.optim.Adam(
|
411 |
+
[
|
412 |
+
weights,
|
413 |
+
],
|
414 |
+
lr=1e-2,
|
415 |
+
)
|
416 |
+
|
417 |
+
def _loss(w):
|
418 |
+
cor_x0 = x0 - batch_mul(w, grad_term)
|
419 |
+
return cond_loss_fn(cor_x0).sum()
|
420 |
+
|
421 |
+
for _ in range(self.adam_num_iter):
|
422 |
+
optimizer.zero_grad()
|
423 |
+
_cur_loss = _loss(weights)
|
424 |
+
_cur_loss.backward()
|
425 |
+
optimizer.step()
|
426 |
+
return weights
|
427 |
+
|
428 |
+
def x0_replace(self, x0, sclar_t, thres_t):
|
429 |
+
if sclar_t > thres_t:
|
430 |
+
merge_x0 = self.op.merge(x0)
|
431 |
+
return self.op.split(merge_x0, x0.shape[0], x0.shape[-1])
|
432 |
+
else:
|
433 |
+
return x0
|
environment.yml
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: guided
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
- defaults
|
7 |
+
dependencies:
|
8 |
+
- _libgcc_mutex=0.1=main
|
9 |
+
- _openmp_mutex=5.1=1_gnu
|
10 |
+
- blas=1.0=mkl
|
11 |
+
- brotlipy=0.7.0=py39h27cfd23_1003
|
12 |
+
- bzip2=1.0.8=h7b6447c_0
|
13 |
+
- ca-certificates=2023.08.22=h06a4308_0
|
14 |
+
- certifi=2023.7.22=py39h06a4308_0
|
15 |
+
- cffi=1.15.1=py39h5eee18b_3
|
16 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
17 |
+
- cryptography=39.0.1=py39h9ce1e76_2
|
18 |
+
- cuda-cudart=11.7.99=0
|
19 |
+
- cuda-cupti=11.7.101=0
|
20 |
+
- cuda-libraries=11.7.1=0
|
21 |
+
- cuda-nvrtc=11.7.99=0
|
22 |
+
- cuda-nvtx=11.7.91=0
|
23 |
+
- cuda-runtime=11.7.1=0
|
24 |
+
- ffmpeg=4.3=hf484d3e_0
|
25 |
+
- filelock=3.9.0=py39h06a4308_0
|
26 |
+
- freetype=2.12.1=h4a9f257_0
|
27 |
+
- giflib=5.2.1=h5eee18b_3
|
28 |
+
- gmp=6.2.1=h295c915_3
|
29 |
+
- gmpy2=2.1.2=py39heeb90bb_0
|
30 |
+
- gnutls=3.6.15=he1e5248_0
|
31 |
+
- idna=3.4=py39h06a4308_0
|
32 |
+
- intel-openmp=2023.1.0=hdb19cb5_46305
|
33 |
+
- jinja2=3.1.2=py39h06a4308_0
|
34 |
+
- jpeg=9e=h5eee18b_1
|
35 |
+
- lame=3.100=h7b6447c_0
|
36 |
+
- lcms2=2.12=h3be6417_0
|
37 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
38 |
+
- lerc=3.0=h295c915_0
|
39 |
+
- libcublas=11.10.3.66=0
|
40 |
+
- libcufft=10.7.2.124=h4fbf590_0
|
41 |
+
- libcufile=1.6.1.9=0
|
42 |
+
- libcurand=10.3.2.106=0
|
43 |
+
- libcusolver=11.4.0.1=0
|
44 |
+
- libcusparse=11.7.4.91=0
|
45 |
+
- libdeflate=1.17=h5eee18b_0
|
46 |
+
- libffi=3.4.4=h6a678d5_0
|
47 |
+
- libgcc-ng=11.2.0=h1234567_1
|
48 |
+
- libgfortran-ng=7.5.0=h14aa051_20
|
49 |
+
- libgfortran4=7.5.0=h14aa051_20
|
50 |
+
- libgomp=11.2.0=h1234567_1
|
51 |
+
- libiconv=1.16=h7f8727e_2
|
52 |
+
- libidn2=2.3.4=h5eee18b_0
|
53 |
+
- libnpp=11.7.4.75=0
|
54 |
+
- libnvjpeg=11.8.0.2=0
|
55 |
+
- libpng=1.6.39=h5eee18b_0
|
56 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
57 |
+
- libtasn1=4.19.0=h5eee18b_0
|
58 |
+
- libtiff=4.5.0=h6a678d5_2
|
59 |
+
- libunistring=0.9.10=h27cfd23_0
|
60 |
+
- libwebp=1.2.4=h11a3e52_1
|
61 |
+
- libwebp-base=1.2.4=h5eee18b_1
|
62 |
+
- lz4-c=1.9.4=h6a678d5_0
|
63 |
+
- markupsafe=2.1.1=py39h7f8727e_0
|
64 |
+
- mkl=2023.1.0=h6d00ec8_46342
|
65 |
+
- mkl-service=2.4.0=py39h5eee18b_1
|
66 |
+
- mkl_fft=1.3.6=py39h417a72b_1
|
67 |
+
- mkl_random=1.2.2=py39h417a72b_1
|
68 |
+
- mpc=1.1.0=h10f8cd9_1
|
69 |
+
- mpfr=4.0.2=hb69a4c5_1
|
70 |
+
- mpi=1.0=openmpi
|
71 |
+
- mpi4py=3.1.4=py39h3e5f7c9_0
|
72 |
+
- mpmath=1.2.1=py39h06a4308_0
|
73 |
+
- ncurses=6.4=h6a678d5_0
|
74 |
+
- nettle=3.7.3=hbbd107a_1
|
75 |
+
- networkx=2.8.4=py39h06a4308_1
|
76 |
+
- numpy=1.24.3=py39hf6e8229_1
|
77 |
+
- numpy-base=1.24.3=py39h060ed82_1
|
78 |
+
- openh264=2.1.1=h4ff587b_0
|
79 |
+
- openmpi=4.0.4=hdf1f1ad_0
|
80 |
+
- openssl=3.0.12=h7f8727e_0
|
81 |
+
- pillow=9.4.0=py39h6a678d5_0
|
82 |
+
- pip=23.1.2=py39h06a4308_0
|
83 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
84 |
+
- pyopenssl=23.0.0=py39h06a4308_0
|
85 |
+
- pysocks=1.7.1=py39h06a4308_0
|
86 |
+
- python=3.9.16=h955ad1f_3
|
87 |
+
- pytorch=2.0.1=py3.9_cuda11.7_cudnn8.5.0_0
|
88 |
+
- pytorch-cuda=11.7=h778d358_5
|
89 |
+
- pytorch-mutex=1.0=cuda
|
90 |
+
- readline=8.2=h5eee18b_0
|
91 |
+
- requests=2.31.0=pyhd8ed1ab_0
|
92 |
+
- setuptools=67.8.0=py39h06a4308_0
|
93 |
+
- six=1.16.0=pyhd3eb1b0_1
|
94 |
+
- sqlite=3.41.2=h5eee18b_0
|
95 |
+
- sympy=1.11.1=py39h06a4308_0
|
96 |
+
- tbb=2021.8.0=hdb19cb5_0
|
97 |
+
- tk=8.6.12=h1ccaba5_0
|
98 |
+
- torchaudio=2.0.2=py39_cu117
|
99 |
+
- torchtext=0.6.0=py_1
|
100 |
+
- torchtriton=2.0.0=py39
|
101 |
+
- torchvision=0.15.2=py39_cu117
|
102 |
+
- tqdm=4.65.0=py39hb070fc8_0
|
103 |
+
- typing_extensions=4.6.3=py39h06a4308_0
|
104 |
+
- urllib3=1.26.16=py39h06a4308_0
|
105 |
+
- wheel=0.38.4=py39h06a4308_0
|
106 |
+
- xz=5.4.2=h5eee18b_0
|
107 |
+
- zlib=1.2.13=h5eee18b_0
|
108 |
+
- zstd=1.5.5=hc292b87_0
|
109 |
+
- pip:
|
110 |
+
- absl-py==1.4.0
|
111 |
+
- accelerate==0.21.0
|
112 |
+
- anyio==4.0.0
|
113 |
+
- appdirs==1.4.4
|
114 |
+
- argon2-cffi==23.1.0
|
115 |
+
- argon2-cffi-bindings==21.2.0
|
116 |
+
- arrow==1.3.0
|
117 |
+
- asttokens==2.4.1
|
118 |
+
- async-lru==2.0.4
|
119 |
+
- attrs==23.1.0
|
120 |
+
- awscli==1.29.84
|
121 |
+
- babel==2.13.1
|
122 |
+
- beautifulsoup4==4.12.2
|
123 |
+
- bleach==6.1.0
|
124 |
+
- blobfile==2.0.2
|
125 |
+
- boltons==23.0.0
|
126 |
+
- botocore==1.31.84
|
127 |
+
- cachetools==5.3.1
|
128 |
+
- chardet==5.1.0
|
129 |
+
- clean-fid==0.1.35
|
130 |
+
- click==8.1.3
|
131 |
+
- clip-anytorch==2.5.2
|
132 |
+
- colorama==0.4.4
|
133 |
+
- comm==0.1.4
|
134 |
+
- contextlib2==21.6.0
|
135 |
+
- contourpy==1.1.0
|
136 |
+
- cycler==0.11.0
|
137 |
+
- debugpy==1.8.0
|
138 |
+
- decorator==5.1.1
|
139 |
+
- defusedxml==0.7.1
|
140 |
+
- docker-pycreds==0.4.0
|
141 |
+
- docutils==0.16
|
142 |
+
- einops==0.6.1
|
143 |
+
- exceptiongroup==1.1.3
|
144 |
+
- executing==2.0.1
|
145 |
+
- fastjsonschema==2.18.1
|
146 |
+
- fonttools==4.40.0
|
147 |
+
- fqdn==1.5.1
|
148 |
+
- fsspec==2023.6.0
|
149 |
+
- ftfy==6.1.1
|
150 |
+
- future==0.18.3
|
151 |
+
- gitdb==4.0.10
|
152 |
+
- gitpython==3.1.31
|
153 |
+
- google-auth==2.21.0
|
154 |
+
- google-auth-oauthlib==1.0.0
|
155 |
+
- grpcio==1.56.0
|
156 |
+
- huggingface-hub==0.15.1
|
157 |
+
- imageio==2.31.1
|
158 |
+
- importlib-metadata==6.7.0
|
159 |
+
- importlib-resources==5.12.0
|
160 |
+
- ipykernel==6.26.0
|
161 |
+
- ipython==8.17.2
|
162 |
+
- ipython-genutils==0.2.0
|
163 |
+
- ipywidgets==8.1.1
|
164 |
+
- isoduration==20.11.0
|
165 |
+
- jedi==0.19.1
|
166 |
+
- jmespath==1.0.1
|
167 |
+
- joblib==1.3.1
|
168 |
+
- json5==0.9.14
|
169 |
+
- jsonmerge==1.9.2
|
170 |
+
- jsonpickle==3.0.1
|
171 |
+
- jsonpointer==2.4
|
172 |
+
- jsonschema==4.19.0
|
173 |
+
- jsonschema-specifications==2023.7.1
|
174 |
+
- jupyter==1.0.0
|
175 |
+
- jupyter-client==8.5.0
|
176 |
+
- jupyter-console==6.6.3
|
177 |
+
- jupyter-core==5.5.0
|
178 |
+
- jupyter-events==0.8.0
|
179 |
+
- jupyter-lsp==2.2.0
|
180 |
+
- jupyter-server==2.9.1
|
181 |
+
- jupyter-server-terminals==0.4.4
|
182 |
+
- jupyterlab==4.0.8
|
183 |
+
- jupyterlab-pygments==0.2.2
|
184 |
+
- jupyterlab-server==2.25.0
|
185 |
+
- jupyterlab-widgets==3.0.9
|
186 |
+
- k-diffusion==0.0.16
|
187 |
+
- kiwisolver==1.4.4
|
188 |
+
- kornia==0.7.0
|
189 |
+
- lazy-loader==0.3
|
190 |
+
- lmdb==1.4.1
|
191 |
+
- lxml==4.9.2
|
192 |
+
- markdown==3.4.3
|
193 |
+
- matplotlib==3.7.1
|
194 |
+
- matplotlib-inline==0.1.6
|
195 |
+
- mido==1.2.10
|
196 |
+
- mistune==3.0.2
|
197 |
+
- ml-collections==0.1.1
|
198 |
+
- more-itertools==10.0.0
|
199 |
+
- music21==8.3.0
|
200 |
+
- nbclient==0.8.0
|
201 |
+
- nbconvert==7.10.0
|
202 |
+
- nbformat==5.9.2
|
203 |
+
- nest-asyncio==1.5.8
|
204 |
+
- notebook==7.0.6
|
205 |
+
- notebook-shim==0.2.3
|
206 |
+
- oauthlib==3.2.2
|
207 |
+
- omegaconf==2.0.0
|
208 |
+
- overrides==7.4.0
|
209 |
+
- packaging==23.1
|
210 |
+
- pandas==2.0.2
|
211 |
+
- pandocfilters==1.5.0
|
212 |
+
- parso==0.8.3
|
213 |
+
- pathtools==0.1.2
|
214 |
+
- pexpect==4.8.0
|
215 |
+
- platformdirs==3.11.0
|
216 |
+
- prometheus-client==0.18.0
|
217 |
+
- prompt-toolkit==3.0.39
|
218 |
+
- protobuf==4.23.3
|
219 |
+
- psutil==5.9.5
|
220 |
+
- ptyprocess==0.7.0
|
221 |
+
- pure-eval==0.2.2
|
222 |
+
- pyasn1==0.5.0
|
223 |
+
- pyasn1-modules==0.3.0
|
224 |
+
- pycryptodomex==3.18.0
|
225 |
+
- pygments==2.16.1
|
226 |
+
- pyparsing==3.1.0
|
227 |
+
- python-dateutil==2.8.2
|
228 |
+
- python-json-logger==2.0.7
|
229 |
+
- pytorch-lightning==1.0.8
|
230 |
+
- pytz==2023.3
|
231 |
+
- pywavelets==1.4.1
|
232 |
+
- pyyaml==6.0
|
233 |
+
- pyzmq==25.1.1
|
234 |
+
- qtconsole==5.4.4
|
235 |
+
- qtpy==2.4.1
|
236 |
+
- referencing==0.30.2
|
237 |
+
- regex==2023.8.8
|
238 |
+
- requests-oauthlib==1.3.1
|
239 |
+
- resize-right==0.0.2
|
240 |
+
- rfc3339-validator==0.1.4
|
241 |
+
- rfc3986-validator==0.1.1
|
242 |
+
- rotary-embedding-torch==0.3.2
|
243 |
+
- rpds-py==0.9.2
|
244 |
+
- rsa==4.7.2
|
245 |
+
- s3transfer==0.7.0
|
246 |
+
- safetensors==0.3.1
|
247 |
+
- scikit-image==0.21.0
|
248 |
+
- scikit-learn==1.3.2
|
249 |
+
- scipy==1.11.2
|
250 |
+
- seaborn==0.13.0
|
251 |
+
- send2trash==1.8.2
|
252 |
+
- sentry-sdk==1.25.1
|
253 |
+
- setproctitle==1.3.2
|
254 |
+
- smmap==5.0.0
|
255 |
+
- sniffio==1.3.0
|
256 |
+
- soupsieve==2.5
|
257 |
+
- stack-data==0.6.3
|
258 |
+
- tensorboard==2.13.0
|
259 |
+
- tensorboard-data-server==0.7.1
|
260 |
+
- terminado==0.17.1
|
261 |
+
- threadpoolctl==3.2.0
|
262 |
+
- tifffile==2023.8.12
|
263 |
+
- timm==0.9.2
|
264 |
+
- tinycss2==1.2.1
|
265 |
+
- tomli==2.0.1
|
266 |
+
- torchdata==0.6.1
|
267 |
+
- torchdiffeq==0.2.3
|
268 |
+
- torchsde==0.2.5
|
269 |
+
- tornado==6.3.3
|
270 |
+
- traitlets==5.13.0
|
271 |
+
- trampoline==0.1.2
|
272 |
+
- types-python-dateutil==2.8.19.14
|
273 |
+
- tzdata==2023.3
|
274 |
+
- uri-template==1.3.0
|
275 |
+
- wandb==0.15.4
|
276 |
+
- wcwidth==0.2.6
|
277 |
+
- webcolors==1.13
|
278 |
+
- webencodings==0.5.1
|
279 |
+
- websocket-client==1.6.4
|
280 |
+
- werkzeug==2.3.6
|
281 |
+
- widgetsnbextension==4.0.9
|
282 |
+
- zipp==3.15.0
|
guided_diffusion/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Codebase for "Improved Denoising Diffusion Probabilistic Models".
|
3 |
+
"""
|
guided_diffusion/condition_functions.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch as th
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from .pr_datasets_all import FUNC_DICT
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
|
12 |
+
plt.rcParams["figure.figsize"] = (20, 3)
|
13 |
+
plt.rcParams['figure.dpi'] = 300
|
14 |
+
plt.rcParams['savefig.dpi'] = 300
|
15 |
+
|
16 |
+
|
17 |
+
def model_fn(x, t, y=None, rule=None,
|
18 |
+
model=nn.Identity(), num_classes=3, class_cond=True, cfg=False, w=0.):
|
19 |
+
# y has to be composer, rule is a dummy input
|
20 |
+
y_null = th.tensor([num_classes] * x.shape[0], device=x.device)
|
21 |
+
if class_cond:
|
22 |
+
if cfg:
|
23 |
+
return (1 + w) * model(x, t, y) - w * model(x, t, y_null)
|
24 |
+
else:
|
25 |
+
return model(x, t, y)
|
26 |
+
else:
|
27 |
+
return model(x, t, y_null)
|
28 |
+
|
29 |
+
|
30 |
+
def dc_model_fn(x, t, y=None, rule=None,
|
31 |
+
model=nn.Identity(), num_classes=3, class_cond=True, cfg=False, w=0.):
|
32 |
+
# diffcollage score function takes in 4 x pitch x time
|
33 |
+
x = x.permute(0, 1, 3, 2)
|
34 |
+
y_null = th.tensor([num_classes] * x.shape[0], device=x.device)
|
35 |
+
if class_cond:
|
36 |
+
if cfg:
|
37 |
+
eps = (1 + w) * model(x, t, y) - w * model(x, t, y_null)
|
38 |
+
return eps.permute(0, 1, 3, 2) # need to return 4 x time x pitch
|
39 |
+
else:
|
40 |
+
return model(x, t, y).permute(0, 1, 3, 2)
|
41 |
+
else:
|
42 |
+
return model(x, t, y_null).permute(0, 1, 3, 2)
|
43 |
+
|
44 |
+
|
45 |
+
# y is a dummy input for cond_fn, rule is the real input
|
46 |
+
def grad_nn_zt_xentropy(x, y=None, rule=None, classifier=nn.Identity()):
|
47 |
+
# Xentropy cond_fn
|
48 |
+
assert rule is not None
|
49 |
+
t = th.zeros(x.shape[0], device=x.device)
|
50 |
+
with th.enable_grad():
|
51 |
+
x_in = x.detach().requires_grad_(True)
|
52 |
+
logits = classifier(x_in, t)
|
53 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
54 |
+
selected = log_probs[range(len(logits)), rule.view(-1)]
|
55 |
+
return th.autograd.grad(selected.sum(), x_in)[0]
|
56 |
+
|
57 |
+
|
58 |
+
def grad_nn_zt_mse(x, t, y=None, rule=None, classifier_scale=10., classifier=nn.Identity()):
|
59 |
+
assert rule is not None
|
60 |
+
with th.enable_grad():
|
61 |
+
x_in = x.detach().requires_grad_(True)
|
62 |
+
logits = classifier(x_in, t)
|
63 |
+
log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
|
64 |
+
return th.autograd.grad(log_probs.sum(), x_in)[0] * classifier_scale
|
65 |
+
|
66 |
+
|
67 |
+
def grad_nn_zt_chord(x, t, y=None, rule=None, classifier_scale=10., classifier=nn.Identity(), both=False):
|
68 |
+
assert rule is not None
|
69 |
+
with th.enable_grad():
|
70 |
+
x_in = x.detach().requires_grad_(True)
|
71 |
+
key_logits, chord_logits = classifier(x_in, t)
|
72 |
+
if both:
|
73 |
+
rule_key = rule[:, :1]
|
74 |
+
rule_chord = rule[:, 1:]
|
75 |
+
rule_chord = rule_chord.reshape(-1)
|
76 |
+
chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1])
|
77 |
+
key_log_probs = - F.cross_entropy(key_logits, rule_key, reduction="none")
|
78 |
+
chord_log_probs = - F.cross_entropy(chord_logits, rule_chord, reduction="none")
|
79 |
+
chord_log_probs = chord_log_probs.reshape(x_in.shape[0], -1).mean(dim=-1)
|
80 |
+
log_probs = key_log_probs + chord_log_probs
|
81 |
+
else:
|
82 |
+
rule = rule.reshape(-1)
|
83 |
+
chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1])
|
84 |
+
log_probs = - F.cross_entropy(chord_logits, rule, reduction="none")
|
85 |
+
return th.autograd.grad(log_probs.sum(), x_in)[0] * classifier_scale
|
86 |
+
|
87 |
+
|
88 |
+
def nn_z0_chord_dummy(x, t, y=None, rule=None, classifier_scale=0.1, classifier=nn.Identity(), both=False):
|
89 |
+
# classifier_scale is equivalent to step_size
|
90 |
+
t = th.zeros(x.shape[0], device=x.device)
|
91 |
+
key_logits, chord_logits = classifier(x, t)
|
92 |
+
if both:
|
93 |
+
rule_key = rule[:, :1]
|
94 |
+
rule_chord = rule[:, 1:]
|
95 |
+
rule_chord = rule_chord.reshape(-1)
|
96 |
+
chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1])
|
97 |
+
key_log_probs = - F.cross_entropy(key_logits, rule_key, reduction="none")
|
98 |
+
chord_log_probs = - F.cross_entropy(chord_logits, rule_chord, reduction="none")
|
99 |
+
chord_log_probs = chord_log_probs.reshape(x.shape[0], -1).mean(dim=-1)
|
100 |
+
log_probs = key_log_probs + chord_log_probs
|
101 |
+
else:
|
102 |
+
rule = rule.reshape(-1)
|
103 |
+
chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1])
|
104 |
+
log_probs = - F.cross_entropy(chord_logits, rule, reduction="none")
|
105 |
+
log_probs = log_probs.reshape(x.shape[0], -1).mean(dim=-1)
|
106 |
+
return log_probs * classifier_scale
|
107 |
+
|
108 |
+
|
109 |
+
def nn_z0_mse_dummy(x, t, y=None, rule=None, classifier_scale=0.1, classifier=nn.Identity()):
|
110 |
+
# mse cond_fn, t is a dummy variable b/c wrap_model in respace
|
111 |
+
assert rule is not None
|
112 |
+
t = th.zeros(x.shape[0], device=x.device)
|
113 |
+
logits = classifier(x, t)
|
114 |
+
log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
|
115 |
+
return log_probs * classifier_scale
|
116 |
+
|
117 |
+
|
118 |
+
def nn_z0_mse(x, rule=None, classifier=nn.Identity()):
|
119 |
+
# mse cond_fn, t is a dummy variable b/c wrap_model in respace
|
120 |
+
t = th.zeros(x.shape[0], device=x.device)
|
121 |
+
logits = classifier(x, t)
|
122 |
+
log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
|
123 |
+
return log_probs
|
124 |
+
|
125 |
+
|
126 |
+
def rule_x0_mse_dummy(x, t, y=None, rule=None, rule_name='pitch_hist'):
|
127 |
+
# use differentiable rule to differentiate through rule(x_0), t is a dummy variable b/c wrap_model in respace
|
128 |
+
logits = FUNC_DICT[rule_name](x)
|
129 |
+
log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
|
130 |
+
return log_probs
|
131 |
+
|
132 |
+
|
133 |
+
def rule_x0_mse(x, rule=None, rule_name='pitch_hist', soft=False):
|
134 |
+
# soften non-differentiable rule to differentiate through rule(x_0)
|
135 |
+
# soften doesn't seem to work so didn't actually take in soft as input, always set to False
|
136 |
+
logits = FUNC_DICT[rule_name](x, soft=soft)
|
137 |
+
log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
|
138 |
+
return log_probs
|
139 |
+
|
140 |
+
|
141 |
+
class _WrappedFn:
|
142 |
+
def __init__(self, fn):
|
143 |
+
self.fn = fn
|
144 |
+
|
145 |
+
def __call__(self, x, t, y=None, rule=None):
|
146 |
+
return self.fn(x, t, y, rule)
|
147 |
+
|
148 |
+
|
149 |
+
function_map = {
|
150 |
+
"grad_nn_zt_xentropy": grad_nn_zt_xentropy,
|
151 |
+
"grad_nn_zt_mse": grad_nn_zt_mse,
|
152 |
+
"grad_nn_zt_chord": grad_nn_zt_chord,
|
153 |
+
"nn_z0_chord_dummy": nn_z0_chord_dummy,
|
154 |
+
"nn_z0_mse_dummy": nn_z0_mse_dummy,
|
155 |
+
"nn_z0_mse": nn_z0_mse,
|
156 |
+
"rule_x0_mse_dummy": rule_x0_mse_dummy,
|
157 |
+
"rule_x0_mse": rule_x0_mse
|
158 |
+
}
|
159 |
+
|
160 |
+
|
161 |
+
def composite_nn_zt(x, t, y=None, rule=None, fns=None, classifier_scales=None, classifiers=None, rule_names=None):
|
162 |
+
num_classifiers = len(classifiers)
|
163 |
+
out = 0
|
164 |
+
for i in range(num_classifiers):
|
165 |
+
out += function_map[fns[i]](x, t, y=y, rule=rule[rule_names[i]],
|
166 |
+
classifier_scale=classifier_scales[i], classifier=classifiers[i])
|
167 |
+
return out
|
168 |
+
|
169 |
+
|
170 |
+
def composite_rule(x, t, y=None, rule=None, fns=None, classifier_scales=None, rule_names=None):
|
171 |
+
out = 0
|
172 |
+
for i in range(len(fns)):
|
173 |
+
out += function_map[fns[i]](x, t, y=y, rule=rule[rule_names[i]], rule_name=rule_names[i]) * classifier_scales[i]
|
174 |
+
return out
|
guided_diffusion/dist_util.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Helpers for distributed training.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import socket
|
8 |
+
|
9 |
+
import blobfile as bf
|
10 |
+
from mpi4py import MPI
|
11 |
+
import torch as th
|
12 |
+
import torch.distributed as dist
|
13 |
+
|
14 |
+
# Change this to reflect your cluster layout.
|
15 |
+
# The GPU for a given rank is (rank % GPUS_PER_NODE).
|
16 |
+
GPUS_PER_NODE = 2
|
17 |
+
|
18 |
+
SETUP_RETRY_COUNT = 3
|
19 |
+
|
20 |
+
|
21 |
+
def setup_dist(port=None):
|
22 |
+
"""
|
23 |
+
Setup a distributed process group.
|
24 |
+
For NGC, set port = "8023"
|
25 |
+
"""
|
26 |
+
if dist.is_initialized():
|
27 |
+
return
|
28 |
+
if not os.environ.get("CUDA_VISIBLE_DEVICES"):
|
29 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
|
30 |
+
|
31 |
+
comm = MPI.COMM_WORLD
|
32 |
+
backend = "gloo" if not th.cuda.is_available() else "nccl"
|
33 |
+
|
34 |
+
if backend == "gloo":
|
35 |
+
hostname = "localhost"
|
36 |
+
else:
|
37 |
+
hostname = socket.gethostbyname(socket.getfqdn())
|
38 |
+
if port is not None:
|
39 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
40 |
+
else:
|
41 |
+
os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
|
42 |
+
os.environ["RANK"] = str(comm.rank)
|
43 |
+
os.environ["WORLD_SIZE"] = str(comm.size)
|
44 |
+
|
45 |
+
if port is not None:
|
46 |
+
os.environ["MASTER_PORT"] = port
|
47 |
+
else:
|
48 |
+
port = comm.bcast(_find_free_port(), root=0)
|
49 |
+
os.environ["MASTER_PORT"] = str(port)
|
50 |
+
dist.init_process_group(backend=backend, init_method="env://")
|
51 |
+
th.cuda.set_device(comm.rank) # need to run on hpc
|
52 |
+
|
53 |
+
return comm
|
54 |
+
|
55 |
+
|
56 |
+
def dev():
|
57 |
+
"""
|
58 |
+
Get the device to use for torch.distributed.
|
59 |
+
"""
|
60 |
+
if th.cuda.is_available():
|
61 |
+
return th.device(f"cuda")
|
62 |
+
return th.device("cpu")
|
63 |
+
|
64 |
+
|
65 |
+
def load_state_dict(path, **kwargs):
|
66 |
+
"""
|
67 |
+
Load a PyTorch file without redundant fetches across MPI ranks.
|
68 |
+
"""
|
69 |
+
chunk_size = 2 ** 30 # MPI has a relatively small size limit
|
70 |
+
if MPI.COMM_WORLD.Get_rank() == 0:
|
71 |
+
with bf.BlobFile(path, "rb") as f:
|
72 |
+
data = f.read()
|
73 |
+
num_chunks = len(data) // chunk_size
|
74 |
+
if len(data) % chunk_size:
|
75 |
+
num_chunks += 1
|
76 |
+
MPI.COMM_WORLD.bcast(num_chunks)
|
77 |
+
for i in range(0, len(data), chunk_size):
|
78 |
+
MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
|
79 |
+
else:
|
80 |
+
num_chunks = MPI.COMM_WORLD.bcast(None)
|
81 |
+
data = bytes()
|
82 |
+
for _ in range(num_chunks):
|
83 |
+
data += MPI.COMM_WORLD.bcast(None)
|
84 |
+
|
85 |
+
return th.load(io.BytesIO(data), **kwargs)
|
86 |
+
|
87 |
+
|
88 |
+
def sync_params(params):
|
89 |
+
"""
|
90 |
+
Synchronize a sequence of Tensors across ranks from rank 0.
|
91 |
+
"""
|
92 |
+
for p in params:
|
93 |
+
with th.no_grad():
|
94 |
+
dist.broadcast(p, 0)
|
95 |
+
|
96 |
+
|
97 |
+
def _find_free_port():
|
98 |
+
try:
|
99 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
100 |
+
s.bind(("", 0))
|
101 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
102 |
+
return s.getsockname()[1]
|
103 |
+
finally:
|
104 |
+
s.close()
|
guided_diffusion/dit.py
ADDED
@@ -0,0 +1,983 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from rotary_embedding_torch import RotaryEmbedding
|
16 |
+
from torch.jit import Final
|
17 |
+
import numpy as np
|
18 |
+
import math
|
19 |
+
from timm.models.vision_transformer import Attention, Mlp
|
20 |
+
from timm.models.vision_transformer_relpos import RelPosAttention
|
21 |
+
from timm.layers import Format, nchw_to, to_2tuple, _assert, RelPosBias, use_fused_attn
|
22 |
+
from typing import Callable, List, Optional, Tuple, Union
|
23 |
+
from functools import partial
|
24 |
+
|
25 |
+
def modulate(x, shift, scale):
|
26 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
27 |
+
|
28 |
+
|
29 |
+
#################################################################################
|
30 |
+
# Embedding Layers for Timesteps and Class Labels #
|
31 |
+
#################################################################################
|
32 |
+
|
33 |
+
class TimestepEmbedder(nn.Module):
|
34 |
+
"""
|
35 |
+
Embeds scalar timesteps into vector representations.
|
36 |
+
"""
|
37 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
38 |
+
super().__init__()
|
39 |
+
self.mlp = nn.Sequential(
|
40 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
41 |
+
nn.SiLU(),
|
42 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
43 |
+
)
|
44 |
+
self.frequency_embedding_size = frequency_embedding_size
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def timestep_embedding(t, dim, max_period=10000):
|
48 |
+
"""
|
49 |
+
Create sinusoidal timestep embeddings.
|
50 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
51 |
+
These may be fractional.
|
52 |
+
:param dim: the dimension of the output.
|
53 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
54 |
+
:return: an (N, D) Tensor of positional embeddings.
|
55 |
+
"""
|
56 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
57 |
+
half = dim // 2
|
58 |
+
freqs = torch.exp(
|
59 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
60 |
+
).to(device=t.device)
|
61 |
+
args = t[:, None].float() * freqs[None]
|
62 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
63 |
+
if dim % 2:
|
64 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
65 |
+
return embedding
|
66 |
+
|
67 |
+
def forward(self, t):
|
68 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
69 |
+
t_emb = self.mlp(t_freq)
|
70 |
+
return t_emb
|
71 |
+
|
72 |
+
|
73 |
+
class LabelEmbedder(nn.Module):
|
74 |
+
"""
|
75 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
76 |
+
"""
|
77 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
78 |
+
super().__init__()
|
79 |
+
use_cfg_embedding = dropout_prob > 0
|
80 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
81 |
+
self.num_classes = num_classes
|
82 |
+
self.dropout_prob = dropout_prob
|
83 |
+
|
84 |
+
def token_drop(self, labels, force_drop_ids=None):
|
85 |
+
"""
|
86 |
+
Drops labels to enable classifier-free guidance.
|
87 |
+
"""
|
88 |
+
if force_drop_ids is None:
|
89 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
90 |
+
else:
|
91 |
+
drop_ids = force_drop_ids == 1
|
92 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
93 |
+
return labels
|
94 |
+
|
95 |
+
def forward(self, labels, train, force_drop_ids=None):
|
96 |
+
use_dropout = self.dropout_prob > 0
|
97 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
98 |
+
labels = self.token_drop(labels, force_drop_ids)
|
99 |
+
embeddings = self.embedding_table(labels)
|
100 |
+
return embeddings
|
101 |
+
|
102 |
+
|
103 |
+
#################################################################################
|
104 |
+
# Embedding Layers for Patches that Support H != W #
|
105 |
+
#################################################################################
|
106 |
+
|
107 |
+
class PatchEmbed(nn.Module):
|
108 |
+
""" 2D Image to Patch Embedding
|
109 |
+
"""
|
110 |
+
output_fmt: Format
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
img_size: Optional[Union[int, tuple, list]] = 224,
|
115 |
+
patch_size: Union[int, tuple, list] = 16,
|
116 |
+
in_chans: int = 3,
|
117 |
+
embed_dim: int = 768,
|
118 |
+
norm_layer: Optional[Callable] = None,
|
119 |
+
flatten: bool = True,
|
120 |
+
output_fmt: Optional[str] = None,
|
121 |
+
bias: bool = True,
|
122 |
+
strict_img_size: bool = True,
|
123 |
+
):
|
124 |
+
super().__init__()
|
125 |
+
self.patch_size = to_2tuple(patch_size)
|
126 |
+
if img_size is not None:
|
127 |
+
if isinstance(img_size, int):
|
128 |
+
self.img_size = to_2tuple(img_size)
|
129 |
+
elif len(img_size) == 1:
|
130 |
+
self.img_size = to_2tuple(img_size[0])
|
131 |
+
else:
|
132 |
+
self.img_size = img_size
|
133 |
+
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
134 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
135 |
+
else:
|
136 |
+
self.img_size = None
|
137 |
+
self.grid_size = None
|
138 |
+
self.num_patches = None
|
139 |
+
|
140 |
+
if output_fmt is not None:
|
141 |
+
self.flatten = False
|
142 |
+
self.output_fmt = Format(output_fmt)
|
143 |
+
else:
|
144 |
+
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
145 |
+
self.flatten = flatten
|
146 |
+
self.output_fmt = Format.NCHW
|
147 |
+
self.strict_img_size = strict_img_size
|
148 |
+
|
149 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
150 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
B, C, H, W = x.shape
|
154 |
+
if self.img_size is not None:
|
155 |
+
if self.strict_img_size:
|
156 |
+
_assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
157 |
+
_assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
|
158 |
+
else:
|
159 |
+
_assert(
|
160 |
+
H % self.patch_size[0] == 0,
|
161 |
+
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
162 |
+
)
|
163 |
+
_assert(
|
164 |
+
W % self.patch_size[1] == 0,
|
165 |
+
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
166 |
+
)
|
167 |
+
|
168 |
+
x = self.proj(x)
|
169 |
+
if self.flatten:
|
170 |
+
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
171 |
+
elif self.output_fmt != Format.NCHW:
|
172 |
+
x = nchw_to(x, self.output_fmt)
|
173 |
+
x = self.norm(x)
|
174 |
+
return x
|
175 |
+
|
176 |
+
|
177 |
+
class FlattenNorm(nn.Module):
|
178 |
+
""" Flatten 2D Image to a vector
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
img_size: Optional[Union[int, tuple, list]] = 224,
|
184 |
+
embed_dim: int = 768,
|
185 |
+
norm_layer: Optional[Callable] = None,
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
self.num_patches = max(img_size)
|
189 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
190 |
+
# todo: hard code 64 and hidden_dim for now
|
191 |
+
self.MLP = nn.Sequential(nn.Linear(64, 256), nn.SiLU(), nn.Linear(256, embed_dim))
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
x = x.permute(0, 2, 1, 3).flatten(2) # B x 4 x 128 x 16 -> B x 128 x 4 x 16 - > B x 128 x 64
|
195 |
+
x = self.MLP(x) # B x 128 x 768
|
196 |
+
x = self.norm(x)
|
197 |
+
return x
|
198 |
+
|
199 |
+
|
200 |
+
class FlattenPatchify1D(nn.Module):
|
201 |
+
""" Flatten 2D Image to a vector with pitch per token
|
202 |
+
"""
|
203 |
+
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
in_channels: int = 4,
|
207 |
+
img_size: Optional[Union[int, tuple, list]] = 224,
|
208 |
+
embed_dim: int = 768,
|
209 |
+
patch_size: int = 8,
|
210 |
+
norm_layer: Optional[Callable] = None,
|
211 |
+
):
|
212 |
+
super().__init__()
|
213 |
+
# dummy, is not needed by the rotary model, but needed for REL and DiT
|
214 |
+
self.num_patches = img_size[0] * img_size[1] // patch_size # img_size: 128x16
|
215 |
+
self.patch_size = patch_size
|
216 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
217 |
+
self.MLP = nn.Sequential(nn.Linear(in_channels * patch_size, 256), nn.SiLU(), nn.Linear(256, embed_dim))
|
218 |
+
|
219 |
+
def forward(self, x):
|
220 |
+
x = x.permute(0, 2, 3, 1) # B x c x 128 x 16 -> B x 128 x 16 x c
|
221 |
+
b, n_time, n_pitch, c = x.shape
|
222 |
+
num_patches = n_time * n_pitch // self.patch_size
|
223 |
+
# B x 128 x 16 x 4 -> B x (128 x 16 / 8) x (4 * 8)
|
224 |
+
x = x.reshape(b, num_patches, -1)
|
225 |
+
x = self.MLP(x) # B x 256 x 768
|
226 |
+
x = self.norm(x)
|
227 |
+
return x
|
228 |
+
|
229 |
+
|
230 |
+
#################################################################################
|
231 |
+
# Core DiT Model #
|
232 |
+
#################################################################################
|
233 |
+
|
234 |
+
class RotaryAttention(nn.Module):
|
235 |
+
fused_attn: Final[bool]
|
236 |
+
|
237 |
+
def __init__(
|
238 |
+
self,
|
239 |
+
dim,
|
240 |
+
num_heads=8,
|
241 |
+
qkv_bias=False,
|
242 |
+
qk_norm=False,
|
243 |
+
attn_drop=0.,
|
244 |
+
proj_drop=0.,
|
245 |
+
norm_layer=nn.LayerNorm,
|
246 |
+
rotary_emb=None,
|
247 |
+
):
|
248 |
+
super().__init__()
|
249 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
250 |
+
self.num_heads = num_heads
|
251 |
+
self.head_dim = dim // num_heads
|
252 |
+
self.scale = self.head_dim ** -0.5
|
253 |
+
self.fused_attn = use_fused_attn()
|
254 |
+
self.rotary_emb = rotary_emb
|
255 |
+
|
256 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
257 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
258 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
259 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
260 |
+
self.proj = nn.Linear(dim, dim)
|
261 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
B, N, C = x.shape
|
265 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
266 |
+
q, k, v = qkv.unbind(0)
|
267 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
268 |
+
|
269 |
+
if self.rotary_emb is not None:
|
270 |
+
q = self.rotary_emb.rotate_queries_or_keys(q)
|
271 |
+
k = self.rotary_emb.rotate_queries_or_keys(k)
|
272 |
+
|
273 |
+
if self.fused_attn:
|
274 |
+
x = F.scaled_dot_product_attention(
|
275 |
+
q, k, v,
|
276 |
+
dropout_p=self.attn_drop.p,
|
277 |
+
)
|
278 |
+
else:
|
279 |
+
q = q * self.scale
|
280 |
+
attn = q @ k.transpose(-2, -1)
|
281 |
+
attn = attn.softmax(dim=-1)
|
282 |
+
attn = self.attn_drop(attn)
|
283 |
+
x = attn @ v
|
284 |
+
|
285 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
286 |
+
x = self.proj(x)
|
287 |
+
x = self.proj_drop(x)
|
288 |
+
return x
|
289 |
+
|
290 |
+
|
291 |
+
class DiTBlock(nn.Module):
|
292 |
+
"""
|
293 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
294 |
+
"""
|
295 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
296 |
+
super().__init__()
|
297 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
298 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
299 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
300 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
301 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
302 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
303 |
+
self.adaLN_modulation = nn.Sequential(
|
304 |
+
nn.SiLU(),
|
305 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
306 |
+
)
|
307 |
+
|
308 |
+
def forward(self, x, c):
|
309 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
310 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
311 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
312 |
+
return x
|
313 |
+
|
314 |
+
|
315 |
+
class DiTBlockRotary(nn.Module):
|
316 |
+
"""
|
317 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning & rotary attention.
|
318 |
+
"""
|
319 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, rotary_emb=None, **block_kwargs):
|
320 |
+
super().__init__()
|
321 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
322 |
+
self.attn = RotaryAttention(hidden_size, num_heads=num_heads, qkv_bias=True, rotary_emb=rotary_emb, **block_kwargs)
|
323 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
324 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
325 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
326 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
327 |
+
self.adaLN_modulation = nn.Sequential(
|
328 |
+
nn.SiLU(),
|
329 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
330 |
+
)
|
331 |
+
|
332 |
+
def forward(self, x, c):
|
333 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
334 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
335 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
336 |
+
return x
|
337 |
+
|
338 |
+
|
339 |
+
class FinalLayer(nn.Module):
|
340 |
+
"""
|
341 |
+
The final layer of DiT.
|
342 |
+
"""
|
343 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
344 |
+
super().__init__()
|
345 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
346 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
347 |
+
self.adaLN_modulation = nn.Sequential(
|
348 |
+
nn.SiLU(),
|
349 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
350 |
+
)
|
351 |
+
|
352 |
+
def forward(self, x, c):
|
353 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
354 |
+
x = modulate(self.norm_final(x), shift, scale)
|
355 |
+
x = self.linear(x)
|
356 |
+
return x
|
357 |
+
|
358 |
+
|
359 |
+
class FinalLayerPatch1D(nn.Module):
|
360 |
+
"""
|
361 |
+
The final layer of DiT with 1d Patchify.
|
362 |
+
"""
|
363 |
+
def __init__(self, hidden_size, out_channels, patch_size_1d=16):
|
364 |
+
super().__init__()
|
365 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
366 |
+
self.linear = nn.Linear(hidden_size, patch_size_1d*out_channels, bias=True)
|
367 |
+
self.adaLN_modulation = nn.Sequential(
|
368 |
+
nn.SiLU(),
|
369 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
370 |
+
)
|
371 |
+
|
372 |
+
def forward(self, x, c):
|
373 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
374 |
+
x = modulate(self.norm_final(x), shift, scale)
|
375 |
+
x = self.linear(x)
|
376 |
+
return x
|
377 |
+
|
378 |
+
|
379 |
+
class DiT(nn.Module):
|
380 |
+
"""
|
381 |
+
Diffusion model with a Transformer backbone.
|
382 |
+
"""
|
383 |
+
def __init__(
|
384 |
+
self,
|
385 |
+
input_size=32,
|
386 |
+
patch_size=2,
|
387 |
+
in_channels=3,
|
388 |
+
hidden_size=1152,
|
389 |
+
depth=28,
|
390 |
+
num_heads=16,
|
391 |
+
mlp_ratio=4.0,
|
392 |
+
class_dropout_prob=0.1,
|
393 |
+
num_classes=9, # cluster composers into 9 groups
|
394 |
+
learn_sigma=True,
|
395 |
+
patchify=True,
|
396 |
+
):
|
397 |
+
super().__init__()
|
398 |
+
self.learn_sigma = learn_sigma
|
399 |
+
self.in_channels = in_channels
|
400 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
401 |
+
self.patch_size = patch_size
|
402 |
+
self.num_heads = num_heads
|
403 |
+
self.input_size = input_size
|
404 |
+
self.patchify = patchify
|
405 |
+
|
406 |
+
if patchify:
|
407 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
408 |
+
else:
|
409 |
+
self.x_embedder = FlattenNorm(input_size, hidden_size)
|
410 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
411 |
+
self.num_classes = num_classes
|
412 |
+
if self.num_classes:
|
413 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
414 |
+
num_patches = self.x_embedder.num_patches
|
415 |
+
# Will use fixed sin-cos embedding:
|
416 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
417 |
+
|
418 |
+
self.blocks = nn.ModuleList([
|
419 |
+
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
420 |
+
])
|
421 |
+
if patchify:
|
422 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
423 |
+
else:
|
424 |
+
self.final_layer = FinalLayerPatch1D(hidden_size, self.out_channels, patch_size)
|
425 |
+
self.initialize_weights()
|
426 |
+
|
427 |
+
def initialize_weights(self):
|
428 |
+
# Initialize transformer layers:
|
429 |
+
def _basic_init(module):
|
430 |
+
if isinstance(module, nn.Linear):
|
431 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
432 |
+
if module.bias is not None:
|
433 |
+
nn.init.constant_(module.bias, 0)
|
434 |
+
self.apply(_basic_init)
|
435 |
+
|
436 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
437 |
+
if self.patchify:
|
438 |
+
if isinstance(self.input_size, int) or len(self.input_size) == 1:
|
439 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5), int(self.x_embedder.num_patches ** 0.5))
|
440 |
+
else:
|
441 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size[0], self.x_embedder.grid_size[1])
|
442 |
+
else:
|
443 |
+
# 1D position encoding
|
444 |
+
pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1],
|
445 |
+
np.arange(self.x_embedder.num_patches, dtype=np.float32))
|
446 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
447 |
+
|
448 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
449 |
+
if self.patchify:
|
450 |
+
w = self.x_embedder.proj.weight.data
|
451 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
452 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
453 |
+
|
454 |
+
# Initialize label embedding table:
|
455 |
+
if self.num_classes:
|
456 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
457 |
+
|
458 |
+
# Initialize timestep embedding MLP:
|
459 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
460 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
461 |
+
|
462 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
463 |
+
for block in self.blocks:
|
464 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
465 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
466 |
+
|
467 |
+
# Zero-out output layers:
|
468 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
469 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
470 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
471 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
472 |
+
|
473 |
+
def unpatchify(self, x):
|
474 |
+
"""
|
475 |
+
x: (N, T, patch_size**2 * C)
|
476 |
+
imgs: (N, H, W, C)
|
477 |
+
"""
|
478 |
+
c = self.out_channels
|
479 |
+
p = self.x_embedder.patch_size[0]
|
480 |
+
if isinstance(self.input_size, int) or len(self.input_size) == 1:
|
481 |
+
h = w = int(x.shape[1] ** 0.5)
|
482 |
+
assert h * w == x.shape[1]
|
483 |
+
else:
|
484 |
+
h = self.input_size[0] // self.patch_size
|
485 |
+
w = self.input_size[1] // self.patch_size
|
486 |
+
|
487 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
488 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
489 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
490 |
+
return imgs
|
491 |
+
|
492 |
+
def unflatten(self, x):
|
493 |
+
c = self.out_channels
|
494 |
+
x = x.reshape(shape=(x.shape[0], x.shape[1], c, -1))
|
495 |
+
imgs = x.permute(0, 2, 1, 3)
|
496 |
+
return imgs
|
497 |
+
|
498 |
+
def forward(self, x, t, y=None):
|
499 |
+
"""
|
500 |
+
Forward pass of DiT.
|
501 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
502 |
+
t: (N,) tensor of diffusion timesteps
|
503 |
+
y: (N,) tensor of class labels
|
504 |
+
"""
|
505 |
+
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
506 |
+
c = self.t_embedder(t) # (N, D)
|
507 |
+
if self.num_classes and y is not None:
|
508 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
509 |
+
c = c + y # (N, D)
|
510 |
+
for block in self.blocks:
|
511 |
+
x = block(x, c) # (N, T, D)
|
512 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
513 |
+
if self.patchify:
|
514 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
515 |
+
else:
|
516 |
+
x = self.unflatten(x)
|
517 |
+
return x
|
518 |
+
|
519 |
+
def forward_with_cfg(self, x, t, y, cfg_scale):
|
520 |
+
"""
|
521 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
522 |
+
"""
|
523 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
524 |
+
half = x[: len(x) // 2]
|
525 |
+
combined = torch.cat([half, half], dim=0)
|
526 |
+
model_out = self.forward(combined, t, y)
|
527 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
528 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
529 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
530 |
+
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
531 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
532 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
533 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
534 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
535 |
+
return torch.cat([eps, rest], dim=1)
|
536 |
+
|
537 |
+
|
538 |
+
class DiTRotary(nn.Module):
|
539 |
+
"""
|
540 |
+
Diffusion model with a Transformer backbone, with rotary position embedding.
|
541 |
+
Use 1D position encoding, patchify is set to False
|
542 |
+
"""
|
543 |
+
|
544 |
+
def __init__(
|
545 |
+
self,
|
546 |
+
input_size=32,
|
547 |
+
patch_size=8, # patch size for 1D patchify
|
548 |
+
in_channels=3,
|
549 |
+
hidden_size=1152,
|
550 |
+
depth=28,
|
551 |
+
num_heads=16,
|
552 |
+
mlp_ratio=4.0,
|
553 |
+
class_dropout_prob=0.1,
|
554 |
+
num_classes=9, # cluster composers into 9 groups
|
555 |
+
learn_sigma=True,
|
556 |
+
):
|
557 |
+
super().__init__()
|
558 |
+
self.learn_sigma = learn_sigma
|
559 |
+
self.in_channels = in_channels
|
560 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
561 |
+
self.patch_size = patch_size
|
562 |
+
self.num_heads = num_heads
|
563 |
+
self.input_size = input_size
|
564 |
+
|
565 |
+
self.x_embedder = FlattenPatchify1D(in_channels, input_size, hidden_size, patch_size)
|
566 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
567 |
+
self.num_classes = num_classes
|
568 |
+
if self.num_classes:
|
569 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
570 |
+
|
571 |
+
rotary_dim = int(hidden_size // num_heads * 0.5) # 0.5 is rotary percentage in multihead rope, by default 0.5
|
572 |
+
self.rotary_emb = RotaryEmbedding(rotary_dim)
|
573 |
+
self.blocks = nn.ModuleList([
|
574 |
+
DiTBlockRotary(hidden_size, num_heads, mlp_ratio=mlp_ratio, rotary_emb=self.rotary_emb) for _ in range(depth)
|
575 |
+
])
|
576 |
+
self.final_layer = FinalLayerPatch1D(hidden_size, self.out_channels, patch_size_1d=self.patch_size)
|
577 |
+
self.initialize_weights()
|
578 |
+
|
579 |
+
def initialize_weights(self):
|
580 |
+
# Initialize transformer layers:
|
581 |
+
def _basic_init(module):
|
582 |
+
if isinstance(module, nn.Linear):
|
583 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
584 |
+
if module.bias is not None:
|
585 |
+
nn.init.constant_(module.bias, 0)
|
586 |
+
|
587 |
+
self.apply(_basic_init)
|
588 |
+
|
589 |
+
# Initialize label embedding table:
|
590 |
+
if self.num_classes:
|
591 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
592 |
+
|
593 |
+
# Initialize timestep embedding MLP:
|
594 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
595 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
596 |
+
|
597 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
598 |
+
for block in self.blocks:
|
599 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
600 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
601 |
+
|
602 |
+
# Zero-out output layers:
|
603 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
604 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
605 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
606 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
607 |
+
|
608 |
+
def unpatchify(self, x):
|
609 |
+
"""
|
610 |
+
x: (N, T, img_size[1] / patch_size * C)
|
611 |
+
imgs: (N, H, W, C)
|
612 |
+
"""
|
613 |
+
# input_size[1] is the pitch dimension, should always be the same
|
614 |
+
x = x.reshape(shape=(x.shape[0], -1, self.input_size[1], self.out_channels))
|
615 |
+
imgs = x.permute(0, 3, 1, 2)
|
616 |
+
return imgs
|
617 |
+
|
618 |
+
def forward(self, x, t, y=None):
|
619 |
+
"""
|
620 |
+
Forward pass of DiT.
|
621 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
622 |
+
t: (N,) tensor of diffusion timesteps
|
623 |
+
y: (N,) tensor of class labels
|
624 |
+
"""
|
625 |
+
x = self.x_embedder(x) # (N, T, D), where T = H * W / patch_size
|
626 |
+
c = self.t_embedder(t) # (N, D)
|
627 |
+
if self.num_classes and y is not None:
|
628 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
629 |
+
c = c + y # (N, D)
|
630 |
+
for block in self.blocks:
|
631 |
+
x = block(x, c) # (N, T, D)
|
632 |
+
x = self.final_layer(x, c) # (N, T, patch_size * out_channels)
|
633 |
+
x = self.unpatchify(x)
|
634 |
+
return x
|
635 |
+
|
636 |
+
|
637 |
+
class DiT_classifier(nn.Module):
|
638 |
+
"""
|
639 |
+
Classifier used in classifier guidance.
|
640 |
+
"""
|
641 |
+
def __init__(
|
642 |
+
self,
|
643 |
+
input_size=32,
|
644 |
+
patch_size=2,
|
645 |
+
in_channels=3,
|
646 |
+
hidden_size=1152,
|
647 |
+
depth=28,
|
648 |
+
num_heads=16,
|
649 |
+
mlp_ratio=4.0,
|
650 |
+
num_classes=9,
|
651 |
+
patchify=True,
|
652 |
+
):
|
653 |
+
super().__init__()
|
654 |
+
self.in_channels = in_channels
|
655 |
+
self.patch_size = patch_size
|
656 |
+
self.num_heads = num_heads
|
657 |
+
self.input_size = input_size
|
658 |
+
self.patchify = patchify
|
659 |
+
|
660 |
+
if patchify:
|
661 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
662 |
+
else:
|
663 |
+
self.x_embedder = FlattenNorm(input_size, hidden_size)
|
664 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
665 |
+
self.num_classes = num_classes
|
666 |
+
num_patches = self.x_embedder.num_patches
|
667 |
+
# Will use fixed sin-cos embedding:
|
668 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
669 |
+
|
670 |
+
self.blocks = nn.ModuleList([
|
671 |
+
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
672 |
+
])
|
673 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size), requires_grad=True)
|
674 |
+
self.norm = nn.LayerNorm(hidden_size)
|
675 |
+
self.classifier_head = nn.Sequential(nn.Linear(hidden_size, hidden_size//4),
|
676 |
+
nn.SiLU(), nn.Linear(hidden_size//4, self.num_classes))
|
677 |
+
self.initialize_weights()
|
678 |
+
|
679 |
+
def initialize_weights(self):
|
680 |
+
# Initialize transformer layers:
|
681 |
+
def _basic_init(module):
|
682 |
+
if isinstance(module, nn.Linear):
|
683 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
684 |
+
if module.bias is not None:
|
685 |
+
nn.init.constant_(module.bias, 0)
|
686 |
+
self.apply(_basic_init)
|
687 |
+
|
688 |
+
if self.patchify:
|
689 |
+
if isinstance(self.input_size, int) or len(self.input_size) == 1:
|
690 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5), int(self.x_embedder.num_patches ** 0.5))
|
691 |
+
else:
|
692 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size[0], self.x_embedder.grid_size[1])
|
693 |
+
else:
|
694 |
+
# 1D position encoding
|
695 |
+
pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1],
|
696 |
+
np.arange(self.x_embedder.num_patches, dtype=np.float32))
|
697 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
698 |
+
|
699 |
+
# Initialize class token
|
700 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
701 |
+
|
702 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
703 |
+
if self.patchify:
|
704 |
+
w = self.x_embedder.proj.weight.data
|
705 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
706 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
707 |
+
|
708 |
+
# Initialize timestep embedding MLP:
|
709 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
710 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
711 |
+
|
712 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
713 |
+
for block in self.blocks:
|
714 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
715 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
716 |
+
|
717 |
+
def forward(self, x, t):
|
718 |
+
"""
|
719 |
+
Forward pass of DiT.
|
720 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
721 |
+
t: (N,) tensor of diffusion timesteps
|
722 |
+
y: (N,) tensor of class labels
|
723 |
+
"""
|
724 |
+
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
725 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
726 |
+
c = self.t_embedder(t) # (N, D)
|
727 |
+
for block in self.blocks:
|
728 |
+
x = block(x, c) # (N, T, D)
|
729 |
+
x = x[:, 0, :] # (N, D)
|
730 |
+
x = self.norm(x)
|
731 |
+
x = self.classifier_head(x) # (N, num_classes)
|
732 |
+
return x
|
733 |
+
|
734 |
+
|
735 |
+
class DiTRotaryClassifier(nn.Module):
|
736 |
+
"""
|
737 |
+
Diffusion model with a Transformer backbone, with rotary position embedding.
|
738 |
+
Use 1D position encoding, patchify is set to False
|
739 |
+
"""
|
740 |
+
|
741 |
+
def __init__(
|
742 |
+
self,
|
743 |
+
input_size=32,
|
744 |
+
patch_size=8, # patch size for 1D patchify
|
745 |
+
in_channels=3,
|
746 |
+
hidden_size=1152,
|
747 |
+
depth=28,
|
748 |
+
num_heads=16,
|
749 |
+
mlp_ratio=4.0,
|
750 |
+
num_classes=9, # cluster composers into 9 groups
|
751 |
+
chord=False,
|
752 |
+
):
|
753 |
+
super().__init__()
|
754 |
+
self.in_channels = in_channels
|
755 |
+
self.patch_size = patch_size
|
756 |
+
self.num_heads = num_heads
|
757 |
+
self.input_size = input_size
|
758 |
+
self.chord = chord
|
759 |
+
self.hidden_size = hidden_size
|
760 |
+
|
761 |
+
self.x_embedder = FlattenPatchify1D(in_channels, input_size, hidden_size, patch_size)
|
762 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
763 |
+
self.num_classes = num_classes
|
764 |
+
|
765 |
+
rotary_dim = int(hidden_size // num_heads * 0.5) # 0.5 is rotary percentage in multihead rope, by default 0.5
|
766 |
+
self.rotary_emb = RotaryEmbedding(rotary_dim)
|
767 |
+
self.blocks = nn.ModuleList([
|
768 |
+
DiTBlockRotary(hidden_size, num_heads, mlp_ratio=mlp_ratio, rotary_emb=self.rotary_emb) for _ in range(depth)
|
769 |
+
])
|
770 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size), requires_grad=True)
|
771 |
+
self.norm = nn.LayerNorm(hidden_size)
|
772 |
+
self.classifier_head = nn.Sequential(nn.Linear(hidden_size, hidden_size//4),
|
773 |
+
nn.SiLU(), nn.Linear(hidden_size//4, self.num_classes))
|
774 |
+
if self.chord:
|
775 |
+
self.norm_key = nn.LayerNorm(hidden_size)
|
776 |
+
# predict key also: 24 major and minor keys + null
|
777 |
+
self.classifier_head_key = nn.Sequential(nn.Linear(hidden_size, hidden_size//4),
|
778 |
+
nn.SiLU(), nn.Linear(hidden_size//4, 25))
|
779 |
+
self.initialize_weights()
|
780 |
+
|
781 |
+
def initialize_weights(self):
|
782 |
+
# Initialize transformer layers:
|
783 |
+
def _basic_init(module):
|
784 |
+
if isinstance(module, nn.Linear):
|
785 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
786 |
+
if module.bias is not None:
|
787 |
+
nn.init.constant_(module.bias, 0)
|
788 |
+
|
789 |
+
self.apply(_basic_init)
|
790 |
+
|
791 |
+
# Initialize class token
|
792 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
793 |
+
|
794 |
+
# Initialize timestep embedding MLP:
|
795 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
796 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
797 |
+
|
798 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
799 |
+
for block in self.blocks:
|
800 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
801 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
802 |
+
|
803 |
+
def forward(self, x, t, y=None):
|
804 |
+
"""
|
805 |
+
Forward pass of DiT.
|
806 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
807 |
+
t: (N,) tensor of diffusion timesteps
|
808 |
+
y: (N,) tensor of class labels
|
809 |
+
"""
|
810 |
+
if self.chord:
|
811 |
+
n_token = x.shape[2] // x.shape[3]
|
812 |
+
x = self.x_embedder(x) # (N, T, D), where T = H * W / patch_size
|
813 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
814 |
+
c = self.t_embedder(t) # (N, D)
|
815 |
+
for block in self.blocks:
|
816 |
+
x = block(x, c) # (N, T, D)
|
817 |
+
if self.chord:
|
818 |
+
x_key = x[:, 0, :]
|
819 |
+
x_key = self.norm_key(x_key)
|
820 |
+
key = self.classifier_head_key(x_key)
|
821 |
+
x_chord = x[:, 1:, :]
|
822 |
+
x_chord = x_chord.reshape(shape=[x.shape[0], n_token, -1, self.hidden_size])
|
823 |
+
x_chord = x_chord.mean(dim=-2)
|
824 |
+
x_chord = self.norm(x_chord)
|
825 |
+
chord = self.classifier_head(x_chord)
|
826 |
+
return key, chord
|
827 |
+
else:
|
828 |
+
x = x[:, 0, :] # (N, D)
|
829 |
+
x = self.norm(x)
|
830 |
+
x = self.classifier_head(x) # (N, num_classes)
|
831 |
+
return x
|
832 |
+
|
833 |
+
|
834 |
+
#################################################################################
|
835 |
+
# Sine/Cosine Positional Embedding Functions #
|
836 |
+
#################################################################################
|
837 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
838 |
+
|
839 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size_h, grid_size_w, cls_token=False, extra_tokens=0):
|
840 |
+
"""
|
841 |
+
grid_size: int of the grid height and width
|
842 |
+
return:
|
843 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
844 |
+
"""
|
845 |
+
grid_h = np.arange(grid_size_h, dtype=np.float32)
|
846 |
+
grid_w = np.arange(grid_size_w, dtype=np.float32)
|
847 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
848 |
+
grid = np.stack(grid, axis=0)
|
849 |
+
|
850 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
851 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
852 |
+
if cls_token and extra_tokens > 0:
|
853 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
854 |
+
return pos_embed
|
855 |
+
|
856 |
+
|
857 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
858 |
+
assert embed_dim % 2 == 0
|
859 |
+
|
860 |
+
# use half of dimensions to encode grid_h
|
861 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
862 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
863 |
+
|
864 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
865 |
+
return emb
|
866 |
+
|
867 |
+
|
868 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
869 |
+
"""
|
870 |
+
embed_dim: output dimension for each position
|
871 |
+
pos: a list of positions to be encoded: size (M,)
|
872 |
+
out: (M, D)
|
873 |
+
"""
|
874 |
+
assert embed_dim % 2 == 0
|
875 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
876 |
+
omega /= embed_dim / 2.
|
877 |
+
omega = 1. / 10000**omega # (D/2,)
|
878 |
+
|
879 |
+
pos = pos.reshape(-1) # (M,)
|
880 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
881 |
+
|
882 |
+
emb_sin = np.sin(out) # (M, D/2)
|
883 |
+
emb_cos = np.cos(out) # (M, D/2)
|
884 |
+
|
885 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
886 |
+
return emb
|
887 |
+
|
888 |
+
|
889 |
+
#################################################################################
|
890 |
+
# DiT Configs #
|
891 |
+
#################################################################################
|
892 |
+
|
893 |
+
def DiT_XL_2(**kwargs):
|
894 |
+
return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
895 |
+
|
896 |
+
def DiT_XL_4(**kwargs):
|
897 |
+
return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
898 |
+
|
899 |
+
def DiTRotary_XL_16(**kwargs):
|
900 |
+
return DiTRotary(depth=28, hidden_size=1152, patch_size=16, num_heads=16, **kwargs)
|
901 |
+
|
902 |
+
def DiTRotary_XL_8(**kwargs):
|
903 |
+
return DiTRotary(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
904 |
+
|
905 |
+
def DiT_XL_8(**kwargs):
|
906 |
+
return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
907 |
+
|
908 |
+
def DiT_L_2(**kwargs):
|
909 |
+
return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
910 |
+
|
911 |
+
def DiT_L_4(**kwargs):
|
912 |
+
return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
913 |
+
|
914 |
+
def DiT_L_8(**kwargs):
|
915 |
+
return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
916 |
+
|
917 |
+
def DiT_B_2(**kwargs):
|
918 |
+
return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
919 |
+
|
920 |
+
def DiT_B_4(**kwargs):
|
921 |
+
return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
922 |
+
|
923 |
+
def DiTRotary_B_16(**kwargs): # seq_len = 128 = 128 * 16/16
|
924 |
+
return DiTRotary(depth=12, hidden_size=768, patch_size=16, num_heads=12, **kwargs)
|
925 |
+
|
926 |
+
def DiTRotary_B_8(**kwargs): # seq_len = 256 = 128 * 16/8
|
927 |
+
return DiTRotary(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
928 |
+
|
929 |
+
def DiT_B_8(**kwargs):
|
930 |
+
return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
931 |
+
|
932 |
+
def DiT_B_4_classifier(**kwargs):
|
933 |
+
return DiT_classifier(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
934 |
+
|
935 |
+
def DiT_B_8_classifier(**kwargs):
|
936 |
+
return DiT_classifier(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
937 |
+
|
938 |
+
def DiTRotary_B_8_classifier(**kwargs):
|
939 |
+
return DiTRotaryClassifier(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
940 |
+
|
941 |
+
def DiT_S_2(**kwargs):
|
942 |
+
return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
943 |
+
|
944 |
+
def DiT_S_2_classifier(**kwargs):
|
945 |
+
return DiT_classifier(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
946 |
+
|
947 |
+
def DiTRotary_S_8_classifier(**kwargs):
|
948 |
+
return DiTRotaryClassifier(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
949 |
+
|
950 |
+
def DiTRotary_S_8_chord_classifier(**kwargs):
|
951 |
+
return DiTRotaryClassifier(depth=12, hidden_size=384, patch_size=8, num_heads=6, chord=True, **kwargs)
|
952 |
+
|
953 |
+
def DiT_XS_2_classifier(**kwargs):
|
954 |
+
return DiT_classifier(depth=4, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
955 |
+
|
956 |
+
def DiTRotary_XS_8_classifier(**kwargs):
|
957 |
+
return DiTRotaryClassifier(depth=4, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
958 |
+
|
959 |
+
def DiT_S_4(**kwargs):
|
960 |
+
return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
961 |
+
|
962 |
+
def DiT_S_4_classifier(**kwargs):
|
963 |
+
return DiT_classifier(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
964 |
+
|
965 |
+
def DiT_S_8(**kwargs):
|
966 |
+
return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
967 |
+
|
968 |
+
|
969 |
+
DiT_models = {
|
970 |
+
'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
|
971 |
+
'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
|
972 |
+
'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
|
973 |
+
'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
|
974 |
+
'DiTRotary_B_16': DiTRotary_B_16, 'DiTRotary_B_8': DiTRotary_B_8,
|
975 |
+
'DiTRotary_XL_16': DiTRotary_XL_16, 'DiTRotary_XL_8': DiTRotary_XL_8,
|
976 |
+
'DiT-B/4-cls': DiT_B_4_classifier, 'DiT-B/8-cls': DiT_B_8_classifier,
|
977 |
+
'DiT-S/4-cls': DiT_S_4_classifier, 'DiT-S/2-cls': DiT_S_2_classifier,
|
978 |
+
'DiT-XS/2-cls': DiT_XS_2_classifier,
|
979 |
+
'DiTRotary-XS/8-cls': DiTRotary_XS_8_classifier,
|
980 |
+
'DiTRotary-S/8-cls': DiTRotary_S_8_classifier,
|
981 |
+
'DiTRotary-S/8-chord-cls': DiTRotary_S_8_chord_classifier,
|
982 |
+
'DiTRotary-B/8-cls': DiTRotary_B_8_classifier,
|
983 |
+
}
|
guided_diffusion/embed_datasets.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import os
|
4 |
+
import pandas as pd
|
5 |
+
import re
|
6 |
+
from PIL import Image
|
7 |
+
import blobfile as bf
|
8 |
+
from mpi4py import MPI
|
9 |
+
import numpy as np
|
10 |
+
from torch.utils.data import DataLoader, Dataset
|
11 |
+
|
12 |
+
CLUSTERS = {'Balakirev': 0,
|
13 |
+
'Bartholdy': 0,
|
14 |
+
'Bizet': 0,
|
15 |
+
'Brahms': 0,
|
16 |
+
'Busoni': 0,
|
17 |
+
'Chopin': 0,
|
18 |
+
'Grieg': 0,
|
19 |
+
'Horowitz': 0,
|
20 |
+
'Liszt': 0,
|
21 |
+
'Mendelssohn': 0,
|
22 |
+
'Moszkowski': 0,
|
23 |
+
'Paganini': 0,
|
24 |
+
'Saint-Saens': 0,
|
25 |
+
'Schubert': 0,
|
26 |
+
'Schumann': 0,
|
27 |
+
'Strauss': 0,
|
28 |
+
'Tchaikovsky': 0,
|
29 |
+
'Wagner': 0,
|
30 |
+
'Beethoven': 1,
|
31 |
+
'Bach': 2,
|
32 |
+
'Handel': 2,
|
33 |
+
'Purcell': 2,
|
34 |
+
'Barber': 3,
|
35 |
+
'Bartok': 3,
|
36 |
+
'Hindemith': 3,
|
37 |
+
'Ligeti': 3,
|
38 |
+
'Messiaen': 3,
|
39 |
+
'Mussorgsky': 3,
|
40 |
+
'Myaskovsky': 3,
|
41 |
+
'Prokofiev': 3,
|
42 |
+
'Schnittke': 3,
|
43 |
+
'Schonberg': 3,
|
44 |
+
'Shostakovich': 3,
|
45 |
+
'Stravinsky': 3,
|
46 |
+
'Debussy': 4,
|
47 |
+
'Ravel': 4,
|
48 |
+
'Clementi': 5,
|
49 |
+
'Haydn': 5,
|
50 |
+
'Mozart': 5,
|
51 |
+
'Pachelbel': 5,
|
52 |
+
'Scarlatti': 5,
|
53 |
+
'Rachmaninoff': 6,
|
54 |
+
'Scriabin': 6,
|
55 |
+
'Gershwin': 7,
|
56 |
+
'Kapustin': 7
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
def extract_string(file_name):
|
61 |
+
if 'loc' not in file_name:
|
62 |
+
ind = [i.start() for i in re.finditer('_', file_name)][-1]
|
63 |
+
name = file_name[:ind]
|
64 |
+
else:
|
65 |
+
name = file_name.split('loc')[0][:-1]
|
66 |
+
return name
|
67 |
+
|
68 |
+
|
69 |
+
def find_composer(name, df):
|
70 |
+
compound_composer = df.loc[df['simple_midi_name'] == name]['canonical_composer'].item()
|
71 |
+
composer = compound_composer.split(' / ')[0].split(' ')[-1] # take the last name of the first composer
|
72 |
+
result = CLUSTERS.setdefault(composer, 8) # default cluster is everyone else (8)
|
73 |
+
return result
|
74 |
+
|
75 |
+
|
76 |
+
def load_data(
|
77 |
+
*,
|
78 |
+
data_dir,
|
79 |
+
batch_size,
|
80 |
+
class_cond=False,
|
81 |
+
deterministic=False,
|
82 |
+
):
|
83 |
+
"""
|
84 |
+
For a dataset, create a generator over (images, kwargs) pairs.
|
85 |
+
|
86 |
+
Each images is an NCHW float tensor, and the kwargs dict contains zero or
|
87 |
+
more keys, each of which map to a batched Tensor of their own.
|
88 |
+
The kwargs dict can be used for class labels, in which case the key is "y"
|
89 |
+
and the values are integer tensors of class labels.
|
90 |
+
|
91 |
+
:param data_dir: a dataset directory.
|
92 |
+
:param batch_size: the batch size of each returned pair.
|
93 |
+
:param image_size: the size to which images are resized.
|
94 |
+
:param class_cond: if True, include a "y" key in returned dicts for class
|
95 |
+
label. If classes are not available and this is true, an
|
96 |
+
exception will be raised.
|
97 |
+
:param deterministic: if True, yield results in a deterministic order.
|
98 |
+
:param random_crop: if True, randomly crop the images for augmentation.
|
99 |
+
:param random_flip: if True, randomly flip the images for augmentation.
|
100 |
+
"""
|
101 |
+
if not data_dir:
|
102 |
+
raise ValueError("unspecified data directory")
|
103 |
+
all_files = _list_image_files(data_dir)
|
104 |
+
classes = None
|
105 |
+
if class_cond:
|
106 |
+
# find the composer
|
107 |
+
parent_dir = os.path.join(*data_dir.split('/')[:-1])
|
108 |
+
if data_dir[0] == '/':
|
109 |
+
parent_dir = '/' + parent_dir
|
110 |
+
df = pd.read_csv(os.path.join(parent_dir, 'maestro-v3.0.0.csv'))
|
111 |
+
df['simple_midi_name'] = [midi_name[5:-5] for midi_name in df['midi_filename']]
|
112 |
+
all_file_names = bf.listdir(data_dir)
|
113 |
+
extracted_names = [extract_string(file_name) for file_name in all_file_names]
|
114 |
+
classes = [find_composer(name, df) for name in extracted_names]
|
115 |
+
|
116 |
+
dataset = ImageDataset(
|
117 |
+
all_files,
|
118 |
+
classes=classes,
|
119 |
+
shard=MPI.COMM_WORLD.Get_rank(),
|
120 |
+
num_shards=MPI.COMM_WORLD.Get_size(),
|
121 |
+
)
|
122 |
+
if deterministic:
|
123 |
+
loader = DataLoader(
|
124 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
loader = DataLoader(
|
128 |
+
dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
|
129 |
+
)
|
130 |
+
while True:
|
131 |
+
yield from loader
|
132 |
+
|
133 |
+
|
134 |
+
def _list_image_files(data_dir):
|
135 |
+
dirs = bf.listdir(data_dir)
|
136 |
+
return [data_dir + '/' + d for d in dirs]
|
137 |
+
|
138 |
+
|
139 |
+
class ImageDataset(Dataset):
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
image_paths,
|
143 |
+
classes=None,
|
144 |
+
shard=0,
|
145 |
+
num_shards=1,
|
146 |
+
):
|
147 |
+
super().__init__()
|
148 |
+
self.local_images = image_paths[shard:][::num_shards]
|
149 |
+
self.local_classes = None if classes is None else classes[shard:][::num_shards]
|
150 |
+
|
151 |
+
def __len__(self):
|
152 |
+
return len(self.local_images)
|
153 |
+
|
154 |
+
def __getitem__(self, idx):
|
155 |
+
path = self.local_images[idx]
|
156 |
+
arr = np.load(path)
|
157 |
+
out_dict = {}
|
158 |
+
if self.local_classes is not None:
|
159 |
+
out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
160 |
+
return arr, out_dict
|
161 |
+
|
guided_diffusion/fp16_util.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Helpers to train with 16-bit precision.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch as th
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
9 |
+
|
10 |
+
from . import logger
|
11 |
+
|
12 |
+
INITIAL_LOG_LOSS_SCALE = 20.0
|
13 |
+
|
14 |
+
|
15 |
+
def convert_module_to_f16(l):
|
16 |
+
"""
|
17 |
+
Convert primitive modules to float16.
|
18 |
+
"""
|
19 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
20 |
+
l.weight.data = l.weight.data.half()
|
21 |
+
if l.bias is not None:
|
22 |
+
l.bias.data = l.bias.data.half()
|
23 |
+
|
24 |
+
|
25 |
+
def convert_module_to_f32(l):
|
26 |
+
"""
|
27 |
+
Convert primitive modules to float32, undoing convert_module_to_f16().
|
28 |
+
"""
|
29 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
30 |
+
l.weight.data = l.weight.data.float()
|
31 |
+
if l.bias is not None:
|
32 |
+
l.bias.data = l.bias.data.float()
|
33 |
+
|
34 |
+
|
35 |
+
def make_master_params(param_groups_and_shapes):
|
36 |
+
"""
|
37 |
+
Copy model parameters into a (differently-shaped) list of full-precision
|
38 |
+
parameters.
|
39 |
+
"""
|
40 |
+
master_params = []
|
41 |
+
for param_group, shape in param_groups_and_shapes:
|
42 |
+
master_param = nn.Parameter(
|
43 |
+
_flatten_dense_tensors(
|
44 |
+
[param.detach().float() for (_, param) in param_group]
|
45 |
+
).view(shape)
|
46 |
+
)
|
47 |
+
master_param.requires_grad = True
|
48 |
+
master_params.append(master_param)
|
49 |
+
return master_params
|
50 |
+
|
51 |
+
|
52 |
+
def model_grads_to_master_grads(param_groups_and_shapes, master_params):
|
53 |
+
"""
|
54 |
+
Copy the gradients from the model parameters into the master parameters
|
55 |
+
from make_master_params().
|
56 |
+
"""
|
57 |
+
for master_param, (param_group, shape) in zip(
|
58 |
+
master_params, param_groups_and_shapes
|
59 |
+
):
|
60 |
+
master_param.grad = _flatten_dense_tensors(
|
61 |
+
[param_grad_or_zeros(param) for (_, param) in param_group]
|
62 |
+
).view(shape)
|
63 |
+
|
64 |
+
|
65 |
+
def master_params_to_model_params(param_groups_and_shapes, master_params):
|
66 |
+
"""
|
67 |
+
Copy the master parameter data back into the model parameters.
|
68 |
+
"""
|
69 |
+
# Without copying to a list, if a generator is passed, this will
|
70 |
+
# silently not copy any parameters.
|
71 |
+
for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
|
72 |
+
for (_, param), unflat_master_param in zip(
|
73 |
+
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
74 |
+
):
|
75 |
+
param.detach().copy_(unflat_master_param)
|
76 |
+
|
77 |
+
|
78 |
+
def unflatten_master_params(param_group, master_param):
|
79 |
+
return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
|
80 |
+
|
81 |
+
|
82 |
+
def get_param_groups_and_shapes(named_model_params):
|
83 |
+
named_model_params = list(named_model_params)
|
84 |
+
scalar_vector_named_params = (
|
85 |
+
[(n, p) for (n, p) in named_model_params if p.ndim <= 1],
|
86 |
+
(-1),
|
87 |
+
)
|
88 |
+
matrix_named_params = (
|
89 |
+
[(n, p) for (n, p) in named_model_params if p.ndim > 1],
|
90 |
+
(1, -1),
|
91 |
+
)
|
92 |
+
return [scalar_vector_named_params, matrix_named_params]
|
93 |
+
|
94 |
+
|
95 |
+
def master_params_to_state_dict(
|
96 |
+
model, param_groups_and_shapes, master_params, use_fp16
|
97 |
+
):
|
98 |
+
if use_fp16:
|
99 |
+
state_dict = model.state_dict()
|
100 |
+
for master_param, (param_group, _) in zip(
|
101 |
+
master_params, param_groups_and_shapes
|
102 |
+
):
|
103 |
+
for (name, _), unflat_master_param in zip(
|
104 |
+
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
105 |
+
):
|
106 |
+
assert name in state_dict
|
107 |
+
state_dict[name] = unflat_master_param
|
108 |
+
else:
|
109 |
+
state_dict = model.state_dict()
|
110 |
+
for i, (name, _value) in enumerate(model.named_parameters()):
|
111 |
+
assert name in state_dict
|
112 |
+
state_dict[name] = master_params[i]
|
113 |
+
return state_dict
|
114 |
+
|
115 |
+
|
116 |
+
def state_dict_to_master_params(model, state_dict, use_fp16):
|
117 |
+
if use_fp16:
|
118 |
+
named_model_params = [
|
119 |
+
(name, state_dict[name]) for name, _ in model.named_parameters()
|
120 |
+
]
|
121 |
+
param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
|
122 |
+
master_params = make_master_params(param_groups_and_shapes)
|
123 |
+
else:
|
124 |
+
master_params = [state_dict[name] for name, _ in model.named_parameters()]
|
125 |
+
return master_params
|
126 |
+
|
127 |
+
|
128 |
+
def zero_master_grads(master_params):
|
129 |
+
for param in master_params:
|
130 |
+
param.grad = None
|
131 |
+
|
132 |
+
|
133 |
+
def zero_grad(model_params):
|
134 |
+
for param in model_params:
|
135 |
+
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
|
136 |
+
if param.grad is not None:
|
137 |
+
param.grad.detach_()
|
138 |
+
param.grad.zero_()
|
139 |
+
|
140 |
+
|
141 |
+
def param_grad_or_zeros(param):
|
142 |
+
if param.grad is not None:
|
143 |
+
return param.grad.data.detach()
|
144 |
+
else:
|
145 |
+
return th.zeros_like(param)
|
146 |
+
|
147 |
+
|
148 |
+
class MixedPrecisionTrainer:
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
*,
|
152 |
+
model,
|
153 |
+
use_fp16=False,
|
154 |
+
fp16_scale_growth=1e-3,
|
155 |
+
initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
|
156 |
+
):
|
157 |
+
self.model = model
|
158 |
+
self.use_fp16 = use_fp16
|
159 |
+
self.fp16_scale_growth = fp16_scale_growth
|
160 |
+
|
161 |
+
self.model_params = list(self.model.parameters())
|
162 |
+
self.master_params = self.model_params
|
163 |
+
self.param_groups_and_shapes = None
|
164 |
+
self.lg_loss_scale = initial_lg_loss_scale
|
165 |
+
|
166 |
+
if self.use_fp16:
|
167 |
+
self.param_groups_and_shapes = get_param_groups_and_shapes(
|
168 |
+
self.model.named_parameters()
|
169 |
+
)
|
170 |
+
self.master_params = make_master_params(self.param_groups_and_shapes)
|
171 |
+
self.model.convert_to_fp16()
|
172 |
+
|
173 |
+
def zero_grad(self):
|
174 |
+
zero_grad(self.model_params)
|
175 |
+
|
176 |
+
def backward(self, loss: th.Tensor):
|
177 |
+
if self.use_fp16:
|
178 |
+
loss_scale = 2 ** self.lg_loss_scale
|
179 |
+
(loss * loss_scale).backward()
|
180 |
+
else:
|
181 |
+
loss.backward()
|
182 |
+
|
183 |
+
def optimize(self, opt: th.optim.Optimizer):
|
184 |
+
if self.use_fp16:
|
185 |
+
return self._optimize_fp16(opt)
|
186 |
+
else:
|
187 |
+
return self._optimize_normal(opt)
|
188 |
+
|
189 |
+
def _optimize_fp16(self, opt: th.optim.Optimizer):
|
190 |
+
logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
|
191 |
+
model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
|
192 |
+
grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
|
193 |
+
if check_overflow(grad_norm):
|
194 |
+
self.lg_loss_scale -= 1
|
195 |
+
logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
|
196 |
+
zero_master_grads(self.master_params)
|
197 |
+
return False
|
198 |
+
|
199 |
+
logger.logkv_mean("grad_norm", grad_norm)
|
200 |
+
logger.logkv_mean("param_norm", param_norm)
|
201 |
+
|
202 |
+
for p in self.master_params:
|
203 |
+
p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
|
204 |
+
opt.step()
|
205 |
+
zero_master_grads(self.master_params)
|
206 |
+
master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
|
207 |
+
self.lg_loss_scale += self.fp16_scale_growth
|
208 |
+
return True
|
209 |
+
|
210 |
+
def _optimize_normal(self, opt: th.optim.Optimizer):
|
211 |
+
grad_norm, param_norm = self._compute_norms()
|
212 |
+
logger.logkv_mean("grad_norm", grad_norm)
|
213 |
+
logger.logkv_mean("param_norm", param_norm)
|
214 |
+
opt.step()
|
215 |
+
return True
|
216 |
+
|
217 |
+
def _compute_norms(self, grad_scale=1.0):
|
218 |
+
grad_norm = 0.0
|
219 |
+
param_norm = 0.0
|
220 |
+
for p in self.master_params:
|
221 |
+
with th.no_grad():
|
222 |
+
param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
|
223 |
+
if p.grad is not None:
|
224 |
+
grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
|
225 |
+
return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
|
226 |
+
|
227 |
+
def master_params_to_state_dict(self, master_params):
|
228 |
+
return master_params_to_state_dict(
|
229 |
+
self.model, self.param_groups_and_shapes, master_params, self.use_fp16
|
230 |
+
)
|
231 |
+
|
232 |
+
def state_dict_to_master_params(self, state_dict):
|
233 |
+
return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
|
234 |
+
|
235 |
+
|
236 |
+
def check_overflow(value):
|
237 |
+
return (value == float("inf")) or (value == -float("inf")) or (value != value)
|
guided_diffusion/gaussian_diffusion.py
ADDED
@@ -0,0 +1,1400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This code started out as a PyTorch port of Ho et al's diffusion models:
|
3 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
|
4 |
+
|
5 |
+
Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import enum
|
9 |
+
import os
|
10 |
+
|
11 |
+
import math
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch as th
|
15 |
+
|
16 |
+
from .nn import mean_flat
|
17 |
+
from .losses import normal_kl, discretized_gaussian_log_likelihood
|
18 |
+
from .midi_util import save_piano_roll_midi
|
19 |
+
from music_rule_guidance.rule_maps import FUNC_DICT, LOSS_DICT
|
20 |
+
from collections import defaultdict
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import multiprocessing
|
23 |
+
from functools import partial
|
24 |
+
|
25 |
+
import matplotlib.pyplot as plt
|
26 |
+
plt.rcParams["figure.figsize"] = (20, 3)
|
27 |
+
plt.rcParams['figure.dpi'] = 300
|
28 |
+
plt.rcParams['savefig.dpi'] = 300
|
29 |
+
|
30 |
+
|
31 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
32 |
+
"""
|
33 |
+
Get a pre-defined beta schedule for the given name.
|
34 |
+
|
35 |
+
The beta schedule library consists of beta schedules which remain similar
|
36 |
+
in the limit of num_diffusion_timesteps.
|
37 |
+
Beta schedules may be added, but should not be removed or changed once
|
38 |
+
they are committed to maintain backwards compatibility.
|
39 |
+
"""
|
40 |
+
if schedule_name == "linear":
|
41 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
42 |
+
# diffusion steps.
|
43 |
+
scale = 1000 / num_diffusion_timesteps
|
44 |
+
beta_start = scale * 0.0001
|
45 |
+
beta_end = scale * 0.02
|
46 |
+
return np.linspace(
|
47 |
+
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
|
48 |
+
)
|
49 |
+
elif schedule_name == "cosine":
|
50 |
+
return betas_for_alpha_bar(
|
51 |
+
num_diffusion_timesteps,
|
52 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
53 |
+
)
|
54 |
+
elif schedule_name == 'stable-diffusion':
|
55 |
+
scale = 1000 / num_diffusion_timesteps
|
56 |
+
beta_start = scale * math.sqrt(0.00085)
|
57 |
+
beta_end = scale * math.sqrt(0.012)
|
58 |
+
return np.linspace(
|
59 |
+
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
|
60 |
+
) ** 2
|
61 |
+
else:
|
62 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
63 |
+
|
64 |
+
|
65 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
66 |
+
"""
|
67 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
68 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
69 |
+
|
70 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
71 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
72 |
+
produces the cumulative product of (1-beta) up to that
|
73 |
+
part of the diffusion process.
|
74 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
75 |
+
prevent singularities.
|
76 |
+
"""
|
77 |
+
betas = []
|
78 |
+
for i in range(num_diffusion_timesteps):
|
79 |
+
t1 = i / num_diffusion_timesteps
|
80 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
81 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
82 |
+
return np.array(betas)
|
83 |
+
|
84 |
+
|
85 |
+
class ModelMeanType(enum.Enum):
|
86 |
+
"""
|
87 |
+
Which type of output the model predicts.
|
88 |
+
"""
|
89 |
+
|
90 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
91 |
+
START_X = enum.auto() # the model predicts x_0
|
92 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
93 |
+
|
94 |
+
|
95 |
+
class ModelVarType(enum.Enum):
|
96 |
+
"""
|
97 |
+
What is used as the model's output variance.
|
98 |
+
|
99 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
100 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
101 |
+
"""
|
102 |
+
|
103 |
+
LEARNED = enum.auto()
|
104 |
+
FIXED_SMALL = enum.auto()
|
105 |
+
FIXED_LARGE = enum.auto()
|
106 |
+
LEARNED_RANGE = enum.auto()
|
107 |
+
|
108 |
+
|
109 |
+
class LossType(enum.Enum):
|
110 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
111 |
+
RESCALED_MSE = (
|
112 |
+
enum.auto()
|
113 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
114 |
+
KL = enum.auto() # use the variational lower-bound
|
115 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
116 |
+
|
117 |
+
def is_vb(self):
|
118 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
119 |
+
|
120 |
+
|
121 |
+
class GaussianDiffusion:
|
122 |
+
"""
|
123 |
+
Utilities for training and sampling diffusion models.
|
124 |
+
|
125 |
+
Ported directly from here, and then adapted over time to further experimentation.
|
126 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
127 |
+
|
128 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
129 |
+
starting at T and going to 1.
|
130 |
+
:param model_mean_type: a ModelMeanType determining what the model outputs.
|
131 |
+
:param model_var_type: a ModelVarType determining how variance is output.
|
132 |
+
:param loss_type: a LossType determining the loss function to use.
|
133 |
+
:param rescale_timesteps: if True, pass floating point timesteps into the
|
134 |
+
model so that they are always scaled like in the
|
135 |
+
original paper (0 to 1000).
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
*,
|
141 |
+
betas,
|
142 |
+
model_mean_type,
|
143 |
+
model_var_type,
|
144 |
+
loss_type,
|
145 |
+
rescale_timesteps=False,
|
146 |
+
):
|
147 |
+
self.model_mean_type = model_mean_type
|
148 |
+
self.model_var_type = model_var_type
|
149 |
+
self.loss_type = loss_type
|
150 |
+
self.rescale_timesteps = rescale_timesteps
|
151 |
+
|
152 |
+
# Use float64 for accuracy.
|
153 |
+
betas = np.array(betas, dtype=np.float64)
|
154 |
+
self.betas = betas
|
155 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
156 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
157 |
+
|
158 |
+
self.num_timesteps = int(betas.shape[0])
|
159 |
+
|
160 |
+
alphas = 1.0 - betas
|
161 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
162 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
163 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
164 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
165 |
+
|
166 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
167 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
168 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
169 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
170 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
171 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
172 |
+
|
173 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
174 |
+
self.posterior_variance = (
|
175 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
176 |
+
)
|
177 |
+
# log calculation clipped because the posterior variance is 0 at the
|
178 |
+
# beginning of the diffusion chain.
|
179 |
+
self.posterior_log_variance_clipped = np.log(
|
180 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
181 |
+
)
|
182 |
+
self.posterior_mean_coef1 = (
|
183 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
184 |
+
)
|
185 |
+
self.posterior_mean_coef2 = (
|
186 |
+
(1.0 - self.alphas_cumprod_prev)
|
187 |
+
* np.sqrt(alphas)
|
188 |
+
/ (1.0 - self.alphas_cumprod)
|
189 |
+
)
|
190 |
+
|
191 |
+
def q_mean_variance(self, x_start, t):
|
192 |
+
"""
|
193 |
+
Get the distribution q(x_t | x_0).
|
194 |
+
|
195 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
196 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
197 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
198 |
+
"""
|
199 |
+
mean = (
|
200 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
201 |
+
)
|
202 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
203 |
+
log_variance = _extract_into_tensor(
|
204 |
+
self.log_one_minus_alphas_cumprod, t, x_start.shape
|
205 |
+
)
|
206 |
+
return mean, variance, log_variance
|
207 |
+
|
208 |
+
def q_sample(self, x_start, t, noise=None):
|
209 |
+
"""
|
210 |
+
Diffuse the data for a given number of diffusion steps.
|
211 |
+
|
212 |
+
In other words, sample from q(x_t | x_0).
|
213 |
+
|
214 |
+
:param x_start: the initial data batch.
|
215 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
216 |
+
:param noise: if specified, the split-out normal noise.
|
217 |
+
:return: A noisy version of x_start.
|
218 |
+
"""
|
219 |
+
if noise is None:
|
220 |
+
noise = th.randn_like(x_start)
|
221 |
+
assert noise.shape == x_start.shape
|
222 |
+
return (
|
223 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
224 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
225 |
+
* noise
|
226 |
+
)
|
227 |
+
|
228 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
229 |
+
"""
|
230 |
+
Compute the mean and variance of the diffusion posterior:
|
231 |
+
|
232 |
+
q(x_{t-1} | x_t, x_0)
|
233 |
+
|
234 |
+
"""
|
235 |
+
assert x_start.shape == x_t.shape
|
236 |
+
posterior_mean = (
|
237 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
238 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
239 |
+
)
|
240 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
241 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
242 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
243 |
+
)
|
244 |
+
assert (
|
245 |
+
posterior_mean.shape[0]
|
246 |
+
== posterior_variance.shape[0]
|
247 |
+
== posterior_log_variance_clipped.shape[0]
|
248 |
+
== x_start.shape[0]
|
249 |
+
)
|
250 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
251 |
+
|
252 |
+
def p_mean_variance(
|
253 |
+
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
|
254 |
+
cond_fn=None, embed_model=None, edit_kwargs=None,
|
255 |
+
):
|
256 |
+
"""
|
257 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
258 |
+
the initial x, x_0.
|
259 |
+
|
260 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
261 |
+
as input.
|
262 |
+
:param x: the [N x C x ...] tensor at time t.
|
263 |
+
:param t: a 1-D Tensor of timesteps.
|
264 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
265 |
+
:param denoised_fn: if not None, a function which applies to the
|
266 |
+
x_start prediction before it is used to sample. Applies before
|
267 |
+
clip_denoised.
|
268 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
269 |
+
pass to the model. This can be used for conditioning.
|
270 |
+
:param cond_fn: log p(y|x), to maximize
|
271 |
+
:param embed_model: contains encoder and decoder
|
272 |
+
:param edit_kwargs: replacement-based conditioning
|
273 |
+
:return: a dict with the following keys:
|
274 |
+
- 'mean': the model mean output.
|
275 |
+
- 'variance': the model variance output.
|
276 |
+
- 'log_variance': the log of 'variance'.
|
277 |
+
- 'pred_xstart': the prediction for x_0.
|
278 |
+
"""
|
279 |
+
def process_xstart(x):
|
280 |
+
if denoised_fn is not None:
|
281 |
+
x = denoised_fn(x)
|
282 |
+
if clip_denoised:
|
283 |
+
return x.clamp(-1, 1)
|
284 |
+
return x
|
285 |
+
|
286 |
+
if model_kwargs is None:
|
287 |
+
model_kwargs = {}
|
288 |
+
|
289 |
+
B, C = x.shape[:2]
|
290 |
+
assert t.shape == (B,)
|
291 |
+
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
292 |
+
|
293 |
+
if edit_kwargs is not None:
|
294 |
+
pred_xstart = process_xstart(
|
295 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
296 |
+
)
|
297 |
+
replaced_x0 = edit_kwargs["mask"] * edit_kwargs["gt"] + (1 - edit_kwargs["mask"]) * pred_xstart
|
298 |
+
model_output = self._predict_eps_from_xstart(x_t=x, t=t, pred_xstart=replaced_x0)
|
299 |
+
|
300 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
301 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
302 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
303 |
+
if self.model_var_type == ModelVarType.LEARNED:
|
304 |
+
model_log_variance = model_var_values
|
305 |
+
model_variance = th.exp(model_log_variance)
|
306 |
+
else:
|
307 |
+
min_log = _extract_into_tensor(
|
308 |
+
self.posterior_log_variance_clipped, t, x.shape
|
309 |
+
)
|
310 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
311 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
312 |
+
frac = (model_var_values + 1) / 2
|
313 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
314 |
+
model_variance = th.exp(model_log_variance)
|
315 |
+
else:
|
316 |
+
model_variance, model_log_variance = {
|
317 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
318 |
+
# to get a better decoder log likelihood.
|
319 |
+
ModelVarType.FIXED_LARGE: (
|
320 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
321 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
322 |
+
),
|
323 |
+
ModelVarType.FIXED_SMALL: (
|
324 |
+
self.posterior_variance,
|
325 |
+
self.posterior_log_variance_clipped,
|
326 |
+
),
|
327 |
+
}[self.model_var_type]
|
328 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
329 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
330 |
+
|
331 |
+
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
332 |
+
pred_xstart = process_xstart(
|
333 |
+
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
|
334 |
+
)
|
335 |
+
model_mean = model_output
|
336 |
+
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
|
337 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
338 |
+
pred_xstart = process_xstart(model_output)
|
339 |
+
else:
|
340 |
+
pred_xstart = process_xstart(
|
341 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
342 |
+
)
|
343 |
+
model_mean, _, _ = self.q_posterior_mean_variance(
|
344 |
+
x_start=pred_xstart, x_t=x, t=t
|
345 |
+
)
|
346 |
+
else:
|
347 |
+
raise NotImplementedError(self.model_mean_type)
|
348 |
+
|
349 |
+
assert (
|
350 |
+
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
351 |
+
)
|
352 |
+
return {
|
353 |
+
"mean": model_mean,
|
354 |
+
"variance": model_variance,
|
355 |
+
"log_variance": model_log_variance,
|
356 |
+
"pred_xstart": pred_xstart,
|
357 |
+
}
|
358 |
+
|
359 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
360 |
+
assert x_t.shape == eps.shape
|
361 |
+
return (
|
362 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
363 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
364 |
+
)
|
365 |
+
|
366 |
+
def _predict_xstart_from_xprev(self, x_t, t, xprev):
|
367 |
+
assert x_t.shape == xprev.shape
|
368 |
+
return ( # (xprev - coef2*x_t) / coef1
|
369 |
+
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
|
370 |
+
- _extract_into_tensor(
|
371 |
+
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
|
372 |
+
)
|
373 |
+
* x_t
|
374 |
+
)
|
375 |
+
|
376 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
377 |
+
return (
|
378 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
379 |
+
- pred_xstart
|
380 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
381 |
+
|
382 |
+
def _scale_timesteps(self, t):
|
383 |
+
if self.rescale_timesteps:
|
384 |
+
return t.float() * (1000.0 / self.num_timesteps)
|
385 |
+
return t
|
386 |
+
|
387 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None, guidance_kwargs=None,
|
388 |
+
model=None, embed_model=None, edit_kwargs=None, scale_factor=1.,
|
389 |
+
record=False):
|
390 |
+
"""
|
391 |
+
Compute the mean for the previous step, given a function cond_fn that
|
392 |
+
computes the gradient of a conditional log probability with respect to
|
393 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
394 |
+
condition on y.
|
395 |
+
|
396 |
+
If dps=True, use diffusion posterior sampling, cond_fn is log p(y|x_0)
|
397 |
+
instead of the grad of it. Need to use model (eps) and embed_model.
|
398 |
+
|
399 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
400 |
+
"""
|
401 |
+
dps = True if guidance_kwargs.method == 'dps' else False
|
402 |
+
if not dps:
|
403 |
+
if edit_kwargs is None:
|
404 |
+
gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
|
405 |
+
new_mean = (
|
406 |
+
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
407 |
+
)
|
408 |
+
else:
|
409 |
+
# only compute gradient on editable latents, since rule is only on editable length
|
410 |
+
x = x[:, :, edit_kwargs["l_start"]:edit_kwargs["l_end"], :]
|
411 |
+
gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
|
412 |
+
new_mean = p_mean_var["mean"].float()
|
413 |
+
new_mean[:, :, edit_kwargs["l_start"]:edit_kwargs["l_end"], :] += (
|
414 |
+
p_mean_var["variance"] * gradient.float())
|
415 |
+
else:
|
416 |
+
assert model is not None
|
417 |
+
step_size = guidance_kwargs.step_size
|
418 |
+
with th.enable_grad():
|
419 |
+
xt = x.detach().requires_grad_(True)
|
420 |
+
eps = model(xt, self._scale_timesteps(t), **model_kwargs)
|
421 |
+
pred_xstart = self._predict_xstart_from_eps(xt, t, eps)
|
422 |
+
# If vae is not None, and not dps_nn, i.e. using dps rule
|
423 |
+
if embed_model is not None and not guidance_kwargs.nn:
|
424 |
+
pred_xstart = _decode(pred_xstart, embed_model, scale_factor=scale_factor)
|
425 |
+
if record:
|
426 |
+
pred_xstart.retain_grad()
|
427 |
+
if edit_kwargs is not None:
|
428 |
+
# only check condition on the editable part
|
429 |
+
pred_xstart = pred_xstart[:, :, edit_kwargs["l_start"]:edit_kwargs["l_end"], :]
|
430 |
+
log_probs = cond_fn(pred_xstart, self._scale_timesteps(t), **model_kwargs)
|
431 |
+
gradient = th.autograd.grad(log_probs.sum(), xt)[0]
|
432 |
+
|
433 |
+
# check if x_0 space works
|
434 |
+
if record:
|
435 |
+
pred_xstart_up = pred_xstart + pred_xstart.grad
|
436 |
+
log_probs_up = cond_fn(pred_xstart_up, self._scale_timesteps(t), **model_kwargs)
|
437 |
+
# record gradient difference
|
438 |
+
cur_grad_diff = (self.prev_gradient_single - gradient).reshape(x.shape[0], -1).norm(dim=-1)
|
439 |
+
prev_gradient_norm = self.prev_gradient_single.reshape(x.shape[0], -1).norm(dim=-1)
|
440 |
+
if prev_gradient_norm.mean() > 1e-5:
|
441 |
+
self.grad_norm.append(prev_gradient_norm.mean().item())
|
442 |
+
cur_grad_diff = cur_grad_diff / prev_gradient_norm
|
443 |
+
self.gradient_diff.append(cur_grad_diff.mean().item())
|
444 |
+
self.prev_gradient_single = gradient
|
445 |
+
self.log_probs.append((log_probs.mean().item()))
|
446 |
+
|
447 |
+
gradient = gradient / th.sqrt(-log_probs.view(x.shape[0], 1, 1, 1) + 1e-12)
|
448 |
+
# gradient = gradient / (-log_probs.view(x.shape[0], 1, 1, 1) + 1e-12)
|
449 |
+
|
450 |
+
if edit_kwargs is None:
|
451 |
+
new_mean = (
|
452 |
+
p_mean_var["mean"].float() + step_size * gradient.float()
|
453 |
+
)
|
454 |
+
else:
|
455 |
+
new_mean = p_mean_var["mean"].float()
|
456 |
+
new_mean[:, :, edit_kwargs["l_start"]:edit_kwargs["l_end"], :] += step_size * gradient.float()
|
457 |
+
|
458 |
+
# check whether moved towards good direction om z space
|
459 |
+
if record:
|
460 |
+
eps = model(xt + step_size * gradient.float(), self._scale_timesteps(t), **model_kwargs)
|
461 |
+
pred_xstart_2 = self._predict_xstart_from_eps(xt, t, eps)
|
462 |
+
pred_xstart_2 = _decode(pred_xstart_2, embed_model, scale_factor=scale_factor)
|
463 |
+
log_probs_2 = cond_fn(pred_xstart_2, self._scale_timesteps(t), **model_kwargs)
|
464 |
+
|
465 |
+
return new_mean
|
466 |
+
|
467 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
468 |
+
"""
|
469 |
+
Compute what the p_mean_variance output would have been, should the
|
470 |
+
model's score function be conditioned by cond_fn.
|
471 |
+
|
472 |
+
See condition_mean() for details on cond_fn.
|
473 |
+
|
474 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
475 |
+
from Song et al (2020).
|
476 |
+
"""
|
477 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
478 |
+
|
479 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
480 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
|
481 |
+
x, self._scale_timesteps(t), **model_kwargs
|
482 |
+
)
|
483 |
+
|
484 |
+
out = p_mean_var.copy()
|
485 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
486 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(
|
487 |
+
x_start=out["pred_xstart"], x_t=x, t=t
|
488 |
+
)
|
489 |
+
return out
|
490 |
+
|
491 |
+
def scg_sample(self,
|
492 |
+
model,
|
493 |
+
t,
|
494 |
+
mean_pred,
|
495 |
+
g_coeff,
|
496 |
+
embed_model,
|
497 |
+
scale_factor,
|
498 |
+
model_kwargs=None,
|
499 |
+
scg_kwargs=None,
|
500 |
+
edit_kwargs=None,
|
501 |
+
dc_kwargs=None,
|
502 |
+
record=False,
|
503 |
+
record_freq=100):
|
504 |
+
"""
|
505 |
+
Sample N x_{t-1} from x_t and select the best one.
|
506 |
+
"""
|
507 |
+
# mean_pred = p_mean_var["mean"]
|
508 |
+
# g_coeff = th.exp(0.5 * p_mean_var["log_variance"])
|
509 |
+
num_samples = scg_kwargs["num_samples"]
|
510 |
+
sample = mean_pred.unsqueeze(dim=0)
|
511 |
+
sample = sample.expand(num_samples, *mean_pred.shape).contiguous()
|
512 |
+
noise = th.randn_like(sample)
|
513 |
+
sample = sample + g_coeff * noise
|
514 |
+
sample = sample.view(-1, *mean_pred.shape[1:])
|
515 |
+
t = t.repeat(num_samples)
|
516 |
+
# it's fine to use different target for different samples, expand and repeat match with each other (012012)
|
517 |
+
cloned_model_kwargs = {"y": model_kwargs["y"].repeat(num_samples)}
|
518 |
+
eps = model(sample, self._scale_timesteps(t), **cloned_model_kwargs)
|
519 |
+
pred_xstart = self._predict_xstart_from_eps(sample, t, eps)
|
520 |
+
if edit_kwargs is not None:
|
521 |
+
# only decode editable part
|
522 |
+
pred_xstart = pred_xstart[:, :, edit_kwargs["l_start"]:edit_kwargs["l_end"], :]
|
523 |
+
if embed_model is not None:
|
524 |
+
pred_xstart = _decode(pred_xstart, embed_model, scale_factor=scale_factor)
|
525 |
+
|
526 |
+
if dc_kwargs is None or dc_kwargs.base <= 0:
|
527 |
+
if record:
|
528 |
+
# create dictionary to record the loss for each rule
|
529 |
+
each_loss = {}
|
530 |
+
# work with multiple rules, model_kwargs["rule"] is a dict that contains rule_name: target
|
531 |
+
total_log_prob = 0
|
532 |
+
for rule_name, rule_target in model_kwargs["rule"].items():
|
533 |
+
gen_rule = _extract_rule(rule_name, pred_xstart)
|
534 |
+
y_ = rule_target.repeat(num_samples, 1)
|
535 |
+
log_prob = - LOSS_DICT[rule_name](gen_rule, y_)
|
536 |
+
if record:
|
537 |
+
each_loss[rule_name] = -log_prob.view(num_samples, -1)
|
538 |
+
total_log_prob += log_prob * scg_kwargs.get(rule_name, 1.)
|
539 |
+
total_log_prob = total_log_prob.view(num_samples, -1)
|
540 |
+
max_ind = total_log_prob.argmax(dim=0)
|
541 |
+
|
542 |
+
# softmax (need to reweight to get unit var otherwise goes to empty rolls)
|
543 |
+
# weight = F.softmax(total_log_prob * 1., dim=0)
|
544 |
+
# var = (weight ** 2).sum(dim=0)
|
545 |
+
# avg_noise = (noise * weight[..., None, None, None]).sum(dim=0) / th.sqrt(var)[..., None, None, None]
|
546 |
+
# # not adding dw
|
547 |
+
# sample = mean_pred + g_coeff * avg_noise
|
548 |
+
# # add dw
|
549 |
+
# dw = th.randn_like(p_mean_var["mean"])
|
550 |
+
# sample = mean_pred + g_coeff * (avg_noise + dw)
|
551 |
+
|
552 |
+
# take argmax
|
553 |
+
sample = sample.view(num_samples, *mean_pred.shape)
|
554 |
+
sample = sample[max_ind, th.arange(mean_pred.shape[0])]
|
555 |
+
|
556 |
+
# take argmax, and add dw
|
557 |
+
# noise = noise.view(num_samples, *p_mean_var["mean"].shape)
|
558 |
+
# best_noise = noise[max_ind, th.arange(p_mean_var["mean"].shape[0])]
|
559 |
+
# dw = th.randn_like(p_mean_var["mean"])
|
560 |
+
# sample = p_mean_var["mean"] + th.exp(0.5 * p_mean_var["log_variance"]) * (best_noise + dw)
|
561 |
+
|
562 |
+
else:
|
563 |
+
# Assuming base length in x0 is only controlled by the corresponding location in xt
|
564 |
+
# (doesn't hold, but maybe can approximate because of cond ind)
|
565 |
+
sample = sample.view(num_samples, *mean_pred.shape)
|
566 |
+
sub_samples = []
|
567 |
+
total_length = pred_xstart.shape[-1]
|
568 |
+
start_inds = th.arange(0, total_length, dc_kwargs.base*8)
|
569 |
+
rule_base = dc_kwargs.base // 16 # number of rules under the base length
|
570 |
+
for i, start_ind in enumerate(start_inds):
|
571 |
+
end_ind = min(start_ind+dc_kwargs.base*8, total_length)
|
572 |
+
pred_xstart_cur = pred_xstart[:, :, :, start_ind: end_ind]
|
573 |
+
total_log_prob = 0
|
574 |
+
for rule_name, rule_target in model_kwargs["rule"].items():
|
575 |
+
gen_rule = _extract_rule(rule_name, pred_xstart_cur)
|
576 |
+
if rule_name == 'note_density':
|
577 |
+
half = rule_target.shape[-1] // 2
|
578 |
+
vt_nd_target = rule_target[:, :half][:, i*rule_base: min((i+1)*rule_base, half)]
|
579 |
+
hr_nd_target = rule_target[:, half:][:, i*rule_base: min((i+1)*rule_base, half)]
|
580 |
+
rule_target = th.concat((vt_nd_target, hr_nd_target), dim=-1)
|
581 |
+
elif 'chord' in rule_name:
|
582 |
+
rule_length = rule_target.shape[-1]
|
583 |
+
rule_target = rule_target[:, i*rule_base: min((i+1)*rule_base, rule_length)]
|
584 |
+
y_ = rule_target.repeat(num_samples, 1)
|
585 |
+
log_prob = - LOSS_DICT[rule_name](gen_rule, y_)
|
586 |
+
total_log_prob += log_prob * scg_kwargs.get(rule_name, 1.)
|
587 |
+
total_log_prob = total_log_prob.view(num_samples, -1)
|
588 |
+
max_ind = total_log_prob.argmax(dim=0)
|
589 |
+
# take argmax on num_sample x batch_size x 4 x 256 x 16
|
590 |
+
sub_sample = sample[max_ind, th.arange(mean_pred.shape[0]), :, start_ind//8: end_ind//8]
|
591 |
+
sub_samples.append(sub_sample)
|
592 |
+
sample = th.concat(sub_samples, dim=-2)
|
593 |
+
|
594 |
+
if record:
|
595 |
+
for rule_name, loss in each_loss.items():
|
596 |
+
current_loss = loss[max_ind, th.arange(mean_pred.shape[0])][0].item()
|
597 |
+
self.each_loss[rule_name].append((t[0].item(), current_loss))
|
598 |
+
max_log_prob = total_log_prob[max_ind, th.arange(mean_pred.shape[0])][0].item()
|
599 |
+
# record log_prob
|
600 |
+
self.log_probs.append((t[0].item(), max_log_prob))
|
601 |
+
# record loss std
|
602 |
+
self.loss_std.append((t[0].item(), total_log_prob.std().item()))
|
603 |
+
# record loss range
|
604 |
+
self.loss_range.append((t[0].item(), (max_log_prob - total_log_prob.min()).abs().item()))
|
605 |
+
# record gradient difference
|
606 |
+
noise = noise.view(num_samples, *mean_pred.shape)
|
607 |
+
gradient = noise[max_ind, th.arange(mean_pred.shape[0])]
|
608 |
+
cur_grad_diff = (self.prev_gradient_single - gradient).reshape(sample.shape[0], -1).norm(dim=-1)
|
609 |
+
prev_gradient_norm = self.prev_gradient_single.reshape(sample.shape[0], -1).norm(dim=-1)
|
610 |
+
if prev_gradient_norm.mean() > 1e-5:
|
611 |
+
self.grad_norm.append(prev_gradient_norm.mean().item())
|
612 |
+
cur_grad_diff = cur_grad_diff / prev_gradient_norm
|
613 |
+
self.gradient_diff.append(cur_grad_diff.mean().item())
|
614 |
+
self.prev_gradient_single = gradient
|
615 |
+
if (t[0] + 1) % record_freq == 0:
|
616 |
+
pred_xstart = pred_xstart.view(num_samples, -1, *pred_xstart.shape[1:])
|
617 |
+
pred_xstart = pred_xstart[max_ind, th.arange(mean_pred.shape[0])]
|
618 |
+
pred_xstart[pred_xstart <= -0.95] = -1. # heuristic thresholding the background
|
619 |
+
pred_xstart = ((pred_xstart + 1) * 63.5).clamp(0, 127).to(th.uint8)
|
620 |
+
self.inter_piano_rolls.append(pred_xstart.cpu())
|
621 |
+
|
622 |
+
# plot loss distribution
|
623 |
+
if len(model_kwargs["rule"].keys()) <= 1:
|
624 |
+
plt.figure(figsize=(4, 3))
|
625 |
+
total_log_prob = total_log_prob.view(-1).cpu()
|
626 |
+
plt.bar(range(len(total_log_prob)), -total_log_prob)
|
627 |
+
plt.xlabel('choice')
|
628 |
+
plt.ylabel('loss')
|
629 |
+
plt.title(f't={t[0]+1}')
|
630 |
+
plt.tight_layout()
|
631 |
+
plt.savefig(f'loggings/debug/t={t[0]+1}.png')
|
632 |
+
plt.show()
|
633 |
+
return sample
|
634 |
+
|
635 |
+
def p_sample(
|
636 |
+
self,
|
637 |
+
model,
|
638 |
+
x,
|
639 |
+
t,
|
640 |
+
clip_denoised=True,
|
641 |
+
denoised_fn=None,
|
642 |
+
cond_fn=None,
|
643 |
+
model_kwargs=None,
|
644 |
+
embed_model=None,
|
645 |
+
scale_factor=1.,
|
646 |
+
guidance_kwargs=None,
|
647 |
+
scg_kwargs=None,
|
648 |
+
edit_kwargs=None,
|
649 |
+
record=False,
|
650 |
+
):
|
651 |
+
"""
|
652 |
+
Sample x_{t-1} from the model at the given timestep.
|
653 |
+
|
654 |
+
:param model: the model to sample from.
|
655 |
+
:param x: the current tensor at x_{t-1}.
|
656 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
657 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
658 |
+
:param denoised_fn: if not None, a function which applies to the
|
659 |
+
x_start prediction before it is used to sample.
|
660 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
661 |
+
similarly to the model.
|
662 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
663 |
+
pass to the model. This can be used for conditioning.
|
664 |
+
:return: a dict containing the following keys:
|
665 |
+
- 'sample': a random sample from the model.
|
666 |
+
- 'pred_xstart': a prediction of x_0.
|
667 |
+
"""
|
668 |
+
if guidance_kwargs is not None:
|
669 |
+
if guidance_kwargs.schedule:
|
670 |
+
t_start = guidance_kwargs.t_start
|
671 |
+
t_end = guidance_kwargs.t_end
|
672 |
+
interval = guidance_kwargs.interval
|
673 |
+
use_guidance = guide_schedule(t, t_start, t_end, interval)
|
674 |
+
else:
|
675 |
+
use_guidance = True
|
676 |
+
else:
|
677 |
+
use_guidance = False
|
678 |
+
out = self.p_mean_variance(
|
679 |
+
model,
|
680 |
+
x,
|
681 |
+
t,
|
682 |
+
clip_denoised=clip_denoised,
|
683 |
+
denoised_fn=denoised_fn,
|
684 |
+
model_kwargs=model_kwargs,
|
685 |
+
cond_fn=cond_fn,
|
686 |
+
embed_model=embed_model,
|
687 |
+
edit_kwargs=edit_kwargs,
|
688 |
+
)
|
689 |
+
|
690 |
+
# if use scg guidance, then schedule only applies to scg sampling
|
691 |
+
if cond_fn is not None and (use_guidance or scg_kwargs is not None):
|
692 |
+
out["mean"] = self.condition_mean(
|
693 |
+
cond_fn, out, x, t, model_kwargs=model_kwargs,
|
694 |
+
guidance_kwargs=guidance_kwargs, model=model, embed_model=embed_model,
|
695 |
+
edit_kwargs=edit_kwargs, scale_factor=scale_factor
|
696 |
+
)
|
697 |
+
|
698 |
+
if scg_kwargs is None:
|
699 |
+
noise = th.randn_like(x)
|
700 |
+
nonzero_mask = (
|
701 |
+
(t > self.t_end).float().view(-1, *([1] * (len(x.shape) - 1)))
|
702 |
+
) # no noise when t == t_end (0 if not early stopping)
|
703 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
704 |
+
|
705 |
+
else: # scg search (greedy)
|
706 |
+
if t[0] > self.t_end:
|
707 |
+
mean_pred = out["mean"]
|
708 |
+
g_coeff = th.exp(0.5 * out["log_variance"])
|
709 |
+
if use_guidance:
|
710 |
+
dc_kwargs = getattr(guidance_kwargs, 'dc', None)
|
711 |
+
sample = self.scg_sample(model, t, mean_pred, g_coeff, embed_model, scale_factor,
|
712 |
+
model_kwargs=model_kwargs, scg_kwargs=scg_kwargs,
|
713 |
+
edit_kwargs=edit_kwargs, dc_kwargs=dc_kwargs, record=record)
|
714 |
+
else:
|
715 |
+
sample = mean_pred + g_coeff * th.randn_like(x)
|
716 |
+
if record:
|
717 |
+
eps = model(sample, self._scale_timesteps(t), **model_kwargs)
|
718 |
+
pred_xstart = self._predict_xstart_from_eps(sample, t, eps)
|
719 |
+
pred_xstart = _decode(pred_xstart, embed_model, scale_factor=scale_factor)
|
720 |
+
if len(model_kwargs["rule"].keys()) <= 1:
|
721 |
+
# only record for individual rule to save time
|
722 |
+
total_log_prob = 0
|
723 |
+
for rule_name, rule_target in model_kwargs["rule"].items():
|
724 |
+
gen_rule = _extract_rule(rule_name, pred_xstart)
|
725 |
+
log_prob = - LOSS_DICT[rule_name](gen_rule, rule_target)
|
726 |
+
total_log_prob += log_prob.mean().item() * scg_kwargs.get(rule_name, 1.)
|
727 |
+
self.log_probs.append((t[0].item(), total_log_prob))
|
728 |
+
if (t[0] + 1) % 100 == 0:
|
729 |
+
pred_xstart[pred_xstart <= -0.95] = -1. # heuristic thresholding the background
|
730 |
+
pred_xstart = ((pred_xstart + 1) * 63.5).clamp(0, 127).to(th.uint8)
|
731 |
+
self.inter_piano_rolls.append(pred_xstart.cpu())
|
732 |
+
else:
|
733 |
+
sample = out["mean"]
|
734 |
+
|
735 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
736 |
+
|
737 |
+
def p_sample_loop(
|
738 |
+
self,
|
739 |
+
model,
|
740 |
+
shape,
|
741 |
+
noise=None,
|
742 |
+
clip_denoised=True,
|
743 |
+
denoised_fn=None,
|
744 |
+
t_end=0,
|
745 |
+
cond_fn=None,
|
746 |
+
model_kwargs=None,
|
747 |
+
device=None,
|
748 |
+
progress=False,
|
749 |
+
embed_model=None,
|
750 |
+
scale_factor=1.,
|
751 |
+
guidance_kwargs=None,
|
752 |
+
scg_kwargs=None,
|
753 |
+
edit_kwargs=None,
|
754 |
+
record=False,
|
755 |
+
):
|
756 |
+
"""
|
757 |
+
Generate samples from the model.
|
758 |
+
|
759 |
+
:param model: the model module.
|
760 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
761 |
+
:param noise: if specified, the noise from the encoder to sample.
|
762 |
+
Should be of the same shape as `shape`.
|
763 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
764 |
+
:param denoised_fn: if not None, a function which applies to the
|
765 |
+
x_start prediction before it is used to sample.
|
766 |
+
:param t_end: early stopping for the sampling process
|
767 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
768 |
+
similarly to the model.
|
769 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
770 |
+
pass to the model. This can be used for conditioning.
|
771 |
+
:param device: if specified, the device to create the samples on.
|
772 |
+
If not specified, use a model parameter's device.
|
773 |
+
:param progress: if True, show a tqdm progress bar.
|
774 |
+
:return: a non-differentiable batch of samples.
|
775 |
+
"""
|
776 |
+
final = None
|
777 |
+
self.t_end = t_end
|
778 |
+
if record:
|
779 |
+
self.prev_gradient_single = th.zeros(shape, device=device)
|
780 |
+
self.gradient_diff = []
|
781 |
+
self.grad_norm = []
|
782 |
+
self.log_probs = []
|
783 |
+
# record loss for each rule
|
784 |
+
self.each_loss = defaultdict(list)
|
785 |
+
self.inter_piano_rolls = []
|
786 |
+
self.loss_std = []
|
787 |
+
self.loss_range = []
|
788 |
+
for sample in self.p_sample_loop_progressive(
|
789 |
+
model,
|
790 |
+
shape,
|
791 |
+
noise=noise,
|
792 |
+
clip_denoised=clip_denoised,
|
793 |
+
denoised_fn=denoised_fn,
|
794 |
+
t_end=t_end,
|
795 |
+
cond_fn=cond_fn,
|
796 |
+
model_kwargs=model_kwargs,
|
797 |
+
device=device,
|
798 |
+
progress=progress,
|
799 |
+
embed_model=embed_model,
|
800 |
+
scale_factor=scale_factor,
|
801 |
+
guidance_kwargs=guidance_kwargs,
|
802 |
+
scg_kwargs=scg_kwargs,
|
803 |
+
edit_kwargs=edit_kwargs,
|
804 |
+
record=record,
|
805 |
+
):
|
806 |
+
final = sample
|
807 |
+
return final["sample"]
|
808 |
+
|
809 |
+
def p_sample_loop_progressive(
|
810 |
+
self,
|
811 |
+
model,
|
812 |
+
shape,
|
813 |
+
noise=None,
|
814 |
+
clip_denoised=True,
|
815 |
+
denoised_fn=None,
|
816 |
+
t_end=0,
|
817 |
+
cond_fn=None,
|
818 |
+
model_kwargs=None,
|
819 |
+
device=None,
|
820 |
+
progress=False,
|
821 |
+
embed_model=None,
|
822 |
+
scale_factor=1.,
|
823 |
+
guidance_kwargs=None,
|
824 |
+
scg_kwargs=None,
|
825 |
+
edit_kwargs=None,
|
826 |
+
record=False,
|
827 |
+
):
|
828 |
+
"""
|
829 |
+
Generate samples from the model and yield intermediate samples from
|
830 |
+
each timestep of diffusion.
|
831 |
+
|
832 |
+
Arguments are the same as p_sample_loop().
|
833 |
+
Returns a generator over dicts, where each dict is the return value of
|
834 |
+
p_sample().
|
835 |
+
"""
|
836 |
+
if device is None:
|
837 |
+
device = next(model.parameters()).device
|
838 |
+
assert isinstance(shape, (tuple, list))
|
839 |
+
if noise is not None:
|
840 |
+
img = noise
|
841 |
+
elif edit_kwargs is not None:
|
842 |
+
t = th.tensor([edit_kwargs["noise_level"]-1] * shape[0], device=device)
|
843 |
+
alpha_cumprod = _extract_into_tensor(self.alphas_cumprod, t, shape)
|
844 |
+
img = th.sqrt(alpha_cumprod) * edit_kwargs["gt"] + th.sqrt((1 - alpha_cumprod)) * th.randn(*shape, device=device)
|
845 |
+
else:
|
846 |
+
img = th.randn(*shape, device=device)
|
847 |
+
indices = list(range(self.num_timesteps))[::-1]
|
848 |
+
if t_end:
|
849 |
+
indices = indices[:-t_end]
|
850 |
+
if edit_kwargs is not None:
|
851 |
+
t_start = self.num_timesteps - edit_kwargs["noise_level"]
|
852 |
+
indices = indices[t_start:]
|
853 |
+
|
854 |
+
if progress:
|
855 |
+
# Lazy import so that we don't depend on tqdm.
|
856 |
+
from tqdm.auto import tqdm
|
857 |
+
|
858 |
+
indices = tqdm(indices)
|
859 |
+
|
860 |
+
for i in indices:
|
861 |
+
t = th.tensor([i] * shape[0], device=device)
|
862 |
+
with th.no_grad():
|
863 |
+
out = self.p_sample(
|
864 |
+
model,
|
865 |
+
img,
|
866 |
+
t,
|
867 |
+
clip_denoised=clip_denoised,
|
868 |
+
denoised_fn=denoised_fn,
|
869 |
+
cond_fn=cond_fn,
|
870 |
+
model_kwargs=model_kwargs,
|
871 |
+
embed_model=embed_model,
|
872 |
+
scale_factor=scale_factor,
|
873 |
+
guidance_kwargs=guidance_kwargs,
|
874 |
+
scg_kwargs=scg_kwargs,
|
875 |
+
edit_kwargs=edit_kwargs,
|
876 |
+
record=record,
|
877 |
+
)
|
878 |
+
yield out
|
879 |
+
img = out["sample"]
|
880 |
+
|
881 |
+
def ddim_sample(
|
882 |
+
self,
|
883 |
+
model,
|
884 |
+
x,
|
885 |
+
t,
|
886 |
+
clip_denoised=True,
|
887 |
+
denoised_fn=None,
|
888 |
+
cond_fn=None,
|
889 |
+
model_kwargs=None,
|
890 |
+
eta=0.0,
|
891 |
+
embed_model=None,
|
892 |
+
scale_factor=1.,
|
893 |
+
guidance_kwargs=None,
|
894 |
+
edit_kwargs=None,
|
895 |
+
scg_kwargs=None,
|
896 |
+
record=False,
|
897 |
+
):
|
898 |
+
"""
|
899 |
+
Sample x_{t-1} from the model using DDIM.
|
900 |
+
|
901 |
+
Same usage as p_sample().
|
902 |
+
"""
|
903 |
+
if guidance_kwargs is not None:
|
904 |
+
if guidance_kwargs.schedule:
|
905 |
+
t_start = guidance_kwargs.t_start
|
906 |
+
t_end = guidance_kwargs.t_end
|
907 |
+
interval = guidance_kwargs.interval
|
908 |
+
use_guidance = guide_schedule(t, t_start, t_end, interval)
|
909 |
+
else:
|
910 |
+
use_guidance = True
|
911 |
+
else:
|
912 |
+
use_guidance = False
|
913 |
+
out = self.p_mean_variance(
|
914 |
+
model,
|
915 |
+
x,
|
916 |
+
t,
|
917 |
+
clip_denoised=clip_denoised,
|
918 |
+
denoised_fn=denoised_fn,
|
919 |
+
model_kwargs=model_kwargs,
|
920 |
+
cond_fn=cond_fn,
|
921 |
+
embed_model=embed_model,
|
922 |
+
edit_kwargs=edit_kwargs,
|
923 |
+
)
|
924 |
+
if cond_fn is not None and use_guidance:
|
925 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
926 |
+
|
927 |
+
# Usually our model outputs epsilon, but we re-derive it
|
928 |
+
# in case we used x_start or x_prev prediction.
|
929 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
930 |
+
|
931 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
932 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
933 |
+
sigma = (
|
934 |
+
eta
|
935 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
936 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
937 |
+
)
|
938 |
+
# Equation 12.
|
939 |
+
mean_pred = (
|
940 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
941 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
942 |
+
)
|
943 |
+
if scg_kwargs is None:
|
944 |
+
noise = th.randn_like(x)
|
945 |
+
nonzero_mask = (
|
946 |
+
(t != self.t_end).float().view(-1, *([1] * (len(x.shape) - 1)))
|
947 |
+
) # no noise when t == t_end (0 if not early stopping)
|
948 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
949 |
+
else:
|
950 |
+
if t[0] > self.t_end:
|
951 |
+
g_coeff = sigma
|
952 |
+
if use_guidance: # tune according to ddim steps
|
953 |
+
dc_kwargs = getattr(guidance_kwargs, 'dc', None)
|
954 |
+
sample = self.scg_sample(self._wrap_model(model), t, mean_pred, g_coeff, embed_model, scale_factor,
|
955 |
+
model_kwargs=model_kwargs, scg_kwargs=scg_kwargs, edit_kwargs=edit_kwargs,
|
956 |
+
dc_kwargs=dc_kwargs, record=record, record_freq=10)
|
957 |
+
else:
|
958 |
+
sample = mean_pred + g_coeff * th.randn_like(x)
|
959 |
+
if record:
|
960 |
+
eps = self._wrap_model(model)(sample, self._scale_timesteps(t), **model_kwargs)
|
961 |
+
pred_xstart = self._predict_xstart_from_eps(sample, t, eps)
|
962 |
+
pred_xstart = _decode(pred_xstart, embed_model, scale_factor=scale_factor)
|
963 |
+
total_log_prob = 0
|
964 |
+
for rule_name, rule_target in model_kwargs["rule"].items():
|
965 |
+
gen_rule = _extract_rule(rule_name, pred_xstart)
|
966 |
+
log_prob = - LOSS_DICT[rule_name](gen_rule, rule_target)
|
967 |
+
total_log_prob += log_prob.mean().item() * scg_kwargs.get(rule_name, 1.)
|
968 |
+
self.log_probs.append((t[0].item(), total_log_prob))
|
969 |
+
|
970 |
+
if (t[0] + 1) % 10 == 0:
|
971 |
+
pred_xstart[pred_xstart <= -0.95] = -1. # heuristic thresholding the background
|
972 |
+
pred_xstart = ((pred_xstart + 1) * 63.5).clamp(0, 127).to(th.uint8)
|
973 |
+
self.inter_piano_rolls.append(pred_xstart.cpu())
|
974 |
+
else:
|
975 |
+
sample = mean_pred
|
976 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
977 |
+
|
978 |
+
def ddim_reverse_sample(
|
979 |
+
self,
|
980 |
+
model,
|
981 |
+
x,
|
982 |
+
t,
|
983 |
+
clip_denoised=True,
|
984 |
+
denoised_fn=None,
|
985 |
+
model_kwargs=None,
|
986 |
+
eta=0.0,
|
987 |
+
):
|
988 |
+
"""
|
989 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
990 |
+
"""
|
991 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
992 |
+
out = self.p_mean_variance(
|
993 |
+
model,
|
994 |
+
x,
|
995 |
+
t,
|
996 |
+
clip_denoised=clip_denoised,
|
997 |
+
denoised_fn=denoised_fn,
|
998 |
+
model_kwargs=model_kwargs,
|
999 |
+
)
|
1000 |
+
# Usually our model outputs epsilon, but we re-derive it
|
1001 |
+
# in case we used x_start or x_prev prediction.
|
1002 |
+
eps = (
|
1003 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
1004 |
+
- out["pred_xstart"]
|
1005 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
1006 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
1007 |
+
|
1008 |
+
# Equation 12. reversed
|
1009 |
+
mean_pred = (
|
1010 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_next)
|
1011 |
+
+ th.sqrt(1 - alpha_bar_next) * eps
|
1012 |
+
)
|
1013 |
+
|
1014 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
1015 |
+
|
1016 |
+
def ddim_sample_loop(
|
1017 |
+
self,
|
1018 |
+
model,
|
1019 |
+
shape,
|
1020 |
+
noise=None,
|
1021 |
+
clip_denoised=True,
|
1022 |
+
denoised_fn=None,
|
1023 |
+
t_end=0,
|
1024 |
+
cond_fn=None,
|
1025 |
+
model_kwargs=None,
|
1026 |
+
device=None,
|
1027 |
+
progress=False,
|
1028 |
+
eta=0.0,
|
1029 |
+
embed_model=None,
|
1030 |
+
scale_factor=1.,
|
1031 |
+
guidance_kwargs=None,
|
1032 |
+
scg_kwargs=None,
|
1033 |
+
edit_kwargs=None,
|
1034 |
+
record=False,
|
1035 |
+
):
|
1036 |
+
"""
|
1037 |
+
Generate samples from the model using DDIM.
|
1038 |
+
|
1039 |
+
Same usage as p_sample_loop().
|
1040 |
+
"""
|
1041 |
+
final = None
|
1042 |
+
self.t_end = t_end
|
1043 |
+
if record:
|
1044 |
+
self.prev_gradient_single = th.zeros(shape, device=device)
|
1045 |
+
self.gradient_diff = []
|
1046 |
+
self.grad_norm = []
|
1047 |
+
self.log_probs = []
|
1048 |
+
self.inter_piano_rolls = []
|
1049 |
+
self.loss_std = []
|
1050 |
+
self.loss_range = []
|
1051 |
+
for sample in self.ddim_sample_loop_progressive(
|
1052 |
+
model,
|
1053 |
+
shape,
|
1054 |
+
noise=noise,
|
1055 |
+
clip_denoised=clip_denoised,
|
1056 |
+
denoised_fn=denoised_fn,
|
1057 |
+
t_end=t_end,
|
1058 |
+
cond_fn=cond_fn,
|
1059 |
+
model_kwargs=model_kwargs,
|
1060 |
+
device=device,
|
1061 |
+
progress=progress,
|
1062 |
+
eta=eta,
|
1063 |
+
embed_model=embed_model,
|
1064 |
+
scale_factor=scale_factor,
|
1065 |
+
guidance_kwargs=guidance_kwargs,
|
1066 |
+
scg_kwargs=scg_kwargs,
|
1067 |
+
edit_kwargs=edit_kwargs,
|
1068 |
+
record=record,
|
1069 |
+
):
|
1070 |
+
final = sample
|
1071 |
+
return final["sample"]
|
1072 |
+
|
1073 |
+
def ddim_sample_loop_progressive(
|
1074 |
+
self,
|
1075 |
+
model,
|
1076 |
+
shape,
|
1077 |
+
noise=None,
|
1078 |
+
clip_denoised=True,
|
1079 |
+
denoised_fn=None,
|
1080 |
+
t_end=0,
|
1081 |
+
cond_fn=None,
|
1082 |
+
model_kwargs=None,
|
1083 |
+
device=None,
|
1084 |
+
progress=False,
|
1085 |
+
eta=0.0,
|
1086 |
+
embed_model=None,
|
1087 |
+
scale_factor=1.,
|
1088 |
+
guidance_kwargs=None,
|
1089 |
+
scg_kwargs=None,
|
1090 |
+
edit_kwargs=None,
|
1091 |
+
record=False,
|
1092 |
+
):
|
1093 |
+
"""
|
1094 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
1095 |
+
each timestep of DDIM.
|
1096 |
+
|
1097 |
+
Same usage as p_sample_loop_progressive().
|
1098 |
+
"""
|
1099 |
+
if device is None:
|
1100 |
+
device = next(model.parameters()).device
|
1101 |
+
assert isinstance(shape, (tuple, list))
|
1102 |
+
if noise is not None:
|
1103 |
+
img = noise
|
1104 |
+
elif edit_kwargs is not None:
|
1105 |
+
t = th.tensor([edit_kwargs["noise_level"]-1] * shape[0], device=device)
|
1106 |
+
alpha_cumprod = _extract_into_tensor(self.alphas_cumprod, t, shape)
|
1107 |
+
img = th.sqrt(alpha_cumprod) * edit_kwargs["gt"] + th.sqrt((1 - alpha_cumprod)) * th.randn(*shape, device=device)
|
1108 |
+
else:
|
1109 |
+
img = th.randn(*shape, device=device)
|
1110 |
+
indices = list(range(self.num_timesteps))[::-1]
|
1111 |
+
if t_end:
|
1112 |
+
indices = indices[:-t_end]
|
1113 |
+
if edit_kwargs is not None:
|
1114 |
+
t_start = self.num_timesteps - edit_kwargs["noise_level"]
|
1115 |
+
indices = indices[t_start:]
|
1116 |
+
|
1117 |
+
if progress:
|
1118 |
+
# Lazy import so that we don't depend on tqdm.
|
1119 |
+
from tqdm.auto import tqdm
|
1120 |
+
|
1121 |
+
indices = tqdm(indices)
|
1122 |
+
|
1123 |
+
for i in indices:
|
1124 |
+
t = th.tensor([i] * shape[0], device=device)
|
1125 |
+
with th.no_grad():
|
1126 |
+
out = self.ddim_sample(
|
1127 |
+
model,
|
1128 |
+
img,
|
1129 |
+
t,
|
1130 |
+
clip_denoised=clip_denoised,
|
1131 |
+
denoised_fn=denoised_fn,
|
1132 |
+
cond_fn=cond_fn,
|
1133 |
+
model_kwargs=model_kwargs,
|
1134 |
+
eta=eta,
|
1135 |
+
embed_model=embed_model,
|
1136 |
+
scale_factor=scale_factor,
|
1137 |
+
guidance_kwargs=guidance_kwargs,
|
1138 |
+
scg_kwargs=scg_kwargs,
|
1139 |
+
edit_kwargs=edit_kwargs,
|
1140 |
+
record=record,
|
1141 |
+
)
|
1142 |
+
yield out
|
1143 |
+
img = out["sample"]
|
1144 |
+
|
1145 |
+
def _vb_terms_bpd(
|
1146 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
1147 |
+
):
|
1148 |
+
"""
|
1149 |
+
Get a term for the variational lower-bound.
|
1150 |
+
|
1151 |
+
The resulting units are bits (rather than nats, as one might expect).
|
1152 |
+
This allows for comparison to other papers.
|
1153 |
+
|
1154 |
+
:return: a dict with the following keys:
|
1155 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
1156 |
+
- 'pred_xstart': the x_0 predictions.
|
1157 |
+
"""
|
1158 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
1159 |
+
x_start=x_start, x_t=x_t, t=t
|
1160 |
+
)
|
1161 |
+
out = self.p_mean_variance(
|
1162 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
1163 |
+
)
|
1164 |
+
kl = normal_kl(
|
1165 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
1166 |
+
)
|
1167 |
+
kl = mean_flat(kl) / np.log(2.0)
|
1168 |
+
|
1169 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
1170 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
1171 |
+
)
|
1172 |
+
assert decoder_nll.shape == x_start.shape
|
1173 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
1174 |
+
|
1175 |
+
# At the first timestep return the decoder NLL,
|
1176 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
1177 |
+
output = th.where((t == 0), decoder_nll, kl)
|
1178 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
1179 |
+
|
1180 |
+
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
1181 |
+
"""
|
1182 |
+
Compute training losses for a single timestep.
|
1183 |
+
|
1184 |
+
:param model: the model to evaluate loss on.
|
1185 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
1186 |
+
:param t: a batch of timestep indices.
|
1187 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
1188 |
+
pass to the model. This can be used for conditioning.
|
1189 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
1190 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
1191 |
+
Some mean or variance settings may also have other keys.
|
1192 |
+
"""
|
1193 |
+
if model_kwargs is None:
|
1194 |
+
model_kwargs = {}
|
1195 |
+
if noise is None:
|
1196 |
+
noise = th.randn_like(x_start)
|
1197 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
1198 |
+
|
1199 |
+
terms = {}
|
1200 |
+
|
1201 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
1202 |
+
terms["loss"] = self._vb_terms_bpd(
|
1203 |
+
model=model,
|
1204 |
+
x_start=x_start,
|
1205 |
+
x_t=x_t,
|
1206 |
+
t=t,
|
1207 |
+
clip_denoised=False,
|
1208 |
+
model_kwargs=model_kwargs,
|
1209 |
+
)["output"]
|
1210 |
+
if self.loss_type == LossType.RESCALED_KL:
|
1211 |
+
terms["loss"] *= self.num_timesteps
|
1212 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
1213 |
+
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
|
1214 |
+
|
1215 |
+
if self.model_var_type in [
|
1216 |
+
ModelVarType.LEARNED,
|
1217 |
+
ModelVarType.LEARNED_RANGE,
|
1218 |
+
]:
|
1219 |
+
B, C = x_t.shape[:2]
|
1220 |
+
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
|
1221 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
1222 |
+
# Learn the variance using the variational bound, but don't let
|
1223 |
+
# it affect our mean prediction.
|
1224 |
+
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
1225 |
+
terms["vb"] = self._vb_terms_bpd(
|
1226 |
+
model=lambda *args, r=frozen_out: r,
|
1227 |
+
x_start=x_start,
|
1228 |
+
x_t=x_t,
|
1229 |
+
t=t,
|
1230 |
+
clip_denoised=False,
|
1231 |
+
)["output"]
|
1232 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
1233 |
+
# Divide by 1000 for equivalence with initial implementation.
|
1234 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
1235 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
1236 |
+
|
1237 |
+
target = {
|
1238 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
1239 |
+
x_start=x_start, x_t=x_t, t=t
|
1240 |
+
)[0],
|
1241 |
+
ModelMeanType.START_X: x_start,
|
1242 |
+
ModelMeanType.EPSILON: noise,
|
1243 |
+
}[self.model_mean_type]
|
1244 |
+
assert model_output.shape == target.shape == x_start.shape
|
1245 |
+
terms["mse"] = mean_flat((target - model_output) ** 2)
|
1246 |
+
if "vb" in terms:
|
1247 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
1248 |
+
else:
|
1249 |
+
terms["loss"] = terms["mse"]
|
1250 |
+
else:
|
1251 |
+
raise NotImplementedError(self.loss_type)
|
1252 |
+
|
1253 |
+
return terms
|
1254 |
+
|
1255 |
+
def _prior_bpd(self, x_start):
|
1256 |
+
"""
|
1257 |
+
Get the prior KL term for the variational lower-bound, measured in
|
1258 |
+
bits-per-dim.
|
1259 |
+
|
1260 |
+
This term can't be optimized, as it only depends on the encoder.
|
1261 |
+
|
1262 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
1263 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
1264 |
+
"""
|
1265 |
+
batch_size = x_start.shape[0]
|
1266 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
1267 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
1268 |
+
kl_prior = normal_kl(
|
1269 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
1270 |
+
)
|
1271 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
1272 |
+
|
1273 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
1274 |
+
"""
|
1275 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
1276 |
+
as well as other related quantities.
|
1277 |
+
|
1278 |
+
:param model: the model to evaluate loss on.
|
1279 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
1280 |
+
:param clip_denoised: if True, clip denoised samples.
|
1281 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
1282 |
+
pass to the model. This can be used for conditioning.
|
1283 |
+
|
1284 |
+
:return: a dict containing the following keys:
|
1285 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
1286 |
+
- prior_bpd: the prior term in the lower-bound.
|
1287 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
1288 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
1289 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
1290 |
+
"""
|
1291 |
+
device = x_start.device
|
1292 |
+
batch_size = x_start.shape[0]
|
1293 |
+
|
1294 |
+
vb = []
|
1295 |
+
xstart_mse = []
|
1296 |
+
mse = []
|
1297 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
1298 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
1299 |
+
noise = th.randn_like(x_start)
|
1300 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
1301 |
+
# Calculate VLB term at the current timestep
|
1302 |
+
with th.no_grad():
|
1303 |
+
out = self._vb_terms_bpd(
|
1304 |
+
model,
|
1305 |
+
x_start=x_start,
|
1306 |
+
x_t=x_t,
|
1307 |
+
t=t_batch,
|
1308 |
+
clip_denoised=clip_denoised,
|
1309 |
+
model_kwargs=model_kwargs,
|
1310 |
+
)
|
1311 |
+
vb.append(out["output"])
|
1312 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
1313 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
1314 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
1315 |
+
|
1316 |
+
vb = th.stack(vb, dim=1)
|
1317 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
1318 |
+
mse = th.stack(mse, dim=1)
|
1319 |
+
|
1320 |
+
prior_bpd = self._prior_bpd(x_start)
|
1321 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
1322 |
+
return {
|
1323 |
+
"total_bpd": total_bpd,
|
1324 |
+
"prior_bpd": prior_bpd,
|
1325 |
+
"vb": vb,
|
1326 |
+
"xstart_mse": xstart_mse,
|
1327 |
+
"mse": mse,
|
1328 |
+
}
|
1329 |
+
|
1330 |
+
|
1331 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
1332 |
+
"""
|
1333 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
1334 |
+
|
1335 |
+
:param arr: the 1-D numpy array.
|
1336 |
+
:param timesteps: a tensor of indices into the array to extract.
|
1337 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
1338 |
+
dimension equal to the length of timesteps.
|
1339 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
1340 |
+
"""
|
1341 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
1342 |
+
while len(res.shape) < len(broadcast_shape):
|
1343 |
+
res = res[..., None]
|
1344 |
+
return res.expand(broadcast_shape)
|
1345 |
+
|
1346 |
+
|
1347 |
+
def _decode(pred_zstart, embed_model, scale_factor=1., threshold=False):
|
1348 |
+
image_size_h = pred_zstart.shape[-2]
|
1349 |
+
image_size_w = pred_zstart.shape[-1]
|
1350 |
+
pred_zstart = pred_zstart / scale_factor
|
1351 |
+
sample = pred_zstart.permute(0, 1, 3, 2)
|
1352 |
+
sample = th.chunk(sample, image_size_h // image_size_w, dim=-1) # B x C x H x W
|
1353 |
+
sample = th.concat(sample, dim=0) # 1st second for all batch, 2nd second for all batch, ...
|
1354 |
+
sample = embed_model.decode(sample)
|
1355 |
+
pred_xstart = th.concat(th.chunk(sample, image_size_h // image_size_w, dim=0), dim=-1)
|
1356 |
+
if threshold:
|
1357 |
+
pred_xstart[pred_xstart <= -0.95] = -1. # heuristic thresholding the background
|
1358 |
+
return pred_xstart
|
1359 |
+
|
1360 |
+
|
1361 |
+
def _extract_rule(rule_name, pred_xstart):
|
1362 |
+
device = pred_xstart.device
|
1363 |
+
if 'chord' in rule_name:
|
1364 |
+
# Split tensor batch into smaller batches
|
1365 |
+
num_processes = 4
|
1366 |
+
pred_xstart = pred_xstart.cpu()
|
1367 |
+
pred_xstart_split = th.chunk(pred_xstart, num_processes)
|
1368 |
+
# rule_func = partial(FUNC_DICT[rule_name], given_key="C major") # todo: hard code key here
|
1369 |
+
rule_func = FUNC_DICT[rule_name]
|
1370 |
+
with multiprocessing.Pool(processes=num_processes) as pool:
|
1371 |
+
gen_rule = pool.map(rule_func, pred_xstart_split)
|
1372 |
+
# Combine results
|
1373 |
+
if len(gen_rule[0].shape) == 1: # batch_size * branching_factor < 4
|
1374 |
+
gen_rule = [item.unsqueeze(dim=0) for item in gen_rule]
|
1375 |
+
gen_rule = th.concat(gen_rule, dim=0).to(device)
|
1376 |
+
|
1377 |
+
else:
|
1378 |
+
gen_rule = FUNC_DICT[rule_name](pred_xstart)
|
1379 |
+
return gen_rule
|
1380 |
+
|
1381 |
+
|
1382 |
+
def _encode(pred_xstart, embed_model, scale_factor=1.):
|
1383 |
+
image_size_h = pred_xstart.shape[-2]
|
1384 |
+
image_size_w = pred_xstart.shape[-1]
|
1385 |
+
seq_len = image_size_w // image_size_h
|
1386 |
+
micro = th.chunk(pred_xstart, seq_len, dim=-1) # B x C x H x W
|
1387 |
+
micro = th.concat(micro, dim=0) # 1st second for all batch, 2nd second for all batch, ...
|
1388 |
+
micro = embed_model.encode_save(micro, range_fix=False)
|
1389 |
+
if micro.shape[1] == 8:
|
1390 |
+
z, _ = th.chunk(micro, 2, dim=1)
|
1391 |
+
else:
|
1392 |
+
z = micro
|
1393 |
+
z = th.concat(th.chunk(z, seq_len, dim=0), dim=-1)
|
1394 |
+
z = z.permute(0, 1, 3, 2)
|
1395 |
+
return z * scale_factor
|
1396 |
+
|
1397 |
+
|
1398 |
+
def guide_schedule(t, t_start=750, t_end=0, interval=1):
|
1399 |
+
flag = t_start > t[0] >= t_end and (t[0] + 1) % interval == 0
|
1400 |
+
return flag
|
guided_diffusion/logger.py
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
|
3 |
+
https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import shutil
|
9 |
+
import os.path as osp
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
import datetime
|
13 |
+
import tempfile
|
14 |
+
import warnings
|
15 |
+
from collections import defaultdict
|
16 |
+
from contextlib import contextmanager
|
17 |
+
|
18 |
+
import wandb
|
19 |
+
|
20 |
+
DEBUG = 10
|
21 |
+
INFO = 20
|
22 |
+
WARN = 30
|
23 |
+
ERROR = 40
|
24 |
+
|
25 |
+
DISABLED = 50
|
26 |
+
|
27 |
+
|
28 |
+
class KVWriter(object):
|
29 |
+
def writekvs(self, kvs):
|
30 |
+
raise NotImplementedError
|
31 |
+
|
32 |
+
|
33 |
+
class SeqWriter(object):
|
34 |
+
def writeseq(self, seq):
|
35 |
+
raise NotImplementedError
|
36 |
+
|
37 |
+
|
38 |
+
class HumanOutputFormat(KVWriter, SeqWriter):
|
39 |
+
def __init__(self, filename_or_file):
|
40 |
+
if isinstance(filename_or_file, str):
|
41 |
+
self.file = open(filename_or_file, "wt")
|
42 |
+
self.own_file = True
|
43 |
+
else:
|
44 |
+
assert hasattr(filename_or_file, "read"), (
|
45 |
+
"expected file or str, got %s" % filename_or_file
|
46 |
+
)
|
47 |
+
self.file = filename_or_file
|
48 |
+
self.own_file = False
|
49 |
+
|
50 |
+
def writekvs(self, kvs):
|
51 |
+
# Create strings for printing
|
52 |
+
key2str = {}
|
53 |
+
for (key, val) in sorted(kvs.items()):
|
54 |
+
if hasattr(val, "__float__"):
|
55 |
+
valstr = "%-8.3g" % val
|
56 |
+
else:
|
57 |
+
valstr = str(val)
|
58 |
+
key2str[self._truncate(key)] = self._truncate(valstr)
|
59 |
+
|
60 |
+
# Find max widths
|
61 |
+
if len(key2str) == 0:
|
62 |
+
print("WARNING: tried to write empty key-value dict")
|
63 |
+
return
|
64 |
+
else:
|
65 |
+
keywidth = max(map(len, key2str.keys()))
|
66 |
+
valwidth = max(map(len, key2str.values()))
|
67 |
+
|
68 |
+
# Write out the data
|
69 |
+
dashes = "-" * (keywidth + valwidth + 7)
|
70 |
+
lines = [dashes]
|
71 |
+
for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
|
72 |
+
lines.append(
|
73 |
+
"| %s%s | %s%s |"
|
74 |
+
% (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
|
75 |
+
)
|
76 |
+
lines.append(dashes)
|
77 |
+
self.file.write("\n".join(lines) + "\n")
|
78 |
+
|
79 |
+
# Flush the output to the file
|
80 |
+
self.file.flush()
|
81 |
+
|
82 |
+
def _truncate(self, s):
|
83 |
+
maxlen = 30
|
84 |
+
return s[: maxlen - 3] + "..." if len(s) > maxlen else s
|
85 |
+
|
86 |
+
def writeseq(self, seq):
|
87 |
+
seq = list(seq)
|
88 |
+
for (i, elem) in enumerate(seq):
|
89 |
+
self.file.write(elem)
|
90 |
+
if i < len(seq) - 1: # add space unless this is the last one
|
91 |
+
self.file.write(" ")
|
92 |
+
self.file.write("\n")
|
93 |
+
self.file.flush()
|
94 |
+
|
95 |
+
def close(self):
|
96 |
+
if self.own_file:
|
97 |
+
self.file.close()
|
98 |
+
|
99 |
+
|
100 |
+
class JSONOutputFormat(KVWriter):
|
101 |
+
def __init__(self, filename):
|
102 |
+
self.file = open(filename, "wt")
|
103 |
+
|
104 |
+
def writekvs(self, kvs):
|
105 |
+
for k, v in sorted(kvs.items()):
|
106 |
+
if hasattr(v, "dtype"):
|
107 |
+
kvs[k] = float(v)
|
108 |
+
self.file.write(json.dumps(kvs) + "\n")
|
109 |
+
self.file.flush()
|
110 |
+
|
111 |
+
def close(self):
|
112 |
+
self.file.close()
|
113 |
+
|
114 |
+
|
115 |
+
class CSVOutputFormat(KVWriter):
|
116 |
+
def __init__(self, filename):
|
117 |
+
self.file = open(filename, "w+t")
|
118 |
+
self.keys = []
|
119 |
+
self.sep = ","
|
120 |
+
|
121 |
+
def writekvs(self, kvs):
|
122 |
+
# Add our current row to the history
|
123 |
+
extra_keys = list(kvs.keys() - self.keys)
|
124 |
+
extra_keys.sort()
|
125 |
+
if extra_keys:
|
126 |
+
self.keys.extend(extra_keys)
|
127 |
+
self.file.seek(0)
|
128 |
+
lines = self.file.readlines()
|
129 |
+
self.file.seek(0)
|
130 |
+
for (i, k) in enumerate(self.keys):
|
131 |
+
if i > 0:
|
132 |
+
self.file.write(",")
|
133 |
+
self.file.write(k)
|
134 |
+
self.file.write("\n")
|
135 |
+
for line in lines[1:]:
|
136 |
+
self.file.write(line[:-1])
|
137 |
+
self.file.write(self.sep * len(extra_keys))
|
138 |
+
self.file.write("\n")
|
139 |
+
for (i, k) in enumerate(self.keys):
|
140 |
+
if i > 0:
|
141 |
+
self.file.write(",")
|
142 |
+
v = kvs.get(k)
|
143 |
+
if v is not None:
|
144 |
+
self.file.write(str(v))
|
145 |
+
self.file.write("\n")
|
146 |
+
self.file.flush()
|
147 |
+
|
148 |
+
def close(self):
|
149 |
+
self.file.close()
|
150 |
+
|
151 |
+
|
152 |
+
class TensorBoardOutputFormat(KVWriter):
|
153 |
+
"""
|
154 |
+
Dumps key/value pairs into TensorBoard's numeric format.
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(self, dir):
|
158 |
+
os.makedirs(dir, exist_ok=True)
|
159 |
+
self.dir = dir
|
160 |
+
self.step = 1
|
161 |
+
prefix = "events"
|
162 |
+
path = osp.join(osp.abspath(dir), prefix)
|
163 |
+
import tensorflow as tf
|
164 |
+
from tensorflow.python import pywrap_tensorflow
|
165 |
+
from tensorflow.core.util import event_pb2
|
166 |
+
from tensorflow.python.util import compat
|
167 |
+
|
168 |
+
self.tf = tf
|
169 |
+
self.event_pb2 = event_pb2
|
170 |
+
self.pywrap_tensorflow = pywrap_tensorflow
|
171 |
+
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
|
172 |
+
|
173 |
+
def writekvs(self, kvs):
|
174 |
+
def summary_val(k, v):
|
175 |
+
kwargs = {"tag": k, "simple_value": float(v)}
|
176 |
+
return self.tf.Summary.Value(**kwargs)
|
177 |
+
|
178 |
+
summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
|
179 |
+
event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
|
180 |
+
event.step = (
|
181 |
+
self.step
|
182 |
+
) # is there any reason why you'd want to specify the step?
|
183 |
+
self.writer.WriteEvent(event)
|
184 |
+
self.writer.Flush()
|
185 |
+
self.step += 1
|
186 |
+
|
187 |
+
def close(self):
|
188 |
+
if self.writer:
|
189 |
+
self.writer.Close()
|
190 |
+
self.writer = None
|
191 |
+
|
192 |
+
|
193 |
+
class WandbOutputFormat(KVWriter):
|
194 |
+
def __init__(self, args):
|
195 |
+
wandb.init(project=args.project, config=vars(args))
|
196 |
+
|
197 |
+
def writekvs(self, kvs):
|
198 |
+
step = int(kvs["step"])
|
199 |
+
wandb.log(kvs, step=step)
|
200 |
+
|
201 |
+
def close(self):
|
202 |
+
pass
|
203 |
+
|
204 |
+
|
205 |
+
def make_output_format(format, ev_dir, args, log_suffix=""):
|
206 |
+
os.makedirs(ev_dir, exist_ok=True)
|
207 |
+
if format == "stdout":
|
208 |
+
return HumanOutputFormat(sys.stdout)
|
209 |
+
elif format == "log":
|
210 |
+
return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
|
211 |
+
elif format == "json":
|
212 |
+
return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
|
213 |
+
elif format == "csv":
|
214 |
+
return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
|
215 |
+
elif format == "tensorboard":
|
216 |
+
return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
|
217 |
+
elif format == "wandb":
|
218 |
+
return WandbOutputFormat(args)
|
219 |
+
else:
|
220 |
+
raise ValueError("Unknown format specified: %s" % (format,))
|
221 |
+
|
222 |
+
|
223 |
+
# ================================================================
|
224 |
+
# API
|
225 |
+
# ================================================================
|
226 |
+
|
227 |
+
|
228 |
+
def logkv(key, val):
|
229 |
+
"""
|
230 |
+
Log a value of some diagnostic
|
231 |
+
Call this once for each diagnostic quantity, each iteration
|
232 |
+
If called many times, last value will be used.
|
233 |
+
"""
|
234 |
+
get_current().logkv(key, val)
|
235 |
+
|
236 |
+
|
237 |
+
def logkv_mean(key, val):
|
238 |
+
"""
|
239 |
+
The same as logkv(), but if called many times, values averaged.
|
240 |
+
"""
|
241 |
+
get_current().logkv_mean(key, val)
|
242 |
+
|
243 |
+
|
244 |
+
def logkvs(d):
|
245 |
+
"""
|
246 |
+
Log a dictionary of key-value pairs
|
247 |
+
"""
|
248 |
+
for (k, v) in d.items():
|
249 |
+
logkv(k, v)
|
250 |
+
|
251 |
+
|
252 |
+
def dumpkvs():
|
253 |
+
"""
|
254 |
+
Write all of the diagnostics from the current iteration
|
255 |
+
"""
|
256 |
+
return get_current().dumpkvs()
|
257 |
+
|
258 |
+
|
259 |
+
def getkvs():
|
260 |
+
return get_current().name2val
|
261 |
+
|
262 |
+
|
263 |
+
def log(*args, level=INFO):
|
264 |
+
"""
|
265 |
+
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
|
266 |
+
"""
|
267 |
+
get_current().log(*args, level=level)
|
268 |
+
|
269 |
+
|
270 |
+
def debug(*args):
|
271 |
+
log(*args, level=DEBUG)
|
272 |
+
|
273 |
+
|
274 |
+
def info(*args):
|
275 |
+
log(*args, level=INFO)
|
276 |
+
|
277 |
+
|
278 |
+
def warn(*args):
|
279 |
+
log(*args, level=WARN)
|
280 |
+
|
281 |
+
|
282 |
+
def error(*args):
|
283 |
+
log(*args, level=ERROR)
|
284 |
+
|
285 |
+
|
286 |
+
def set_level(level):
|
287 |
+
"""
|
288 |
+
Set logging threshold on current logger.
|
289 |
+
"""
|
290 |
+
get_current().set_level(level)
|
291 |
+
|
292 |
+
|
293 |
+
def set_comm(comm):
|
294 |
+
get_current().set_comm(comm)
|
295 |
+
|
296 |
+
|
297 |
+
def get_dir():
|
298 |
+
"""
|
299 |
+
Get directory that log files are being written to.
|
300 |
+
will be None if there is no output directory (i.e., if you didn't call start)
|
301 |
+
"""
|
302 |
+
return get_current().get_dir()
|
303 |
+
|
304 |
+
|
305 |
+
record_tabular = logkv
|
306 |
+
dump_tabular = dumpkvs
|
307 |
+
|
308 |
+
|
309 |
+
@contextmanager
|
310 |
+
def profile_kv(scopename):
|
311 |
+
logkey = "wait_" + scopename
|
312 |
+
tstart = time.time()
|
313 |
+
try:
|
314 |
+
yield
|
315 |
+
finally:
|
316 |
+
get_current().name2val[logkey] += time.time() - tstart
|
317 |
+
|
318 |
+
|
319 |
+
def profile(n):
|
320 |
+
"""
|
321 |
+
Usage:
|
322 |
+
@profile("my_func")
|
323 |
+
def my_func(): code
|
324 |
+
"""
|
325 |
+
|
326 |
+
def decorator_with_name(func):
|
327 |
+
def func_wrapper(*args, **kwargs):
|
328 |
+
with profile_kv(n):
|
329 |
+
return func(*args, **kwargs)
|
330 |
+
|
331 |
+
return func_wrapper
|
332 |
+
|
333 |
+
return decorator_with_name
|
334 |
+
|
335 |
+
|
336 |
+
# ================================================================
|
337 |
+
# Backend
|
338 |
+
# ================================================================
|
339 |
+
|
340 |
+
|
341 |
+
def get_current():
|
342 |
+
if Logger.CURRENT is None:
|
343 |
+
_configure_default_logger()
|
344 |
+
|
345 |
+
return Logger.CURRENT
|
346 |
+
|
347 |
+
|
348 |
+
class Logger(object):
|
349 |
+
DEFAULT = None # A logger with no output files. (See right below class definition)
|
350 |
+
# So that you can still log to the terminal without setting up any output files
|
351 |
+
CURRENT = None # Current logger being used by the free functions above
|
352 |
+
|
353 |
+
def __init__(self, dir, output_formats, comm=None):
|
354 |
+
self.name2val = defaultdict(float) # values this iteration
|
355 |
+
self.name2cnt = defaultdict(int)
|
356 |
+
self.level = INFO
|
357 |
+
self.dir = dir
|
358 |
+
self.output_formats = output_formats
|
359 |
+
self.comm = comm
|
360 |
+
|
361 |
+
# Logging API, forwarded
|
362 |
+
# ----------------------------------------
|
363 |
+
def logkv(self, key, val):
|
364 |
+
self.name2val[key] = val
|
365 |
+
|
366 |
+
def logkv_mean(self, key, val):
|
367 |
+
oldval, cnt = self.name2val[key], self.name2cnt[key]
|
368 |
+
self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
|
369 |
+
self.name2cnt[key] = cnt + 1
|
370 |
+
|
371 |
+
def dumpkvs(self):
|
372 |
+
if self.comm is None:
|
373 |
+
d = self.name2val
|
374 |
+
else:
|
375 |
+
d = mpi_weighted_mean(
|
376 |
+
self.comm,
|
377 |
+
{
|
378 |
+
name: (val, self.name2cnt.get(name, 1))
|
379 |
+
for (name, val) in self.name2val.items()
|
380 |
+
},
|
381 |
+
)
|
382 |
+
if self.comm.rank != 0:
|
383 |
+
d["dummy"] = 1 # so we don't get a warning about empty dict
|
384 |
+
out = d.copy() # Return the dict for unit testing purposes
|
385 |
+
for fmt in self.output_formats:
|
386 |
+
if isinstance(fmt, KVWriter):
|
387 |
+
fmt.writekvs(d)
|
388 |
+
self.name2val.clear()
|
389 |
+
self.name2cnt.clear()
|
390 |
+
return out
|
391 |
+
|
392 |
+
def log(self, *args, level=INFO):
|
393 |
+
if self.level <= level:
|
394 |
+
self._do_log(args)
|
395 |
+
|
396 |
+
# Configuration
|
397 |
+
# ----------------------------------------
|
398 |
+
def set_level(self, level):
|
399 |
+
self.level = level
|
400 |
+
|
401 |
+
def set_comm(self, comm):
|
402 |
+
self.comm = comm
|
403 |
+
|
404 |
+
def get_dir(self):
|
405 |
+
return self.dir
|
406 |
+
|
407 |
+
def close(self):
|
408 |
+
for fmt in self.output_formats:
|
409 |
+
fmt.close()
|
410 |
+
|
411 |
+
# Misc
|
412 |
+
# ----------------------------------------
|
413 |
+
def _do_log(self, args):
|
414 |
+
for fmt in self.output_formats:
|
415 |
+
if isinstance(fmt, SeqWriter):
|
416 |
+
fmt.writeseq(map(str, args))
|
417 |
+
|
418 |
+
|
419 |
+
def get_rank_without_mpi_import():
|
420 |
+
# check environment variables here instead of importing mpi4py
|
421 |
+
# to avoid calling MPI_Init() when this module is imported
|
422 |
+
for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
|
423 |
+
if varname in os.environ:
|
424 |
+
return int(os.environ[varname])
|
425 |
+
return 0
|
426 |
+
|
427 |
+
|
428 |
+
def mpi_weighted_mean(comm, local_name2valcount):
|
429 |
+
"""
|
430 |
+
Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
|
431 |
+
Perform a weighted average over dicts that are each on a different node
|
432 |
+
Input: local_name2valcount: dict mapping key -> (value, count)
|
433 |
+
Returns: key -> mean
|
434 |
+
"""
|
435 |
+
all_name2valcount = comm.gather(local_name2valcount)
|
436 |
+
if comm.rank == 0:
|
437 |
+
name2sum = defaultdict(float)
|
438 |
+
name2count = defaultdict(float)
|
439 |
+
for n2vc in all_name2valcount:
|
440 |
+
for (name, (val, count)) in n2vc.items():
|
441 |
+
try:
|
442 |
+
val = float(val)
|
443 |
+
except ValueError:
|
444 |
+
if comm.rank == 0:
|
445 |
+
warnings.warn(
|
446 |
+
"WARNING: tried to compute mean on non-float {}={}".format(
|
447 |
+
name, val
|
448 |
+
)
|
449 |
+
)
|
450 |
+
else:
|
451 |
+
name2sum[name] += val * count
|
452 |
+
name2count[name] += count
|
453 |
+
return {name: name2sum[name] / name2count[name] for name in name2sum}
|
454 |
+
else:
|
455 |
+
return {}
|
456 |
+
|
457 |
+
|
458 |
+
def configure(args=None, format_strs=None, comm=None, log_suffix=""):
|
459 |
+
"""
|
460 |
+
If comm is provided, average all numerical stats across that comm
|
461 |
+
"""
|
462 |
+
dir = args.dir
|
463 |
+
if dir is not None:
|
464 |
+
if "loggings" not in dir: # save under cur dir
|
465 |
+
dir = osp.join("loggings", dir)
|
466 |
+
else:
|
467 |
+
if dir is None:
|
468 |
+
dir = os.getenv("OPENAI_LOGDIR")
|
469 |
+
if dir is None:
|
470 |
+
dir = osp.join(
|
471 |
+
# tempfile.gettempdir(),
|
472 |
+
"loggings",
|
473 |
+
datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
|
474 |
+
)
|
475 |
+
assert isinstance(dir, str)
|
476 |
+
dir = os.path.expanduser(dir)
|
477 |
+
os.makedirs(os.path.expanduser(dir), exist_ok=True)
|
478 |
+
if args.training:
|
479 |
+
# make dir for samples and checkpoints if training the model
|
480 |
+
os.makedirs(os.path.expanduser(osp.join(dir, "samples")), exist_ok=True)
|
481 |
+
os.makedirs(os.path.expanduser(osp.join(dir, "checkpoints")), exist_ok=True)
|
482 |
+
|
483 |
+
rank = get_rank_without_mpi_import()
|
484 |
+
if rank > 0:
|
485 |
+
log_suffix = log_suffix + "-rank%03i" % rank
|
486 |
+
|
487 |
+
if format_strs is None:
|
488 |
+
if rank == 0:
|
489 |
+
format_strs = os.getenv("OPENAI_LOG_FORMAT", "wandb,stdout,log,csv").split(",")
|
490 |
+
else:
|
491 |
+
format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
|
492 |
+
format_strs = filter(None, format_strs)
|
493 |
+
output_formats = [make_output_format(f, dir, args, log_suffix) for f in format_strs]
|
494 |
+
|
495 |
+
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
|
496 |
+
if output_formats:
|
497 |
+
log("Logging to %s" % dir)
|
498 |
+
|
499 |
+
|
500 |
+
def _configure_default_logger():
|
501 |
+
configure()
|
502 |
+
Logger.DEFAULT = Logger.CURRENT
|
503 |
+
|
504 |
+
|
505 |
+
def reset():
|
506 |
+
if Logger.CURRENT is not Logger.DEFAULT:
|
507 |
+
Logger.CURRENT.close()
|
508 |
+
Logger.CURRENT = Logger.DEFAULT
|
509 |
+
log("Reset logger")
|
510 |
+
|
511 |
+
|
512 |
+
@contextmanager
|
513 |
+
def scoped_configure(dir=None, format_strs=None, comm=None):
|
514 |
+
prevlogger = Logger.CURRENT
|
515 |
+
configure(dir=dir, format_strs=format_strs, comm=comm)
|
516 |
+
try:
|
517 |
+
yield
|
518 |
+
finally:
|
519 |
+
Logger.CURRENT.close()
|
520 |
+
Logger.CURRENT = prevlogger
|
521 |
+
|
guided_diffusion/losses.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Helpers for various likelihood-based losses. These are ported from the original
|
3 |
+
Ho et al. diffusion models codebase:
|
4 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch as th
|
10 |
+
|
11 |
+
|
12 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
13 |
+
"""
|
14 |
+
Compute the KL divergence between two gaussians.
|
15 |
+
|
16 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
17 |
+
scalars, among other use cases.
|
18 |
+
"""
|
19 |
+
tensor = None
|
20 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
21 |
+
if isinstance(obj, th.Tensor):
|
22 |
+
tensor = obj
|
23 |
+
break
|
24 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
25 |
+
|
26 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
27 |
+
# Tensors, but it does not work for th.exp().
|
28 |
+
logvar1, logvar2 = [
|
29 |
+
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
30 |
+
for x in (logvar1, logvar2)
|
31 |
+
]
|
32 |
+
|
33 |
+
return 0.5 * (
|
34 |
+
-1.0
|
35 |
+
+ logvar2
|
36 |
+
- logvar1
|
37 |
+
+ th.exp(logvar1 - logvar2)
|
38 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def approx_standard_normal_cdf(x):
|
43 |
+
"""
|
44 |
+
A fast approximation of the cumulative distribution function of the
|
45 |
+
standard normal.
|
46 |
+
"""
|
47 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
48 |
+
|
49 |
+
|
50 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
51 |
+
"""
|
52 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
53 |
+
given image.
|
54 |
+
|
55 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
56 |
+
rescaled to the range [-1, 1].
|
57 |
+
:param means: the Gaussian mean Tensor.
|
58 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
59 |
+
:return: a tensor like x of log probabilities (in nats).
|
60 |
+
"""
|
61 |
+
assert x.shape == means.shape == log_scales.shape
|
62 |
+
centered_x = x - means
|
63 |
+
inv_stdv = th.exp(-log_scales)
|
64 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
65 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
66 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
67 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
68 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
69 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
70 |
+
cdf_delta = cdf_plus - cdf_min
|
71 |
+
log_probs = th.where(
|
72 |
+
x < -0.999,
|
73 |
+
log_cdf_plus,
|
74 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
75 |
+
)
|
76 |
+
assert log_probs.shape == x.shape
|
77 |
+
return log_probs
|
guided_diffusion/midi_util.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import pretty_midi
|
7 |
+
import matplotlib as mpl
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from . import dist_util
|
10 |
+
import yaml
|
11 |
+
from types import SimpleNamespace
|
12 |
+
from music_rule_guidance.piano_roll_to_chord import piano_roll_to_pretty_midi, KEY_DICT, IND2KEY
|
13 |
+
from music_rule_guidance.rule_maps import FUNC_DICT, LOSS_DICT
|
14 |
+
from music_rule_guidance.music_rules import MAX_PIANO, MIN_PIANO
|
15 |
+
|
16 |
+
plt.rcParams['figure.dpi'] = 300
|
17 |
+
plt.rcParams['savefig.dpi'] = 300
|
18 |
+
|
19 |
+
# bounds to compute classes for nd editing
|
20 |
+
VERTICAL_ND_BOUNDS = [1.29, 2.7578125, 3.61, 4.4921875, 5.28125, 6.1171875, 7.22]
|
21 |
+
VERTICAL_ND_CENTER = [0.56, 2.0239, 3.1839, 4.0511, 4.8867, 5.6992, 6.6686, 7.77]
|
22 |
+
HORIZONTAL_ND_BOUNDS = [1.8, 2.6, 3.2, 3.6, 4.4, 4.8, 5.8]
|
23 |
+
HORIZONTAL_ND_CENTER = [1.4, 2.2000, 2.9, 3.4, 4.0, 4.6, 5.3, 6.3]
|
24 |
+
|
25 |
+
|
26 |
+
def dict_to_obj(d):
|
27 |
+
if isinstance(d, list):
|
28 |
+
d = [dict_to_obj(x) if isinstance(x, dict) else x for x in d]
|
29 |
+
if not isinstance(d, dict):
|
30 |
+
return d
|
31 |
+
return SimpleNamespace(**{k: dict_to_obj(v) for k, v in d.items()})
|
32 |
+
|
33 |
+
|
34 |
+
def load_config(filename):
|
35 |
+
with open(filename, 'r') as file:
|
36 |
+
data = yaml.safe_load(file)
|
37 |
+
# Convert the dictionary to an object
|
38 |
+
data_obj = dict_to_obj(data)
|
39 |
+
return data_obj
|
40 |
+
|
41 |
+
|
42 |
+
@torch.no_grad()
|
43 |
+
def decode_sample_for_midi(sample, embed_model=None, scale_factor=1., threshold=-0.95):
|
44 |
+
# decode latent samples to a long piano roll of [0, 127]
|
45 |
+
sample = sample / scale_factor
|
46 |
+
|
47 |
+
if embed_model is not None:
|
48 |
+
image_size_h = sample.shape[-2]
|
49 |
+
image_size_w = sample.shape[-1]
|
50 |
+
if image_size_h > image_size_w: # transposed for raster col, don't need to permute for pixel space
|
51 |
+
sample = sample.permute(0, 1, 3, 2) # vertical axis means pitch after transpose
|
52 |
+
num_latents = sample.shape[-1] // sample.shape[-2]
|
53 |
+
if image_size_h >= image_size_w:
|
54 |
+
sample = torch.chunk(sample, num_latents, dim=-1) # B x C x H x W
|
55 |
+
sample = torch.concat(sample, dim=0) # 1st second for all batch, 2nd second for all batch, ...
|
56 |
+
sample = embed_model.decode(sample)
|
57 |
+
if image_size_h >= image_size_w:
|
58 |
+
sample = torch.concat(torch.chunk(sample, num_latents, dim=0), dim=-1)
|
59 |
+
|
60 |
+
sample[sample <= threshold] = -1. # heuristic thresholding the background
|
61 |
+
sample = ((sample + 1) * 63.5).clamp(0, 127).to(torch.uint8)
|
62 |
+
sample = sample.permute(0, 2, 3, 1)
|
63 |
+
sample = sample.contiguous()
|
64 |
+
return sample
|
65 |
+
|
66 |
+
|
67 |
+
def save_piano_roll_midi(sample, save_dir, fs=100, y=None, save_piano_roll=False, save_ind=0):
|
68 |
+
# input shape: B x 128 (pitch) x time (no pedal) or B x 2 (pedal) x 128 x time (with pedal)
|
69 |
+
fig_size = sample.shape[-1] // 128 * 3
|
70 |
+
plt.rcParams["figure.figsize"] = (fig_size, 3)
|
71 |
+
pedal = True if len(sample.shape) == 4 else False
|
72 |
+
onset = True if sample.shape[1] == 3 else False
|
73 |
+
for i in range(sample.shape[0]):
|
74 |
+
cur_sample = sample[i]
|
75 |
+
if cur_sample.shape[-1] < 5000 and save_piano_roll: # do not save piano rolls that are too long
|
76 |
+
if pedal:
|
77 |
+
plt.imshow(cur_sample[0, ::-1], vmin=0, vmax=127)
|
78 |
+
else:
|
79 |
+
plt.imshow(cur_sample[::-1], vmin=0, vmax=127)
|
80 |
+
plt.savefig(os.path.join(save_dir, "prsample_" + str(i) + ".png"))
|
81 |
+
if onset:
|
82 |
+
# add onset for first column
|
83 |
+
first_column = cur_sample[0, :, 0]
|
84 |
+
first_onset_pitch = first_column.nonzero()[0]
|
85 |
+
cur_sample[1, first_onset_pitch, 0] = 127
|
86 |
+
cur_sample = cur_sample.astype(np.float32)
|
87 |
+
pm = piano_roll_to_pretty_midi(cur_sample, fs=fs)
|
88 |
+
if y is not None:
|
89 |
+
save_name = 'sample_' + str(i + save_ind) + '_y_' + str(y[i].item()) + '.midi'
|
90 |
+
else:
|
91 |
+
save_name = 'sample_' + str(i + save_ind) + '.midi'
|
92 |
+
pm.write(os.path.join(save_dir, save_name))
|
93 |
+
return
|
94 |
+
|
95 |
+
|
96 |
+
def eval_rule_loss(generated_samples, target_rules):
|
97 |
+
results = {}
|
98 |
+
batch_size = generated_samples.shape[0]
|
99 |
+
for rule_name, rule_target in target_rules.items():
|
100 |
+
rule_target_list = rule_target.tolist()
|
101 |
+
if batch_size == 1:
|
102 |
+
rule_target_list = [rule_target_list]
|
103 |
+
results[rule_name + '.target_rule'] = rule_target_list
|
104 |
+
rule_target = rule_target.to(generated_samples.device)
|
105 |
+
if 'chord' in rule_name:
|
106 |
+
gen_rule, key, corr = FUNC_DICT[rule_name](generated_samples, return_key=True)
|
107 |
+
key_strings = [IND2KEY[key_ind] for key_ind in key]
|
108 |
+
loss = LOSS_DICT[rule_name](gen_rule, rule_target)
|
109 |
+
mean_loss, std_loss, gen_rule, loss = loss.mean(), loss.std(), gen_rule.tolist(), loss.tolist()
|
110 |
+
if batch_size == 1:
|
111 |
+
gen_rule = [gen_rule]
|
112 |
+
results[rule_name + '.gen_rule'] = gen_rule
|
113 |
+
results[rule_name + '.key_str'] = key_strings
|
114 |
+
results[rule_name + '.key_corr'] = corr
|
115 |
+
results[rule_name + '.loss'] = loss
|
116 |
+
else:
|
117 |
+
gen_rule = FUNC_DICT[rule_name](generated_samples)
|
118 |
+
loss = LOSS_DICT[rule_name](gen_rule, rule_target)
|
119 |
+
mean_loss, std_loss, gen_rule, loss = loss.mean(), loss.std(), gen_rule.tolist(), loss.tolist()
|
120 |
+
if batch_size == 1:
|
121 |
+
gen_rule = [gen_rule]
|
122 |
+
results[rule_name + '.gen_rule'] = gen_rule
|
123 |
+
results[rule_name + '.loss'] = loss
|
124 |
+
return pd.DataFrame(results)
|
125 |
+
|
126 |
+
|
127 |
+
def compute_rule(generated_samples, orig_samples, target_rules):
|
128 |
+
results = {}
|
129 |
+
batch_size = generated_samples.shape[0]
|
130 |
+
for rule_name in target_rules:
|
131 |
+
rule_target = FUNC_DICT[rule_name](orig_samples)
|
132 |
+
rule_target_list = rule_target.tolist()
|
133 |
+
if batch_size == 1:
|
134 |
+
rule_target_list = [rule_target_list]
|
135 |
+
results[rule_name + '.target_rule'] = rule_target_list
|
136 |
+
rule_target = rule_target.to(generated_samples.device)
|
137 |
+
if rule_name == 'chord_progression':
|
138 |
+
gen_rule, key, corr = FUNC_DICT[rule_name](generated_samples, return_key=True)
|
139 |
+
key_strings = [IND2KEY[key_ind] for key_ind in key]
|
140 |
+
loss = LOSS_DICT[rule_name](gen_rule, rule_target)
|
141 |
+
mean_loss, std_loss, gen_rule, loss = loss.mean(), loss.std(), gen_rule.tolist(), loss.tolist()
|
142 |
+
if batch_size == 1:
|
143 |
+
gen_rule = [gen_rule]
|
144 |
+
results[rule_name + '.gen_rule'] = gen_rule
|
145 |
+
results[rule_name + '.key_str'] = key_strings
|
146 |
+
results[rule_name + '.key_corr'] = corr
|
147 |
+
results[rule_name + '.loss'] = loss
|
148 |
+
else:
|
149 |
+
gen_rule = FUNC_DICT[rule_name](generated_samples)
|
150 |
+
loss = LOSS_DICT[rule_name](gen_rule, rule_target)
|
151 |
+
mean_loss, std_loss, gen_rule, loss = loss.mean(), loss.std(), gen_rule.tolist(), loss.tolist()
|
152 |
+
if batch_size == 1:
|
153 |
+
gen_rule = [gen_rule]
|
154 |
+
results[rule_name + '.gen_rule'] = gen_rule
|
155 |
+
results[rule_name + '.loss'] = loss
|
156 |
+
return pd.DataFrame(results)
|
157 |
+
|
158 |
+
|
159 |
+
def visualize_piano_roll(piano_roll):
|
160 |
+
"""
|
161 |
+
Assuming piano roll has shape Bx1x128x1024, and the values are between [-1, 1], on gpu.
|
162 |
+
Visualize with some gap in between the first 256, last 256/
|
163 |
+
"""
|
164 |
+
piano_roll = torch.flip(piano_roll, [2])
|
165 |
+
piano_roll = (piano_roll + 1) / 2.
|
166 |
+
vis_length = 256
|
167 |
+
gap = 80
|
168 |
+
plt.rcParams["figure.figsize"] = (12, 3)
|
169 |
+
data = torch.zeros(128, vis_length * 2 + gap)
|
170 |
+
data[:, :vis_length] = piano_roll[0, 0, :, :vis_length]
|
171 |
+
data[:, -vis_length:] = piano_roll[0, 0, :, -vis_length:]
|
172 |
+
data_clone = data.clone()
|
173 |
+
# make it look thicker
|
174 |
+
data[1:, :] = data[1:, :] + data_clone[:-1, :]
|
175 |
+
data[2:, :] = data[2:, :] + data_clone[:-2, :]
|
176 |
+
data = data.cpu().numpy()
|
177 |
+
plt.imshow(data, cmap=mpl.colormaps['Blues'])
|
178 |
+
ax = plt.gca() # gca stands for 'get current axis'
|
179 |
+
for edge, spine in ax.spines.items():
|
180 |
+
spine.set_linewidth(2) # Adjust the value as per your requirement
|
181 |
+
plt.grid(color='gray', linestyle='-', linewidth=2., alpha=0.5, which='both', axis='x')
|
182 |
+
plt.xticks(
|
183 |
+
np.concatenate((np.arange(0, vis_length + 1, 128), np.arange(vis_length + gap, vis_length * 2 + gap, 128))))
|
184 |
+
# plt.savefig('piano_roll_example.png', bbox_inches='tight', pad_inches=0.1, dpi=300)
|
185 |
+
plt.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)
|
186 |
+
plt.tight_layout()
|
187 |
+
plt.show()
|
188 |
+
|
189 |
+
plt.rcParams["figure.figsize"] = (3, 3)
|
190 |
+
for i in range(2):
|
191 |
+
plt.imshow(data[:, i*128: (i+1)*128], cmap=mpl.colormaps['Blues'])
|
192 |
+
ax = plt.gca()
|
193 |
+
for edge, spine in ax.spines.items():
|
194 |
+
spine.set_linewidth(2)
|
195 |
+
plt.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)
|
196 |
+
plt.tight_layout()
|
197 |
+
plt.show()
|
198 |
+
|
199 |
+
for i in range(-2, 0):
|
200 |
+
if (i+1)*128 < 0:
|
201 |
+
plt.imshow(data[:, i*128: (i+1)*128], cmap=mpl.colormaps['Blues'])
|
202 |
+
else:
|
203 |
+
plt.imshow(data[:, i*128:], cmap=mpl.colormaps['Blues'])
|
204 |
+
ax = plt.gca()
|
205 |
+
for edge, spine in ax.spines.items():
|
206 |
+
spine.set_linewidth(2)
|
207 |
+
plt.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)
|
208 |
+
plt.tight_layout()
|
209 |
+
plt.show()
|
210 |
+
|
211 |
+
return
|
212 |
+
|
213 |
+
|
214 |
+
def visualize_full_piano_roll(midi_file_name, fs=100):
|
215 |
+
"""
|
216 |
+
Visualize full piano roll from midi file
|
217 |
+
"""
|
218 |
+
midi_data = pretty_midi.PrettyMIDI(midi_file_name)
|
219 |
+
# do not process sustain pedal
|
220 |
+
piano_roll = torch.tensor(midi_data.get_piano_roll(fs=fs, pedal_threshold=None))
|
221 |
+
data = torch.flip(piano_roll, [0])
|
222 |
+
plt.rcParams["figure.figsize"] = (12, 3)
|
223 |
+
# data_clone = data.clone()
|
224 |
+
# # make it look thicker
|
225 |
+
# data[1:, :] = data[1:, :] + data_clone[:-1, :]
|
226 |
+
# data[2:, :] = data[2:, :] + data_clone[:-2, :]
|
227 |
+
data = data.cpu().numpy()
|
228 |
+
plt.imshow(data, cmap=mpl.colormaps['Blues'])
|
229 |
+
ax = plt.gca() # gca stands for 'get current axis'
|
230 |
+
for edge, spine in ax.spines.items():
|
231 |
+
spine.set_linewidth(2) # Adjust the value as per your requirement
|
232 |
+
plt.grid(color='gray', linestyle='-', linewidth=2., alpha=0.5, which='both', axis='x')
|
233 |
+
plt.xticks(np.arange(0, piano_roll.shape[1], 128))
|
234 |
+
# plt.savefig('piano_roll_example.png', bbox_inches='tight', pad_inches=0.1, dpi=300)
|
235 |
+
plt.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)
|
236 |
+
plt.tight_layout()
|
237 |
+
plt.show()
|
238 |
+
return
|
239 |
+
|
240 |
+
|
241 |
+
def plot_record(vals, title, save_dir):
|
242 |
+
ts = [item[0] for item in vals]
|
243 |
+
log_probs = [item[1] for item in vals]
|
244 |
+
plt.plot(ts, log_probs)
|
245 |
+
plt.gca().invert_xaxis()
|
246 |
+
plt.title(title)
|
247 |
+
plt.savefig(save_dir + '/' + title + '.png')
|
248 |
+
plt.show()
|
249 |
+
return
|
250 |
+
|
251 |
+
|
252 |
+
def quantize_pedal(value, num_bins=8):
|
253 |
+
"""Quantize an integer value from 0 to 127 into 8 bins and return the center value of the bin."""
|
254 |
+
if value < 0 or value > 127:
|
255 |
+
raise ValueError("Value should be between 0 and 127")
|
256 |
+
# Determine bin size
|
257 |
+
bin_size = 128 // num_bins # 16
|
258 |
+
# Quantize the value
|
259 |
+
bin_index = value // bin_size
|
260 |
+
bin_center = bin_size * bin_index + bin_size // 2
|
261 |
+
# Handle edge case for the last bin
|
262 |
+
if bin_center > 127:
|
263 |
+
bin_center = 127
|
264 |
+
return bin_center
|
265 |
+
|
266 |
+
|
267 |
+
def get_full_piano_roll(midi_data, fs, show=False):
|
268 |
+
# do not process sustain pedal
|
269 |
+
piano_roll, onset_roll = midi_data.get_piano_roll(fs=fs, pedal_threshold=None, onset=True)
|
270 |
+
# save pedal roll explicitly
|
271 |
+
pedal_roll = np.zeros_like(piano_roll)
|
272 |
+
# process pedal
|
273 |
+
for instru in midi_data.instruments:
|
274 |
+
pedal_changes = [_e for _e in instru.control_changes if _e.number == CC_SUSTAIN_PEDAL]
|
275 |
+
for cc in pedal_changes:
|
276 |
+
time_now = int(cc.time * fs)
|
277 |
+
if time_now < pedal_roll.shape[-1]:
|
278 |
+
# need to distinguish control_change 0 and background 0, with quantize 0-16 will be 8
|
279 |
+
# in muscore files, 0 immediately followed by 127, need to shift by one column
|
280 |
+
if pedal_roll[MIN_PIANO, time_now] != 0. and abs(pedal_roll[MIN_PIANO, time_now] - cc.value) > 64:
|
281 |
+
# use shift 2 here to prevent missing change when using interpolation augmentation
|
282 |
+
pedal_roll[MIN_PIANO:MAX_PIANO + 1, min(time_now + 2, pedal_roll.shape[-1] - 1)] = quantize_pedal(cc.value)
|
283 |
+
else:
|
284 |
+
pedal_roll[MIN_PIANO:MAX_PIANO + 1, time_now] = quantize_pedal(cc.value)
|
285 |
+
full_roll = np.concatenate((piano_roll[None], onset_roll[None], pedal_roll[None]), axis=0)
|
286 |
+
if show:
|
287 |
+
plt.imshow(piano_roll[::-1, :1024], vmin=0, vmax=127)
|
288 |
+
plt.show()
|
289 |
+
plt.imshow(pedal_roll[::-1, :1024], vmin=0, vmax=127)
|
290 |
+
plt.show()
|
291 |
+
return full_roll
|
guided_diffusion/nn.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Various utilities for neural networks.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch as th
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
12 |
+
class SiLU(nn.Module):
|
13 |
+
def forward(self, x):
|
14 |
+
return x * th.sigmoid(x)
|
15 |
+
|
16 |
+
|
17 |
+
class GroupNorm32(nn.GroupNorm):
|
18 |
+
def forward(self, x):
|
19 |
+
return super().forward(x.float()).type(x.dtype)
|
20 |
+
|
21 |
+
|
22 |
+
def conv_nd(dims, *args, **kwargs):
|
23 |
+
"""
|
24 |
+
Create a 1D, 2D, or 3D convolution module.
|
25 |
+
"""
|
26 |
+
if dims == 1:
|
27 |
+
return nn.Conv1d(*args, **kwargs)
|
28 |
+
elif dims == 2:
|
29 |
+
return nn.Conv2d(*args, **kwargs)
|
30 |
+
elif dims == 3:
|
31 |
+
return nn.Conv3d(*args, **kwargs)
|
32 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
33 |
+
|
34 |
+
|
35 |
+
def linear(*args, **kwargs):
|
36 |
+
"""
|
37 |
+
Create a linear module.
|
38 |
+
"""
|
39 |
+
return nn.Linear(*args, **kwargs)
|
40 |
+
|
41 |
+
|
42 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
43 |
+
"""
|
44 |
+
Create a 1D, 2D, or 3D average pooling module.
|
45 |
+
"""
|
46 |
+
if dims == 1:
|
47 |
+
return nn.AvgPool1d(*args, **kwargs)
|
48 |
+
elif dims == 2:
|
49 |
+
return nn.AvgPool2d(*args, **kwargs)
|
50 |
+
elif dims == 3:
|
51 |
+
return nn.AvgPool3d(*args, **kwargs)
|
52 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
53 |
+
|
54 |
+
|
55 |
+
def update_ema(target_params, source_params, rate=0.99):
|
56 |
+
"""
|
57 |
+
Update target parameters to be closer to those of source parameters using
|
58 |
+
an exponential moving average.
|
59 |
+
|
60 |
+
:param target_params: the target parameter sequence.
|
61 |
+
:param source_params: the source parameter sequence.
|
62 |
+
:param rate: the EMA rate (closer to 1 means slower).
|
63 |
+
"""
|
64 |
+
for targ, src in zip(target_params, source_params):
|
65 |
+
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
66 |
+
|
67 |
+
|
68 |
+
def zero_module(module):
|
69 |
+
"""
|
70 |
+
Zero out the parameters of a module and return it.
|
71 |
+
"""
|
72 |
+
for p in module.parameters():
|
73 |
+
p.detach().zero_()
|
74 |
+
return module
|
75 |
+
|
76 |
+
|
77 |
+
def scale_module(module, scale):
|
78 |
+
"""
|
79 |
+
Scale the parameters of a module and return it.
|
80 |
+
"""
|
81 |
+
for p in module.parameters():
|
82 |
+
p.detach().mul_(scale)
|
83 |
+
return module
|
84 |
+
|
85 |
+
|
86 |
+
def mean_flat(tensor):
|
87 |
+
"""
|
88 |
+
Take the mean over all non-batch dimensions.
|
89 |
+
"""
|
90 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
91 |
+
|
92 |
+
|
93 |
+
def normalization(channels):
|
94 |
+
"""
|
95 |
+
Make a standard normalization layer.
|
96 |
+
|
97 |
+
:param channels: number of input channels.
|
98 |
+
:return: an nn.Module for normalization.
|
99 |
+
"""
|
100 |
+
return GroupNorm32(32, channels)
|
101 |
+
|
102 |
+
|
103 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
104 |
+
"""
|
105 |
+
Create sinusoidal timestep embeddings.
|
106 |
+
|
107 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
108 |
+
These may be fractional.
|
109 |
+
:param dim: the dimension of the output.
|
110 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
111 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
112 |
+
"""
|
113 |
+
half = dim // 2
|
114 |
+
freqs = th.exp(
|
115 |
+
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
|
116 |
+
).to(device=timesteps.device)
|
117 |
+
args = timesteps[:, None].float() * freqs[None]
|
118 |
+
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
119 |
+
if dim % 2:
|
120 |
+
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
121 |
+
return embedding
|
122 |
+
|
123 |
+
|
124 |
+
def checkpoint(func, inputs, params, flag):
|
125 |
+
"""
|
126 |
+
Evaluate a function without caching intermediate activations, allowing for
|
127 |
+
reduced memory at the expense of extra compute in the backward pass.
|
128 |
+
|
129 |
+
:param func: the function to evaluate.
|
130 |
+
:param inputs: the argument sequence to pass to `func`.
|
131 |
+
:param params: a sequence of parameters `func` depends on but does not
|
132 |
+
explicitly take as arguments.
|
133 |
+
:param flag: if False, disable gradient checkpointing.
|
134 |
+
"""
|
135 |
+
if flag:
|
136 |
+
args = tuple(inputs) + tuple(params)
|
137 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
138 |
+
else:
|
139 |
+
return func(*inputs)
|
140 |
+
|
141 |
+
|
142 |
+
class CheckpointFunction(th.autograd.Function):
|
143 |
+
@staticmethod
|
144 |
+
def forward(ctx, run_function, length, *args):
|
145 |
+
ctx.run_function = run_function
|
146 |
+
ctx.input_tensors = list(args[:length])
|
147 |
+
ctx.input_params = list(args[length:])
|
148 |
+
with th.no_grad():
|
149 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
150 |
+
return output_tensors
|
151 |
+
|
152 |
+
@staticmethod
|
153 |
+
def backward(ctx, *output_grads):
|
154 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
155 |
+
with th.enable_grad():
|
156 |
+
# Fixes a bug where the first op in run_function modifies the
|
157 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
158 |
+
# Tensors.
|
159 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
160 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
161 |
+
input_grads = th.autograd.grad(
|
162 |
+
output_tensors,
|
163 |
+
ctx.input_tensors + ctx.input_params,
|
164 |
+
output_grads,
|
165 |
+
allow_unused=True,
|
166 |
+
)
|
167 |
+
del ctx.input_tensors
|
168 |
+
del ctx.input_params
|
169 |
+
del output_tensors
|
170 |
+
return (None, None) + input_grads
|
guided_diffusion/pr_datasets_all.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import os
|
4 |
+
import pandas as pd
|
5 |
+
import csv
|
6 |
+
import re
|
7 |
+
from PIL import Image
|
8 |
+
import blobfile as bf
|
9 |
+
from mpi4py import MPI
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.utils.data import DataLoader, Dataset
|
14 |
+
|
15 |
+
from music_rule_guidance import music_rules
|
16 |
+
from music_rule_guidance.rule_maps import FUNC_DICT
|
17 |
+
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
plt.rcParams["figure.figsize"] = (6,3)
|
20 |
+
plt.rcParams['figure.dpi'] = 300
|
21 |
+
plt.rcParams['savefig.dpi'] = 300
|
22 |
+
|
23 |
+
# This file load in merged dataset with y being its dataset info
|
24 |
+
|
25 |
+
|
26 |
+
def load_data(
|
27 |
+
*,
|
28 |
+
data_dir,
|
29 |
+
batch_size,
|
30 |
+
class_cond=False,
|
31 |
+
deterministic=False,
|
32 |
+
image_size=1024,
|
33 |
+
rule=None,
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
For a dataset, create a generator over (images, kwargs) pairs.
|
37 |
+
|
38 |
+
Each images is an NCHW float tensor, and the kwargs dict contains zero or
|
39 |
+
more keys, each of which map to a batched Tensor of their own.
|
40 |
+
The kwargs dict can be used for class labels, in which case the key is "y"
|
41 |
+
and the values are integer tensors of class labels.
|
42 |
+
|
43 |
+
:param data_dir: the csv file that contains all the data paths and classes.
|
44 |
+
:param batch_size: the batch size of each returned pair.
|
45 |
+
:param image_size: the size to which images are resized.
|
46 |
+
:param class_cond: if True, include a "y" key in returned dicts for class
|
47 |
+
label. If classes are not available and this is true, an
|
48 |
+
exception will be raised.
|
49 |
+
:param deterministic: if True, yield results in a deterministic order.
|
50 |
+
:param rule: a str that contains the name of the rule
|
51 |
+
"""
|
52 |
+
|
53 |
+
df = pd.read_csv(data_dir)
|
54 |
+
all_files = df['midi_filename'].tolist()
|
55 |
+
classes = None
|
56 |
+
if class_cond:
|
57 |
+
classes = df['classes'].tolist()
|
58 |
+
if deterministic:
|
59 |
+
dataset = ImageDataset(
|
60 |
+
all_files,
|
61 |
+
classes=classes,
|
62 |
+
shard=MPI.COMM_WORLD.Get_rank(),
|
63 |
+
num_shards=MPI.COMM_WORLD.Get_size(),
|
64 |
+
image_size=image_size,
|
65 |
+
rule=rule,
|
66 |
+
pitch_shift=False,
|
67 |
+
time_stretch=False,
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
dataset = ImageDataset(
|
71 |
+
all_files,
|
72 |
+
classes=classes,
|
73 |
+
shard=MPI.COMM_WORLD.Get_rank(),
|
74 |
+
num_shards=MPI.COMM_WORLD.Get_size(),
|
75 |
+
image_size=image_size,
|
76 |
+
rule=rule,
|
77 |
+
)
|
78 |
+
if deterministic:
|
79 |
+
loader = DataLoader(
|
80 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
loader = DataLoader(
|
84 |
+
dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
|
85 |
+
)
|
86 |
+
while True:
|
87 |
+
yield from loader
|
88 |
+
|
89 |
+
|
90 |
+
def key_shift(x, k):
|
91 |
+
# apply shift on both notes and onset
|
92 |
+
# x sample (batch x 3 x pitch x time)
|
93 |
+
# k number of pitches to shift
|
94 |
+
# only apply on (batch x 2 x pitch x time) because no key shift on pedal
|
95 |
+
|
96 |
+
pitches_and_onsets = x[:, :2, :, :]
|
97 |
+
pedals = x[:, 2:, :, :]
|
98 |
+
|
99 |
+
if k > 0:
|
100 |
+
pitches_and_onsets = torch.cat((pitches_and_onsets[:, :, k:, :], pitches_and_onsets[:, :, 0:k, :]), dim=2)
|
101 |
+
elif k < 0:
|
102 |
+
pitches_and_onsets = torch.cat((pitches_and_onsets[:, :, -k:, :], pitches_and_onsets[:, :, 0:-k, :]), dim=2)
|
103 |
+
|
104 |
+
x = torch.cat((pitches_and_onsets, pedals), dim=1)
|
105 |
+
return music_rules.piano_like(x)
|
106 |
+
|
107 |
+
|
108 |
+
class ImageDataset(Dataset):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
image_paths,
|
112 |
+
classes=None,
|
113 |
+
rule=None,
|
114 |
+
shard=0,
|
115 |
+
num_shards=1,
|
116 |
+
image_size=1024,
|
117 |
+
pitch_shift=True,
|
118 |
+
time_stretch=True,
|
119 |
+
):
|
120 |
+
super().__init__()
|
121 |
+
self.local_images = image_paths[shard:][::num_shards]
|
122 |
+
self.local_classes = None if classes is None else classes[shard:][::num_shards]
|
123 |
+
self.rule = rule
|
124 |
+
self.pitch_shift = pitch_shift
|
125 |
+
self.time_stretch = time_stretch
|
126 |
+
self.image_size = image_size
|
127 |
+
|
128 |
+
def __len__(self):
|
129 |
+
return len(self.local_images)
|
130 |
+
|
131 |
+
def __getitem__(self, idx):
|
132 |
+
path = self.local_images[idx]
|
133 |
+
arr = np.load(path)[np.newaxis] # 1 x 2 x 128 x time
|
134 |
+
arr = arr.astype(np.float32) / 63.5 - 1
|
135 |
+
arr = torch.from_numpy(arr)
|
136 |
+
|
137 |
+
if self.time_stretch: # apply for both notes and pedal
|
138 |
+
pr_len = int(np.random.uniform(0.95, 1.05) * self.image_size)
|
139 |
+
start = np.random.randint(arr.shape[-1] - pr_len)
|
140 |
+
arr = arr[:, :, :, start:start+pr_len]
|
141 |
+
if pr_len < self.image_size: # stretching, prevent duplicating onsets
|
142 |
+
piano_pedal = arr[:, [0, 2], :, :]
|
143 |
+
piano_pedal = F.interpolate(piano_pedal, size=(128, self.image_size), mode='nearest')
|
144 |
+
onset_raw = arr[:, 1:2, :, :]
|
145 |
+
ind_a2b = (torch.arange(self.image_size)/self.image_size*pr_len).int()
|
146 |
+
ind = ind_a2b.diff().nonzero().squeeze() + 1
|
147 |
+
zero_tensor = torch.tensor([0])
|
148 |
+
ind = torch.concat((zero_tensor, ind))
|
149 |
+
onset = -torch.ones(1, 1, 128, self.image_size)
|
150 |
+
onset[:, :, :, ind] = onset_raw
|
151 |
+
arr = torch.concat((piano_pedal[:, :1, :, :], onset, piano_pedal[:, 1:, :, :]), dim=1)
|
152 |
+
if pr_len > self.image_size: # compressing, add onset if happen to drop onsets and keep durations
|
153 |
+
arr = F.interpolate(arr, size=(128, self.image_size), mode='nearest')
|
154 |
+
piano = arr[:, :1, :, :]
|
155 |
+
first_column = piano[:, :, :, :1]
|
156 |
+
padded_piano = torch.concat((first_column, piano), dim=-1)
|
157 |
+
onset_online = torch.diff(padded_piano, dim=-1)
|
158 |
+
mask = onset_online > 0
|
159 |
+
arr[:, 1:2, :, :][mask] = 1
|
160 |
+
else:
|
161 |
+
arr = arr[:, :, :, :self.image_size]
|
162 |
+
if self.pitch_shift: # only apply for notes
|
163 |
+
k = np.random.randint(-6, 7) # generate randint from -6 to +6
|
164 |
+
arr = key_shift(arr, k)
|
165 |
+
|
166 |
+
arr = music_rules.piano_like(arr) # also set pedal roll to be 0 for non-piano pitches (match VAE training)
|
167 |
+
|
168 |
+
out_dict = {}
|
169 |
+
if self.rule is not None:
|
170 |
+
if 'chord' in self.rule: # predict chord and key jointly
|
171 |
+
chord, key, _ = FUNC_DICT[self.rule](arr, return_key=True)
|
172 |
+
out_dict["chord"] = chord
|
173 |
+
out_dict["key"] = np.array(key, dtype=np.int64)
|
174 |
+
else:
|
175 |
+
out_dict[self.rule] = FUNC_DICT[self.rule](arr)
|
176 |
+
if self.local_classes is not None:
|
177 |
+
out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
178 |
+
# debug
|
179 |
+
# out_dict["path"] = path
|
180 |
+
# Remove the extra dimensions to get back a 3D tensor: 2x128x128
|
181 |
+
arr = arr.squeeze(0)
|
182 |
+
return arr, out_dict
|
183 |
+
|
guided_diffusion/resample.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch as th
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
def create_named_schedule_sampler(name, diffusion):
|
9 |
+
"""
|
10 |
+
Create a ScheduleSampler from a library of pre-defined samplers.
|
11 |
+
|
12 |
+
:param name: the name of the sampler.
|
13 |
+
:param diffusion: the diffusion object to sample for.
|
14 |
+
"""
|
15 |
+
if name == "uniform":
|
16 |
+
return UniformSampler(diffusion)
|
17 |
+
elif name == "loss-second-moment":
|
18 |
+
return LossSecondMomentResampler(diffusion)
|
19 |
+
else:
|
20 |
+
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
21 |
+
|
22 |
+
|
23 |
+
class ScheduleSampler(ABC):
|
24 |
+
"""
|
25 |
+
A distribution over timesteps in the diffusion process, intended to reduce
|
26 |
+
variance of the objective.
|
27 |
+
|
28 |
+
By default, samplers perform unbiased importance sampling, in which the
|
29 |
+
objective's mean is unchanged.
|
30 |
+
However, subclasses may override sample() to change how the resampled
|
31 |
+
terms are reweighted, allowing for actual changes in the objective.
|
32 |
+
"""
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def weights(self):
|
36 |
+
"""
|
37 |
+
Get a numpy array of weights, one per diffusion step.
|
38 |
+
|
39 |
+
The weights needn't be normalized, but must be positive.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def sample(self, batch_size, device):
|
43 |
+
"""
|
44 |
+
Importance-sample timesteps for a batch.
|
45 |
+
|
46 |
+
:param batch_size: the number of timesteps.
|
47 |
+
:param device: the torch device to save to.
|
48 |
+
:return: a tuple (timesteps, weights):
|
49 |
+
- timesteps: a tensor of timestep indices.
|
50 |
+
- weights: a tensor of weights to scale the resulting losses.
|
51 |
+
"""
|
52 |
+
w = self.weights()
|
53 |
+
p = w / np.sum(w)
|
54 |
+
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
55 |
+
indices = th.from_numpy(indices_np).long().to(device)
|
56 |
+
weights_np = 1 / (len(p) * p[indices_np])
|
57 |
+
weights = th.from_numpy(weights_np).float().to(device)
|
58 |
+
return indices, weights
|
59 |
+
|
60 |
+
|
61 |
+
class UniformSampler(ScheduleSampler):
|
62 |
+
def __init__(self, diffusion):
|
63 |
+
self.diffusion = diffusion
|
64 |
+
self._weights = np.ones([diffusion.num_timesteps])
|
65 |
+
|
66 |
+
def weights(self):
|
67 |
+
return self._weights
|
68 |
+
|
69 |
+
|
70 |
+
class LossAwareSampler(ScheduleSampler):
|
71 |
+
def update_with_local_losses(self, local_ts, local_losses):
|
72 |
+
"""
|
73 |
+
Update the reweighting using losses from a model.
|
74 |
+
|
75 |
+
Call this method from each rank with a batch of timesteps and the
|
76 |
+
corresponding losses for each of those timesteps.
|
77 |
+
This method will perform synchronization to make sure all of the ranks
|
78 |
+
maintain the exact same reweighting.
|
79 |
+
|
80 |
+
:param local_ts: an integer Tensor of timesteps.
|
81 |
+
:param local_losses: a 1D Tensor of losses.
|
82 |
+
"""
|
83 |
+
batch_sizes = [
|
84 |
+
th.tensor([0], dtype=th.int32, device=local_ts.device)
|
85 |
+
for _ in range(dist.get_world_size())
|
86 |
+
]
|
87 |
+
dist.all_gather(
|
88 |
+
batch_sizes,
|
89 |
+
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
90 |
+
)
|
91 |
+
|
92 |
+
# Pad all_gather batches to be the maximum batch size.
|
93 |
+
batch_sizes = [x.item() for x in batch_sizes]
|
94 |
+
max_bs = max(batch_sizes)
|
95 |
+
|
96 |
+
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
97 |
+
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
98 |
+
dist.all_gather(timestep_batches, local_ts)
|
99 |
+
dist.all_gather(loss_batches, local_losses)
|
100 |
+
timesteps = [
|
101 |
+
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
|
102 |
+
]
|
103 |
+
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
104 |
+
self.update_with_all_losses(timesteps, losses)
|
105 |
+
|
106 |
+
@abstractmethod
|
107 |
+
def update_with_all_losses(self, ts, losses):
|
108 |
+
"""
|
109 |
+
Update the reweighting using losses from a model.
|
110 |
+
|
111 |
+
Sub-classes should override this method to update the reweighting
|
112 |
+
using losses from the model.
|
113 |
+
|
114 |
+
This method directly updates the reweighting without synchronizing
|
115 |
+
between workers. It is called by update_with_local_losses from all
|
116 |
+
ranks with identical arguments. Thus, it should have deterministic
|
117 |
+
behavior to maintain state across workers.
|
118 |
+
|
119 |
+
:param ts: a list of int timesteps.
|
120 |
+
:param losses: a list of float losses, one per timestep.
|
121 |
+
"""
|
122 |
+
|
123 |
+
|
124 |
+
class LossSecondMomentResampler(LossAwareSampler):
|
125 |
+
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
126 |
+
self.diffusion = diffusion
|
127 |
+
self.history_per_term = history_per_term
|
128 |
+
self.uniform_prob = uniform_prob
|
129 |
+
self._loss_history = np.zeros(
|
130 |
+
[diffusion.num_timesteps, history_per_term], dtype=np.float64
|
131 |
+
)
|
132 |
+
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=int)
|
133 |
+
|
134 |
+
def weights(self):
|
135 |
+
if not self._warmed_up():
|
136 |
+
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
137 |
+
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
|
138 |
+
weights /= np.sum(weights)
|
139 |
+
weights *= 1 - self.uniform_prob
|
140 |
+
weights += self.uniform_prob / len(weights)
|
141 |
+
return weights
|
142 |
+
|
143 |
+
def update_with_all_losses(self, ts, losses):
|
144 |
+
for t, loss in zip(ts, losses):
|
145 |
+
if self._loss_counts[t] == self.history_per_term:
|
146 |
+
# Shift out the oldest loss term.
|
147 |
+
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
148 |
+
self._loss_history[t, -1] = loss
|
149 |
+
else:
|
150 |
+
self._loss_history[t, self._loss_counts[t]] = loss
|
151 |
+
self._loss_counts[t] += 1
|
152 |
+
|
153 |
+
def _warmed_up(self):
|
154 |
+
return (self._loss_counts == self.history_per_term).all()
|
guided_diffusion/respace.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch as th
|
3 |
+
|
4 |
+
from .gaussian_diffusion import GaussianDiffusion
|
5 |
+
|
6 |
+
|
7 |
+
def space_timesteps(num_timesteps, section_counts):
|
8 |
+
"""
|
9 |
+
Create a list of timesteps to use from an original diffusion process,
|
10 |
+
given the number of timesteps we want to take from equally-sized portions
|
11 |
+
of the original process.
|
12 |
+
|
13 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
14 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
15 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
16 |
+
|
17 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
18 |
+
from the DDIM paper is used, and only one section is allowed.
|
19 |
+
|
20 |
+
:param num_timesteps: the number of diffusion steps in the original
|
21 |
+
process to divide up.
|
22 |
+
:param section_counts: either a list of numbers, or a string containing
|
23 |
+
comma-separated numbers, indicating the step count
|
24 |
+
per section. As a special case, use "ddimN" where N
|
25 |
+
is a number of steps to use the striding from the
|
26 |
+
DDIM paper.
|
27 |
+
:return: a set of diffusion steps from the original process to use.
|
28 |
+
"""
|
29 |
+
if isinstance(section_counts, str):
|
30 |
+
if section_counts.startswith("ddim"):
|
31 |
+
desired_count = int(section_counts[len("ddim") :])
|
32 |
+
for i in range(1, num_timesteps):
|
33 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
34 |
+
return set(range(0, num_timesteps, i))
|
35 |
+
raise ValueError(
|
36 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
37 |
+
)
|
38 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
39 |
+
size_per = num_timesteps // len(section_counts)
|
40 |
+
extra = num_timesteps % len(section_counts)
|
41 |
+
start_idx = 0
|
42 |
+
all_steps = []
|
43 |
+
for i, section_count in enumerate(section_counts):
|
44 |
+
size = size_per + (1 if i < extra else 0)
|
45 |
+
if size < section_count:
|
46 |
+
raise ValueError(
|
47 |
+
f"cannot divide section of {size} steps into {section_count}"
|
48 |
+
)
|
49 |
+
if section_count <= 1:
|
50 |
+
frac_stride = 1
|
51 |
+
else:
|
52 |
+
frac_stride = (size - 1) / (section_count - 1)
|
53 |
+
cur_idx = 0.0
|
54 |
+
taken_steps = []
|
55 |
+
for _ in range(section_count):
|
56 |
+
taken_steps.append(start_idx + round(cur_idx))
|
57 |
+
cur_idx += frac_stride
|
58 |
+
all_steps += taken_steps
|
59 |
+
start_idx += size
|
60 |
+
return set(all_steps)
|
61 |
+
|
62 |
+
|
63 |
+
class SpacedDiffusion(GaussianDiffusion):
|
64 |
+
"""
|
65 |
+
A diffusion process which can skip steps in a base diffusion process.
|
66 |
+
|
67 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
68 |
+
original diffusion process to retain.
|
69 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self, use_timesteps, **kwargs):
|
73 |
+
self.use_timesteps = set(use_timesteps)
|
74 |
+
self.timestep_map = []
|
75 |
+
self.original_num_steps = len(kwargs["betas"])
|
76 |
+
|
77 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
78 |
+
last_alpha_cumprod = 1.0
|
79 |
+
new_betas = []
|
80 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
81 |
+
if i in self.use_timesteps:
|
82 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
83 |
+
last_alpha_cumprod = alpha_cumprod
|
84 |
+
self.timestep_map.append(i)
|
85 |
+
kwargs["betas"] = np.array(new_betas)
|
86 |
+
super().__init__(**kwargs)
|
87 |
+
|
88 |
+
def p_mean_variance(
|
89 |
+
self, model, *args, **kwargs
|
90 |
+
): # pylint: disable=signature-differs
|
91 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
92 |
+
|
93 |
+
def training_losses(
|
94 |
+
self, model, *args, **kwargs
|
95 |
+
): # pylint: disable=signature-differs
|
96 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
97 |
+
|
98 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
99 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
100 |
+
|
101 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
102 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
103 |
+
|
104 |
+
def _wrap_model(self, model):
|
105 |
+
if isinstance(model, _WrappedModel):
|
106 |
+
return model
|
107 |
+
return _WrappedModel(
|
108 |
+
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
|
109 |
+
)
|
110 |
+
|
111 |
+
def _scale_timesteps(self, t):
|
112 |
+
# Scaling is done by the wrapped model.
|
113 |
+
return t
|
114 |
+
|
115 |
+
|
116 |
+
class _WrappedModel:
|
117 |
+
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
|
118 |
+
self.model = model
|
119 |
+
self.timestep_map = timestep_map
|
120 |
+
self.rescale_timesteps = rescale_timesteps
|
121 |
+
self.original_num_steps = original_num_steps
|
122 |
+
|
123 |
+
def __call__(self, x, ts, **kwargs):
|
124 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
125 |
+
new_ts = map_tensor[ts]
|
126 |
+
if self.rescale_timesteps:
|
127 |
+
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
128 |
+
return self.model(x, new_ts, **kwargs)
|
guided_diffusion/script_util.py
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import inspect
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from music_rule_guidance import music_rules
|
6 |
+
from . import gaussian_diffusion as gd
|
7 |
+
from .respace import SpacedDiffusion, space_timesteps
|
8 |
+
from .unet import SuperResModel, UNetModel, EncoderUNetModel
|
9 |
+
|
10 |
+
NUM_CLASSES = 3 # number of datasets
|
11 |
+
|
12 |
+
|
13 |
+
def diffusion_defaults():
|
14 |
+
"""
|
15 |
+
Defaults for image and classifier training.
|
16 |
+
"""
|
17 |
+
return dict(
|
18 |
+
learn_sigma=False,
|
19 |
+
diffusion_steps=1000,
|
20 |
+
noise_schedule="linear",
|
21 |
+
timestep_respacing="",
|
22 |
+
use_kl=False,
|
23 |
+
predict_xstart=False,
|
24 |
+
rescale_timesteps=False,
|
25 |
+
rescale_learned_sigmas=False,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
def classifier_defaults():
|
30 |
+
"""
|
31 |
+
Defaults for classifier models.
|
32 |
+
"""
|
33 |
+
return dict(
|
34 |
+
image_size=64,
|
35 |
+
in_channels=3,
|
36 |
+
classifier_use_fp16=False,
|
37 |
+
classifier_width=128,
|
38 |
+
classifier_depth=2,
|
39 |
+
classifier_attention_resolutions="32,16,8", # 16
|
40 |
+
classifier_use_scale_shift_norm=True, # False
|
41 |
+
classifier_resblock_updown=True, # False
|
42 |
+
classifier_pool="attention",
|
43 |
+
num_classes=3,
|
44 |
+
chord=False,
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
def model_and_diffusion_image_defaults():
|
49 |
+
"""
|
50 |
+
Defaults for image training.
|
51 |
+
"""
|
52 |
+
res = dict(
|
53 |
+
image_size=64,
|
54 |
+
in_channels=3,
|
55 |
+
num_channels=128,
|
56 |
+
num_res_blocks=2,
|
57 |
+
num_heads=4,
|
58 |
+
num_heads_upsample=-1,
|
59 |
+
num_head_channels=-1,
|
60 |
+
attention_resolutions="32,16,8",
|
61 |
+
channel_mult="",
|
62 |
+
dropout=0.0,
|
63 |
+
class_cond=False,
|
64 |
+
use_checkpoint=False,
|
65 |
+
use_scale_shift_norm=True,
|
66 |
+
resblock_updown=False,
|
67 |
+
use_fp16=False,
|
68 |
+
use_new_attention_order=False,
|
69 |
+
)
|
70 |
+
res.update(diffusion_defaults())
|
71 |
+
return res
|
72 |
+
|
73 |
+
|
74 |
+
def model_and_diffusion_defaults():
|
75 |
+
"""
|
76 |
+
Defaults for piano roll training.
|
77 |
+
"""
|
78 |
+
res = dict(
|
79 |
+
image_size=128,
|
80 |
+
in_channels=1,
|
81 |
+
num_channels=128,
|
82 |
+
num_res_blocks=2,
|
83 |
+
num_heads=4,
|
84 |
+
num_heads_upsample=-1,
|
85 |
+
num_head_channels=-1,
|
86 |
+
attention_resolutions="32,16,8",
|
87 |
+
channel_mult="",
|
88 |
+
dropout=0.0,
|
89 |
+
class_cond=False,
|
90 |
+
use_checkpoint=False,
|
91 |
+
use_scale_shift_norm=True,
|
92 |
+
resblock_updown=False,
|
93 |
+
use_fp16=False,
|
94 |
+
use_new_attention_order=False,
|
95 |
+
)
|
96 |
+
res.update(diffusion_defaults())
|
97 |
+
return res
|
98 |
+
|
99 |
+
|
100 |
+
def classifier_and_diffusion_defaults():
|
101 |
+
res = classifier_defaults()
|
102 |
+
res.update(diffusion_defaults())
|
103 |
+
return res
|
104 |
+
|
105 |
+
|
106 |
+
def create_diffusion(
|
107 |
+
learn_sigma,
|
108 |
+
diffusion_steps,
|
109 |
+
noise_schedule,
|
110 |
+
timestep_respacing,
|
111 |
+
use_kl,
|
112 |
+
predict_xstart,
|
113 |
+
rescale_timesteps,
|
114 |
+
rescale_learned_sigmas,
|
115 |
+
):
|
116 |
+
diffusion = create_gaussian_diffusion(
|
117 |
+
steps=diffusion_steps,
|
118 |
+
learn_sigma=learn_sigma,
|
119 |
+
noise_schedule=noise_schedule,
|
120 |
+
use_kl=use_kl,
|
121 |
+
predict_xstart=predict_xstart,
|
122 |
+
rescale_timesteps=rescale_timesteps,
|
123 |
+
rescale_learned_sigmas=rescale_learned_sigmas,
|
124 |
+
timestep_respacing=timestep_respacing,
|
125 |
+
)
|
126 |
+
return diffusion
|
127 |
+
|
128 |
+
|
129 |
+
def create_model_and_diffusion(
|
130 |
+
image_size,
|
131 |
+
in_channels,
|
132 |
+
class_cond,
|
133 |
+
learn_sigma,
|
134 |
+
num_channels,
|
135 |
+
num_res_blocks,
|
136 |
+
channel_mult,
|
137 |
+
num_heads,
|
138 |
+
num_head_channels,
|
139 |
+
num_heads_upsample,
|
140 |
+
attention_resolutions,
|
141 |
+
dropout,
|
142 |
+
diffusion_steps,
|
143 |
+
noise_schedule,
|
144 |
+
timestep_respacing,
|
145 |
+
use_kl,
|
146 |
+
predict_xstart,
|
147 |
+
rescale_timesteps,
|
148 |
+
rescale_learned_sigmas,
|
149 |
+
use_checkpoint,
|
150 |
+
use_scale_shift_norm,
|
151 |
+
resblock_updown,
|
152 |
+
use_fp16,
|
153 |
+
use_new_attention_order,
|
154 |
+
):
|
155 |
+
model = create_model(
|
156 |
+
image_size,
|
157 |
+
num_channels,
|
158 |
+
num_res_blocks,
|
159 |
+
in_channels,
|
160 |
+
channel_mult=channel_mult,
|
161 |
+
learn_sigma=learn_sigma,
|
162 |
+
class_cond=class_cond,
|
163 |
+
use_checkpoint=use_checkpoint,
|
164 |
+
attention_resolutions=attention_resolutions,
|
165 |
+
num_heads=num_heads,
|
166 |
+
num_head_channels=num_head_channels,
|
167 |
+
num_heads_upsample=num_heads_upsample,
|
168 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
169 |
+
dropout=dropout,
|
170 |
+
resblock_updown=resblock_updown,
|
171 |
+
use_fp16=use_fp16,
|
172 |
+
use_new_attention_order=use_new_attention_order,
|
173 |
+
)
|
174 |
+
diffusion = create_gaussian_diffusion(
|
175 |
+
steps=diffusion_steps,
|
176 |
+
learn_sigma=learn_sigma,
|
177 |
+
noise_schedule=noise_schedule,
|
178 |
+
use_kl=use_kl,
|
179 |
+
predict_xstart=predict_xstart,
|
180 |
+
rescale_timesteps=rescale_timesteps,
|
181 |
+
rescale_learned_sigmas=rescale_learned_sigmas,
|
182 |
+
timestep_respacing=timestep_respacing,
|
183 |
+
)
|
184 |
+
return model, diffusion
|
185 |
+
|
186 |
+
|
187 |
+
def create_model(
|
188 |
+
image_size,
|
189 |
+
num_channels,
|
190 |
+
num_res_blocks,
|
191 |
+
in_channels=3,
|
192 |
+
channel_mult="",
|
193 |
+
learn_sigma=False,
|
194 |
+
class_cond=False,
|
195 |
+
use_checkpoint=False,
|
196 |
+
attention_resolutions="16",
|
197 |
+
num_heads=4,
|
198 |
+
num_head_channels=-1,
|
199 |
+
num_heads_upsample=-1,
|
200 |
+
use_scale_shift_norm=False,
|
201 |
+
dropout=0,
|
202 |
+
resblock_updown=False,
|
203 |
+
use_fp16=False,
|
204 |
+
use_new_attention_order=False,
|
205 |
+
):
|
206 |
+
image_size = image_size[-1] # if H != W, use W as image_size
|
207 |
+
if channel_mult == "":
|
208 |
+
if image_size == 512:
|
209 |
+
channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
|
210 |
+
elif image_size == 256:
|
211 |
+
channel_mult = (1, 1, 2, 2, 4, 4)
|
212 |
+
elif image_size == 128:
|
213 |
+
channel_mult = (1, 1, 2, 3, 4)
|
214 |
+
elif image_size == 64:
|
215 |
+
channel_mult = (1, 2, 3, 4)
|
216 |
+
elif image_size == 32:
|
217 |
+
channel_mult = (1, 2, 2, 2)
|
218 |
+
elif image_size == 16:
|
219 |
+
channel_mult = (1, 2, 2)
|
220 |
+
else:
|
221 |
+
raise ValueError(f"unsupported image size: {image_size}")
|
222 |
+
else:
|
223 |
+
channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
|
224 |
+
|
225 |
+
attention_ds = []
|
226 |
+
for res in attention_resolutions.split(","):
|
227 |
+
attention_ds.append(image_size // int(res))
|
228 |
+
|
229 |
+
return UNetModel(
|
230 |
+
image_size=image_size,
|
231 |
+
in_channels=in_channels,
|
232 |
+
model_channels=num_channels,
|
233 |
+
out_channels=(in_channels if not learn_sigma else 2*in_channels),
|
234 |
+
num_res_blocks=num_res_blocks,
|
235 |
+
attention_resolutions=tuple(attention_ds),
|
236 |
+
dropout=dropout,
|
237 |
+
channel_mult=channel_mult,
|
238 |
+
num_classes=(NUM_CLASSES if class_cond else None),
|
239 |
+
use_checkpoint=use_checkpoint,
|
240 |
+
use_fp16=use_fp16,
|
241 |
+
num_heads=num_heads,
|
242 |
+
num_head_channels=num_head_channels,
|
243 |
+
num_heads_upsample=num_heads_upsample,
|
244 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
245 |
+
resblock_updown=resblock_updown,
|
246 |
+
use_new_attention_order=use_new_attention_order,
|
247 |
+
)
|
248 |
+
|
249 |
+
|
250 |
+
def create_classifier_and_diffusion(
|
251 |
+
image_size,
|
252 |
+
in_channels,
|
253 |
+
classifier_use_fp16,
|
254 |
+
classifier_width,
|
255 |
+
classifier_depth,
|
256 |
+
classifier_attention_resolutions,
|
257 |
+
classifier_use_scale_shift_norm,
|
258 |
+
classifier_resblock_updown,
|
259 |
+
classifier_pool,
|
260 |
+
learn_sigma,
|
261 |
+
diffusion_steps,
|
262 |
+
noise_schedule,
|
263 |
+
timestep_respacing,
|
264 |
+
use_kl,
|
265 |
+
predict_xstart,
|
266 |
+
rescale_timesteps,
|
267 |
+
rescale_learned_sigmas,
|
268 |
+
num_classes,
|
269 |
+
chord,
|
270 |
+
):
|
271 |
+
classifier = create_classifier(
|
272 |
+
image_size,
|
273 |
+
in_channels,
|
274 |
+
classifier_use_fp16,
|
275 |
+
classifier_width,
|
276 |
+
classifier_depth,
|
277 |
+
classifier_attention_resolutions,
|
278 |
+
classifier_use_scale_shift_norm,
|
279 |
+
classifier_resblock_updown,
|
280 |
+
classifier_pool,
|
281 |
+
num_classes,
|
282 |
+
chord,
|
283 |
+
)
|
284 |
+
diffusion = create_gaussian_diffusion(
|
285 |
+
steps=diffusion_steps,
|
286 |
+
learn_sigma=learn_sigma,
|
287 |
+
noise_schedule=noise_schedule,
|
288 |
+
use_kl=use_kl,
|
289 |
+
predict_xstart=predict_xstart,
|
290 |
+
rescale_timesteps=rescale_timesteps,
|
291 |
+
rescale_learned_sigmas=rescale_learned_sigmas,
|
292 |
+
timestep_respacing=timestep_respacing,
|
293 |
+
)
|
294 |
+
return classifier, diffusion
|
295 |
+
|
296 |
+
|
297 |
+
def create_classifier(
|
298 |
+
image_size,
|
299 |
+
in_channels,
|
300 |
+
classifier_use_fp16,
|
301 |
+
classifier_width,
|
302 |
+
classifier_depth,
|
303 |
+
classifier_attention_resolutions,
|
304 |
+
classifier_use_scale_shift_norm,
|
305 |
+
classifier_resblock_updown,
|
306 |
+
classifier_pool,
|
307 |
+
num_classes,
|
308 |
+
chord,
|
309 |
+
):
|
310 |
+
image_size = image_size[-1] # if H != W, use W as image_size
|
311 |
+
if image_size == 512:
|
312 |
+
channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
|
313 |
+
elif image_size == 256:
|
314 |
+
channel_mult = (1, 1, 2, 2, 4, 4)
|
315 |
+
elif image_size == 128:
|
316 |
+
channel_mult = (1, 1, 2, 3, 4)
|
317 |
+
elif image_size == 64:
|
318 |
+
channel_mult = (1, 2, 3, 4)
|
319 |
+
elif image_size == 16: # debug data load in
|
320 |
+
channel_mult = (1, 2, 2)
|
321 |
+
else:
|
322 |
+
raise ValueError(f"unsupported image size: {image_size}")
|
323 |
+
|
324 |
+
attention_ds = []
|
325 |
+
for res in classifier_attention_resolutions.split(","):
|
326 |
+
attention_ds.append(image_size // int(res))
|
327 |
+
|
328 |
+
return EncoderUNetModel(
|
329 |
+
image_size=image_size,
|
330 |
+
in_channels=in_channels,
|
331 |
+
model_channels=classifier_width,
|
332 |
+
out_channels=num_classes,
|
333 |
+
num_res_blocks=classifier_depth,
|
334 |
+
attention_resolutions=tuple(attention_ds),
|
335 |
+
channel_mult=channel_mult,
|
336 |
+
use_fp16=classifier_use_fp16,
|
337 |
+
num_head_channels=64,
|
338 |
+
use_scale_shift_norm=classifier_use_scale_shift_norm,
|
339 |
+
resblock_updown=classifier_resblock_updown,
|
340 |
+
pool=classifier_pool,
|
341 |
+
chord=chord,
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
def sr_model_and_diffusion_defaults():
|
346 |
+
res = model_and_diffusion_defaults()
|
347 |
+
res["large_size"] = 256
|
348 |
+
res["small_size"] = 64
|
349 |
+
arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
|
350 |
+
for k in res.copy().keys():
|
351 |
+
if k not in arg_names:
|
352 |
+
del res[k]
|
353 |
+
return res
|
354 |
+
|
355 |
+
|
356 |
+
def sr_create_model_and_diffusion(
|
357 |
+
large_size,
|
358 |
+
small_size,
|
359 |
+
class_cond,
|
360 |
+
learn_sigma,
|
361 |
+
num_channels,
|
362 |
+
num_res_blocks,
|
363 |
+
num_heads,
|
364 |
+
num_head_channels,
|
365 |
+
num_heads_upsample,
|
366 |
+
attention_resolutions,
|
367 |
+
dropout,
|
368 |
+
diffusion_steps,
|
369 |
+
noise_schedule,
|
370 |
+
timestep_respacing,
|
371 |
+
use_kl,
|
372 |
+
predict_xstart,
|
373 |
+
rescale_timesteps,
|
374 |
+
rescale_learned_sigmas,
|
375 |
+
use_checkpoint,
|
376 |
+
use_scale_shift_norm,
|
377 |
+
resblock_updown,
|
378 |
+
use_fp16,
|
379 |
+
):
|
380 |
+
model = sr_create_model(
|
381 |
+
large_size,
|
382 |
+
small_size,
|
383 |
+
num_channels,
|
384 |
+
num_res_blocks,
|
385 |
+
learn_sigma=learn_sigma,
|
386 |
+
class_cond=class_cond,
|
387 |
+
use_checkpoint=use_checkpoint,
|
388 |
+
attention_resolutions=attention_resolutions,
|
389 |
+
num_heads=num_heads,
|
390 |
+
num_head_channels=num_head_channels,
|
391 |
+
num_heads_upsample=num_heads_upsample,
|
392 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
393 |
+
dropout=dropout,
|
394 |
+
resblock_updown=resblock_updown,
|
395 |
+
use_fp16=use_fp16,
|
396 |
+
)
|
397 |
+
diffusion = create_gaussian_diffusion(
|
398 |
+
steps=diffusion_steps,
|
399 |
+
learn_sigma=learn_sigma,
|
400 |
+
noise_schedule=noise_schedule,
|
401 |
+
use_kl=use_kl,
|
402 |
+
predict_xstart=predict_xstart,
|
403 |
+
rescale_timesteps=rescale_timesteps,
|
404 |
+
rescale_learned_sigmas=rescale_learned_sigmas,
|
405 |
+
timestep_respacing=timestep_respacing,
|
406 |
+
)
|
407 |
+
return model, diffusion
|
408 |
+
|
409 |
+
|
410 |
+
def sr_create_model(
|
411 |
+
large_size,
|
412 |
+
small_size,
|
413 |
+
num_channels,
|
414 |
+
num_res_blocks,
|
415 |
+
learn_sigma,
|
416 |
+
class_cond,
|
417 |
+
use_checkpoint,
|
418 |
+
attention_resolutions,
|
419 |
+
num_heads,
|
420 |
+
num_head_channels,
|
421 |
+
num_heads_upsample,
|
422 |
+
use_scale_shift_norm,
|
423 |
+
dropout,
|
424 |
+
resblock_updown,
|
425 |
+
use_fp16,
|
426 |
+
):
|
427 |
+
_ = small_size # hack to prevent unused variable
|
428 |
+
|
429 |
+
if large_size == 512:
|
430 |
+
channel_mult = (1, 1, 2, 2, 4, 4)
|
431 |
+
elif large_size == 256:
|
432 |
+
channel_mult = (1, 1, 2, 2, 4, 4)
|
433 |
+
elif large_size == 64:
|
434 |
+
channel_mult = (1, 2, 3, 4)
|
435 |
+
else:
|
436 |
+
raise ValueError(f"unsupported large size: {large_size}")
|
437 |
+
|
438 |
+
attention_ds = []
|
439 |
+
for res in attention_resolutions.split(","):
|
440 |
+
attention_ds.append(large_size // int(res))
|
441 |
+
|
442 |
+
return SuperResModel(
|
443 |
+
image_size=large_size,
|
444 |
+
in_channels=3,
|
445 |
+
model_channels=num_channels,
|
446 |
+
out_channels=(3 if not learn_sigma else 6),
|
447 |
+
num_res_blocks=num_res_blocks,
|
448 |
+
attention_resolutions=tuple(attention_ds),
|
449 |
+
dropout=dropout,
|
450 |
+
channel_mult=channel_mult,
|
451 |
+
num_classes=(NUM_CLASSES if class_cond else None),
|
452 |
+
use_checkpoint=use_checkpoint,
|
453 |
+
num_heads=num_heads,
|
454 |
+
num_head_channels=num_head_channels,
|
455 |
+
num_heads_upsample=num_heads_upsample,
|
456 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
457 |
+
resblock_updown=resblock_updown,
|
458 |
+
use_fp16=use_fp16,
|
459 |
+
)
|
460 |
+
|
461 |
+
|
462 |
+
def create_gaussian_diffusion(
|
463 |
+
*,
|
464 |
+
steps=1000,
|
465 |
+
learn_sigma=False,
|
466 |
+
sigma_small=False,
|
467 |
+
noise_schedule="linear",
|
468 |
+
use_kl=False,
|
469 |
+
predict_xstart=False,
|
470 |
+
rescale_timesteps=False,
|
471 |
+
rescale_learned_sigmas=False,
|
472 |
+
timestep_respacing="",
|
473 |
+
):
|
474 |
+
betas = gd.get_named_beta_schedule(noise_schedule, steps)
|
475 |
+
if use_kl:
|
476 |
+
loss_type = gd.LossType.RESCALED_KL
|
477 |
+
elif rescale_learned_sigmas:
|
478 |
+
loss_type = gd.LossType.RESCALED_MSE
|
479 |
+
else:
|
480 |
+
loss_type = gd.LossType.MSE
|
481 |
+
if not timestep_respacing:
|
482 |
+
timestep_respacing = [steps]
|
483 |
+
return SpacedDiffusion(
|
484 |
+
use_timesteps=space_timesteps(steps, timestep_respacing),
|
485 |
+
betas=betas,
|
486 |
+
model_mean_type=(
|
487 |
+
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
|
488 |
+
),
|
489 |
+
model_var_type=(
|
490 |
+
(
|
491 |
+
gd.ModelVarType.FIXED_LARGE
|
492 |
+
if not sigma_small
|
493 |
+
else gd.ModelVarType.FIXED_SMALL
|
494 |
+
)
|
495 |
+
if not learn_sigma
|
496 |
+
else gd.ModelVarType.LEARNED_RANGE
|
497 |
+
),
|
498 |
+
loss_type=loss_type,
|
499 |
+
rescale_timesteps=rescale_timesteps,
|
500 |
+
)
|
501 |
+
|
502 |
+
|
503 |
+
def add_dict_to_argparser(parser, default_dict):
|
504 |
+
for k, v in default_dict.items():
|
505 |
+
v_type = type(v)
|
506 |
+
if v is None:
|
507 |
+
v_type = str
|
508 |
+
elif isinstance(v, bool):
|
509 |
+
v_type = str2bool
|
510 |
+
if k == 'image_size':
|
511 |
+
parser.add_argument(f"--{k}", nargs='+', default=v, type=v_type)
|
512 |
+
else:
|
513 |
+
parser.add_argument(f"--{k}", default=v, type=v_type)
|
514 |
+
|
515 |
+
|
516 |
+
def args_to_dict(args, keys):
|
517 |
+
return {k: getattr(args, k) for k in keys}
|
518 |
+
|
519 |
+
|
520 |
+
def str2bool(v):
|
521 |
+
"""
|
522 |
+
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
523 |
+
"""
|
524 |
+
if isinstance(v, bool):
|
525 |
+
return v
|
526 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
527 |
+
return True
|
528 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
529 |
+
return False
|
530 |
+
else:
|
531 |
+
raise argparse.ArgumentTypeError("boolean value expected")
|
guided_diffusion/train_util.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
|
8 |
+
import blobfile as bf
|
9 |
+
import torch as th
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.distributed as dist
|
12 |
+
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
13 |
+
from torch.optim import AdamW
|
14 |
+
|
15 |
+
from . import dist_util, midi_util, logger
|
16 |
+
from .fp16_util import MixedPrecisionTrainer
|
17 |
+
from .nn import update_ema
|
18 |
+
from .resample import LossAwareSampler, UniformSampler
|
19 |
+
from taming.modules.distributions.distributions import DiagonalGaussianDistribution
|
20 |
+
|
21 |
+
# For ImageNet experiments, this was a good default value.
|
22 |
+
# We found that the lg_loss_scale quickly climbed to
|
23 |
+
# 20-21 within the first ~1K steps of training.
|
24 |
+
INITIAL_LOG_LOSS_SCALE = 20.0
|
25 |
+
|
26 |
+
|
27 |
+
class TrainLoop:
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
*,
|
31 |
+
model,
|
32 |
+
eval_model,
|
33 |
+
diffusion,
|
34 |
+
data,
|
35 |
+
batch_size,
|
36 |
+
microbatch,
|
37 |
+
lr,
|
38 |
+
ema_rate,
|
39 |
+
log_interval,
|
40 |
+
save_interval,
|
41 |
+
resume_checkpoint,
|
42 |
+
embed_model=None,
|
43 |
+
use_fp16=False,
|
44 |
+
fp16_scale_growth=1e-3,
|
45 |
+
schedule_sampler=None,
|
46 |
+
weight_decay=0.0,
|
47 |
+
lr_anneal_steps=0,
|
48 |
+
eval_data=None,
|
49 |
+
eval_interval=-1,
|
50 |
+
eval_sample_batch_size=16,
|
51 |
+
total_num_gpus=1, # training is run on how many gpus, used to distribute classes on each gpu
|
52 |
+
eval_sample_use_ddim=True,
|
53 |
+
eval_sample_clip_denoised=True,
|
54 |
+
in_channels=1,
|
55 |
+
fs=100,
|
56 |
+
pedal=False, # whether decode with pedal as the second channel
|
57 |
+
scale_factor=1.,
|
58 |
+
num_classes=0, # whether to use class_cond in sampling
|
59 |
+
microbatch_encode=-1,
|
60 |
+
encode_rep=4,
|
61 |
+
shift_size=4, # shift_size when generating time shifted sampels from an encoding
|
62 |
+
):
|
63 |
+
self.model = model
|
64 |
+
self.eval_model = eval_model
|
65 |
+
self.embed_model = embed_model
|
66 |
+
self.scale_factor = scale_factor
|
67 |
+
self.diffusion = diffusion
|
68 |
+
self.data = data
|
69 |
+
self.batch_size = batch_size
|
70 |
+
self.microbatch = microbatch if microbatch > 0 else batch_size
|
71 |
+
self.microbatch_encode = microbatch_encode
|
72 |
+
self.encode_rep = encode_rep
|
73 |
+
self.batch_size = self.batch_size // self.encode_rep # effective batch size
|
74 |
+
self.microbatch = self.microbatch // self.encode_rep
|
75 |
+
self.shift_size = shift_size # need to be compatible with encode_rep
|
76 |
+
self.lr = lr
|
77 |
+
self.ema_rate = (
|
78 |
+
[ema_rate]
|
79 |
+
if isinstance(ema_rate, float)
|
80 |
+
else [float(x) for x in ema_rate.split(",")]
|
81 |
+
)
|
82 |
+
self.log_interval = log_interval
|
83 |
+
self.save_interval = save_interval
|
84 |
+
self.resume_checkpoint = resume_checkpoint
|
85 |
+
self.use_fp16 = use_fp16
|
86 |
+
self.fp16_scale_growth = fp16_scale_growth
|
87 |
+
self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
|
88 |
+
self.weight_decay = weight_decay
|
89 |
+
self.lr_anneal_steps = lr_anneal_steps
|
90 |
+
# eval
|
91 |
+
self.eval_data = eval_data
|
92 |
+
self.eval_interval = eval_interval
|
93 |
+
self.total_num_gpus = total_num_gpus
|
94 |
+
self.eval_sample_batch_size = eval_sample_batch_size // self.total_num_gpus
|
95 |
+
self.eval_sample_use_ddim = eval_sample_use_ddim
|
96 |
+
self.eval_sample_clip_denoised = eval_sample_clip_denoised
|
97 |
+
self.in_channels = in_channels
|
98 |
+
self.fs = fs
|
99 |
+
self.pedal = pedal
|
100 |
+
self.num_classes = num_classes
|
101 |
+
|
102 |
+
self.step = 0
|
103 |
+
self.resume_step = 0
|
104 |
+
self.global_batch = self.batch_size * dist.get_world_size()
|
105 |
+
|
106 |
+
self.sync_cuda = th.cuda.is_available()
|
107 |
+
|
108 |
+
self._load_and_sync_parameters()
|
109 |
+
self.mp_trainer = MixedPrecisionTrainer(
|
110 |
+
model=self.model,
|
111 |
+
use_fp16=self.use_fp16,
|
112 |
+
fp16_scale_growth=fp16_scale_growth,
|
113 |
+
)
|
114 |
+
|
115 |
+
self.opt = AdamW(
|
116 |
+
self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
|
117 |
+
)
|
118 |
+
if self.resume_step:
|
119 |
+
self._load_optimizer_state()
|
120 |
+
# Model was resumed, either due to a restart or a checkpoint
|
121 |
+
# being specified at the command line.
|
122 |
+
self.ema_params = [
|
123 |
+
self._load_ema_parameters(rate) for rate in self.ema_rate
|
124 |
+
]
|
125 |
+
else:
|
126 |
+
self.ema_params = [
|
127 |
+
copy.deepcopy(self.mp_trainer.master_params)
|
128 |
+
for _ in range(len(self.ema_rate))
|
129 |
+
]
|
130 |
+
|
131 |
+
if th.cuda.is_available():
|
132 |
+
self.use_ddp = True
|
133 |
+
self.ddp_model = DDP(
|
134 |
+
self.model,
|
135 |
+
device_ids=[dist_util.dev()],
|
136 |
+
output_device=dist_util.dev(),
|
137 |
+
broadcast_buffers=False,
|
138 |
+
bucket_cap_mb=128,
|
139 |
+
find_unused_parameters=False,
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
if dist.get_world_size() > 1:
|
143 |
+
logger.warn(
|
144 |
+
"Distributed training requires CUDA. "
|
145 |
+
"Gradients will not be synchronized properly!"
|
146 |
+
)
|
147 |
+
self.use_ddp = False
|
148 |
+
self.ddp_model = self.model
|
149 |
+
|
150 |
+
def _load_and_sync_parameters(self):
|
151 |
+
resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
|
152 |
+
|
153 |
+
if resume_checkpoint:
|
154 |
+
self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
|
155 |
+
logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
|
156 |
+
self.model.load_state_dict(
|
157 |
+
dist_util.load_state_dict(
|
158 |
+
resume_checkpoint, map_location=dist_util.dev()
|
159 |
+
)
|
160 |
+
)
|
161 |
+
|
162 |
+
dist_util.sync_params(self.model.parameters())
|
163 |
+
|
164 |
+
def _load_ema_parameters(self, rate):
|
165 |
+
ema_params = copy.deepcopy(self.mp_trainer.master_params)
|
166 |
+
|
167 |
+
main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
|
168 |
+
ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
|
169 |
+
if ema_checkpoint:
|
170 |
+
logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
|
171 |
+
state_dict = dist_util.load_state_dict(
|
172 |
+
ema_checkpoint, map_location=dist_util.dev()
|
173 |
+
)
|
174 |
+
ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
|
175 |
+
|
176 |
+
dist_util.sync_params(ema_params)
|
177 |
+
return ema_params
|
178 |
+
|
179 |
+
def _load_optimizer_state(self):
|
180 |
+
main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
|
181 |
+
opt_checkpoint = bf.join(
|
182 |
+
bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
|
183 |
+
)
|
184 |
+
if bf.exists(opt_checkpoint):
|
185 |
+
logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
|
186 |
+
state_dict = dist_util.load_state_dict(
|
187 |
+
opt_checkpoint, map_location=dist_util.dev()
|
188 |
+
)
|
189 |
+
self.opt.load_state_dict(state_dict)
|
190 |
+
|
191 |
+
def run_loop(self):
|
192 |
+
while (
|
193 |
+
not self.lr_anneal_steps
|
194 |
+
or self.step + self.resume_step < self.lr_anneal_steps
|
195 |
+
):
|
196 |
+
batch, cond = next(self.data)
|
197 |
+
dist.barrier()
|
198 |
+
self.run_step(batch, cond)
|
199 |
+
if self.eval_data is not None and self.step % self.eval_interval == 0:
|
200 |
+
batch_eval, cond_eval = next(self.eval_data)
|
201 |
+
self.run_step_eval(batch_eval, cond_eval)
|
202 |
+
if self.step % self.log_interval == 0:
|
203 |
+
logger.dumpkvs()
|
204 |
+
if self.step % self.save_interval == 0 and self.step != 0:
|
205 |
+
self.save()
|
206 |
+
# Run for a finite amount of time in integration tests.
|
207 |
+
if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
|
208 |
+
return
|
209 |
+
self.step += 1
|
210 |
+
# Save the last checkpoint if it wasn't already saved.
|
211 |
+
if (self.step - 1) % self.save_interval != 0:
|
212 |
+
self.save()
|
213 |
+
|
214 |
+
def run_step(self, batch, cond):
|
215 |
+
self.forward_backward(batch, cond)
|
216 |
+
took_step = self.mp_trainer.optimize(self.opt)
|
217 |
+
if took_step:
|
218 |
+
self._update_ema()
|
219 |
+
self._anneal_lr()
|
220 |
+
self.log_step()
|
221 |
+
|
222 |
+
def run_step_eval(self, batch, cond):
|
223 |
+
with th.no_grad():
|
224 |
+
# load in ema_params for eval_model in cpu, then move to gpu
|
225 |
+
# only use the first ema rate if there are multiple ema rate
|
226 |
+
ema_state_dict = self.mp_trainer.master_params_to_state_dict(self.ema_params[0])
|
227 |
+
# ema_state_dict_cpu = {k: v.cpu() for k, v in ema_state_dict.items()}
|
228 |
+
# self.eval_model.load_state_dict(ema_state_dict_cpu)
|
229 |
+
self.eval_model.load_state_dict(ema_state_dict)
|
230 |
+
# self.eval_model.to(dist_util.dev())
|
231 |
+
if self.use_fp16:
|
232 |
+
self.eval_model.convert_to_fp16()
|
233 |
+
self.eval_model.eval()
|
234 |
+
for i in range(0, batch.shape[0], self.microbatch):
|
235 |
+
micro = batch[i: i + self.microbatch].to(dist_util.dev())
|
236 |
+
if self.embed_model is not None:
|
237 |
+
micro = get_kl_input(micro, microbatch=self.microbatch_encode,
|
238 |
+
model=self.embed_model, scale_factor=self.scale_factor,
|
239 |
+
shift_size=self.shift_size)
|
240 |
+
micro_cond = {
|
241 |
+
k: v[i: i + self.microbatch].repeat_interleave(self.encode_rep).to(dist_util.dev())
|
242 |
+
for k, v in cond.items()
|
243 |
+
}
|
244 |
+
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
|
245 |
+
|
246 |
+
compute_losses = functools.partial(
|
247 |
+
self.diffusion.training_losses,
|
248 |
+
self.eval_model,
|
249 |
+
micro,
|
250 |
+
t,
|
251 |
+
model_kwargs=micro_cond,
|
252 |
+
)
|
253 |
+
losses = compute_losses()
|
254 |
+
log_loss_dict(
|
255 |
+
self.diffusion, t, {'eval_'+k: v * weights for k, v in losses.items()}
|
256 |
+
)
|
257 |
+
if self.eval_sample_batch_size > 0 and self.step != 0:
|
258 |
+
# if True:
|
259 |
+
model_kwargs = {}
|
260 |
+
if self.num_classes > 0:
|
261 |
+
# classes = th.randint(
|
262 |
+
# low=0, high=self.num_classes, size=(self.eval_sample_batch_size,), device=dist_util.dev()
|
263 |
+
# )
|
264 |
+
# balance generated classes
|
265 |
+
rank = dist.get_rank()
|
266 |
+
samples_per_class = math.ceil(self.eval_sample_batch_size * self.total_num_gpus / self.num_classes)
|
267 |
+
label_start = rank * self.eval_sample_batch_size // samples_per_class
|
268 |
+
label_end = math.ceil((rank + 1) * self.eval_sample_batch_size / samples_per_class)
|
269 |
+
classes = th.arange(label_start, label_end, dtype=th.int, device=dist_util.dev()).repeat_interleave(samples_per_class)
|
270 |
+
model_kwargs["y"] = classes[:self.eval_sample_batch_size]
|
271 |
+
all_images = []
|
272 |
+
all_labels = []
|
273 |
+
image_size_h = micro.shape[-2]
|
274 |
+
image_size_w = micro.shape[-1]
|
275 |
+
sample_fn = (
|
276 |
+
self.diffusion.p_sample_loop if not self.eval_sample_use_ddim else self.diffusion.ddim_sample_loop
|
277 |
+
)
|
278 |
+
sample = sample_fn(
|
279 |
+
self.eval_model,
|
280 |
+
(self.eval_sample_batch_size, self.in_channels, image_size_h, image_size_w),
|
281 |
+
# (4, self.in_channels, image_size_h, image_size_w),
|
282 |
+
clip_denoised=self.eval_sample_clip_denoised,
|
283 |
+
model_kwargs=model_kwargs,
|
284 |
+
progress=True
|
285 |
+
)
|
286 |
+
##### debug
|
287 |
+
# sample = micro
|
288 |
+
sample = midi_util.decode_sample_for_midi(sample, embed_model=self.embed_model,
|
289 |
+
scale_factor=self.scale_factor, threshold=-0.95)
|
290 |
+
|
291 |
+
gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
|
292 |
+
dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
|
293 |
+
all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
|
294 |
+
if self.num_classes > 0:
|
295 |
+
gathered_labels = [
|
296 |
+
th.zeros_like(model_kwargs["y"]) for _ in range(dist.get_world_size())
|
297 |
+
]
|
298 |
+
dist.all_gather(gathered_labels, model_kwargs["y"])
|
299 |
+
all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
|
300 |
+
|
301 |
+
arr = np.concatenate(all_images, axis=0)
|
302 |
+
if arr.shape[-1] == 1: # no pedal, need shape B x 128 x 1024
|
303 |
+
arr = arr.squeeze(axis=-1)
|
304 |
+
else: # with pedal, need shape: B x 2 x 128 x 1024
|
305 |
+
arr = arr.transpose(0, 3, 1, 2)
|
306 |
+
if self.num_classes > 0:
|
307 |
+
label_arr = np.concatenate(all_labels, axis=0)
|
308 |
+
save_dir = osp.join(get_blob_logdir(), "samples", "iter_" + str(self.step + self.resume_step))
|
309 |
+
os.makedirs(os.path.expanduser(save_dir), exist_ok=True)
|
310 |
+
if dist.get_rank() == 0:
|
311 |
+
if self.num_classes > 0:
|
312 |
+
midi_util.save_piano_roll_midi(arr, save_dir, self.fs, y=label_arr)
|
313 |
+
else:
|
314 |
+
midi_util.save_piano_roll_midi(arr, save_dir, self.fs)
|
315 |
+
dist.barrier()
|
316 |
+
# # put the model on cpu to prepare for next loading
|
317 |
+
# self.eval_model.to("cpu")
|
318 |
+
|
319 |
+
def forward_backward(self, batch, cond):
|
320 |
+
self.mp_trainer.zero_grad()
|
321 |
+
for i in range(0, batch.shape[0], self.microbatch):
|
322 |
+
micro = batch[i : i + self.microbatch].to(dist_util.dev())
|
323 |
+
if self.embed_model is not None:
|
324 |
+
micro = get_kl_input(micro, microbatch=self.microbatch_encode,
|
325 |
+
model=self.embed_model, scale_factor=self.scale_factor,
|
326 |
+
shift_size=self.shift_size)
|
327 |
+
micro_cond = {
|
328 |
+
k: v[i : i + self.microbatch].repeat_interleave(self.encode_rep).to(dist_util.dev())
|
329 |
+
for k, v in cond.items()
|
330 |
+
}
|
331 |
+
last_batch = (i + self.microbatch) >= self.batch_size
|
332 |
+
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
|
333 |
+
|
334 |
+
compute_losses = functools.partial(
|
335 |
+
self.diffusion.training_losses,
|
336 |
+
self.ddp_model,
|
337 |
+
micro,
|
338 |
+
t,
|
339 |
+
model_kwargs=micro_cond,
|
340 |
+
)
|
341 |
+
|
342 |
+
if last_batch or not self.use_ddp:
|
343 |
+
losses = compute_losses()
|
344 |
+
else:
|
345 |
+
with self.ddp_model.no_sync():
|
346 |
+
losses = compute_losses()
|
347 |
+
|
348 |
+
if isinstance(self.schedule_sampler, LossAwareSampler):
|
349 |
+
self.schedule_sampler.update_with_local_losses(
|
350 |
+
t, losses["loss"].detach()
|
351 |
+
)
|
352 |
+
|
353 |
+
loss = (losses["loss"] * weights).mean()
|
354 |
+
log_loss_dict(
|
355 |
+
self.diffusion, t, {k: v * weights for k, v in losses.items()}
|
356 |
+
)
|
357 |
+
self.mp_trainer.backward(loss)
|
358 |
+
# # keep gpu mem constant?
|
359 |
+
# del losses
|
360 |
+
|
361 |
+
def _update_ema(self):
|
362 |
+
for rate, params in zip(self.ema_rate, self.ema_params):
|
363 |
+
update_ema(params, self.mp_trainer.master_params, rate=rate)
|
364 |
+
|
365 |
+
def _anneal_lr(self):
|
366 |
+
if not self.lr_anneal_steps:
|
367 |
+
return
|
368 |
+
frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
|
369 |
+
lr = self.lr * (1 - frac_done)
|
370 |
+
for param_group in self.opt.param_groups:
|
371 |
+
param_group["lr"] = lr
|
372 |
+
|
373 |
+
def log_step(self):
|
374 |
+
logger.logkv("step", self.step + self.resume_step)
|
375 |
+
logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
|
376 |
+
|
377 |
+
def save(self):
|
378 |
+
def save_checkpoint(rate, params):
|
379 |
+
state_dict = self.mp_trainer.master_params_to_state_dict(params)
|
380 |
+
if dist.get_rank() == 0:
|
381 |
+
logger.log(f"saving model {rate}...")
|
382 |
+
if not rate:
|
383 |
+
filename = f"model{(self.step+self.resume_step):06d}.pt"
|
384 |
+
else:
|
385 |
+
filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
|
386 |
+
with bf.BlobFile(bf.join(get_blob_logdir(), "checkpoints", filename), "wb") as f:
|
387 |
+
th.save(state_dict, f)
|
388 |
+
|
389 |
+
save_checkpoint(0, self.mp_trainer.master_params)
|
390 |
+
for rate, params in zip(self.ema_rate, self.ema_params):
|
391 |
+
save_checkpoint(rate, params)
|
392 |
+
|
393 |
+
if dist.get_rank() == 0:
|
394 |
+
with bf.BlobFile(
|
395 |
+
bf.join(get_blob_logdir(), "checkpoints", f"opt{(self.step+self.resume_step):06d}.pt"),
|
396 |
+
"wb",
|
397 |
+
) as f:
|
398 |
+
th.save(self.opt.state_dict(), f)
|
399 |
+
|
400 |
+
dist.barrier()
|
401 |
+
|
402 |
+
|
403 |
+
@th.no_grad()
|
404 |
+
def get_kl_input(batch, microbatch=-1, model=None, scale_factor=1., recombine=True, shift_size=4):
|
405 |
+
# here microbatch should be outer microbatch // encode_rep
|
406 |
+
if microbatch < 0:
|
407 |
+
microbatch = batch.shape[0]
|
408 |
+
full_z = []
|
409 |
+
image_size_h = batch.shape[-2]
|
410 |
+
image_size_w = batch.shape[-1]
|
411 |
+
seq_len = image_size_w // image_size_h
|
412 |
+
for i in range(0, batch.shape[0], microbatch):
|
413 |
+
micro = batch[i : i + microbatch].to(dist_util.dev())
|
414 |
+
# encode each 1s and concatenate
|
415 |
+
micro = th.chunk(micro, seq_len, dim=-1) # B x C x H x W
|
416 |
+
micro = th.concat(micro, dim=0) # 1st second for all batch, 2nd second for all batch, ...
|
417 |
+
micro = model.encode_save(micro, range_fix=False)
|
418 |
+
posterior = DiagonalGaussianDistribution(micro)
|
419 |
+
# z = posterior.sample()
|
420 |
+
z = posterior.mode()
|
421 |
+
z = th.concat(th.chunk(z, seq_len, dim=0), dim=-1)
|
422 |
+
z = z.permute(0, 1, 3, 2)
|
423 |
+
full_z.append(z)
|
424 |
+
full_z = th.concat(full_z, dim=0) # B x 4 x (15x16), 16
|
425 |
+
if recombine: # if not using microbatch, then need to use recombination of tokens
|
426 |
+
# unfold: dimension, size, step
|
427 |
+
full_z = full_z.unfold(2, 8*16, 16*shift_size).permute(0, 2, 1, 4, 3) # (B x encode_rep) x 4 x 128 x 16
|
428 |
+
full_z = full_z.contiguous().view(-1, 4, 8*16, 16) # B x 4 x 128 x 16
|
429 |
+
return (full_z * scale_factor).detach()
|
430 |
+
|
431 |
+
|
432 |
+
def parse_resume_step_from_filename(filename):
|
433 |
+
"""
|
434 |
+
Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
|
435 |
+
checkpoint's number of steps.
|
436 |
+
"""
|
437 |
+
split = filename.split("model")
|
438 |
+
if len(split) < 2:
|
439 |
+
return 0
|
440 |
+
split1 = split[-1].split(".")[0]
|
441 |
+
try:
|
442 |
+
return int(split1)
|
443 |
+
except ValueError:
|
444 |
+
return 0
|
445 |
+
|
446 |
+
|
447 |
+
def get_blob_logdir():
|
448 |
+
# You can change this to be a separate path to save checkpoints to
|
449 |
+
# a blobstore or some external drive.
|
450 |
+
return logger.get_dir()
|
451 |
+
|
452 |
+
|
453 |
+
def find_resume_checkpoint():
|
454 |
+
# On your infrastructure, you may want to override this to automatically
|
455 |
+
# discover the latest checkpoint on your blob storage, etc.
|
456 |
+
return None
|
457 |
+
|
458 |
+
|
459 |
+
def find_ema_checkpoint(main_checkpoint, step, rate):
|
460 |
+
if main_checkpoint is None:
|
461 |
+
return None
|
462 |
+
filename = f"ema_{rate}_{(step):06d}.pt"
|
463 |
+
path = bf.join(bf.dirname(main_checkpoint), filename)
|
464 |
+
if bf.exists(path):
|
465 |
+
return path
|
466 |
+
return None
|
467 |
+
|
468 |
+
|
469 |
+
def log_loss_dict(diffusion, ts, losses):
|
470 |
+
for key, values in losses.items():
|
471 |
+
logger.logkv_mean(key, values.mean().item())
|
472 |
+
# Log the quantiles (four quartiles, in particular).
|
473 |
+
for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
|
474 |
+
quartile = int(4 * sub_t / diffusion.num_timesteps)
|
475 |
+
logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
|
guided_diffusion/unet.py
ADDED
@@ -0,0 +1,906 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch as th
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from .fp16_util import convert_module_to_f16, convert_module_to_f32
|
11 |
+
from .nn import (
|
12 |
+
checkpoint,
|
13 |
+
conv_nd,
|
14 |
+
linear,
|
15 |
+
avg_pool_nd,
|
16 |
+
zero_module,
|
17 |
+
normalization,
|
18 |
+
timestep_embedding,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class AttentionPool2d(nn.Module):
|
23 |
+
"""
|
24 |
+
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
spacial_dim: int,
|
30 |
+
embed_dim: int,
|
31 |
+
num_heads_channels: int,
|
32 |
+
output_dim: int = None,
|
33 |
+
chord: bool = False,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
self.positional_embedding = nn.Parameter(
|
37 |
+
th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
|
38 |
+
)
|
39 |
+
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
40 |
+
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
41 |
+
self.chord = chord
|
42 |
+
if chord:
|
43 |
+
self.c_proj_key = conv_nd(1, embed_dim, 25, 1)
|
44 |
+
self.num_heads = embed_dim // num_heads_channels
|
45 |
+
self.attention = QKVAttention(self.num_heads)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
b, c, *_spatial = x.shape
|
49 |
+
x = x.reshape(b, c, -1) # NC(HW)
|
50 |
+
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
51 |
+
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
52 |
+
x = self.qkv_proj(x)
|
53 |
+
x = self.attention(x)
|
54 |
+
if self.chord:
|
55 |
+
x_key = self.c_proj_key(x)
|
56 |
+
key = x_key[:, :, 0]
|
57 |
+
x_chord = self.c_proj(x)[:, :, 1:]
|
58 |
+
chord = x_chord.reshape(b, -1, *_spatial).mean(dim=2).permute(0, 2, 1)
|
59 |
+
return key, chord
|
60 |
+
else:
|
61 |
+
x = self.c_proj(x)
|
62 |
+
return x[:, :, 0]
|
63 |
+
|
64 |
+
|
65 |
+
class TimestepBlock(nn.Module):
|
66 |
+
"""
|
67 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
68 |
+
"""
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
def forward(self, x, emb):
|
72 |
+
"""
|
73 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
74 |
+
"""
|
75 |
+
|
76 |
+
|
77 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
78 |
+
"""
|
79 |
+
A sequential module that passes timestep embeddings to the children that
|
80 |
+
support it as an extra input.
|
81 |
+
"""
|
82 |
+
|
83 |
+
def forward(self, x, emb):
|
84 |
+
for layer in self:
|
85 |
+
if isinstance(layer, TimestepBlock):
|
86 |
+
x = layer(x, emb)
|
87 |
+
else:
|
88 |
+
x = layer(x)
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
class Upsample(nn.Module):
|
93 |
+
"""
|
94 |
+
An upsampling layer with an optional convolution.
|
95 |
+
|
96 |
+
:param channels: channels in the inputs and outputs.
|
97 |
+
:param use_conv: a bool determining if a convolution is applied.
|
98 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
99 |
+
upsampling occurs in the inner-two dimensions.
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
103 |
+
super().__init__()
|
104 |
+
self.channels = channels
|
105 |
+
self.out_channels = out_channels or channels
|
106 |
+
self.use_conv = use_conv
|
107 |
+
self.dims = dims
|
108 |
+
if use_conv:
|
109 |
+
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
assert x.shape[1] == self.channels
|
113 |
+
if self.dims == 3:
|
114 |
+
x = F.interpolate(
|
115 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
119 |
+
if self.use_conv:
|
120 |
+
x = self.conv(x)
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class Downsample(nn.Module):
|
125 |
+
"""
|
126 |
+
A downsampling layer with an optional convolution.
|
127 |
+
|
128 |
+
:param channels: channels in the inputs and outputs.
|
129 |
+
:param use_conv: a bool determining if a convolution is applied.
|
130 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
131 |
+
downsampling occurs in the inner-two dimensions.
|
132 |
+
"""
|
133 |
+
|
134 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
135 |
+
super().__init__()
|
136 |
+
self.channels = channels
|
137 |
+
self.out_channels = out_channels or channels
|
138 |
+
self.use_conv = use_conv
|
139 |
+
self.dims = dims
|
140 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
141 |
+
if use_conv:
|
142 |
+
self.op = conv_nd(
|
143 |
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
assert self.channels == self.out_channels
|
147 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
assert x.shape[1] == self.channels
|
151 |
+
return self.op(x)
|
152 |
+
|
153 |
+
|
154 |
+
class ResBlock(TimestepBlock):
|
155 |
+
"""
|
156 |
+
A residual block that can optionally change the number of channels.
|
157 |
+
|
158 |
+
:param channels: the number of input channels.
|
159 |
+
:param emb_channels: the number of timestep embedding channels.
|
160 |
+
:param dropout: the rate of dropout.
|
161 |
+
:param out_channels: if specified, the number of out channels.
|
162 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
163 |
+
convolution instead of a smaller 1x1 convolution to change the
|
164 |
+
channels in the skip connection.
|
165 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
166 |
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
167 |
+
:param up: if True, use this block for upsampling.
|
168 |
+
:param down: if True, use this block for downsampling.
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
channels,
|
174 |
+
emb_channels,
|
175 |
+
dropout,
|
176 |
+
out_channels=None,
|
177 |
+
use_conv=False,
|
178 |
+
use_scale_shift_norm=False,
|
179 |
+
dims=2,
|
180 |
+
use_checkpoint=False,
|
181 |
+
up=False,
|
182 |
+
down=False,
|
183 |
+
):
|
184 |
+
super().__init__()
|
185 |
+
self.channels = channels
|
186 |
+
self.emb_channels = emb_channels
|
187 |
+
self.dropout = dropout
|
188 |
+
self.out_channels = out_channels or channels
|
189 |
+
self.use_conv = use_conv
|
190 |
+
self.use_checkpoint = use_checkpoint
|
191 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
192 |
+
|
193 |
+
self.in_layers = nn.Sequential(
|
194 |
+
normalization(channels),
|
195 |
+
nn.SiLU(),
|
196 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
197 |
+
)
|
198 |
+
|
199 |
+
self.updown = up or down
|
200 |
+
|
201 |
+
if up:
|
202 |
+
self.h_upd = Upsample(channels, False, dims)
|
203 |
+
self.x_upd = Upsample(channels, False, dims)
|
204 |
+
elif down:
|
205 |
+
self.h_upd = Downsample(channels, False, dims)
|
206 |
+
self.x_upd = Downsample(channels, False, dims)
|
207 |
+
else:
|
208 |
+
self.h_upd = self.x_upd = nn.Identity()
|
209 |
+
|
210 |
+
self.emb_layers = nn.Sequential(
|
211 |
+
nn.SiLU(),
|
212 |
+
linear(
|
213 |
+
emb_channels,
|
214 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
215 |
+
),
|
216 |
+
)
|
217 |
+
self.out_layers = nn.Sequential(
|
218 |
+
normalization(self.out_channels),
|
219 |
+
nn.SiLU(),
|
220 |
+
nn.Dropout(p=dropout),
|
221 |
+
zero_module(
|
222 |
+
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
223 |
+
),
|
224 |
+
)
|
225 |
+
|
226 |
+
if self.out_channels == channels:
|
227 |
+
self.skip_connection = nn.Identity()
|
228 |
+
elif use_conv:
|
229 |
+
self.skip_connection = conv_nd(
|
230 |
+
dims, channels, self.out_channels, 3, padding=1
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
234 |
+
|
235 |
+
def forward(self, x, emb):
|
236 |
+
"""
|
237 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
238 |
+
|
239 |
+
:param x: an [N x C x ...] Tensor of features.
|
240 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
241 |
+
:return: an [N x C x ...] Tensor of outputs.
|
242 |
+
"""
|
243 |
+
return checkpoint(
|
244 |
+
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
245 |
+
)
|
246 |
+
|
247 |
+
def _forward(self, x, emb):
|
248 |
+
if self.updown:
|
249 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
250 |
+
h = in_rest(x)
|
251 |
+
h = self.h_upd(h)
|
252 |
+
x = self.x_upd(x)
|
253 |
+
h = in_conv(h)
|
254 |
+
else:
|
255 |
+
h = self.in_layers(x)
|
256 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
257 |
+
while len(emb_out.shape) < len(h.shape):
|
258 |
+
emb_out = emb_out[..., None]
|
259 |
+
if self.use_scale_shift_norm:
|
260 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
261 |
+
scale, shift = th.chunk(emb_out, 2, dim=1)
|
262 |
+
h = out_norm(h) * (1 + scale) + shift
|
263 |
+
h = out_rest(h)
|
264 |
+
else:
|
265 |
+
h = h + emb_out
|
266 |
+
h = self.out_layers(h)
|
267 |
+
return self.skip_connection(x) + h
|
268 |
+
|
269 |
+
|
270 |
+
class AttentionBlock(nn.Module):
|
271 |
+
"""
|
272 |
+
An attention block that allows spatial positions to attend to each other.
|
273 |
+
|
274 |
+
Originally ported from here, but adapted to the N-d case.
|
275 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
276 |
+
"""
|
277 |
+
|
278 |
+
def __init__(
|
279 |
+
self,
|
280 |
+
channels,
|
281 |
+
num_heads=1,
|
282 |
+
num_head_channels=-1,
|
283 |
+
use_checkpoint=False,
|
284 |
+
use_new_attention_order=False,
|
285 |
+
):
|
286 |
+
super().__init__()
|
287 |
+
self.channels = channels
|
288 |
+
if num_head_channels == -1:
|
289 |
+
self.num_heads = num_heads
|
290 |
+
else:
|
291 |
+
assert (
|
292 |
+
channels % num_head_channels == 0
|
293 |
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
294 |
+
self.num_heads = channels // num_head_channels
|
295 |
+
self.use_checkpoint = use_checkpoint
|
296 |
+
self.norm = normalization(channels)
|
297 |
+
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
298 |
+
if use_new_attention_order:
|
299 |
+
# split qkv before split heads
|
300 |
+
self.attention = QKVAttention(self.num_heads)
|
301 |
+
else:
|
302 |
+
# split heads before split qkv
|
303 |
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
304 |
+
|
305 |
+
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
306 |
+
|
307 |
+
def forward(self, x):
|
308 |
+
return checkpoint(self._forward, (x,), self.parameters(), True)
|
309 |
+
|
310 |
+
def _forward(self, x):
|
311 |
+
b, c, *spatial = x.shape
|
312 |
+
x = x.reshape(b, c, -1)
|
313 |
+
qkv = self.qkv(self.norm(x))
|
314 |
+
h = self.attention(qkv)
|
315 |
+
h = self.proj_out(h)
|
316 |
+
return (x + h).reshape(b, c, *spatial)
|
317 |
+
|
318 |
+
|
319 |
+
def count_flops_attn(model, _x, y):
|
320 |
+
"""
|
321 |
+
A counter for the `thop` package to count the operations in an
|
322 |
+
attention operation.
|
323 |
+
Meant to be used like:
|
324 |
+
macs, params = thop.profile(
|
325 |
+
model,
|
326 |
+
inputs=(inputs, timestamps),
|
327 |
+
custom_ops={QKVAttention: QKVAttention.count_flops},
|
328 |
+
)
|
329 |
+
"""
|
330 |
+
b, c, *spatial = y[0].shape
|
331 |
+
num_spatial = int(np.prod(spatial))
|
332 |
+
# We perform two matmuls with the same number of ops.
|
333 |
+
# The first computes the weight matrix, the second computes
|
334 |
+
# the combination of the value vectors.
|
335 |
+
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
336 |
+
model.total_ops += th.DoubleTensor([matmul_ops])
|
337 |
+
|
338 |
+
|
339 |
+
class QKVAttentionLegacy(nn.Module):
|
340 |
+
"""
|
341 |
+
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
342 |
+
"""
|
343 |
+
|
344 |
+
def __init__(self, n_heads):
|
345 |
+
super().__init__()
|
346 |
+
self.n_heads = n_heads
|
347 |
+
|
348 |
+
def forward(self, qkv):
|
349 |
+
"""
|
350 |
+
Apply QKV attention.
|
351 |
+
|
352 |
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
353 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
354 |
+
"""
|
355 |
+
bs, width, length = qkv.shape
|
356 |
+
assert width % (3 * self.n_heads) == 0
|
357 |
+
ch = width // (3 * self.n_heads)
|
358 |
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
359 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
360 |
+
weight = th.einsum(
|
361 |
+
"bct,bcs->bts", q * scale, k * scale
|
362 |
+
) # More stable with f16 than dividing afterwards
|
363 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
364 |
+
a = th.einsum("bts,bcs->bct", weight, v)
|
365 |
+
return a.reshape(bs, -1, length)
|
366 |
+
|
367 |
+
@staticmethod
|
368 |
+
def count_flops(model, _x, y):
|
369 |
+
return count_flops_attn(model, _x, y)
|
370 |
+
|
371 |
+
|
372 |
+
class QKVAttention(nn.Module):
|
373 |
+
"""
|
374 |
+
A module which performs QKV attention and splits in a different order.
|
375 |
+
"""
|
376 |
+
|
377 |
+
def __init__(self, n_heads):
|
378 |
+
super().__init__()
|
379 |
+
self.n_heads = n_heads
|
380 |
+
|
381 |
+
def forward(self, qkv):
|
382 |
+
"""
|
383 |
+
Apply QKV attention.
|
384 |
+
|
385 |
+
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
386 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
387 |
+
"""
|
388 |
+
bs, width, length = qkv.shape
|
389 |
+
assert width % (3 * self.n_heads) == 0
|
390 |
+
ch = width // (3 * self.n_heads)
|
391 |
+
q, k, v = qkv.chunk(3, dim=1)
|
392 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
393 |
+
weight = th.einsum(
|
394 |
+
"bct,bcs->bts",
|
395 |
+
(q * scale).view(bs * self.n_heads, ch, length),
|
396 |
+
(k * scale).view(bs * self.n_heads, ch, length),
|
397 |
+
) # More stable with f16 than dividing afterwards
|
398 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
399 |
+
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
400 |
+
return a.reshape(bs, -1, length)
|
401 |
+
|
402 |
+
@staticmethod
|
403 |
+
def count_flops(model, _x, y):
|
404 |
+
return count_flops_attn(model, _x, y)
|
405 |
+
|
406 |
+
|
407 |
+
class UNetModel(nn.Module):
|
408 |
+
"""
|
409 |
+
The full UNet model with attention and timestep embedding.
|
410 |
+
|
411 |
+
:param in_channels: channels in the input Tensor.
|
412 |
+
:param model_channels: base channel count for the model.
|
413 |
+
:param out_channels: channels in the output Tensor.
|
414 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
415 |
+
:param attention_resolutions: a collection of downsample rates at which
|
416 |
+
attention will take place. May be a set, list, or tuple.
|
417 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
418 |
+
will be used.
|
419 |
+
:param dropout: the dropout probability.
|
420 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
421 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
422 |
+
downsampling.
|
423 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
424 |
+
:param num_classes: if specified (as an int), then this model will be
|
425 |
+
class-conditional with `num_classes` classes.
|
426 |
+
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
427 |
+
:param num_heads: the number of attention heads in each attention layer.
|
428 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
429 |
+
a fixed channel width per attention head.
|
430 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
431 |
+
of heads for upsampling. Deprecated.
|
432 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
433 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
434 |
+
:param use_new_attention_order: use a different attention pattern for potentially
|
435 |
+
increased efficiency.
|
436 |
+
"""
|
437 |
+
|
438 |
+
def __init__(
|
439 |
+
self,
|
440 |
+
image_size,
|
441 |
+
in_channels,
|
442 |
+
model_channels,
|
443 |
+
out_channels,
|
444 |
+
num_res_blocks,
|
445 |
+
attention_resolutions,
|
446 |
+
dropout=0,
|
447 |
+
channel_mult=(1, 2, 4, 8),
|
448 |
+
conv_resample=True,
|
449 |
+
dims=2,
|
450 |
+
num_classes=None,
|
451 |
+
use_checkpoint=False,
|
452 |
+
use_fp16=False,
|
453 |
+
num_heads=1,
|
454 |
+
num_head_channels=-1,
|
455 |
+
num_heads_upsample=-1,
|
456 |
+
use_scale_shift_norm=False,
|
457 |
+
resblock_updown=False,
|
458 |
+
use_new_attention_order=False,
|
459 |
+
):
|
460 |
+
super().__init__()
|
461 |
+
|
462 |
+
if num_heads_upsample == -1:
|
463 |
+
num_heads_upsample = num_heads
|
464 |
+
|
465 |
+
self.image_size = image_size
|
466 |
+
self.in_channels = in_channels
|
467 |
+
self.model_channels = model_channels
|
468 |
+
self.out_channels = out_channels
|
469 |
+
self.num_res_blocks = num_res_blocks
|
470 |
+
self.attention_resolutions = attention_resolutions
|
471 |
+
self.dropout = dropout
|
472 |
+
self.channel_mult = channel_mult
|
473 |
+
self.conv_resample = conv_resample
|
474 |
+
self.num_classes = num_classes
|
475 |
+
self.use_checkpoint = use_checkpoint
|
476 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
477 |
+
self.num_heads = num_heads
|
478 |
+
self.num_head_channels = num_head_channels
|
479 |
+
self.num_heads_upsample = num_heads_upsample
|
480 |
+
|
481 |
+
time_embed_dim = model_channels * 4
|
482 |
+
self.time_embed = nn.Sequential(
|
483 |
+
linear(model_channels, time_embed_dim),
|
484 |
+
nn.SiLU(),
|
485 |
+
linear(time_embed_dim, time_embed_dim),
|
486 |
+
)
|
487 |
+
|
488 |
+
if self.num_classes is not None:
|
489 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
490 |
+
|
491 |
+
ch = input_ch = int(channel_mult[0] * model_channels)
|
492 |
+
self.input_blocks = nn.ModuleList(
|
493 |
+
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
|
494 |
+
)
|
495 |
+
self._feature_size = ch
|
496 |
+
input_block_chans = [ch]
|
497 |
+
ds = 1
|
498 |
+
for level, mult in enumerate(channel_mult):
|
499 |
+
for _ in range(num_res_blocks):
|
500 |
+
layers = [
|
501 |
+
ResBlock(
|
502 |
+
ch,
|
503 |
+
time_embed_dim,
|
504 |
+
dropout,
|
505 |
+
out_channels=int(mult * model_channels),
|
506 |
+
dims=dims,
|
507 |
+
use_checkpoint=use_checkpoint,
|
508 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
509 |
+
)
|
510 |
+
]
|
511 |
+
ch = int(mult * model_channels)
|
512 |
+
if ds in attention_resolutions:
|
513 |
+
layers.append(
|
514 |
+
AttentionBlock(
|
515 |
+
ch,
|
516 |
+
use_checkpoint=use_checkpoint,
|
517 |
+
num_heads=num_heads,
|
518 |
+
num_head_channels=num_head_channels,
|
519 |
+
use_new_attention_order=use_new_attention_order,
|
520 |
+
)
|
521 |
+
)
|
522 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
523 |
+
self._feature_size += ch
|
524 |
+
input_block_chans.append(ch)
|
525 |
+
if level != len(channel_mult) - 1:
|
526 |
+
out_ch = ch
|
527 |
+
self.input_blocks.append(
|
528 |
+
TimestepEmbedSequential(
|
529 |
+
ResBlock(
|
530 |
+
ch,
|
531 |
+
time_embed_dim,
|
532 |
+
dropout,
|
533 |
+
out_channels=out_ch,
|
534 |
+
dims=dims,
|
535 |
+
use_checkpoint=use_checkpoint,
|
536 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
537 |
+
down=True,
|
538 |
+
)
|
539 |
+
if resblock_updown
|
540 |
+
else Downsample(
|
541 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
542 |
+
)
|
543 |
+
)
|
544 |
+
)
|
545 |
+
ch = out_ch
|
546 |
+
input_block_chans.append(ch)
|
547 |
+
ds *= 2
|
548 |
+
self._feature_size += ch
|
549 |
+
|
550 |
+
self.middle_block = TimestepEmbedSequential(
|
551 |
+
ResBlock(
|
552 |
+
ch,
|
553 |
+
time_embed_dim,
|
554 |
+
dropout,
|
555 |
+
dims=dims,
|
556 |
+
use_checkpoint=use_checkpoint,
|
557 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
558 |
+
),
|
559 |
+
AttentionBlock(
|
560 |
+
ch,
|
561 |
+
use_checkpoint=use_checkpoint,
|
562 |
+
num_heads=num_heads,
|
563 |
+
num_head_channels=num_head_channels,
|
564 |
+
use_new_attention_order=use_new_attention_order,
|
565 |
+
),
|
566 |
+
ResBlock(
|
567 |
+
ch,
|
568 |
+
time_embed_dim,
|
569 |
+
dropout,
|
570 |
+
dims=dims,
|
571 |
+
use_checkpoint=use_checkpoint,
|
572 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
573 |
+
),
|
574 |
+
)
|
575 |
+
self._feature_size += ch
|
576 |
+
|
577 |
+
self.output_blocks = nn.ModuleList([])
|
578 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
579 |
+
for i in range(num_res_blocks + 1):
|
580 |
+
ich = input_block_chans.pop()
|
581 |
+
layers = [
|
582 |
+
ResBlock(
|
583 |
+
ch + ich,
|
584 |
+
time_embed_dim,
|
585 |
+
dropout,
|
586 |
+
out_channels=int(model_channels * mult),
|
587 |
+
dims=dims,
|
588 |
+
use_checkpoint=use_checkpoint,
|
589 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
590 |
+
)
|
591 |
+
]
|
592 |
+
ch = int(model_channels * mult)
|
593 |
+
if ds in attention_resolutions:
|
594 |
+
layers.append(
|
595 |
+
AttentionBlock(
|
596 |
+
ch,
|
597 |
+
use_checkpoint=use_checkpoint,
|
598 |
+
num_heads=num_heads_upsample,
|
599 |
+
num_head_channels=num_head_channels,
|
600 |
+
use_new_attention_order=use_new_attention_order,
|
601 |
+
)
|
602 |
+
)
|
603 |
+
if level and i == num_res_blocks:
|
604 |
+
out_ch = ch
|
605 |
+
layers.append(
|
606 |
+
ResBlock(
|
607 |
+
ch,
|
608 |
+
time_embed_dim,
|
609 |
+
dropout,
|
610 |
+
out_channels=out_ch,
|
611 |
+
dims=dims,
|
612 |
+
use_checkpoint=use_checkpoint,
|
613 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
614 |
+
up=True,
|
615 |
+
)
|
616 |
+
if resblock_updown
|
617 |
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
618 |
+
)
|
619 |
+
ds //= 2
|
620 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
621 |
+
self._feature_size += ch
|
622 |
+
|
623 |
+
self.out = nn.Sequential(
|
624 |
+
normalization(ch),
|
625 |
+
nn.SiLU(),
|
626 |
+
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
|
627 |
+
)
|
628 |
+
|
629 |
+
def convert_to_fp16(self):
|
630 |
+
"""
|
631 |
+
Convert the torso of the model to float16.
|
632 |
+
"""
|
633 |
+
self.input_blocks.apply(convert_module_to_f16)
|
634 |
+
self.middle_block.apply(convert_module_to_f16)
|
635 |
+
self.output_blocks.apply(convert_module_to_f16)
|
636 |
+
|
637 |
+
def convert_to_fp32(self):
|
638 |
+
"""
|
639 |
+
Convert the torso of the model to float32.
|
640 |
+
"""
|
641 |
+
self.input_blocks.apply(convert_module_to_f32)
|
642 |
+
self.middle_block.apply(convert_module_to_f32)
|
643 |
+
self.output_blocks.apply(convert_module_to_f32)
|
644 |
+
|
645 |
+
def forward(self, x, timesteps, y=None):
|
646 |
+
"""
|
647 |
+
Apply the model to an input batch.
|
648 |
+
|
649 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
650 |
+
:param timesteps: a 1-D batch of timesteps.
|
651 |
+
:param y: an [N] Tensor of labels, if class-conditional.
|
652 |
+
:return: an [N x C x ...] Tensor of outputs.
|
653 |
+
"""
|
654 |
+
assert (y is not None) == (
|
655 |
+
self.num_classes is not None
|
656 |
+
), "must specify y if and only if the model is class-conditional"
|
657 |
+
|
658 |
+
hs = []
|
659 |
+
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
660 |
+
|
661 |
+
if self.num_classes is not None:
|
662 |
+
assert y.shape == (x.shape[0],)
|
663 |
+
emb = emb + self.label_emb(y)
|
664 |
+
|
665 |
+
h = x.type(self.dtype)
|
666 |
+
for module in self.input_blocks:
|
667 |
+
h = module(h, emb)
|
668 |
+
hs.append(h)
|
669 |
+
h = self.middle_block(h, emb)
|
670 |
+
for module in self.output_blocks:
|
671 |
+
h = th.cat([h, hs.pop()], dim=1)
|
672 |
+
h = module(h, emb)
|
673 |
+
h = h.type(x.dtype)
|
674 |
+
return self.out(h)
|
675 |
+
|
676 |
+
|
677 |
+
class SuperResModel(UNetModel):
|
678 |
+
"""
|
679 |
+
A UNetModel that performs super-resolution.
|
680 |
+
|
681 |
+
Expects an extra kwarg `low_res` to condition on a low-resolution image.
|
682 |
+
"""
|
683 |
+
|
684 |
+
def __init__(self, image_size, in_channels, *args, **kwargs):
|
685 |
+
super().__init__(image_size, in_channels * 2, *args, **kwargs)
|
686 |
+
|
687 |
+
def forward(self, x, timesteps, low_res=None, **kwargs):
|
688 |
+
_, _, new_height, new_width = x.shape
|
689 |
+
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
|
690 |
+
x = th.cat([x, upsampled], dim=1)
|
691 |
+
return super().forward(x, timesteps, **kwargs)
|
692 |
+
|
693 |
+
|
694 |
+
class EncoderUNetModel(nn.Module):
|
695 |
+
"""
|
696 |
+
The half UNet model with attention and timestep embedding.
|
697 |
+
|
698 |
+
For usage, see UNet.
|
699 |
+
"""
|
700 |
+
|
701 |
+
def __init__(
|
702 |
+
self,
|
703 |
+
image_size,
|
704 |
+
in_channels,
|
705 |
+
model_channels,
|
706 |
+
out_channels,
|
707 |
+
num_res_blocks,
|
708 |
+
attention_resolutions,
|
709 |
+
dropout=0,
|
710 |
+
channel_mult=(1, 2, 4, 8),
|
711 |
+
conv_resample=True,
|
712 |
+
dims=2,
|
713 |
+
use_checkpoint=False,
|
714 |
+
use_fp16=False,
|
715 |
+
num_heads=1,
|
716 |
+
num_head_channels=-1,
|
717 |
+
num_heads_upsample=-1,
|
718 |
+
use_scale_shift_norm=False,
|
719 |
+
resblock_updown=False,
|
720 |
+
use_new_attention_order=False,
|
721 |
+
pool="adaptive",
|
722 |
+
chord=False,
|
723 |
+
):
|
724 |
+
super().__init__()
|
725 |
+
|
726 |
+
if num_heads_upsample == -1:
|
727 |
+
num_heads_upsample = num_heads
|
728 |
+
|
729 |
+
self.in_channels = in_channels
|
730 |
+
self.model_channels = model_channels
|
731 |
+
self.out_channels = out_channels
|
732 |
+
self.num_res_blocks = num_res_blocks
|
733 |
+
self.attention_resolutions = attention_resolutions
|
734 |
+
self.dropout = dropout
|
735 |
+
self.channel_mult = channel_mult
|
736 |
+
self.conv_resample = conv_resample
|
737 |
+
self.use_checkpoint = use_checkpoint
|
738 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
739 |
+
self.num_heads = num_heads
|
740 |
+
self.num_head_channels = num_head_channels
|
741 |
+
self.num_heads_upsample = num_heads_upsample
|
742 |
+
|
743 |
+
time_embed_dim = model_channels * 4
|
744 |
+
self.time_embed = nn.Sequential(
|
745 |
+
linear(model_channels, time_embed_dim),
|
746 |
+
nn.SiLU(),
|
747 |
+
linear(time_embed_dim, time_embed_dim),
|
748 |
+
)
|
749 |
+
|
750 |
+
ch = int(channel_mult[0] * model_channels)
|
751 |
+
self.input_blocks = nn.ModuleList(
|
752 |
+
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
|
753 |
+
)
|
754 |
+
self._feature_size = ch
|
755 |
+
input_block_chans = [ch]
|
756 |
+
ds = 1
|
757 |
+
for level, mult in enumerate(channel_mult):
|
758 |
+
for _ in range(num_res_blocks):
|
759 |
+
layers = [
|
760 |
+
ResBlock(
|
761 |
+
ch,
|
762 |
+
time_embed_dim,
|
763 |
+
dropout,
|
764 |
+
out_channels=int(mult * model_channels),
|
765 |
+
dims=dims,
|
766 |
+
use_checkpoint=use_checkpoint,
|
767 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
768 |
+
)
|
769 |
+
]
|
770 |
+
ch = int(mult * model_channels)
|
771 |
+
if ds in attention_resolutions:
|
772 |
+
layers.append(
|
773 |
+
AttentionBlock(
|
774 |
+
ch,
|
775 |
+
use_checkpoint=use_checkpoint,
|
776 |
+
num_heads=num_heads,
|
777 |
+
num_head_channels=num_head_channels,
|
778 |
+
use_new_attention_order=use_new_attention_order,
|
779 |
+
)
|
780 |
+
)
|
781 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
782 |
+
self._feature_size += ch
|
783 |
+
input_block_chans.append(ch)
|
784 |
+
if level != len(channel_mult) - 1:
|
785 |
+
out_ch = ch
|
786 |
+
self.input_blocks.append(
|
787 |
+
TimestepEmbedSequential(
|
788 |
+
ResBlock(
|
789 |
+
ch,
|
790 |
+
time_embed_dim,
|
791 |
+
dropout,
|
792 |
+
out_channels=out_ch,
|
793 |
+
dims=dims,
|
794 |
+
use_checkpoint=use_checkpoint,
|
795 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
796 |
+
down=True,
|
797 |
+
)
|
798 |
+
if resblock_updown
|
799 |
+
else Downsample(
|
800 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
801 |
+
)
|
802 |
+
)
|
803 |
+
)
|
804 |
+
ch = out_ch
|
805 |
+
input_block_chans.append(ch)
|
806 |
+
ds *= 2
|
807 |
+
self._feature_size += ch
|
808 |
+
|
809 |
+
self.middle_block = TimestepEmbedSequential(
|
810 |
+
ResBlock(
|
811 |
+
ch,
|
812 |
+
time_embed_dim,
|
813 |
+
dropout,
|
814 |
+
dims=dims,
|
815 |
+
use_checkpoint=use_checkpoint,
|
816 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
817 |
+
),
|
818 |
+
AttentionBlock(
|
819 |
+
ch,
|
820 |
+
use_checkpoint=use_checkpoint,
|
821 |
+
num_heads=num_heads,
|
822 |
+
num_head_channels=num_head_channels,
|
823 |
+
use_new_attention_order=use_new_attention_order,
|
824 |
+
),
|
825 |
+
ResBlock(
|
826 |
+
ch,
|
827 |
+
time_embed_dim,
|
828 |
+
dropout,
|
829 |
+
dims=dims,
|
830 |
+
use_checkpoint=use_checkpoint,
|
831 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
832 |
+
),
|
833 |
+
)
|
834 |
+
self._feature_size += ch
|
835 |
+
self.pool = pool
|
836 |
+
if pool == "adaptive":
|
837 |
+
self.out = nn.Sequential(
|
838 |
+
normalization(ch),
|
839 |
+
nn.SiLU(),
|
840 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
841 |
+
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
842 |
+
nn.Flatten(),
|
843 |
+
)
|
844 |
+
elif pool == "attention":
|
845 |
+
assert num_head_channels != -1
|
846 |
+
self.out = nn.Sequential(
|
847 |
+
normalization(ch),
|
848 |
+
nn.SiLU(),
|
849 |
+
AttentionPool2d(
|
850 |
+
(image_size // ds), ch, num_head_channels, out_channels, chord
|
851 |
+
),
|
852 |
+
)
|
853 |
+
elif pool == "spatial":
|
854 |
+
self.out = nn.Sequential(
|
855 |
+
nn.Linear(self._feature_size, 2048),
|
856 |
+
nn.ReLU(),
|
857 |
+
nn.Linear(2048, self.out_channels),
|
858 |
+
)
|
859 |
+
elif pool == "spatial_v2":
|
860 |
+
self.out = nn.Sequential(
|
861 |
+
nn.Linear(self._feature_size, 2048),
|
862 |
+
normalization(2048),
|
863 |
+
nn.SiLU(),
|
864 |
+
nn.Linear(2048, self.out_channels),
|
865 |
+
)
|
866 |
+
else:
|
867 |
+
raise NotImplementedError(f"Unexpected {pool} pooling")
|
868 |
+
|
869 |
+
def convert_to_fp16(self):
|
870 |
+
"""
|
871 |
+
Convert the torso of the model to float16.
|
872 |
+
"""
|
873 |
+
self.input_blocks.apply(convert_module_to_f16)
|
874 |
+
self.middle_block.apply(convert_module_to_f16)
|
875 |
+
|
876 |
+
def convert_to_fp32(self):
|
877 |
+
"""
|
878 |
+
Convert the torso of the model to float32.
|
879 |
+
"""
|
880 |
+
self.input_blocks.apply(convert_module_to_f32)
|
881 |
+
self.middle_block.apply(convert_module_to_f32)
|
882 |
+
|
883 |
+
def forward(self, x, timesteps):
|
884 |
+
"""
|
885 |
+
Apply the model to an input batch.
|
886 |
+
|
887 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
888 |
+
:param timesteps: a 1-D batch of timesteps.
|
889 |
+
:return: an [N x K] Tensor of outputs.
|
890 |
+
"""
|
891 |
+
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
892 |
+
|
893 |
+
results = []
|
894 |
+
h = x.type(self.dtype)
|
895 |
+
for module in self.input_blocks:
|
896 |
+
h = module(h, emb)
|
897 |
+
if self.pool.startswith("spatial"):
|
898 |
+
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
899 |
+
h = self.middle_block(h, emb)
|
900 |
+
if self.pool.startswith("spatial"):
|
901 |
+
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
902 |
+
h = th.cat(results, axis=-1)
|
903 |
+
return self.out(h)
|
904 |
+
else:
|
905 |
+
h = h.type(x.dtype)
|
906 |
+
return self.out(h)
|
load_utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import torch
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
|
5 |
+
|
6 |
+
def get_obj_from_str(string, reload=False):
|
7 |
+
module, cls = string.rsplit(".", 1)
|
8 |
+
if reload:
|
9 |
+
module_imp = importlib.import_module(module)
|
10 |
+
importlib.reload(module_imp)
|
11 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
12 |
+
|
13 |
+
|
14 |
+
def instantiate_from_config(config):
|
15 |
+
if not "target" in config:
|
16 |
+
raise KeyError("Expected key `target` to instantiate.")
|
17 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
18 |
+
|
19 |
+
|
20 |
+
def load_model(name, ckpt):
|
21 |
+
config = OmegaConf.load(f'taming-transformers/configs/pr/{name}.yaml')
|
22 |
+
model = instantiate_from_config(config.model)
|
23 |
+
model.init_from_ckpt(ckpt) # load_state_dict(mc['state_dict'])
|
24 |
+
model.eval()
|
25 |
+
return model
|
26 |
+
|
27 |
+
|
28 |
+
def load_data(name):
|
29 |
+
config = OmegaConf.load(f'taming-transformers/configs/pr/{name}.yaml')
|
30 |
+
data = instantiate_from_config(config.data)
|
31 |
+
return data
|
music_evaluation/README.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Music Evaluation
|
2 |
+
|
3 |
+
Adapted from this GitHub repository [mgeval](https://github.com/RichardYang40148/mgeval)
|
4 |
+
|
5 |
+
Deleted all packages in mgeval using python 2.
|
6 |
+
|
7 |
+
# Packages:
|
8 |
+
scipy, numpy, seaborn, pretty_midi, scikit-learn, python 3
|
9 |
+
|
10 |
+
# Usage:
|
11 |
+
|
12 |
+
```
|
13 |
+
python music_evaluator.py --set1dir /path/to/your/ground-truth/data/ --set2dir /path/to/your/generated-sample/ --outdir output-dir --num_sample number-of-samples-to-evaluate
|
14 |
+
```
|
15 |
+
|
16 |
+
|
17 |
+
# Output
|
18 |
+
All outputs are in the output-dir directory in the current folder, including plots and statistics.txt.
|
19 |
+
|
20 |
+
Check out the result folder for an example.
|
21 |
+
|
22 |
+
You can either run music_evaluator.py or demo.ipynb for evaluation.
|
music_evaluation/convert_to_wav.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from midi2audio import FluidSynth
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
|
6 |
+
# This program converts a folder of .midi files to a folder of .wav files
|
7 |
+
# Need to download FluidSynth and midi2audio packages
|
8 |
+
#
|
9 |
+
# Usage:
|
10 |
+
# python convert_to_wav.py midi_dir wav_dir
|
11 |
+
# More info about FluidSynth: https://github.com/FluidSynth/fluidsynth
|
12 |
+
# More info about midi2audio: https://github.com/bzamecnik/midi2audio
|
13 |
+
# Need Sound fonts to run this program: https://sites.google.com/site/soundfonts4u/
|
14 |
+
# The sound font used in this program: https://drive.google.com/file/d/1nvTy62-wHGnZ6CKYuPNAiGlKLtWg9Ir9/view?usp=sharing
|
15 |
+
|
16 |
+
def convert_midi_to_audio(input_dir, output_dir, fs):
|
17 |
+
# sound_font_path = os.path.join(os.getcwd(), "Dore Mark's NY S&S Model B-v5.2.sf2")
|
18 |
+
# fs = FluidSynth(sound_font_path)
|
19 |
+
os.chdir(input_dir)
|
20 |
+
filenames = os.listdir(input_dir)
|
21 |
+
for midi_file in filenames:
|
22 |
+
filename = midi_file[:-5]
|
23 |
+
filename = filename + ".wav"
|
24 |
+
output_file = os.path.join(output_dir, filename)
|
25 |
+
fs.midi_to_audio(midi_file, output_file)
|
26 |
+
|
27 |
+
return
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
sound_font_path = os.path.join(os.getcwd(), "Dore Mark's NY S&S Model B-v5.2.sf2")
|
32 |
+
fs = FluidSynth(sound_font_path)
|
33 |
+
# fs.midi_to_audio('MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_0.midi', 'output.wav')
|
34 |
+
|
35 |
+
output_dir = sys.argv[2]
|
36 |
+
os.makedirs(output_dir, exist_ok=True)
|
37 |
+
current_path = os.getcwd()
|
38 |
+
output_dir = os.path.join(current_path, output_dir)
|
39 |
+
|
40 |
+
input_dir = sys.argv[1]
|
41 |
+
input_dir = os.path.join(current_path, input_dir)
|
42 |
+
convert_midi_to_audio(input_dir, output_dir, fs)
|
music_evaluation/demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
music_evaluation/fad.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from frechet_audio_distance import FrechetAudioDistance
|
2 |
+
import sys
|
3 |
+
|
4 |
+
# Compute FAD distance between the ground truth dataset and sample dataset
|
5 |
+
# Pretty slow depends on the speed
|
6 |
+
# Save embeddings.npy for future fast usage
|
7 |
+
#
|
8 |
+
# Usage: python fad.py background_dir_path eval_dir_path
|
9 |
+
# Feel free the change the embedding path in the code
|
10 |
+
# More info about FrechetAudioDistance: https://github.com/gudgud96/frechet-audio-distance
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
# to use `vggish`
|
14 |
+
frechet = FrechetAudioDistance(
|
15 |
+
model_name="vggish",
|
16 |
+
use_pca=False,
|
17 |
+
use_activation=False,
|
18 |
+
verbose=False
|
19 |
+
)
|
20 |
+
# # to use `PANN`
|
21 |
+
# frechet = FrechetAudioDistance(
|
22 |
+
# model_name="pann",
|
23 |
+
# use_pca=False,
|
24 |
+
# use_activation=False,
|
25 |
+
# verbose=False
|
26 |
+
# )
|
27 |
+
|
28 |
+
background_dir = sys.argv[1]
|
29 |
+
eval_dir = sys.argv[2]
|
30 |
+
|
31 |
+
background_embds_path = "./ground_truth_embeddings.npy"
|
32 |
+
eval_embds_path = "./eval_embeddings.npy"
|
33 |
+
|
34 |
+
fad_score = frechet.score(background_dir, eval_dir,
|
35 |
+
background_embds_path=background_embds_path,
|
36 |
+
eval_embds_path=eval_embds_path,dtype="float32")
|
37 |
+
|
38 |
+
print(fad_score)
|
music_evaluation/figaro/chord_recognition.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
class MIDIChord(object):
|
4 |
+
def __init__(self, pm):
|
5 |
+
self.pm = pm
|
6 |
+
# define pitch classes
|
7 |
+
self.PITCH_CLASSES = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
8 |
+
# define chord maps (required)
|
9 |
+
self.CHORD_MAPS = {'maj': [0, 4],
|
10 |
+
'min': [0, 3],
|
11 |
+
'dim': [0, 3, 6],
|
12 |
+
'aug': [0, 4, 8],
|
13 |
+
'dom7': [0, 4, 10],
|
14 |
+
'maj7': [0, 4, 11],
|
15 |
+
'min7': [0, 3, 10]}
|
16 |
+
# define chord insiders (+10)
|
17 |
+
self.CHORD_INSIDERS = {'maj': [7],
|
18 |
+
'min': [7],
|
19 |
+
'dim': [9],
|
20 |
+
'aug': [],
|
21 |
+
'dom7': [7],
|
22 |
+
'maj7': [7],
|
23 |
+
'min7': [7]}
|
24 |
+
# define chord outsiders (-1)
|
25 |
+
self.CHORD_OUTSIDERS_1 = {'maj': [2, 5, 9],
|
26 |
+
'min': [2, 5, 8],
|
27 |
+
'dim': [2, 5, 10],
|
28 |
+
'aug': [2, 5, 9],
|
29 |
+
'dom7': [2, 5, 9],
|
30 |
+
'maj7': [2, 5, 9],
|
31 |
+
'maj7': [2, 5, 9],
|
32 |
+
'min7': [2, 5, 8]}
|
33 |
+
# define chord outsiders (-2)
|
34 |
+
self.CHORD_OUTSIDERS_2 = {'maj': [1, 3, 6, 8, 10, 11],
|
35 |
+
'min': [1, 4, 6, 9, 11],
|
36 |
+
'dim': [1, 4, 7, 8, 11],
|
37 |
+
'aug': [1, 3, 6, 7, 10],
|
38 |
+
'dom7': [1, 3, 6, 8, 11],
|
39 |
+
'maj7': [1, 3, 6, 8, 10],
|
40 |
+
'min7': [1, 4, 6, 9, 11]}
|
41 |
+
|
42 |
+
def sequencing(self, chroma):
|
43 |
+
candidates = {}
|
44 |
+
for index in range(len(chroma)):
|
45 |
+
if chroma[index]:
|
46 |
+
root_note = index
|
47 |
+
_chroma = np.roll(chroma, -root_note)
|
48 |
+
sequence = np.where(_chroma == 1)[0]
|
49 |
+
candidates[root_note] = list(sequence)
|
50 |
+
return candidates
|
51 |
+
|
52 |
+
def scoring(self, candidates):
|
53 |
+
scores = {}
|
54 |
+
qualities = {}
|
55 |
+
for root_note, sequence in candidates.items():
|
56 |
+
if 3 not in sequence and 4 not in sequence:
|
57 |
+
scores[root_note] = -100
|
58 |
+
qualities[root_note] = 'None'
|
59 |
+
elif 3 in sequence and 4 in sequence:
|
60 |
+
scores[root_note] = -100
|
61 |
+
qualities[root_note] = 'None'
|
62 |
+
else:
|
63 |
+
# decide quality
|
64 |
+
if 3 in sequence:
|
65 |
+
if 6 in sequence:
|
66 |
+
quality = 'dim'
|
67 |
+
else:
|
68 |
+
if 10 in sequence:
|
69 |
+
quality = 'min7'
|
70 |
+
else:
|
71 |
+
quality = 'min'
|
72 |
+
elif 4 in sequence:
|
73 |
+
if 8 in sequence:
|
74 |
+
quality = 'aug'
|
75 |
+
else:
|
76 |
+
if 10 in sequence:
|
77 |
+
quality = 'dom7'
|
78 |
+
elif 11 in sequence:
|
79 |
+
quality = 'maj7'
|
80 |
+
else:
|
81 |
+
quality = 'maj'
|
82 |
+
# decide score
|
83 |
+
maps = self.CHORD_MAPS.get(quality)
|
84 |
+
_notes = [n for n in sequence if n not in maps]
|
85 |
+
score = 0
|
86 |
+
for n in _notes:
|
87 |
+
if n in self.CHORD_OUTSIDERS_1.get(quality):
|
88 |
+
score -= 1
|
89 |
+
elif n in self.CHORD_OUTSIDERS_2.get(quality):
|
90 |
+
score -= 2
|
91 |
+
elif n in self.CHORD_INSIDERS.get(quality):
|
92 |
+
score += 10
|
93 |
+
scores[root_note] = score
|
94 |
+
qualities[root_note] = quality
|
95 |
+
return scores, qualities
|
96 |
+
|
97 |
+
def find_chord(self, chroma, threshold=10):
|
98 |
+
chroma = np.sum(chroma, axis=1)
|
99 |
+
chroma = np.array([1 if c > threshold else 0 for c in chroma])
|
100 |
+
if np.sum(chroma) == 0:
|
101 |
+
return 'N', 'N', 'N', 10
|
102 |
+
else:
|
103 |
+
candidates = self.sequencing(chroma=chroma)
|
104 |
+
scores, qualities = self.scoring(candidates=candidates)
|
105 |
+
# bass note
|
106 |
+
sorted_notes = []
|
107 |
+
for i, v in enumerate(chroma):
|
108 |
+
if v > 0:
|
109 |
+
sorted_notes.append(int(i%12))
|
110 |
+
bass_note = sorted_notes[0]
|
111 |
+
# root note
|
112 |
+
__root_note = []
|
113 |
+
_max = max(scores.values())
|
114 |
+
for _root_note, score in scores.items():
|
115 |
+
if score == _max:
|
116 |
+
__root_note.append(_root_note)
|
117 |
+
if len(__root_note) == 1:
|
118 |
+
root_note = __root_note[0]
|
119 |
+
else:
|
120 |
+
#TODO: what should i do
|
121 |
+
for n in sorted_notes:
|
122 |
+
if n in __root_note:
|
123 |
+
root_note = n
|
124 |
+
break
|
125 |
+
# quality
|
126 |
+
quality = qualities.get(root_note)
|
127 |
+
sequence = candidates.get(root_note)
|
128 |
+
# score
|
129 |
+
score = scores.get(root_note)
|
130 |
+
return self.PITCH_CLASSES[root_note], quality, self.PITCH_CLASSES[bass_note], score
|
131 |
+
|
132 |
+
def greedy(self, candidates, max_tick, min_length):
|
133 |
+
chords = []
|
134 |
+
# start from 0
|
135 |
+
start_tick = 0
|
136 |
+
while start_tick < max_tick:
|
137 |
+
_candidates = candidates.get(start_tick)
|
138 |
+
_candidates = sorted(_candidates.items(), key=lambda x: (x[1][-1], x[0]))
|
139 |
+
# choose
|
140 |
+
end_tick, (root_note, quality, bass_note, _) = _candidates[-1]
|
141 |
+
if root_note == bass_note:
|
142 |
+
chord = '{}:{}'.format(root_note, quality)
|
143 |
+
else:
|
144 |
+
chord = '{}:{}/{}'.format(root_note, quality, bass_note)
|
145 |
+
chords.append([start_tick, end_tick, chord])
|
146 |
+
start_tick = end_tick
|
147 |
+
# remove :None
|
148 |
+
temp = chords
|
149 |
+
while ':None' in temp[0][-1]:
|
150 |
+
try:
|
151 |
+
temp[1][0] = temp[0][0]
|
152 |
+
del temp[0]
|
153 |
+
except:
|
154 |
+
print('NO CHORD')
|
155 |
+
return []
|
156 |
+
temp2 = []
|
157 |
+
for chord in temp:
|
158 |
+
if ':None' not in chord[-1]:
|
159 |
+
temp2.append(chord)
|
160 |
+
else:
|
161 |
+
temp2[-1][1] = chord[1]
|
162 |
+
return temp2
|
163 |
+
|
164 |
+
def dynamic(self, candidates, max_tick, min_length):
|
165 |
+
# store index of best chord at each position
|
166 |
+
chords = [None for i in range(max_tick + 1)]
|
167 |
+
# store score of best chords at each position
|
168 |
+
scores = np.zeros(max_tick + 1)
|
169 |
+
scores[1:].fill(np.NINF)
|
170 |
+
|
171 |
+
start_tick = 0
|
172 |
+
while start_tick < max_tick:
|
173 |
+
if start_tick in candidates:
|
174 |
+
for i, (end_tick, candidate) in enumerate(candidates.get(start_tick).items()):
|
175 |
+
root_note, quality, bass_note, score = candidate
|
176 |
+
# if this candidate is best yet, update scores and chords
|
177 |
+
if scores[end_tick] < scores[start_tick] + score:
|
178 |
+
scores[end_tick] = scores[start_tick] + score
|
179 |
+
if root_note == bass_note:
|
180 |
+
chord = '{}:{}'.format(root_note, quality)
|
181 |
+
else:
|
182 |
+
chord = '{}:{}/{}'.format(root_note, quality, bass_note)
|
183 |
+
chords[end_tick] = (start_tick, end_tick, chord)
|
184 |
+
start_tick += 1
|
185 |
+
# Read the best path
|
186 |
+
start_tick = len(chords) - 1
|
187 |
+
results = []
|
188 |
+
while start_tick > 0:
|
189 |
+
chord = chords[start_tick]
|
190 |
+
start_tick = chord[0]
|
191 |
+
results.append(chord)
|
192 |
+
|
193 |
+
return list(reversed(results))
|
194 |
+
|
195 |
+
def dedupe(self, chords):
|
196 |
+
if len(chords) == 0:
|
197 |
+
return []
|
198 |
+
deduped = []
|
199 |
+
start, end, chord = chords[0]
|
200 |
+
for (curr, next) in zip(chords[:-1], chords[1:]):
|
201 |
+
if chord == next[2]:
|
202 |
+
end = next[1]
|
203 |
+
else:
|
204 |
+
deduped.append([start, end, chord])
|
205 |
+
start, end, chord = next
|
206 |
+
deduped.append([start, end, chord])
|
207 |
+
return deduped
|
208 |
+
|
209 |
+
def get_candidates(self, chroma, max_tick, intervals=[1, 2, 3, 4]):
|
210 |
+
candidates = {}
|
211 |
+
for interval in intervals:
|
212 |
+
for start_beat in range(max_tick):
|
213 |
+
# set target pianoroll
|
214 |
+
end_beat = start_beat + interval
|
215 |
+
if end_beat > max_tick:
|
216 |
+
end_beat = max_tick
|
217 |
+
_chroma = chroma[:, start_beat:end_beat]
|
218 |
+
# find chord
|
219 |
+
root_note, quality, bass_note, score = self.find_chord(chroma=_chroma)
|
220 |
+
# save
|
221 |
+
if start_beat not in candidates:
|
222 |
+
candidates[start_beat] = {}
|
223 |
+
candidates[start_beat][end_beat] = (root_note, quality, bass_note, score)
|
224 |
+
else:
|
225 |
+
if end_beat not in candidates[start_beat]:
|
226 |
+
candidates[start_beat][end_beat] = (root_note, quality, bass_note, score)
|
227 |
+
return candidates
|
228 |
+
|
229 |
+
def extract(self):
|
230 |
+
# read
|
231 |
+
beats = self.pm.get_beats()
|
232 |
+
chroma = self.pm.get_chroma(times=beats)
|
233 |
+
# get lots of candidates
|
234 |
+
candidates = self.get_candidates(chroma, max_tick=len(beats))
|
235 |
+
|
236 |
+
# greedy
|
237 |
+
chords = self.dynamic(candidates=candidates,
|
238 |
+
max_tick=len(beats),
|
239 |
+
min_length=1)
|
240 |
+
chords = self.dedupe(chords)
|
241 |
+
for chord in chords:
|
242 |
+
chord[0] = beats[chord[0]]
|
243 |
+
if chord[1] >= len(beats):
|
244 |
+
chord[1] = self.pm.get_end_time()
|
245 |
+
else:
|
246 |
+
chord[1] = beats[chord[1]]
|
247 |
+
return chords
|
music_evaluation/figaro/constants.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
# parameters for input representation
|
4 |
+
DEFAULT_POS_PER_QUARTER = 12
|
5 |
+
DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 32+1, dtype=np.int32)
|
6 |
+
DEFAULT_DURATION_BINS = np.sort(np.concatenate([
|
7 |
+
np.arange(1, 13), # smallest possible units up to 1 quarter
|
8 |
+
np.arange(12, 24, 3)[1:], # 16th notes up to 1 bar
|
9 |
+
np.arange(13, 24, 4)[1:], # triplets up to 1 bar
|
10 |
+
np.arange(24, 48, 6), # 8th notes up to 2 bars
|
11 |
+
np.arange(48, 4*48, 12), # quarter notes up to 8 bars
|
12 |
+
np.arange(4*48, 16*48+1, 24) # half notes up to 16 bars
|
13 |
+
]))
|
14 |
+
DEFAULT_TEMPO_BINS = np.linspace(0, 240, 32+1, dtype=np.int32)
|
15 |
+
DEFAULT_NOTE_DENSITY_BINS = np.linspace(0, 12, 32+1)
|
16 |
+
DEFAULT_MEAN_VELOCITY_BINS = np.linspace(0, 128, 32+1)
|
17 |
+
DEFAULT_MEAN_PITCH_BINS = np.linspace(0, 128, 32+1)
|
18 |
+
DEFAULT_MEAN_DURATION_BINS = np.logspace(0, 7, 32+1, base=2) # log space between 1 and 128 positions (~2.5 bars)
|
19 |
+
|
20 |
+
# parameters for output
|
21 |
+
DEFAULT_RESOLUTION = 480
|
22 |
+
|
23 |
+
# maximum length of a single bar is 3*4 = 12 beats
|
24 |
+
MAX_BAR_LENGTH = 3
|
25 |
+
# maximum number of bars in a piece is 512 (this covers almost all sequences)
|
26 |
+
MAX_N_BARS = 512
|
27 |
+
|
28 |
+
PAD_TOKEN = '<pad>'
|
29 |
+
UNK_TOKEN = '<unk>'
|
30 |
+
BOS_TOKEN = '<bos>'
|
31 |
+
EOS_TOKEN = '<eos>'
|
32 |
+
MASK_TOKEN = '<mask>'
|
33 |
+
|
34 |
+
TIME_SIGNATURE_KEY = 'Time Signature'
|
35 |
+
BAR_KEY = 'Bar'
|
36 |
+
POSITION_KEY = 'Position'
|
37 |
+
INSTRUMENT_KEY = 'Instrument'
|
38 |
+
PITCH_KEY = 'Pitch'
|
39 |
+
VELOCITY_KEY = 'Velocity'
|
40 |
+
DURATION_KEY = 'Duration'
|
41 |
+
TEMPO_KEY = 'Tempo'
|
42 |
+
CHORD_KEY = 'Chord'
|
43 |
+
|
44 |
+
NOTE_DENSITY_KEY = 'Note Density'
|
45 |
+
MEAN_PITCH_KEY = 'Mean Pitch'
|
46 |
+
MEAN_VELOCITY_KEY = 'Mean Velocity'
|
47 |
+
MEAN_DURATION_KEY = 'Mean Duration'
|
music_evaluation/figaro/evaluate.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob
|
2 |
+
from statistics import NormalDist
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import input_representation as ir
|
7 |
+
|
8 |
+
SAMPLE_DIR = os.getenv('SAMPLE_DIR', './samples')
|
9 |
+
OUT_FILE = os.getenv('OUT_FILE', './metrics.csv')
|
10 |
+
MAX_SAMPLES = int(os.getenv('MAX_SAMPLES', 1024))
|
11 |
+
# use to find base file name when generate multiple files for 1 gt file
|
12 |
+
SPLIT_STR = os.getenv('SPLIT_STR', None)
|
13 |
+
POST_STR = os.getenv('POST_STR', None)
|
14 |
+
|
15 |
+
METRICS = [
|
16 |
+
'inst_prec', 'inst_rec', 'inst_f1',
|
17 |
+
'chord_prec', 'chord_rec', 'chord_f1',
|
18 |
+
'time_sig_acc',
|
19 |
+
'note_dens_oa', 'pitch_oa', 'velocity_oa', 'duration_oa',
|
20 |
+
'chroma_crossent', 'chroma_kldiv', 'chroma_sim',
|
21 |
+
'groove_crossent', 'groove_kldiv', 'groove_sim',
|
22 |
+
]
|
23 |
+
|
24 |
+
DF_KEYS = ['id', 'original', 'sample'] + METRICS
|
25 |
+
|
26 |
+
keys = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
27 |
+
qualities = ['maj', 'min', 'dim', 'aug', 'dom7', 'maj7', 'min7', 'None']
|
28 |
+
CHORDS = [f"{k}:{q}" for k in keys for q in qualities] + ['N:N']
|
29 |
+
|
30 |
+
def get_group_id(file):
|
31 |
+
# change this depending on name of generated samples
|
32 |
+
name = os.path.basename(file)
|
33 |
+
return name.split('.')[0]
|
34 |
+
|
35 |
+
def get_base_name(file):
|
36 |
+
base_file_name = os.path.basename(file)
|
37 |
+
base_name = base_file_name.split(SPLIT_STR)[0]
|
38 |
+
if POST_STR is not None:
|
39 |
+
gt_name = base_name + POST_STR
|
40 |
+
else:
|
41 |
+
gt_name = base_name
|
42 |
+
return gt_name
|
43 |
+
|
44 |
+
def get_file_groups(path, max_samples=MAX_SAMPLES):
|
45 |
+
# change this depending on file structure of generated samples
|
46 |
+
files = glob.glob(os.path.join(path, '*.mid'), recursive=True) + glob.glob(os.path.join(path, '*.midi'), recursive=True)
|
47 |
+
assert len(files), f"provided directory was empty: {path}"
|
48 |
+
|
49 |
+
samples = sorted(files)
|
50 |
+
origs = sorted([os.path.join(path, 'gt', get_base_name(file)) for file in files])
|
51 |
+
pairs = list(zip(origs, samples))
|
52 |
+
|
53 |
+
pairs = list(filter(lambda pair: os.path.exists(pair[0]), pairs))
|
54 |
+
if max_samples > 0:
|
55 |
+
pairs = pairs[:max_samples]
|
56 |
+
|
57 |
+
groups = dict()
|
58 |
+
for orig, sample in pairs:
|
59 |
+
sample_id = get_group_id(sample)
|
60 |
+
orig_id = get_group_id(orig)
|
61 |
+
if orig_id not in groups:
|
62 |
+
groups[orig_id] = list()
|
63 |
+
groups[orig_id].append((orig, sample))
|
64 |
+
|
65 |
+
return list(groups.values())
|
66 |
+
|
67 |
+
def read_file(file):
|
68 |
+
with open(file, 'r') as f:
|
69 |
+
events = f.read().split('\n')
|
70 |
+
events = [e for e in events if e]
|
71 |
+
return events
|
72 |
+
|
73 |
+
def get_chord_groups(desc):
|
74 |
+
bars = [1 if 'Bar_' in item else 0 for item in desc]
|
75 |
+
bar_ids = np.cumsum(bars) - 1
|
76 |
+
groups = [[] for _ in range(bar_ids[-1] + 1)]
|
77 |
+
for i, item in enumerate(desc):
|
78 |
+
if 'Chord_' in item:
|
79 |
+
chord = item.split('_')[-1]
|
80 |
+
groups[bar_ids[i]].append(chord)
|
81 |
+
return groups
|
82 |
+
|
83 |
+
def instruments(events):
|
84 |
+
insts = [128 if item.instrument == 'drum' else int(item.instrument) for item in events[1:-1] if item.name == 'Note']
|
85 |
+
insts = np.bincount(insts, minlength=129)
|
86 |
+
return (insts > 0).astype(int)
|
87 |
+
|
88 |
+
def chords(events):
|
89 |
+
chords = [CHORDS.index(item) for item in events]
|
90 |
+
chords = np.bincount(chords, minlength=129)
|
91 |
+
return (chords > 0).astype(int)
|
92 |
+
|
93 |
+
def chroma(events):
|
94 |
+
pitch_classes = [item.pitch % 12 for item in events[1:-1] if item.name == 'Note' and item.instrument != 'drum']
|
95 |
+
if len(pitch_classes):
|
96 |
+
count = np.bincount(pitch_classes, minlength=12)
|
97 |
+
count = count / np.sqrt(np.sum(count ** 2))
|
98 |
+
else:
|
99 |
+
count = np.array([1/12] * 12)
|
100 |
+
return count
|
101 |
+
|
102 |
+
def groove(events, start=0, pos_per_bar=48, ticks_per_bar=1920):
|
103 |
+
flags = np.linspace(start, start + ticks_per_bar, pos_per_bar, endpoint=False)
|
104 |
+
onsets = [item.start for item in events[1:-1] if item.name == 'Note']
|
105 |
+
positions = [np.argmin(np.abs(flags - beat)) for beat in onsets]
|
106 |
+
if len(positions):
|
107 |
+
count = np.bincount(positions, minlength=pos_per_bar)
|
108 |
+
count = np.convolve(count, [1, 4, 1], 'same')
|
109 |
+
count = count / np.sqrt(np.sum(count ** 2))
|
110 |
+
else:
|
111 |
+
count = np.array([1/pos_per_bar] * pos_per_bar)
|
112 |
+
return count
|
113 |
+
|
114 |
+
def multi_class_accuracy(y_true, y_pred):
|
115 |
+
tp = ((y_true == 1) & (y_pred == 1)).sum()
|
116 |
+
p = tp / y_pred.sum()
|
117 |
+
r = tp / y_true.sum()
|
118 |
+
if p + r > 0:
|
119 |
+
f1 = 2*p*r / (p + r)
|
120 |
+
else:
|
121 |
+
f1 = 0
|
122 |
+
return p, r, f1
|
123 |
+
|
124 |
+
def cross_entropy(p_true, p_pred, eps=1e-8):
|
125 |
+
return -np.sum(p_true * np.log(p_pred + eps)) / len(p_true)
|
126 |
+
|
127 |
+
def kl_divergence(p_true, p_pred, eps=1e-8):
|
128 |
+
return np.sum(p_true * (np.log(p_true + eps) - np.log(p_pred + eps))) / len(p_true)
|
129 |
+
|
130 |
+
def cosine_sim(p_true, p_pred):
|
131 |
+
return np.sum(p_true * p_pred)
|
132 |
+
|
133 |
+
def sliding_window_metrics(items, start, end, window=1920, step=480, ticks_per_beat=480):
|
134 |
+
glob_start, glob_end = start, end
|
135 |
+
notes = [item for item in items if item.name == 'Note']
|
136 |
+
starts = np.arange(glob_start, glob_end - window, step=step)
|
137 |
+
|
138 |
+
groups = []
|
139 |
+
start_idx, end_idx = 0, 0
|
140 |
+
for start in starts:
|
141 |
+
while notes[start_idx].start < start:
|
142 |
+
start_idx += 1
|
143 |
+
while end_idx < len(notes) and notes[end_idx].start < start + window:
|
144 |
+
end_idx += 1
|
145 |
+
|
146 |
+
groups.append([start] + notes[start_idx:end_idx] + [start + window])
|
147 |
+
return groups
|
148 |
+
|
149 |
+
def meta_stats(group, ticks_per_beat=480):
|
150 |
+
start, end = group[0], group[-1]
|
151 |
+
ns = [item for item in group[1:-1] if item.name == 'Note']
|
152 |
+
ns_ = [note for note in ns if note.instrument != 'drum']
|
153 |
+
pitches = [note.pitch for note in ns_]
|
154 |
+
vels = [note.velocity for note in ns_]
|
155 |
+
durs = [(note.end - note.start) / ticks_per_beat for note in ns_]
|
156 |
+
|
157 |
+
return {
|
158 |
+
'note_density': len(ns) / ((end - start) / ticks_per_beat),
|
159 |
+
'pitch_mean': np.mean(pitches) if len(pitches) else np.nan,
|
160 |
+
'velocity_mean': np.mean(vels) if len(vels) else np.nan,
|
161 |
+
'duration_mean': np.mean(durs) if len(durs) else np.nan,
|
162 |
+
'pitch_std': np.std(pitches) if len(pitches) else np.nan,
|
163 |
+
'velocity_std': np.std(vels) if len(vels) else np.nan,
|
164 |
+
'duration_std': np.std(durs) if len(durs) else np.nan,
|
165 |
+
}
|
166 |
+
|
167 |
+
def overlapping_area(mu1, sigma1, mu2, sigma2, eps=0.01):
|
168 |
+
sigma1, sigma2 = max(eps, sigma1), max(eps, sigma2)
|
169 |
+
return NormalDist(mu=mu1, sigma=sigma1).overlap(NormalDist(mu=mu2, sigma=sigma2))
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
def main():
|
174 |
+
file_groups = get_file_groups(SAMPLE_DIR)
|
175 |
+
|
176 |
+
metrics = pd.DataFrame()
|
177 |
+
for sample_id, group in enumerate(file_groups):
|
178 |
+
print(f"[info] Group {sample_id + 1}/{len(file_groups)}")
|
179 |
+
micro_metrics = pd.DataFrame()
|
180 |
+
for orig_file, sample_file in group:
|
181 |
+
print(f"original: {orig_file.split('/')[-1]} | sample: {sample_file.split('/')[-1]}")
|
182 |
+
orig = ir.InputRepresentation(orig_file)
|
183 |
+
sample = ir.InputRepresentation(sample_file)
|
184 |
+
|
185 |
+
orig_desc, sample_desc = orig.get_description(), sample.get_description()
|
186 |
+
if len(orig_desc) == 0 or len(sample_desc) == 0:
|
187 |
+
print("[warning] empty sample! skipping")
|
188 |
+
continue
|
189 |
+
|
190 |
+
chord_groups1 = get_chord_groups(orig_desc)
|
191 |
+
chord_groups2 = get_chord_groups(sample_desc)
|
192 |
+
|
193 |
+
note_density_gt = []
|
194 |
+
|
195 |
+
for g1, g2, cg1, cg2 in zip(orig.groups, sample.groups, chord_groups1, chord_groups2):
|
196 |
+
row = pd.DataFrame([{ 'id': sample_id, 'original': orig_file.split('/')[-1], 'sample': sample_file.split('/')[-1]}])
|
197 |
+
|
198 |
+
meta1, meta2 = meta_stats(g1, ticks_per_beat=orig.pm.resolution), meta_stats(g2, ticks_per_beat=sample.pm.resolution)
|
199 |
+
row['pitch_oa'] = overlapping_area(meta1['pitch_mean'], meta1['pitch_std'], meta2['pitch_mean'], meta2['pitch_std'])
|
200 |
+
row['velocity_oa'] = overlapping_area(meta1['velocity_mean'], meta1['velocity_std'], meta2['velocity_mean'], meta2['velocity_std'])
|
201 |
+
row['duration_oa'] = overlapping_area(meta1['duration_mean'], meta1['duration_std'], meta2['duration_mean'], meta2['duration_std'])
|
202 |
+
row['note_density_abs_err'] = np.abs(meta1['note_density'] - meta2['note_density'])
|
203 |
+
row['mean_pitch_abs_err'] = np.abs(meta1['pitch_mean'] - meta2['pitch_mean'])
|
204 |
+
row['mean_velocity_abs_err'] = np.abs(meta1['velocity_mean'] - meta2['velocity_mean'])
|
205 |
+
row['mean_duration_abs_err'] = np.abs(meta1['duration_mean'] - meta2['duration_mean'])
|
206 |
+
note_density_gt.append(meta1['note_density'])
|
207 |
+
|
208 |
+
ts1, ts2 = orig._get_time_signature(g1[0]), sample._get_time_signature(g2[0])
|
209 |
+
ts1, ts2 = f"{ts1.numerator}/{ts1.denominator}", f"{ts2.numerator}/{ts2.denominator}"
|
210 |
+
row['time_sig_acc'] = 1 if ts1 == ts2 else 0
|
211 |
+
|
212 |
+
inst1, inst2 = instruments(g1), instruments(g2)
|
213 |
+
prec, rec, f1 = multi_class_accuracy(inst1, inst2)
|
214 |
+
row['inst_prec'] = prec
|
215 |
+
row['inst_rec'] = rec
|
216 |
+
row['inst_f1'] = f1
|
217 |
+
|
218 |
+
chords1, chords2 = chords(cg1), chords(cg2)
|
219 |
+
prec, rec, f1 = multi_class_accuracy(chords1, chords2)
|
220 |
+
row['chord_prec'] = prec
|
221 |
+
row['chord_rec'] = rec
|
222 |
+
row['chord_f1'] = f1
|
223 |
+
|
224 |
+
c1, c2 = chroma(g1), chroma(g2)
|
225 |
+
row['chroma_crossent'] = cross_entropy(c1, c2)
|
226 |
+
row['chroma_kldiv'] = kl_divergence(c1, c2)
|
227 |
+
row['chroma_sim'] = cosine_sim(c1, c2)
|
228 |
+
|
229 |
+
ppb = max(orig._get_positions_per_bar(g1[0]), sample._get_positions_per_bar(g2[0]))
|
230 |
+
tpb = max(orig._get_ticks_per_bar(g1[0]), sample._get_ticks_per_bar(g2[0]))
|
231 |
+
r1 = groove(g1, start=g1[0], pos_per_bar=ppb, ticks_per_bar=tpb)
|
232 |
+
r2 = groove(g2, start=g2[0], pos_per_bar=ppb, ticks_per_bar=tpb)
|
233 |
+
row['groove_crossent'] = cross_entropy(r1, r2)
|
234 |
+
row['groove_kldiv'] = kl_divergence(r1, r2)
|
235 |
+
row['groove_sim'] = cosine_sim(r1, r2)
|
236 |
+
|
237 |
+
micro_metrics = pd.concat([micro_metrics, row], ignore_index=True)
|
238 |
+
if len(micro_metrics) == 0:
|
239 |
+
continue
|
240 |
+
|
241 |
+
nd_mean = np.mean(note_density_gt)
|
242 |
+
micro_metrics['note_density_nsq_err'] = micro_metrics['note_density_abs_err']**2 / nd_mean**2
|
243 |
+
|
244 |
+
metrics = pd.concat([metrics, micro_metrics], ignore_index=True)
|
245 |
+
|
246 |
+
micro_avg = micro_metrics.mean(numeric_only=True)
|
247 |
+
print("[info] Group {}: inst_f1={:.2f} | chord_f1={:.2f} | pitch_oa={:.2f} | vel_oa={:.2f} | dur_oa={:.2f} | chroma_sim={:.2f} | groove_sim={:.2f}".format(
|
248 |
+
sample_id+1, micro_avg['inst_f1'], micro_avg['chord_f1'], micro_avg['pitch_oa'], micro_avg['velocity_oa'], micro_avg['duration_oa'], micro_avg['chroma_sim'], micro_avg['groove_sim']
|
249 |
+
))
|
250 |
+
|
251 |
+
os.makedirs(os.path.dirname(OUT_FILE), exist_ok=True)
|
252 |
+
# metrics.to_csv(OUT_FILE, index=False)
|
253 |
+
|
254 |
+
summary_keys = ['inst_f1', 'chord_f1', 'time_sig_acc', 'pitch_oa', 'velocity_oa', 'duration_oa', 'chroma_sim', 'groove_sim']
|
255 |
+
summary = metrics[summary_keys + ['id']].groupby('id').mean().mean()
|
256 |
+
|
257 |
+
nsq_err = metrics.groupby('id')['note_density_nsq_err'].mean()
|
258 |
+
summary['note_density_nrmse'] = np.sqrt(nsq_err).mean()
|
259 |
+
|
260 |
+
print('***** SUMMARY *****')
|
261 |
+
print(summary)
|
262 |
+
|
263 |
+
summary.to_frame().T.to_csv(OUT_FILE, index=False)
|
264 |
+
|
265 |
+
print("done")
|
266 |
+
|
267 |
+
if __name__ == '__main__':
|
268 |
+
main()
|
music_evaluation/figaro/input_representation.py
ADDED
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from chord_recognition import MIDIChord
|
2 |
+
import numpy as np
|
3 |
+
import pretty_midi
|
4 |
+
|
5 |
+
from vocab import RemiVocab
|
6 |
+
|
7 |
+
from constants import (
|
8 |
+
EOS_TOKEN,
|
9 |
+
# vocab keys
|
10 |
+
TIME_SIGNATURE_KEY,
|
11 |
+
BAR_KEY,
|
12 |
+
POSITION_KEY,
|
13 |
+
INSTRUMENT_KEY,
|
14 |
+
PITCH_KEY,
|
15 |
+
VELOCITY_KEY,
|
16 |
+
DURATION_KEY,
|
17 |
+
TEMPO_KEY,
|
18 |
+
CHORD_KEY,
|
19 |
+
NOTE_DENSITY_KEY,
|
20 |
+
MEAN_PITCH_KEY,
|
21 |
+
MEAN_VELOCITY_KEY,
|
22 |
+
MEAN_DURATION_KEY,
|
23 |
+
# discretization parameters
|
24 |
+
DEFAULT_POS_PER_QUARTER,
|
25 |
+
DEFAULT_VELOCITY_BINS,
|
26 |
+
DEFAULT_DURATION_BINS,
|
27 |
+
DEFAULT_TEMPO_BINS,
|
28 |
+
DEFAULT_NOTE_DENSITY_BINS,
|
29 |
+
DEFAULT_MEAN_VELOCITY_BINS,
|
30 |
+
DEFAULT_MEAN_PITCH_BINS,
|
31 |
+
DEFAULT_MEAN_DURATION_BINS,
|
32 |
+
DEFAULT_RESOLUTION
|
33 |
+
)
|
34 |
+
|
35 |
+
# define "Item" for general storage
|
36 |
+
class Item(object):
|
37 |
+
def __init__(self, name, start, end, velocity=None, pitch=None, instrument=None):
|
38 |
+
self.name = name
|
39 |
+
self.start = start
|
40 |
+
self.end = end
|
41 |
+
self.velocity = velocity
|
42 |
+
self.pitch = pitch
|
43 |
+
self.instrument = instrument
|
44 |
+
|
45 |
+
def __repr__(self):
|
46 |
+
return 'Item(name={}, start={}, end={}, velocity={}, pitch={}, instrument={})'.format(
|
47 |
+
self.name, self.start, self.end, self.velocity, self.pitch, self.instrument)
|
48 |
+
|
49 |
+
# define "Event" for event storage
|
50 |
+
class Event(object):
|
51 |
+
def __init__(self, name, time, value, text):
|
52 |
+
self.name = name
|
53 |
+
self.time = time
|
54 |
+
self.value = value
|
55 |
+
self.text = text
|
56 |
+
|
57 |
+
def __repr__(self):
|
58 |
+
return 'Event(name={}, time={}, value={}, text={})'.format(
|
59 |
+
self.name, self.time, self.value, self.text)
|
60 |
+
|
61 |
+
class InputRepresentation():
|
62 |
+
def version():
|
63 |
+
return 'v4'
|
64 |
+
|
65 |
+
def __init__(self, file, do_extract_chords=True, strict=False):
|
66 |
+
if isinstance(file, pretty_midi.PrettyMIDI):
|
67 |
+
self.pm = file
|
68 |
+
else:
|
69 |
+
self.pm = pretty_midi.PrettyMIDI(file)
|
70 |
+
|
71 |
+
if strict and len(self.pm.time_signature_changes) == 0:
|
72 |
+
raise ValueError("Invalid MIDI file: No time signature defined")
|
73 |
+
|
74 |
+
self.resolution = self.pm.resolution
|
75 |
+
|
76 |
+
self.note_items = None
|
77 |
+
self.tempo_items = None
|
78 |
+
self.chords = None
|
79 |
+
self.groups = None
|
80 |
+
|
81 |
+
self._read_items()
|
82 |
+
self._quantize_items()
|
83 |
+
if do_extract_chords:
|
84 |
+
self.extract_chords()
|
85 |
+
self._group_items()
|
86 |
+
|
87 |
+
if strict and len(self.note_items) == 0:
|
88 |
+
raise ValueError("Invalid MIDI file: No notes found, empty file.")
|
89 |
+
|
90 |
+
# read notes and tempo changes from midi (assume there is only one track)
|
91 |
+
def _read_items(self):
|
92 |
+
# note
|
93 |
+
self.note_items = []
|
94 |
+
for instrument in self.pm.instruments:
|
95 |
+
pedal_events = [event for event in instrument.control_changes if event.number == 64]
|
96 |
+
pedal_pressed = False
|
97 |
+
start = None
|
98 |
+
pedals = []
|
99 |
+
for e in pedal_events:
|
100 |
+
if e.value >= 64 and not pedal_pressed:
|
101 |
+
pedal_pressed = True
|
102 |
+
start = e.time
|
103 |
+
elif e.value < 64 and pedal_pressed:
|
104 |
+
pedal_pressed = False
|
105 |
+
pedals.append(Item(name='Pedal', start=start, end=e.time))
|
106 |
+
start = e.time
|
107 |
+
|
108 |
+
notes = instrument.notes
|
109 |
+
notes.sort(key=lambda x: (x.start, x.pitch))
|
110 |
+
|
111 |
+
if instrument.is_drum:
|
112 |
+
instrument_name = 'drum'
|
113 |
+
else:
|
114 |
+
instrument_name = instrument.program
|
115 |
+
|
116 |
+
pedal_idx = 0
|
117 |
+
for note in notes:
|
118 |
+
pedal_candidates = [(i + pedal_idx, pedal) for i, pedal in enumerate(pedals[pedal_idx:]) if note.end >= pedal.start and note.start < pedal.end]
|
119 |
+
if len(pedal_candidates) > 0:
|
120 |
+
pedal_idx = pedal_candidates[0][0]
|
121 |
+
pedal = pedal_candidates[-1][1]
|
122 |
+
else:
|
123 |
+
pedal = Item(name='Pedal', start=0, end=0)
|
124 |
+
|
125 |
+
self.note_items.append(Item(
|
126 |
+
name='Note',
|
127 |
+
start=self.pm.time_to_tick(note.start),
|
128 |
+
end=self.pm.time_to_tick(max(note.end, pedal.end)),
|
129 |
+
velocity=note.velocity,
|
130 |
+
pitch=note.pitch,
|
131 |
+
instrument=instrument_name))
|
132 |
+
self.note_items.sort(key=lambda x: (x.start, x.pitch))
|
133 |
+
# tempo
|
134 |
+
self.tempo_items = []
|
135 |
+
times, tempi = self.pm.get_tempo_changes()
|
136 |
+
for time, tempo in zip(times, tempi):
|
137 |
+
self.tempo_items.append(Item(
|
138 |
+
name='Tempo',
|
139 |
+
start=time,
|
140 |
+
end=None,
|
141 |
+
velocity=None,
|
142 |
+
pitch=int(tempo)))
|
143 |
+
self.tempo_items.sort(key=lambda x: x.start)
|
144 |
+
# expand to all beat
|
145 |
+
max_tick = self.pm.time_to_tick(self.pm.get_end_time())
|
146 |
+
existing_ticks = {item.start: item.pitch for item in self.tempo_items}
|
147 |
+
wanted_ticks = np.arange(0, max_tick+1, DEFAULT_RESOLUTION)
|
148 |
+
output = []
|
149 |
+
for tick in wanted_ticks:
|
150 |
+
if tick in existing_ticks:
|
151 |
+
output.append(Item(
|
152 |
+
name='Tempo',
|
153 |
+
start=self.pm.time_to_tick(tick),
|
154 |
+
end=None,
|
155 |
+
velocity=None,
|
156 |
+
pitch=existing_ticks[tick]))
|
157 |
+
else:
|
158 |
+
output.append(Item(
|
159 |
+
name='Tempo',
|
160 |
+
start=self.pm.time_to_tick(tick),
|
161 |
+
end=None,
|
162 |
+
velocity=None,
|
163 |
+
pitch=output[-1].pitch))
|
164 |
+
self.tempo_items = output
|
165 |
+
|
166 |
+
# quantize items
|
167 |
+
def _quantize_items(self):
|
168 |
+
ticks = self.resolution / DEFAULT_POS_PER_QUARTER
|
169 |
+
# grid
|
170 |
+
end_tick = self.pm.time_to_tick(self.pm.get_end_time())
|
171 |
+
grids = np.arange(0, max(self.resolution, end_tick), ticks)
|
172 |
+
# process
|
173 |
+
for item in self.note_items:
|
174 |
+
index = np.searchsorted(grids, item.start, side='right')
|
175 |
+
if index > 0:
|
176 |
+
index -= 1
|
177 |
+
shift = round(grids[index]) - item.start
|
178 |
+
item.start += shift
|
179 |
+
item.end += shift
|
180 |
+
|
181 |
+
def get_end_tick(self):
|
182 |
+
return self.pm.time_to_tick(self.pm.get_end_time())
|
183 |
+
|
184 |
+
# extract chord
|
185 |
+
def extract_chords(self):
|
186 |
+
end_tick = self.pm.time_to_tick(self.pm.get_end_time())
|
187 |
+
if end_tick < self.resolution:
|
188 |
+
# If sequence is shorter than 1/4th note, it's probably empty
|
189 |
+
self.chords = []
|
190 |
+
return self.chords
|
191 |
+
method = MIDIChord(self.pm)
|
192 |
+
chords = method.extract()
|
193 |
+
output = []
|
194 |
+
for chord in chords:
|
195 |
+
output.append(Item(
|
196 |
+
name='Chord',
|
197 |
+
start=self.pm.time_to_tick(chord[0]),
|
198 |
+
end=self.pm.time_to_tick(chord[1]),
|
199 |
+
velocity=None,
|
200 |
+
pitch=chord[2].split('/')[0]))
|
201 |
+
if len(output) == 0 or output[0].start > 0:
|
202 |
+
if len(output) == 0:
|
203 |
+
end = self.pm.time_to_tick(self.pm.get_end_time())
|
204 |
+
else:
|
205 |
+
end = output[0].start
|
206 |
+
output.append(Item(
|
207 |
+
name='Chord',
|
208 |
+
start=0,
|
209 |
+
end=end,
|
210 |
+
velocity=None,
|
211 |
+
pitch='N:N'
|
212 |
+
))
|
213 |
+
self.chords = output
|
214 |
+
return self.chords
|
215 |
+
|
216 |
+
# group items
|
217 |
+
def _group_items(self):
|
218 |
+
if self.chords:
|
219 |
+
items = self.chords + self.tempo_items + self.note_items
|
220 |
+
else:
|
221 |
+
items = self.tempo_items + self.note_items
|
222 |
+
|
223 |
+
def _get_key(item):
|
224 |
+
type_priority = {
|
225 |
+
'Chord': 0,
|
226 |
+
'Tempo': 1,
|
227 |
+
'Note': 2
|
228 |
+
}
|
229 |
+
return (
|
230 |
+
item.start, # order by time
|
231 |
+
type_priority[item.name], # chord events first, then tempo events, then note events
|
232 |
+
-1 if item.instrument == 'drum' else item.instrument, # order by instrument
|
233 |
+
item.pitch # order by note pitch
|
234 |
+
)
|
235 |
+
|
236 |
+
items.sort(key=_get_key)
|
237 |
+
downbeats = self.pm.get_downbeats()
|
238 |
+
downbeats = np.concatenate([downbeats, [self.pm.get_end_time()]])
|
239 |
+
self.groups = []
|
240 |
+
for db1, db2 in zip(downbeats[:-1], downbeats[1:]):
|
241 |
+
db1, db2 = self.pm.time_to_tick(db1), self.pm.time_to_tick(db2)
|
242 |
+
insiders = []
|
243 |
+
for item in items:
|
244 |
+
if (item.start >= db1) and (item.start < db2):
|
245 |
+
insiders.append(item)
|
246 |
+
overall = [db1] + insiders + [db2]
|
247 |
+
self.groups.append(overall)
|
248 |
+
|
249 |
+
# Trim empty groups from the beginning and end
|
250 |
+
for idx in [0, -1]:
|
251 |
+
while len(self.groups) > 0:
|
252 |
+
group = self.groups[idx]
|
253 |
+
notes = [item for item in group[1:-1] if item.name == 'Note']
|
254 |
+
if len(notes) == 0:
|
255 |
+
self.groups.pop(idx)
|
256 |
+
else:
|
257 |
+
break
|
258 |
+
|
259 |
+
return self.groups
|
260 |
+
|
261 |
+
def _get_time_signature(self, start):
|
262 |
+
# This method assumes that time signature changes don't happen within a bar
|
263 |
+
# which is a convention that commonly holds
|
264 |
+
time_sig = None
|
265 |
+
for curr_sig, next_sig in zip(self.pm.time_signature_changes[:-1], self.pm.time_signature_changes[1:]):
|
266 |
+
if self.pm.time_to_tick(curr_sig.time) <= start and self.pm.time_to_tick(next_sig.time) > start:
|
267 |
+
time_sig = curr_sig
|
268 |
+
break
|
269 |
+
if time_sig is None:
|
270 |
+
time_sig = self.pm.time_signature_changes[-1]
|
271 |
+
return time_sig
|
272 |
+
|
273 |
+
def _get_ticks_per_bar(self, start):
|
274 |
+
time_sig = self._get_time_signature(start)
|
275 |
+
quarters_per_bar = 4 * time_sig.numerator / time_sig.denominator
|
276 |
+
return self.pm.resolution * quarters_per_bar
|
277 |
+
|
278 |
+
def _get_positions_per_bar(self, start=None, time_sig=None):
|
279 |
+
if time_sig is None:
|
280 |
+
time_sig = self._get_time_signature(start)
|
281 |
+
quarters_per_bar = 4 * time_sig.numerator / time_sig.denominator
|
282 |
+
positions_per_bar = int(DEFAULT_POS_PER_QUARTER * quarters_per_bar)
|
283 |
+
return positions_per_bar
|
284 |
+
|
285 |
+
def tick_to_position(self, tick):
|
286 |
+
return round(tick / self.pm.resolution * DEFAULT_POS_PER_QUARTER)
|
287 |
+
|
288 |
+
# item to event
|
289 |
+
def get_remi_events(self):
|
290 |
+
events = []
|
291 |
+
n_downbeat = 0
|
292 |
+
current_chord = None
|
293 |
+
current_tempo = None
|
294 |
+
for i in range(len(self.groups)):
|
295 |
+
bar_st, bar_et = self.groups[i][0], self.groups[i][-1]
|
296 |
+
n_downbeat += 1
|
297 |
+
positions_per_bar = self._get_positions_per_bar(bar_st)
|
298 |
+
if positions_per_bar <= 0:
|
299 |
+
raise ValueError('Invalid REMI file: There must be at least 1 position per bar.')
|
300 |
+
|
301 |
+
events.append(Event(
|
302 |
+
name=BAR_KEY,
|
303 |
+
time=None,
|
304 |
+
value='{}'.format(n_downbeat),
|
305 |
+
text='{}'.format(n_downbeat)))
|
306 |
+
|
307 |
+
time_sig = self._get_time_signature(bar_st)
|
308 |
+
events.append(Event(
|
309 |
+
name=TIME_SIGNATURE_KEY,
|
310 |
+
time=None,
|
311 |
+
value='{}/{}'.format(time_sig.numerator, time_sig.denominator),
|
312 |
+
text='{}/{}'.format(time_sig.numerator, time_sig.denominator)
|
313 |
+
))
|
314 |
+
|
315 |
+
if current_chord is not None:
|
316 |
+
events.append(Event(
|
317 |
+
name=POSITION_KEY,
|
318 |
+
time=0,
|
319 |
+
value='{}'.format(0),
|
320 |
+
text='{}/{}'.format(1, positions_per_bar)))
|
321 |
+
events.append(Event(
|
322 |
+
name=CHORD_KEY,
|
323 |
+
time=current_chord.start,
|
324 |
+
value=current_chord.pitch,
|
325 |
+
text='{}'.format(current_chord.pitch)))
|
326 |
+
|
327 |
+
if current_tempo is not None:
|
328 |
+
events.append(Event(
|
329 |
+
name=POSITION_KEY,
|
330 |
+
time=0,
|
331 |
+
value='{}'.format(0),
|
332 |
+
text='{}/{}'.format(1, positions_per_bar)))
|
333 |
+
tempo = current_tempo.pitch
|
334 |
+
index = np.argmin(abs(DEFAULT_TEMPO_BINS-tempo))
|
335 |
+
events.append(Event(
|
336 |
+
name=TEMPO_KEY,
|
337 |
+
time=current_tempo.start,
|
338 |
+
value=index,
|
339 |
+
text='{}/{}'.format(tempo, DEFAULT_TEMPO_BINS[index])))
|
340 |
+
|
341 |
+
quarters_per_bar = 4 * time_sig.numerator / time_sig.denominator
|
342 |
+
ticks_per_bar = self.pm.resolution * quarters_per_bar
|
343 |
+
flags = np.linspace(bar_st, bar_st + ticks_per_bar, positions_per_bar, endpoint=False)
|
344 |
+
for item in self.groups[i][1:-1]:
|
345 |
+
# position
|
346 |
+
index = np.argmin(abs(flags-item.start))
|
347 |
+
pos_event = Event(
|
348 |
+
name=POSITION_KEY,
|
349 |
+
time=item.start,
|
350 |
+
value='{}'.format(index),
|
351 |
+
text='{}/{}'.format(index+1, positions_per_bar))
|
352 |
+
|
353 |
+
if item.name == 'Note':
|
354 |
+
events.append(pos_event)
|
355 |
+
# instrument
|
356 |
+
if item.instrument == 'drum':
|
357 |
+
name = 'drum'
|
358 |
+
else:
|
359 |
+
name = pretty_midi.program_to_instrument_name(item.instrument)
|
360 |
+
events.append(Event(
|
361 |
+
name=INSTRUMENT_KEY,
|
362 |
+
time=item.start,
|
363 |
+
value=name,
|
364 |
+
text='{}'.format(name)))
|
365 |
+
# pitch
|
366 |
+
events.append(Event(
|
367 |
+
name=PITCH_KEY,
|
368 |
+
time=item.start,
|
369 |
+
value='drum_{}'.format(item.pitch) if name == 'drum' else item.pitch,
|
370 |
+
text='{}'.format(pretty_midi.note_number_to_name(item.pitch))))
|
371 |
+
# velocity
|
372 |
+
velocity_index = np.argmin(abs(DEFAULT_VELOCITY_BINS - item.velocity))
|
373 |
+
events.append(Event(
|
374 |
+
name=VELOCITY_KEY,
|
375 |
+
time=item.start,
|
376 |
+
value=velocity_index,
|
377 |
+
text='{}/{}'.format(item.velocity, DEFAULT_VELOCITY_BINS[velocity_index])))
|
378 |
+
# duration
|
379 |
+
duration = self.tick_to_position(item.end - item.start)
|
380 |
+
index = np.argmin(abs(DEFAULT_DURATION_BINS-duration))
|
381 |
+
events.append(Event(
|
382 |
+
name=DURATION_KEY,
|
383 |
+
time=item.start,
|
384 |
+
value=index,
|
385 |
+
text='{}/{}'.format(duration, DEFAULT_DURATION_BINS[index])))
|
386 |
+
elif item.name == 'Chord':
|
387 |
+
if current_chord is None or item.pitch != current_chord.pitch:
|
388 |
+
events.append(pos_event)
|
389 |
+
events.append(Event(
|
390 |
+
name=CHORD_KEY,
|
391 |
+
time=item.start,
|
392 |
+
value=item.pitch,
|
393 |
+
text='{}'.format(item.pitch)))
|
394 |
+
current_chord = item
|
395 |
+
elif item.name == 'Tempo':
|
396 |
+
if current_tempo is None or item.pitch != current_tempo.pitch:
|
397 |
+
events.append(pos_event)
|
398 |
+
tempo = item.pitch
|
399 |
+
index = np.argmin(abs(DEFAULT_TEMPO_BINS-tempo))
|
400 |
+
events.append(Event(
|
401 |
+
name=TEMPO_KEY,
|
402 |
+
time=item.start,
|
403 |
+
value=index,
|
404 |
+
text='{}/{}'.format(tempo, DEFAULT_TEMPO_BINS[index])))
|
405 |
+
current_tempo = item
|
406 |
+
|
407 |
+
return [f'{e.name}_{e.value}' for e in events]
|
408 |
+
|
409 |
+
def get_description(self,
|
410 |
+
omit_time_sig=False,
|
411 |
+
omit_instruments=False,
|
412 |
+
omit_chords=False,
|
413 |
+
omit_meta=False):
|
414 |
+
events = []
|
415 |
+
n_downbeat = 0
|
416 |
+
current_chord = None
|
417 |
+
|
418 |
+
for i in range(len(self.groups)):
|
419 |
+
bar_st, bar_et = self.groups[i][0], self.groups[i][-1]
|
420 |
+
n_downbeat += 1
|
421 |
+
time_sig = self._get_time_signature(bar_st)
|
422 |
+
positions_per_bar = self._get_positions_per_bar(time_sig=time_sig)
|
423 |
+
if positions_per_bar <= 0:
|
424 |
+
raise ValueError('Invalid REMI file: There must be at least 1 position in each bar.')
|
425 |
+
|
426 |
+
events.append(Event(
|
427 |
+
name=BAR_KEY,
|
428 |
+
time=None,
|
429 |
+
value='{}'.format(n_downbeat),
|
430 |
+
text='{}'.format(n_downbeat)))
|
431 |
+
|
432 |
+
if not omit_time_sig:
|
433 |
+
events.append(Event(
|
434 |
+
name=TIME_SIGNATURE_KEY,
|
435 |
+
time=None,
|
436 |
+
value='{}/{}'.format(time_sig.numerator, time_sig.denominator),
|
437 |
+
text='{}/{}'.format(time_sig.numerator, time_sig.denominator),
|
438 |
+
))
|
439 |
+
|
440 |
+
if not omit_meta:
|
441 |
+
notes = [item for item in self.groups[i][1:-1] if item.name == 'Note']
|
442 |
+
n_notes = len(notes)
|
443 |
+
velocities = np.array([item.velocity for item in notes])
|
444 |
+
pitches = np.array([item.pitch for item in notes])
|
445 |
+
durations = np.array([item.end - item.start for item in notes])
|
446 |
+
|
447 |
+
note_density = n_notes/positions_per_bar
|
448 |
+
index = np.argmin(abs(DEFAULT_NOTE_DENSITY_BINS-note_density))
|
449 |
+
events.append(Event(
|
450 |
+
name=NOTE_DENSITY_KEY,
|
451 |
+
time=None,
|
452 |
+
value=index,
|
453 |
+
text='{:.2f}/{:.2f}'.format(note_density, DEFAULT_NOTE_DENSITY_BINS[index])
|
454 |
+
))
|
455 |
+
|
456 |
+
# will be 0 if there's no notes
|
457 |
+
mean_velocity = velocities.mean() if len(velocities) > 0 else np.nan
|
458 |
+
index = np.argmin(abs(DEFAULT_MEAN_VELOCITY_BINS-mean_velocity))
|
459 |
+
events.append(Event(
|
460 |
+
name=MEAN_VELOCITY_KEY,
|
461 |
+
time=None,
|
462 |
+
value=index if mean_velocity != np.nan else 'NaN',
|
463 |
+
text='{:.2f}/{:.2f}'.format(mean_velocity, DEFAULT_MEAN_VELOCITY_BINS[index])
|
464 |
+
))
|
465 |
+
|
466 |
+
# will be 0 if there's no notes
|
467 |
+
mean_pitch = pitches.mean() if len(pitches) > 0 else np.nan
|
468 |
+
index = np.argmin(abs(DEFAULT_MEAN_PITCH_BINS-mean_pitch))
|
469 |
+
events.append(Event(
|
470 |
+
name=MEAN_PITCH_KEY,
|
471 |
+
time=None,
|
472 |
+
value=index if mean_pitch != np.nan else 'NaN',
|
473 |
+
text='{:.2f}/{:.2f}'.format(mean_pitch, DEFAULT_MEAN_PITCH_BINS[index])
|
474 |
+
))
|
475 |
+
|
476 |
+
# will be 1 if there's no notes
|
477 |
+
mean_duration = durations.mean() if len(durations) > 0 else np.nan
|
478 |
+
index = np.argmin(abs(DEFAULT_MEAN_DURATION_BINS-mean_duration))
|
479 |
+
events.append(Event(
|
480 |
+
name=MEAN_DURATION_KEY,
|
481 |
+
time=None,
|
482 |
+
value=index if mean_duration != np.nan else 'NaN',
|
483 |
+
text='{:.2f}/{:.2f}'.format(mean_duration, DEFAULT_MEAN_DURATION_BINS[index])
|
484 |
+
))
|
485 |
+
|
486 |
+
if not omit_instruments:
|
487 |
+
instruments = set([item.instrument for item in notes])
|
488 |
+
for instrument in instruments:
|
489 |
+
instrument = pretty_midi.program_to_instrument_name(instrument) if instrument != 'drum' else 'drum'
|
490 |
+
events.append(Event(
|
491 |
+
name=INSTRUMENT_KEY,
|
492 |
+
time=None,
|
493 |
+
value=instrument,
|
494 |
+
text=instrument
|
495 |
+
))
|
496 |
+
|
497 |
+
if not omit_chords:
|
498 |
+
chords = [item for item in self.groups[i][1:-1] if item.name == 'Chord']
|
499 |
+
if len(chords) == 0 and current_chord is not None:
|
500 |
+
chords = [current_chord]
|
501 |
+
elif len(chords) > 0:
|
502 |
+
if chords[0].start > bar_st and current_chord is not None:
|
503 |
+
chords.insert(0, current_chord)
|
504 |
+
current_chord = chords[-1]
|
505 |
+
|
506 |
+
for chord in chords:
|
507 |
+
events.append(Event(
|
508 |
+
name=CHORD_KEY,
|
509 |
+
time=None,
|
510 |
+
value=chord.pitch,
|
511 |
+
text='{}'.format(chord.pitch)
|
512 |
+
))
|
513 |
+
|
514 |
+
return [f'{e.name}_{e.value}' for e in events]
|
515 |
+
|
516 |
+
|
517 |
+
#############################################################################################
|
518 |
+
# WRITE MIDI
|
519 |
+
#############################################################################################
|
520 |
+
|
521 |
+
def remi2midi(events, bpm=120, time_signature=(4, 4), polyphony_limit=16):
|
522 |
+
vocab = RemiVocab()
|
523 |
+
|
524 |
+
def _get_time(bar, position, bpm=120, positions_per_bar=48):
|
525 |
+
abs_position = bar*positions_per_bar + position
|
526 |
+
beat = abs_position / DEFAULT_POS_PER_QUARTER
|
527 |
+
return beat/bpm*60
|
528 |
+
|
529 |
+
def _get_time(reference, bar, pos):
|
530 |
+
time_sig = reference['time_sig']
|
531 |
+
num, denom = time_sig.numerator, time_sig.denominator
|
532 |
+
# Quarters per bar, assuming 4 quarters per whole note
|
533 |
+
qpb = 4 * num / denom
|
534 |
+
ref_pos = reference['pos']
|
535 |
+
d_bars = bar - ref_pos[0]
|
536 |
+
d_pos = (pos - ref_pos[1]) + d_bars*qpb*DEFAULT_POS_PER_QUARTER
|
537 |
+
d_quarters = d_pos / DEFAULT_POS_PER_QUARTER
|
538 |
+
# Convert quarters to seconds
|
539 |
+
dt = d_quarters / reference['tempo'] * 60
|
540 |
+
return reference['time'] + dt
|
541 |
+
|
542 |
+
# time_sigs = [event.split('_')[-1].split('/') for event in events if f"{TIME_SIGNATURE_KEY}_" in event]
|
543 |
+
# time_sigs = [(int(num), int(denom)) for num, denom in time_sigs]
|
544 |
+
|
545 |
+
tempo_changes = [event for event in events if f"{TEMPO_KEY}_" in event]
|
546 |
+
if len(tempo_changes) > 0:
|
547 |
+
bpm = DEFAULT_TEMPO_BINS[int(tempo_changes[0].split('_')[-1])]
|
548 |
+
|
549 |
+
pm = pretty_midi.PrettyMIDI(initial_tempo=bpm)
|
550 |
+
num, denom = time_signature
|
551 |
+
pm.time_signature_changes.append(pretty_midi.TimeSignature(num, denom, 0))
|
552 |
+
current_time_sig = pm.time_signature_changes[0]
|
553 |
+
|
554 |
+
instruments = {}
|
555 |
+
|
556 |
+
# Use implicit timeline: keep track of last tempo/time signature change event
|
557 |
+
# and calculate time difference relative to that
|
558 |
+
last_tl_event = {
|
559 |
+
'time': 0,
|
560 |
+
'pos': (0, 0),
|
561 |
+
'time_sig': current_time_sig,
|
562 |
+
'tempo': bpm
|
563 |
+
}
|
564 |
+
|
565 |
+
bar = -1
|
566 |
+
n_notes = 0
|
567 |
+
polyphony_control = {}
|
568 |
+
for i, event in enumerate(events):
|
569 |
+
if event == EOS_TOKEN:
|
570 |
+
break
|
571 |
+
|
572 |
+
if not bar in polyphony_control:
|
573 |
+
polyphony_control[bar] = {}
|
574 |
+
|
575 |
+
if f"{BAR_KEY}_" in events[i]:
|
576 |
+
# Next bar is starting
|
577 |
+
bar += 1
|
578 |
+
polyphony_control[bar] = {}
|
579 |
+
|
580 |
+
if i+1 < len(events) and f"{TIME_SIGNATURE_KEY}_" in events[i+1]:
|
581 |
+
num, denom = events[i+1].split('_')[-1].split('/')
|
582 |
+
num, denom = int(num), int(denom)
|
583 |
+
current_time_sig = last_tl_event['time_sig']
|
584 |
+
if num != current_time_sig.numerator or denom != current_time_sig.denominator:
|
585 |
+
time = _get_time(last_tl_event, bar, 0)
|
586 |
+
time_sig = pretty_midi.TimeSignature(num, denom, time)
|
587 |
+
pm.time_signature_changes.append(time_sig)
|
588 |
+
last_tl_event['time'] = time
|
589 |
+
last_tl_event['pos'] = (bar, 0)
|
590 |
+
last_tl_event['time_sig'] = time_sig
|
591 |
+
|
592 |
+
elif i+1 < len(events) and \
|
593 |
+
f"{POSITION_KEY}_" in events[i] and \
|
594 |
+
f"{TEMPO_KEY}_" in events[i+1]:
|
595 |
+
position = int(events[i].split('_')[-1])
|
596 |
+
tempo_idx = int(events[i+1].split('_')[-1])
|
597 |
+
tempo = DEFAULT_TEMPO_BINS[tempo_idx]
|
598 |
+
|
599 |
+
if tempo != last_tl_event['tempo']:
|
600 |
+
time = _get_time(last_tl_event, bar, position)
|
601 |
+
last_tl_event['time'] = time
|
602 |
+
last_tl_event['pos'] = (bar, position)
|
603 |
+
last_tl_event['tempo'] = tempo
|
604 |
+
|
605 |
+
elif i+4 < len(events) and \
|
606 |
+
f"{POSITION_KEY}_" in events[i] and \
|
607 |
+
f"{INSTRUMENT_KEY}_" in events[i+1] and \
|
608 |
+
f"{PITCH_KEY}_" in events[i+2] and \
|
609 |
+
f"{VELOCITY_KEY}_" in events[i+3] and \
|
610 |
+
f"{DURATION_KEY}_" in events[i+4]:
|
611 |
+
# get position
|
612 |
+
position = int(events[i].split('_')[-1])
|
613 |
+
if not position in polyphony_control[bar]:
|
614 |
+
polyphony_control[bar][position] = {}
|
615 |
+
|
616 |
+
# get instrument
|
617 |
+
instrument_name = events[i+1].split('_')[-1]
|
618 |
+
if instrument_name not in polyphony_control[bar][position]:
|
619 |
+
polyphony_control[bar][position][instrument_name] = 0
|
620 |
+
elif polyphony_control[bar][position][instrument_name] >= polyphony_limit:
|
621 |
+
# If number of notes exceeds polyphony limit, omit this note
|
622 |
+
continue
|
623 |
+
|
624 |
+
if instrument_name not in instruments:
|
625 |
+
if instrument_name == 'drum':
|
626 |
+
instrument = pretty_midi.Instrument(0, is_drum=True)
|
627 |
+
else:
|
628 |
+
program = pretty_midi.instrument_name_to_program(instrument_name)
|
629 |
+
instrument = pretty_midi.Instrument(program)
|
630 |
+
instruments[instrument_name] = instrument
|
631 |
+
else:
|
632 |
+
instrument = instruments[instrument_name]
|
633 |
+
|
634 |
+
# get pitch
|
635 |
+
pitch = int(events[i+2].split('_')[-1])
|
636 |
+
# get velocity
|
637 |
+
velocity_index = int(events[i+3].split('_')[-1])
|
638 |
+
velocity = min(127, DEFAULT_VELOCITY_BINS[velocity_index])
|
639 |
+
# get duration
|
640 |
+
duration_index = int(events[i+4].split('_')[-1])
|
641 |
+
duration = DEFAULT_DURATION_BINS[duration_index]
|
642 |
+
# create not and add to instrument
|
643 |
+
start = _get_time(last_tl_event, bar, position)
|
644 |
+
end = _get_time(last_tl_event, bar, position + duration)
|
645 |
+
note = pretty_midi.Note(velocity=velocity,
|
646 |
+
pitch=pitch,
|
647 |
+
start=start,
|
648 |
+
end=end)
|
649 |
+
instrument.notes.append(note)
|
650 |
+
n_notes += 1
|
651 |
+
polyphony_control[bar][position][instrument_name] += 1
|
652 |
+
|
653 |
+
for instrument in instruments.values():
|
654 |
+
pm.instruments.append(instrument)
|
655 |
+
return pm
|
music_evaluation/figaro/vocab.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pretty_midi
|
2 |
+
from collections import Counter
|
3 |
+
import torchtext
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
from constants import (
|
7 |
+
DEFAULT_VELOCITY_BINS,
|
8 |
+
DEFAULT_DURATION_BINS,
|
9 |
+
DEFAULT_TEMPO_BINS,
|
10 |
+
DEFAULT_POS_PER_QUARTER,
|
11 |
+
DEFAULT_NOTE_DENSITY_BINS,
|
12 |
+
DEFAULT_MEAN_VELOCITY_BINS,
|
13 |
+
DEFAULT_MEAN_PITCH_BINS,
|
14 |
+
DEFAULT_MEAN_DURATION_BINS
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
from constants import (
|
19 |
+
MAX_BAR_LENGTH,
|
20 |
+
MAX_N_BARS,
|
21 |
+
|
22 |
+
PAD_TOKEN,
|
23 |
+
UNK_TOKEN,
|
24 |
+
BOS_TOKEN,
|
25 |
+
EOS_TOKEN,
|
26 |
+
MASK_TOKEN,
|
27 |
+
|
28 |
+
TIME_SIGNATURE_KEY,
|
29 |
+
BAR_KEY,
|
30 |
+
POSITION_KEY,
|
31 |
+
INSTRUMENT_KEY,
|
32 |
+
PITCH_KEY,
|
33 |
+
VELOCITY_KEY,
|
34 |
+
DURATION_KEY,
|
35 |
+
TEMPO_KEY,
|
36 |
+
CHORD_KEY,
|
37 |
+
|
38 |
+
NOTE_DENSITY_KEY,
|
39 |
+
MEAN_PITCH_KEY,
|
40 |
+
MEAN_VELOCITY_KEY,
|
41 |
+
MEAN_DURATION_KEY,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
class Tokens:
|
47 |
+
def get_instrument_tokens(key=INSTRUMENT_KEY):
|
48 |
+
tokens = [f'{key}_{pretty_midi.program_to_instrument_name(i)}' for i in range(128)]
|
49 |
+
tokens.append(f'{key}_drum')
|
50 |
+
return tokens
|
51 |
+
|
52 |
+
def get_chord_tokens(key=CHORD_KEY, qualities = ['maj', 'min', 'dim', 'aug', 'dom7', 'maj7', 'min7', 'None']):
|
53 |
+
pitch_classes = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
54 |
+
|
55 |
+
chords = [f'{root}:{quality}' for root in pitch_classes for quality in qualities]
|
56 |
+
chords.append('N:N')
|
57 |
+
|
58 |
+
tokens = [f'{key}_{chord}' for chord in chords]
|
59 |
+
return tokens
|
60 |
+
|
61 |
+
def get_time_signature_tokens(key=TIME_SIGNATURE_KEY):
|
62 |
+
denominators = [2, 4, 8, 16]
|
63 |
+
time_sigs = [f'{p}/{q}' for q in denominators for p in range(1, MAX_BAR_LENGTH*q + 1)]
|
64 |
+
tokens = [f'{key}_{time_sig}' for time_sig in time_sigs]
|
65 |
+
return tokens
|
66 |
+
|
67 |
+
def get_midi_tokens(
|
68 |
+
instrument_key=INSTRUMENT_KEY,
|
69 |
+
time_signature_key=TIME_SIGNATURE_KEY,
|
70 |
+
pitch_key=PITCH_KEY,
|
71 |
+
velocity_key=VELOCITY_KEY,
|
72 |
+
duration_key=DURATION_KEY,
|
73 |
+
tempo_key=TEMPO_KEY,
|
74 |
+
bar_key=BAR_KEY,
|
75 |
+
position_key=POSITION_KEY
|
76 |
+
):
|
77 |
+
instrument_tokens = Tokens.get_instrument_tokens(instrument_key)
|
78 |
+
|
79 |
+
pitch_tokens = [f'{pitch_key}_{i}' for i in range(128)] + [f'{pitch_key}_drum_{i}' for i in range(128)]
|
80 |
+
velocity_tokens = [f'{velocity_key}_{i}' for i in range(len(DEFAULT_VELOCITY_BINS))]
|
81 |
+
duration_tokens = [f'{duration_key}_{i}' for i in range(len(DEFAULT_DURATION_BINS))]
|
82 |
+
tempo_tokens = [f'{tempo_key}_{i}' for i in range(len(DEFAULT_TEMPO_BINS))]
|
83 |
+
bar_tokens = [f'{bar_key}_{i}' for i in range(MAX_N_BARS)]
|
84 |
+
position_tokens = [f'{position_key}_{i}' for i in range(MAX_BAR_LENGTH*4*DEFAULT_POS_PER_QUARTER)]
|
85 |
+
|
86 |
+
time_sig_tokens = Tokens.get_time_signature_tokens(time_signature_key)
|
87 |
+
|
88 |
+
return (
|
89 |
+
time_sig_tokens +
|
90 |
+
tempo_tokens +
|
91 |
+
instrument_tokens +
|
92 |
+
pitch_tokens +
|
93 |
+
velocity_tokens +
|
94 |
+
duration_tokens +
|
95 |
+
bar_tokens +
|
96 |
+
position_tokens
|
97 |
+
)
|
98 |
+
|
99 |
+
class Vocab:
|
100 |
+
def __init__(self, counter, specials=[PAD_TOKEN, UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, MASK_TOKEN], unk_token=UNK_TOKEN):
|
101 |
+
self.vocab = torchtext.vocab.vocab(counter)
|
102 |
+
|
103 |
+
self.specials = specials
|
104 |
+
for i, token in enumerate(self.specials):
|
105 |
+
self.vocab.insert_token(token, i)
|
106 |
+
|
107 |
+
if unk_token in specials:
|
108 |
+
self.vocab.set_default_index(self.vocab.get_stoi()[unk_token])
|
109 |
+
|
110 |
+
def to_i(self, token):
|
111 |
+
return self.vocab.get_stoi()[token]
|
112 |
+
|
113 |
+
def to_s(self, idx):
|
114 |
+
if idx >= len(self.vocab):
|
115 |
+
return UNK_TOKEN
|
116 |
+
else:
|
117 |
+
return self.vocab.get_itos()[idx]
|
118 |
+
|
119 |
+
def __len__(self):
|
120 |
+
return len(self.vocab)
|
121 |
+
|
122 |
+
def encode(self, seq):
|
123 |
+
return self.vocab(seq)
|
124 |
+
|
125 |
+
def decode(self, seq):
|
126 |
+
if isinstance(seq, Tensor):
|
127 |
+
seq = seq.numpy()
|
128 |
+
return self.vocab.lookup_tokens(seq)
|
129 |
+
|
130 |
+
|
131 |
+
class RemiVocab(Vocab):
|
132 |
+
def __init__(self):
|
133 |
+
midi_tokens = Tokens.get_midi_tokens()
|
134 |
+
chord_tokens = Tokens.get_chord_tokens()
|
135 |
+
|
136 |
+
self.tokens = midi_tokens + chord_tokens
|
137 |
+
|
138 |
+
counter = Counter(self.tokens)
|
139 |
+
super().__init__(counter)
|
140 |
+
|
141 |
+
|
142 |
+
class DescriptionVocab(Vocab):
|
143 |
+
def __init__(self):
|
144 |
+
time_sig_tokens = Tokens.get_time_signature_tokens()
|
145 |
+
instrument_tokens = Tokens.get_instrument_tokens()
|
146 |
+
chord_tokens = Tokens.get_chord_tokens()
|
147 |
+
|
148 |
+
bar_tokens = [f'Bar_{i}' for i in range(MAX_N_BARS)]
|
149 |
+
density_tokens = [f'{NOTE_DENSITY_KEY}_{i}' for i in range(len(DEFAULT_NOTE_DENSITY_BINS))]
|
150 |
+
velocity_tokens = [f'{MEAN_VELOCITY_KEY}_{i}' for i in range(len(DEFAULT_MEAN_VELOCITY_BINS))]
|
151 |
+
pitch_tokens = [f'{MEAN_PITCH_KEY}_{i}' for i in range(len(DEFAULT_MEAN_PITCH_BINS))]
|
152 |
+
duration_tokens = [f'{MEAN_DURATION_KEY}_{i}' for i in range(len(DEFAULT_MEAN_DURATION_BINS))]
|
153 |
+
|
154 |
+
self.tokens = (
|
155 |
+
time_sig_tokens +
|
156 |
+
instrument_tokens +
|
157 |
+
chord_tokens +
|
158 |
+
density_tokens +
|
159 |
+
velocity_tokens +
|
160 |
+
pitch_tokens +
|
161 |
+
duration_tokens +
|
162 |
+
bar_tokens
|
163 |
+
)
|
164 |
+
|
165 |
+
counter = Counter(self.tokens)
|
166 |
+
super().__init__(counter)
|
music_evaluation/mgeval/__init__.py
ADDED
File without changes
|
music_evaluation/mgeval/__init__.pyc
ADDED
Binary file (105 Bytes). View file
|
|
music_evaluation/mgeval/core.py
ADDED
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding:utf-8
|
2 |
+
"""core.py
|
3 |
+
Include feature extractor and musically informed objective measures.
|
4 |
+
"""
|
5 |
+
import pretty_midi
|
6 |
+
import numpy as np
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import statistics
|
10 |
+
# import midi
|
11 |
+
import glob
|
12 |
+
import math
|
13 |
+
|
14 |
+
|
15 |
+
# feature extractor
|
16 |
+
def extract_feature(_file):
|
17 |
+
"""
|
18 |
+
This function extracts two midi feature:
|
19 |
+
pretty_midi object: https://github.com/craffel/pretty-midi
|
20 |
+
midi_pattern: https://github.com/vishnubob/python-midi
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
dict(pretty_midi: pretty_midi object,
|
24 |
+
midi_pattern: midi pattern contains a list of tracks)
|
25 |
+
"""
|
26 |
+
feature = {'pretty_midi': pretty_midi.PrettyMIDI(_file)}
|
27 |
+
# 'midi_pattern': midi.read_midifile(_file)}
|
28 |
+
return feature
|
29 |
+
|
30 |
+
|
31 |
+
# musically informed objective measures.
|
32 |
+
class metrics(object):
|
33 |
+
def total_used_pitch(self, feature):
|
34 |
+
"""
|
35 |
+
total_used_pitch (Pitch count): The number of different pitches within a sample.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
'used_pitch': pitch count, scalar for each sample.
|
39 |
+
"""
|
40 |
+
try:
|
41 |
+
instrument = feature['pretty_midi'].instruments[0]
|
42 |
+
piano_roll = instrument.get_piano_roll(fs=100)
|
43 |
+
sum_notes = np.sum(piano_roll, axis=1)
|
44 |
+
used_pitch = np.sum(sum_notes > 0)
|
45 |
+
return used_pitch
|
46 |
+
except:
|
47 |
+
return 0 # empty piano roll
|
48 |
+
|
49 |
+
def mean_note_velocity(self, feature):
|
50 |
+
"""
|
51 |
+
mean_var_note_velocity: The velocity of different notes within a sample.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
note_velocity: the average velocity and the variance of velocity, 2 scalar for each sample.
|
55 |
+
"""
|
56 |
+
try:
|
57 |
+
instrument = feature['pretty_midi'].instruments[0]
|
58 |
+
velocity = []
|
59 |
+
for note in instrument.notes:
|
60 |
+
velocity.append(note.velocity)
|
61 |
+
|
62 |
+
mean_velocity = statistics.mean(velocity)
|
63 |
+
if len(velocity) > 1:
|
64 |
+
variance_velocity = statistics.variance(velocity)
|
65 |
+
else:
|
66 |
+
variance_velocity = 0
|
67 |
+
|
68 |
+
return mean_velocity
|
69 |
+
except:
|
70 |
+
return 0 # empty piano roll
|
71 |
+
|
72 |
+
def mean_note_duration(self, feature):
|
73 |
+
"""
|
74 |
+
mean_var_note_duration: The duration of different notes within a sample.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
note_duration: the average duration and the variance of velocity, 2 scalar for each sample.
|
78 |
+
"""
|
79 |
+
try:
|
80 |
+
instrument = feature['pretty_midi'].instruments[0]
|
81 |
+
duration = []
|
82 |
+
for note in instrument.notes:
|
83 |
+
d = note.end - note.start
|
84 |
+
duration.append(d)
|
85 |
+
|
86 |
+
mean_duration = statistics.mean(duration)
|
87 |
+
if len(duration) > 1:
|
88 |
+
variance_duration = statistics.variance(duration)
|
89 |
+
else:
|
90 |
+
variance_duration = 0
|
91 |
+
|
92 |
+
return mean_duration
|
93 |
+
except:
|
94 |
+
return 0 # empty piano roll
|
95 |
+
|
96 |
+
def note_density(self, feature):
|
97 |
+
"""
|
98 |
+
note_density: the density of note within a sample.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
note_density: the density of notes, 1 scalar for each sample.
|
102 |
+
"""
|
103 |
+
|
104 |
+
# instrument = feature['pretty_midi'].instruments[0]
|
105 |
+
total_notes = sum(len(instrument.notes) for instrument in feature['pretty_midi'].instruments)
|
106 |
+
total_duration = feature['pretty_midi'].get_end_time()
|
107 |
+
|
108 |
+
# Calculate the note density
|
109 |
+
if total_duration > 0:
|
110 |
+
note_density = total_notes / total_duration
|
111 |
+
else:
|
112 |
+
note_density = 0
|
113 |
+
|
114 |
+
return note_density
|
115 |
+
|
116 |
+
# def bar_used_pitch(self, feature, track_num=1, num_bar=None):
|
117 |
+
# """
|
118 |
+
# bar_used_pitch (Pitch count per bar)
|
119 |
+
|
120 |
+
# Args:
|
121 |
+
# 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
|
122 |
+
# 'num_bar': specify the number of bars in the midi pattern, if set as None, round to the number of complete bar.
|
123 |
+
|
124 |
+
# Returns:
|
125 |
+
# 'used_pitch': with shape of [num_bar,1]
|
126 |
+
# """
|
127 |
+
# pattern = feature['midi_pattern']
|
128 |
+
# pattern.make_ticks_abs()
|
129 |
+
# resolution = pattern.resolution
|
130 |
+
# for i in range(0, len(pattern[track_num])):
|
131 |
+
# if type(pattern[track_num][i]) == midi.events.TimeSignatureEvent:
|
132 |
+
# time_sig = pattern[track_num][i].data
|
133 |
+
# bar_length = time_sig[0] * resolution * 4 / 2**(time_sig[1])
|
134 |
+
# if num_bar is None:
|
135 |
+
# num_bar = int(round(float(pattern[track_num][-1].tick) / bar_length))
|
136 |
+
# used_notes = np.zeros((num_bar, 1))
|
137 |
+
# else:
|
138 |
+
# used_notes = np.zeros((num_bar, 1))
|
139 |
+
|
140 |
+
# elif type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
|
141 |
+
# if 'time_sig' not in locals(): # set default bar length as 4 beat
|
142 |
+
# bar_length = 4 * resolution
|
143 |
+
# time_sig = [4, 2, 24, 8]
|
144 |
+
|
145 |
+
# if num_bar is None:
|
146 |
+
# num_bar = int(round(float(pattern[track_num][-1].tick) / bar_length))
|
147 |
+
# used_notes = np.zeros((num_bar, 1))
|
148 |
+
# used_notes[pattern[track_num][i].tick / bar_length] += 1
|
149 |
+
# else:
|
150 |
+
# used_notes = np.zeros((num_bar, 1))
|
151 |
+
# used_notes[pattern[track_num][i].tick / bar_length] += 1
|
152 |
+
# note_list = []
|
153 |
+
# note_list.append(pattern[track_num][i].data[0])
|
154 |
+
|
155 |
+
# else:
|
156 |
+
# for j in range(0, num_bar):
|
157 |
+
# if 'note_list'in locals():
|
158 |
+
# pass
|
159 |
+
# else:
|
160 |
+
# note_list = []
|
161 |
+
# note_list.append(pattern[track_num][i].data[0])
|
162 |
+
# idx = pattern[track_num][i].tick / bar_length
|
163 |
+
# if idx >= num_bar:
|
164 |
+
# continue
|
165 |
+
# used_notes[idx] += 1
|
166 |
+
# # used_notes[pattern[track_num][i].tick / bar_length] += 1
|
167 |
+
|
168 |
+
# used_pitch = np.zeros((num_bar, 1))
|
169 |
+
# current_note = 0
|
170 |
+
# for i in range(0, num_bar):
|
171 |
+
# used_pitch[i] = len(set(note_list[current_note:current_note + int(used_notes[i][0])]))
|
172 |
+
# current_note += int(used_notes[i][0])
|
173 |
+
|
174 |
+
# return used_pitch
|
175 |
+
|
176 |
+
# def total_used_note(self, feature, track_num=1):
|
177 |
+
# """
|
178 |
+
# total_used_note (Note count): The number of used notes.
|
179 |
+
# As opposed to the pitch count, the note count does not contain pitch information but is a rhythm-related feature.
|
180 |
+
|
181 |
+
# Args:
|
182 |
+
# 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
|
183 |
+
|
184 |
+
# Returns:
|
185 |
+
# 'used_notes': a scalar for each sample.
|
186 |
+
# """
|
187 |
+
# pattern = feature['midi_pattern']
|
188 |
+
# used_notes = 0
|
189 |
+
# for i in range(0, len(pattern[track_num])):
|
190 |
+
# if type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
|
191 |
+
# used_notes += 1
|
192 |
+
# return used_notes
|
193 |
+
|
194 |
+
# def bar_used_note(self, feature, track_num=1, num_bar=None):
|
195 |
+
# """
|
196 |
+
# bar_used_note (Note count per bar).
|
197 |
+
|
198 |
+
# Args:
|
199 |
+
# 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
|
200 |
+
# 'num_bar': specify the number of bars in the midi pattern, if set as None, round to the number of complete bar.
|
201 |
+
|
202 |
+
# Returns:
|
203 |
+
# 'used_notes': with shape of [num_bar, 1]
|
204 |
+
# """
|
205 |
+
# pattern = feature['midi_pattern']
|
206 |
+
# pattern.make_ticks_abs()
|
207 |
+
# resolution = pattern.resolution
|
208 |
+
# for i in range(0, len(pattern[track_num])):
|
209 |
+
# if type(pattern[track_num][i]) == midi.events.TimeSignatureEvent:
|
210 |
+
# time_sig = pattern[track_num][i].data
|
211 |
+
# bar_length = time_sig[track_num] * resolution * 4 / 2**(time_sig[1])
|
212 |
+
# if num_bar is None:
|
213 |
+
# num_bar = int(round(float(pattern[track_num][-1].tick) / bar_length))
|
214 |
+
# used_notes = np.zeros((num_bar, 1))
|
215 |
+
# else:
|
216 |
+
# used_notes = np.zeros((num_bar, 1))
|
217 |
+
|
218 |
+
# elif type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
|
219 |
+
# if 'time_sig' not in locals(): # set default bar length as 4 beat
|
220 |
+
# bar_length = 4 * resolution
|
221 |
+
# time_sig = [4, 2, 24, 8]
|
222 |
+
|
223 |
+
# if num_bar is None:
|
224 |
+
# num_bar = int(round(float(pattern[track_num][-1].tick) / bar_length))
|
225 |
+
# used_notes = np.zeros((num_bar, 1))
|
226 |
+
# used_notes[pattern[track_num][i].tick / bar_length] += 1
|
227 |
+
# else:
|
228 |
+
# used_notes = np.zeros((num_bar, 1))
|
229 |
+
# used_notes[pattern[track_num][i].tick / bar_length] += 1
|
230 |
+
|
231 |
+
# else:
|
232 |
+
# idx = pattern[track_num][i].tick / bar_length
|
233 |
+
# if idx >= num_bar:
|
234 |
+
# continue
|
235 |
+
# used_notes[idx] += 1
|
236 |
+
# return used_notes
|
237 |
+
|
238 |
+
def total_pitch_class_histogram(self, feature):
|
239 |
+
"""
|
240 |
+
total_pitch_class_histogram (Pitch class histogram):
|
241 |
+
The pitch class histogram is an octave-independent representation of the pitch content with a dimensionality of 12 for a chromatic scale.
|
242 |
+
In our case, it represents to the octave-independent chromatic quantization of the frequency continuum.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
'histogram': histrogram of 12 pitch, with weighted duration shape 12
|
246 |
+
"""
|
247 |
+
# print(feature['pretty_midi'].instruments)
|
248 |
+
histogram = np.zeros(12)
|
249 |
+
try:
|
250 |
+
piano_roll = feature['pretty_midi'].instruments[0].get_piano_roll(fs=100)
|
251 |
+
# piano_roll = feature['pretty_midi'].get_piano_roll(fs=100)
|
252 |
+
|
253 |
+
for i in range(0, 128):
|
254 |
+
pitch_class = i % 12
|
255 |
+
histogram[pitch_class] += np.sum(piano_roll, axis=1)[i]
|
256 |
+
histogram = histogram / sum(histogram)
|
257 |
+
return histogram
|
258 |
+
except:
|
259 |
+
return histogram
|
260 |
+
|
261 |
+
def bar_pitch_class_histogram(self, feature, track_num=1, num_bar=None, bpm=120):
|
262 |
+
"""
|
263 |
+
bar_pitch_class_histogram (Pitch class histogram per bar):
|
264 |
+
|
265 |
+
Args:
|
266 |
+
'bpm' : specify the assigned speed in bpm, default is 120 bpm.
|
267 |
+
'num_bar': specify the number of bars in the midi pattern, if set as None, round to the number of complete bar.
|
268 |
+
'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
'histogram': with shape of [num_bar, 12]
|
272 |
+
"""
|
273 |
+
|
274 |
+
# todo: deal with more than one time signature cases
|
275 |
+
pm_object = feature['pretty_midi']
|
276 |
+
if num_bar is None:
|
277 |
+
numer = pm_object.time_signature_changes[-1].numerator
|
278 |
+
deno = pm_object.time_signature_changes[-1].denominator
|
279 |
+
bar_length = 60. / bpm * numer * 4 / deno * 100
|
280 |
+
piano_roll = pm_object.instruments[track_num].get_piano_roll(fs=100)
|
281 |
+
piano_roll = np.transpose(piano_roll, (1, 0))
|
282 |
+
actual_bar = len(piano_roll) / bar_length
|
283 |
+
num_bar = int(round(actual_bar))
|
284 |
+
bar_length = int(round(bar_length))
|
285 |
+
else:
|
286 |
+
numer = pm_object.time_signature_changes[-1].numerator
|
287 |
+
deno = pm_object.time_signature_changes[-1].denominator
|
288 |
+
bar_length = 60. / bpm * numer * 4 / deno * 100
|
289 |
+
piano_roll = pm_object.instruments[track_num].get_piano_roll(fs=100)
|
290 |
+
piano_roll = np.transpose(piano_roll, (1, 0))
|
291 |
+
actual_bar = len(piano_roll) / bar_length
|
292 |
+
bar_length = int(math.ceil(bar_length))
|
293 |
+
|
294 |
+
if actual_bar > num_bar:
|
295 |
+
mod = np.mod(len(piano_roll), bar_length * 128)
|
296 |
+
piano_roll = piano_roll[:-np.mod(len(piano_roll), bar_length)].reshape((num_bar, -1, 128)) # make exact bar
|
297 |
+
elif actual_bar == num_bar:
|
298 |
+
piano_roll = piano_roll.reshape((num_bar, -1, 128))
|
299 |
+
else:
|
300 |
+
piano_roll = np.pad(piano_roll, ((0, int(num_bar * bar_length - len(piano_roll))), (0, 0)), mode='constant',
|
301 |
+
constant_values=0)
|
302 |
+
piano_roll = piano_roll.reshape((num_bar, -1, 128))
|
303 |
+
|
304 |
+
bar_histogram = np.zeros((num_bar, 12))
|
305 |
+
for i in range(0, num_bar):
|
306 |
+
histogram = np.zeros(12)
|
307 |
+
for j in range(0, 128):
|
308 |
+
pitch_class = j % 12
|
309 |
+
histogram[pitch_class] += np.sum(piano_roll[i], axis=0)[j]
|
310 |
+
if sum(histogram) != 0:
|
311 |
+
bar_histogram[i] = histogram / sum(histogram)
|
312 |
+
else:
|
313 |
+
bar_histogram[i] = np.zeros(12)
|
314 |
+
return bar_histogram
|
315 |
+
|
316 |
+
def pitch_class_transition_matrix(self, feature, normalize=0):
|
317 |
+
"""
|
318 |
+
pitch_class_transition_matrix (Pitch class transition matrix):
|
319 |
+
The transition of pitch classes contains useful information for tasks such as key detection, chord recognition, or genre pattern recognition.
|
320 |
+
The two-dimensional pitch class transition matrix is a histogram-like representation computed by counting the pitch transitions for each (ordered) pair of notes.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
'normalize' : If set to 0, return transition without normalization.
|
324 |
+
If set to 1, normalizae by row.
|
325 |
+
If set to 2, normalize by entire matrix sum.
|
326 |
+
Returns:
|
327 |
+
'transition_matrix': shape of [12, 12], transition_matrix of 12 x 12.
|
328 |
+
"""
|
329 |
+
pm_object = feature['pretty_midi']
|
330 |
+
transition_matrix = pm_object.get_pitch_class_transition_matrix()
|
331 |
+
|
332 |
+
if normalize == 0:
|
333 |
+
return transition_matrix
|
334 |
+
|
335 |
+
elif normalize == 1:
|
336 |
+
sums = np.sum(transition_matrix, axis=1)
|
337 |
+
sums[sums == 0] = 1
|
338 |
+
return transition_matrix / sums.reshape(-1, 1)
|
339 |
+
|
340 |
+
elif normalize == 2:
|
341 |
+
return transition_matrix / sum(sum(transition_matrix))
|
342 |
+
|
343 |
+
else:
|
344 |
+
print("invalid normalization mode, return unnormalized matrix")
|
345 |
+
return transition_matrix
|
346 |
+
|
347 |
+
def pitch_range(self, feature):
|
348 |
+
"""
|
349 |
+
pitch_range (Pitch range):
|
350 |
+
The pitch range is calculated by subtraction of the highest and lowest used pitch in semitones.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
'p_range': a scalar for each sample.
|
354 |
+
"""
|
355 |
+
try:
|
356 |
+
piano_roll = feature['pretty_midi'].instruments[0].get_piano_roll(fs=100)
|
357 |
+
pitch_index = np.where(np.sum(piano_roll, axis=1) > 0)
|
358 |
+
p_range = np.max(pitch_index) - np.min(pitch_index)
|
359 |
+
return p_range
|
360 |
+
except:
|
361 |
+
return 0 # empty piano roll
|
362 |
+
|
363 |
+
# def avg_pitch_shift(self, feature, track_num=1):
|
364 |
+
# """
|
365 |
+
# avg_pitch_shift (Average pitch interval):
|
366 |
+
# Average value of the interval between two consecutive pitches in semitones.
|
367 |
+
|
368 |
+
# Args:
|
369 |
+
# 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
|
370 |
+
|
371 |
+
# Returns:
|
372 |
+
# 'pitch_shift': a scalar for each sample.
|
373 |
+
# """
|
374 |
+
# pattern = feature['midi_pattern']
|
375 |
+
# pattern.make_ticks_abs()
|
376 |
+
# resolution = pattern.resolution
|
377 |
+
# total_used_note = self.total_used_note(feature, track_num=track_num)
|
378 |
+
# d_note = np.zeros((max(total_used_note - 1, 0)))
|
379 |
+
# # if total_used_note == 0:
|
380 |
+
# # return 0
|
381 |
+
# # d_note = np.zeros((total_used_note - 1))
|
382 |
+
# current_note = 0
|
383 |
+
# counter = 0
|
384 |
+
# for i in range(0, len(pattern[track_num])):
|
385 |
+
# if type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
|
386 |
+
# if counter != 0:
|
387 |
+
# d_note[counter - 1] = current_note - pattern[track_num][i].data[0]
|
388 |
+
# current_note = pattern[track_num][i].data[0]
|
389 |
+
# counter += 1
|
390 |
+
# else:
|
391 |
+
# current_note = pattern[track_num][i].data[0]
|
392 |
+
# counter += 1
|
393 |
+
# pitch_shift = np.mean(abs(d_note))
|
394 |
+
# return pitch_shift
|
395 |
+
|
396 |
+
def avg_IOI(self, feature):
|
397 |
+
"""
|
398 |
+
avg_IOI (Average inter-onset-interval):
|
399 |
+
To calculate the inter-onset-interval in the symbolic music domain, we find the time between two consecutive notes.
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
'avg_ioi': a scalar for each sample.
|
403 |
+
"""
|
404 |
+
try:
|
405 |
+
tmp = feature['pretty_midi'].instruments[0]
|
406 |
+
pm_object = feature['pretty_midi']
|
407 |
+
onset = pm_object.get_onsets()
|
408 |
+
ioi = np.diff(onset)
|
409 |
+
avg_ioi = np.mean(ioi)
|
410 |
+
return avg_ioi
|
411 |
+
except:
|
412 |
+
return 0 # empty piano roll
|
413 |
+
|
414 |
+
# def note_length_hist(self, feature, track_num=1, normalize=True, pause_event=False):
|
415 |
+
# """
|
416 |
+
# note_length_hist (Note length histogram):
|
417 |
+
# To extract the note length histogram, we first define a set of allowable beat length classes:
|
418 |
+
# [full, half, quarter, 8th, 16th, dot half, dot quarter, dot 8th, dot 16th, half note triplet, quarter note triplet, 8th note triplet].
|
419 |
+
# The pause_event option, when activated, will double the vector size to represent the same lengths for rests.
|
420 |
+
# The classification of each event is performed by dividing the basic unit into the length of (barlength)/96, and each note length is quantized to the closest length category.
|
421 |
+
|
422 |
+
# Args:
|
423 |
+
# 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
|
424 |
+
# 'normalize' : If true, normalize by vector sum.
|
425 |
+
# 'pause_event' : when activated, will double the vector size to represent the same lengths for rests.
|
426 |
+
|
427 |
+
# Returns:
|
428 |
+
# 'note_length_hist': The output vector has a length of either 12 (or 24 when pause_event is True).
|
429 |
+
# """
|
430 |
+
|
431 |
+
# pattern = feature['midi_pattern']
|
432 |
+
# if pause_event is False:
|
433 |
+
# note_length_hist = np.zeros((12))
|
434 |
+
# pattern.make_ticks_abs()
|
435 |
+
# resolution = pattern.resolution
|
436 |
+
# # basic unit: bar_length/96
|
437 |
+
# for i in range(0, len(pattern[track_num])):
|
438 |
+
# if type(pattern[track_num][i]) == midi.events.TimeSignatureEvent:
|
439 |
+
# time_sig = pattern[track_num][i].data
|
440 |
+
# bar_length = time_sig[track_num] * resolution * 4 / 2**(time_sig[1])
|
441 |
+
# elif type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
|
442 |
+
# if 'time_sig' not in locals(): # set default bar length as 4 beat
|
443 |
+
# bar_length = 4 * resolution
|
444 |
+
# time_sig = [4, 2, 24, 8]
|
445 |
+
# unit = bar_length / 96.
|
446 |
+
# hist_list = [unit * 96, unit * 48, unit * 24, unit * 12, unit * 6, unit * 72, unit * 36, unit * 18, unit * 9, unit * 32, unit * 16, unit * 8]
|
447 |
+
# current_tick = pattern[track_num][i].tick
|
448 |
+
# current_note = pattern[track_num][i].data[0]
|
449 |
+
# # find next note off
|
450 |
+
# for j in range(i, len(pattern[track_num])):
|
451 |
+
# if type(pattern[track_num][j]) == midi.events.NoteOffEvent or (type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] == 0):
|
452 |
+
# if pattern[track_num][j].data[0] == current_note:
|
453 |
+
|
454 |
+
# note_length = pattern[track_num][j].tick - current_tick
|
455 |
+
# distance = np.abs(np.array(hist_list) - note_length)
|
456 |
+
# idx = distance.argmin()
|
457 |
+
# note_length_hist[idx] += 1
|
458 |
+
# break
|
459 |
+
# else:
|
460 |
+
# note_length_hist = np.zeros((24))
|
461 |
+
# pattern.make_ticks_abs()
|
462 |
+
# resolution = pattern.resolution
|
463 |
+
# # basic unit: bar_length/96
|
464 |
+
# for i in range(0, len(pattern[track_num])):
|
465 |
+
# if type(pattern[track_num][i]) == midi.events.TimeSignatureEvent:
|
466 |
+
# time_sig = pattern[track_num][i].data
|
467 |
+
# bar_length = time_sig[track_num] * resolution * 4 / 2**(time_sig[1])
|
468 |
+
# elif type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
|
469 |
+
# check_previous_off = True
|
470 |
+
# if 'time_sig' not in locals(): # set default bar length as 4 beat
|
471 |
+
# bar_length = 4 * resolution
|
472 |
+
# time_sig = [4, 2, 24, 8]
|
473 |
+
# unit = bar_length / 96.
|
474 |
+
# tol = 3. * unit
|
475 |
+
# hist_list = [unit * 96, unit * 48, unit * 24, unit * 12, unit * 6, unit * 72, unit * 36, unit * 18, unit * 9, unit * 32, unit * 16, unit * 8]
|
476 |
+
# current_tick = pattern[track_num][i].tick
|
477 |
+
# current_note = pattern[track_num][i].data[0]
|
478 |
+
# # find next note off
|
479 |
+
# for j in range(i, len(pattern[track_num])):
|
480 |
+
# # find next note off
|
481 |
+
# if type(pattern[track_num][j]) == midi.events.NoteOffEvent or (type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] == 0):
|
482 |
+
# if pattern[track_num][j].data[0] == current_note:
|
483 |
+
|
484 |
+
# note_length = pattern[track_num][j].tick - current_tick
|
485 |
+
# distance = np.abs(np.array(hist_list) - note_length)
|
486 |
+
# idx = distance.argmin()
|
487 |
+
# note_length_hist[idx] += 1
|
488 |
+
# break
|
489 |
+
# else:
|
490 |
+
# if pattern[track_num][j].tick == current_tick:
|
491 |
+
# check_previous_off = False
|
492 |
+
|
493 |
+
# # find previous note off/on
|
494 |
+
# if check_previous_off is True:
|
495 |
+
# for j in range(i - 1, 0, -1):
|
496 |
+
# if type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] != 0:
|
497 |
+
# break
|
498 |
+
|
499 |
+
# elif type(pattern[track_num][j]) == midi.events.NoteOffEvent or (type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] == 0):
|
500 |
+
|
501 |
+
# note_length = current_tick - pattern[track_num][j].tick
|
502 |
+
# distance = np.abs(np.array(hist_list) - note_length)
|
503 |
+
# idx = distance.argmin()
|
504 |
+
# if distance[idx] < tol:
|
505 |
+
# note_length_hist[idx + 12] += 1
|
506 |
+
# break
|
507 |
+
|
508 |
+
# if normalize is False:
|
509 |
+
# return note_length_hist
|
510 |
+
|
511 |
+
# elif normalize is True:
|
512 |
+
|
513 |
+
# return note_length_hist / np.sum(note_length_hist)
|
514 |
+
|
515 |
+
# def note_length_transition_matrix(self, feature, track_num=1, normalize=0, pause_event=False):
|
516 |
+
# """
|
517 |
+
# note_length_transition_matrix (Note length transition matrix):
|
518 |
+
# Similar to the pitch class transition matrix, the note length tran- sition matrix provides useful information for rhythm description.
|
519 |
+
|
520 |
+
# Args:
|
521 |
+
# 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
|
522 |
+
# 'normalize' : If true, normalize by vector sum.
|
523 |
+
# 'pause_event' : when activated, will double the vector size to represent the same lengths for rests.
|
524 |
+
|
525 |
+
# 'normalize' : If set to 0, return transition without normalization.
|
526 |
+
# If set to 1, normalizae by row.
|
527 |
+
# If set to 2, normalize by entire matrix sum.
|
528 |
+
|
529 |
+
# Returns:
|
530 |
+
# 'transition_matrix': The output feature dimension is 12 × 12 (or 24 x 24 when pause_event is True).
|
531 |
+
# """
|
532 |
+
# pattern = feature['midi_pattern']
|
533 |
+
# if pause_event is False:
|
534 |
+
# transition_matrix = np.zeros((12, 12))
|
535 |
+
# pattern.make_ticks_abs()
|
536 |
+
# resolution = pattern.resolution
|
537 |
+
# idx = None
|
538 |
+
# # basic unit: bar_length/96
|
539 |
+
# for i in range(0, len(pattern[track_num])):
|
540 |
+
# if type(pattern[track_num][i]) == midi.events.TimeSignatureEvent:
|
541 |
+
# time_sig = pattern[track_num][i].data
|
542 |
+
# bar_length = time_sig[track_num] * resolution * 4 / 2**(time_sig[1])
|
543 |
+
# elif type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
|
544 |
+
# if 'time_sig' not in locals(): # set default bar length as 4 beat
|
545 |
+
# bar_length = 4 * resolution
|
546 |
+
# time_sig = [4, 2, 24, 8]
|
547 |
+
# unit = bar_length / 96.
|
548 |
+
# hist_list = [unit * 96, unit * 48, unit * 24, unit * 12, unit * 6, unit * 72, unit * 36, unit * 18, unit * 9, unit * 32, unit * 16, unit * 8]
|
549 |
+
# current_tick = pattern[track_num][i].tick
|
550 |
+
# current_note = pattern[track_num][i].data[0]
|
551 |
+
# # find note off
|
552 |
+
# for j in range(i, len(pattern[track_num])):
|
553 |
+
# if type(pattern[track_num][j]) == midi.events.NoteOffEvent or (type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] == 0):
|
554 |
+
# if pattern[track_num][j].data[0] == current_note:
|
555 |
+
# note_length = pattern[track_num][j].tick - current_tick
|
556 |
+
# distance = np.abs(np.array(hist_list) - note_length)
|
557 |
+
|
558 |
+
# last_idx = idx
|
559 |
+
# idx = distance.argmin()
|
560 |
+
# if last_idx is not None:
|
561 |
+
# transition_matrix[last_idx][idx] += 1
|
562 |
+
# break
|
563 |
+
# else:
|
564 |
+
# transition_matrix = np.zeros((24, 24))
|
565 |
+
# pattern.make_ticks_abs()
|
566 |
+
# resolution = pattern.resolution
|
567 |
+
# idx = None
|
568 |
+
# # basic unit: bar_length/96
|
569 |
+
# for i in range(0, len(pattern[track_num])):
|
570 |
+
# if type(pattern[track_num][i]) == midi.events.TimeSignatureEvent:
|
571 |
+
# time_sig = pattern[track_num][i].data
|
572 |
+
# bar_length = time_sig[track_num] * resolution * 4 / 2**(time_sig[1])
|
573 |
+
# elif type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
|
574 |
+
# check_previous_off = True
|
575 |
+
# if 'time_sig' not in locals(): # set default bar length as 4 beat
|
576 |
+
# bar_length = 4 * resolution
|
577 |
+
# time_sig = [4, 2, 24, 8]
|
578 |
+
# unit = bar_length / 96.
|
579 |
+
# tol = 3. * unit
|
580 |
+
# hist_list = [unit * 96, unit * 48, unit * 24, unit * 12, unit * 6, unit * 72, unit * 36, unit * 18, unit * 9, unit * 32, unit * 16, unit * 8]
|
581 |
+
# current_tick = pattern[track_num][i].tick
|
582 |
+
# current_note = pattern[track_num][i].data[0]
|
583 |
+
# # find next note off
|
584 |
+
# for j in range(i, len(pattern[track_num])):
|
585 |
+
# # find next note off
|
586 |
+
# if type(pattern[track_num][j]) == midi.events.NoteOffEvent or (type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] == 0):
|
587 |
+
# if pattern[track_num][j].data[0] == current_note:
|
588 |
+
|
589 |
+
# note_length = pattern[track_num][j].tick - current_tick
|
590 |
+
# distance = np.abs(np.array(hist_list) - note_length)
|
591 |
+
# last_idx = idx
|
592 |
+
# idx = distance.argmin()
|
593 |
+
# if last_idx is not None:
|
594 |
+
# transition_matrix[last_idx][idx] += 1
|
595 |
+
# break
|
596 |
+
# else:
|
597 |
+
# if pattern[track_num][j].tick == current_tick:
|
598 |
+
# check_previous_off = False
|
599 |
+
|
600 |
+
# # find previous note off/on
|
601 |
+
# if check_previous_off is True:
|
602 |
+
# for j in range(i - 1, 0, -1):
|
603 |
+
# if type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] != 0:
|
604 |
+
# break
|
605 |
+
|
606 |
+
# elif type(pattern[track_num][j]) == midi.events.NoteOffEvent or (type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] == 0):
|
607 |
+
|
608 |
+
# note_length = current_tick - pattern[track_num][j].tick
|
609 |
+
# distance = np.abs(np.array(hist_list) - note_length)
|
610 |
+
|
611 |
+
# last_idx = idx
|
612 |
+
# idx = distance.argmin()
|
613 |
+
# if last_idx is not None:
|
614 |
+
# if distance[idx] < tol:
|
615 |
+
# idx = last_idx
|
616 |
+
# transition_matrix[last_idx][idx + 12] += 1
|
617 |
+
# break
|
618 |
+
|
619 |
+
# if normalize == 0:
|
620 |
+
# return transition_matrix
|
621 |
+
|
622 |
+
# elif normalize == 1:
|
623 |
+
|
624 |
+
# sums = np.sum(transition_matrix, axis=1)
|
625 |
+
# sums[sums == 0] = 1
|
626 |
+
# return transition_matrix / sums.reshape(-1, 1)
|
627 |
+
|
628 |
+
# elif normalize == 2:
|
629 |
+
|
630 |
+
# return transition_matrix / sum(sum(transition_matrix))
|
631 |
+
|
632 |
+
# else:
|
633 |
+
# print "invalid normalization mode, return unnormalized matrix"
|
634 |
+
# return transition_matrix
|
635 |
+
|
636 |
+
# def chord_dependency(self, feature, bar_chord, bpm=120, num_bar=None, track_num=1):
|
637 |
+
# pm_object = feature['pretty_midi']
|
638 |
+
# # compare bar chroma with chord chroma. calculate the ecludian
|
639 |
+
# bar_pitch_class_histogram = self.bar_pitch_class_histogram(pm_object, bpm=bpm, num_bar=num_bar, track_num=track_num)
|
640 |
+
# dist = np.zeros((len(bar_pitch_class_histogram)))
|
641 |
+
# for i in range((len(bar_pitch_class_histogram))):
|
642 |
+
# dist[i] = np.linalg.norm(bar_pitch_class_histogram[i] - bar_chord[i])
|
643 |
+
# average_dist = np.mean(dist)
|
644 |
+
# return average_dist
|
music_evaluation/mgeval/core.pyc
ADDED
Binary file (18.2 kB). View file
|
|