yjhuangcd commited on
Commit
9965bf6
·
1 Parent(s): 9708aee

First commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +136 -3
  2. compute_std.py +54 -0
  3. datasets/README.md +20 -0
  4. datasets/all_midi.csv +0 -0
  5. datasets/chunk_midi.py +72 -0
  6. datasets/filter_class.py +38 -0
  7. datasets/piano_roll_all.py +139 -0
  8. datasets/select_midi.py +74 -0
  9. diff_collage/README.md +3 -0
  10. diff_collage/__init__.py +5 -0
  11. diff_collage/avg_circle.py +64 -0
  12. diff_collage/avg_long.py +40 -0
  13. diff_collage/condind_circle.py +190 -0
  14. diff_collage/condind_long.py +147 -0
  15. diff_collage/generic_sampler.py +113 -0
  16. diff_collage/loss_helper.py +41 -0
  17. diff_collage/w_img.py +79 -0
  18. diff_collage/w_loss.py +433 -0
  19. environment.yml +282 -0
  20. guided_diffusion/__init__.py +3 -0
  21. guided_diffusion/condition_functions.py +174 -0
  22. guided_diffusion/dist_util.py +104 -0
  23. guided_diffusion/dit.py +983 -0
  24. guided_diffusion/embed_datasets.py +161 -0
  25. guided_diffusion/fp16_util.py +237 -0
  26. guided_diffusion/gaussian_diffusion.py +1400 -0
  27. guided_diffusion/logger.py +521 -0
  28. guided_diffusion/losses.py +77 -0
  29. guided_diffusion/midi_util.py +291 -0
  30. guided_diffusion/nn.py +170 -0
  31. guided_diffusion/pr_datasets_all.py +183 -0
  32. guided_diffusion/resample.py +154 -0
  33. guided_diffusion/respace.py +128 -0
  34. guided_diffusion/script_util.py +531 -0
  35. guided_diffusion/train_util.py +475 -0
  36. guided_diffusion/unet.py +906 -0
  37. load_utils.py +31 -0
  38. music_evaluation/README.md +22 -0
  39. music_evaluation/convert_to_wav.py +42 -0
  40. music_evaluation/demo.ipynb +0 -0
  41. music_evaluation/fad.py +38 -0
  42. music_evaluation/figaro/chord_recognition.py +247 -0
  43. music_evaluation/figaro/constants.py +47 -0
  44. music_evaluation/figaro/evaluate.py +268 -0
  45. music_evaluation/figaro/input_representation.py +655 -0
  46. music_evaluation/figaro/vocab.py +166 -0
  47. music_evaluation/mgeval/__init__.py +0 -0
  48. music_evaluation/mgeval/__init__.pyc +0 -0
  49. music_evaluation/mgeval/core.py +644 -0
  50. music_evaluation/mgeval/core.pyc +0 -0
README.md CHANGED
@@ -1,3 +1,136 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Symbolic Music Generation with Non-Differentiable Rule Guided Diffusion
2
+
3
+ This is the codebase for the paper: [Symbolic Music Generation with Non-Differentiable Rule Guided Diffusion](https://arxiv.org/abs/2402.14285).
4
+
5
+ We introduced a symbolic music generator with non-differentiable rule guided diffusion models, drawing inspiration from stochastic control. For music demos, please visit our [project website](https://scg-rule-guided-music.github.io/).
6
+
7
+ <img align="center" src="rule_guided_music_gen.png" width="750">
8
+
9
+ ## Set up the environment
10
+
11
+ - Put the pretrained VAE checkpoint under `taming-transformers/checkpoints`
12
+ - Create conda virtual environment via: `conda env create -f environment.yml`
13
+ - Activating virtual env: `conda activate guided`
14
+
15
+ ## Download Pretrained Checkpoints
16
+ - Pretrained VAE checkpoint under `trained_models/VAE`: put it under `taming-transformers/checkpoints/all_onset/epoch_14.ckpt`.
17
+ - Pretrained Diffusion model checkpoint under `trained_models/diffusion`: put it under `loggings/checkpoints/ema_0.9999_1200000.pt`.
18
+ - Pretrained classifiers for each rule under `trained_models/classifier`: put them under `loggings/classifier/`.
19
+
20
+ ## Rule Guided Generation
21
+ All the configs of the rule guidance are stored in `scripts/configs/`. `cond_demo` contains the configs that we used to generate the demos for composer co-creation. `cond_table` contains the configs that we used to create the table in the paper. `edit` contains the configs of editing an existing excerpt.
22
+ For instance, if you want to guide diffusion models on all of the rules simultaneously, and use both SCG and classifier guidance, you can use this config `scripts/configs/cond_table/all/scg_classifier_all.yml`. The results will save in this directory: `loggings/cond_table/all/scg_classifier_all`.
23
+
24
+ The config file contains the following fields:
25
+ | Field | Description |
26
+ |-----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
27
+ | `target_rules` | Specify the desired attributes here. New rules can be added by writing rule programs in `music_rule_guidance/music_rules.py` and updating `rule_maps`. |
28
+ | `guidance` | Hyper-parameters for guidance, including classifier config for classifier guidance, and when to start or stop using guidance. |
29
+ | `scg` | Hyper-parameters for stochastic control guidance (SCG, ours). |
30
+ | `sampling` | Hyper-parameters for diffusion model sampling. Options include using DDIM or sampling longer sequences with `diff_collage`. |
31
+
32
+ To run the rule-guided sampling code, you can use the following script:
33
+ ```
34
+ python sample_rule.py
35
+ --config_path <config_file>
36
+ --batch_size 4
37
+ --num_samples 20
38
+ --data_dir <data_dir>
39
+ --model DiTRotary_XL_8
40
+ --model_path loggings/checkpoints/ema_0.9999_1200000.pt
41
+ --image_size 128 16
42
+ --in_channels 4
43
+ --scale_factor 1.2465
44
+ --class_cond True
45
+ --num_classes 3
46
+ --class_label 1
47
+ ```
48
+ The meaning of each hyper-parameter is listed as follows:
49
+ | Hyper-parameter | Description |
50
+ |-------------------|-------------------------------------------------------------------------------------------------------------------|
51
+ | `config_path` | Path of the configuration file. For example: `scripts/configs/cond_demo/demo1.yml`. |
52
+ | `batch_size` | Batch size for generation. Default: 4. |
53
+ | `num_samples` | How many samples to generate in total. Default: 20. |
54
+ | `data_dir` | Optional: directory to store data. Used to extract rule label from existing music excerpts. Do not need if target rule labels are given (just leave as default value in this case). |
55
+ | `model` | Model backbone for diffusion model. Default: DiTRotary_XL_8. |
56
+ | `model_path` | Path of the pretrained diffusion model. |
57
+ | `image_size` | Size of the generated piano roll in latent space (for 10.24s, the size is 128x16). |
58
+ | `in_channels` | Number of channels for the latent space of pretrained VAE model. Default: 4. |
59
+ | `scale_factor` | 1 / std of the latents. You can use `compute_std.py` to compute it for a pretrained VAE. By default: 1.2465 (computed for the VAE checkpoint that we provided). |
60
+ | `class_cond` | Whether to condition on music genre (datasets: maestro, muscore and pop) for generation. Default: True. |
61
+ | `num_classes` | Number of classes (datasets). We trained on 3 datasets. |
62
+ | `class_label` | 0 for Maestro (classical performance), 1 for Muscore (classical sheet music), 2 for Pop. |
63
+
64
+
65
+ To guide on new rules in addition to what we considered (pitch histogram, note density and chord progression).
66
+ You can add the rule function to `music_rule_guidance/music_rules.py`, and add it to `FUNC_DICT` in `rule_maps.py`.
67
+ In addition, you need to pick a loss function for the newly added rule and add it to `LOSS_DICT` in `rule_maps.py`.
68
+ Then you can use the key in `FUNC_DICT` for `target_rules` in the config file.
69
+
70
+
71
+ This framework also supports editing existing excerpt:
72
+ ```
73
+ python scripts/edit.py
74
+ --config_path scripts/configs/edit_table/nd_500_num16.yml
75
+ --batch_size 2
76
+ --num_samples 20
77
+ --data_dir <data_dir>
78
+ --model DiTRotary_XL_8
79
+ --model_path loggings/checkpoints/ema_0.9999_1200000.pt
80
+ --image_size 128 16
81
+ --in_channels 4
82
+ --scale_factor 1.2465
83
+ --class_cond True
84
+ --num_classes 3
85
+ --class_label 2
86
+ ```
87
+
88
+
89
+ ## Train diffusion model for music generation
90
+ To train a diffusion model for symbolic music generation, using the following script.
91
+ ```
92
+ mpiexec -n 8 python scripts/train_dit.py
93
+ --dir <loggings/save_dir>
94
+ --data_dir <datasets/data_dir>
95
+ --model DiTRotary_XL_8
96
+ --image_size 128 16
97
+ --in_channels 4
98
+ --batch_size 32
99
+ --encode_rep 4
100
+ --shift_size 4
101
+ --pr_image_size 2560
102
+ --microbatch_encode -1
103
+ --class_cond True
104
+ --num_classes 3
105
+ --scale_factor <scale_factor>
106
+ --fs 100
107
+ --save_interval 10000
108
+ --resume <dir to the last saved model ckpt>
109
+ ```
110
+ The meaning of each hyper-parameter is listed as follows:
111
+ | Hyper-parameter | Description |
112
+ |-------------------|-------------------------------------------------------------------------------------------------------------------|
113
+ | `mpiexec -n 8` | Multi-GPU training, using 8 GPUs. |
114
+ | `embed_model_name`| VAE config, default is `kl/f8-all-onset`. |
115
+ | `embed_model_ckpt`| Directory of the VAE checkpoint. |
116
+ | `dir` | Directory to save diffusion checkpoints and generated samples. |
117
+ | `data_dir` | Where you store your piano roll data. |
118
+ | `model` | Diffusion model name (config), e.g., `DiTRotary_XL_8`: a DiT XL model with 1D patch_size=8 (seq_len=256). |
119
+ | `image_size` | Latent space size (for 10.24s, the size is 128x16). |
120
+ | `in_channels` | Latent space channel (default is 4). |
121
+ | `batch_size` | Batch size on each GPU. Effective batch size is batch_size * num_GPUs. Aim for an effective batch size of 256. |
122
+ | `encoder_rep` | How many excerpts to create from a long sequence. Batch_size needs to be greater or equal to encoder_rep. Default: 4. |
123
+ | `shift_size` | Time shift between successive music excerpts from a long sequence. Default: 4. |
124
+ | `pr_image_size` | Length of a long sequence, need to be compatible with `encoder_rep` and `shift_size`. For example, for `encoder_rep=4` and `shift_size=4`, the excerpts created from a long sequence are 1-8, 5-12, 9-16 and 13-20. Therefore `pr_image_size=20x128=2560`. | |
125
+ | `class_cond` | Train with class conditioning (score(x,y), y is the class). |
126
+ | `num_classes` | Number of classes in your conditioning. E.g., 3 - 0 for maestro, 1 for muscore, 2 for pop. |
127
+ | `scale_factor` | 1 / std of the latents. You can use `compute_std.py` to compute it for a pretrained VAE. |
128
+ | `fs` | Time resolution is 1 / fs. |
129
+ | `save_interval` | Frequency of saving checkpoints, e.g., every 10k steps. |
130
+
131
+ ## References
132
+ This repository is based on [openai/guided-diffusion](https://github.com/openai/guided-diffusion), with modifications for data representation, guidance algorithm and architecture improvements.
133
+ - The VAE architecture is modified upon [taming-transformers](https://github.com/CompVis/taming-transformers).
134
+ - The DiT architecture is modified upon [DiT](https://github.com/facebookresearch/DiT).
135
+ - Music evaluation code is adapted from [mgeval](https://github.com/RichardYang40148/mgeval) and [figaro](https://github.com/dvruette/figaro).
136
+ - MIDI to piano roll representation is adapted from [pretty_midi](https://github.com/craffel/pretty-midi).
compute_std.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from load_utils import load_model
4
+ from guided_diffusion import dist_util
5
+ from guided_diffusion.gaussian_diffusion import _encode, _decode
6
+ from guided_diffusion.pr_datasets_all import load_data
7
+ from tqdm import tqdm
8
+ from guided_diffusion.midi_util import visualize_full_piano_roll, save_piano_roll_midi
9
+ from music_rule_guidance import music_rules
10
+ import matplotlib.pyplot as plt
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+ plt.rcParams["figure.figsize"] = (20,3)
14
+ plt.rcParams['figure.dpi'] = 300
15
+ plt.rcParams['savefig.dpi'] = 300
16
+
17
+
18
+ MODEL_NAME = 'kl/f8-all-onset'
19
+ MODEL_CKPT = 'taming-transformers/checkpoints/all_onset/epoch_14.ckpt'
20
+
21
+ TOTAL_BATCH = 256
22
+
23
+
24
+ def main():
25
+
26
+ data = load_data(
27
+ data_dir='datasets/all-len-40-gap-16-no-empty_train.csv',
28
+ batch_size=32,
29
+ class_cond=True,
30
+ image_size=1024,
31
+ deterministic=False,
32
+ fs=100,
33
+ )
34
+ embed_model = load_model(MODEL_NAME, MODEL_CKPT)
35
+ del embed_model.loss
36
+ embed_model.to(dist_util.dev())
37
+ embed_model.eval()
38
+
39
+ z_list = []
40
+ with torch.no_grad():
41
+ for _ in tqdm(range(TOTAL_BATCH)):
42
+ batch, cond = next(data)
43
+ batch = batch.to(dist_util.dev())
44
+ enc = _encode(batch, embed_model, scale_factor=1.)
45
+ z_list.append(enc.cpu())
46
+ latents = torch.concat(z_list, dim=0)
47
+ scale_factor = 1. / latents.flatten().std().item()
48
+ print(f"scale_factor: {scale_factor}")
49
+ print("done")
50
+
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()
datasets/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Creating data representation for symbolic music
2
+
3
+ This directory contains instructions and scripts for creating training dataset.
4
+ Notice that you do not need to prepare dataset if you want to generate music and have target rule labels in mind.
5
+ You will need to prepare dataset if you want to train a model, or you want to extract rule labels from existing music excerpts.
6
+
7
+ We train our diffusion model on three datasets: [Maestro](https://magenta.tensorflow.org/datasets/maestro#v300) (classical piano performance), Muscore (crawled from the Muscore website), and Pop ([Pop1k7](https://drive.google.com/file/d/1qw_tVUntblIg4lW16vbpjLXVndkVtgDe/view) and [Pop909](https://github.com/music-x-lab/POP909-Dataset)).
8
+ You can download the data and put the midi files into the corresponding folder: `maestro`, `muscore` and `pop`.
9
+
10
+ Then run `piano_roll_all.py` to create piano roll excerpts from the dataset.
11
+
12
+ The above script creates piano roll excerpts of 1.28s.
13
+ To create music of 10.24s for training, return to the main folder and run `rearrange_pr_data.py` to concat shorter piano rolls to longer ones.
14
+ The processed data will be saved in `datasets/all-len-40-gap-16-no-empty` by default, and along with the data, there will be two csv files:
15
+ `all-len-40-gap-16-no-empty_train.csv` and `all-len-40-gap-16-no-empty_test.csv` that list the filenames.
16
+
17
+ If you want to extract rule label from the piano rolls and condition on a specific dataset, you need to create csv file for each dataset using:
18
+ ```
19
+ python filter_class.py --file_path all-len-40-gap-16-no-empty_test.csv --class_label <class label>
20
+ ```
datasets/all_midi.csv ADDED
The diff for this file is too large to render. See raw diff
 
datasets/chunk_midi.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pretty_midi
3
+ import argparse
4
+
5
+ def chunk_midi(input_path, output_dir, chunk_length=10.24):
6
+ # Ensure the output directory exists
7
+ if not os.path.exists(output_dir):
8
+ os.makedirs(output_dir)
9
+
10
+ for midi_file_name in os.listdir(input_path):
11
+ if not (midi_file_name.endswith('.midi') or midi_file_name.endswith('.mid')):
12
+ continue # Skip non-midi files
13
+
14
+ full_path = os.path.join(input_path, midi_file_name)
15
+ try:
16
+ midi_data = pretty_midi.PrettyMIDI(full_path)
17
+ except Exception as e:
18
+ print(f"Error processing {midi_file_name}: {e}")
19
+ continue # Skip to the next file if an error occurs
20
+
21
+ end_time = midi_data.get_end_time() # Get end time directly with pretty_midi
22
+ num_chunks = int(end_time // chunk_length) + (1 if end_time % chunk_length > 0 else 0)
23
+
24
+ base_name, file_extension = os.path.splitext(midi_file_name)
25
+
26
+ for i in range(num_chunks):
27
+ start_time = i * chunk_length
28
+ segment_end_time = min((i + 1) * chunk_length, midi_data.get_end_time())
29
+
30
+ # Create a new MIDI object for each chunk
31
+ chunk_midi_data = pretty_midi.PrettyMIDI()
32
+
33
+ # Merge non-drum instruments into a single instrument
34
+ merged_instrument = pretty_midi.Instrument(program=0, is_drum=False)
35
+
36
+ for instrument in midi_data.instruments:
37
+ if not instrument.is_drum:
38
+ for note in instrument.notes:
39
+ if start_time <= note.start < segment_end_time:
40
+ # Shift the note start and end times to start at 0
41
+ new_note = pretty_midi.Note(
42
+ velocity=note.velocity,
43
+ pitch=note.pitch,
44
+ start=note.start - start_time,
45
+ end=note.end - start_time
46
+ )
47
+ merged_instrument.notes.append(new_note)
48
+ else:
49
+ # If it's a drum instrument, just adjust the note times and append it
50
+ new_drum_instrument = pretty_midi.Instrument(program=instrument.program, is_drum=True, name=instrument.name)
51
+ new_drum_instrument.notes = [note for note in instrument.notes if start_time <= note.start < segment_end_time]
52
+ for note in new_drum_instrument.notes:
53
+ note.start -= start_time
54
+ note.end -= start_time
55
+ chunk_midi_data.instruments.append(new_drum_instrument)
56
+
57
+ # Add the merged instrument to the MIDI object
58
+ chunk_midi_data.instruments.append(merged_instrument)
59
+
60
+ # Save the chunk with the same extension as the original file
61
+ new_midi_name = "{}_{}{}".format(base_name, i, file_extension)
62
+ chunk_midi_data.write(os.path.join(output_dir, new_midi_name))
63
+
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser(description="Chunk MIDI files into specified lengths.")
67
+ parser.add_argument("--input_path", type=str, help="Path to the directory containing the MIDI files to chunk.")
68
+ parser.add_argument("--output_dir", type=str, help="Path to the directory where the chunked MIDI files will be saved.")
69
+ parser.add_argument("--chunk_length", type=float, default=10.24, help="length to chunk the midi file to (s).")
70
+ args = parser.parse_args()
71
+
72
+ chunk_midi(args.input_path, args.output_dir, chunk_length=args.chunk_length)
datasets/filter_class.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import argparse
3
+
4
+ def filter_and_save_csv(file_path, class_label):
5
+ """
6
+ Filters a CSV file to keep only the rows where the 'classes' column equals the specified class label.
7
+ Saves the filtered DataFrame to a new CSV file.
8
+
9
+ :param file_path: Path to the original CSV file.
10
+ :param class_label: The class label to filter by.
11
+ """
12
+ # Read the CSV file
13
+ df = pd.read_csv(file_path)
14
+
15
+ # Filter out rows where 'classes' equals the specified class_label
16
+ filtered_df = df[df['classes'] == class_label]
17
+
18
+ # Save the filtered DataFrame to a new CSV file
19
+ # The new file name is the original file name with '_cls_<class_label>' appended before the file extension
20
+ new_file_path = file_path.replace('.csv', f'_cls_{class_label}.csv')
21
+ filtered_df.to_csv(new_file_path, index=False)
22
+
23
+ print(f"Filtered CSV saved as: {new_file_path}")
24
+
25
+ def main():
26
+ # Set up the argument parser
27
+ parser = argparse.ArgumentParser(description="Filter a CSV file by class and save to a new file.")
28
+ parser.add_argument("--file_path", type=str, help="Path to the original CSV file")
29
+ parser.add_argument("--class_label", type=int, help="The class label to filter by")
30
+
31
+ # Parse arguments
32
+ args = parser.parse_args()
33
+
34
+ # Call the function with the provided arguments
35
+ filter_and_save_csv(args.file_path, args.class_label)
36
+
37
+ if __name__ == "__main__":
38
+ main()
datasets/piano_roll_all.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torchvision import transforms, utils
5
+ import pretty_midi
6
+ import pandas as pd
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import math
10
+ from music_rule_guidance.music_rules import MAX_PIANO, MIN_PIANO
11
+
12
+ import matplotlib.pyplot as plt
13
+ plt.rcParams["figure.figsize"] = (6,3)
14
+ plt.rcParams['figure.dpi'] = 300
15
+ plt.rcParams['savefig.dpi'] = 300
16
+
17
+ CC_SUSTAIN_PEDAL = 64
18
+
19
+
20
+ def split_csv(csv_path='merged_midi.csv'):
21
+ # separate training validation testing files
22
+ df = pd.read_csv(csv_path)
23
+ save_name = csv_path[:csv_path.rfind('.csv')]
24
+ for split in ['train', 'validation', 'test']:
25
+ path = os.path.join(save_name, split + '.csv')
26
+ df_sub = df[df.split == split]
27
+ df_sub.to_csv(path, index=False)
28
+ return
29
+
30
+
31
+ def quantize_pedal(value, num_bins=8):
32
+ """Quantize an integer value from 0 to 127 into 8 bins and return the center value of the bin."""
33
+ if value < 0 or value > 127:
34
+ raise ValueError("Value should be between 0 and 127")
35
+ # Determine bin size
36
+ bin_size = 128 // num_bins # 16
37
+ # Quantize the value
38
+ bin_index = value // bin_size
39
+ bin_center = bin_size * bin_index + bin_size // 2
40
+ # Handle edge case for the last bin
41
+ if bin_center > 127:
42
+ bin_center = 127
43
+ return bin_center
44
+
45
+
46
+ def get_full_piano_roll(midi_data, fs, show=False):
47
+ # do not process sustain pedal
48
+ piano_roll, onset_roll = midi_data.get_piano_roll(fs=fs, pedal_threshold=None, onset=True)
49
+ # save pedal roll explicitly
50
+ pedal_roll = np.zeros_like(piano_roll)
51
+ # process pedal
52
+ for instru in midi_data.instruments:
53
+ pedal_changes = [_e for _e in instru.control_changes if _e.number == CC_SUSTAIN_PEDAL]
54
+ for cc in pedal_changes:
55
+ time_now = int(cc.time * fs)
56
+ if time_now < pedal_roll.shape[-1]:
57
+ # need to distinguish control_change 0 and background 0, with quantize 0-16 will be 8
58
+ # in muscore files, 0 immediately followed by 127, need to shift by one column
59
+ if pedal_roll[MIN_PIANO, time_now] != 0. and abs(pedal_roll[MIN_PIANO, time_now] - cc.value) > 64:
60
+ # use shift 2 here to prevent missing change when using interpolation augmentation
61
+ pedal_roll[MIN_PIANO:MAX_PIANO + 1, min(time_now + 2, pedal_roll.shape[-1] - 1)] = quantize_pedal(cc.value)
62
+ else:
63
+ pedal_roll[MIN_PIANO:MAX_PIANO + 1, time_now] = quantize_pedal(cc.value)
64
+ full_roll = np.concatenate((piano_roll[None], onset_roll[None], pedal_roll[None]), axis=0)
65
+ if show:
66
+ plt.imshow(piano_roll[::-1, :1024], vmin=0, vmax=127)
67
+ plt.show()
68
+ plt.imshow(pedal_roll[::-1, :1024], vmin=0, vmax=127)
69
+ plt.show()
70
+ return full_roll
71
+
72
+
73
+ def preprocess_midi(target='merged', csv_path='merged_midi.csv', fs=100., image_size=128, overlap=False, show=False):
74
+ # get piano roll from midi file
75
+ df = pd.read_csv(csv_path)
76
+ total_pieces = len(df)
77
+ if not os.path.exists(target):
78
+ os.makedirs(target)
79
+ for split in ['train', 'test']:
80
+ path = os.path.join(target, split)
81
+ if not os.path.exists(path):
82
+ os.makedirs(path)
83
+ for i in tqdm(range(total_pieces)):
84
+ midi_filename = df.midi_filename[i]
85
+ split = df.split[i]
86
+ dataset = df.dataset[i]
87
+ path = os.path.join(target, split)
88
+ midi_data = pretty_midi.PrettyMIDI(os.path.join(dataset, midi_filename))
89
+ full_roll = get_full_piano_roll(midi_data, fs=fs, show=show)
90
+ for j in range(0, full_roll.shape[-1], image_size):
91
+ if j + image_size <= full_roll.shape[-1]:
92
+ full_roll_excerpt = full_roll[:, :, j:j + image_size]
93
+ else:
94
+ full_roll_excerpt = np.zeros((3, full_roll.shape[1], image_size)) # 2x128ximage_size
95
+ full_roll_excerpt[:, :, : full_roll.shape[-1] - j] = full_roll[:, :, j:]
96
+ empty_roll = math.isclose(full_roll_excerpt.max(), 0.)
97
+ if not empty_roll:
98
+ # Find the last '/' in the string
99
+ last_slash_index = midi_filename.rfind('/')
100
+ # Find the '.npy' in the string
101
+ dot_mid_index = midi_filename.rfind('.mid')
102
+ # Extract the substring between last '/' and '.mid'
103
+ save_name = midi_filename[last_slash_index + 1:dot_mid_index]
104
+ full_roll_excerpt = full_roll_excerpt.astype(np.uint8)
105
+ np.save(os.path.join(path, save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
106
+ # save with dataset name for VAE duplicate file names
107
+ # np.save(os.path.join(path, dataset + '_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
108
+ if overlap:
109
+ for j in range(image_size//2, full_roll.shape[-1], image_size): # overlap with image_size//2
110
+ if j + image_size <= full_roll.shape[-1]:
111
+ full_roll_excerpt = full_roll[:, :, j:j + image_size]
112
+ else:
113
+ full_roll_excerpt = np.zeros((3, full_roll.shape[1], image_size))
114
+ full_roll_excerpt[:, :, : full_roll.shape[-1] - j] = full_roll[:, :, j:]
115
+ empty_roll = math.isclose(full_roll_excerpt.max(), 0.)
116
+ if not empty_roll:
117
+ last_slash_index = midi_filename.rfind('/')
118
+ dot_mid_index = midi_filename.rfind('.mid')
119
+ save_name = midi_filename[last_slash_index + 1:dot_mid_index]
120
+ full_roll_excerpt = full_roll_excerpt.astype(np.uint8)
121
+ np.save(os.path.join(path, 'shift_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
122
+ # save with dataset name for VAE duplicate file names
123
+ # np.save(os.path.join(path, dataset + '_' + 'shift_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
124
+ return
125
+
126
+
127
+ def main():
128
+ # create fs=100 1.28s datasets without overlap (can be rearranged)
129
+ preprocess_midi(target='all-128-fs100', csv_path='all_midi.csv', fs=100, image_size=128, overlap=False, show=False)
130
+ # create fs=100 2.56s datasets with overlap (used for vae training), when load in, need to select 1.28s from 2.56s
131
+ # preprocess_midi(target='all-256-overlap-fs100', csv_path='all_midi.csv', fs=100, image_size=256, overlap=True,
132
+ # show=False)
133
+ # create fs=12.5 (0.08s) for pixel space diffusion model, rearrangement with length 2
134
+ # preprocess_midi(target='all-128-fs12.5', csv_path='all_midi.csv', fs=12.5, image_size=128, overlap=False,
135
+ # show=False)
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
datasets/select_midi.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pretty_midi
3
+ import argparse
4
+ import random
5
+
6
+ def select_midi(input_path, output_dir, select_length=10.24):
7
+ # Ensure the output directory exists
8
+ if not os.path.exists(output_dir):
9
+ os.makedirs(output_dir)
10
+
11
+ for midi_file_name in os.listdir(input_path):
12
+ if not (midi_file_name.endswith('.midi') or midi_file_name.endswith('.mid')):
13
+ continue # Skip non-midi files
14
+
15
+ full_path = os.path.join(input_path, midi_file_name)
16
+ try:
17
+ midi_data = pretty_midi.PrettyMIDI(full_path)
18
+ except Exception as e:
19
+ print(f"Error processing {midi_file_name}: {e}")
20
+ continue # Skip to the next file if an error occurs
21
+
22
+ end_time = midi_data.get_end_time() # Get end time directly with pretty_midi
23
+
24
+ if select_length > end_time:
25
+ print("Segment length is longer than the MIDI file duration.")
26
+ continue
27
+
28
+ start_time = random.uniform(0, end_time - select_length)
29
+ segment_end_time = start_time + select_length
30
+
31
+ # Create a new MIDI object for each chunk
32
+ chunk_midi_data = pretty_midi.PrettyMIDI()
33
+
34
+ # Merge non-drum instruments into a single instrument
35
+ merged_instrument = pretty_midi.Instrument(program=0, is_drum=False)
36
+
37
+ for instrument in midi_data.instruments:
38
+ if not instrument.is_drum:
39
+ for note in instrument.notes:
40
+ if start_time <= note.start < segment_end_time:
41
+ # Shift the note start and end times to start at 0
42
+ new_note = pretty_midi.Note(
43
+ velocity=note.velocity,
44
+ pitch=note.pitch,
45
+ start=note.start - start_time,
46
+ end=note.end - start_time
47
+ )
48
+ merged_instrument.notes.append(new_note)
49
+ else:
50
+ # If it's a drum instrument, just adjust the note times and append it
51
+ new_drum_instrument = pretty_midi.Instrument(program=instrument.program, is_drum=True,
52
+ name=instrument.name)
53
+ new_drum_instrument.notes = [note for note in instrument.notes if
54
+ start_time <= note.start < segment_end_time]
55
+ for note in new_drum_instrument.notes:
56
+ note.start -= start_time
57
+ note.end -= start_time
58
+ chunk_midi_data.instruments.append(new_drum_instrument)
59
+
60
+ # Add the merged instrument to the MIDI object
61
+ chunk_midi_data.instruments.append(merged_instrument)
62
+
63
+ # Save the chunk with the same name as the original file
64
+ chunk_midi_data.write(os.path.join(output_dir, midi_file_name))
65
+
66
+
67
+ if __name__ == "__main__":
68
+ parser = argparse.ArgumentParser(description="Chunk MIDI files into specified lengths.")
69
+ parser.add_argument("--input_path", type=str, help="Path to the directory containing the MIDI files to chunk.")
70
+ parser.add_argument("--output_dir", type=str, help="Path to the directory where the chunked MIDI files will be saved.")
71
+ parser.add_argument("--select_length", type=float, default=10.24, help="length to chunk the midi file to (s).")
72
+ args = parser.parse_args()
73
+
74
+ select_midi(args.input_path, args.output_dir, select_length=args.select_length)
diff_collage/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Diff Collage
2
+
3
+ This is an implementation of the [DiffCollage](https://arxiv.org/abs/2303.17076) paper. We use DiffCollage to generate long sequence following certain structure (e.g. loop).
diff_collage/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .w_loss import *
2
+ from .generic_sampler import *
3
+ from .condind_long import CondIndSimple
4
+ from .condind_circle import CondIndCircle
5
+ from .avg_long import AvgLong
diff_collage/avg_circle.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+ from einops import rearrange
4
+
5
+ from .generic_sampler import SimpleWork
6
+ from .w_img import split_wimg, avg_merge_wimg
7
+
8
+ class AvgCircle(SimpleWork):
9
+ def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32):
10
+ c, h, w = shape
11
+ self.base_img_w = w
12
+ self.overlap_size = overlap_size
13
+ self.num_img = num_img
14
+ final_img_w = w * num_img - self.overlap_size * num_img
15
+ super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn))
16
+
17
+ def get_eps_t_fn(self, eps_scalar_t_fn):
18
+ def eps_t_fn(long_x, scalar_t, enable_grad=False):
19
+ shift = np.random.randint(self.base_img_w)
20
+ long_x = th.cat(
21
+ [
22
+ long_x[:,:,:,shift:],
23
+ long_x[:,:,:,:shift]
24
+ ],
25
+ dim=-1
26
+ )
27
+
28
+ x = th.cat(
29
+ [
30
+ long_x,
31
+ long_x[:,:,:,:self.overlap_size]
32
+ ],
33
+ dim=-1,
34
+ )
35
+ xs, _overlap = split_wimg(x, self.num_img, rtn_overlap=True)
36
+ assert _overlap == self.overlap_size
37
+ full_eps = eps_scalar_t_fn(xs, scalar_t, enable_grad) # #((b,n), c, h, w)
38
+
39
+ eps = avg_merge_wimg(full_eps, self.overlap_size, n=self.num_img)
40
+ eps = th.cat(
41
+ [
42
+ (eps[:,:,:,:self.overlap_size] + eps[:,:,:,-self.overlap_size:])/2.0,
43
+ eps[:,:,:,self.overlap_size:-self.overlap_size]
44
+ ],
45
+ dim=-1
46
+ )
47
+ assert eps.shape == long_x.shape
48
+ return th.cat(
49
+ [
50
+ eps[:,:,:,-shift:],
51
+ eps[:,:,:,:-shift],
52
+ ],
53
+ dim=-1
54
+ )
55
+ # return eps
56
+
57
+ return eps_t_fn
58
+
59
+ def x0_fn(self, xt, scalar_t, enable_grad=False):
60
+ cur_eps = self.eps_scalar_t_fn(xt, scalar_t, enable_grad)
61
+ x0 = xt - scalar_t * cur_eps
62
+ return x0, {}, {
63
+ "x0": x0.cpu()
64
+ }
diff_collage/avg_long.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ from einops import rearrange
3
+
4
+ from .generic_sampler import SimpleWork
5
+ from .w_img import split_wimg, avg_merge_wimg
6
+
7
+ class AvgLong(SimpleWork):
8
+ def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32):
9
+ c, h, w = shape
10
+ assert overlap_size == w // 2
11
+ self.overlap_size = overlap_size
12
+ self.num_img = num_img
13
+ final_img_w = w * num_img - self.overlap_size * (num_img - 1)
14
+ super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn))
15
+
16
+ def loss(self, x):
17
+ x1, x2 = x[:-1], x[1:]
18
+ return th.sum(
19
+ (th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
20
+ dim=(1, 2, 3),
21
+ )
22
+
23
+ def get_eps_t_fn(self, eps_scalar_t_fn):
24
+ def eps_t_fn(long_x, scalar_t, y=None):
25
+ xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
26
+ if y is not None:
27
+ y = y.repeat_interleave(self.num_img)
28
+ scalar_t = scalar_t.repeat_interleave(self.num_img)
29
+ full_eps = eps_scalar_t_fn(xs, scalar_t, y=y) #((b,n), c, h, w)
30
+ full_eps = rearrange(
31
+ full_eps,
32
+ "(b n) c h w -> n b c h w", n = self.num_img
33
+ )
34
+
35
+ whole_eps = rearrange(
36
+ full_eps,
37
+ "n b c h w -> (b n) c h w"
38
+ )
39
+ return avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
40
+ return eps_t_fn
diff_collage/condind_circle.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ from einops import rearrange
3
+
4
+ from .generic_sampler import SimpleWork
5
+ from .w_img import split_wimg, avg_merge_wimg
6
+
7
+ class CondIndCircle(SimpleWork):
8
+ def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32):
9
+ c, h, w = shape
10
+ assert overlap_size == w // 2
11
+ self.overlap_size = overlap_size
12
+ self.num_img = num_img
13
+ final_img_w = w * num_img - self.overlap_size * num_img
14
+ super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn))
15
+
16
+ def circle_split(self, in_x):
17
+ long_x = th.cat(
18
+ [
19
+ in_x,
20
+ in_x[:,:,:,:self.overlap_size],
21
+ ],
22
+ dim=-1
23
+ )
24
+ xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
25
+ return xs
26
+
27
+ def circle_merge(self, xs, overlap_size=None):
28
+ if overlap_size is None:
29
+ overlap_size = self.overlap_size
30
+ long_xs = avg_merge_wimg(xs, overlap_size, n=self.num_img, is_avg=True)
31
+ return th.cat(
32
+ [
33
+ (
34
+ long_xs[:,:,:,:overlap_size] + long_xs[:,:,:,-overlap_size:]
35
+ ) / 2.0,
36
+ long_xs[:,:,:,overlap_size:-overlap_size]
37
+ ],
38
+ dim=-1
39
+ )
40
+
41
+ def get_eps_t_fn(self, eps_scalar_t_fn):
42
+ def eps_t_fn(in_x, scalar_t, y=None):
43
+ long_x = th.cat(
44
+ [
45
+ in_x,
46
+ in_x[:,:,:,:self.overlap_size],
47
+ ],
48
+ dim=-1
49
+ )
50
+ xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
51
+ if y is not None:
52
+ y = y.repeat_interleave(self.num_img)
53
+ scalar_t = scalar_t.repeat_interleave(self.num_img)
54
+ full_eps = eps_scalar_t_fn(xs, scalar_t, y=y) #((b,n), c, h, w)
55
+ full_eps = rearrange(
56
+ full_eps,
57
+ "(b n) c h w -> n b c h w", n = self.num_img
58
+ )
59
+
60
+ # calculate half eps
61
+ half_eps = eps_scalar_t_fn(xs[:,:,:,-self.overlap_size:], scalar_t, y=y) #((b,n), c, h, w//2)
62
+ half_eps = rearrange(
63
+ half_eps,
64
+ "(b n) c h w -> n b c h w", n = self.num_img
65
+ )
66
+
67
+ half_eps[-1]=0
68
+
69
+ full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps
70
+ whole_eps = rearrange(
71
+ full_eps,
72
+ "n b c h w -> (b n) c h w"
73
+ )
74
+ long_eps = avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
75
+ return th.cat(
76
+ [
77
+ (
78
+ long_eps[:,:,:,:self.overlap_size] + long_eps[:,:,:,-self.overlap_size:]
79
+ ) / 2.0,
80
+ long_eps[:,:,:,self.overlap_size:-self.overlap_size]
81
+ ],
82
+ dim=-1
83
+ )
84
+ return eps_t_fn
85
+
86
+
87
+ class CondIndCircleSR(SimpleWork):
88
+ def __init__(self, shape, eps_scalar_t_fn, num_img, low_res, overlap_size=32):
89
+ c, h, w = shape
90
+ assert overlap_size == w // 2
91
+ self.overlap_size = overlap_size
92
+ self.low_overlap_size = low_res.shape[-2] // 2
93
+ self.num_img = num_img
94
+ final_img_w = w * num_img - self.overlap_size * num_img
95
+ assert low_res.shape[-1] == self.low_overlap_size * num_img
96
+
97
+ self.square_fn = self.get_square_sr_fn(eps_scalar_t_fn, low_res)
98
+ self.half_fn = self.get_half_sr_fn(eps_scalar_t_fn, low_res)
99
+
100
+ super().__init__((c, h, final_img_w), self.get_eps_t_fn())
101
+
102
+ def circle_split(self, in_x, overlap_size=None):
103
+ if overlap_size is None:
104
+ overlap_size = self.overlap_size
105
+ long_x = th.cat(
106
+ [
107
+ in_x,
108
+ in_x[:,:,:,:overlap_size],
109
+ ],
110
+ dim=-1
111
+ )
112
+ xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
113
+ return xs
114
+
115
+ def circle_merge(self, xs, overlap_size=None):
116
+ if overlap_size is None:
117
+ overlap_size = self.overlap_size
118
+ long_xs = avg_merge_wimg(xs, overlap_size, n=self.num_img, is_avg=True)
119
+ return th.cat(
120
+ [
121
+ (
122
+ long_xs[:,:,:,:overlap_size] + long_xs[:,:,:,-overlap_size:]
123
+ ) / 2.0,
124
+ long_xs[:,:,:,overlap_size:-overlap_size]
125
+ ],
126
+ dim=-1
127
+ )
128
+
129
+ def get_square_sr_fn(self, eps_fn, low_res):
130
+ low_res = self.circle_split(low_res, self.low_overlap_size)
131
+ def _fn(_x, _t, enable_grad):
132
+ context = th.enable_grad if enable_grad else th.no_grad
133
+ with context():
134
+ vec_t = th.ones(_x.shape[0]).cuda() * _t
135
+ rtn = eps_fn(_x, vec_t, low_res)
136
+ rtn = rearrange(
137
+ rtn,
138
+ "(b n) c h w -> n b c h w", n = self.num_img
139
+ )
140
+ return rtn
141
+ return _fn
142
+
143
+ def get_half_sr_fn(self, eps_fn, low_res):
144
+ low_res = self.circle_split(low_res, self.low_overlap_size)
145
+ def _fn(_x, _t, enable_grad):
146
+ context = th.enable_grad if enable_grad else th.no_grad
147
+ with context():
148
+ vec_t = th.ones(_x.shape[0]).cuda() * _t
149
+ half_eps = eps_fn(_x[:,:,:,-self.overlap_size:], vec_t, low_res[:,:,:,-self.low_overlap_size:])
150
+ half_eps = rearrange(
151
+ half_eps,
152
+ "(b n) c h w -> n b c h w", n = self.num_img
153
+ )
154
+
155
+ half_eps[-1]=0
156
+ return half_eps
157
+ return _fn
158
+
159
+ def get_eps_t_fn(self):
160
+ def eps_t_fn(in_x, scalar_t, enable_grad=False):
161
+ long_x = th.cat(
162
+ [
163
+ in_x,
164
+ in_x[:,:,:,:self.overlap_size],
165
+ ],
166
+ dim=-1
167
+ )
168
+ xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
169
+
170
+ # full eps
171
+ full_eps = self.square_fn(xs, scalar_t, enable_grad)
172
+ # calculate half eps
173
+ half_eps = self.half_fn(xs, scalar_t, enable_grad)
174
+
175
+ full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps
176
+ whole_eps = rearrange(
177
+ full_eps,
178
+ "n b c h w -> (b n) c h w"
179
+ )
180
+ long_eps = avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
181
+ return th.cat(
182
+ [
183
+ (
184
+ long_eps[:,:,:,:self.overlap_size] + long_eps[:,:,:,-self.overlap_size:]
185
+ ) / 2.0,
186
+ long_eps[:,:,:,self.overlap_size:-self.overlap_size]
187
+ ],
188
+ dim=-1
189
+ )
190
+ return eps_t_fn
diff_collage/condind_long.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch as th
3
+ from einops import rearrange
4
+
5
+ from .generic_sampler import SimpleWork
6
+ from .w_img import split_wimg, avg_merge_wimg
7
+
8
+ class CondIndSimple(SimpleWork):
9
+ def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32):
10
+ c, h, w = shape
11
+ assert overlap_size == w // 2
12
+ self.overlap_size = overlap_size
13
+ self.num_img = num_img
14
+ final_img_w = w * num_img - self.overlap_size * (num_img - 1)
15
+ super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn))
16
+
17
+ def loss(self, x):
18
+ x1, x2 = x[:-1], x[1:]
19
+ return th.sum(
20
+ (th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
21
+ dim=(1, 2, 3),
22
+ )
23
+
24
+ def get_eps_t_fn(self, eps_scalar_t_fn):
25
+ def eps_t_fn(long_x, scalar_t, y=None):
26
+ xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
27
+ if y is not None:
28
+ y = y.repeat_interleave(self.num_img)
29
+ scalar_t = scalar_t.repeat_interleave(self.num_img)
30
+ full_eps = eps_scalar_t_fn(xs, scalar_t, y=y) #((b,n), c, h, w)
31
+ full_eps = rearrange(
32
+ full_eps,
33
+ "(b n) c h w -> n b c h w", n = self.num_img
34
+ )
35
+
36
+ # calculate half eps
37
+ half_eps = eps_scalar_t_fn(xs[:,:,:,-self.overlap_size:], scalar_t, y=y) #((b,n), c, h, w//2)
38
+ half_eps = rearrange(
39
+ half_eps,
40
+ "(b n) c h w -> n b c h w", n = self.num_img
41
+ )
42
+
43
+ half_eps[-1]=0
44
+
45
+ full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps
46
+ whole_eps = rearrange(
47
+ full_eps,
48
+ "n b c h w -> (b n) c h w"
49
+ )
50
+ return avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
51
+ return eps_t_fn
52
+
53
+
54
+
55
+ class CondIndSR(SimpleWork):
56
+ def __init__(self, shape, eps_scalar_t_fn, num_img, low_res, overlap_size=128):
57
+ c, h, w = shape
58
+ assert overlap_size == w // 2
59
+ self.overlap_size = overlap_size
60
+ self.low_overlap_size = low_res.shape[-2] // 2
61
+ self.num_img = num_img
62
+ final_img_w = w * num_img - self.overlap_size * (num_img - 1)
63
+ assert low_res.shape[-1] == self.low_overlap_size * (num_img + 1)
64
+
65
+ self.square_fn = self.get_square_sr_fn(eps_scalar_t_fn, low_res)
66
+ self.half_fn = self.get_half_sr_fn(eps_scalar_t_fn, low_res)
67
+
68
+ super().__init__((c, h, final_img_w), self.get_eps_t_fn())
69
+
70
+ def get_square_sr_fn(self, eps_fn, low_res):
71
+ low_res = split_wimg(low_res, self.num_img, False)
72
+ def _fn(_x, _t, enable_grad):
73
+ context = th.enable_grad if enable_grad else th.no_grad
74
+ with context():
75
+ vec_t = th.ones(_x.shape[0]).cuda() * _t
76
+ rtn = eps_fn(_x, vec_t, low_res)
77
+ rtn = rearrange(
78
+ rtn,
79
+ "(b n) c h w -> n b c h w", n = self.num_img
80
+ )
81
+ return rtn
82
+ return _fn
83
+
84
+ def get_half_sr_fn(self, eps_fn, low_res):
85
+ low_res = split_wimg(low_res, self.num_img, False)
86
+ def _fn(_x, _t, enable_grad):
87
+ context = th.enable_grad if enable_grad else th.no_grad
88
+ with context():
89
+ vec_t = th.ones(_x.shape[0]).cuda() * _t
90
+ half_eps = eps_fn(_x[:,:,:,-self.overlap_size:], vec_t, low_res[:,:,:,-self.low_overlap_size:])
91
+ half_eps = rearrange(
92
+ half_eps,
93
+ "(b n) c h w -> n b c h w", n = self.num_img
94
+ )
95
+
96
+ half_eps[-1]=0
97
+ return half_eps
98
+ return _fn
99
+
100
+ def get_eps_t_fn(self):
101
+ def eps_t_fn(in_x, scalar_t, enable_grad=False):
102
+ xs = split_wimg(in_x, self.num_img, rtn_overlap=False)
103
+
104
+ # full eps
105
+ full_eps = self.square_fn(xs, scalar_t, enable_grad)
106
+ # calculate half eps
107
+ half_eps = self.half_fn(xs, scalar_t, enable_grad)
108
+
109
+ full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps
110
+ whole_eps = rearrange(
111
+ full_eps,
112
+ "n b c h w -> (b n) c h w"
113
+ )
114
+ out_eps = avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
115
+ return out_eps
116
+ return eps_t_fn
117
+
118
+
119
+
120
+
121
+ # class CondIndLong(SimpleWork):
122
+ # def __init__(self, shape, eps_scalar_t_fn, overlap_size=32):
123
+ # super().__init__(shape, eps_scalar_t_fn)
124
+ # self.overlap_size = overlap_size
125
+
126
+ # def loss(self, x):
127
+ # x1, x2 = x[:-1], x[1:]
128
+ # return th.sum(
129
+ # (th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
130
+ # dim=(1, 2, 3),
131
+ # )
132
+
133
+ # def generate_xT(self, n):
134
+ # white_noise = th.randn((n , *self.shape)).cuda()
135
+ # return self.noise(white_noise, None) * 80.0
136
+
137
+ # def noise(self, xt, scalar_t):
138
+ # del scalar_t
139
+ # noise = th.randn_like(xt)
140
+ # b, _, _, w = xt.shape
141
+ # final_img_w = w * b - self.overlap_size * (b - 1)
142
+ # noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w]
143
+ # noise = split_wimg(noise, b, rtn_overlap=False)
144
+ # return noise
145
+
146
+ # def merge(self, xs):
147
+ # return avg_merge_wimg(xs, self.overlap_size)
diff_collage/generic_sampler.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ import math
4
+ import numpy as np
5
+ import torch as th
6
+ from tqdm import tqdm
7
+
8
+ __all__ = [
9
+ "generic_sampler",
10
+ "SimpleWork",
11
+ ]
12
+
13
+
14
+ def batch_mul(a, b): # pylint: disable=invalid-name
15
+ return th.einsum("a...,a...->a...", a, b)
16
+
17
+ class SimpleWork:
18
+ def __init__(self, shape, eps_scalar_t_fn):
19
+ self.shape = shape
20
+ self.eps_scalar_t_fn = eps_scalar_t_fn
21
+
22
+ def generate_xT(self, n):
23
+ return 80.0 * th.randn((n , *self.shape)).cuda()
24
+
25
+ def x0_fn(self, xt, scalar_t, y=None):
26
+ cur_eps = self.eps_scalar_t_fn(xt, scalar_t, y=y)
27
+ x0 = xt - scalar_t * cur_eps
28
+ x0 = th.clip(x0, -1,1)
29
+ return x0, {}, {"x0": x0.cpu()}
30
+
31
+ def noise(self, xt, scalar_t):
32
+ del scalar_t
33
+ return th.randn_like(xt)
34
+
35
+ def rev_ts(self, n_step, ts_order):
36
+ _rev_ts = th.pow(
37
+ th.linspace(
38
+ np.power(80.0, 1.0 / ts_order),
39
+ np.power(1e-3, 1.0 / ts_order),
40
+ n_step + 1
41
+ ),
42
+ ts_order
43
+ )
44
+ return _rev_ts.cuda()
45
+
46
+ def generic_sampler( # pylint: disable=too-many-locals
47
+ x,
48
+ rev_ts,
49
+ noise_fn,
50
+ x0_pred_fn,
51
+ xt_lgv_fn=None,
52
+ s_churn = 0.0,
53
+ before_step_fn=None,
54
+ end_fn=None, # to do???
55
+ is_tqdm=True,
56
+ is_traj=True,
57
+ ):
58
+ measure_loss = defaultdict(list)
59
+ traj = defaultdict(list)
60
+ if callable(x):
61
+ x = x()
62
+ if traj:
63
+ traj["xt"].append(x.cpu())
64
+
65
+ s_t_min = 0.05
66
+ s_t_max = 50.0
67
+ s_noise = 1.003
68
+ eta = min(s_churn / len(rev_ts), math.sqrt(2.0) - 1)
69
+
70
+ loop = zip(rev_ts[:-1], rev_ts[1:])
71
+ if is_tqdm:
72
+ loop = tqdm(loop)
73
+
74
+ running_x = x
75
+ for cur_t, next_t in loop:
76
+ # cur_x = traj["xt"][-1].clone().to("cuda")
77
+ cur_x = running_x
78
+ if cur_t < s_t_max and cur_t > s_t_min:
79
+ hat_cur_t = cur_t + eta * cur_t
80
+ cur_noise = noise_fn(cur_x, cur_t)
81
+ cur_x = cur_x + s_noise * cur_noise * th.sqrt(hat_cur_t ** 2 - cur_t ** 2)
82
+ cur_t = hat_cur_t
83
+
84
+ if before_step_fn is not None:
85
+ # TODO: may change the callabck
86
+ cur_x = before_step_fn(cur_x, cur_t)
87
+
88
+ x0, loss_info, traj_info = x0_pred_fn(cur_x, cur_t)
89
+ epsilon_1 = (cur_x - x0) / cur_t
90
+
91
+ xt_next = x0 + next_t * epsilon_1
92
+
93
+ x0, loss_info, traj_info = x0_pred_fn(xt_next, next_t)
94
+ epsilon_2 = (xt_next - x0) / next_t
95
+
96
+ xt_next = cur_x + (next_t - cur_t) * (epsilon_1 + epsilon_2) / 2
97
+
98
+ running_x = xt_next
99
+
100
+ if is_traj:
101
+ for key, value in loss_info.items():
102
+ measure_loss[key].append(value)
103
+
104
+ for key, value in traj_info.items():
105
+ traj[key].append(value)
106
+ traj["xt"].append(running_x.to("cpu").detach())
107
+
108
+ if xt_lgv_fn:
109
+ raise RuntimeError("Not implemented")
110
+
111
+ if is_traj:
112
+ return traj, measure_loss
113
+ return running_x
diff_collage/loss_helper.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ from .generic_sampler import batch_mul
3
+
4
+ def get_x0_grad_pred_fn(raw_net_model, cond_loss_fn, weight_fn, x0_update, thres_t):
5
+ def fn(xt, scalar_t):
6
+ xt = xt.requires_grad_(True)
7
+ x0_pred = raw_net_model(xt, scalar_t)
8
+
9
+ loss_info = {
10
+ "raw_x0": cond_loss_fn(x0_pred.detach()).cpu(),
11
+ }
12
+ traj_info = {
13
+ "t": scalar_t,
14
+ }
15
+ if scalar_t < thres_t:
16
+ x0_cor = x0_pred.detach()
17
+ else:
18
+ pred_loss = cond_loss_fn(x0_pred)
19
+ grad_term = th.autograd.grad(pred_loss.sum(), xt)[0]
20
+ weights = weight_fn(x0_pred, grad_term, cond_loss_fn)
21
+ x0_cor = (x0_pred - batch_mul(weights, grad_term)).detach()
22
+ loss_info["weight"] = weights.detach().cpu()
23
+ traj_info["grad"] = grad_term.detach().cpu()
24
+
25
+ if x0_update:
26
+ x0 = x0_update(x0_cor, scalar_t)
27
+ else:
28
+ x0 = x0_cor
29
+
30
+ loss_info["cor_x0"] = cond_loss_fn(x0_cor.detach()).cpu()
31
+ loss_info["x0"] = cond_loss_fn(x0.detach()).cpu()
32
+ traj_info.update({
33
+ "raw_x0": x0_pred.detach().cpu(),
34
+ "cor_x0": x0_cor.detach().cpu(),
35
+ "x0": x0.detach().cpu(),
36
+ }
37
+ )
38
+ return x0_cor, loss_info, traj_info
39
+
40
+
41
+ return fn
diff_collage/w_img.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ from einops import rearrange
3
+
4
+ __all__ = [
5
+ "split_wimg",
6
+ ]
7
+
8
+ def split_wimg(wimg, n_img, rtn_overlap=True):
9
+ if wimg.ndim == 3:
10
+ wimg = wimg[None]
11
+ _, _, h, w = wimg.shape
12
+ base_len = 128 # todo: hard code 128 here (the length of the latents)
13
+ overlap_size = (n_img * base_len - w) // (n_img - 1)
14
+ assert n_img * base_len - overlap_size * (n_img - 1) == w
15
+
16
+ img = th.nn.functional.unfold(wimg, kernel_size=(h, base_len), stride=base_len - overlap_size) #(B, block, n_img)
17
+ img = rearrange(
18
+ img,
19
+ "b (c h w) n -> (b n) c h w", h=h, w=base_len
20
+ )
21
+
22
+ if rtn_overlap:
23
+ return img , overlap_size
24
+ return img
25
+
26
+ def avg_merge_wimg(imgs, overlap_size, n=None, is_avg=True):
27
+ b, _, h, w = imgs.shape
28
+ if n == None:
29
+ n = b
30
+ unfold_img = rearrange(
31
+ imgs,
32
+ "(b n) c h w -> b (c h w) n", n = n
33
+ )
34
+ img = th.nn.functional.fold(
35
+ unfold_img,
36
+ (h, n * w - (n-1) * overlap_size),
37
+ kernel_size = (h, w),
38
+ stride = w - overlap_size
39
+ )
40
+ if is_avg:
41
+ counter = th.nn.functional.fold(
42
+ th.ones_like(unfold_img),
43
+ (h, n * w - (n-1) * overlap_size),
44
+ kernel_size = (h, w),
45
+ stride = w - overlap_size
46
+ )
47
+ return img / counter
48
+ return img
49
+
50
+ # legacy code use naive implementation
51
+
52
+ def split_wimg_legacy(himg, n_img, rtn_overlap=True):
53
+ if himg.ndim == 3:
54
+ himg = himg[None]
55
+ _, _, h, w = himg.shape
56
+ overlap_size = (n_img * h - w) // (n_img - 1)
57
+ assert n_img * h - overlap_size * (n_img - 1) == w
58
+ himg = himg[0]
59
+ rtn_img = [himg[:, :, :h]]
60
+ for i in range(n_img - 1):
61
+ rtn_img.append(himg[:, :, (h - overlap_size) * (i + 1) : h + (h - overlap_size) * (i + 1)])
62
+ if rtn_overlap:
63
+ return th.stack(rtn_img), overlap_size
64
+ return th.stack(rtn_img)
65
+
66
+ def avg_merge_wimg_legacy(imgs, overlap_size):
67
+ _, _, _, w = imgs.shape
68
+ rtn_img = [imgs[0]]
69
+ for cur_img in imgs[1:]:
70
+ rtn_img.append(cur_img[:, :, overlap_size:])
71
+ first_img = th.cat(rtn_img, dim=-1)
72
+
73
+ rtn_img = []
74
+ for cur_img in imgs[:-1]:
75
+ rtn_img.append(cur_img[:, :, : w - overlap_size])
76
+ rtn_img.append(imgs[-1])
77
+ second_img = th.cat(rtn_img, dim=-1)
78
+
79
+ return (first_img + second_img) / 2.0
diff_collage/w_loss.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch as th
3
+ from einops import rearrange
4
+ import numpy as np
5
+
6
+ from .generic_sampler import batch_mul
7
+
8
+
9
+ def split_wimg(himg, n_img, rtn_overlap=True):
10
+ if himg.ndim == 3:
11
+ himg = himg[None]
12
+ _, _, h, w = himg.shape
13
+ overlap_size = (n_img * h - w) // (n_img - 1)
14
+ assert n_img * h - overlap_size * (n_img - 1) == w
15
+ himg = himg[0]
16
+ rtn_img = [himg[:, :, :h]]
17
+ for i in range(n_img - 1):
18
+ rtn_img.append(himg[:, :, (h - overlap_size) * (i + 1) : h + (h - overlap_size) * (i + 1)])
19
+ if rtn_overlap:
20
+ return th.stack(rtn_img), overlap_size
21
+ return th.stack(rtn_img)
22
+
23
+
24
+ def merge_wimg(imgs, overlap_size):
25
+ _, _, _, w = imgs.shape
26
+ rtn_img = [imgs[0]]
27
+ for cur_img in imgs[1:]:
28
+ rtn_img.append(cur_img[:, :, overlap_size:])
29
+ first_img = th.cat(rtn_img, dim=-1)
30
+
31
+ rtn_img = []
32
+ for cur_img in imgs[:-1]:
33
+ rtn_img.append(cur_img[:, :, : w - overlap_size])
34
+ rtn_img.append(imgs[-1])
35
+ second_img = th.cat(rtn_img, dim=-1)
36
+
37
+ return (first_img + second_img) / 2.0
38
+
39
+
40
+ def get_x0_pred_fn(raw_net_model, cond_loss_fn, weight_fn, x0_fn, thres_t, init_fn=None):
41
+ def fn(xt, scalar_t):
42
+ if init_fn is not None:
43
+ xt = init_fn(xt, scalar_t)
44
+ xt = xt.requires_grad_(True)
45
+ x0_pred = raw_net_model(xt, scalar_t)
46
+
47
+ loss_info = {
48
+ "raw_x0": cond_loss_fn(x0_pred.detach()).cpu(),
49
+ }
50
+ traj_info = {
51
+ "t": scalar_t,
52
+ }
53
+ if scalar_t < thres_t:
54
+ x0_cor = x0_pred.detach()
55
+ else:
56
+ pred_loss = cond_loss_fn(x0_pred)
57
+ grad_term = th.autograd.grad(pred_loss.sum(), xt)[0]
58
+ weights = weight_fn(x0_pred, grad_term, cond_loss_fn)
59
+ x0_cor = (x0_pred - batch_mul(weights, grad_term)).detach()
60
+ loss_info["weight"] = weights.detach().cpu()
61
+ traj_info["grad"] = grad_term.detach().cpu()
62
+
63
+ if x0_fn:
64
+ x0 = x0_fn(x0_cor, scalar_t)
65
+ else:
66
+ x0 = x0_cor
67
+
68
+ loss_info["cor_x0"] = cond_loss_fn(x0_cor.detach()).cpu()
69
+ loss_info["x0"] = cond_loss_fn(x0.detach()).cpu()
70
+ traj_info.update({
71
+ "raw_x0": x0_pred.detach().cpu(),
72
+ "cor_x0": x0_cor.detach().cpu(),
73
+ "x0": x0.detach().cpu(),
74
+ }
75
+ )
76
+ return x0_cor, loss_info, traj_info
77
+
78
+ return fn
79
+
80
+
81
+ def simple_noise(cur_t, xt):
82
+ del cur_t
83
+ return th.randn_like(xt)
84
+
85
+
86
+ def get_fix_weight_fn(fix_weight):
87
+ def weight_fn(xs, grads, *args):
88
+ del grads, args
89
+ return th.ones(xs.shape[0]).to(xs) * fix_weight
90
+
91
+ return weight_fn
92
+
93
+
94
+ class SeqWorker:
95
+ def __init__(self, overlap_size=10, src_img=None):
96
+ self.overlap_size = overlap_size
97
+ self.src_img = src_img
98
+
99
+ def loss(self, x):
100
+ return th.sum(
101
+ (th.abs(self.src_img[:, :, :, -self.overlap_size :] - x[:, :, :, : self.overlap_size]))
102
+ ** 2,
103
+ dim=(1, 2, 3),
104
+ )
105
+
106
+ def x0_replace(self, x0):
107
+ rtn_x0 = x0.clone()
108
+ rtn_x0[:, :, :, : self.overlap_size] = self.src_img[:, :, :, -self.overlap_size :]
109
+ return x0
110
+
111
+ def optimal_weight_fn(self, x0, grads, *args, ratio=1.0):
112
+ del args
113
+ overlap_size = self.overlap_size
114
+ # argmin_{w} (delta_pixel - w * delta_pixel)^2
115
+ delta_pixel = x0[:, :, :, :overlap_size] - self.src_img[:, :, :, -overlap_size:]
116
+ delta_grads = grads[:, :, :, :overlap_size]
117
+ num = th.sum(delta_pixel * delta_grads).item()
118
+ denum = th.sum(delta_grads * delta_grads).item()
119
+ _optimal_weight = num / denum
120
+ if math.isnan(_optimal_weight):
121
+ print(denum)
122
+ raise RuntimeError("nan for weights")
123
+
124
+ return ratio * _optimal_weight * th.ones(x0.shape[0]).to(x0)
125
+
126
+
127
+ class CircleWorker:
128
+ def __init__(self, overlap_size=10, adam_num_iter=100):
129
+ self.overlap_size = overlap_size
130
+ self.adam_num_iter = adam_num_iter
131
+
132
+
133
+ def get_match_patch(self, x):
134
+ tail = x[:, :, :, -self.overlap_size :]
135
+ head = x[:, :, :, : self.overlap_size]
136
+ tail = th.roll(tail, 1, 0)
137
+ return tail, head
138
+
139
+ def loss(self, x):
140
+ tail, head = self.get_match_patch(x)
141
+ return th.sum(
142
+ (tail - head)**2,
143
+ dim=(1, 2, 3),
144
+ )
145
+
146
+ def split_noise(self, cur_t, xt):
147
+ noise = simple_noise(cur_t, xt)
148
+ b, _, _, w = xt.shape
149
+ final_img_w = w * b - self.overlap_size * b
150
+ noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w]
151
+ noise = th.cat([noise, noise[:,:,:, :self.overlap_size]], dim=-1)
152
+ noise, _ = split_wimg(noise, b)
153
+ return noise
154
+
155
+ def merge_circle_image(self, xt):
156
+ merged_long_img = merge_wimg(xt, self.overlap_size)
157
+ return th.cat(
158
+ [
159
+ (merged_long_img[:,:,:self.overlap_size] + merged_long_img[:,:,-self.overlap_size:]) / 2.0,
160
+ merged_long_img[:,:,self.overlap_size:-self.overlap_size],
161
+ ],
162
+ dim=-1
163
+ )
164
+
165
+ def split_circle_image(self, merged_long_img, n):
166
+ imgs,_ = split_wimg(
167
+ th.cat(
168
+ [
169
+ merged_long_img,
170
+ merged_long_img[:,:,:self.overlap_size],
171
+ ],
172
+ dim = -1,
173
+ ),
174
+ n
175
+ )
176
+ return imgs
177
+
178
+
179
+ def optimal_weight_fn(self, xs, grads, *args):
180
+ del args
181
+ # argmin_{w} (delta_pixel - w * delta_pixel)^2
182
+ tail, head = self.get_match_patch(xs)
183
+ delta_pixel = tail - head
184
+ tail, head = self.get_match_patch(grads)
185
+ delta_grads = tail - head
186
+
187
+ num = th.sum(delta_pixel * delta_grads).item()
188
+ denum = th.sum(delta_grads * delta_grads).item()
189
+ _optimal_weight = num / denum
190
+ return _optimal_weight * th.ones(xs.shape[0]).to(xs)
191
+
192
+ def adam_grad_weight(self, x0, grad_term, cond_loss_fn):
193
+ init_weight = self.optimal_weight_fn(x0, grad_term)
194
+ grad_term = grad_term.detach()
195
+ x0 = x0.detach()
196
+ with th.enable_grad():
197
+ weights = init_weight.requires_grad_()
198
+ optimizer = th.optim.Adam(
199
+ [
200
+ weights,
201
+ ],
202
+ lr=1e-2,
203
+ )
204
+
205
+ def _loss(w):
206
+ cor_x0 = x0 - batch_mul(w, grad_term)
207
+ return cond_loss_fn(cor_x0).sum()
208
+
209
+ for _ in range(self.adam_num_iter):
210
+ optimizer.zero_grad()
211
+ _cur_loss = _loss(weights)
212
+ _cur_loss.backward()
213
+ optimizer.step()
214
+ return weights
215
+
216
+ # TODO:
217
+ def x0_replace(self, x0, sclar_t, thres_t):
218
+ if sclar_t > thres_t:
219
+ merge_x0 = merge_wimg(x0, self.overlap_size)
220
+ return split_wimg(merge_x0, x0.shape[0])[0]
221
+ else:
222
+ return x0
223
+
224
+
225
+ class ParaWorker:
226
+ def __init__(self, overlap_size=10, adam_num_iter=100):
227
+ self.overlap_size = overlap_size
228
+ self.adam_num_iter = adam_num_iter
229
+
230
+ def loss(self, x):
231
+ x1, x2 = x[:-1], x[1:]
232
+ return th.sum(
233
+ (th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
234
+ dim=(1, 2, 3),
235
+ )
236
+
237
+ def split_noise(self, xt, cur_t):
238
+ noise = simple_noise(cur_t, xt)
239
+ b, _, _, w = xt.shape
240
+ final_img_w = w * b - self.overlap_size * (b - 1)
241
+ noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w]
242
+ noise, _ = split_wimg(noise, b)
243
+ return noise
244
+
245
+ def optimal_weight_fn(self, xs, grads, *args):
246
+ del args
247
+ overlap_size = self.overlap_size
248
+ # argmin_{w} (delta_pixel - w * delta_pixel)^2
249
+ delta_pixel = xs[:-1, :, :, -overlap_size:] - xs[1:, :, :, :overlap_size]
250
+ delta_grads = grads[:-1, :, :, -overlap_size:] - grads[1:, :, :, :overlap_size]
251
+ num = th.sum(delta_pixel * delta_grads).item()
252
+ denum = th.sum(delta_grads * delta_grads).item()
253
+ _optimal_weight = num / denum
254
+ return _optimal_weight * th.ones(xs.shape[0]).to(xs)
255
+
256
+ def adam_grad_weight(self, x0, grad_term, cond_loss_fn):
257
+ init_weight = self.optimal_weight_fn(x0, grad_term)
258
+ grad_term = grad_term.detach()
259
+ x0 = x0.detach()
260
+ with th.enable_grad():
261
+ weights = init_weight.requires_grad_()
262
+ optimizer = th.optim.Adam(
263
+ [
264
+ weights,
265
+ ],
266
+ lr=1e-2,
267
+ )
268
+
269
+ def _loss(w):
270
+ cor_x0 = x0 - batch_mul(w, grad_term)
271
+ return cond_loss_fn(cor_x0).sum()
272
+
273
+ for _ in range(self.adam_num_iter):
274
+ optimizer.zero_grad()
275
+ _cur_loss = _loss(weights)
276
+ _cur_loss.backward()
277
+ optimizer.step()
278
+ return weights
279
+
280
+ def x0_replace(self, x0, sclar_t, thres_t):
281
+ if sclar_t > thres_t:
282
+ merge_x0 = merge_wimg(x0, self.overlap_size)
283
+ return split_wimg(merge_x0, x0.shape[0])[0]
284
+ else:
285
+ return x0
286
+
287
+ class ParaWorkerC(ParaWorker):
288
+ def __init__(self, src_img, mask_img, inpaint_w = 1.0, overlap_size=10, adam_num_iter=100):
289
+ self.src_img = src_img
290
+ self.inpaint_w = inpaint_w
291
+ self.mask_img = mask_img # 1 indicate masked given pixels
292
+ super().__init__(overlap_size, adam_num_iter)
293
+
294
+ def loss(self, x):
295
+ if x.shape[0] == 1:
296
+ return th.sum(
297
+ th.sum(
298
+ th.square(self.src_img[:,:,:,:x.shape[-1]] - x), dim=(0,1)
299
+ ) * self.mask_img[:,:x.shape[-1]]
300
+ )
301
+ else:
302
+ consistent_loss = super().loss(x)
303
+ # merge image
304
+ merge_x = merge_wimg(x, self.overlap_size)
305
+
306
+ inpating_loss = th.sum(
307
+ th.sum(
308
+ th.square(self.src_img[:,:,:,:merge_x.shape[-1]] - merge_x), dim=(0,1)
309
+ ) * self.mask_img[:,:merge_x.shape[-1]]
310
+ )
311
+
312
+ return consistent_loss + inpating_loss / (x.shape[-1] - 1)
313
+
314
+ def x0_replace(self, x0, sclar_t, thres_t):
315
+ if sclar_t > thres_t:
316
+ merge_x = merge_wimg(x0, self.overlap_size)
317
+ src_img = self.src_img[:,:,:,:merge_x.shape[-1]]
318
+ mask_img = self.mask_img[:,:merge_x.shape[-1]]
319
+ merge_x = th.where(mask_img[None,None], src_img, merge_x)
320
+ return split_wimg(merge_x, x0.shape[0])[0]
321
+ else:
322
+ return x0
323
+
324
+
325
+ class SplitMergeOp:
326
+ def __init__(self, avg_overlap=32):
327
+ self.avg_overlap = avg_overlap
328
+ self.cur_overlap_int = None
329
+
330
+ def sample(self, n):
331
+ # lower_coef = 3 / 4.0
332
+ _lower_bound = self.avg_overlap - 6
333
+ base_overlap = np.ones(n) * _lower_bound
334
+
335
+ total_ball = (self.avg_overlap - _lower_bound) * n
336
+ random_number = np.random.randint(0, total_ball - n, n-1)
337
+ random_number = np.sort(random_number)
338
+ balls = np.append(random_number, total_ball - n) - np.insert(random_number, 0, 0) + np.ones(n) + base_overlap
339
+
340
+ assert np.sum(balls) == n * self.avg_overlap
341
+
342
+ # TODO: FIXME
343
+ balls = np.ones(n) * self.avg_overlap
344
+
345
+ return balls.astype(np.int)
346
+
347
+ def reset(self, n):
348
+ self.cur_overlap_int = self.sample(n)
349
+
350
+ def split(self, img, n, img_w=64):
351
+ assert img.ndim == 3
352
+ # assert img.shape[-1] > (n-1) * self.avg_overlap
353
+ assert (n-1) == self.cur_overlap_int.shape[0]
354
+
355
+ assert (n-1) * self.avg_overlap + img.shape[-1] == n * img_w
356
+
357
+ cur_idx = 0
358
+ imgs = []
359
+ for cur_overlap in self.cur_overlap_int:
360
+ imgs.append(img[:,:,cur_idx:cur_idx + img_w])
361
+ cur_idx = cur_idx + img_w - cur_overlap
362
+ imgs.append(img[:,:,cur_idx:])
363
+ return th.stack(imgs)
364
+
365
+ def merge(self, imgs):
366
+ b = imgs.shape[0]
367
+ img_size = imgs.shape[-1]
368
+ assert b - 1 == self.cur_overlap_int.shape[0]
369
+ img_width = b * imgs.shape[-1] - np.sum(self.cur_overlap_int)
370
+ wimg = th.zeros((3, imgs.shape[-2], img_width)).to(imgs)
371
+ ncnt = th.zeros(img_width).to(imgs)
372
+ cur_idx = 0
373
+ for i_th, cur_img in enumerate(imgs):
374
+ wimg[:,:,cur_idx:cur_idx + img_size] += cur_img
375
+ ncnt[cur_idx:cur_idx + img_size] += 1.0
376
+ if i_th < b -1:
377
+ cur_idx = cur_idx + img_size - self.cur_overlap_int[i_th]
378
+ return wimg / ncnt[None,None,:]
379
+
380
+
381
+ class ParaWorkerFix:
382
+ def __init__(self, overlap_size=10, adam_num_iter=100):
383
+ self.overlap_size = overlap_size
384
+ self.adam_num_iter = adam_num_iter
385
+ self.op = SplitMergeOp(overlap_size)
386
+
387
+ def loss(self, x):
388
+ avg_x = self.op.split(
389
+ self.op.merge(x), x.shape[0], x.shape[-1]
390
+ )
391
+ return th.sum(
392
+ (x - avg_x) ** 2,
393
+ dim=(1, 2, 3),
394
+ )
395
+
396
+ def split_noise(self, cur_t, xt):
397
+ noise = simple_noise(cur_t, xt)
398
+ b, _, _, w = xt.shape
399
+ final_img_w = w * b - self.overlap_size * (b - 1)
400
+ noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w][0]
401
+ noise = self.op.split(noise, b, w)
402
+ return noise
403
+
404
+ def adam_grad_weight(self, x0, grad_term, cond_loss_fn):
405
+ init_weight = th.ones(x0.shape[0]).to(x0)
406
+ grad_term = grad_term.detach()
407
+ x0 = x0.detach()
408
+ with th.enable_grad():
409
+ weights = init_weight.requires_grad_()
410
+ optimizer = th.optim.Adam(
411
+ [
412
+ weights,
413
+ ],
414
+ lr=1e-2,
415
+ )
416
+
417
+ def _loss(w):
418
+ cor_x0 = x0 - batch_mul(w, grad_term)
419
+ return cond_loss_fn(cor_x0).sum()
420
+
421
+ for _ in range(self.adam_num_iter):
422
+ optimizer.zero_grad()
423
+ _cur_loss = _loss(weights)
424
+ _cur_loss.backward()
425
+ optimizer.step()
426
+ return weights
427
+
428
+ def x0_replace(self, x0, sclar_t, thres_t):
429
+ if sclar_t > thres_t:
430
+ merge_x0 = self.op.merge(x0)
431
+ return self.op.split(merge_x0, x0.shape[0], x0.shape[-1])
432
+ else:
433
+ return x0
environment.yml ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: guided
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=main
9
+ - _openmp_mutex=5.1=1_gnu
10
+ - blas=1.0=mkl
11
+ - brotlipy=0.7.0=py39h27cfd23_1003
12
+ - bzip2=1.0.8=h7b6447c_0
13
+ - ca-certificates=2023.08.22=h06a4308_0
14
+ - certifi=2023.7.22=py39h06a4308_0
15
+ - cffi=1.15.1=py39h5eee18b_3
16
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
17
+ - cryptography=39.0.1=py39h9ce1e76_2
18
+ - cuda-cudart=11.7.99=0
19
+ - cuda-cupti=11.7.101=0
20
+ - cuda-libraries=11.7.1=0
21
+ - cuda-nvrtc=11.7.99=0
22
+ - cuda-nvtx=11.7.91=0
23
+ - cuda-runtime=11.7.1=0
24
+ - ffmpeg=4.3=hf484d3e_0
25
+ - filelock=3.9.0=py39h06a4308_0
26
+ - freetype=2.12.1=h4a9f257_0
27
+ - giflib=5.2.1=h5eee18b_3
28
+ - gmp=6.2.1=h295c915_3
29
+ - gmpy2=2.1.2=py39heeb90bb_0
30
+ - gnutls=3.6.15=he1e5248_0
31
+ - idna=3.4=py39h06a4308_0
32
+ - intel-openmp=2023.1.0=hdb19cb5_46305
33
+ - jinja2=3.1.2=py39h06a4308_0
34
+ - jpeg=9e=h5eee18b_1
35
+ - lame=3.100=h7b6447c_0
36
+ - lcms2=2.12=h3be6417_0
37
+ - ld_impl_linux-64=2.38=h1181459_1
38
+ - lerc=3.0=h295c915_0
39
+ - libcublas=11.10.3.66=0
40
+ - libcufft=10.7.2.124=h4fbf590_0
41
+ - libcufile=1.6.1.9=0
42
+ - libcurand=10.3.2.106=0
43
+ - libcusolver=11.4.0.1=0
44
+ - libcusparse=11.7.4.91=0
45
+ - libdeflate=1.17=h5eee18b_0
46
+ - libffi=3.4.4=h6a678d5_0
47
+ - libgcc-ng=11.2.0=h1234567_1
48
+ - libgfortran-ng=7.5.0=h14aa051_20
49
+ - libgfortran4=7.5.0=h14aa051_20
50
+ - libgomp=11.2.0=h1234567_1
51
+ - libiconv=1.16=h7f8727e_2
52
+ - libidn2=2.3.4=h5eee18b_0
53
+ - libnpp=11.7.4.75=0
54
+ - libnvjpeg=11.8.0.2=0
55
+ - libpng=1.6.39=h5eee18b_0
56
+ - libstdcxx-ng=11.2.0=h1234567_1
57
+ - libtasn1=4.19.0=h5eee18b_0
58
+ - libtiff=4.5.0=h6a678d5_2
59
+ - libunistring=0.9.10=h27cfd23_0
60
+ - libwebp=1.2.4=h11a3e52_1
61
+ - libwebp-base=1.2.4=h5eee18b_1
62
+ - lz4-c=1.9.4=h6a678d5_0
63
+ - markupsafe=2.1.1=py39h7f8727e_0
64
+ - mkl=2023.1.0=h6d00ec8_46342
65
+ - mkl-service=2.4.0=py39h5eee18b_1
66
+ - mkl_fft=1.3.6=py39h417a72b_1
67
+ - mkl_random=1.2.2=py39h417a72b_1
68
+ - mpc=1.1.0=h10f8cd9_1
69
+ - mpfr=4.0.2=hb69a4c5_1
70
+ - mpi=1.0=openmpi
71
+ - mpi4py=3.1.4=py39h3e5f7c9_0
72
+ - mpmath=1.2.1=py39h06a4308_0
73
+ - ncurses=6.4=h6a678d5_0
74
+ - nettle=3.7.3=hbbd107a_1
75
+ - networkx=2.8.4=py39h06a4308_1
76
+ - numpy=1.24.3=py39hf6e8229_1
77
+ - numpy-base=1.24.3=py39h060ed82_1
78
+ - openh264=2.1.1=h4ff587b_0
79
+ - openmpi=4.0.4=hdf1f1ad_0
80
+ - openssl=3.0.12=h7f8727e_0
81
+ - pillow=9.4.0=py39h6a678d5_0
82
+ - pip=23.1.2=py39h06a4308_0
83
+ - pycparser=2.21=pyhd3eb1b0_0
84
+ - pyopenssl=23.0.0=py39h06a4308_0
85
+ - pysocks=1.7.1=py39h06a4308_0
86
+ - python=3.9.16=h955ad1f_3
87
+ - pytorch=2.0.1=py3.9_cuda11.7_cudnn8.5.0_0
88
+ - pytorch-cuda=11.7=h778d358_5
89
+ - pytorch-mutex=1.0=cuda
90
+ - readline=8.2=h5eee18b_0
91
+ - requests=2.31.0=pyhd8ed1ab_0
92
+ - setuptools=67.8.0=py39h06a4308_0
93
+ - six=1.16.0=pyhd3eb1b0_1
94
+ - sqlite=3.41.2=h5eee18b_0
95
+ - sympy=1.11.1=py39h06a4308_0
96
+ - tbb=2021.8.0=hdb19cb5_0
97
+ - tk=8.6.12=h1ccaba5_0
98
+ - torchaudio=2.0.2=py39_cu117
99
+ - torchtext=0.6.0=py_1
100
+ - torchtriton=2.0.0=py39
101
+ - torchvision=0.15.2=py39_cu117
102
+ - tqdm=4.65.0=py39hb070fc8_0
103
+ - typing_extensions=4.6.3=py39h06a4308_0
104
+ - urllib3=1.26.16=py39h06a4308_0
105
+ - wheel=0.38.4=py39h06a4308_0
106
+ - xz=5.4.2=h5eee18b_0
107
+ - zlib=1.2.13=h5eee18b_0
108
+ - zstd=1.5.5=hc292b87_0
109
+ - pip:
110
+ - absl-py==1.4.0
111
+ - accelerate==0.21.0
112
+ - anyio==4.0.0
113
+ - appdirs==1.4.4
114
+ - argon2-cffi==23.1.0
115
+ - argon2-cffi-bindings==21.2.0
116
+ - arrow==1.3.0
117
+ - asttokens==2.4.1
118
+ - async-lru==2.0.4
119
+ - attrs==23.1.0
120
+ - awscli==1.29.84
121
+ - babel==2.13.1
122
+ - beautifulsoup4==4.12.2
123
+ - bleach==6.1.0
124
+ - blobfile==2.0.2
125
+ - boltons==23.0.0
126
+ - botocore==1.31.84
127
+ - cachetools==5.3.1
128
+ - chardet==5.1.0
129
+ - clean-fid==0.1.35
130
+ - click==8.1.3
131
+ - clip-anytorch==2.5.2
132
+ - colorama==0.4.4
133
+ - comm==0.1.4
134
+ - contextlib2==21.6.0
135
+ - contourpy==1.1.0
136
+ - cycler==0.11.0
137
+ - debugpy==1.8.0
138
+ - decorator==5.1.1
139
+ - defusedxml==0.7.1
140
+ - docker-pycreds==0.4.0
141
+ - docutils==0.16
142
+ - einops==0.6.1
143
+ - exceptiongroup==1.1.3
144
+ - executing==2.0.1
145
+ - fastjsonschema==2.18.1
146
+ - fonttools==4.40.0
147
+ - fqdn==1.5.1
148
+ - fsspec==2023.6.0
149
+ - ftfy==6.1.1
150
+ - future==0.18.3
151
+ - gitdb==4.0.10
152
+ - gitpython==3.1.31
153
+ - google-auth==2.21.0
154
+ - google-auth-oauthlib==1.0.0
155
+ - grpcio==1.56.0
156
+ - huggingface-hub==0.15.1
157
+ - imageio==2.31.1
158
+ - importlib-metadata==6.7.0
159
+ - importlib-resources==5.12.0
160
+ - ipykernel==6.26.0
161
+ - ipython==8.17.2
162
+ - ipython-genutils==0.2.0
163
+ - ipywidgets==8.1.1
164
+ - isoduration==20.11.0
165
+ - jedi==0.19.1
166
+ - jmespath==1.0.1
167
+ - joblib==1.3.1
168
+ - json5==0.9.14
169
+ - jsonmerge==1.9.2
170
+ - jsonpickle==3.0.1
171
+ - jsonpointer==2.4
172
+ - jsonschema==4.19.0
173
+ - jsonschema-specifications==2023.7.1
174
+ - jupyter==1.0.0
175
+ - jupyter-client==8.5.0
176
+ - jupyter-console==6.6.3
177
+ - jupyter-core==5.5.0
178
+ - jupyter-events==0.8.0
179
+ - jupyter-lsp==2.2.0
180
+ - jupyter-server==2.9.1
181
+ - jupyter-server-terminals==0.4.4
182
+ - jupyterlab==4.0.8
183
+ - jupyterlab-pygments==0.2.2
184
+ - jupyterlab-server==2.25.0
185
+ - jupyterlab-widgets==3.0.9
186
+ - k-diffusion==0.0.16
187
+ - kiwisolver==1.4.4
188
+ - kornia==0.7.0
189
+ - lazy-loader==0.3
190
+ - lmdb==1.4.1
191
+ - lxml==4.9.2
192
+ - markdown==3.4.3
193
+ - matplotlib==3.7.1
194
+ - matplotlib-inline==0.1.6
195
+ - mido==1.2.10
196
+ - mistune==3.0.2
197
+ - ml-collections==0.1.1
198
+ - more-itertools==10.0.0
199
+ - music21==8.3.0
200
+ - nbclient==0.8.0
201
+ - nbconvert==7.10.0
202
+ - nbformat==5.9.2
203
+ - nest-asyncio==1.5.8
204
+ - notebook==7.0.6
205
+ - notebook-shim==0.2.3
206
+ - oauthlib==3.2.2
207
+ - omegaconf==2.0.0
208
+ - overrides==7.4.0
209
+ - packaging==23.1
210
+ - pandas==2.0.2
211
+ - pandocfilters==1.5.0
212
+ - parso==0.8.3
213
+ - pathtools==0.1.2
214
+ - pexpect==4.8.0
215
+ - platformdirs==3.11.0
216
+ - prometheus-client==0.18.0
217
+ - prompt-toolkit==3.0.39
218
+ - protobuf==4.23.3
219
+ - psutil==5.9.5
220
+ - ptyprocess==0.7.0
221
+ - pure-eval==0.2.2
222
+ - pyasn1==0.5.0
223
+ - pyasn1-modules==0.3.0
224
+ - pycryptodomex==3.18.0
225
+ - pygments==2.16.1
226
+ - pyparsing==3.1.0
227
+ - python-dateutil==2.8.2
228
+ - python-json-logger==2.0.7
229
+ - pytorch-lightning==1.0.8
230
+ - pytz==2023.3
231
+ - pywavelets==1.4.1
232
+ - pyyaml==6.0
233
+ - pyzmq==25.1.1
234
+ - qtconsole==5.4.4
235
+ - qtpy==2.4.1
236
+ - referencing==0.30.2
237
+ - regex==2023.8.8
238
+ - requests-oauthlib==1.3.1
239
+ - resize-right==0.0.2
240
+ - rfc3339-validator==0.1.4
241
+ - rfc3986-validator==0.1.1
242
+ - rotary-embedding-torch==0.3.2
243
+ - rpds-py==0.9.2
244
+ - rsa==4.7.2
245
+ - s3transfer==0.7.0
246
+ - safetensors==0.3.1
247
+ - scikit-image==0.21.0
248
+ - scikit-learn==1.3.2
249
+ - scipy==1.11.2
250
+ - seaborn==0.13.0
251
+ - send2trash==1.8.2
252
+ - sentry-sdk==1.25.1
253
+ - setproctitle==1.3.2
254
+ - smmap==5.0.0
255
+ - sniffio==1.3.0
256
+ - soupsieve==2.5
257
+ - stack-data==0.6.3
258
+ - tensorboard==2.13.0
259
+ - tensorboard-data-server==0.7.1
260
+ - terminado==0.17.1
261
+ - threadpoolctl==3.2.0
262
+ - tifffile==2023.8.12
263
+ - timm==0.9.2
264
+ - tinycss2==1.2.1
265
+ - tomli==2.0.1
266
+ - torchdata==0.6.1
267
+ - torchdiffeq==0.2.3
268
+ - torchsde==0.2.5
269
+ - tornado==6.3.3
270
+ - traitlets==5.13.0
271
+ - trampoline==0.1.2
272
+ - types-python-dateutil==2.8.19.14
273
+ - tzdata==2023.3
274
+ - uri-template==1.3.0
275
+ - wandb==0.15.4
276
+ - wcwidth==0.2.6
277
+ - webcolors==1.13
278
+ - webencodings==0.5.1
279
+ - websocket-client==1.6.4
280
+ - werkzeug==2.3.6
281
+ - widgetsnbextension==4.0.9
282
+ - zipp==3.15.0
guided_diffusion/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Codebase for "Improved Denoising Diffusion Probabilistic Models".
3
+ """
guided_diffusion/condition_functions.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .pr_datasets_all import FUNC_DICT
10
+ import matplotlib.pyplot as plt
11
+
12
+ plt.rcParams["figure.figsize"] = (20, 3)
13
+ plt.rcParams['figure.dpi'] = 300
14
+ plt.rcParams['savefig.dpi'] = 300
15
+
16
+
17
+ def model_fn(x, t, y=None, rule=None,
18
+ model=nn.Identity(), num_classes=3, class_cond=True, cfg=False, w=0.):
19
+ # y has to be composer, rule is a dummy input
20
+ y_null = th.tensor([num_classes] * x.shape[0], device=x.device)
21
+ if class_cond:
22
+ if cfg:
23
+ return (1 + w) * model(x, t, y) - w * model(x, t, y_null)
24
+ else:
25
+ return model(x, t, y)
26
+ else:
27
+ return model(x, t, y_null)
28
+
29
+
30
+ def dc_model_fn(x, t, y=None, rule=None,
31
+ model=nn.Identity(), num_classes=3, class_cond=True, cfg=False, w=0.):
32
+ # diffcollage score function takes in 4 x pitch x time
33
+ x = x.permute(0, 1, 3, 2)
34
+ y_null = th.tensor([num_classes] * x.shape[0], device=x.device)
35
+ if class_cond:
36
+ if cfg:
37
+ eps = (1 + w) * model(x, t, y) - w * model(x, t, y_null)
38
+ return eps.permute(0, 1, 3, 2) # need to return 4 x time x pitch
39
+ else:
40
+ return model(x, t, y).permute(0, 1, 3, 2)
41
+ else:
42
+ return model(x, t, y_null).permute(0, 1, 3, 2)
43
+
44
+
45
+ # y is a dummy input for cond_fn, rule is the real input
46
+ def grad_nn_zt_xentropy(x, y=None, rule=None, classifier=nn.Identity()):
47
+ # Xentropy cond_fn
48
+ assert rule is not None
49
+ t = th.zeros(x.shape[0], device=x.device)
50
+ with th.enable_grad():
51
+ x_in = x.detach().requires_grad_(True)
52
+ logits = classifier(x_in, t)
53
+ log_probs = F.log_softmax(logits, dim=-1)
54
+ selected = log_probs[range(len(logits)), rule.view(-1)]
55
+ return th.autograd.grad(selected.sum(), x_in)[0]
56
+
57
+
58
+ def grad_nn_zt_mse(x, t, y=None, rule=None, classifier_scale=10., classifier=nn.Identity()):
59
+ assert rule is not None
60
+ with th.enable_grad():
61
+ x_in = x.detach().requires_grad_(True)
62
+ logits = classifier(x_in, t)
63
+ log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
64
+ return th.autograd.grad(log_probs.sum(), x_in)[0] * classifier_scale
65
+
66
+
67
+ def grad_nn_zt_chord(x, t, y=None, rule=None, classifier_scale=10., classifier=nn.Identity(), both=False):
68
+ assert rule is not None
69
+ with th.enable_grad():
70
+ x_in = x.detach().requires_grad_(True)
71
+ key_logits, chord_logits = classifier(x_in, t)
72
+ if both:
73
+ rule_key = rule[:, :1]
74
+ rule_chord = rule[:, 1:]
75
+ rule_chord = rule_chord.reshape(-1)
76
+ chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1])
77
+ key_log_probs = - F.cross_entropy(key_logits, rule_key, reduction="none")
78
+ chord_log_probs = - F.cross_entropy(chord_logits, rule_chord, reduction="none")
79
+ chord_log_probs = chord_log_probs.reshape(x_in.shape[0], -1).mean(dim=-1)
80
+ log_probs = key_log_probs + chord_log_probs
81
+ else:
82
+ rule = rule.reshape(-1)
83
+ chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1])
84
+ log_probs = - F.cross_entropy(chord_logits, rule, reduction="none")
85
+ return th.autograd.grad(log_probs.sum(), x_in)[0] * classifier_scale
86
+
87
+
88
+ def nn_z0_chord_dummy(x, t, y=None, rule=None, classifier_scale=0.1, classifier=nn.Identity(), both=False):
89
+ # classifier_scale is equivalent to step_size
90
+ t = th.zeros(x.shape[0], device=x.device)
91
+ key_logits, chord_logits = classifier(x, t)
92
+ if both:
93
+ rule_key = rule[:, :1]
94
+ rule_chord = rule[:, 1:]
95
+ rule_chord = rule_chord.reshape(-1)
96
+ chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1])
97
+ key_log_probs = - F.cross_entropy(key_logits, rule_key, reduction="none")
98
+ chord_log_probs = - F.cross_entropy(chord_logits, rule_chord, reduction="none")
99
+ chord_log_probs = chord_log_probs.reshape(x.shape[0], -1).mean(dim=-1)
100
+ log_probs = key_log_probs + chord_log_probs
101
+ else:
102
+ rule = rule.reshape(-1)
103
+ chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1])
104
+ log_probs = - F.cross_entropy(chord_logits, rule, reduction="none")
105
+ log_probs = log_probs.reshape(x.shape[0], -1).mean(dim=-1)
106
+ return log_probs * classifier_scale
107
+
108
+
109
+ def nn_z0_mse_dummy(x, t, y=None, rule=None, classifier_scale=0.1, classifier=nn.Identity()):
110
+ # mse cond_fn, t is a dummy variable b/c wrap_model in respace
111
+ assert rule is not None
112
+ t = th.zeros(x.shape[0], device=x.device)
113
+ logits = classifier(x, t)
114
+ log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
115
+ return log_probs * classifier_scale
116
+
117
+
118
+ def nn_z0_mse(x, rule=None, classifier=nn.Identity()):
119
+ # mse cond_fn, t is a dummy variable b/c wrap_model in respace
120
+ t = th.zeros(x.shape[0], device=x.device)
121
+ logits = classifier(x, t)
122
+ log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
123
+ return log_probs
124
+
125
+
126
+ def rule_x0_mse_dummy(x, t, y=None, rule=None, rule_name='pitch_hist'):
127
+ # use differentiable rule to differentiate through rule(x_0), t is a dummy variable b/c wrap_model in respace
128
+ logits = FUNC_DICT[rule_name](x)
129
+ log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
130
+ return log_probs
131
+
132
+
133
+ def rule_x0_mse(x, rule=None, rule_name='pitch_hist', soft=False):
134
+ # soften non-differentiable rule to differentiate through rule(x_0)
135
+ # soften doesn't seem to work so didn't actually take in soft as input, always set to False
136
+ logits = FUNC_DICT[rule_name](x, soft=soft)
137
+ log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
138
+ return log_probs
139
+
140
+
141
+ class _WrappedFn:
142
+ def __init__(self, fn):
143
+ self.fn = fn
144
+
145
+ def __call__(self, x, t, y=None, rule=None):
146
+ return self.fn(x, t, y, rule)
147
+
148
+
149
+ function_map = {
150
+ "grad_nn_zt_xentropy": grad_nn_zt_xentropy,
151
+ "grad_nn_zt_mse": grad_nn_zt_mse,
152
+ "grad_nn_zt_chord": grad_nn_zt_chord,
153
+ "nn_z0_chord_dummy": nn_z0_chord_dummy,
154
+ "nn_z0_mse_dummy": nn_z0_mse_dummy,
155
+ "nn_z0_mse": nn_z0_mse,
156
+ "rule_x0_mse_dummy": rule_x0_mse_dummy,
157
+ "rule_x0_mse": rule_x0_mse
158
+ }
159
+
160
+
161
+ def composite_nn_zt(x, t, y=None, rule=None, fns=None, classifier_scales=None, classifiers=None, rule_names=None):
162
+ num_classifiers = len(classifiers)
163
+ out = 0
164
+ for i in range(num_classifiers):
165
+ out += function_map[fns[i]](x, t, y=y, rule=rule[rule_names[i]],
166
+ classifier_scale=classifier_scales[i], classifier=classifiers[i])
167
+ return out
168
+
169
+
170
+ def composite_rule(x, t, y=None, rule=None, fns=None, classifier_scales=None, rule_names=None):
171
+ out = 0
172
+ for i in range(len(fns)):
173
+ out += function_map[fns[i]](x, t, y=y, rule=rule[rule_names[i]], rule_name=rule_names[i]) * classifier_scales[i]
174
+ return out
guided_diffusion/dist_util.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for distributed training.
3
+ """
4
+
5
+ import io
6
+ import os
7
+ import socket
8
+
9
+ import blobfile as bf
10
+ from mpi4py import MPI
11
+ import torch as th
12
+ import torch.distributed as dist
13
+
14
+ # Change this to reflect your cluster layout.
15
+ # The GPU for a given rank is (rank % GPUS_PER_NODE).
16
+ GPUS_PER_NODE = 2
17
+
18
+ SETUP_RETRY_COUNT = 3
19
+
20
+
21
+ def setup_dist(port=None):
22
+ """
23
+ Setup a distributed process group.
24
+ For NGC, set port = "8023"
25
+ """
26
+ if dist.is_initialized():
27
+ return
28
+ if not os.environ.get("CUDA_VISIBLE_DEVICES"):
29
+ os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
30
+
31
+ comm = MPI.COMM_WORLD
32
+ backend = "gloo" if not th.cuda.is_available() else "nccl"
33
+
34
+ if backend == "gloo":
35
+ hostname = "localhost"
36
+ else:
37
+ hostname = socket.gethostbyname(socket.getfqdn())
38
+ if port is not None:
39
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
40
+ else:
41
+ os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
42
+ os.environ["RANK"] = str(comm.rank)
43
+ os.environ["WORLD_SIZE"] = str(comm.size)
44
+
45
+ if port is not None:
46
+ os.environ["MASTER_PORT"] = port
47
+ else:
48
+ port = comm.bcast(_find_free_port(), root=0)
49
+ os.environ["MASTER_PORT"] = str(port)
50
+ dist.init_process_group(backend=backend, init_method="env://")
51
+ th.cuda.set_device(comm.rank) # need to run on hpc
52
+
53
+ return comm
54
+
55
+
56
+ def dev():
57
+ """
58
+ Get the device to use for torch.distributed.
59
+ """
60
+ if th.cuda.is_available():
61
+ return th.device(f"cuda")
62
+ return th.device("cpu")
63
+
64
+
65
+ def load_state_dict(path, **kwargs):
66
+ """
67
+ Load a PyTorch file without redundant fetches across MPI ranks.
68
+ """
69
+ chunk_size = 2 ** 30 # MPI has a relatively small size limit
70
+ if MPI.COMM_WORLD.Get_rank() == 0:
71
+ with bf.BlobFile(path, "rb") as f:
72
+ data = f.read()
73
+ num_chunks = len(data) // chunk_size
74
+ if len(data) % chunk_size:
75
+ num_chunks += 1
76
+ MPI.COMM_WORLD.bcast(num_chunks)
77
+ for i in range(0, len(data), chunk_size):
78
+ MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
79
+ else:
80
+ num_chunks = MPI.COMM_WORLD.bcast(None)
81
+ data = bytes()
82
+ for _ in range(num_chunks):
83
+ data += MPI.COMM_WORLD.bcast(None)
84
+
85
+ return th.load(io.BytesIO(data), **kwargs)
86
+
87
+
88
+ def sync_params(params):
89
+ """
90
+ Synchronize a sequence of Tensors across ranks from rank 0.
91
+ """
92
+ for p in params:
93
+ with th.no_grad():
94
+ dist.broadcast(p, 0)
95
+
96
+
97
+ def _find_free_port():
98
+ try:
99
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
100
+ s.bind(("", 0))
101
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
102
+ return s.getsockname()[1]
103
+ finally:
104
+ s.close()
guided_diffusion/dit.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from rotary_embedding_torch import RotaryEmbedding
16
+ from torch.jit import Final
17
+ import numpy as np
18
+ import math
19
+ from timm.models.vision_transformer import Attention, Mlp
20
+ from timm.models.vision_transformer_relpos import RelPosAttention
21
+ from timm.layers import Format, nchw_to, to_2tuple, _assert, RelPosBias, use_fused_attn
22
+ from typing import Callable, List, Optional, Tuple, Union
23
+ from functools import partial
24
+
25
+ def modulate(x, shift, scale):
26
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
27
+
28
+
29
+ #################################################################################
30
+ # Embedding Layers for Timesteps and Class Labels #
31
+ #################################################################################
32
+
33
+ class TimestepEmbedder(nn.Module):
34
+ """
35
+ Embeds scalar timesteps into vector representations.
36
+ """
37
+ def __init__(self, hidden_size, frequency_embedding_size=256):
38
+ super().__init__()
39
+ self.mlp = nn.Sequential(
40
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
41
+ nn.SiLU(),
42
+ nn.Linear(hidden_size, hidden_size, bias=True),
43
+ )
44
+ self.frequency_embedding_size = frequency_embedding_size
45
+
46
+ @staticmethod
47
+ def timestep_embedding(t, dim, max_period=10000):
48
+ """
49
+ Create sinusoidal timestep embeddings.
50
+ :param t: a 1-D Tensor of N indices, one per batch element.
51
+ These may be fractional.
52
+ :param dim: the dimension of the output.
53
+ :param max_period: controls the minimum frequency of the embeddings.
54
+ :return: an (N, D) Tensor of positional embeddings.
55
+ """
56
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
57
+ half = dim // 2
58
+ freqs = torch.exp(
59
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
60
+ ).to(device=t.device)
61
+ args = t[:, None].float() * freqs[None]
62
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
63
+ if dim % 2:
64
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
65
+ return embedding
66
+
67
+ def forward(self, t):
68
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
69
+ t_emb = self.mlp(t_freq)
70
+ return t_emb
71
+
72
+
73
+ class LabelEmbedder(nn.Module):
74
+ """
75
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
76
+ """
77
+ def __init__(self, num_classes, hidden_size, dropout_prob):
78
+ super().__init__()
79
+ use_cfg_embedding = dropout_prob > 0
80
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
81
+ self.num_classes = num_classes
82
+ self.dropout_prob = dropout_prob
83
+
84
+ def token_drop(self, labels, force_drop_ids=None):
85
+ """
86
+ Drops labels to enable classifier-free guidance.
87
+ """
88
+ if force_drop_ids is None:
89
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
90
+ else:
91
+ drop_ids = force_drop_ids == 1
92
+ labels = torch.where(drop_ids, self.num_classes, labels)
93
+ return labels
94
+
95
+ def forward(self, labels, train, force_drop_ids=None):
96
+ use_dropout = self.dropout_prob > 0
97
+ if (train and use_dropout) or (force_drop_ids is not None):
98
+ labels = self.token_drop(labels, force_drop_ids)
99
+ embeddings = self.embedding_table(labels)
100
+ return embeddings
101
+
102
+
103
+ #################################################################################
104
+ # Embedding Layers for Patches that Support H != W #
105
+ #################################################################################
106
+
107
+ class PatchEmbed(nn.Module):
108
+ """ 2D Image to Patch Embedding
109
+ """
110
+ output_fmt: Format
111
+
112
+ def __init__(
113
+ self,
114
+ img_size: Optional[Union[int, tuple, list]] = 224,
115
+ patch_size: Union[int, tuple, list] = 16,
116
+ in_chans: int = 3,
117
+ embed_dim: int = 768,
118
+ norm_layer: Optional[Callable] = None,
119
+ flatten: bool = True,
120
+ output_fmt: Optional[str] = None,
121
+ bias: bool = True,
122
+ strict_img_size: bool = True,
123
+ ):
124
+ super().__init__()
125
+ self.patch_size = to_2tuple(patch_size)
126
+ if img_size is not None:
127
+ if isinstance(img_size, int):
128
+ self.img_size = to_2tuple(img_size)
129
+ elif len(img_size) == 1:
130
+ self.img_size = to_2tuple(img_size[0])
131
+ else:
132
+ self.img_size = img_size
133
+ self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
134
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
135
+ else:
136
+ self.img_size = None
137
+ self.grid_size = None
138
+ self.num_patches = None
139
+
140
+ if output_fmt is not None:
141
+ self.flatten = False
142
+ self.output_fmt = Format(output_fmt)
143
+ else:
144
+ # flatten spatial dim and transpose to channels last, kept for bwd compat
145
+ self.flatten = flatten
146
+ self.output_fmt = Format.NCHW
147
+ self.strict_img_size = strict_img_size
148
+
149
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
150
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
151
+
152
+ def forward(self, x):
153
+ B, C, H, W = x.shape
154
+ if self.img_size is not None:
155
+ if self.strict_img_size:
156
+ _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
157
+ _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
158
+ else:
159
+ _assert(
160
+ H % self.patch_size[0] == 0,
161
+ f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
162
+ )
163
+ _assert(
164
+ W % self.patch_size[1] == 0,
165
+ f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
166
+ )
167
+
168
+ x = self.proj(x)
169
+ if self.flatten:
170
+ x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
171
+ elif self.output_fmt != Format.NCHW:
172
+ x = nchw_to(x, self.output_fmt)
173
+ x = self.norm(x)
174
+ return x
175
+
176
+
177
+ class FlattenNorm(nn.Module):
178
+ """ Flatten 2D Image to a vector
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ img_size: Optional[Union[int, tuple, list]] = 224,
184
+ embed_dim: int = 768,
185
+ norm_layer: Optional[Callable] = None,
186
+ ):
187
+ super().__init__()
188
+ self.num_patches = max(img_size)
189
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
190
+ # todo: hard code 64 and hidden_dim for now
191
+ self.MLP = nn.Sequential(nn.Linear(64, 256), nn.SiLU(), nn.Linear(256, embed_dim))
192
+
193
+ def forward(self, x):
194
+ x = x.permute(0, 2, 1, 3).flatten(2) # B x 4 x 128 x 16 -> B x 128 x 4 x 16 - > B x 128 x 64
195
+ x = self.MLP(x) # B x 128 x 768
196
+ x = self.norm(x)
197
+ return x
198
+
199
+
200
+ class FlattenPatchify1D(nn.Module):
201
+ """ Flatten 2D Image to a vector with pitch per token
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ in_channels: int = 4,
207
+ img_size: Optional[Union[int, tuple, list]] = 224,
208
+ embed_dim: int = 768,
209
+ patch_size: int = 8,
210
+ norm_layer: Optional[Callable] = None,
211
+ ):
212
+ super().__init__()
213
+ # dummy, is not needed by the rotary model, but needed for REL and DiT
214
+ self.num_patches = img_size[0] * img_size[1] // patch_size # img_size: 128x16
215
+ self.patch_size = patch_size
216
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
217
+ self.MLP = nn.Sequential(nn.Linear(in_channels * patch_size, 256), nn.SiLU(), nn.Linear(256, embed_dim))
218
+
219
+ def forward(self, x):
220
+ x = x.permute(0, 2, 3, 1) # B x c x 128 x 16 -> B x 128 x 16 x c
221
+ b, n_time, n_pitch, c = x.shape
222
+ num_patches = n_time * n_pitch // self.patch_size
223
+ # B x 128 x 16 x 4 -> B x (128 x 16 / 8) x (4 * 8)
224
+ x = x.reshape(b, num_patches, -1)
225
+ x = self.MLP(x) # B x 256 x 768
226
+ x = self.norm(x)
227
+ return x
228
+
229
+
230
+ #################################################################################
231
+ # Core DiT Model #
232
+ #################################################################################
233
+
234
+ class RotaryAttention(nn.Module):
235
+ fused_attn: Final[bool]
236
+
237
+ def __init__(
238
+ self,
239
+ dim,
240
+ num_heads=8,
241
+ qkv_bias=False,
242
+ qk_norm=False,
243
+ attn_drop=0.,
244
+ proj_drop=0.,
245
+ norm_layer=nn.LayerNorm,
246
+ rotary_emb=None,
247
+ ):
248
+ super().__init__()
249
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
250
+ self.num_heads = num_heads
251
+ self.head_dim = dim // num_heads
252
+ self.scale = self.head_dim ** -0.5
253
+ self.fused_attn = use_fused_attn()
254
+ self.rotary_emb = rotary_emb
255
+
256
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
257
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
258
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
259
+ self.attn_drop = nn.Dropout(attn_drop)
260
+ self.proj = nn.Linear(dim, dim)
261
+ self.proj_drop = nn.Dropout(proj_drop)
262
+
263
+ def forward(self, x):
264
+ B, N, C = x.shape
265
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
266
+ q, k, v = qkv.unbind(0)
267
+ q, k = self.q_norm(q), self.k_norm(k)
268
+
269
+ if self.rotary_emb is not None:
270
+ q = self.rotary_emb.rotate_queries_or_keys(q)
271
+ k = self.rotary_emb.rotate_queries_or_keys(k)
272
+
273
+ if self.fused_attn:
274
+ x = F.scaled_dot_product_attention(
275
+ q, k, v,
276
+ dropout_p=self.attn_drop.p,
277
+ )
278
+ else:
279
+ q = q * self.scale
280
+ attn = q @ k.transpose(-2, -1)
281
+ attn = attn.softmax(dim=-1)
282
+ attn = self.attn_drop(attn)
283
+ x = attn @ v
284
+
285
+ x = x.transpose(1, 2).reshape(B, N, C)
286
+ x = self.proj(x)
287
+ x = self.proj_drop(x)
288
+ return x
289
+
290
+
291
+ class DiTBlock(nn.Module):
292
+ """
293
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
294
+ """
295
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
296
+ super().__init__()
297
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
298
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
299
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
300
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
301
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
302
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
303
+ self.adaLN_modulation = nn.Sequential(
304
+ nn.SiLU(),
305
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
306
+ )
307
+
308
+ def forward(self, x, c):
309
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
310
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
311
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
312
+ return x
313
+
314
+
315
+ class DiTBlockRotary(nn.Module):
316
+ """
317
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning & rotary attention.
318
+ """
319
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, rotary_emb=None, **block_kwargs):
320
+ super().__init__()
321
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
322
+ self.attn = RotaryAttention(hidden_size, num_heads=num_heads, qkv_bias=True, rotary_emb=rotary_emb, **block_kwargs)
323
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
324
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
325
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
326
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
327
+ self.adaLN_modulation = nn.Sequential(
328
+ nn.SiLU(),
329
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
330
+ )
331
+
332
+ def forward(self, x, c):
333
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
334
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
335
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
336
+ return x
337
+
338
+
339
+ class FinalLayer(nn.Module):
340
+ """
341
+ The final layer of DiT.
342
+ """
343
+ def __init__(self, hidden_size, patch_size, out_channels):
344
+ super().__init__()
345
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
346
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
347
+ self.adaLN_modulation = nn.Sequential(
348
+ nn.SiLU(),
349
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
350
+ )
351
+
352
+ def forward(self, x, c):
353
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
354
+ x = modulate(self.norm_final(x), shift, scale)
355
+ x = self.linear(x)
356
+ return x
357
+
358
+
359
+ class FinalLayerPatch1D(nn.Module):
360
+ """
361
+ The final layer of DiT with 1d Patchify.
362
+ """
363
+ def __init__(self, hidden_size, out_channels, patch_size_1d=16):
364
+ super().__init__()
365
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
366
+ self.linear = nn.Linear(hidden_size, patch_size_1d*out_channels, bias=True)
367
+ self.adaLN_modulation = nn.Sequential(
368
+ nn.SiLU(),
369
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
370
+ )
371
+
372
+ def forward(self, x, c):
373
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
374
+ x = modulate(self.norm_final(x), shift, scale)
375
+ x = self.linear(x)
376
+ return x
377
+
378
+
379
+ class DiT(nn.Module):
380
+ """
381
+ Diffusion model with a Transformer backbone.
382
+ """
383
+ def __init__(
384
+ self,
385
+ input_size=32,
386
+ patch_size=2,
387
+ in_channels=3,
388
+ hidden_size=1152,
389
+ depth=28,
390
+ num_heads=16,
391
+ mlp_ratio=4.0,
392
+ class_dropout_prob=0.1,
393
+ num_classes=9, # cluster composers into 9 groups
394
+ learn_sigma=True,
395
+ patchify=True,
396
+ ):
397
+ super().__init__()
398
+ self.learn_sigma = learn_sigma
399
+ self.in_channels = in_channels
400
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
401
+ self.patch_size = patch_size
402
+ self.num_heads = num_heads
403
+ self.input_size = input_size
404
+ self.patchify = patchify
405
+
406
+ if patchify:
407
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
408
+ else:
409
+ self.x_embedder = FlattenNorm(input_size, hidden_size)
410
+ self.t_embedder = TimestepEmbedder(hidden_size)
411
+ self.num_classes = num_classes
412
+ if self.num_classes:
413
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
414
+ num_patches = self.x_embedder.num_patches
415
+ # Will use fixed sin-cos embedding:
416
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
417
+
418
+ self.blocks = nn.ModuleList([
419
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
420
+ ])
421
+ if patchify:
422
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
423
+ else:
424
+ self.final_layer = FinalLayerPatch1D(hidden_size, self.out_channels, patch_size)
425
+ self.initialize_weights()
426
+
427
+ def initialize_weights(self):
428
+ # Initialize transformer layers:
429
+ def _basic_init(module):
430
+ if isinstance(module, nn.Linear):
431
+ torch.nn.init.xavier_uniform_(module.weight)
432
+ if module.bias is not None:
433
+ nn.init.constant_(module.bias, 0)
434
+ self.apply(_basic_init)
435
+
436
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
437
+ if self.patchify:
438
+ if isinstance(self.input_size, int) or len(self.input_size) == 1:
439
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5), int(self.x_embedder.num_patches ** 0.5))
440
+ else:
441
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size[0], self.x_embedder.grid_size[1])
442
+ else:
443
+ # 1D position encoding
444
+ pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1],
445
+ np.arange(self.x_embedder.num_patches, dtype=np.float32))
446
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
447
+
448
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
449
+ if self.patchify:
450
+ w = self.x_embedder.proj.weight.data
451
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
452
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
453
+
454
+ # Initialize label embedding table:
455
+ if self.num_classes:
456
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
457
+
458
+ # Initialize timestep embedding MLP:
459
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
460
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
461
+
462
+ # Zero-out adaLN modulation layers in DiT blocks:
463
+ for block in self.blocks:
464
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
465
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
466
+
467
+ # Zero-out output layers:
468
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
469
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
470
+ nn.init.constant_(self.final_layer.linear.weight, 0)
471
+ nn.init.constant_(self.final_layer.linear.bias, 0)
472
+
473
+ def unpatchify(self, x):
474
+ """
475
+ x: (N, T, patch_size**2 * C)
476
+ imgs: (N, H, W, C)
477
+ """
478
+ c = self.out_channels
479
+ p = self.x_embedder.patch_size[0]
480
+ if isinstance(self.input_size, int) or len(self.input_size) == 1:
481
+ h = w = int(x.shape[1] ** 0.5)
482
+ assert h * w == x.shape[1]
483
+ else:
484
+ h = self.input_size[0] // self.patch_size
485
+ w = self.input_size[1] // self.patch_size
486
+
487
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
488
+ x = torch.einsum('nhwpqc->nchpwq', x)
489
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
490
+ return imgs
491
+
492
+ def unflatten(self, x):
493
+ c = self.out_channels
494
+ x = x.reshape(shape=(x.shape[0], x.shape[1], c, -1))
495
+ imgs = x.permute(0, 2, 1, 3)
496
+ return imgs
497
+
498
+ def forward(self, x, t, y=None):
499
+ """
500
+ Forward pass of DiT.
501
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
502
+ t: (N,) tensor of diffusion timesteps
503
+ y: (N,) tensor of class labels
504
+ """
505
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
506
+ c = self.t_embedder(t) # (N, D)
507
+ if self.num_classes and y is not None:
508
+ y = self.y_embedder(y, self.training) # (N, D)
509
+ c = c + y # (N, D)
510
+ for block in self.blocks:
511
+ x = block(x, c) # (N, T, D)
512
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
513
+ if self.patchify:
514
+ x = self.unpatchify(x) # (N, out_channels, H, W)
515
+ else:
516
+ x = self.unflatten(x)
517
+ return x
518
+
519
+ def forward_with_cfg(self, x, t, y, cfg_scale):
520
+ """
521
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
522
+ """
523
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
524
+ half = x[: len(x) // 2]
525
+ combined = torch.cat([half, half], dim=0)
526
+ model_out = self.forward(combined, t, y)
527
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
528
+ # three channels by default. The standard approach to cfg applies it to all channels.
529
+ # This can be done by uncommenting the following line and commenting-out the line following that.
530
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
531
+ eps, rest = model_out[:, :3], model_out[:, 3:]
532
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
533
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
534
+ eps = torch.cat([half_eps, half_eps], dim=0)
535
+ return torch.cat([eps, rest], dim=1)
536
+
537
+
538
+ class DiTRotary(nn.Module):
539
+ """
540
+ Diffusion model with a Transformer backbone, with rotary position embedding.
541
+ Use 1D position encoding, patchify is set to False
542
+ """
543
+
544
+ def __init__(
545
+ self,
546
+ input_size=32,
547
+ patch_size=8, # patch size for 1D patchify
548
+ in_channels=3,
549
+ hidden_size=1152,
550
+ depth=28,
551
+ num_heads=16,
552
+ mlp_ratio=4.0,
553
+ class_dropout_prob=0.1,
554
+ num_classes=9, # cluster composers into 9 groups
555
+ learn_sigma=True,
556
+ ):
557
+ super().__init__()
558
+ self.learn_sigma = learn_sigma
559
+ self.in_channels = in_channels
560
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
561
+ self.patch_size = patch_size
562
+ self.num_heads = num_heads
563
+ self.input_size = input_size
564
+
565
+ self.x_embedder = FlattenPatchify1D(in_channels, input_size, hidden_size, patch_size)
566
+ self.t_embedder = TimestepEmbedder(hidden_size)
567
+ self.num_classes = num_classes
568
+ if self.num_classes:
569
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
570
+
571
+ rotary_dim = int(hidden_size // num_heads * 0.5) # 0.5 is rotary percentage in multihead rope, by default 0.5
572
+ self.rotary_emb = RotaryEmbedding(rotary_dim)
573
+ self.blocks = nn.ModuleList([
574
+ DiTBlockRotary(hidden_size, num_heads, mlp_ratio=mlp_ratio, rotary_emb=self.rotary_emb) for _ in range(depth)
575
+ ])
576
+ self.final_layer = FinalLayerPatch1D(hidden_size, self.out_channels, patch_size_1d=self.patch_size)
577
+ self.initialize_weights()
578
+
579
+ def initialize_weights(self):
580
+ # Initialize transformer layers:
581
+ def _basic_init(module):
582
+ if isinstance(module, nn.Linear):
583
+ torch.nn.init.xavier_uniform_(module.weight)
584
+ if module.bias is not None:
585
+ nn.init.constant_(module.bias, 0)
586
+
587
+ self.apply(_basic_init)
588
+
589
+ # Initialize label embedding table:
590
+ if self.num_classes:
591
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
592
+
593
+ # Initialize timestep embedding MLP:
594
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
595
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
596
+
597
+ # Zero-out adaLN modulation layers in DiT blocks:
598
+ for block in self.blocks:
599
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
600
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
601
+
602
+ # Zero-out output layers:
603
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
604
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
605
+ nn.init.constant_(self.final_layer.linear.weight, 0)
606
+ nn.init.constant_(self.final_layer.linear.bias, 0)
607
+
608
+ def unpatchify(self, x):
609
+ """
610
+ x: (N, T, img_size[1] / patch_size * C)
611
+ imgs: (N, H, W, C)
612
+ """
613
+ # input_size[1] is the pitch dimension, should always be the same
614
+ x = x.reshape(shape=(x.shape[0], -1, self.input_size[1], self.out_channels))
615
+ imgs = x.permute(0, 3, 1, 2)
616
+ return imgs
617
+
618
+ def forward(self, x, t, y=None):
619
+ """
620
+ Forward pass of DiT.
621
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
622
+ t: (N,) tensor of diffusion timesteps
623
+ y: (N,) tensor of class labels
624
+ """
625
+ x = self.x_embedder(x) # (N, T, D), where T = H * W / patch_size
626
+ c = self.t_embedder(t) # (N, D)
627
+ if self.num_classes and y is not None:
628
+ y = self.y_embedder(y, self.training) # (N, D)
629
+ c = c + y # (N, D)
630
+ for block in self.blocks:
631
+ x = block(x, c) # (N, T, D)
632
+ x = self.final_layer(x, c) # (N, T, patch_size * out_channels)
633
+ x = self.unpatchify(x)
634
+ return x
635
+
636
+
637
+ class DiT_classifier(nn.Module):
638
+ """
639
+ Classifier used in classifier guidance.
640
+ """
641
+ def __init__(
642
+ self,
643
+ input_size=32,
644
+ patch_size=2,
645
+ in_channels=3,
646
+ hidden_size=1152,
647
+ depth=28,
648
+ num_heads=16,
649
+ mlp_ratio=4.0,
650
+ num_classes=9,
651
+ patchify=True,
652
+ ):
653
+ super().__init__()
654
+ self.in_channels = in_channels
655
+ self.patch_size = patch_size
656
+ self.num_heads = num_heads
657
+ self.input_size = input_size
658
+ self.patchify = patchify
659
+
660
+ if patchify:
661
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
662
+ else:
663
+ self.x_embedder = FlattenNorm(input_size, hidden_size)
664
+ self.t_embedder = TimestepEmbedder(hidden_size)
665
+ self.num_classes = num_classes
666
+ num_patches = self.x_embedder.num_patches
667
+ # Will use fixed sin-cos embedding:
668
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
669
+
670
+ self.blocks = nn.ModuleList([
671
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
672
+ ])
673
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size), requires_grad=True)
674
+ self.norm = nn.LayerNorm(hidden_size)
675
+ self.classifier_head = nn.Sequential(nn.Linear(hidden_size, hidden_size//4),
676
+ nn.SiLU(), nn.Linear(hidden_size//4, self.num_classes))
677
+ self.initialize_weights()
678
+
679
+ def initialize_weights(self):
680
+ # Initialize transformer layers:
681
+ def _basic_init(module):
682
+ if isinstance(module, nn.Linear):
683
+ torch.nn.init.xavier_uniform_(module.weight)
684
+ if module.bias is not None:
685
+ nn.init.constant_(module.bias, 0)
686
+ self.apply(_basic_init)
687
+
688
+ if self.patchify:
689
+ if isinstance(self.input_size, int) or len(self.input_size) == 1:
690
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5), int(self.x_embedder.num_patches ** 0.5))
691
+ else:
692
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size[0], self.x_embedder.grid_size[1])
693
+ else:
694
+ # 1D position encoding
695
+ pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1],
696
+ np.arange(self.x_embedder.num_patches, dtype=np.float32))
697
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
698
+
699
+ # Initialize class token
700
+ nn.init.normal_(self.cls_token, std=1e-6)
701
+
702
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
703
+ if self.patchify:
704
+ w = self.x_embedder.proj.weight.data
705
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
706
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
707
+
708
+ # Initialize timestep embedding MLP:
709
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
710
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
711
+
712
+ # Zero-out adaLN modulation layers in DiT blocks:
713
+ for block in self.blocks:
714
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
715
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
716
+
717
+ def forward(self, x, t):
718
+ """
719
+ Forward pass of DiT.
720
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
721
+ t: (N,) tensor of diffusion timesteps
722
+ y: (N,) tensor of class labels
723
+ """
724
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
725
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
726
+ c = self.t_embedder(t) # (N, D)
727
+ for block in self.blocks:
728
+ x = block(x, c) # (N, T, D)
729
+ x = x[:, 0, :] # (N, D)
730
+ x = self.norm(x)
731
+ x = self.classifier_head(x) # (N, num_classes)
732
+ return x
733
+
734
+
735
+ class DiTRotaryClassifier(nn.Module):
736
+ """
737
+ Diffusion model with a Transformer backbone, with rotary position embedding.
738
+ Use 1D position encoding, patchify is set to False
739
+ """
740
+
741
+ def __init__(
742
+ self,
743
+ input_size=32,
744
+ patch_size=8, # patch size for 1D patchify
745
+ in_channels=3,
746
+ hidden_size=1152,
747
+ depth=28,
748
+ num_heads=16,
749
+ mlp_ratio=4.0,
750
+ num_classes=9, # cluster composers into 9 groups
751
+ chord=False,
752
+ ):
753
+ super().__init__()
754
+ self.in_channels = in_channels
755
+ self.patch_size = patch_size
756
+ self.num_heads = num_heads
757
+ self.input_size = input_size
758
+ self.chord = chord
759
+ self.hidden_size = hidden_size
760
+
761
+ self.x_embedder = FlattenPatchify1D(in_channels, input_size, hidden_size, patch_size)
762
+ self.t_embedder = TimestepEmbedder(hidden_size)
763
+ self.num_classes = num_classes
764
+
765
+ rotary_dim = int(hidden_size // num_heads * 0.5) # 0.5 is rotary percentage in multihead rope, by default 0.5
766
+ self.rotary_emb = RotaryEmbedding(rotary_dim)
767
+ self.blocks = nn.ModuleList([
768
+ DiTBlockRotary(hidden_size, num_heads, mlp_ratio=mlp_ratio, rotary_emb=self.rotary_emb) for _ in range(depth)
769
+ ])
770
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size), requires_grad=True)
771
+ self.norm = nn.LayerNorm(hidden_size)
772
+ self.classifier_head = nn.Sequential(nn.Linear(hidden_size, hidden_size//4),
773
+ nn.SiLU(), nn.Linear(hidden_size//4, self.num_classes))
774
+ if self.chord:
775
+ self.norm_key = nn.LayerNorm(hidden_size)
776
+ # predict key also: 24 major and minor keys + null
777
+ self.classifier_head_key = nn.Sequential(nn.Linear(hidden_size, hidden_size//4),
778
+ nn.SiLU(), nn.Linear(hidden_size//4, 25))
779
+ self.initialize_weights()
780
+
781
+ def initialize_weights(self):
782
+ # Initialize transformer layers:
783
+ def _basic_init(module):
784
+ if isinstance(module, nn.Linear):
785
+ torch.nn.init.xavier_uniform_(module.weight)
786
+ if module.bias is not None:
787
+ nn.init.constant_(module.bias, 0)
788
+
789
+ self.apply(_basic_init)
790
+
791
+ # Initialize class token
792
+ nn.init.normal_(self.cls_token, std=1e-6)
793
+
794
+ # Initialize timestep embedding MLP:
795
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
796
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
797
+
798
+ # Zero-out adaLN modulation layers in DiT blocks:
799
+ for block in self.blocks:
800
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
801
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
802
+
803
+ def forward(self, x, t, y=None):
804
+ """
805
+ Forward pass of DiT.
806
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
807
+ t: (N,) tensor of diffusion timesteps
808
+ y: (N,) tensor of class labels
809
+ """
810
+ if self.chord:
811
+ n_token = x.shape[2] // x.shape[3]
812
+ x = self.x_embedder(x) # (N, T, D), where T = H * W / patch_size
813
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
814
+ c = self.t_embedder(t) # (N, D)
815
+ for block in self.blocks:
816
+ x = block(x, c) # (N, T, D)
817
+ if self.chord:
818
+ x_key = x[:, 0, :]
819
+ x_key = self.norm_key(x_key)
820
+ key = self.classifier_head_key(x_key)
821
+ x_chord = x[:, 1:, :]
822
+ x_chord = x_chord.reshape(shape=[x.shape[0], n_token, -1, self.hidden_size])
823
+ x_chord = x_chord.mean(dim=-2)
824
+ x_chord = self.norm(x_chord)
825
+ chord = self.classifier_head(x_chord)
826
+ return key, chord
827
+ else:
828
+ x = x[:, 0, :] # (N, D)
829
+ x = self.norm(x)
830
+ x = self.classifier_head(x) # (N, num_classes)
831
+ return x
832
+
833
+
834
+ #################################################################################
835
+ # Sine/Cosine Positional Embedding Functions #
836
+ #################################################################################
837
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
838
+
839
+ def get_2d_sincos_pos_embed(embed_dim, grid_size_h, grid_size_w, cls_token=False, extra_tokens=0):
840
+ """
841
+ grid_size: int of the grid height and width
842
+ return:
843
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
844
+ """
845
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
846
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
847
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
848
+ grid = np.stack(grid, axis=0)
849
+
850
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
851
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
852
+ if cls_token and extra_tokens > 0:
853
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
854
+ return pos_embed
855
+
856
+
857
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
858
+ assert embed_dim % 2 == 0
859
+
860
+ # use half of dimensions to encode grid_h
861
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
862
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
863
+
864
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
865
+ return emb
866
+
867
+
868
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
869
+ """
870
+ embed_dim: output dimension for each position
871
+ pos: a list of positions to be encoded: size (M,)
872
+ out: (M, D)
873
+ """
874
+ assert embed_dim % 2 == 0
875
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
876
+ omega /= embed_dim / 2.
877
+ omega = 1. / 10000**omega # (D/2,)
878
+
879
+ pos = pos.reshape(-1) # (M,)
880
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
881
+
882
+ emb_sin = np.sin(out) # (M, D/2)
883
+ emb_cos = np.cos(out) # (M, D/2)
884
+
885
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
886
+ return emb
887
+
888
+
889
+ #################################################################################
890
+ # DiT Configs #
891
+ #################################################################################
892
+
893
+ def DiT_XL_2(**kwargs):
894
+ return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
895
+
896
+ def DiT_XL_4(**kwargs):
897
+ return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
898
+
899
+ def DiTRotary_XL_16(**kwargs):
900
+ return DiTRotary(depth=28, hidden_size=1152, patch_size=16, num_heads=16, **kwargs)
901
+
902
+ def DiTRotary_XL_8(**kwargs):
903
+ return DiTRotary(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
904
+
905
+ def DiT_XL_8(**kwargs):
906
+ return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
907
+
908
+ def DiT_L_2(**kwargs):
909
+ return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
910
+
911
+ def DiT_L_4(**kwargs):
912
+ return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
913
+
914
+ def DiT_L_8(**kwargs):
915
+ return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
916
+
917
+ def DiT_B_2(**kwargs):
918
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
919
+
920
+ def DiT_B_4(**kwargs):
921
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
922
+
923
+ def DiTRotary_B_16(**kwargs): # seq_len = 128 = 128 * 16/16
924
+ return DiTRotary(depth=12, hidden_size=768, patch_size=16, num_heads=12, **kwargs)
925
+
926
+ def DiTRotary_B_8(**kwargs): # seq_len = 256 = 128 * 16/8
927
+ return DiTRotary(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
928
+
929
+ def DiT_B_8(**kwargs):
930
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
931
+
932
+ def DiT_B_4_classifier(**kwargs):
933
+ return DiT_classifier(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
934
+
935
+ def DiT_B_8_classifier(**kwargs):
936
+ return DiT_classifier(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
937
+
938
+ def DiTRotary_B_8_classifier(**kwargs):
939
+ return DiTRotaryClassifier(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
940
+
941
+ def DiT_S_2(**kwargs):
942
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
943
+
944
+ def DiT_S_2_classifier(**kwargs):
945
+ return DiT_classifier(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
946
+
947
+ def DiTRotary_S_8_classifier(**kwargs):
948
+ return DiTRotaryClassifier(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
949
+
950
+ def DiTRotary_S_8_chord_classifier(**kwargs):
951
+ return DiTRotaryClassifier(depth=12, hidden_size=384, patch_size=8, num_heads=6, chord=True, **kwargs)
952
+
953
+ def DiT_XS_2_classifier(**kwargs):
954
+ return DiT_classifier(depth=4, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
955
+
956
+ def DiTRotary_XS_8_classifier(**kwargs):
957
+ return DiTRotaryClassifier(depth=4, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
958
+
959
+ def DiT_S_4(**kwargs):
960
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
961
+
962
+ def DiT_S_4_classifier(**kwargs):
963
+ return DiT_classifier(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
964
+
965
+ def DiT_S_8(**kwargs):
966
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
967
+
968
+
969
+ DiT_models = {
970
+ 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
971
+ 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
972
+ 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
973
+ 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
974
+ 'DiTRotary_B_16': DiTRotary_B_16, 'DiTRotary_B_8': DiTRotary_B_8,
975
+ 'DiTRotary_XL_16': DiTRotary_XL_16, 'DiTRotary_XL_8': DiTRotary_XL_8,
976
+ 'DiT-B/4-cls': DiT_B_4_classifier, 'DiT-B/8-cls': DiT_B_8_classifier,
977
+ 'DiT-S/4-cls': DiT_S_4_classifier, 'DiT-S/2-cls': DiT_S_2_classifier,
978
+ 'DiT-XS/2-cls': DiT_XS_2_classifier,
979
+ 'DiTRotary-XS/8-cls': DiTRotary_XS_8_classifier,
980
+ 'DiTRotary-S/8-cls': DiTRotary_S_8_classifier,
981
+ 'DiTRotary-S/8-chord-cls': DiTRotary_S_8_chord_classifier,
982
+ 'DiTRotary-B/8-cls': DiTRotary_B_8_classifier,
983
+ }
guided_diffusion/embed_datasets.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import os
4
+ import pandas as pd
5
+ import re
6
+ from PIL import Image
7
+ import blobfile as bf
8
+ from mpi4py import MPI
9
+ import numpy as np
10
+ from torch.utils.data import DataLoader, Dataset
11
+
12
+ CLUSTERS = {'Balakirev': 0,
13
+ 'Bartholdy': 0,
14
+ 'Bizet': 0,
15
+ 'Brahms': 0,
16
+ 'Busoni': 0,
17
+ 'Chopin': 0,
18
+ 'Grieg': 0,
19
+ 'Horowitz': 0,
20
+ 'Liszt': 0,
21
+ 'Mendelssohn': 0,
22
+ 'Moszkowski': 0,
23
+ 'Paganini': 0,
24
+ 'Saint-Saens': 0,
25
+ 'Schubert': 0,
26
+ 'Schumann': 0,
27
+ 'Strauss': 0,
28
+ 'Tchaikovsky': 0,
29
+ 'Wagner': 0,
30
+ 'Beethoven': 1,
31
+ 'Bach': 2,
32
+ 'Handel': 2,
33
+ 'Purcell': 2,
34
+ 'Barber': 3,
35
+ 'Bartok': 3,
36
+ 'Hindemith': 3,
37
+ 'Ligeti': 3,
38
+ 'Messiaen': 3,
39
+ 'Mussorgsky': 3,
40
+ 'Myaskovsky': 3,
41
+ 'Prokofiev': 3,
42
+ 'Schnittke': 3,
43
+ 'Schonberg': 3,
44
+ 'Shostakovich': 3,
45
+ 'Stravinsky': 3,
46
+ 'Debussy': 4,
47
+ 'Ravel': 4,
48
+ 'Clementi': 5,
49
+ 'Haydn': 5,
50
+ 'Mozart': 5,
51
+ 'Pachelbel': 5,
52
+ 'Scarlatti': 5,
53
+ 'Rachmaninoff': 6,
54
+ 'Scriabin': 6,
55
+ 'Gershwin': 7,
56
+ 'Kapustin': 7
57
+ }
58
+
59
+
60
+ def extract_string(file_name):
61
+ if 'loc' not in file_name:
62
+ ind = [i.start() for i in re.finditer('_', file_name)][-1]
63
+ name = file_name[:ind]
64
+ else:
65
+ name = file_name.split('loc')[0][:-1]
66
+ return name
67
+
68
+
69
+ def find_composer(name, df):
70
+ compound_composer = df.loc[df['simple_midi_name'] == name]['canonical_composer'].item()
71
+ composer = compound_composer.split(' / ')[0].split(' ')[-1] # take the last name of the first composer
72
+ result = CLUSTERS.setdefault(composer, 8) # default cluster is everyone else (8)
73
+ return result
74
+
75
+
76
+ def load_data(
77
+ *,
78
+ data_dir,
79
+ batch_size,
80
+ class_cond=False,
81
+ deterministic=False,
82
+ ):
83
+ """
84
+ For a dataset, create a generator over (images, kwargs) pairs.
85
+
86
+ Each images is an NCHW float tensor, and the kwargs dict contains zero or
87
+ more keys, each of which map to a batched Tensor of their own.
88
+ The kwargs dict can be used for class labels, in which case the key is "y"
89
+ and the values are integer tensors of class labels.
90
+
91
+ :param data_dir: a dataset directory.
92
+ :param batch_size: the batch size of each returned pair.
93
+ :param image_size: the size to which images are resized.
94
+ :param class_cond: if True, include a "y" key in returned dicts for class
95
+ label. If classes are not available and this is true, an
96
+ exception will be raised.
97
+ :param deterministic: if True, yield results in a deterministic order.
98
+ :param random_crop: if True, randomly crop the images for augmentation.
99
+ :param random_flip: if True, randomly flip the images for augmentation.
100
+ """
101
+ if not data_dir:
102
+ raise ValueError("unspecified data directory")
103
+ all_files = _list_image_files(data_dir)
104
+ classes = None
105
+ if class_cond:
106
+ # find the composer
107
+ parent_dir = os.path.join(*data_dir.split('/')[:-1])
108
+ if data_dir[0] == '/':
109
+ parent_dir = '/' + parent_dir
110
+ df = pd.read_csv(os.path.join(parent_dir, 'maestro-v3.0.0.csv'))
111
+ df['simple_midi_name'] = [midi_name[5:-5] for midi_name in df['midi_filename']]
112
+ all_file_names = bf.listdir(data_dir)
113
+ extracted_names = [extract_string(file_name) for file_name in all_file_names]
114
+ classes = [find_composer(name, df) for name in extracted_names]
115
+
116
+ dataset = ImageDataset(
117
+ all_files,
118
+ classes=classes,
119
+ shard=MPI.COMM_WORLD.Get_rank(),
120
+ num_shards=MPI.COMM_WORLD.Get_size(),
121
+ )
122
+ if deterministic:
123
+ loader = DataLoader(
124
+ dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
125
+ )
126
+ else:
127
+ loader = DataLoader(
128
+ dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
129
+ )
130
+ while True:
131
+ yield from loader
132
+
133
+
134
+ def _list_image_files(data_dir):
135
+ dirs = bf.listdir(data_dir)
136
+ return [data_dir + '/' + d for d in dirs]
137
+
138
+
139
+ class ImageDataset(Dataset):
140
+ def __init__(
141
+ self,
142
+ image_paths,
143
+ classes=None,
144
+ shard=0,
145
+ num_shards=1,
146
+ ):
147
+ super().__init__()
148
+ self.local_images = image_paths[shard:][::num_shards]
149
+ self.local_classes = None if classes is None else classes[shard:][::num_shards]
150
+
151
+ def __len__(self):
152
+ return len(self.local_images)
153
+
154
+ def __getitem__(self, idx):
155
+ path = self.local_images[idx]
156
+ arr = np.load(path)
157
+ out_dict = {}
158
+ if self.local_classes is not None:
159
+ out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
160
+ return arr, out_dict
161
+
guided_diffusion/fp16_util.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to train with 16-bit precision.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9
+
10
+ from . import logger
11
+
12
+ INITIAL_LOG_LOSS_SCALE = 20.0
13
+
14
+
15
+ def convert_module_to_f16(l):
16
+ """
17
+ Convert primitive modules to float16.
18
+ """
19
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20
+ l.weight.data = l.weight.data.half()
21
+ if l.bias is not None:
22
+ l.bias.data = l.bias.data.half()
23
+
24
+
25
+ def convert_module_to_f32(l):
26
+ """
27
+ Convert primitive modules to float32, undoing convert_module_to_f16().
28
+ """
29
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30
+ l.weight.data = l.weight.data.float()
31
+ if l.bias is not None:
32
+ l.bias.data = l.bias.data.float()
33
+
34
+
35
+ def make_master_params(param_groups_and_shapes):
36
+ """
37
+ Copy model parameters into a (differently-shaped) list of full-precision
38
+ parameters.
39
+ """
40
+ master_params = []
41
+ for param_group, shape in param_groups_and_shapes:
42
+ master_param = nn.Parameter(
43
+ _flatten_dense_tensors(
44
+ [param.detach().float() for (_, param) in param_group]
45
+ ).view(shape)
46
+ )
47
+ master_param.requires_grad = True
48
+ master_params.append(master_param)
49
+ return master_params
50
+
51
+
52
+ def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53
+ """
54
+ Copy the gradients from the model parameters into the master parameters
55
+ from make_master_params().
56
+ """
57
+ for master_param, (param_group, shape) in zip(
58
+ master_params, param_groups_and_shapes
59
+ ):
60
+ master_param.grad = _flatten_dense_tensors(
61
+ [param_grad_or_zeros(param) for (_, param) in param_group]
62
+ ).view(shape)
63
+
64
+
65
+ def master_params_to_model_params(param_groups_and_shapes, master_params):
66
+ """
67
+ Copy the master parameter data back into the model parameters.
68
+ """
69
+ # Without copying to a list, if a generator is passed, this will
70
+ # silently not copy any parameters.
71
+ for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72
+ for (_, param), unflat_master_param in zip(
73
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
74
+ ):
75
+ param.detach().copy_(unflat_master_param)
76
+
77
+
78
+ def unflatten_master_params(param_group, master_param):
79
+ return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80
+
81
+
82
+ def get_param_groups_and_shapes(named_model_params):
83
+ named_model_params = list(named_model_params)
84
+ scalar_vector_named_params = (
85
+ [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86
+ (-1),
87
+ )
88
+ matrix_named_params = (
89
+ [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90
+ (1, -1),
91
+ )
92
+ return [scalar_vector_named_params, matrix_named_params]
93
+
94
+
95
+ def master_params_to_state_dict(
96
+ model, param_groups_and_shapes, master_params, use_fp16
97
+ ):
98
+ if use_fp16:
99
+ state_dict = model.state_dict()
100
+ for master_param, (param_group, _) in zip(
101
+ master_params, param_groups_and_shapes
102
+ ):
103
+ for (name, _), unflat_master_param in zip(
104
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
105
+ ):
106
+ assert name in state_dict
107
+ state_dict[name] = unflat_master_param
108
+ else:
109
+ state_dict = model.state_dict()
110
+ for i, (name, _value) in enumerate(model.named_parameters()):
111
+ assert name in state_dict
112
+ state_dict[name] = master_params[i]
113
+ return state_dict
114
+
115
+
116
+ def state_dict_to_master_params(model, state_dict, use_fp16):
117
+ if use_fp16:
118
+ named_model_params = [
119
+ (name, state_dict[name]) for name, _ in model.named_parameters()
120
+ ]
121
+ param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122
+ master_params = make_master_params(param_groups_and_shapes)
123
+ else:
124
+ master_params = [state_dict[name] for name, _ in model.named_parameters()]
125
+ return master_params
126
+
127
+
128
+ def zero_master_grads(master_params):
129
+ for param in master_params:
130
+ param.grad = None
131
+
132
+
133
+ def zero_grad(model_params):
134
+ for param in model_params:
135
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136
+ if param.grad is not None:
137
+ param.grad.detach_()
138
+ param.grad.zero_()
139
+
140
+
141
+ def param_grad_or_zeros(param):
142
+ if param.grad is not None:
143
+ return param.grad.data.detach()
144
+ else:
145
+ return th.zeros_like(param)
146
+
147
+
148
+ class MixedPrecisionTrainer:
149
+ def __init__(
150
+ self,
151
+ *,
152
+ model,
153
+ use_fp16=False,
154
+ fp16_scale_growth=1e-3,
155
+ initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156
+ ):
157
+ self.model = model
158
+ self.use_fp16 = use_fp16
159
+ self.fp16_scale_growth = fp16_scale_growth
160
+
161
+ self.model_params = list(self.model.parameters())
162
+ self.master_params = self.model_params
163
+ self.param_groups_and_shapes = None
164
+ self.lg_loss_scale = initial_lg_loss_scale
165
+
166
+ if self.use_fp16:
167
+ self.param_groups_and_shapes = get_param_groups_and_shapes(
168
+ self.model.named_parameters()
169
+ )
170
+ self.master_params = make_master_params(self.param_groups_and_shapes)
171
+ self.model.convert_to_fp16()
172
+
173
+ def zero_grad(self):
174
+ zero_grad(self.model_params)
175
+
176
+ def backward(self, loss: th.Tensor):
177
+ if self.use_fp16:
178
+ loss_scale = 2 ** self.lg_loss_scale
179
+ (loss * loss_scale).backward()
180
+ else:
181
+ loss.backward()
182
+
183
+ def optimize(self, opt: th.optim.Optimizer):
184
+ if self.use_fp16:
185
+ return self._optimize_fp16(opt)
186
+ else:
187
+ return self._optimize_normal(opt)
188
+
189
+ def _optimize_fp16(self, opt: th.optim.Optimizer):
190
+ logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191
+ model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192
+ grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193
+ if check_overflow(grad_norm):
194
+ self.lg_loss_scale -= 1
195
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196
+ zero_master_grads(self.master_params)
197
+ return False
198
+
199
+ logger.logkv_mean("grad_norm", grad_norm)
200
+ logger.logkv_mean("param_norm", param_norm)
201
+
202
+ for p in self.master_params:
203
+ p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
204
+ opt.step()
205
+ zero_master_grads(self.master_params)
206
+ master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
207
+ self.lg_loss_scale += self.fp16_scale_growth
208
+ return True
209
+
210
+ def _optimize_normal(self, opt: th.optim.Optimizer):
211
+ grad_norm, param_norm = self._compute_norms()
212
+ logger.logkv_mean("grad_norm", grad_norm)
213
+ logger.logkv_mean("param_norm", param_norm)
214
+ opt.step()
215
+ return True
216
+
217
+ def _compute_norms(self, grad_scale=1.0):
218
+ grad_norm = 0.0
219
+ param_norm = 0.0
220
+ for p in self.master_params:
221
+ with th.no_grad():
222
+ param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
223
+ if p.grad is not None:
224
+ grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
225
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
226
+
227
+ def master_params_to_state_dict(self, master_params):
228
+ return master_params_to_state_dict(
229
+ self.model, self.param_groups_and_shapes, master_params, self.use_fp16
230
+ )
231
+
232
+ def state_dict_to_master_params(self, state_dict):
233
+ return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
234
+
235
+
236
+ def check_overflow(value):
237
+ return (value == float("inf")) or (value == -float("inf")) or (value != value)
guided_diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,1400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code started out as a PyTorch port of Ho et al's diffusion models:
3
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
4
+
5
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
6
+ """
7
+
8
+ import enum
9
+ import os
10
+
11
+ import math
12
+
13
+ import numpy as np
14
+ import torch as th
15
+
16
+ from .nn import mean_flat
17
+ from .losses import normal_kl, discretized_gaussian_log_likelihood
18
+ from .midi_util import save_piano_roll_midi
19
+ from music_rule_guidance.rule_maps import FUNC_DICT, LOSS_DICT
20
+ from collections import defaultdict
21
+ import torch.nn.functional as F
22
+ import multiprocessing
23
+ from functools import partial
24
+
25
+ import matplotlib.pyplot as plt
26
+ plt.rcParams["figure.figsize"] = (20, 3)
27
+ plt.rcParams['figure.dpi'] = 300
28
+ plt.rcParams['savefig.dpi'] = 300
29
+
30
+
31
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
32
+ """
33
+ Get a pre-defined beta schedule for the given name.
34
+
35
+ The beta schedule library consists of beta schedules which remain similar
36
+ in the limit of num_diffusion_timesteps.
37
+ Beta schedules may be added, but should not be removed or changed once
38
+ they are committed to maintain backwards compatibility.
39
+ """
40
+ if schedule_name == "linear":
41
+ # Linear schedule from Ho et al, extended to work for any number of
42
+ # diffusion steps.
43
+ scale = 1000 / num_diffusion_timesteps
44
+ beta_start = scale * 0.0001
45
+ beta_end = scale * 0.02
46
+ return np.linspace(
47
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
48
+ )
49
+ elif schedule_name == "cosine":
50
+ return betas_for_alpha_bar(
51
+ num_diffusion_timesteps,
52
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
53
+ )
54
+ elif schedule_name == 'stable-diffusion':
55
+ scale = 1000 / num_diffusion_timesteps
56
+ beta_start = scale * math.sqrt(0.00085)
57
+ beta_end = scale * math.sqrt(0.012)
58
+ return np.linspace(
59
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
60
+ ) ** 2
61
+ else:
62
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
63
+
64
+
65
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
66
+ """
67
+ Create a beta schedule that discretizes the given alpha_t_bar function,
68
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
69
+
70
+ :param num_diffusion_timesteps: the number of betas to produce.
71
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
72
+ produces the cumulative product of (1-beta) up to that
73
+ part of the diffusion process.
74
+ :param max_beta: the maximum beta to use; use values lower than 1 to
75
+ prevent singularities.
76
+ """
77
+ betas = []
78
+ for i in range(num_diffusion_timesteps):
79
+ t1 = i / num_diffusion_timesteps
80
+ t2 = (i + 1) / num_diffusion_timesteps
81
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
82
+ return np.array(betas)
83
+
84
+
85
+ class ModelMeanType(enum.Enum):
86
+ """
87
+ Which type of output the model predicts.
88
+ """
89
+
90
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
91
+ START_X = enum.auto() # the model predicts x_0
92
+ EPSILON = enum.auto() # the model predicts epsilon
93
+
94
+
95
+ class ModelVarType(enum.Enum):
96
+ """
97
+ What is used as the model's output variance.
98
+
99
+ The LEARNED_RANGE option has been added to allow the model to predict
100
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
101
+ """
102
+
103
+ LEARNED = enum.auto()
104
+ FIXED_SMALL = enum.auto()
105
+ FIXED_LARGE = enum.auto()
106
+ LEARNED_RANGE = enum.auto()
107
+
108
+
109
+ class LossType(enum.Enum):
110
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
111
+ RESCALED_MSE = (
112
+ enum.auto()
113
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
114
+ KL = enum.auto() # use the variational lower-bound
115
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
116
+
117
+ def is_vb(self):
118
+ return self == LossType.KL or self == LossType.RESCALED_KL
119
+
120
+
121
+ class GaussianDiffusion:
122
+ """
123
+ Utilities for training and sampling diffusion models.
124
+
125
+ Ported directly from here, and then adapted over time to further experimentation.
126
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
127
+
128
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
129
+ starting at T and going to 1.
130
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
131
+ :param model_var_type: a ModelVarType determining how variance is output.
132
+ :param loss_type: a LossType determining the loss function to use.
133
+ :param rescale_timesteps: if True, pass floating point timesteps into the
134
+ model so that they are always scaled like in the
135
+ original paper (0 to 1000).
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ *,
141
+ betas,
142
+ model_mean_type,
143
+ model_var_type,
144
+ loss_type,
145
+ rescale_timesteps=False,
146
+ ):
147
+ self.model_mean_type = model_mean_type
148
+ self.model_var_type = model_var_type
149
+ self.loss_type = loss_type
150
+ self.rescale_timesteps = rescale_timesteps
151
+
152
+ # Use float64 for accuracy.
153
+ betas = np.array(betas, dtype=np.float64)
154
+ self.betas = betas
155
+ assert len(betas.shape) == 1, "betas must be 1-D"
156
+ assert (betas > 0).all() and (betas <= 1).all()
157
+
158
+ self.num_timesteps = int(betas.shape[0])
159
+
160
+ alphas = 1.0 - betas
161
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
162
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
163
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
164
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
165
+
166
+ # calculations for diffusion q(x_t | x_{t-1}) and others
167
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
168
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
169
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
170
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
171
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
172
+
173
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
174
+ self.posterior_variance = (
175
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
176
+ )
177
+ # log calculation clipped because the posterior variance is 0 at the
178
+ # beginning of the diffusion chain.
179
+ self.posterior_log_variance_clipped = np.log(
180
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
181
+ )
182
+ self.posterior_mean_coef1 = (
183
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
184
+ )
185
+ self.posterior_mean_coef2 = (
186
+ (1.0 - self.alphas_cumprod_prev)
187
+ * np.sqrt(alphas)
188
+ / (1.0 - self.alphas_cumprod)
189
+ )
190
+
191
+ def q_mean_variance(self, x_start, t):
192
+ """
193
+ Get the distribution q(x_t | x_0).
194
+
195
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
196
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
197
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
198
+ """
199
+ mean = (
200
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
201
+ )
202
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
203
+ log_variance = _extract_into_tensor(
204
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
205
+ )
206
+ return mean, variance, log_variance
207
+
208
+ def q_sample(self, x_start, t, noise=None):
209
+ """
210
+ Diffuse the data for a given number of diffusion steps.
211
+
212
+ In other words, sample from q(x_t | x_0).
213
+
214
+ :param x_start: the initial data batch.
215
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
216
+ :param noise: if specified, the split-out normal noise.
217
+ :return: A noisy version of x_start.
218
+ """
219
+ if noise is None:
220
+ noise = th.randn_like(x_start)
221
+ assert noise.shape == x_start.shape
222
+ return (
223
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
224
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
225
+ * noise
226
+ )
227
+
228
+ def q_posterior_mean_variance(self, x_start, x_t, t):
229
+ """
230
+ Compute the mean and variance of the diffusion posterior:
231
+
232
+ q(x_{t-1} | x_t, x_0)
233
+
234
+ """
235
+ assert x_start.shape == x_t.shape
236
+ posterior_mean = (
237
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
238
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
239
+ )
240
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
241
+ posterior_log_variance_clipped = _extract_into_tensor(
242
+ self.posterior_log_variance_clipped, t, x_t.shape
243
+ )
244
+ assert (
245
+ posterior_mean.shape[0]
246
+ == posterior_variance.shape[0]
247
+ == posterior_log_variance_clipped.shape[0]
248
+ == x_start.shape[0]
249
+ )
250
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
251
+
252
+ def p_mean_variance(
253
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
254
+ cond_fn=None, embed_model=None, edit_kwargs=None,
255
+ ):
256
+ """
257
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
258
+ the initial x, x_0.
259
+
260
+ :param model: the model, which takes a signal and a batch of timesteps
261
+ as input.
262
+ :param x: the [N x C x ...] tensor at time t.
263
+ :param t: a 1-D Tensor of timesteps.
264
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
265
+ :param denoised_fn: if not None, a function which applies to the
266
+ x_start prediction before it is used to sample. Applies before
267
+ clip_denoised.
268
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
269
+ pass to the model. This can be used for conditioning.
270
+ :param cond_fn: log p(y|x), to maximize
271
+ :param embed_model: contains encoder and decoder
272
+ :param edit_kwargs: replacement-based conditioning
273
+ :return: a dict with the following keys:
274
+ - 'mean': the model mean output.
275
+ - 'variance': the model variance output.
276
+ - 'log_variance': the log of 'variance'.
277
+ - 'pred_xstart': the prediction for x_0.
278
+ """
279
+ def process_xstart(x):
280
+ if denoised_fn is not None:
281
+ x = denoised_fn(x)
282
+ if clip_denoised:
283
+ return x.clamp(-1, 1)
284
+ return x
285
+
286
+ if model_kwargs is None:
287
+ model_kwargs = {}
288
+
289
+ B, C = x.shape[:2]
290
+ assert t.shape == (B,)
291
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
292
+
293
+ if edit_kwargs is not None:
294
+ pred_xstart = process_xstart(
295
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
296
+ )
297
+ replaced_x0 = edit_kwargs["mask"] * edit_kwargs["gt"] + (1 - edit_kwargs["mask"]) * pred_xstart
298
+ model_output = self._predict_eps_from_xstart(x_t=x, t=t, pred_xstart=replaced_x0)
299
+
300
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
301
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
302
+ model_output, model_var_values = th.split(model_output, C, dim=1)
303
+ if self.model_var_type == ModelVarType.LEARNED:
304
+ model_log_variance = model_var_values
305
+ model_variance = th.exp(model_log_variance)
306
+ else:
307
+ min_log = _extract_into_tensor(
308
+ self.posterior_log_variance_clipped, t, x.shape
309
+ )
310
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
311
+ # The model_var_values is [-1, 1] for [min_var, max_var].
312
+ frac = (model_var_values + 1) / 2
313
+ model_log_variance = frac * max_log + (1 - frac) * min_log
314
+ model_variance = th.exp(model_log_variance)
315
+ else:
316
+ model_variance, model_log_variance = {
317
+ # for fixedlarge, we set the initial (log-)variance like so
318
+ # to get a better decoder log likelihood.
319
+ ModelVarType.FIXED_LARGE: (
320
+ np.append(self.posterior_variance[1], self.betas[1:]),
321
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
322
+ ),
323
+ ModelVarType.FIXED_SMALL: (
324
+ self.posterior_variance,
325
+ self.posterior_log_variance_clipped,
326
+ ),
327
+ }[self.model_var_type]
328
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
329
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
330
+
331
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
332
+ pred_xstart = process_xstart(
333
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
334
+ )
335
+ model_mean = model_output
336
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
337
+ if self.model_mean_type == ModelMeanType.START_X:
338
+ pred_xstart = process_xstart(model_output)
339
+ else:
340
+ pred_xstart = process_xstart(
341
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
342
+ )
343
+ model_mean, _, _ = self.q_posterior_mean_variance(
344
+ x_start=pred_xstart, x_t=x, t=t
345
+ )
346
+ else:
347
+ raise NotImplementedError(self.model_mean_type)
348
+
349
+ assert (
350
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
351
+ )
352
+ return {
353
+ "mean": model_mean,
354
+ "variance": model_variance,
355
+ "log_variance": model_log_variance,
356
+ "pred_xstart": pred_xstart,
357
+ }
358
+
359
+ def _predict_xstart_from_eps(self, x_t, t, eps):
360
+ assert x_t.shape == eps.shape
361
+ return (
362
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
363
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
364
+ )
365
+
366
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
367
+ assert x_t.shape == xprev.shape
368
+ return ( # (xprev - coef2*x_t) / coef1
369
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
370
+ - _extract_into_tensor(
371
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
372
+ )
373
+ * x_t
374
+ )
375
+
376
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
377
+ return (
378
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
379
+ - pred_xstart
380
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
381
+
382
+ def _scale_timesteps(self, t):
383
+ if self.rescale_timesteps:
384
+ return t.float() * (1000.0 / self.num_timesteps)
385
+ return t
386
+
387
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None, guidance_kwargs=None,
388
+ model=None, embed_model=None, edit_kwargs=None, scale_factor=1.,
389
+ record=False):
390
+ """
391
+ Compute the mean for the previous step, given a function cond_fn that
392
+ computes the gradient of a conditional log probability with respect to
393
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
394
+ condition on y.
395
+
396
+ If dps=True, use diffusion posterior sampling, cond_fn is log p(y|x_0)
397
+ instead of the grad of it. Need to use model (eps) and embed_model.
398
+
399
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
400
+ """
401
+ dps = True if guidance_kwargs.method == 'dps' else False
402
+ if not dps:
403
+ if edit_kwargs is None:
404
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
405
+ new_mean = (
406
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
407
+ )
408
+ else:
409
+ # only compute gradient on editable latents, since rule is only on editable length
410
+ x = x[:, :, edit_kwargs["l_start"]:edit_kwargs["l_end"], :]
411
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
412
+ new_mean = p_mean_var["mean"].float()
413
+ new_mean[:, :, edit_kwargs["l_start"]:edit_kwargs["l_end"], :] += (
414
+ p_mean_var["variance"] * gradient.float())
415
+ else:
416
+ assert model is not None
417
+ step_size = guidance_kwargs.step_size
418
+ with th.enable_grad():
419
+ xt = x.detach().requires_grad_(True)
420
+ eps = model(xt, self._scale_timesteps(t), **model_kwargs)
421
+ pred_xstart = self._predict_xstart_from_eps(xt, t, eps)
422
+ # If vae is not None, and not dps_nn, i.e. using dps rule
423
+ if embed_model is not None and not guidance_kwargs.nn:
424
+ pred_xstart = _decode(pred_xstart, embed_model, scale_factor=scale_factor)
425
+ if record:
426
+ pred_xstart.retain_grad()
427
+ if edit_kwargs is not None:
428
+ # only check condition on the editable part
429
+ pred_xstart = pred_xstart[:, :, edit_kwargs["l_start"]:edit_kwargs["l_end"], :]
430
+ log_probs = cond_fn(pred_xstart, self._scale_timesteps(t), **model_kwargs)
431
+ gradient = th.autograd.grad(log_probs.sum(), xt)[0]
432
+
433
+ # check if x_0 space works
434
+ if record:
435
+ pred_xstart_up = pred_xstart + pred_xstart.grad
436
+ log_probs_up = cond_fn(pred_xstart_up, self._scale_timesteps(t), **model_kwargs)
437
+ # record gradient difference
438
+ cur_grad_diff = (self.prev_gradient_single - gradient).reshape(x.shape[0], -1).norm(dim=-1)
439
+ prev_gradient_norm = self.prev_gradient_single.reshape(x.shape[0], -1).norm(dim=-1)
440
+ if prev_gradient_norm.mean() > 1e-5:
441
+ self.grad_norm.append(prev_gradient_norm.mean().item())
442
+ cur_grad_diff = cur_grad_diff / prev_gradient_norm
443
+ self.gradient_diff.append(cur_grad_diff.mean().item())
444
+ self.prev_gradient_single = gradient
445
+ self.log_probs.append((log_probs.mean().item()))
446
+
447
+ gradient = gradient / th.sqrt(-log_probs.view(x.shape[0], 1, 1, 1) + 1e-12)
448
+ # gradient = gradient / (-log_probs.view(x.shape[0], 1, 1, 1) + 1e-12)
449
+
450
+ if edit_kwargs is None:
451
+ new_mean = (
452
+ p_mean_var["mean"].float() + step_size * gradient.float()
453
+ )
454
+ else:
455
+ new_mean = p_mean_var["mean"].float()
456
+ new_mean[:, :, edit_kwargs["l_start"]:edit_kwargs["l_end"], :] += step_size * gradient.float()
457
+
458
+ # check whether moved towards good direction om z space
459
+ if record:
460
+ eps = model(xt + step_size * gradient.float(), self._scale_timesteps(t), **model_kwargs)
461
+ pred_xstart_2 = self._predict_xstart_from_eps(xt, t, eps)
462
+ pred_xstart_2 = _decode(pred_xstart_2, embed_model, scale_factor=scale_factor)
463
+ log_probs_2 = cond_fn(pred_xstart_2, self._scale_timesteps(t), **model_kwargs)
464
+
465
+ return new_mean
466
+
467
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
468
+ """
469
+ Compute what the p_mean_variance output would have been, should the
470
+ model's score function be conditioned by cond_fn.
471
+
472
+ See condition_mean() for details on cond_fn.
473
+
474
+ Unlike condition_mean(), this instead uses the conditioning strategy
475
+ from Song et al (2020).
476
+ """
477
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
478
+
479
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
480
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
481
+ x, self._scale_timesteps(t), **model_kwargs
482
+ )
483
+
484
+ out = p_mean_var.copy()
485
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
486
+ out["mean"], _, _ = self.q_posterior_mean_variance(
487
+ x_start=out["pred_xstart"], x_t=x, t=t
488
+ )
489
+ return out
490
+
491
+ def scg_sample(self,
492
+ model,
493
+ t,
494
+ mean_pred,
495
+ g_coeff,
496
+ embed_model,
497
+ scale_factor,
498
+ model_kwargs=None,
499
+ scg_kwargs=None,
500
+ edit_kwargs=None,
501
+ dc_kwargs=None,
502
+ record=False,
503
+ record_freq=100):
504
+ """
505
+ Sample N x_{t-1} from x_t and select the best one.
506
+ """
507
+ # mean_pred = p_mean_var["mean"]
508
+ # g_coeff = th.exp(0.5 * p_mean_var["log_variance"])
509
+ num_samples = scg_kwargs["num_samples"]
510
+ sample = mean_pred.unsqueeze(dim=0)
511
+ sample = sample.expand(num_samples, *mean_pred.shape).contiguous()
512
+ noise = th.randn_like(sample)
513
+ sample = sample + g_coeff * noise
514
+ sample = sample.view(-1, *mean_pred.shape[1:])
515
+ t = t.repeat(num_samples)
516
+ # it's fine to use different target for different samples, expand and repeat match with each other (012012)
517
+ cloned_model_kwargs = {"y": model_kwargs["y"].repeat(num_samples)}
518
+ eps = model(sample, self._scale_timesteps(t), **cloned_model_kwargs)
519
+ pred_xstart = self._predict_xstart_from_eps(sample, t, eps)
520
+ if edit_kwargs is not None:
521
+ # only decode editable part
522
+ pred_xstart = pred_xstart[:, :, edit_kwargs["l_start"]:edit_kwargs["l_end"], :]
523
+ if embed_model is not None:
524
+ pred_xstart = _decode(pred_xstart, embed_model, scale_factor=scale_factor)
525
+
526
+ if dc_kwargs is None or dc_kwargs.base <= 0:
527
+ if record:
528
+ # create dictionary to record the loss for each rule
529
+ each_loss = {}
530
+ # work with multiple rules, model_kwargs["rule"] is a dict that contains rule_name: target
531
+ total_log_prob = 0
532
+ for rule_name, rule_target in model_kwargs["rule"].items():
533
+ gen_rule = _extract_rule(rule_name, pred_xstart)
534
+ y_ = rule_target.repeat(num_samples, 1)
535
+ log_prob = - LOSS_DICT[rule_name](gen_rule, y_)
536
+ if record:
537
+ each_loss[rule_name] = -log_prob.view(num_samples, -1)
538
+ total_log_prob += log_prob * scg_kwargs.get(rule_name, 1.)
539
+ total_log_prob = total_log_prob.view(num_samples, -1)
540
+ max_ind = total_log_prob.argmax(dim=0)
541
+
542
+ # softmax (need to reweight to get unit var otherwise goes to empty rolls)
543
+ # weight = F.softmax(total_log_prob * 1., dim=0)
544
+ # var = (weight ** 2).sum(dim=0)
545
+ # avg_noise = (noise * weight[..., None, None, None]).sum(dim=0) / th.sqrt(var)[..., None, None, None]
546
+ # # not adding dw
547
+ # sample = mean_pred + g_coeff * avg_noise
548
+ # # add dw
549
+ # dw = th.randn_like(p_mean_var["mean"])
550
+ # sample = mean_pred + g_coeff * (avg_noise + dw)
551
+
552
+ # take argmax
553
+ sample = sample.view(num_samples, *mean_pred.shape)
554
+ sample = sample[max_ind, th.arange(mean_pred.shape[0])]
555
+
556
+ # take argmax, and add dw
557
+ # noise = noise.view(num_samples, *p_mean_var["mean"].shape)
558
+ # best_noise = noise[max_ind, th.arange(p_mean_var["mean"].shape[0])]
559
+ # dw = th.randn_like(p_mean_var["mean"])
560
+ # sample = p_mean_var["mean"] + th.exp(0.5 * p_mean_var["log_variance"]) * (best_noise + dw)
561
+
562
+ else:
563
+ # Assuming base length in x0 is only controlled by the corresponding location in xt
564
+ # (doesn't hold, but maybe can approximate because of cond ind)
565
+ sample = sample.view(num_samples, *mean_pred.shape)
566
+ sub_samples = []
567
+ total_length = pred_xstart.shape[-1]
568
+ start_inds = th.arange(0, total_length, dc_kwargs.base*8)
569
+ rule_base = dc_kwargs.base // 16 # number of rules under the base length
570
+ for i, start_ind in enumerate(start_inds):
571
+ end_ind = min(start_ind+dc_kwargs.base*8, total_length)
572
+ pred_xstart_cur = pred_xstart[:, :, :, start_ind: end_ind]
573
+ total_log_prob = 0
574
+ for rule_name, rule_target in model_kwargs["rule"].items():
575
+ gen_rule = _extract_rule(rule_name, pred_xstart_cur)
576
+ if rule_name == 'note_density':
577
+ half = rule_target.shape[-1] // 2
578
+ vt_nd_target = rule_target[:, :half][:, i*rule_base: min((i+1)*rule_base, half)]
579
+ hr_nd_target = rule_target[:, half:][:, i*rule_base: min((i+1)*rule_base, half)]
580
+ rule_target = th.concat((vt_nd_target, hr_nd_target), dim=-1)
581
+ elif 'chord' in rule_name:
582
+ rule_length = rule_target.shape[-1]
583
+ rule_target = rule_target[:, i*rule_base: min((i+1)*rule_base, rule_length)]
584
+ y_ = rule_target.repeat(num_samples, 1)
585
+ log_prob = - LOSS_DICT[rule_name](gen_rule, y_)
586
+ total_log_prob += log_prob * scg_kwargs.get(rule_name, 1.)
587
+ total_log_prob = total_log_prob.view(num_samples, -1)
588
+ max_ind = total_log_prob.argmax(dim=0)
589
+ # take argmax on num_sample x batch_size x 4 x 256 x 16
590
+ sub_sample = sample[max_ind, th.arange(mean_pred.shape[0]), :, start_ind//8: end_ind//8]
591
+ sub_samples.append(sub_sample)
592
+ sample = th.concat(sub_samples, dim=-2)
593
+
594
+ if record:
595
+ for rule_name, loss in each_loss.items():
596
+ current_loss = loss[max_ind, th.arange(mean_pred.shape[0])][0].item()
597
+ self.each_loss[rule_name].append((t[0].item(), current_loss))
598
+ max_log_prob = total_log_prob[max_ind, th.arange(mean_pred.shape[0])][0].item()
599
+ # record log_prob
600
+ self.log_probs.append((t[0].item(), max_log_prob))
601
+ # record loss std
602
+ self.loss_std.append((t[0].item(), total_log_prob.std().item()))
603
+ # record loss range
604
+ self.loss_range.append((t[0].item(), (max_log_prob - total_log_prob.min()).abs().item()))
605
+ # record gradient difference
606
+ noise = noise.view(num_samples, *mean_pred.shape)
607
+ gradient = noise[max_ind, th.arange(mean_pred.shape[0])]
608
+ cur_grad_diff = (self.prev_gradient_single - gradient).reshape(sample.shape[0], -1).norm(dim=-1)
609
+ prev_gradient_norm = self.prev_gradient_single.reshape(sample.shape[0], -1).norm(dim=-1)
610
+ if prev_gradient_norm.mean() > 1e-5:
611
+ self.grad_norm.append(prev_gradient_norm.mean().item())
612
+ cur_grad_diff = cur_grad_diff / prev_gradient_norm
613
+ self.gradient_diff.append(cur_grad_diff.mean().item())
614
+ self.prev_gradient_single = gradient
615
+ if (t[0] + 1) % record_freq == 0:
616
+ pred_xstart = pred_xstart.view(num_samples, -1, *pred_xstart.shape[1:])
617
+ pred_xstart = pred_xstart[max_ind, th.arange(mean_pred.shape[0])]
618
+ pred_xstart[pred_xstart <= -0.95] = -1. # heuristic thresholding the background
619
+ pred_xstart = ((pred_xstart + 1) * 63.5).clamp(0, 127).to(th.uint8)
620
+ self.inter_piano_rolls.append(pred_xstart.cpu())
621
+
622
+ # plot loss distribution
623
+ if len(model_kwargs["rule"].keys()) <= 1:
624
+ plt.figure(figsize=(4, 3))
625
+ total_log_prob = total_log_prob.view(-1).cpu()
626
+ plt.bar(range(len(total_log_prob)), -total_log_prob)
627
+ plt.xlabel('choice')
628
+ plt.ylabel('loss')
629
+ plt.title(f't={t[0]+1}')
630
+ plt.tight_layout()
631
+ plt.savefig(f'loggings/debug/t={t[0]+1}.png')
632
+ plt.show()
633
+ return sample
634
+
635
+ def p_sample(
636
+ self,
637
+ model,
638
+ x,
639
+ t,
640
+ clip_denoised=True,
641
+ denoised_fn=None,
642
+ cond_fn=None,
643
+ model_kwargs=None,
644
+ embed_model=None,
645
+ scale_factor=1.,
646
+ guidance_kwargs=None,
647
+ scg_kwargs=None,
648
+ edit_kwargs=None,
649
+ record=False,
650
+ ):
651
+ """
652
+ Sample x_{t-1} from the model at the given timestep.
653
+
654
+ :param model: the model to sample from.
655
+ :param x: the current tensor at x_{t-1}.
656
+ :param t: the value of t, starting at 0 for the first diffusion step.
657
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
658
+ :param denoised_fn: if not None, a function which applies to the
659
+ x_start prediction before it is used to sample.
660
+ :param cond_fn: if not None, this is a gradient function that acts
661
+ similarly to the model.
662
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
663
+ pass to the model. This can be used for conditioning.
664
+ :return: a dict containing the following keys:
665
+ - 'sample': a random sample from the model.
666
+ - 'pred_xstart': a prediction of x_0.
667
+ """
668
+ if guidance_kwargs is not None:
669
+ if guidance_kwargs.schedule:
670
+ t_start = guidance_kwargs.t_start
671
+ t_end = guidance_kwargs.t_end
672
+ interval = guidance_kwargs.interval
673
+ use_guidance = guide_schedule(t, t_start, t_end, interval)
674
+ else:
675
+ use_guidance = True
676
+ else:
677
+ use_guidance = False
678
+ out = self.p_mean_variance(
679
+ model,
680
+ x,
681
+ t,
682
+ clip_denoised=clip_denoised,
683
+ denoised_fn=denoised_fn,
684
+ model_kwargs=model_kwargs,
685
+ cond_fn=cond_fn,
686
+ embed_model=embed_model,
687
+ edit_kwargs=edit_kwargs,
688
+ )
689
+
690
+ # if use scg guidance, then schedule only applies to scg sampling
691
+ if cond_fn is not None and (use_guidance or scg_kwargs is not None):
692
+ out["mean"] = self.condition_mean(
693
+ cond_fn, out, x, t, model_kwargs=model_kwargs,
694
+ guidance_kwargs=guidance_kwargs, model=model, embed_model=embed_model,
695
+ edit_kwargs=edit_kwargs, scale_factor=scale_factor
696
+ )
697
+
698
+ if scg_kwargs is None:
699
+ noise = th.randn_like(x)
700
+ nonzero_mask = (
701
+ (t > self.t_end).float().view(-1, *([1] * (len(x.shape) - 1)))
702
+ ) # no noise when t == t_end (0 if not early stopping)
703
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
704
+
705
+ else: # scg search (greedy)
706
+ if t[0] > self.t_end:
707
+ mean_pred = out["mean"]
708
+ g_coeff = th.exp(0.5 * out["log_variance"])
709
+ if use_guidance:
710
+ dc_kwargs = getattr(guidance_kwargs, 'dc', None)
711
+ sample = self.scg_sample(model, t, mean_pred, g_coeff, embed_model, scale_factor,
712
+ model_kwargs=model_kwargs, scg_kwargs=scg_kwargs,
713
+ edit_kwargs=edit_kwargs, dc_kwargs=dc_kwargs, record=record)
714
+ else:
715
+ sample = mean_pred + g_coeff * th.randn_like(x)
716
+ if record:
717
+ eps = model(sample, self._scale_timesteps(t), **model_kwargs)
718
+ pred_xstart = self._predict_xstart_from_eps(sample, t, eps)
719
+ pred_xstart = _decode(pred_xstart, embed_model, scale_factor=scale_factor)
720
+ if len(model_kwargs["rule"].keys()) <= 1:
721
+ # only record for individual rule to save time
722
+ total_log_prob = 0
723
+ for rule_name, rule_target in model_kwargs["rule"].items():
724
+ gen_rule = _extract_rule(rule_name, pred_xstart)
725
+ log_prob = - LOSS_DICT[rule_name](gen_rule, rule_target)
726
+ total_log_prob += log_prob.mean().item() * scg_kwargs.get(rule_name, 1.)
727
+ self.log_probs.append((t[0].item(), total_log_prob))
728
+ if (t[0] + 1) % 100 == 0:
729
+ pred_xstart[pred_xstart <= -0.95] = -1. # heuristic thresholding the background
730
+ pred_xstart = ((pred_xstart + 1) * 63.5).clamp(0, 127).to(th.uint8)
731
+ self.inter_piano_rolls.append(pred_xstart.cpu())
732
+ else:
733
+ sample = out["mean"]
734
+
735
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
736
+
737
+ def p_sample_loop(
738
+ self,
739
+ model,
740
+ shape,
741
+ noise=None,
742
+ clip_denoised=True,
743
+ denoised_fn=None,
744
+ t_end=0,
745
+ cond_fn=None,
746
+ model_kwargs=None,
747
+ device=None,
748
+ progress=False,
749
+ embed_model=None,
750
+ scale_factor=1.,
751
+ guidance_kwargs=None,
752
+ scg_kwargs=None,
753
+ edit_kwargs=None,
754
+ record=False,
755
+ ):
756
+ """
757
+ Generate samples from the model.
758
+
759
+ :param model: the model module.
760
+ :param shape: the shape of the samples, (N, C, H, W).
761
+ :param noise: if specified, the noise from the encoder to sample.
762
+ Should be of the same shape as `shape`.
763
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
764
+ :param denoised_fn: if not None, a function which applies to the
765
+ x_start prediction before it is used to sample.
766
+ :param t_end: early stopping for the sampling process
767
+ :param cond_fn: if not None, this is a gradient function that acts
768
+ similarly to the model.
769
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
770
+ pass to the model. This can be used for conditioning.
771
+ :param device: if specified, the device to create the samples on.
772
+ If not specified, use a model parameter's device.
773
+ :param progress: if True, show a tqdm progress bar.
774
+ :return: a non-differentiable batch of samples.
775
+ """
776
+ final = None
777
+ self.t_end = t_end
778
+ if record:
779
+ self.prev_gradient_single = th.zeros(shape, device=device)
780
+ self.gradient_diff = []
781
+ self.grad_norm = []
782
+ self.log_probs = []
783
+ # record loss for each rule
784
+ self.each_loss = defaultdict(list)
785
+ self.inter_piano_rolls = []
786
+ self.loss_std = []
787
+ self.loss_range = []
788
+ for sample in self.p_sample_loop_progressive(
789
+ model,
790
+ shape,
791
+ noise=noise,
792
+ clip_denoised=clip_denoised,
793
+ denoised_fn=denoised_fn,
794
+ t_end=t_end,
795
+ cond_fn=cond_fn,
796
+ model_kwargs=model_kwargs,
797
+ device=device,
798
+ progress=progress,
799
+ embed_model=embed_model,
800
+ scale_factor=scale_factor,
801
+ guidance_kwargs=guidance_kwargs,
802
+ scg_kwargs=scg_kwargs,
803
+ edit_kwargs=edit_kwargs,
804
+ record=record,
805
+ ):
806
+ final = sample
807
+ return final["sample"]
808
+
809
+ def p_sample_loop_progressive(
810
+ self,
811
+ model,
812
+ shape,
813
+ noise=None,
814
+ clip_denoised=True,
815
+ denoised_fn=None,
816
+ t_end=0,
817
+ cond_fn=None,
818
+ model_kwargs=None,
819
+ device=None,
820
+ progress=False,
821
+ embed_model=None,
822
+ scale_factor=1.,
823
+ guidance_kwargs=None,
824
+ scg_kwargs=None,
825
+ edit_kwargs=None,
826
+ record=False,
827
+ ):
828
+ """
829
+ Generate samples from the model and yield intermediate samples from
830
+ each timestep of diffusion.
831
+
832
+ Arguments are the same as p_sample_loop().
833
+ Returns a generator over dicts, where each dict is the return value of
834
+ p_sample().
835
+ """
836
+ if device is None:
837
+ device = next(model.parameters()).device
838
+ assert isinstance(shape, (tuple, list))
839
+ if noise is not None:
840
+ img = noise
841
+ elif edit_kwargs is not None:
842
+ t = th.tensor([edit_kwargs["noise_level"]-1] * shape[0], device=device)
843
+ alpha_cumprod = _extract_into_tensor(self.alphas_cumprod, t, shape)
844
+ img = th.sqrt(alpha_cumprod) * edit_kwargs["gt"] + th.sqrt((1 - alpha_cumprod)) * th.randn(*shape, device=device)
845
+ else:
846
+ img = th.randn(*shape, device=device)
847
+ indices = list(range(self.num_timesteps))[::-1]
848
+ if t_end:
849
+ indices = indices[:-t_end]
850
+ if edit_kwargs is not None:
851
+ t_start = self.num_timesteps - edit_kwargs["noise_level"]
852
+ indices = indices[t_start:]
853
+
854
+ if progress:
855
+ # Lazy import so that we don't depend on tqdm.
856
+ from tqdm.auto import tqdm
857
+
858
+ indices = tqdm(indices)
859
+
860
+ for i in indices:
861
+ t = th.tensor([i] * shape[0], device=device)
862
+ with th.no_grad():
863
+ out = self.p_sample(
864
+ model,
865
+ img,
866
+ t,
867
+ clip_denoised=clip_denoised,
868
+ denoised_fn=denoised_fn,
869
+ cond_fn=cond_fn,
870
+ model_kwargs=model_kwargs,
871
+ embed_model=embed_model,
872
+ scale_factor=scale_factor,
873
+ guidance_kwargs=guidance_kwargs,
874
+ scg_kwargs=scg_kwargs,
875
+ edit_kwargs=edit_kwargs,
876
+ record=record,
877
+ )
878
+ yield out
879
+ img = out["sample"]
880
+
881
+ def ddim_sample(
882
+ self,
883
+ model,
884
+ x,
885
+ t,
886
+ clip_denoised=True,
887
+ denoised_fn=None,
888
+ cond_fn=None,
889
+ model_kwargs=None,
890
+ eta=0.0,
891
+ embed_model=None,
892
+ scale_factor=1.,
893
+ guidance_kwargs=None,
894
+ edit_kwargs=None,
895
+ scg_kwargs=None,
896
+ record=False,
897
+ ):
898
+ """
899
+ Sample x_{t-1} from the model using DDIM.
900
+
901
+ Same usage as p_sample().
902
+ """
903
+ if guidance_kwargs is not None:
904
+ if guidance_kwargs.schedule:
905
+ t_start = guidance_kwargs.t_start
906
+ t_end = guidance_kwargs.t_end
907
+ interval = guidance_kwargs.interval
908
+ use_guidance = guide_schedule(t, t_start, t_end, interval)
909
+ else:
910
+ use_guidance = True
911
+ else:
912
+ use_guidance = False
913
+ out = self.p_mean_variance(
914
+ model,
915
+ x,
916
+ t,
917
+ clip_denoised=clip_denoised,
918
+ denoised_fn=denoised_fn,
919
+ model_kwargs=model_kwargs,
920
+ cond_fn=cond_fn,
921
+ embed_model=embed_model,
922
+ edit_kwargs=edit_kwargs,
923
+ )
924
+ if cond_fn is not None and use_guidance:
925
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
926
+
927
+ # Usually our model outputs epsilon, but we re-derive it
928
+ # in case we used x_start or x_prev prediction.
929
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
930
+
931
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
932
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
933
+ sigma = (
934
+ eta
935
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
936
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
937
+ )
938
+ # Equation 12.
939
+ mean_pred = (
940
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
941
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
942
+ )
943
+ if scg_kwargs is None:
944
+ noise = th.randn_like(x)
945
+ nonzero_mask = (
946
+ (t != self.t_end).float().view(-1, *([1] * (len(x.shape) - 1)))
947
+ ) # no noise when t == t_end (0 if not early stopping)
948
+ sample = mean_pred + nonzero_mask * sigma * noise
949
+ else:
950
+ if t[0] > self.t_end:
951
+ g_coeff = sigma
952
+ if use_guidance: # tune according to ddim steps
953
+ dc_kwargs = getattr(guidance_kwargs, 'dc', None)
954
+ sample = self.scg_sample(self._wrap_model(model), t, mean_pred, g_coeff, embed_model, scale_factor,
955
+ model_kwargs=model_kwargs, scg_kwargs=scg_kwargs, edit_kwargs=edit_kwargs,
956
+ dc_kwargs=dc_kwargs, record=record, record_freq=10)
957
+ else:
958
+ sample = mean_pred + g_coeff * th.randn_like(x)
959
+ if record:
960
+ eps = self._wrap_model(model)(sample, self._scale_timesteps(t), **model_kwargs)
961
+ pred_xstart = self._predict_xstart_from_eps(sample, t, eps)
962
+ pred_xstart = _decode(pred_xstart, embed_model, scale_factor=scale_factor)
963
+ total_log_prob = 0
964
+ for rule_name, rule_target in model_kwargs["rule"].items():
965
+ gen_rule = _extract_rule(rule_name, pred_xstart)
966
+ log_prob = - LOSS_DICT[rule_name](gen_rule, rule_target)
967
+ total_log_prob += log_prob.mean().item() * scg_kwargs.get(rule_name, 1.)
968
+ self.log_probs.append((t[0].item(), total_log_prob))
969
+
970
+ if (t[0] + 1) % 10 == 0:
971
+ pred_xstart[pred_xstart <= -0.95] = -1. # heuristic thresholding the background
972
+ pred_xstart = ((pred_xstart + 1) * 63.5).clamp(0, 127).to(th.uint8)
973
+ self.inter_piano_rolls.append(pred_xstart.cpu())
974
+ else:
975
+ sample = mean_pred
976
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
977
+
978
+ def ddim_reverse_sample(
979
+ self,
980
+ model,
981
+ x,
982
+ t,
983
+ clip_denoised=True,
984
+ denoised_fn=None,
985
+ model_kwargs=None,
986
+ eta=0.0,
987
+ ):
988
+ """
989
+ Sample x_{t+1} from the model using DDIM reverse ODE.
990
+ """
991
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
992
+ out = self.p_mean_variance(
993
+ model,
994
+ x,
995
+ t,
996
+ clip_denoised=clip_denoised,
997
+ denoised_fn=denoised_fn,
998
+ model_kwargs=model_kwargs,
999
+ )
1000
+ # Usually our model outputs epsilon, but we re-derive it
1001
+ # in case we used x_start or x_prev prediction.
1002
+ eps = (
1003
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
1004
+ - out["pred_xstart"]
1005
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
1006
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
1007
+
1008
+ # Equation 12. reversed
1009
+ mean_pred = (
1010
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
1011
+ + th.sqrt(1 - alpha_bar_next) * eps
1012
+ )
1013
+
1014
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
1015
+
1016
+ def ddim_sample_loop(
1017
+ self,
1018
+ model,
1019
+ shape,
1020
+ noise=None,
1021
+ clip_denoised=True,
1022
+ denoised_fn=None,
1023
+ t_end=0,
1024
+ cond_fn=None,
1025
+ model_kwargs=None,
1026
+ device=None,
1027
+ progress=False,
1028
+ eta=0.0,
1029
+ embed_model=None,
1030
+ scale_factor=1.,
1031
+ guidance_kwargs=None,
1032
+ scg_kwargs=None,
1033
+ edit_kwargs=None,
1034
+ record=False,
1035
+ ):
1036
+ """
1037
+ Generate samples from the model using DDIM.
1038
+
1039
+ Same usage as p_sample_loop().
1040
+ """
1041
+ final = None
1042
+ self.t_end = t_end
1043
+ if record:
1044
+ self.prev_gradient_single = th.zeros(shape, device=device)
1045
+ self.gradient_diff = []
1046
+ self.grad_norm = []
1047
+ self.log_probs = []
1048
+ self.inter_piano_rolls = []
1049
+ self.loss_std = []
1050
+ self.loss_range = []
1051
+ for sample in self.ddim_sample_loop_progressive(
1052
+ model,
1053
+ shape,
1054
+ noise=noise,
1055
+ clip_denoised=clip_denoised,
1056
+ denoised_fn=denoised_fn,
1057
+ t_end=t_end,
1058
+ cond_fn=cond_fn,
1059
+ model_kwargs=model_kwargs,
1060
+ device=device,
1061
+ progress=progress,
1062
+ eta=eta,
1063
+ embed_model=embed_model,
1064
+ scale_factor=scale_factor,
1065
+ guidance_kwargs=guidance_kwargs,
1066
+ scg_kwargs=scg_kwargs,
1067
+ edit_kwargs=edit_kwargs,
1068
+ record=record,
1069
+ ):
1070
+ final = sample
1071
+ return final["sample"]
1072
+
1073
+ def ddim_sample_loop_progressive(
1074
+ self,
1075
+ model,
1076
+ shape,
1077
+ noise=None,
1078
+ clip_denoised=True,
1079
+ denoised_fn=None,
1080
+ t_end=0,
1081
+ cond_fn=None,
1082
+ model_kwargs=None,
1083
+ device=None,
1084
+ progress=False,
1085
+ eta=0.0,
1086
+ embed_model=None,
1087
+ scale_factor=1.,
1088
+ guidance_kwargs=None,
1089
+ scg_kwargs=None,
1090
+ edit_kwargs=None,
1091
+ record=False,
1092
+ ):
1093
+ """
1094
+ Use DDIM to sample from the model and yield intermediate samples from
1095
+ each timestep of DDIM.
1096
+
1097
+ Same usage as p_sample_loop_progressive().
1098
+ """
1099
+ if device is None:
1100
+ device = next(model.parameters()).device
1101
+ assert isinstance(shape, (tuple, list))
1102
+ if noise is not None:
1103
+ img = noise
1104
+ elif edit_kwargs is not None:
1105
+ t = th.tensor([edit_kwargs["noise_level"]-1] * shape[0], device=device)
1106
+ alpha_cumprod = _extract_into_tensor(self.alphas_cumprod, t, shape)
1107
+ img = th.sqrt(alpha_cumprod) * edit_kwargs["gt"] + th.sqrt((1 - alpha_cumprod)) * th.randn(*shape, device=device)
1108
+ else:
1109
+ img = th.randn(*shape, device=device)
1110
+ indices = list(range(self.num_timesteps))[::-1]
1111
+ if t_end:
1112
+ indices = indices[:-t_end]
1113
+ if edit_kwargs is not None:
1114
+ t_start = self.num_timesteps - edit_kwargs["noise_level"]
1115
+ indices = indices[t_start:]
1116
+
1117
+ if progress:
1118
+ # Lazy import so that we don't depend on tqdm.
1119
+ from tqdm.auto import tqdm
1120
+
1121
+ indices = tqdm(indices)
1122
+
1123
+ for i in indices:
1124
+ t = th.tensor([i] * shape[0], device=device)
1125
+ with th.no_grad():
1126
+ out = self.ddim_sample(
1127
+ model,
1128
+ img,
1129
+ t,
1130
+ clip_denoised=clip_denoised,
1131
+ denoised_fn=denoised_fn,
1132
+ cond_fn=cond_fn,
1133
+ model_kwargs=model_kwargs,
1134
+ eta=eta,
1135
+ embed_model=embed_model,
1136
+ scale_factor=scale_factor,
1137
+ guidance_kwargs=guidance_kwargs,
1138
+ scg_kwargs=scg_kwargs,
1139
+ edit_kwargs=edit_kwargs,
1140
+ record=record,
1141
+ )
1142
+ yield out
1143
+ img = out["sample"]
1144
+
1145
+ def _vb_terms_bpd(
1146
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
1147
+ ):
1148
+ """
1149
+ Get a term for the variational lower-bound.
1150
+
1151
+ The resulting units are bits (rather than nats, as one might expect).
1152
+ This allows for comparison to other papers.
1153
+
1154
+ :return: a dict with the following keys:
1155
+ - 'output': a shape [N] tensor of NLLs or KLs.
1156
+ - 'pred_xstart': the x_0 predictions.
1157
+ """
1158
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
1159
+ x_start=x_start, x_t=x_t, t=t
1160
+ )
1161
+ out = self.p_mean_variance(
1162
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
1163
+ )
1164
+ kl = normal_kl(
1165
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
1166
+ )
1167
+ kl = mean_flat(kl) / np.log(2.0)
1168
+
1169
+ decoder_nll = -discretized_gaussian_log_likelihood(
1170
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
1171
+ )
1172
+ assert decoder_nll.shape == x_start.shape
1173
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
1174
+
1175
+ # At the first timestep return the decoder NLL,
1176
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
1177
+ output = th.where((t == 0), decoder_nll, kl)
1178
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
1179
+
1180
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
1181
+ """
1182
+ Compute training losses for a single timestep.
1183
+
1184
+ :param model: the model to evaluate loss on.
1185
+ :param x_start: the [N x C x ...] tensor of inputs.
1186
+ :param t: a batch of timestep indices.
1187
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1188
+ pass to the model. This can be used for conditioning.
1189
+ :param noise: if specified, the specific Gaussian noise to try to remove.
1190
+ :return: a dict with the key "loss" containing a tensor of shape [N].
1191
+ Some mean or variance settings may also have other keys.
1192
+ """
1193
+ if model_kwargs is None:
1194
+ model_kwargs = {}
1195
+ if noise is None:
1196
+ noise = th.randn_like(x_start)
1197
+ x_t = self.q_sample(x_start, t, noise=noise)
1198
+
1199
+ terms = {}
1200
+
1201
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
1202
+ terms["loss"] = self._vb_terms_bpd(
1203
+ model=model,
1204
+ x_start=x_start,
1205
+ x_t=x_t,
1206
+ t=t,
1207
+ clip_denoised=False,
1208
+ model_kwargs=model_kwargs,
1209
+ )["output"]
1210
+ if self.loss_type == LossType.RESCALED_KL:
1211
+ terms["loss"] *= self.num_timesteps
1212
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
1213
+ model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
1214
+
1215
+ if self.model_var_type in [
1216
+ ModelVarType.LEARNED,
1217
+ ModelVarType.LEARNED_RANGE,
1218
+ ]:
1219
+ B, C = x_t.shape[:2]
1220
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
1221
+ model_output, model_var_values = th.split(model_output, C, dim=1)
1222
+ # Learn the variance using the variational bound, but don't let
1223
+ # it affect our mean prediction.
1224
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
1225
+ terms["vb"] = self._vb_terms_bpd(
1226
+ model=lambda *args, r=frozen_out: r,
1227
+ x_start=x_start,
1228
+ x_t=x_t,
1229
+ t=t,
1230
+ clip_denoised=False,
1231
+ )["output"]
1232
+ if self.loss_type == LossType.RESCALED_MSE:
1233
+ # Divide by 1000 for equivalence with initial implementation.
1234
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
1235
+ terms["vb"] *= self.num_timesteps / 1000.0
1236
+
1237
+ target = {
1238
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
1239
+ x_start=x_start, x_t=x_t, t=t
1240
+ )[0],
1241
+ ModelMeanType.START_X: x_start,
1242
+ ModelMeanType.EPSILON: noise,
1243
+ }[self.model_mean_type]
1244
+ assert model_output.shape == target.shape == x_start.shape
1245
+ terms["mse"] = mean_flat((target - model_output) ** 2)
1246
+ if "vb" in terms:
1247
+ terms["loss"] = terms["mse"] + terms["vb"]
1248
+ else:
1249
+ terms["loss"] = terms["mse"]
1250
+ else:
1251
+ raise NotImplementedError(self.loss_type)
1252
+
1253
+ return terms
1254
+
1255
+ def _prior_bpd(self, x_start):
1256
+ """
1257
+ Get the prior KL term for the variational lower-bound, measured in
1258
+ bits-per-dim.
1259
+
1260
+ This term can't be optimized, as it only depends on the encoder.
1261
+
1262
+ :param x_start: the [N x C x ...] tensor of inputs.
1263
+ :return: a batch of [N] KL values (in bits), one per batch element.
1264
+ """
1265
+ batch_size = x_start.shape[0]
1266
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1267
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1268
+ kl_prior = normal_kl(
1269
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
1270
+ )
1271
+ return mean_flat(kl_prior) / np.log(2.0)
1272
+
1273
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
1274
+ """
1275
+ Compute the entire variational lower-bound, measured in bits-per-dim,
1276
+ as well as other related quantities.
1277
+
1278
+ :param model: the model to evaluate loss on.
1279
+ :param x_start: the [N x C x ...] tensor of inputs.
1280
+ :param clip_denoised: if True, clip denoised samples.
1281
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1282
+ pass to the model. This can be used for conditioning.
1283
+
1284
+ :return: a dict containing the following keys:
1285
+ - total_bpd: the total variational lower-bound, per batch element.
1286
+ - prior_bpd: the prior term in the lower-bound.
1287
+ - vb: an [N x T] tensor of terms in the lower-bound.
1288
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
1289
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
1290
+ """
1291
+ device = x_start.device
1292
+ batch_size = x_start.shape[0]
1293
+
1294
+ vb = []
1295
+ xstart_mse = []
1296
+ mse = []
1297
+ for t in list(range(self.num_timesteps))[::-1]:
1298
+ t_batch = th.tensor([t] * batch_size, device=device)
1299
+ noise = th.randn_like(x_start)
1300
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
1301
+ # Calculate VLB term at the current timestep
1302
+ with th.no_grad():
1303
+ out = self._vb_terms_bpd(
1304
+ model,
1305
+ x_start=x_start,
1306
+ x_t=x_t,
1307
+ t=t_batch,
1308
+ clip_denoised=clip_denoised,
1309
+ model_kwargs=model_kwargs,
1310
+ )
1311
+ vb.append(out["output"])
1312
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
1313
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
1314
+ mse.append(mean_flat((eps - noise) ** 2))
1315
+
1316
+ vb = th.stack(vb, dim=1)
1317
+ xstart_mse = th.stack(xstart_mse, dim=1)
1318
+ mse = th.stack(mse, dim=1)
1319
+
1320
+ prior_bpd = self._prior_bpd(x_start)
1321
+ total_bpd = vb.sum(dim=1) + prior_bpd
1322
+ return {
1323
+ "total_bpd": total_bpd,
1324
+ "prior_bpd": prior_bpd,
1325
+ "vb": vb,
1326
+ "xstart_mse": xstart_mse,
1327
+ "mse": mse,
1328
+ }
1329
+
1330
+
1331
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1332
+ """
1333
+ Extract values from a 1-D numpy array for a batch of indices.
1334
+
1335
+ :param arr: the 1-D numpy array.
1336
+ :param timesteps: a tensor of indices into the array to extract.
1337
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1338
+ dimension equal to the length of timesteps.
1339
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1340
+ """
1341
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1342
+ while len(res.shape) < len(broadcast_shape):
1343
+ res = res[..., None]
1344
+ return res.expand(broadcast_shape)
1345
+
1346
+
1347
+ def _decode(pred_zstart, embed_model, scale_factor=1., threshold=False):
1348
+ image_size_h = pred_zstart.shape[-2]
1349
+ image_size_w = pred_zstart.shape[-1]
1350
+ pred_zstart = pred_zstart / scale_factor
1351
+ sample = pred_zstart.permute(0, 1, 3, 2)
1352
+ sample = th.chunk(sample, image_size_h // image_size_w, dim=-1) # B x C x H x W
1353
+ sample = th.concat(sample, dim=0) # 1st second for all batch, 2nd second for all batch, ...
1354
+ sample = embed_model.decode(sample)
1355
+ pred_xstart = th.concat(th.chunk(sample, image_size_h // image_size_w, dim=0), dim=-1)
1356
+ if threshold:
1357
+ pred_xstart[pred_xstart <= -0.95] = -1. # heuristic thresholding the background
1358
+ return pred_xstart
1359
+
1360
+
1361
+ def _extract_rule(rule_name, pred_xstart):
1362
+ device = pred_xstart.device
1363
+ if 'chord' in rule_name:
1364
+ # Split tensor batch into smaller batches
1365
+ num_processes = 4
1366
+ pred_xstart = pred_xstart.cpu()
1367
+ pred_xstart_split = th.chunk(pred_xstart, num_processes)
1368
+ # rule_func = partial(FUNC_DICT[rule_name], given_key="C major") # todo: hard code key here
1369
+ rule_func = FUNC_DICT[rule_name]
1370
+ with multiprocessing.Pool(processes=num_processes) as pool:
1371
+ gen_rule = pool.map(rule_func, pred_xstart_split)
1372
+ # Combine results
1373
+ if len(gen_rule[0].shape) == 1: # batch_size * branching_factor < 4
1374
+ gen_rule = [item.unsqueeze(dim=0) for item in gen_rule]
1375
+ gen_rule = th.concat(gen_rule, dim=0).to(device)
1376
+
1377
+ else:
1378
+ gen_rule = FUNC_DICT[rule_name](pred_xstart)
1379
+ return gen_rule
1380
+
1381
+
1382
+ def _encode(pred_xstart, embed_model, scale_factor=1.):
1383
+ image_size_h = pred_xstart.shape[-2]
1384
+ image_size_w = pred_xstart.shape[-1]
1385
+ seq_len = image_size_w // image_size_h
1386
+ micro = th.chunk(pred_xstart, seq_len, dim=-1) # B x C x H x W
1387
+ micro = th.concat(micro, dim=0) # 1st second for all batch, 2nd second for all batch, ...
1388
+ micro = embed_model.encode_save(micro, range_fix=False)
1389
+ if micro.shape[1] == 8:
1390
+ z, _ = th.chunk(micro, 2, dim=1)
1391
+ else:
1392
+ z = micro
1393
+ z = th.concat(th.chunk(z, seq_len, dim=0), dim=-1)
1394
+ z = z.permute(0, 1, 3, 2)
1395
+ return z * scale_factor
1396
+
1397
+
1398
+ def guide_schedule(t, t_start=750, t_end=0, interval=1):
1399
+ flag = t_start > t[0] >= t_end and (t[0] + 1) % interval == 0
1400
+ return flag
guided_diffusion/logger.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3
+ https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import shutil
9
+ import os.path as osp
10
+ import json
11
+ import time
12
+ import datetime
13
+ import tempfile
14
+ import warnings
15
+ from collections import defaultdict
16
+ from contextlib import contextmanager
17
+
18
+ import wandb
19
+
20
+ DEBUG = 10
21
+ INFO = 20
22
+ WARN = 30
23
+ ERROR = 40
24
+
25
+ DISABLED = 50
26
+
27
+
28
+ class KVWriter(object):
29
+ def writekvs(self, kvs):
30
+ raise NotImplementedError
31
+
32
+
33
+ class SeqWriter(object):
34
+ def writeseq(self, seq):
35
+ raise NotImplementedError
36
+
37
+
38
+ class HumanOutputFormat(KVWriter, SeqWriter):
39
+ def __init__(self, filename_or_file):
40
+ if isinstance(filename_or_file, str):
41
+ self.file = open(filename_or_file, "wt")
42
+ self.own_file = True
43
+ else:
44
+ assert hasattr(filename_or_file, "read"), (
45
+ "expected file or str, got %s" % filename_or_file
46
+ )
47
+ self.file = filename_or_file
48
+ self.own_file = False
49
+
50
+ def writekvs(self, kvs):
51
+ # Create strings for printing
52
+ key2str = {}
53
+ for (key, val) in sorted(kvs.items()):
54
+ if hasattr(val, "__float__"):
55
+ valstr = "%-8.3g" % val
56
+ else:
57
+ valstr = str(val)
58
+ key2str[self._truncate(key)] = self._truncate(valstr)
59
+
60
+ # Find max widths
61
+ if len(key2str) == 0:
62
+ print("WARNING: tried to write empty key-value dict")
63
+ return
64
+ else:
65
+ keywidth = max(map(len, key2str.keys()))
66
+ valwidth = max(map(len, key2str.values()))
67
+
68
+ # Write out the data
69
+ dashes = "-" * (keywidth + valwidth + 7)
70
+ lines = [dashes]
71
+ for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
72
+ lines.append(
73
+ "| %s%s | %s%s |"
74
+ % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
75
+ )
76
+ lines.append(dashes)
77
+ self.file.write("\n".join(lines) + "\n")
78
+
79
+ # Flush the output to the file
80
+ self.file.flush()
81
+
82
+ def _truncate(self, s):
83
+ maxlen = 30
84
+ return s[: maxlen - 3] + "..." if len(s) > maxlen else s
85
+
86
+ def writeseq(self, seq):
87
+ seq = list(seq)
88
+ for (i, elem) in enumerate(seq):
89
+ self.file.write(elem)
90
+ if i < len(seq) - 1: # add space unless this is the last one
91
+ self.file.write(" ")
92
+ self.file.write("\n")
93
+ self.file.flush()
94
+
95
+ def close(self):
96
+ if self.own_file:
97
+ self.file.close()
98
+
99
+
100
+ class JSONOutputFormat(KVWriter):
101
+ def __init__(self, filename):
102
+ self.file = open(filename, "wt")
103
+
104
+ def writekvs(self, kvs):
105
+ for k, v in sorted(kvs.items()):
106
+ if hasattr(v, "dtype"):
107
+ kvs[k] = float(v)
108
+ self.file.write(json.dumps(kvs) + "\n")
109
+ self.file.flush()
110
+
111
+ def close(self):
112
+ self.file.close()
113
+
114
+
115
+ class CSVOutputFormat(KVWriter):
116
+ def __init__(self, filename):
117
+ self.file = open(filename, "w+t")
118
+ self.keys = []
119
+ self.sep = ","
120
+
121
+ def writekvs(self, kvs):
122
+ # Add our current row to the history
123
+ extra_keys = list(kvs.keys() - self.keys)
124
+ extra_keys.sort()
125
+ if extra_keys:
126
+ self.keys.extend(extra_keys)
127
+ self.file.seek(0)
128
+ lines = self.file.readlines()
129
+ self.file.seek(0)
130
+ for (i, k) in enumerate(self.keys):
131
+ if i > 0:
132
+ self.file.write(",")
133
+ self.file.write(k)
134
+ self.file.write("\n")
135
+ for line in lines[1:]:
136
+ self.file.write(line[:-1])
137
+ self.file.write(self.sep * len(extra_keys))
138
+ self.file.write("\n")
139
+ for (i, k) in enumerate(self.keys):
140
+ if i > 0:
141
+ self.file.write(",")
142
+ v = kvs.get(k)
143
+ if v is not None:
144
+ self.file.write(str(v))
145
+ self.file.write("\n")
146
+ self.file.flush()
147
+
148
+ def close(self):
149
+ self.file.close()
150
+
151
+
152
+ class TensorBoardOutputFormat(KVWriter):
153
+ """
154
+ Dumps key/value pairs into TensorBoard's numeric format.
155
+ """
156
+
157
+ def __init__(self, dir):
158
+ os.makedirs(dir, exist_ok=True)
159
+ self.dir = dir
160
+ self.step = 1
161
+ prefix = "events"
162
+ path = osp.join(osp.abspath(dir), prefix)
163
+ import tensorflow as tf
164
+ from tensorflow.python import pywrap_tensorflow
165
+ from tensorflow.core.util import event_pb2
166
+ from tensorflow.python.util import compat
167
+
168
+ self.tf = tf
169
+ self.event_pb2 = event_pb2
170
+ self.pywrap_tensorflow = pywrap_tensorflow
171
+ self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
172
+
173
+ def writekvs(self, kvs):
174
+ def summary_val(k, v):
175
+ kwargs = {"tag": k, "simple_value": float(v)}
176
+ return self.tf.Summary.Value(**kwargs)
177
+
178
+ summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
179
+ event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
180
+ event.step = (
181
+ self.step
182
+ ) # is there any reason why you'd want to specify the step?
183
+ self.writer.WriteEvent(event)
184
+ self.writer.Flush()
185
+ self.step += 1
186
+
187
+ def close(self):
188
+ if self.writer:
189
+ self.writer.Close()
190
+ self.writer = None
191
+
192
+
193
+ class WandbOutputFormat(KVWriter):
194
+ def __init__(self, args):
195
+ wandb.init(project=args.project, config=vars(args))
196
+
197
+ def writekvs(self, kvs):
198
+ step = int(kvs["step"])
199
+ wandb.log(kvs, step=step)
200
+
201
+ def close(self):
202
+ pass
203
+
204
+
205
+ def make_output_format(format, ev_dir, args, log_suffix=""):
206
+ os.makedirs(ev_dir, exist_ok=True)
207
+ if format == "stdout":
208
+ return HumanOutputFormat(sys.stdout)
209
+ elif format == "log":
210
+ return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
211
+ elif format == "json":
212
+ return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
213
+ elif format == "csv":
214
+ return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
215
+ elif format == "tensorboard":
216
+ return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
217
+ elif format == "wandb":
218
+ return WandbOutputFormat(args)
219
+ else:
220
+ raise ValueError("Unknown format specified: %s" % (format,))
221
+
222
+
223
+ # ================================================================
224
+ # API
225
+ # ================================================================
226
+
227
+
228
+ def logkv(key, val):
229
+ """
230
+ Log a value of some diagnostic
231
+ Call this once for each diagnostic quantity, each iteration
232
+ If called many times, last value will be used.
233
+ """
234
+ get_current().logkv(key, val)
235
+
236
+
237
+ def logkv_mean(key, val):
238
+ """
239
+ The same as logkv(), but if called many times, values averaged.
240
+ """
241
+ get_current().logkv_mean(key, val)
242
+
243
+
244
+ def logkvs(d):
245
+ """
246
+ Log a dictionary of key-value pairs
247
+ """
248
+ for (k, v) in d.items():
249
+ logkv(k, v)
250
+
251
+
252
+ def dumpkvs():
253
+ """
254
+ Write all of the diagnostics from the current iteration
255
+ """
256
+ return get_current().dumpkvs()
257
+
258
+
259
+ def getkvs():
260
+ return get_current().name2val
261
+
262
+
263
+ def log(*args, level=INFO):
264
+ """
265
+ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
266
+ """
267
+ get_current().log(*args, level=level)
268
+
269
+
270
+ def debug(*args):
271
+ log(*args, level=DEBUG)
272
+
273
+
274
+ def info(*args):
275
+ log(*args, level=INFO)
276
+
277
+
278
+ def warn(*args):
279
+ log(*args, level=WARN)
280
+
281
+
282
+ def error(*args):
283
+ log(*args, level=ERROR)
284
+
285
+
286
+ def set_level(level):
287
+ """
288
+ Set logging threshold on current logger.
289
+ """
290
+ get_current().set_level(level)
291
+
292
+
293
+ def set_comm(comm):
294
+ get_current().set_comm(comm)
295
+
296
+
297
+ def get_dir():
298
+ """
299
+ Get directory that log files are being written to.
300
+ will be None if there is no output directory (i.e., if you didn't call start)
301
+ """
302
+ return get_current().get_dir()
303
+
304
+
305
+ record_tabular = logkv
306
+ dump_tabular = dumpkvs
307
+
308
+
309
+ @contextmanager
310
+ def profile_kv(scopename):
311
+ logkey = "wait_" + scopename
312
+ tstart = time.time()
313
+ try:
314
+ yield
315
+ finally:
316
+ get_current().name2val[logkey] += time.time() - tstart
317
+
318
+
319
+ def profile(n):
320
+ """
321
+ Usage:
322
+ @profile("my_func")
323
+ def my_func(): code
324
+ """
325
+
326
+ def decorator_with_name(func):
327
+ def func_wrapper(*args, **kwargs):
328
+ with profile_kv(n):
329
+ return func(*args, **kwargs)
330
+
331
+ return func_wrapper
332
+
333
+ return decorator_with_name
334
+
335
+
336
+ # ================================================================
337
+ # Backend
338
+ # ================================================================
339
+
340
+
341
+ def get_current():
342
+ if Logger.CURRENT is None:
343
+ _configure_default_logger()
344
+
345
+ return Logger.CURRENT
346
+
347
+
348
+ class Logger(object):
349
+ DEFAULT = None # A logger with no output files. (See right below class definition)
350
+ # So that you can still log to the terminal without setting up any output files
351
+ CURRENT = None # Current logger being used by the free functions above
352
+
353
+ def __init__(self, dir, output_formats, comm=None):
354
+ self.name2val = defaultdict(float) # values this iteration
355
+ self.name2cnt = defaultdict(int)
356
+ self.level = INFO
357
+ self.dir = dir
358
+ self.output_formats = output_formats
359
+ self.comm = comm
360
+
361
+ # Logging API, forwarded
362
+ # ----------------------------------------
363
+ def logkv(self, key, val):
364
+ self.name2val[key] = val
365
+
366
+ def logkv_mean(self, key, val):
367
+ oldval, cnt = self.name2val[key], self.name2cnt[key]
368
+ self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
369
+ self.name2cnt[key] = cnt + 1
370
+
371
+ def dumpkvs(self):
372
+ if self.comm is None:
373
+ d = self.name2val
374
+ else:
375
+ d = mpi_weighted_mean(
376
+ self.comm,
377
+ {
378
+ name: (val, self.name2cnt.get(name, 1))
379
+ for (name, val) in self.name2val.items()
380
+ },
381
+ )
382
+ if self.comm.rank != 0:
383
+ d["dummy"] = 1 # so we don't get a warning about empty dict
384
+ out = d.copy() # Return the dict for unit testing purposes
385
+ for fmt in self.output_formats:
386
+ if isinstance(fmt, KVWriter):
387
+ fmt.writekvs(d)
388
+ self.name2val.clear()
389
+ self.name2cnt.clear()
390
+ return out
391
+
392
+ def log(self, *args, level=INFO):
393
+ if self.level <= level:
394
+ self._do_log(args)
395
+
396
+ # Configuration
397
+ # ----------------------------------------
398
+ def set_level(self, level):
399
+ self.level = level
400
+
401
+ def set_comm(self, comm):
402
+ self.comm = comm
403
+
404
+ def get_dir(self):
405
+ return self.dir
406
+
407
+ def close(self):
408
+ for fmt in self.output_formats:
409
+ fmt.close()
410
+
411
+ # Misc
412
+ # ----------------------------------------
413
+ def _do_log(self, args):
414
+ for fmt in self.output_formats:
415
+ if isinstance(fmt, SeqWriter):
416
+ fmt.writeseq(map(str, args))
417
+
418
+
419
+ def get_rank_without_mpi_import():
420
+ # check environment variables here instead of importing mpi4py
421
+ # to avoid calling MPI_Init() when this module is imported
422
+ for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
423
+ if varname in os.environ:
424
+ return int(os.environ[varname])
425
+ return 0
426
+
427
+
428
+ def mpi_weighted_mean(comm, local_name2valcount):
429
+ """
430
+ Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
431
+ Perform a weighted average over dicts that are each on a different node
432
+ Input: local_name2valcount: dict mapping key -> (value, count)
433
+ Returns: key -> mean
434
+ """
435
+ all_name2valcount = comm.gather(local_name2valcount)
436
+ if comm.rank == 0:
437
+ name2sum = defaultdict(float)
438
+ name2count = defaultdict(float)
439
+ for n2vc in all_name2valcount:
440
+ for (name, (val, count)) in n2vc.items():
441
+ try:
442
+ val = float(val)
443
+ except ValueError:
444
+ if comm.rank == 0:
445
+ warnings.warn(
446
+ "WARNING: tried to compute mean on non-float {}={}".format(
447
+ name, val
448
+ )
449
+ )
450
+ else:
451
+ name2sum[name] += val * count
452
+ name2count[name] += count
453
+ return {name: name2sum[name] / name2count[name] for name in name2sum}
454
+ else:
455
+ return {}
456
+
457
+
458
+ def configure(args=None, format_strs=None, comm=None, log_suffix=""):
459
+ """
460
+ If comm is provided, average all numerical stats across that comm
461
+ """
462
+ dir = args.dir
463
+ if dir is not None:
464
+ if "loggings" not in dir: # save under cur dir
465
+ dir = osp.join("loggings", dir)
466
+ else:
467
+ if dir is None:
468
+ dir = os.getenv("OPENAI_LOGDIR")
469
+ if dir is None:
470
+ dir = osp.join(
471
+ # tempfile.gettempdir(),
472
+ "loggings",
473
+ datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
474
+ )
475
+ assert isinstance(dir, str)
476
+ dir = os.path.expanduser(dir)
477
+ os.makedirs(os.path.expanduser(dir), exist_ok=True)
478
+ if args.training:
479
+ # make dir for samples and checkpoints if training the model
480
+ os.makedirs(os.path.expanduser(osp.join(dir, "samples")), exist_ok=True)
481
+ os.makedirs(os.path.expanduser(osp.join(dir, "checkpoints")), exist_ok=True)
482
+
483
+ rank = get_rank_without_mpi_import()
484
+ if rank > 0:
485
+ log_suffix = log_suffix + "-rank%03i" % rank
486
+
487
+ if format_strs is None:
488
+ if rank == 0:
489
+ format_strs = os.getenv("OPENAI_LOG_FORMAT", "wandb,stdout,log,csv").split(",")
490
+ else:
491
+ format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
492
+ format_strs = filter(None, format_strs)
493
+ output_formats = [make_output_format(f, dir, args, log_suffix) for f in format_strs]
494
+
495
+ Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
496
+ if output_formats:
497
+ log("Logging to %s" % dir)
498
+
499
+
500
+ def _configure_default_logger():
501
+ configure()
502
+ Logger.DEFAULT = Logger.CURRENT
503
+
504
+
505
+ def reset():
506
+ if Logger.CURRENT is not Logger.DEFAULT:
507
+ Logger.CURRENT.close()
508
+ Logger.CURRENT = Logger.DEFAULT
509
+ log("Reset logger")
510
+
511
+
512
+ @contextmanager
513
+ def scoped_configure(dir=None, format_strs=None, comm=None):
514
+ prevlogger = Logger.CURRENT
515
+ configure(dir=dir, format_strs=format_strs, comm=comm)
516
+ try:
517
+ yield
518
+ finally:
519
+ Logger.CURRENT.close()
520
+ Logger.CURRENT = prevlogger
521
+
guided_diffusion/losses.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for various likelihood-based losses. These are ported from the original
3
+ Ho et al. diffusion models codebase:
4
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ import torch as th
10
+
11
+
12
+ def normal_kl(mean1, logvar1, mean2, logvar2):
13
+ """
14
+ Compute the KL divergence between two gaussians.
15
+
16
+ Shapes are automatically broadcasted, so batches can be compared to
17
+ scalars, among other use cases.
18
+ """
19
+ tensor = None
20
+ for obj in (mean1, logvar1, mean2, logvar2):
21
+ if isinstance(obj, th.Tensor):
22
+ tensor = obj
23
+ break
24
+ assert tensor is not None, "at least one argument must be a Tensor"
25
+
26
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
27
+ # Tensors, but it does not work for th.exp().
28
+ logvar1, logvar2 = [
29
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30
+ for x in (logvar1, logvar2)
31
+ ]
32
+
33
+ return 0.5 * (
34
+ -1.0
35
+ + logvar2
36
+ - logvar1
37
+ + th.exp(logvar1 - logvar2)
38
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39
+ )
40
+
41
+
42
+ def approx_standard_normal_cdf(x):
43
+ """
44
+ A fast approximation of the cumulative distribution function of the
45
+ standard normal.
46
+ """
47
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48
+
49
+
50
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51
+ """
52
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
53
+ given image.
54
+
55
+ :param x: the target images. It is assumed that this was uint8 values,
56
+ rescaled to the range [-1, 1].
57
+ :param means: the Gaussian mean Tensor.
58
+ :param log_scales: the Gaussian log stddev Tensor.
59
+ :return: a tensor like x of log probabilities (in nats).
60
+ """
61
+ assert x.shape == means.shape == log_scales.shape
62
+ centered_x = x - means
63
+ inv_stdv = th.exp(-log_scales)
64
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65
+ cdf_plus = approx_standard_normal_cdf(plus_in)
66
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67
+ cdf_min = approx_standard_normal_cdf(min_in)
68
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70
+ cdf_delta = cdf_plus - cdf_min
71
+ log_probs = th.where(
72
+ x < -0.999,
73
+ log_cdf_plus,
74
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75
+ )
76
+ assert log_probs.shape == x.shape
77
+ return log_probs
guided_diffusion/midi_util.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ import pandas as pd
6
+ import pretty_midi
7
+ import matplotlib as mpl
8
+ import matplotlib.pyplot as plt
9
+ from . import dist_util
10
+ import yaml
11
+ from types import SimpleNamespace
12
+ from music_rule_guidance.piano_roll_to_chord import piano_roll_to_pretty_midi, KEY_DICT, IND2KEY
13
+ from music_rule_guidance.rule_maps import FUNC_DICT, LOSS_DICT
14
+ from music_rule_guidance.music_rules import MAX_PIANO, MIN_PIANO
15
+
16
+ plt.rcParams['figure.dpi'] = 300
17
+ plt.rcParams['savefig.dpi'] = 300
18
+
19
+ # bounds to compute classes for nd editing
20
+ VERTICAL_ND_BOUNDS = [1.29, 2.7578125, 3.61, 4.4921875, 5.28125, 6.1171875, 7.22]
21
+ VERTICAL_ND_CENTER = [0.56, 2.0239, 3.1839, 4.0511, 4.8867, 5.6992, 6.6686, 7.77]
22
+ HORIZONTAL_ND_BOUNDS = [1.8, 2.6, 3.2, 3.6, 4.4, 4.8, 5.8]
23
+ HORIZONTAL_ND_CENTER = [1.4, 2.2000, 2.9, 3.4, 4.0, 4.6, 5.3, 6.3]
24
+
25
+
26
+ def dict_to_obj(d):
27
+ if isinstance(d, list):
28
+ d = [dict_to_obj(x) if isinstance(x, dict) else x for x in d]
29
+ if not isinstance(d, dict):
30
+ return d
31
+ return SimpleNamespace(**{k: dict_to_obj(v) for k, v in d.items()})
32
+
33
+
34
+ def load_config(filename):
35
+ with open(filename, 'r') as file:
36
+ data = yaml.safe_load(file)
37
+ # Convert the dictionary to an object
38
+ data_obj = dict_to_obj(data)
39
+ return data_obj
40
+
41
+
42
+ @torch.no_grad()
43
+ def decode_sample_for_midi(sample, embed_model=None, scale_factor=1., threshold=-0.95):
44
+ # decode latent samples to a long piano roll of [0, 127]
45
+ sample = sample / scale_factor
46
+
47
+ if embed_model is not None:
48
+ image_size_h = sample.shape[-2]
49
+ image_size_w = sample.shape[-1]
50
+ if image_size_h > image_size_w: # transposed for raster col, don't need to permute for pixel space
51
+ sample = sample.permute(0, 1, 3, 2) # vertical axis means pitch after transpose
52
+ num_latents = sample.shape[-1] // sample.shape[-2]
53
+ if image_size_h >= image_size_w:
54
+ sample = torch.chunk(sample, num_latents, dim=-1) # B x C x H x W
55
+ sample = torch.concat(sample, dim=0) # 1st second for all batch, 2nd second for all batch, ...
56
+ sample = embed_model.decode(sample)
57
+ if image_size_h >= image_size_w:
58
+ sample = torch.concat(torch.chunk(sample, num_latents, dim=0), dim=-1)
59
+
60
+ sample[sample <= threshold] = -1. # heuristic thresholding the background
61
+ sample = ((sample + 1) * 63.5).clamp(0, 127).to(torch.uint8)
62
+ sample = sample.permute(0, 2, 3, 1)
63
+ sample = sample.contiguous()
64
+ return sample
65
+
66
+
67
+ def save_piano_roll_midi(sample, save_dir, fs=100, y=None, save_piano_roll=False, save_ind=0):
68
+ # input shape: B x 128 (pitch) x time (no pedal) or B x 2 (pedal) x 128 x time (with pedal)
69
+ fig_size = sample.shape[-1] // 128 * 3
70
+ plt.rcParams["figure.figsize"] = (fig_size, 3)
71
+ pedal = True if len(sample.shape) == 4 else False
72
+ onset = True if sample.shape[1] == 3 else False
73
+ for i in range(sample.shape[0]):
74
+ cur_sample = sample[i]
75
+ if cur_sample.shape[-1] < 5000 and save_piano_roll: # do not save piano rolls that are too long
76
+ if pedal:
77
+ plt.imshow(cur_sample[0, ::-1], vmin=0, vmax=127)
78
+ else:
79
+ plt.imshow(cur_sample[::-1], vmin=0, vmax=127)
80
+ plt.savefig(os.path.join(save_dir, "prsample_" + str(i) + ".png"))
81
+ if onset:
82
+ # add onset for first column
83
+ first_column = cur_sample[0, :, 0]
84
+ first_onset_pitch = first_column.nonzero()[0]
85
+ cur_sample[1, first_onset_pitch, 0] = 127
86
+ cur_sample = cur_sample.astype(np.float32)
87
+ pm = piano_roll_to_pretty_midi(cur_sample, fs=fs)
88
+ if y is not None:
89
+ save_name = 'sample_' + str(i + save_ind) + '_y_' + str(y[i].item()) + '.midi'
90
+ else:
91
+ save_name = 'sample_' + str(i + save_ind) + '.midi'
92
+ pm.write(os.path.join(save_dir, save_name))
93
+ return
94
+
95
+
96
+ def eval_rule_loss(generated_samples, target_rules):
97
+ results = {}
98
+ batch_size = generated_samples.shape[0]
99
+ for rule_name, rule_target in target_rules.items():
100
+ rule_target_list = rule_target.tolist()
101
+ if batch_size == 1:
102
+ rule_target_list = [rule_target_list]
103
+ results[rule_name + '.target_rule'] = rule_target_list
104
+ rule_target = rule_target.to(generated_samples.device)
105
+ if 'chord' in rule_name:
106
+ gen_rule, key, corr = FUNC_DICT[rule_name](generated_samples, return_key=True)
107
+ key_strings = [IND2KEY[key_ind] for key_ind in key]
108
+ loss = LOSS_DICT[rule_name](gen_rule, rule_target)
109
+ mean_loss, std_loss, gen_rule, loss = loss.mean(), loss.std(), gen_rule.tolist(), loss.tolist()
110
+ if batch_size == 1:
111
+ gen_rule = [gen_rule]
112
+ results[rule_name + '.gen_rule'] = gen_rule
113
+ results[rule_name + '.key_str'] = key_strings
114
+ results[rule_name + '.key_corr'] = corr
115
+ results[rule_name + '.loss'] = loss
116
+ else:
117
+ gen_rule = FUNC_DICT[rule_name](generated_samples)
118
+ loss = LOSS_DICT[rule_name](gen_rule, rule_target)
119
+ mean_loss, std_loss, gen_rule, loss = loss.mean(), loss.std(), gen_rule.tolist(), loss.tolist()
120
+ if batch_size == 1:
121
+ gen_rule = [gen_rule]
122
+ results[rule_name + '.gen_rule'] = gen_rule
123
+ results[rule_name + '.loss'] = loss
124
+ return pd.DataFrame(results)
125
+
126
+
127
+ def compute_rule(generated_samples, orig_samples, target_rules):
128
+ results = {}
129
+ batch_size = generated_samples.shape[0]
130
+ for rule_name in target_rules:
131
+ rule_target = FUNC_DICT[rule_name](orig_samples)
132
+ rule_target_list = rule_target.tolist()
133
+ if batch_size == 1:
134
+ rule_target_list = [rule_target_list]
135
+ results[rule_name + '.target_rule'] = rule_target_list
136
+ rule_target = rule_target.to(generated_samples.device)
137
+ if rule_name == 'chord_progression':
138
+ gen_rule, key, corr = FUNC_DICT[rule_name](generated_samples, return_key=True)
139
+ key_strings = [IND2KEY[key_ind] for key_ind in key]
140
+ loss = LOSS_DICT[rule_name](gen_rule, rule_target)
141
+ mean_loss, std_loss, gen_rule, loss = loss.mean(), loss.std(), gen_rule.tolist(), loss.tolist()
142
+ if batch_size == 1:
143
+ gen_rule = [gen_rule]
144
+ results[rule_name + '.gen_rule'] = gen_rule
145
+ results[rule_name + '.key_str'] = key_strings
146
+ results[rule_name + '.key_corr'] = corr
147
+ results[rule_name + '.loss'] = loss
148
+ else:
149
+ gen_rule = FUNC_DICT[rule_name](generated_samples)
150
+ loss = LOSS_DICT[rule_name](gen_rule, rule_target)
151
+ mean_loss, std_loss, gen_rule, loss = loss.mean(), loss.std(), gen_rule.tolist(), loss.tolist()
152
+ if batch_size == 1:
153
+ gen_rule = [gen_rule]
154
+ results[rule_name + '.gen_rule'] = gen_rule
155
+ results[rule_name + '.loss'] = loss
156
+ return pd.DataFrame(results)
157
+
158
+
159
+ def visualize_piano_roll(piano_roll):
160
+ """
161
+ Assuming piano roll has shape Bx1x128x1024, and the values are between [-1, 1], on gpu.
162
+ Visualize with some gap in between the first 256, last 256/
163
+ """
164
+ piano_roll = torch.flip(piano_roll, [2])
165
+ piano_roll = (piano_roll + 1) / 2.
166
+ vis_length = 256
167
+ gap = 80
168
+ plt.rcParams["figure.figsize"] = (12, 3)
169
+ data = torch.zeros(128, vis_length * 2 + gap)
170
+ data[:, :vis_length] = piano_roll[0, 0, :, :vis_length]
171
+ data[:, -vis_length:] = piano_roll[0, 0, :, -vis_length:]
172
+ data_clone = data.clone()
173
+ # make it look thicker
174
+ data[1:, :] = data[1:, :] + data_clone[:-1, :]
175
+ data[2:, :] = data[2:, :] + data_clone[:-2, :]
176
+ data = data.cpu().numpy()
177
+ plt.imshow(data, cmap=mpl.colormaps['Blues'])
178
+ ax = plt.gca() # gca stands for 'get current axis'
179
+ for edge, spine in ax.spines.items():
180
+ spine.set_linewidth(2) # Adjust the value as per your requirement
181
+ plt.grid(color='gray', linestyle='-', linewidth=2., alpha=0.5, which='both', axis='x')
182
+ plt.xticks(
183
+ np.concatenate((np.arange(0, vis_length + 1, 128), np.arange(vis_length + gap, vis_length * 2 + gap, 128))))
184
+ # plt.savefig('piano_roll_example.png', bbox_inches='tight', pad_inches=0.1, dpi=300)
185
+ plt.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)
186
+ plt.tight_layout()
187
+ plt.show()
188
+
189
+ plt.rcParams["figure.figsize"] = (3, 3)
190
+ for i in range(2):
191
+ plt.imshow(data[:, i*128: (i+1)*128], cmap=mpl.colormaps['Blues'])
192
+ ax = plt.gca()
193
+ for edge, spine in ax.spines.items():
194
+ spine.set_linewidth(2)
195
+ plt.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)
196
+ plt.tight_layout()
197
+ plt.show()
198
+
199
+ for i in range(-2, 0):
200
+ if (i+1)*128 < 0:
201
+ plt.imshow(data[:, i*128: (i+1)*128], cmap=mpl.colormaps['Blues'])
202
+ else:
203
+ plt.imshow(data[:, i*128:], cmap=mpl.colormaps['Blues'])
204
+ ax = plt.gca()
205
+ for edge, spine in ax.spines.items():
206
+ spine.set_linewidth(2)
207
+ plt.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)
208
+ plt.tight_layout()
209
+ plt.show()
210
+
211
+ return
212
+
213
+
214
+ def visualize_full_piano_roll(midi_file_name, fs=100):
215
+ """
216
+ Visualize full piano roll from midi file
217
+ """
218
+ midi_data = pretty_midi.PrettyMIDI(midi_file_name)
219
+ # do not process sustain pedal
220
+ piano_roll = torch.tensor(midi_data.get_piano_roll(fs=fs, pedal_threshold=None))
221
+ data = torch.flip(piano_roll, [0])
222
+ plt.rcParams["figure.figsize"] = (12, 3)
223
+ # data_clone = data.clone()
224
+ # # make it look thicker
225
+ # data[1:, :] = data[1:, :] + data_clone[:-1, :]
226
+ # data[2:, :] = data[2:, :] + data_clone[:-2, :]
227
+ data = data.cpu().numpy()
228
+ plt.imshow(data, cmap=mpl.colormaps['Blues'])
229
+ ax = plt.gca() # gca stands for 'get current axis'
230
+ for edge, spine in ax.spines.items():
231
+ spine.set_linewidth(2) # Adjust the value as per your requirement
232
+ plt.grid(color='gray', linestyle='-', linewidth=2., alpha=0.5, which='both', axis='x')
233
+ plt.xticks(np.arange(0, piano_roll.shape[1], 128))
234
+ # plt.savefig('piano_roll_example.png', bbox_inches='tight', pad_inches=0.1, dpi=300)
235
+ plt.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)
236
+ plt.tight_layout()
237
+ plt.show()
238
+ return
239
+
240
+
241
+ def plot_record(vals, title, save_dir):
242
+ ts = [item[0] for item in vals]
243
+ log_probs = [item[1] for item in vals]
244
+ plt.plot(ts, log_probs)
245
+ plt.gca().invert_xaxis()
246
+ plt.title(title)
247
+ plt.savefig(save_dir + '/' + title + '.png')
248
+ plt.show()
249
+ return
250
+
251
+
252
+ def quantize_pedal(value, num_bins=8):
253
+ """Quantize an integer value from 0 to 127 into 8 bins and return the center value of the bin."""
254
+ if value < 0 or value > 127:
255
+ raise ValueError("Value should be between 0 and 127")
256
+ # Determine bin size
257
+ bin_size = 128 // num_bins # 16
258
+ # Quantize the value
259
+ bin_index = value // bin_size
260
+ bin_center = bin_size * bin_index + bin_size // 2
261
+ # Handle edge case for the last bin
262
+ if bin_center > 127:
263
+ bin_center = 127
264
+ return bin_center
265
+
266
+
267
+ def get_full_piano_roll(midi_data, fs, show=False):
268
+ # do not process sustain pedal
269
+ piano_roll, onset_roll = midi_data.get_piano_roll(fs=fs, pedal_threshold=None, onset=True)
270
+ # save pedal roll explicitly
271
+ pedal_roll = np.zeros_like(piano_roll)
272
+ # process pedal
273
+ for instru in midi_data.instruments:
274
+ pedal_changes = [_e for _e in instru.control_changes if _e.number == CC_SUSTAIN_PEDAL]
275
+ for cc in pedal_changes:
276
+ time_now = int(cc.time * fs)
277
+ if time_now < pedal_roll.shape[-1]:
278
+ # need to distinguish control_change 0 and background 0, with quantize 0-16 will be 8
279
+ # in muscore files, 0 immediately followed by 127, need to shift by one column
280
+ if pedal_roll[MIN_PIANO, time_now] != 0. and abs(pedal_roll[MIN_PIANO, time_now] - cc.value) > 64:
281
+ # use shift 2 here to prevent missing change when using interpolation augmentation
282
+ pedal_roll[MIN_PIANO:MAX_PIANO + 1, min(time_now + 2, pedal_roll.shape[-1] - 1)] = quantize_pedal(cc.value)
283
+ else:
284
+ pedal_roll[MIN_PIANO:MAX_PIANO + 1, time_now] = quantize_pedal(cc.value)
285
+ full_roll = np.concatenate((piano_roll[None], onset_roll[None], pedal_roll[None]), axis=0)
286
+ if show:
287
+ plt.imshow(piano_roll[::-1, :1024], vmin=0, vmax=127)
288
+ plt.show()
289
+ plt.imshow(pedal_roll[::-1, :1024], vmin=0, vmax=127)
290
+ plt.show()
291
+ return full_roll
guided_diffusion/nn.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+
11
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12
+ class SiLU(nn.Module):
13
+ def forward(self, x):
14
+ return x * th.sigmoid(x)
15
+
16
+
17
+ class GroupNorm32(nn.GroupNorm):
18
+ def forward(self, x):
19
+ return super().forward(x.float()).type(x.dtype)
20
+
21
+
22
+ def conv_nd(dims, *args, **kwargs):
23
+ """
24
+ Create a 1D, 2D, or 3D convolution module.
25
+ """
26
+ if dims == 1:
27
+ return nn.Conv1d(*args, **kwargs)
28
+ elif dims == 2:
29
+ return nn.Conv2d(*args, **kwargs)
30
+ elif dims == 3:
31
+ return nn.Conv3d(*args, **kwargs)
32
+ raise ValueError(f"unsupported dimensions: {dims}")
33
+
34
+
35
+ def linear(*args, **kwargs):
36
+ """
37
+ Create a linear module.
38
+ """
39
+ return nn.Linear(*args, **kwargs)
40
+
41
+
42
+ def avg_pool_nd(dims, *args, **kwargs):
43
+ """
44
+ Create a 1D, 2D, or 3D average pooling module.
45
+ """
46
+ if dims == 1:
47
+ return nn.AvgPool1d(*args, **kwargs)
48
+ elif dims == 2:
49
+ return nn.AvgPool2d(*args, **kwargs)
50
+ elif dims == 3:
51
+ return nn.AvgPool3d(*args, **kwargs)
52
+ raise ValueError(f"unsupported dimensions: {dims}")
53
+
54
+
55
+ def update_ema(target_params, source_params, rate=0.99):
56
+ """
57
+ Update target parameters to be closer to those of source parameters using
58
+ an exponential moving average.
59
+
60
+ :param target_params: the target parameter sequence.
61
+ :param source_params: the source parameter sequence.
62
+ :param rate: the EMA rate (closer to 1 means slower).
63
+ """
64
+ for targ, src in zip(target_params, source_params):
65
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66
+
67
+
68
+ def zero_module(module):
69
+ """
70
+ Zero out the parameters of a module and return it.
71
+ """
72
+ for p in module.parameters():
73
+ p.detach().zero_()
74
+ return module
75
+
76
+
77
+ def scale_module(module, scale):
78
+ """
79
+ Scale the parameters of a module and return it.
80
+ """
81
+ for p in module.parameters():
82
+ p.detach().mul_(scale)
83
+ return module
84
+
85
+
86
+ def mean_flat(tensor):
87
+ """
88
+ Take the mean over all non-batch dimensions.
89
+ """
90
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
91
+
92
+
93
+ def normalization(channels):
94
+ """
95
+ Make a standard normalization layer.
96
+
97
+ :param channels: number of input channels.
98
+ :return: an nn.Module for normalization.
99
+ """
100
+ return GroupNorm32(32, channels)
101
+
102
+
103
+ def timestep_embedding(timesteps, dim, max_period=10000):
104
+ """
105
+ Create sinusoidal timestep embeddings.
106
+
107
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
108
+ These may be fractional.
109
+ :param dim: the dimension of the output.
110
+ :param max_period: controls the minimum frequency of the embeddings.
111
+ :return: an [N x dim] Tensor of positional embeddings.
112
+ """
113
+ half = dim // 2
114
+ freqs = th.exp(
115
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116
+ ).to(device=timesteps.device)
117
+ args = timesteps[:, None].float() * freqs[None]
118
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119
+ if dim % 2:
120
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121
+ return embedding
122
+
123
+
124
+ def checkpoint(func, inputs, params, flag):
125
+ """
126
+ Evaluate a function without caching intermediate activations, allowing for
127
+ reduced memory at the expense of extra compute in the backward pass.
128
+
129
+ :param func: the function to evaluate.
130
+ :param inputs: the argument sequence to pass to `func`.
131
+ :param params: a sequence of parameters `func` depends on but does not
132
+ explicitly take as arguments.
133
+ :param flag: if False, disable gradient checkpointing.
134
+ """
135
+ if flag:
136
+ args = tuple(inputs) + tuple(params)
137
+ return CheckpointFunction.apply(func, len(inputs), *args)
138
+ else:
139
+ return func(*inputs)
140
+
141
+
142
+ class CheckpointFunction(th.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, run_function, length, *args):
145
+ ctx.run_function = run_function
146
+ ctx.input_tensors = list(args[:length])
147
+ ctx.input_params = list(args[length:])
148
+ with th.no_grad():
149
+ output_tensors = ctx.run_function(*ctx.input_tensors)
150
+ return output_tensors
151
+
152
+ @staticmethod
153
+ def backward(ctx, *output_grads):
154
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155
+ with th.enable_grad():
156
+ # Fixes a bug where the first op in run_function modifies the
157
+ # Tensor storage in place, which is not allowed for detach()'d
158
+ # Tensors.
159
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160
+ output_tensors = ctx.run_function(*shallow_copies)
161
+ input_grads = th.autograd.grad(
162
+ output_tensors,
163
+ ctx.input_tensors + ctx.input_params,
164
+ output_grads,
165
+ allow_unused=True,
166
+ )
167
+ del ctx.input_tensors
168
+ del ctx.input_params
169
+ del output_tensors
170
+ return (None, None) + input_grads
guided_diffusion/pr_datasets_all.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import os
4
+ import pandas as pd
5
+ import csv
6
+ import re
7
+ from PIL import Image
8
+ import blobfile as bf
9
+ from mpi4py import MPI
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import DataLoader, Dataset
14
+
15
+ from music_rule_guidance import music_rules
16
+ from music_rule_guidance.rule_maps import FUNC_DICT
17
+
18
+ import matplotlib.pyplot as plt
19
+ plt.rcParams["figure.figsize"] = (6,3)
20
+ plt.rcParams['figure.dpi'] = 300
21
+ plt.rcParams['savefig.dpi'] = 300
22
+
23
+ # This file load in merged dataset with y being its dataset info
24
+
25
+
26
+ def load_data(
27
+ *,
28
+ data_dir,
29
+ batch_size,
30
+ class_cond=False,
31
+ deterministic=False,
32
+ image_size=1024,
33
+ rule=None,
34
+ ):
35
+ """
36
+ For a dataset, create a generator over (images, kwargs) pairs.
37
+
38
+ Each images is an NCHW float tensor, and the kwargs dict contains zero or
39
+ more keys, each of which map to a batched Tensor of their own.
40
+ The kwargs dict can be used for class labels, in which case the key is "y"
41
+ and the values are integer tensors of class labels.
42
+
43
+ :param data_dir: the csv file that contains all the data paths and classes.
44
+ :param batch_size: the batch size of each returned pair.
45
+ :param image_size: the size to which images are resized.
46
+ :param class_cond: if True, include a "y" key in returned dicts for class
47
+ label. If classes are not available and this is true, an
48
+ exception will be raised.
49
+ :param deterministic: if True, yield results in a deterministic order.
50
+ :param rule: a str that contains the name of the rule
51
+ """
52
+
53
+ df = pd.read_csv(data_dir)
54
+ all_files = df['midi_filename'].tolist()
55
+ classes = None
56
+ if class_cond:
57
+ classes = df['classes'].tolist()
58
+ if deterministic:
59
+ dataset = ImageDataset(
60
+ all_files,
61
+ classes=classes,
62
+ shard=MPI.COMM_WORLD.Get_rank(),
63
+ num_shards=MPI.COMM_WORLD.Get_size(),
64
+ image_size=image_size,
65
+ rule=rule,
66
+ pitch_shift=False,
67
+ time_stretch=False,
68
+ )
69
+ else:
70
+ dataset = ImageDataset(
71
+ all_files,
72
+ classes=classes,
73
+ shard=MPI.COMM_WORLD.Get_rank(),
74
+ num_shards=MPI.COMM_WORLD.Get_size(),
75
+ image_size=image_size,
76
+ rule=rule,
77
+ )
78
+ if deterministic:
79
+ loader = DataLoader(
80
+ dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
81
+ )
82
+ else:
83
+ loader = DataLoader(
84
+ dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
85
+ )
86
+ while True:
87
+ yield from loader
88
+
89
+
90
+ def key_shift(x, k):
91
+ # apply shift on both notes and onset
92
+ # x sample (batch x 3 x pitch x time)
93
+ # k number of pitches to shift
94
+ # only apply on (batch x 2 x pitch x time) because no key shift on pedal
95
+
96
+ pitches_and_onsets = x[:, :2, :, :]
97
+ pedals = x[:, 2:, :, :]
98
+
99
+ if k > 0:
100
+ pitches_and_onsets = torch.cat((pitches_and_onsets[:, :, k:, :], pitches_and_onsets[:, :, 0:k, :]), dim=2)
101
+ elif k < 0:
102
+ pitches_and_onsets = torch.cat((pitches_and_onsets[:, :, -k:, :], pitches_and_onsets[:, :, 0:-k, :]), dim=2)
103
+
104
+ x = torch.cat((pitches_and_onsets, pedals), dim=1)
105
+ return music_rules.piano_like(x)
106
+
107
+
108
+ class ImageDataset(Dataset):
109
+ def __init__(
110
+ self,
111
+ image_paths,
112
+ classes=None,
113
+ rule=None,
114
+ shard=0,
115
+ num_shards=1,
116
+ image_size=1024,
117
+ pitch_shift=True,
118
+ time_stretch=True,
119
+ ):
120
+ super().__init__()
121
+ self.local_images = image_paths[shard:][::num_shards]
122
+ self.local_classes = None if classes is None else classes[shard:][::num_shards]
123
+ self.rule = rule
124
+ self.pitch_shift = pitch_shift
125
+ self.time_stretch = time_stretch
126
+ self.image_size = image_size
127
+
128
+ def __len__(self):
129
+ return len(self.local_images)
130
+
131
+ def __getitem__(self, idx):
132
+ path = self.local_images[idx]
133
+ arr = np.load(path)[np.newaxis] # 1 x 2 x 128 x time
134
+ arr = arr.astype(np.float32) / 63.5 - 1
135
+ arr = torch.from_numpy(arr)
136
+
137
+ if self.time_stretch: # apply for both notes and pedal
138
+ pr_len = int(np.random.uniform(0.95, 1.05) * self.image_size)
139
+ start = np.random.randint(arr.shape[-1] - pr_len)
140
+ arr = arr[:, :, :, start:start+pr_len]
141
+ if pr_len < self.image_size: # stretching, prevent duplicating onsets
142
+ piano_pedal = arr[:, [0, 2], :, :]
143
+ piano_pedal = F.interpolate(piano_pedal, size=(128, self.image_size), mode='nearest')
144
+ onset_raw = arr[:, 1:2, :, :]
145
+ ind_a2b = (torch.arange(self.image_size)/self.image_size*pr_len).int()
146
+ ind = ind_a2b.diff().nonzero().squeeze() + 1
147
+ zero_tensor = torch.tensor([0])
148
+ ind = torch.concat((zero_tensor, ind))
149
+ onset = -torch.ones(1, 1, 128, self.image_size)
150
+ onset[:, :, :, ind] = onset_raw
151
+ arr = torch.concat((piano_pedal[:, :1, :, :], onset, piano_pedal[:, 1:, :, :]), dim=1)
152
+ if pr_len > self.image_size: # compressing, add onset if happen to drop onsets and keep durations
153
+ arr = F.interpolate(arr, size=(128, self.image_size), mode='nearest')
154
+ piano = arr[:, :1, :, :]
155
+ first_column = piano[:, :, :, :1]
156
+ padded_piano = torch.concat((first_column, piano), dim=-1)
157
+ onset_online = torch.diff(padded_piano, dim=-1)
158
+ mask = onset_online > 0
159
+ arr[:, 1:2, :, :][mask] = 1
160
+ else:
161
+ arr = arr[:, :, :, :self.image_size]
162
+ if self.pitch_shift: # only apply for notes
163
+ k = np.random.randint(-6, 7) # generate randint from -6 to +6
164
+ arr = key_shift(arr, k)
165
+
166
+ arr = music_rules.piano_like(arr) # also set pedal roll to be 0 for non-piano pitches (match VAE training)
167
+
168
+ out_dict = {}
169
+ if self.rule is not None:
170
+ if 'chord' in self.rule: # predict chord and key jointly
171
+ chord, key, _ = FUNC_DICT[self.rule](arr, return_key=True)
172
+ out_dict["chord"] = chord
173
+ out_dict["key"] = np.array(key, dtype=np.int64)
174
+ else:
175
+ out_dict[self.rule] = FUNC_DICT[self.rule](arr)
176
+ if self.local_classes is not None:
177
+ out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
178
+ # debug
179
+ # out_dict["path"] = path
180
+ # Remove the extra dimensions to get back a 3D tensor: 2x128x128
181
+ arr = arr.squeeze(0)
182
+ return arr, out_dict
183
+
guided_diffusion/resample.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch as th
5
+ import torch.distributed as dist
6
+
7
+
8
+ def create_named_schedule_sampler(name, diffusion):
9
+ """
10
+ Create a ScheduleSampler from a library of pre-defined samplers.
11
+
12
+ :param name: the name of the sampler.
13
+ :param diffusion: the diffusion object to sample for.
14
+ """
15
+ if name == "uniform":
16
+ return UniformSampler(diffusion)
17
+ elif name == "loss-second-moment":
18
+ return LossSecondMomentResampler(diffusion)
19
+ else:
20
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
21
+
22
+
23
+ class ScheduleSampler(ABC):
24
+ """
25
+ A distribution over timesteps in the diffusion process, intended to reduce
26
+ variance of the objective.
27
+
28
+ By default, samplers perform unbiased importance sampling, in which the
29
+ objective's mean is unchanged.
30
+ However, subclasses may override sample() to change how the resampled
31
+ terms are reweighted, allowing for actual changes in the objective.
32
+ """
33
+
34
+ @abstractmethod
35
+ def weights(self):
36
+ """
37
+ Get a numpy array of weights, one per diffusion step.
38
+
39
+ The weights needn't be normalized, but must be positive.
40
+ """
41
+
42
+ def sample(self, batch_size, device):
43
+ """
44
+ Importance-sample timesteps for a batch.
45
+
46
+ :param batch_size: the number of timesteps.
47
+ :param device: the torch device to save to.
48
+ :return: a tuple (timesteps, weights):
49
+ - timesteps: a tensor of timestep indices.
50
+ - weights: a tensor of weights to scale the resulting losses.
51
+ """
52
+ w = self.weights()
53
+ p = w / np.sum(w)
54
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55
+ indices = th.from_numpy(indices_np).long().to(device)
56
+ weights_np = 1 / (len(p) * p[indices_np])
57
+ weights = th.from_numpy(weights_np).float().to(device)
58
+ return indices, weights
59
+
60
+
61
+ class UniformSampler(ScheduleSampler):
62
+ def __init__(self, diffusion):
63
+ self.diffusion = diffusion
64
+ self._weights = np.ones([diffusion.num_timesteps])
65
+
66
+ def weights(self):
67
+ return self._weights
68
+
69
+
70
+ class LossAwareSampler(ScheduleSampler):
71
+ def update_with_local_losses(self, local_ts, local_losses):
72
+ """
73
+ Update the reweighting using losses from a model.
74
+
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+
80
+ :param local_ts: an integer Tensor of timesteps.
81
+ :param local_losses: a 1D Tensor of losses.
82
+ """
83
+ batch_sizes = [
84
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
85
+ for _ in range(dist.get_world_size())
86
+ ]
87
+ dist.all_gather(
88
+ batch_sizes,
89
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90
+ )
91
+
92
+ # Pad all_gather batches to be the maximum batch size.
93
+ batch_sizes = [x.item() for x in batch_sizes]
94
+ max_bs = max(batch_sizes)
95
+
96
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98
+ dist.all_gather(timestep_batches, local_ts)
99
+ dist.all_gather(loss_batches, local_losses)
100
+ timesteps = [
101
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102
+ ]
103
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104
+ self.update_with_all_losses(timesteps, losses)
105
+
106
+ @abstractmethod
107
+ def update_with_all_losses(self, ts, losses):
108
+ """
109
+ Update the reweighting using losses from a model.
110
+
111
+ Sub-classes should override this method to update the reweighting
112
+ using losses from the model.
113
+
114
+ This method directly updates the reweighting without synchronizing
115
+ between workers. It is called by update_with_local_losses from all
116
+ ranks with identical arguments. Thus, it should have deterministic
117
+ behavior to maintain state across workers.
118
+
119
+ :param ts: a list of int timesteps.
120
+ :param losses: a list of float losses, one per timestep.
121
+ """
122
+
123
+
124
+ class LossSecondMomentResampler(LossAwareSampler):
125
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126
+ self.diffusion = diffusion
127
+ self.history_per_term = history_per_term
128
+ self.uniform_prob = uniform_prob
129
+ self._loss_history = np.zeros(
130
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
131
+ )
132
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=int)
133
+
134
+ def weights(self):
135
+ if not self._warmed_up():
136
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138
+ weights /= np.sum(weights)
139
+ weights *= 1 - self.uniform_prob
140
+ weights += self.uniform_prob / len(weights)
141
+ return weights
142
+
143
+ def update_with_all_losses(self, ts, losses):
144
+ for t, loss in zip(ts, losses):
145
+ if self._loss_counts[t] == self.history_per_term:
146
+ # Shift out the oldest loss term.
147
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
148
+ self._loss_history[t, -1] = loss
149
+ else:
150
+ self._loss_history[t, self._loss_counts[t]] = loss
151
+ self._loss_counts[t] += 1
152
+
153
+ def _warmed_up(self):
154
+ return (self._loss_counts == self.history_per_term).all()
guided_diffusion/respace.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+
4
+ from .gaussian_diffusion import GaussianDiffusion
5
+
6
+
7
+ def space_timesteps(num_timesteps, section_counts):
8
+ """
9
+ Create a list of timesteps to use from an original diffusion process,
10
+ given the number of timesteps we want to take from equally-sized portions
11
+ of the original process.
12
+
13
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
14
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
15
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
16
+
17
+ If the stride is a string starting with "ddim", then the fixed striding
18
+ from the DDIM paper is used, and only one section is allowed.
19
+
20
+ :param num_timesteps: the number of diffusion steps in the original
21
+ process to divide up.
22
+ :param section_counts: either a list of numbers, or a string containing
23
+ comma-separated numbers, indicating the step count
24
+ per section. As a special case, use "ddimN" where N
25
+ is a number of steps to use the striding from the
26
+ DDIM paper.
27
+ :return: a set of diffusion steps from the original process to use.
28
+ """
29
+ if isinstance(section_counts, str):
30
+ if section_counts.startswith("ddim"):
31
+ desired_count = int(section_counts[len("ddim") :])
32
+ for i in range(1, num_timesteps):
33
+ if len(range(0, num_timesteps, i)) == desired_count:
34
+ return set(range(0, num_timesteps, i))
35
+ raise ValueError(
36
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
37
+ )
38
+ section_counts = [int(x) for x in section_counts.split(",")]
39
+ size_per = num_timesteps // len(section_counts)
40
+ extra = num_timesteps % len(section_counts)
41
+ start_idx = 0
42
+ all_steps = []
43
+ for i, section_count in enumerate(section_counts):
44
+ size = size_per + (1 if i < extra else 0)
45
+ if size < section_count:
46
+ raise ValueError(
47
+ f"cannot divide section of {size} steps into {section_count}"
48
+ )
49
+ if section_count <= 1:
50
+ frac_stride = 1
51
+ else:
52
+ frac_stride = (size - 1) / (section_count - 1)
53
+ cur_idx = 0.0
54
+ taken_steps = []
55
+ for _ in range(section_count):
56
+ taken_steps.append(start_idx + round(cur_idx))
57
+ cur_idx += frac_stride
58
+ all_steps += taken_steps
59
+ start_idx += size
60
+ return set(all_steps)
61
+
62
+
63
+ class SpacedDiffusion(GaussianDiffusion):
64
+ """
65
+ A diffusion process which can skip steps in a base diffusion process.
66
+
67
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
68
+ original diffusion process to retain.
69
+ :param kwargs: the kwargs to create the base diffusion process.
70
+ """
71
+
72
+ def __init__(self, use_timesteps, **kwargs):
73
+ self.use_timesteps = set(use_timesteps)
74
+ self.timestep_map = []
75
+ self.original_num_steps = len(kwargs["betas"])
76
+
77
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78
+ last_alpha_cumprod = 1.0
79
+ new_betas = []
80
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81
+ if i in self.use_timesteps:
82
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83
+ last_alpha_cumprod = alpha_cumprod
84
+ self.timestep_map.append(i)
85
+ kwargs["betas"] = np.array(new_betas)
86
+ super().__init__(**kwargs)
87
+
88
+ def p_mean_variance(
89
+ self, model, *args, **kwargs
90
+ ): # pylint: disable=signature-differs
91
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92
+
93
+ def training_losses(
94
+ self, model, *args, **kwargs
95
+ ): # pylint: disable=signature-differs
96
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
97
+
98
+ def condition_mean(self, cond_fn, *args, **kwargs):
99
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100
+
101
+ def condition_score(self, cond_fn, *args, **kwargs):
102
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103
+
104
+ def _wrap_model(self, model):
105
+ if isinstance(model, _WrappedModel):
106
+ return model
107
+ return _WrappedModel(
108
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109
+ )
110
+
111
+ def _scale_timesteps(self, t):
112
+ # Scaling is done by the wrapped model.
113
+ return t
114
+
115
+
116
+ class _WrappedModel:
117
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118
+ self.model = model
119
+ self.timestep_map = timestep_map
120
+ self.rescale_timesteps = rescale_timesteps
121
+ self.original_num_steps = original_num_steps
122
+
123
+ def __call__(self, x, ts, **kwargs):
124
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125
+ new_ts = map_tensor[ts]
126
+ if self.rescale_timesteps:
127
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128
+ return self.model(x, new_ts, **kwargs)
guided_diffusion/script_util.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import inspect
3
+ import torch.nn.functional as F
4
+
5
+ from music_rule_guidance import music_rules
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+ from .unet import SuperResModel, UNetModel, EncoderUNetModel
9
+
10
+ NUM_CLASSES = 3 # number of datasets
11
+
12
+
13
+ def diffusion_defaults():
14
+ """
15
+ Defaults for image and classifier training.
16
+ """
17
+ return dict(
18
+ learn_sigma=False,
19
+ diffusion_steps=1000,
20
+ noise_schedule="linear",
21
+ timestep_respacing="",
22
+ use_kl=False,
23
+ predict_xstart=False,
24
+ rescale_timesteps=False,
25
+ rescale_learned_sigmas=False,
26
+ )
27
+
28
+
29
+ def classifier_defaults():
30
+ """
31
+ Defaults for classifier models.
32
+ """
33
+ return dict(
34
+ image_size=64,
35
+ in_channels=3,
36
+ classifier_use_fp16=False,
37
+ classifier_width=128,
38
+ classifier_depth=2,
39
+ classifier_attention_resolutions="32,16,8", # 16
40
+ classifier_use_scale_shift_norm=True, # False
41
+ classifier_resblock_updown=True, # False
42
+ classifier_pool="attention",
43
+ num_classes=3,
44
+ chord=False,
45
+ )
46
+
47
+
48
+ def model_and_diffusion_image_defaults():
49
+ """
50
+ Defaults for image training.
51
+ """
52
+ res = dict(
53
+ image_size=64,
54
+ in_channels=3,
55
+ num_channels=128,
56
+ num_res_blocks=2,
57
+ num_heads=4,
58
+ num_heads_upsample=-1,
59
+ num_head_channels=-1,
60
+ attention_resolutions="32,16,8",
61
+ channel_mult="",
62
+ dropout=0.0,
63
+ class_cond=False,
64
+ use_checkpoint=False,
65
+ use_scale_shift_norm=True,
66
+ resblock_updown=False,
67
+ use_fp16=False,
68
+ use_new_attention_order=False,
69
+ )
70
+ res.update(diffusion_defaults())
71
+ return res
72
+
73
+
74
+ def model_and_diffusion_defaults():
75
+ """
76
+ Defaults for piano roll training.
77
+ """
78
+ res = dict(
79
+ image_size=128,
80
+ in_channels=1,
81
+ num_channels=128,
82
+ num_res_blocks=2,
83
+ num_heads=4,
84
+ num_heads_upsample=-1,
85
+ num_head_channels=-1,
86
+ attention_resolutions="32,16,8",
87
+ channel_mult="",
88
+ dropout=0.0,
89
+ class_cond=False,
90
+ use_checkpoint=False,
91
+ use_scale_shift_norm=True,
92
+ resblock_updown=False,
93
+ use_fp16=False,
94
+ use_new_attention_order=False,
95
+ )
96
+ res.update(diffusion_defaults())
97
+ return res
98
+
99
+
100
+ def classifier_and_diffusion_defaults():
101
+ res = classifier_defaults()
102
+ res.update(diffusion_defaults())
103
+ return res
104
+
105
+
106
+ def create_diffusion(
107
+ learn_sigma,
108
+ diffusion_steps,
109
+ noise_schedule,
110
+ timestep_respacing,
111
+ use_kl,
112
+ predict_xstart,
113
+ rescale_timesteps,
114
+ rescale_learned_sigmas,
115
+ ):
116
+ diffusion = create_gaussian_diffusion(
117
+ steps=diffusion_steps,
118
+ learn_sigma=learn_sigma,
119
+ noise_schedule=noise_schedule,
120
+ use_kl=use_kl,
121
+ predict_xstart=predict_xstart,
122
+ rescale_timesteps=rescale_timesteps,
123
+ rescale_learned_sigmas=rescale_learned_sigmas,
124
+ timestep_respacing=timestep_respacing,
125
+ )
126
+ return diffusion
127
+
128
+
129
+ def create_model_and_diffusion(
130
+ image_size,
131
+ in_channels,
132
+ class_cond,
133
+ learn_sigma,
134
+ num_channels,
135
+ num_res_blocks,
136
+ channel_mult,
137
+ num_heads,
138
+ num_head_channels,
139
+ num_heads_upsample,
140
+ attention_resolutions,
141
+ dropout,
142
+ diffusion_steps,
143
+ noise_schedule,
144
+ timestep_respacing,
145
+ use_kl,
146
+ predict_xstart,
147
+ rescale_timesteps,
148
+ rescale_learned_sigmas,
149
+ use_checkpoint,
150
+ use_scale_shift_norm,
151
+ resblock_updown,
152
+ use_fp16,
153
+ use_new_attention_order,
154
+ ):
155
+ model = create_model(
156
+ image_size,
157
+ num_channels,
158
+ num_res_blocks,
159
+ in_channels,
160
+ channel_mult=channel_mult,
161
+ learn_sigma=learn_sigma,
162
+ class_cond=class_cond,
163
+ use_checkpoint=use_checkpoint,
164
+ attention_resolutions=attention_resolutions,
165
+ num_heads=num_heads,
166
+ num_head_channels=num_head_channels,
167
+ num_heads_upsample=num_heads_upsample,
168
+ use_scale_shift_norm=use_scale_shift_norm,
169
+ dropout=dropout,
170
+ resblock_updown=resblock_updown,
171
+ use_fp16=use_fp16,
172
+ use_new_attention_order=use_new_attention_order,
173
+ )
174
+ diffusion = create_gaussian_diffusion(
175
+ steps=diffusion_steps,
176
+ learn_sigma=learn_sigma,
177
+ noise_schedule=noise_schedule,
178
+ use_kl=use_kl,
179
+ predict_xstart=predict_xstart,
180
+ rescale_timesteps=rescale_timesteps,
181
+ rescale_learned_sigmas=rescale_learned_sigmas,
182
+ timestep_respacing=timestep_respacing,
183
+ )
184
+ return model, diffusion
185
+
186
+
187
+ def create_model(
188
+ image_size,
189
+ num_channels,
190
+ num_res_blocks,
191
+ in_channels=3,
192
+ channel_mult="",
193
+ learn_sigma=False,
194
+ class_cond=False,
195
+ use_checkpoint=False,
196
+ attention_resolutions="16",
197
+ num_heads=4,
198
+ num_head_channels=-1,
199
+ num_heads_upsample=-1,
200
+ use_scale_shift_norm=False,
201
+ dropout=0,
202
+ resblock_updown=False,
203
+ use_fp16=False,
204
+ use_new_attention_order=False,
205
+ ):
206
+ image_size = image_size[-1] # if H != W, use W as image_size
207
+ if channel_mult == "":
208
+ if image_size == 512:
209
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
210
+ elif image_size == 256:
211
+ channel_mult = (1, 1, 2, 2, 4, 4)
212
+ elif image_size == 128:
213
+ channel_mult = (1, 1, 2, 3, 4)
214
+ elif image_size == 64:
215
+ channel_mult = (1, 2, 3, 4)
216
+ elif image_size == 32:
217
+ channel_mult = (1, 2, 2, 2)
218
+ elif image_size == 16:
219
+ channel_mult = (1, 2, 2)
220
+ else:
221
+ raise ValueError(f"unsupported image size: {image_size}")
222
+ else:
223
+ channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
224
+
225
+ attention_ds = []
226
+ for res in attention_resolutions.split(","):
227
+ attention_ds.append(image_size // int(res))
228
+
229
+ return UNetModel(
230
+ image_size=image_size,
231
+ in_channels=in_channels,
232
+ model_channels=num_channels,
233
+ out_channels=(in_channels if not learn_sigma else 2*in_channels),
234
+ num_res_blocks=num_res_blocks,
235
+ attention_resolutions=tuple(attention_ds),
236
+ dropout=dropout,
237
+ channel_mult=channel_mult,
238
+ num_classes=(NUM_CLASSES if class_cond else None),
239
+ use_checkpoint=use_checkpoint,
240
+ use_fp16=use_fp16,
241
+ num_heads=num_heads,
242
+ num_head_channels=num_head_channels,
243
+ num_heads_upsample=num_heads_upsample,
244
+ use_scale_shift_norm=use_scale_shift_norm,
245
+ resblock_updown=resblock_updown,
246
+ use_new_attention_order=use_new_attention_order,
247
+ )
248
+
249
+
250
+ def create_classifier_and_diffusion(
251
+ image_size,
252
+ in_channels,
253
+ classifier_use_fp16,
254
+ classifier_width,
255
+ classifier_depth,
256
+ classifier_attention_resolutions,
257
+ classifier_use_scale_shift_norm,
258
+ classifier_resblock_updown,
259
+ classifier_pool,
260
+ learn_sigma,
261
+ diffusion_steps,
262
+ noise_schedule,
263
+ timestep_respacing,
264
+ use_kl,
265
+ predict_xstart,
266
+ rescale_timesteps,
267
+ rescale_learned_sigmas,
268
+ num_classes,
269
+ chord,
270
+ ):
271
+ classifier = create_classifier(
272
+ image_size,
273
+ in_channels,
274
+ classifier_use_fp16,
275
+ classifier_width,
276
+ classifier_depth,
277
+ classifier_attention_resolutions,
278
+ classifier_use_scale_shift_norm,
279
+ classifier_resblock_updown,
280
+ classifier_pool,
281
+ num_classes,
282
+ chord,
283
+ )
284
+ diffusion = create_gaussian_diffusion(
285
+ steps=diffusion_steps,
286
+ learn_sigma=learn_sigma,
287
+ noise_schedule=noise_schedule,
288
+ use_kl=use_kl,
289
+ predict_xstart=predict_xstart,
290
+ rescale_timesteps=rescale_timesteps,
291
+ rescale_learned_sigmas=rescale_learned_sigmas,
292
+ timestep_respacing=timestep_respacing,
293
+ )
294
+ return classifier, diffusion
295
+
296
+
297
+ def create_classifier(
298
+ image_size,
299
+ in_channels,
300
+ classifier_use_fp16,
301
+ classifier_width,
302
+ classifier_depth,
303
+ classifier_attention_resolutions,
304
+ classifier_use_scale_shift_norm,
305
+ classifier_resblock_updown,
306
+ classifier_pool,
307
+ num_classes,
308
+ chord,
309
+ ):
310
+ image_size = image_size[-1] # if H != W, use W as image_size
311
+ if image_size == 512:
312
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
313
+ elif image_size == 256:
314
+ channel_mult = (1, 1, 2, 2, 4, 4)
315
+ elif image_size == 128:
316
+ channel_mult = (1, 1, 2, 3, 4)
317
+ elif image_size == 64:
318
+ channel_mult = (1, 2, 3, 4)
319
+ elif image_size == 16: # debug data load in
320
+ channel_mult = (1, 2, 2)
321
+ else:
322
+ raise ValueError(f"unsupported image size: {image_size}")
323
+
324
+ attention_ds = []
325
+ for res in classifier_attention_resolutions.split(","):
326
+ attention_ds.append(image_size // int(res))
327
+
328
+ return EncoderUNetModel(
329
+ image_size=image_size,
330
+ in_channels=in_channels,
331
+ model_channels=classifier_width,
332
+ out_channels=num_classes,
333
+ num_res_blocks=classifier_depth,
334
+ attention_resolutions=tuple(attention_ds),
335
+ channel_mult=channel_mult,
336
+ use_fp16=classifier_use_fp16,
337
+ num_head_channels=64,
338
+ use_scale_shift_norm=classifier_use_scale_shift_norm,
339
+ resblock_updown=classifier_resblock_updown,
340
+ pool=classifier_pool,
341
+ chord=chord,
342
+ )
343
+
344
+
345
+ def sr_model_and_diffusion_defaults():
346
+ res = model_and_diffusion_defaults()
347
+ res["large_size"] = 256
348
+ res["small_size"] = 64
349
+ arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
350
+ for k in res.copy().keys():
351
+ if k not in arg_names:
352
+ del res[k]
353
+ return res
354
+
355
+
356
+ def sr_create_model_and_diffusion(
357
+ large_size,
358
+ small_size,
359
+ class_cond,
360
+ learn_sigma,
361
+ num_channels,
362
+ num_res_blocks,
363
+ num_heads,
364
+ num_head_channels,
365
+ num_heads_upsample,
366
+ attention_resolutions,
367
+ dropout,
368
+ diffusion_steps,
369
+ noise_schedule,
370
+ timestep_respacing,
371
+ use_kl,
372
+ predict_xstart,
373
+ rescale_timesteps,
374
+ rescale_learned_sigmas,
375
+ use_checkpoint,
376
+ use_scale_shift_norm,
377
+ resblock_updown,
378
+ use_fp16,
379
+ ):
380
+ model = sr_create_model(
381
+ large_size,
382
+ small_size,
383
+ num_channels,
384
+ num_res_blocks,
385
+ learn_sigma=learn_sigma,
386
+ class_cond=class_cond,
387
+ use_checkpoint=use_checkpoint,
388
+ attention_resolutions=attention_resolutions,
389
+ num_heads=num_heads,
390
+ num_head_channels=num_head_channels,
391
+ num_heads_upsample=num_heads_upsample,
392
+ use_scale_shift_norm=use_scale_shift_norm,
393
+ dropout=dropout,
394
+ resblock_updown=resblock_updown,
395
+ use_fp16=use_fp16,
396
+ )
397
+ diffusion = create_gaussian_diffusion(
398
+ steps=diffusion_steps,
399
+ learn_sigma=learn_sigma,
400
+ noise_schedule=noise_schedule,
401
+ use_kl=use_kl,
402
+ predict_xstart=predict_xstart,
403
+ rescale_timesteps=rescale_timesteps,
404
+ rescale_learned_sigmas=rescale_learned_sigmas,
405
+ timestep_respacing=timestep_respacing,
406
+ )
407
+ return model, diffusion
408
+
409
+
410
+ def sr_create_model(
411
+ large_size,
412
+ small_size,
413
+ num_channels,
414
+ num_res_blocks,
415
+ learn_sigma,
416
+ class_cond,
417
+ use_checkpoint,
418
+ attention_resolutions,
419
+ num_heads,
420
+ num_head_channels,
421
+ num_heads_upsample,
422
+ use_scale_shift_norm,
423
+ dropout,
424
+ resblock_updown,
425
+ use_fp16,
426
+ ):
427
+ _ = small_size # hack to prevent unused variable
428
+
429
+ if large_size == 512:
430
+ channel_mult = (1, 1, 2, 2, 4, 4)
431
+ elif large_size == 256:
432
+ channel_mult = (1, 1, 2, 2, 4, 4)
433
+ elif large_size == 64:
434
+ channel_mult = (1, 2, 3, 4)
435
+ else:
436
+ raise ValueError(f"unsupported large size: {large_size}")
437
+
438
+ attention_ds = []
439
+ for res in attention_resolutions.split(","):
440
+ attention_ds.append(large_size // int(res))
441
+
442
+ return SuperResModel(
443
+ image_size=large_size,
444
+ in_channels=3,
445
+ model_channels=num_channels,
446
+ out_channels=(3 if not learn_sigma else 6),
447
+ num_res_blocks=num_res_blocks,
448
+ attention_resolutions=tuple(attention_ds),
449
+ dropout=dropout,
450
+ channel_mult=channel_mult,
451
+ num_classes=(NUM_CLASSES if class_cond else None),
452
+ use_checkpoint=use_checkpoint,
453
+ num_heads=num_heads,
454
+ num_head_channels=num_head_channels,
455
+ num_heads_upsample=num_heads_upsample,
456
+ use_scale_shift_norm=use_scale_shift_norm,
457
+ resblock_updown=resblock_updown,
458
+ use_fp16=use_fp16,
459
+ )
460
+
461
+
462
+ def create_gaussian_diffusion(
463
+ *,
464
+ steps=1000,
465
+ learn_sigma=False,
466
+ sigma_small=False,
467
+ noise_schedule="linear",
468
+ use_kl=False,
469
+ predict_xstart=False,
470
+ rescale_timesteps=False,
471
+ rescale_learned_sigmas=False,
472
+ timestep_respacing="",
473
+ ):
474
+ betas = gd.get_named_beta_schedule(noise_schedule, steps)
475
+ if use_kl:
476
+ loss_type = gd.LossType.RESCALED_KL
477
+ elif rescale_learned_sigmas:
478
+ loss_type = gd.LossType.RESCALED_MSE
479
+ else:
480
+ loss_type = gd.LossType.MSE
481
+ if not timestep_respacing:
482
+ timestep_respacing = [steps]
483
+ return SpacedDiffusion(
484
+ use_timesteps=space_timesteps(steps, timestep_respacing),
485
+ betas=betas,
486
+ model_mean_type=(
487
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
488
+ ),
489
+ model_var_type=(
490
+ (
491
+ gd.ModelVarType.FIXED_LARGE
492
+ if not sigma_small
493
+ else gd.ModelVarType.FIXED_SMALL
494
+ )
495
+ if not learn_sigma
496
+ else gd.ModelVarType.LEARNED_RANGE
497
+ ),
498
+ loss_type=loss_type,
499
+ rescale_timesteps=rescale_timesteps,
500
+ )
501
+
502
+
503
+ def add_dict_to_argparser(parser, default_dict):
504
+ for k, v in default_dict.items():
505
+ v_type = type(v)
506
+ if v is None:
507
+ v_type = str
508
+ elif isinstance(v, bool):
509
+ v_type = str2bool
510
+ if k == 'image_size':
511
+ parser.add_argument(f"--{k}", nargs='+', default=v, type=v_type)
512
+ else:
513
+ parser.add_argument(f"--{k}", default=v, type=v_type)
514
+
515
+
516
+ def args_to_dict(args, keys):
517
+ return {k: getattr(args, k) for k in keys}
518
+
519
+
520
+ def str2bool(v):
521
+ """
522
+ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
523
+ """
524
+ if isinstance(v, bool):
525
+ return v
526
+ if v.lower() in ("yes", "true", "t", "y", "1"):
527
+ return True
528
+ elif v.lower() in ("no", "false", "f", "n", "0"):
529
+ return False
530
+ else:
531
+ raise argparse.ArgumentTypeError("boolean value expected")
guided_diffusion/train_util.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import functools
3
+ import os
4
+ import os.path as osp
5
+ import numpy as np
6
+ import math
7
+
8
+ import blobfile as bf
9
+ import torch as th
10
+ import torch.nn.functional as F
11
+ import torch.distributed as dist
12
+ from torch.nn.parallel.distributed import DistributedDataParallel as DDP
13
+ from torch.optim import AdamW
14
+
15
+ from . import dist_util, midi_util, logger
16
+ from .fp16_util import MixedPrecisionTrainer
17
+ from .nn import update_ema
18
+ from .resample import LossAwareSampler, UniformSampler
19
+ from taming.modules.distributions.distributions import DiagonalGaussianDistribution
20
+
21
+ # For ImageNet experiments, this was a good default value.
22
+ # We found that the lg_loss_scale quickly climbed to
23
+ # 20-21 within the first ~1K steps of training.
24
+ INITIAL_LOG_LOSS_SCALE = 20.0
25
+
26
+
27
+ class TrainLoop:
28
+ def __init__(
29
+ self,
30
+ *,
31
+ model,
32
+ eval_model,
33
+ diffusion,
34
+ data,
35
+ batch_size,
36
+ microbatch,
37
+ lr,
38
+ ema_rate,
39
+ log_interval,
40
+ save_interval,
41
+ resume_checkpoint,
42
+ embed_model=None,
43
+ use_fp16=False,
44
+ fp16_scale_growth=1e-3,
45
+ schedule_sampler=None,
46
+ weight_decay=0.0,
47
+ lr_anneal_steps=0,
48
+ eval_data=None,
49
+ eval_interval=-1,
50
+ eval_sample_batch_size=16,
51
+ total_num_gpus=1, # training is run on how many gpus, used to distribute classes on each gpu
52
+ eval_sample_use_ddim=True,
53
+ eval_sample_clip_denoised=True,
54
+ in_channels=1,
55
+ fs=100,
56
+ pedal=False, # whether decode with pedal as the second channel
57
+ scale_factor=1.,
58
+ num_classes=0, # whether to use class_cond in sampling
59
+ microbatch_encode=-1,
60
+ encode_rep=4,
61
+ shift_size=4, # shift_size when generating time shifted sampels from an encoding
62
+ ):
63
+ self.model = model
64
+ self.eval_model = eval_model
65
+ self.embed_model = embed_model
66
+ self.scale_factor = scale_factor
67
+ self.diffusion = diffusion
68
+ self.data = data
69
+ self.batch_size = batch_size
70
+ self.microbatch = microbatch if microbatch > 0 else batch_size
71
+ self.microbatch_encode = microbatch_encode
72
+ self.encode_rep = encode_rep
73
+ self.batch_size = self.batch_size // self.encode_rep # effective batch size
74
+ self.microbatch = self.microbatch // self.encode_rep
75
+ self.shift_size = shift_size # need to be compatible with encode_rep
76
+ self.lr = lr
77
+ self.ema_rate = (
78
+ [ema_rate]
79
+ if isinstance(ema_rate, float)
80
+ else [float(x) for x in ema_rate.split(",")]
81
+ )
82
+ self.log_interval = log_interval
83
+ self.save_interval = save_interval
84
+ self.resume_checkpoint = resume_checkpoint
85
+ self.use_fp16 = use_fp16
86
+ self.fp16_scale_growth = fp16_scale_growth
87
+ self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
88
+ self.weight_decay = weight_decay
89
+ self.lr_anneal_steps = lr_anneal_steps
90
+ # eval
91
+ self.eval_data = eval_data
92
+ self.eval_interval = eval_interval
93
+ self.total_num_gpus = total_num_gpus
94
+ self.eval_sample_batch_size = eval_sample_batch_size // self.total_num_gpus
95
+ self.eval_sample_use_ddim = eval_sample_use_ddim
96
+ self.eval_sample_clip_denoised = eval_sample_clip_denoised
97
+ self.in_channels = in_channels
98
+ self.fs = fs
99
+ self.pedal = pedal
100
+ self.num_classes = num_classes
101
+
102
+ self.step = 0
103
+ self.resume_step = 0
104
+ self.global_batch = self.batch_size * dist.get_world_size()
105
+
106
+ self.sync_cuda = th.cuda.is_available()
107
+
108
+ self._load_and_sync_parameters()
109
+ self.mp_trainer = MixedPrecisionTrainer(
110
+ model=self.model,
111
+ use_fp16=self.use_fp16,
112
+ fp16_scale_growth=fp16_scale_growth,
113
+ )
114
+
115
+ self.opt = AdamW(
116
+ self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
117
+ )
118
+ if self.resume_step:
119
+ self._load_optimizer_state()
120
+ # Model was resumed, either due to a restart or a checkpoint
121
+ # being specified at the command line.
122
+ self.ema_params = [
123
+ self._load_ema_parameters(rate) for rate in self.ema_rate
124
+ ]
125
+ else:
126
+ self.ema_params = [
127
+ copy.deepcopy(self.mp_trainer.master_params)
128
+ for _ in range(len(self.ema_rate))
129
+ ]
130
+
131
+ if th.cuda.is_available():
132
+ self.use_ddp = True
133
+ self.ddp_model = DDP(
134
+ self.model,
135
+ device_ids=[dist_util.dev()],
136
+ output_device=dist_util.dev(),
137
+ broadcast_buffers=False,
138
+ bucket_cap_mb=128,
139
+ find_unused_parameters=False,
140
+ )
141
+ else:
142
+ if dist.get_world_size() > 1:
143
+ logger.warn(
144
+ "Distributed training requires CUDA. "
145
+ "Gradients will not be synchronized properly!"
146
+ )
147
+ self.use_ddp = False
148
+ self.ddp_model = self.model
149
+
150
+ def _load_and_sync_parameters(self):
151
+ resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
152
+
153
+ if resume_checkpoint:
154
+ self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
155
+ logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
156
+ self.model.load_state_dict(
157
+ dist_util.load_state_dict(
158
+ resume_checkpoint, map_location=dist_util.dev()
159
+ )
160
+ )
161
+
162
+ dist_util.sync_params(self.model.parameters())
163
+
164
+ def _load_ema_parameters(self, rate):
165
+ ema_params = copy.deepcopy(self.mp_trainer.master_params)
166
+
167
+ main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
168
+ ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
169
+ if ema_checkpoint:
170
+ logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
171
+ state_dict = dist_util.load_state_dict(
172
+ ema_checkpoint, map_location=dist_util.dev()
173
+ )
174
+ ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
175
+
176
+ dist_util.sync_params(ema_params)
177
+ return ema_params
178
+
179
+ def _load_optimizer_state(self):
180
+ main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
181
+ opt_checkpoint = bf.join(
182
+ bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
183
+ )
184
+ if bf.exists(opt_checkpoint):
185
+ logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
186
+ state_dict = dist_util.load_state_dict(
187
+ opt_checkpoint, map_location=dist_util.dev()
188
+ )
189
+ self.opt.load_state_dict(state_dict)
190
+
191
+ def run_loop(self):
192
+ while (
193
+ not self.lr_anneal_steps
194
+ or self.step + self.resume_step < self.lr_anneal_steps
195
+ ):
196
+ batch, cond = next(self.data)
197
+ dist.barrier()
198
+ self.run_step(batch, cond)
199
+ if self.eval_data is not None and self.step % self.eval_interval == 0:
200
+ batch_eval, cond_eval = next(self.eval_data)
201
+ self.run_step_eval(batch_eval, cond_eval)
202
+ if self.step % self.log_interval == 0:
203
+ logger.dumpkvs()
204
+ if self.step % self.save_interval == 0 and self.step != 0:
205
+ self.save()
206
+ # Run for a finite amount of time in integration tests.
207
+ if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
208
+ return
209
+ self.step += 1
210
+ # Save the last checkpoint if it wasn't already saved.
211
+ if (self.step - 1) % self.save_interval != 0:
212
+ self.save()
213
+
214
+ def run_step(self, batch, cond):
215
+ self.forward_backward(batch, cond)
216
+ took_step = self.mp_trainer.optimize(self.opt)
217
+ if took_step:
218
+ self._update_ema()
219
+ self._anneal_lr()
220
+ self.log_step()
221
+
222
+ def run_step_eval(self, batch, cond):
223
+ with th.no_grad():
224
+ # load in ema_params for eval_model in cpu, then move to gpu
225
+ # only use the first ema rate if there are multiple ema rate
226
+ ema_state_dict = self.mp_trainer.master_params_to_state_dict(self.ema_params[0])
227
+ # ema_state_dict_cpu = {k: v.cpu() for k, v in ema_state_dict.items()}
228
+ # self.eval_model.load_state_dict(ema_state_dict_cpu)
229
+ self.eval_model.load_state_dict(ema_state_dict)
230
+ # self.eval_model.to(dist_util.dev())
231
+ if self.use_fp16:
232
+ self.eval_model.convert_to_fp16()
233
+ self.eval_model.eval()
234
+ for i in range(0, batch.shape[0], self.microbatch):
235
+ micro = batch[i: i + self.microbatch].to(dist_util.dev())
236
+ if self.embed_model is not None:
237
+ micro = get_kl_input(micro, microbatch=self.microbatch_encode,
238
+ model=self.embed_model, scale_factor=self.scale_factor,
239
+ shift_size=self.shift_size)
240
+ micro_cond = {
241
+ k: v[i: i + self.microbatch].repeat_interleave(self.encode_rep).to(dist_util.dev())
242
+ for k, v in cond.items()
243
+ }
244
+ t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
245
+
246
+ compute_losses = functools.partial(
247
+ self.diffusion.training_losses,
248
+ self.eval_model,
249
+ micro,
250
+ t,
251
+ model_kwargs=micro_cond,
252
+ )
253
+ losses = compute_losses()
254
+ log_loss_dict(
255
+ self.diffusion, t, {'eval_'+k: v * weights for k, v in losses.items()}
256
+ )
257
+ if self.eval_sample_batch_size > 0 and self.step != 0:
258
+ # if True:
259
+ model_kwargs = {}
260
+ if self.num_classes > 0:
261
+ # classes = th.randint(
262
+ # low=0, high=self.num_classes, size=(self.eval_sample_batch_size,), device=dist_util.dev()
263
+ # )
264
+ # balance generated classes
265
+ rank = dist.get_rank()
266
+ samples_per_class = math.ceil(self.eval_sample_batch_size * self.total_num_gpus / self.num_classes)
267
+ label_start = rank * self.eval_sample_batch_size // samples_per_class
268
+ label_end = math.ceil((rank + 1) * self.eval_sample_batch_size / samples_per_class)
269
+ classes = th.arange(label_start, label_end, dtype=th.int, device=dist_util.dev()).repeat_interleave(samples_per_class)
270
+ model_kwargs["y"] = classes[:self.eval_sample_batch_size]
271
+ all_images = []
272
+ all_labels = []
273
+ image_size_h = micro.shape[-2]
274
+ image_size_w = micro.shape[-1]
275
+ sample_fn = (
276
+ self.diffusion.p_sample_loop if not self.eval_sample_use_ddim else self.diffusion.ddim_sample_loop
277
+ )
278
+ sample = sample_fn(
279
+ self.eval_model,
280
+ (self.eval_sample_batch_size, self.in_channels, image_size_h, image_size_w),
281
+ # (4, self.in_channels, image_size_h, image_size_w),
282
+ clip_denoised=self.eval_sample_clip_denoised,
283
+ model_kwargs=model_kwargs,
284
+ progress=True
285
+ )
286
+ ##### debug
287
+ # sample = micro
288
+ sample = midi_util.decode_sample_for_midi(sample, embed_model=self.embed_model,
289
+ scale_factor=self.scale_factor, threshold=-0.95)
290
+
291
+ gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
292
+ dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
293
+ all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
294
+ if self.num_classes > 0:
295
+ gathered_labels = [
296
+ th.zeros_like(model_kwargs["y"]) for _ in range(dist.get_world_size())
297
+ ]
298
+ dist.all_gather(gathered_labels, model_kwargs["y"])
299
+ all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
300
+
301
+ arr = np.concatenate(all_images, axis=0)
302
+ if arr.shape[-1] == 1: # no pedal, need shape B x 128 x 1024
303
+ arr = arr.squeeze(axis=-1)
304
+ else: # with pedal, need shape: B x 2 x 128 x 1024
305
+ arr = arr.transpose(0, 3, 1, 2)
306
+ if self.num_classes > 0:
307
+ label_arr = np.concatenate(all_labels, axis=0)
308
+ save_dir = osp.join(get_blob_logdir(), "samples", "iter_" + str(self.step + self.resume_step))
309
+ os.makedirs(os.path.expanduser(save_dir), exist_ok=True)
310
+ if dist.get_rank() == 0:
311
+ if self.num_classes > 0:
312
+ midi_util.save_piano_roll_midi(arr, save_dir, self.fs, y=label_arr)
313
+ else:
314
+ midi_util.save_piano_roll_midi(arr, save_dir, self.fs)
315
+ dist.barrier()
316
+ # # put the model on cpu to prepare for next loading
317
+ # self.eval_model.to("cpu")
318
+
319
+ def forward_backward(self, batch, cond):
320
+ self.mp_trainer.zero_grad()
321
+ for i in range(0, batch.shape[0], self.microbatch):
322
+ micro = batch[i : i + self.microbatch].to(dist_util.dev())
323
+ if self.embed_model is not None:
324
+ micro = get_kl_input(micro, microbatch=self.microbatch_encode,
325
+ model=self.embed_model, scale_factor=self.scale_factor,
326
+ shift_size=self.shift_size)
327
+ micro_cond = {
328
+ k: v[i : i + self.microbatch].repeat_interleave(self.encode_rep).to(dist_util.dev())
329
+ for k, v in cond.items()
330
+ }
331
+ last_batch = (i + self.microbatch) >= self.batch_size
332
+ t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
333
+
334
+ compute_losses = functools.partial(
335
+ self.diffusion.training_losses,
336
+ self.ddp_model,
337
+ micro,
338
+ t,
339
+ model_kwargs=micro_cond,
340
+ )
341
+
342
+ if last_batch or not self.use_ddp:
343
+ losses = compute_losses()
344
+ else:
345
+ with self.ddp_model.no_sync():
346
+ losses = compute_losses()
347
+
348
+ if isinstance(self.schedule_sampler, LossAwareSampler):
349
+ self.schedule_sampler.update_with_local_losses(
350
+ t, losses["loss"].detach()
351
+ )
352
+
353
+ loss = (losses["loss"] * weights).mean()
354
+ log_loss_dict(
355
+ self.diffusion, t, {k: v * weights for k, v in losses.items()}
356
+ )
357
+ self.mp_trainer.backward(loss)
358
+ # # keep gpu mem constant?
359
+ # del losses
360
+
361
+ def _update_ema(self):
362
+ for rate, params in zip(self.ema_rate, self.ema_params):
363
+ update_ema(params, self.mp_trainer.master_params, rate=rate)
364
+
365
+ def _anneal_lr(self):
366
+ if not self.lr_anneal_steps:
367
+ return
368
+ frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
369
+ lr = self.lr * (1 - frac_done)
370
+ for param_group in self.opt.param_groups:
371
+ param_group["lr"] = lr
372
+
373
+ def log_step(self):
374
+ logger.logkv("step", self.step + self.resume_step)
375
+ logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
376
+
377
+ def save(self):
378
+ def save_checkpoint(rate, params):
379
+ state_dict = self.mp_trainer.master_params_to_state_dict(params)
380
+ if dist.get_rank() == 0:
381
+ logger.log(f"saving model {rate}...")
382
+ if not rate:
383
+ filename = f"model{(self.step+self.resume_step):06d}.pt"
384
+ else:
385
+ filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
386
+ with bf.BlobFile(bf.join(get_blob_logdir(), "checkpoints", filename), "wb") as f:
387
+ th.save(state_dict, f)
388
+
389
+ save_checkpoint(0, self.mp_trainer.master_params)
390
+ for rate, params in zip(self.ema_rate, self.ema_params):
391
+ save_checkpoint(rate, params)
392
+
393
+ if dist.get_rank() == 0:
394
+ with bf.BlobFile(
395
+ bf.join(get_blob_logdir(), "checkpoints", f"opt{(self.step+self.resume_step):06d}.pt"),
396
+ "wb",
397
+ ) as f:
398
+ th.save(self.opt.state_dict(), f)
399
+
400
+ dist.barrier()
401
+
402
+
403
+ @th.no_grad()
404
+ def get_kl_input(batch, microbatch=-1, model=None, scale_factor=1., recombine=True, shift_size=4):
405
+ # here microbatch should be outer microbatch // encode_rep
406
+ if microbatch < 0:
407
+ microbatch = batch.shape[0]
408
+ full_z = []
409
+ image_size_h = batch.shape[-2]
410
+ image_size_w = batch.shape[-1]
411
+ seq_len = image_size_w // image_size_h
412
+ for i in range(0, batch.shape[0], microbatch):
413
+ micro = batch[i : i + microbatch].to(dist_util.dev())
414
+ # encode each 1s and concatenate
415
+ micro = th.chunk(micro, seq_len, dim=-1) # B x C x H x W
416
+ micro = th.concat(micro, dim=0) # 1st second for all batch, 2nd second for all batch, ...
417
+ micro = model.encode_save(micro, range_fix=False)
418
+ posterior = DiagonalGaussianDistribution(micro)
419
+ # z = posterior.sample()
420
+ z = posterior.mode()
421
+ z = th.concat(th.chunk(z, seq_len, dim=0), dim=-1)
422
+ z = z.permute(0, 1, 3, 2)
423
+ full_z.append(z)
424
+ full_z = th.concat(full_z, dim=0) # B x 4 x (15x16), 16
425
+ if recombine: # if not using microbatch, then need to use recombination of tokens
426
+ # unfold: dimension, size, step
427
+ full_z = full_z.unfold(2, 8*16, 16*shift_size).permute(0, 2, 1, 4, 3) # (B x encode_rep) x 4 x 128 x 16
428
+ full_z = full_z.contiguous().view(-1, 4, 8*16, 16) # B x 4 x 128 x 16
429
+ return (full_z * scale_factor).detach()
430
+
431
+
432
+ def parse_resume_step_from_filename(filename):
433
+ """
434
+ Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
435
+ checkpoint's number of steps.
436
+ """
437
+ split = filename.split("model")
438
+ if len(split) < 2:
439
+ return 0
440
+ split1 = split[-1].split(".")[0]
441
+ try:
442
+ return int(split1)
443
+ except ValueError:
444
+ return 0
445
+
446
+
447
+ def get_blob_logdir():
448
+ # You can change this to be a separate path to save checkpoints to
449
+ # a blobstore or some external drive.
450
+ return logger.get_dir()
451
+
452
+
453
+ def find_resume_checkpoint():
454
+ # On your infrastructure, you may want to override this to automatically
455
+ # discover the latest checkpoint on your blob storage, etc.
456
+ return None
457
+
458
+
459
+ def find_ema_checkpoint(main_checkpoint, step, rate):
460
+ if main_checkpoint is None:
461
+ return None
462
+ filename = f"ema_{rate}_{(step):06d}.pt"
463
+ path = bf.join(bf.dirname(main_checkpoint), filename)
464
+ if bf.exists(path):
465
+ return path
466
+ return None
467
+
468
+
469
+ def log_loss_dict(diffusion, ts, losses):
470
+ for key, values in losses.items():
471
+ logger.logkv_mean(key, values.mean().item())
472
+ # Log the quantiles (four quartiles, in particular).
473
+ for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
474
+ quartile = int(4 * sub_t / diffusion.num_timesteps)
475
+ logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
guided_diffusion/unet.py ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .fp16_util import convert_module_to_f16, convert_module_to_f32
11
+ from .nn import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+
21
+
22
+ class AttentionPool2d(nn.Module):
23
+ """
24
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ spacial_dim: int,
30
+ embed_dim: int,
31
+ num_heads_channels: int,
32
+ output_dim: int = None,
33
+ chord: bool = False,
34
+ ):
35
+ super().__init__()
36
+ self.positional_embedding = nn.Parameter(
37
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
38
+ )
39
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
40
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
41
+ self.chord = chord
42
+ if chord:
43
+ self.c_proj_key = conv_nd(1, embed_dim, 25, 1)
44
+ self.num_heads = embed_dim // num_heads_channels
45
+ self.attention = QKVAttention(self.num_heads)
46
+
47
+ def forward(self, x):
48
+ b, c, *_spatial = x.shape
49
+ x = x.reshape(b, c, -1) # NC(HW)
50
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
51
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
52
+ x = self.qkv_proj(x)
53
+ x = self.attention(x)
54
+ if self.chord:
55
+ x_key = self.c_proj_key(x)
56
+ key = x_key[:, :, 0]
57
+ x_chord = self.c_proj(x)[:, :, 1:]
58
+ chord = x_chord.reshape(b, -1, *_spatial).mean(dim=2).permute(0, 2, 1)
59
+ return key, chord
60
+ else:
61
+ x = self.c_proj(x)
62
+ return x[:, :, 0]
63
+
64
+
65
+ class TimestepBlock(nn.Module):
66
+ """
67
+ Any module where forward() takes timestep embeddings as a second argument.
68
+ """
69
+
70
+ @abstractmethod
71
+ def forward(self, x, emb):
72
+ """
73
+ Apply the module to `x` given `emb` timestep embeddings.
74
+ """
75
+
76
+
77
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
78
+ """
79
+ A sequential module that passes timestep embeddings to the children that
80
+ support it as an extra input.
81
+ """
82
+
83
+ def forward(self, x, emb):
84
+ for layer in self:
85
+ if isinstance(layer, TimestepBlock):
86
+ x = layer(x, emb)
87
+ else:
88
+ x = layer(x)
89
+ return x
90
+
91
+
92
+ class Upsample(nn.Module):
93
+ """
94
+ An upsampling layer with an optional convolution.
95
+
96
+ :param channels: channels in the inputs and outputs.
97
+ :param use_conv: a bool determining if a convolution is applied.
98
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
99
+ upsampling occurs in the inner-two dimensions.
100
+ """
101
+
102
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
103
+ super().__init__()
104
+ self.channels = channels
105
+ self.out_channels = out_channels or channels
106
+ self.use_conv = use_conv
107
+ self.dims = dims
108
+ if use_conv:
109
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
110
+
111
+ def forward(self, x):
112
+ assert x.shape[1] == self.channels
113
+ if self.dims == 3:
114
+ x = F.interpolate(
115
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
116
+ )
117
+ else:
118
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
119
+ if self.use_conv:
120
+ x = self.conv(x)
121
+ return x
122
+
123
+
124
+ class Downsample(nn.Module):
125
+ """
126
+ A downsampling layer with an optional convolution.
127
+
128
+ :param channels: channels in the inputs and outputs.
129
+ :param use_conv: a bool determining if a convolution is applied.
130
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
131
+ downsampling occurs in the inner-two dimensions.
132
+ """
133
+
134
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
135
+ super().__init__()
136
+ self.channels = channels
137
+ self.out_channels = out_channels or channels
138
+ self.use_conv = use_conv
139
+ self.dims = dims
140
+ stride = 2 if dims != 3 else (1, 2, 2)
141
+ if use_conv:
142
+ self.op = conv_nd(
143
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
144
+ )
145
+ else:
146
+ assert self.channels == self.out_channels
147
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
148
+
149
+ def forward(self, x):
150
+ assert x.shape[1] == self.channels
151
+ return self.op(x)
152
+
153
+
154
+ class ResBlock(TimestepBlock):
155
+ """
156
+ A residual block that can optionally change the number of channels.
157
+
158
+ :param channels: the number of input channels.
159
+ :param emb_channels: the number of timestep embedding channels.
160
+ :param dropout: the rate of dropout.
161
+ :param out_channels: if specified, the number of out channels.
162
+ :param use_conv: if True and out_channels is specified, use a spatial
163
+ convolution instead of a smaller 1x1 convolution to change the
164
+ channels in the skip connection.
165
+ :param dims: determines if the signal is 1D, 2D, or 3D.
166
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
167
+ :param up: if True, use this block for upsampling.
168
+ :param down: if True, use this block for downsampling.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ channels,
174
+ emb_channels,
175
+ dropout,
176
+ out_channels=None,
177
+ use_conv=False,
178
+ use_scale_shift_norm=False,
179
+ dims=2,
180
+ use_checkpoint=False,
181
+ up=False,
182
+ down=False,
183
+ ):
184
+ super().__init__()
185
+ self.channels = channels
186
+ self.emb_channels = emb_channels
187
+ self.dropout = dropout
188
+ self.out_channels = out_channels or channels
189
+ self.use_conv = use_conv
190
+ self.use_checkpoint = use_checkpoint
191
+ self.use_scale_shift_norm = use_scale_shift_norm
192
+
193
+ self.in_layers = nn.Sequential(
194
+ normalization(channels),
195
+ nn.SiLU(),
196
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
197
+ )
198
+
199
+ self.updown = up or down
200
+
201
+ if up:
202
+ self.h_upd = Upsample(channels, False, dims)
203
+ self.x_upd = Upsample(channels, False, dims)
204
+ elif down:
205
+ self.h_upd = Downsample(channels, False, dims)
206
+ self.x_upd = Downsample(channels, False, dims)
207
+ else:
208
+ self.h_upd = self.x_upd = nn.Identity()
209
+
210
+ self.emb_layers = nn.Sequential(
211
+ nn.SiLU(),
212
+ linear(
213
+ emb_channels,
214
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
215
+ ),
216
+ )
217
+ self.out_layers = nn.Sequential(
218
+ normalization(self.out_channels),
219
+ nn.SiLU(),
220
+ nn.Dropout(p=dropout),
221
+ zero_module(
222
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
223
+ ),
224
+ )
225
+
226
+ if self.out_channels == channels:
227
+ self.skip_connection = nn.Identity()
228
+ elif use_conv:
229
+ self.skip_connection = conv_nd(
230
+ dims, channels, self.out_channels, 3, padding=1
231
+ )
232
+ else:
233
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
234
+
235
+ def forward(self, x, emb):
236
+ """
237
+ Apply the block to a Tensor, conditioned on a timestep embedding.
238
+
239
+ :param x: an [N x C x ...] Tensor of features.
240
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
241
+ :return: an [N x C x ...] Tensor of outputs.
242
+ """
243
+ return checkpoint(
244
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
245
+ )
246
+
247
+ def _forward(self, x, emb):
248
+ if self.updown:
249
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
250
+ h = in_rest(x)
251
+ h = self.h_upd(h)
252
+ x = self.x_upd(x)
253
+ h = in_conv(h)
254
+ else:
255
+ h = self.in_layers(x)
256
+ emb_out = self.emb_layers(emb).type(h.dtype)
257
+ while len(emb_out.shape) < len(h.shape):
258
+ emb_out = emb_out[..., None]
259
+ if self.use_scale_shift_norm:
260
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
261
+ scale, shift = th.chunk(emb_out, 2, dim=1)
262
+ h = out_norm(h) * (1 + scale) + shift
263
+ h = out_rest(h)
264
+ else:
265
+ h = h + emb_out
266
+ h = self.out_layers(h)
267
+ return self.skip_connection(x) + h
268
+
269
+
270
+ class AttentionBlock(nn.Module):
271
+ """
272
+ An attention block that allows spatial positions to attend to each other.
273
+
274
+ Originally ported from here, but adapted to the N-d case.
275
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
276
+ """
277
+
278
+ def __init__(
279
+ self,
280
+ channels,
281
+ num_heads=1,
282
+ num_head_channels=-1,
283
+ use_checkpoint=False,
284
+ use_new_attention_order=False,
285
+ ):
286
+ super().__init__()
287
+ self.channels = channels
288
+ if num_head_channels == -1:
289
+ self.num_heads = num_heads
290
+ else:
291
+ assert (
292
+ channels % num_head_channels == 0
293
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
294
+ self.num_heads = channels // num_head_channels
295
+ self.use_checkpoint = use_checkpoint
296
+ self.norm = normalization(channels)
297
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
298
+ if use_new_attention_order:
299
+ # split qkv before split heads
300
+ self.attention = QKVAttention(self.num_heads)
301
+ else:
302
+ # split heads before split qkv
303
+ self.attention = QKVAttentionLegacy(self.num_heads)
304
+
305
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
306
+
307
+ def forward(self, x):
308
+ return checkpoint(self._forward, (x,), self.parameters(), True)
309
+
310
+ def _forward(self, x):
311
+ b, c, *spatial = x.shape
312
+ x = x.reshape(b, c, -1)
313
+ qkv = self.qkv(self.norm(x))
314
+ h = self.attention(qkv)
315
+ h = self.proj_out(h)
316
+ return (x + h).reshape(b, c, *spatial)
317
+
318
+
319
+ def count_flops_attn(model, _x, y):
320
+ """
321
+ A counter for the `thop` package to count the operations in an
322
+ attention operation.
323
+ Meant to be used like:
324
+ macs, params = thop.profile(
325
+ model,
326
+ inputs=(inputs, timestamps),
327
+ custom_ops={QKVAttention: QKVAttention.count_flops},
328
+ )
329
+ """
330
+ b, c, *spatial = y[0].shape
331
+ num_spatial = int(np.prod(spatial))
332
+ # We perform two matmuls with the same number of ops.
333
+ # The first computes the weight matrix, the second computes
334
+ # the combination of the value vectors.
335
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
336
+ model.total_ops += th.DoubleTensor([matmul_ops])
337
+
338
+
339
+ class QKVAttentionLegacy(nn.Module):
340
+ """
341
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
342
+ """
343
+
344
+ def __init__(self, n_heads):
345
+ super().__init__()
346
+ self.n_heads = n_heads
347
+
348
+ def forward(self, qkv):
349
+ """
350
+ Apply QKV attention.
351
+
352
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
353
+ :return: an [N x (H * C) x T] tensor after attention.
354
+ """
355
+ bs, width, length = qkv.shape
356
+ assert width % (3 * self.n_heads) == 0
357
+ ch = width // (3 * self.n_heads)
358
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
359
+ scale = 1 / math.sqrt(math.sqrt(ch))
360
+ weight = th.einsum(
361
+ "bct,bcs->bts", q * scale, k * scale
362
+ ) # More stable with f16 than dividing afterwards
363
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
364
+ a = th.einsum("bts,bcs->bct", weight, v)
365
+ return a.reshape(bs, -1, length)
366
+
367
+ @staticmethod
368
+ def count_flops(model, _x, y):
369
+ return count_flops_attn(model, _x, y)
370
+
371
+
372
+ class QKVAttention(nn.Module):
373
+ """
374
+ A module which performs QKV attention and splits in a different order.
375
+ """
376
+
377
+ def __init__(self, n_heads):
378
+ super().__init__()
379
+ self.n_heads = n_heads
380
+
381
+ def forward(self, qkv):
382
+ """
383
+ Apply QKV attention.
384
+
385
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
386
+ :return: an [N x (H * C) x T] tensor after attention.
387
+ """
388
+ bs, width, length = qkv.shape
389
+ assert width % (3 * self.n_heads) == 0
390
+ ch = width // (3 * self.n_heads)
391
+ q, k, v = qkv.chunk(3, dim=1)
392
+ scale = 1 / math.sqrt(math.sqrt(ch))
393
+ weight = th.einsum(
394
+ "bct,bcs->bts",
395
+ (q * scale).view(bs * self.n_heads, ch, length),
396
+ (k * scale).view(bs * self.n_heads, ch, length),
397
+ ) # More stable with f16 than dividing afterwards
398
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
399
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
400
+ return a.reshape(bs, -1, length)
401
+
402
+ @staticmethod
403
+ def count_flops(model, _x, y):
404
+ return count_flops_attn(model, _x, y)
405
+
406
+
407
+ class UNetModel(nn.Module):
408
+ """
409
+ The full UNet model with attention and timestep embedding.
410
+
411
+ :param in_channels: channels in the input Tensor.
412
+ :param model_channels: base channel count for the model.
413
+ :param out_channels: channels in the output Tensor.
414
+ :param num_res_blocks: number of residual blocks per downsample.
415
+ :param attention_resolutions: a collection of downsample rates at which
416
+ attention will take place. May be a set, list, or tuple.
417
+ For example, if this contains 4, then at 4x downsampling, attention
418
+ will be used.
419
+ :param dropout: the dropout probability.
420
+ :param channel_mult: channel multiplier for each level of the UNet.
421
+ :param conv_resample: if True, use learned convolutions for upsampling and
422
+ downsampling.
423
+ :param dims: determines if the signal is 1D, 2D, or 3D.
424
+ :param num_classes: if specified (as an int), then this model will be
425
+ class-conditional with `num_classes` classes.
426
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
427
+ :param num_heads: the number of attention heads in each attention layer.
428
+ :param num_heads_channels: if specified, ignore num_heads and instead use
429
+ a fixed channel width per attention head.
430
+ :param num_heads_upsample: works with num_heads to set a different number
431
+ of heads for upsampling. Deprecated.
432
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
433
+ :param resblock_updown: use residual blocks for up/downsampling.
434
+ :param use_new_attention_order: use a different attention pattern for potentially
435
+ increased efficiency.
436
+ """
437
+
438
+ def __init__(
439
+ self,
440
+ image_size,
441
+ in_channels,
442
+ model_channels,
443
+ out_channels,
444
+ num_res_blocks,
445
+ attention_resolutions,
446
+ dropout=0,
447
+ channel_mult=(1, 2, 4, 8),
448
+ conv_resample=True,
449
+ dims=2,
450
+ num_classes=None,
451
+ use_checkpoint=False,
452
+ use_fp16=False,
453
+ num_heads=1,
454
+ num_head_channels=-1,
455
+ num_heads_upsample=-1,
456
+ use_scale_shift_norm=False,
457
+ resblock_updown=False,
458
+ use_new_attention_order=False,
459
+ ):
460
+ super().__init__()
461
+
462
+ if num_heads_upsample == -1:
463
+ num_heads_upsample = num_heads
464
+
465
+ self.image_size = image_size
466
+ self.in_channels = in_channels
467
+ self.model_channels = model_channels
468
+ self.out_channels = out_channels
469
+ self.num_res_blocks = num_res_blocks
470
+ self.attention_resolutions = attention_resolutions
471
+ self.dropout = dropout
472
+ self.channel_mult = channel_mult
473
+ self.conv_resample = conv_resample
474
+ self.num_classes = num_classes
475
+ self.use_checkpoint = use_checkpoint
476
+ self.dtype = th.float16 if use_fp16 else th.float32
477
+ self.num_heads = num_heads
478
+ self.num_head_channels = num_head_channels
479
+ self.num_heads_upsample = num_heads_upsample
480
+
481
+ time_embed_dim = model_channels * 4
482
+ self.time_embed = nn.Sequential(
483
+ linear(model_channels, time_embed_dim),
484
+ nn.SiLU(),
485
+ linear(time_embed_dim, time_embed_dim),
486
+ )
487
+
488
+ if self.num_classes is not None:
489
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
490
+
491
+ ch = input_ch = int(channel_mult[0] * model_channels)
492
+ self.input_blocks = nn.ModuleList(
493
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
494
+ )
495
+ self._feature_size = ch
496
+ input_block_chans = [ch]
497
+ ds = 1
498
+ for level, mult in enumerate(channel_mult):
499
+ for _ in range(num_res_blocks):
500
+ layers = [
501
+ ResBlock(
502
+ ch,
503
+ time_embed_dim,
504
+ dropout,
505
+ out_channels=int(mult * model_channels),
506
+ dims=dims,
507
+ use_checkpoint=use_checkpoint,
508
+ use_scale_shift_norm=use_scale_shift_norm,
509
+ )
510
+ ]
511
+ ch = int(mult * model_channels)
512
+ if ds in attention_resolutions:
513
+ layers.append(
514
+ AttentionBlock(
515
+ ch,
516
+ use_checkpoint=use_checkpoint,
517
+ num_heads=num_heads,
518
+ num_head_channels=num_head_channels,
519
+ use_new_attention_order=use_new_attention_order,
520
+ )
521
+ )
522
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
523
+ self._feature_size += ch
524
+ input_block_chans.append(ch)
525
+ if level != len(channel_mult) - 1:
526
+ out_ch = ch
527
+ self.input_blocks.append(
528
+ TimestepEmbedSequential(
529
+ ResBlock(
530
+ ch,
531
+ time_embed_dim,
532
+ dropout,
533
+ out_channels=out_ch,
534
+ dims=dims,
535
+ use_checkpoint=use_checkpoint,
536
+ use_scale_shift_norm=use_scale_shift_norm,
537
+ down=True,
538
+ )
539
+ if resblock_updown
540
+ else Downsample(
541
+ ch, conv_resample, dims=dims, out_channels=out_ch
542
+ )
543
+ )
544
+ )
545
+ ch = out_ch
546
+ input_block_chans.append(ch)
547
+ ds *= 2
548
+ self._feature_size += ch
549
+
550
+ self.middle_block = TimestepEmbedSequential(
551
+ ResBlock(
552
+ ch,
553
+ time_embed_dim,
554
+ dropout,
555
+ dims=dims,
556
+ use_checkpoint=use_checkpoint,
557
+ use_scale_shift_norm=use_scale_shift_norm,
558
+ ),
559
+ AttentionBlock(
560
+ ch,
561
+ use_checkpoint=use_checkpoint,
562
+ num_heads=num_heads,
563
+ num_head_channels=num_head_channels,
564
+ use_new_attention_order=use_new_attention_order,
565
+ ),
566
+ ResBlock(
567
+ ch,
568
+ time_embed_dim,
569
+ dropout,
570
+ dims=dims,
571
+ use_checkpoint=use_checkpoint,
572
+ use_scale_shift_norm=use_scale_shift_norm,
573
+ ),
574
+ )
575
+ self._feature_size += ch
576
+
577
+ self.output_blocks = nn.ModuleList([])
578
+ for level, mult in list(enumerate(channel_mult))[::-1]:
579
+ for i in range(num_res_blocks + 1):
580
+ ich = input_block_chans.pop()
581
+ layers = [
582
+ ResBlock(
583
+ ch + ich,
584
+ time_embed_dim,
585
+ dropout,
586
+ out_channels=int(model_channels * mult),
587
+ dims=dims,
588
+ use_checkpoint=use_checkpoint,
589
+ use_scale_shift_norm=use_scale_shift_norm,
590
+ )
591
+ ]
592
+ ch = int(model_channels * mult)
593
+ if ds in attention_resolutions:
594
+ layers.append(
595
+ AttentionBlock(
596
+ ch,
597
+ use_checkpoint=use_checkpoint,
598
+ num_heads=num_heads_upsample,
599
+ num_head_channels=num_head_channels,
600
+ use_new_attention_order=use_new_attention_order,
601
+ )
602
+ )
603
+ if level and i == num_res_blocks:
604
+ out_ch = ch
605
+ layers.append(
606
+ ResBlock(
607
+ ch,
608
+ time_embed_dim,
609
+ dropout,
610
+ out_channels=out_ch,
611
+ dims=dims,
612
+ use_checkpoint=use_checkpoint,
613
+ use_scale_shift_norm=use_scale_shift_norm,
614
+ up=True,
615
+ )
616
+ if resblock_updown
617
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
618
+ )
619
+ ds //= 2
620
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
621
+ self._feature_size += ch
622
+
623
+ self.out = nn.Sequential(
624
+ normalization(ch),
625
+ nn.SiLU(),
626
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
627
+ )
628
+
629
+ def convert_to_fp16(self):
630
+ """
631
+ Convert the torso of the model to float16.
632
+ """
633
+ self.input_blocks.apply(convert_module_to_f16)
634
+ self.middle_block.apply(convert_module_to_f16)
635
+ self.output_blocks.apply(convert_module_to_f16)
636
+
637
+ def convert_to_fp32(self):
638
+ """
639
+ Convert the torso of the model to float32.
640
+ """
641
+ self.input_blocks.apply(convert_module_to_f32)
642
+ self.middle_block.apply(convert_module_to_f32)
643
+ self.output_blocks.apply(convert_module_to_f32)
644
+
645
+ def forward(self, x, timesteps, y=None):
646
+ """
647
+ Apply the model to an input batch.
648
+
649
+ :param x: an [N x C x ...] Tensor of inputs.
650
+ :param timesteps: a 1-D batch of timesteps.
651
+ :param y: an [N] Tensor of labels, if class-conditional.
652
+ :return: an [N x C x ...] Tensor of outputs.
653
+ """
654
+ assert (y is not None) == (
655
+ self.num_classes is not None
656
+ ), "must specify y if and only if the model is class-conditional"
657
+
658
+ hs = []
659
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
660
+
661
+ if self.num_classes is not None:
662
+ assert y.shape == (x.shape[0],)
663
+ emb = emb + self.label_emb(y)
664
+
665
+ h = x.type(self.dtype)
666
+ for module in self.input_blocks:
667
+ h = module(h, emb)
668
+ hs.append(h)
669
+ h = self.middle_block(h, emb)
670
+ for module in self.output_blocks:
671
+ h = th.cat([h, hs.pop()], dim=1)
672
+ h = module(h, emb)
673
+ h = h.type(x.dtype)
674
+ return self.out(h)
675
+
676
+
677
+ class SuperResModel(UNetModel):
678
+ """
679
+ A UNetModel that performs super-resolution.
680
+
681
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
682
+ """
683
+
684
+ def __init__(self, image_size, in_channels, *args, **kwargs):
685
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
686
+
687
+ def forward(self, x, timesteps, low_res=None, **kwargs):
688
+ _, _, new_height, new_width = x.shape
689
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
690
+ x = th.cat([x, upsampled], dim=1)
691
+ return super().forward(x, timesteps, **kwargs)
692
+
693
+
694
+ class EncoderUNetModel(nn.Module):
695
+ """
696
+ The half UNet model with attention and timestep embedding.
697
+
698
+ For usage, see UNet.
699
+ """
700
+
701
+ def __init__(
702
+ self,
703
+ image_size,
704
+ in_channels,
705
+ model_channels,
706
+ out_channels,
707
+ num_res_blocks,
708
+ attention_resolutions,
709
+ dropout=0,
710
+ channel_mult=(1, 2, 4, 8),
711
+ conv_resample=True,
712
+ dims=2,
713
+ use_checkpoint=False,
714
+ use_fp16=False,
715
+ num_heads=1,
716
+ num_head_channels=-1,
717
+ num_heads_upsample=-1,
718
+ use_scale_shift_norm=False,
719
+ resblock_updown=False,
720
+ use_new_attention_order=False,
721
+ pool="adaptive",
722
+ chord=False,
723
+ ):
724
+ super().__init__()
725
+
726
+ if num_heads_upsample == -1:
727
+ num_heads_upsample = num_heads
728
+
729
+ self.in_channels = in_channels
730
+ self.model_channels = model_channels
731
+ self.out_channels = out_channels
732
+ self.num_res_blocks = num_res_blocks
733
+ self.attention_resolutions = attention_resolutions
734
+ self.dropout = dropout
735
+ self.channel_mult = channel_mult
736
+ self.conv_resample = conv_resample
737
+ self.use_checkpoint = use_checkpoint
738
+ self.dtype = th.float16 if use_fp16 else th.float32
739
+ self.num_heads = num_heads
740
+ self.num_head_channels = num_head_channels
741
+ self.num_heads_upsample = num_heads_upsample
742
+
743
+ time_embed_dim = model_channels * 4
744
+ self.time_embed = nn.Sequential(
745
+ linear(model_channels, time_embed_dim),
746
+ nn.SiLU(),
747
+ linear(time_embed_dim, time_embed_dim),
748
+ )
749
+
750
+ ch = int(channel_mult[0] * model_channels)
751
+ self.input_blocks = nn.ModuleList(
752
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
753
+ )
754
+ self._feature_size = ch
755
+ input_block_chans = [ch]
756
+ ds = 1
757
+ for level, mult in enumerate(channel_mult):
758
+ for _ in range(num_res_blocks):
759
+ layers = [
760
+ ResBlock(
761
+ ch,
762
+ time_embed_dim,
763
+ dropout,
764
+ out_channels=int(mult * model_channels),
765
+ dims=dims,
766
+ use_checkpoint=use_checkpoint,
767
+ use_scale_shift_norm=use_scale_shift_norm,
768
+ )
769
+ ]
770
+ ch = int(mult * model_channels)
771
+ if ds in attention_resolutions:
772
+ layers.append(
773
+ AttentionBlock(
774
+ ch,
775
+ use_checkpoint=use_checkpoint,
776
+ num_heads=num_heads,
777
+ num_head_channels=num_head_channels,
778
+ use_new_attention_order=use_new_attention_order,
779
+ )
780
+ )
781
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
782
+ self._feature_size += ch
783
+ input_block_chans.append(ch)
784
+ if level != len(channel_mult) - 1:
785
+ out_ch = ch
786
+ self.input_blocks.append(
787
+ TimestepEmbedSequential(
788
+ ResBlock(
789
+ ch,
790
+ time_embed_dim,
791
+ dropout,
792
+ out_channels=out_ch,
793
+ dims=dims,
794
+ use_checkpoint=use_checkpoint,
795
+ use_scale_shift_norm=use_scale_shift_norm,
796
+ down=True,
797
+ )
798
+ if resblock_updown
799
+ else Downsample(
800
+ ch, conv_resample, dims=dims, out_channels=out_ch
801
+ )
802
+ )
803
+ )
804
+ ch = out_ch
805
+ input_block_chans.append(ch)
806
+ ds *= 2
807
+ self._feature_size += ch
808
+
809
+ self.middle_block = TimestepEmbedSequential(
810
+ ResBlock(
811
+ ch,
812
+ time_embed_dim,
813
+ dropout,
814
+ dims=dims,
815
+ use_checkpoint=use_checkpoint,
816
+ use_scale_shift_norm=use_scale_shift_norm,
817
+ ),
818
+ AttentionBlock(
819
+ ch,
820
+ use_checkpoint=use_checkpoint,
821
+ num_heads=num_heads,
822
+ num_head_channels=num_head_channels,
823
+ use_new_attention_order=use_new_attention_order,
824
+ ),
825
+ ResBlock(
826
+ ch,
827
+ time_embed_dim,
828
+ dropout,
829
+ dims=dims,
830
+ use_checkpoint=use_checkpoint,
831
+ use_scale_shift_norm=use_scale_shift_norm,
832
+ ),
833
+ )
834
+ self._feature_size += ch
835
+ self.pool = pool
836
+ if pool == "adaptive":
837
+ self.out = nn.Sequential(
838
+ normalization(ch),
839
+ nn.SiLU(),
840
+ nn.AdaptiveAvgPool2d((1, 1)),
841
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
842
+ nn.Flatten(),
843
+ )
844
+ elif pool == "attention":
845
+ assert num_head_channels != -1
846
+ self.out = nn.Sequential(
847
+ normalization(ch),
848
+ nn.SiLU(),
849
+ AttentionPool2d(
850
+ (image_size // ds), ch, num_head_channels, out_channels, chord
851
+ ),
852
+ )
853
+ elif pool == "spatial":
854
+ self.out = nn.Sequential(
855
+ nn.Linear(self._feature_size, 2048),
856
+ nn.ReLU(),
857
+ nn.Linear(2048, self.out_channels),
858
+ )
859
+ elif pool == "spatial_v2":
860
+ self.out = nn.Sequential(
861
+ nn.Linear(self._feature_size, 2048),
862
+ normalization(2048),
863
+ nn.SiLU(),
864
+ nn.Linear(2048, self.out_channels),
865
+ )
866
+ else:
867
+ raise NotImplementedError(f"Unexpected {pool} pooling")
868
+
869
+ def convert_to_fp16(self):
870
+ """
871
+ Convert the torso of the model to float16.
872
+ """
873
+ self.input_blocks.apply(convert_module_to_f16)
874
+ self.middle_block.apply(convert_module_to_f16)
875
+
876
+ def convert_to_fp32(self):
877
+ """
878
+ Convert the torso of the model to float32.
879
+ """
880
+ self.input_blocks.apply(convert_module_to_f32)
881
+ self.middle_block.apply(convert_module_to_f32)
882
+
883
+ def forward(self, x, timesteps):
884
+ """
885
+ Apply the model to an input batch.
886
+
887
+ :param x: an [N x C x ...] Tensor of inputs.
888
+ :param timesteps: a 1-D batch of timesteps.
889
+ :return: an [N x K] Tensor of outputs.
890
+ """
891
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
892
+
893
+ results = []
894
+ h = x.type(self.dtype)
895
+ for module in self.input_blocks:
896
+ h = module(h, emb)
897
+ if self.pool.startswith("spatial"):
898
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
899
+ h = self.middle_block(h, emb)
900
+ if self.pool.startswith("spatial"):
901
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
902
+ h = th.cat(results, axis=-1)
903
+ return self.out(h)
904
+ else:
905
+ h = h.type(x.dtype)
906
+ return self.out(h)
load_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import torch
3
+ from omegaconf import OmegaConf
4
+
5
+
6
+ def get_obj_from_str(string, reload=False):
7
+ module, cls = string.rsplit(".", 1)
8
+ if reload:
9
+ module_imp = importlib.import_module(module)
10
+ importlib.reload(module_imp)
11
+ return getattr(importlib.import_module(module, package=None), cls)
12
+
13
+
14
+ def instantiate_from_config(config):
15
+ if not "target" in config:
16
+ raise KeyError("Expected key `target` to instantiate.")
17
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
18
+
19
+
20
+ def load_model(name, ckpt):
21
+ config = OmegaConf.load(f'taming-transformers/configs/pr/{name}.yaml')
22
+ model = instantiate_from_config(config.model)
23
+ model.init_from_ckpt(ckpt) # load_state_dict(mc['state_dict'])
24
+ model.eval()
25
+ return model
26
+
27
+
28
+ def load_data(name):
29
+ config = OmegaConf.load(f'taming-transformers/configs/pr/{name}.yaml')
30
+ data = instantiate_from_config(config.data)
31
+ return data
music_evaluation/README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Music Evaluation
2
+
3
+ Adapted from this GitHub repository [mgeval](https://github.com/RichardYang40148/mgeval)
4
+
5
+ Deleted all packages in mgeval using python 2.
6
+
7
+ # Packages:
8
+ scipy, numpy, seaborn, pretty_midi, scikit-learn, python 3
9
+
10
+ # Usage:
11
+
12
+ ```
13
+ python music_evaluator.py --set1dir /path/to/your/ground-truth/data/ --set2dir /path/to/your/generated-sample/ --outdir output-dir --num_sample number-of-samples-to-evaluate
14
+ ```
15
+
16
+
17
+ # Output
18
+ All outputs are in the output-dir directory in the current folder, including plots and statistics.txt.
19
+
20
+ Check out the result folder for an example.
21
+
22
+ You can either run music_evaluator.py or demo.ipynb for evaluation.
music_evaluation/convert_to_wav.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from midi2audio import FluidSynth
2
+ import os
3
+ import sys
4
+
5
+
6
+ # This program converts a folder of .midi files to a folder of .wav files
7
+ # Need to download FluidSynth and midi2audio packages
8
+ #
9
+ # Usage:
10
+ # python convert_to_wav.py midi_dir wav_dir
11
+ # More info about FluidSynth: https://github.com/FluidSynth/fluidsynth
12
+ # More info about midi2audio: https://github.com/bzamecnik/midi2audio
13
+ # Need Sound fonts to run this program: https://sites.google.com/site/soundfonts4u/
14
+ # The sound font used in this program: https://drive.google.com/file/d/1nvTy62-wHGnZ6CKYuPNAiGlKLtWg9Ir9/view?usp=sharing
15
+
16
+ def convert_midi_to_audio(input_dir, output_dir, fs):
17
+ # sound_font_path = os.path.join(os.getcwd(), "Dore Mark's NY S&S Model B-v5.2.sf2")
18
+ # fs = FluidSynth(sound_font_path)
19
+ os.chdir(input_dir)
20
+ filenames = os.listdir(input_dir)
21
+ for midi_file in filenames:
22
+ filename = midi_file[:-5]
23
+ filename = filename + ".wav"
24
+ output_file = os.path.join(output_dir, filename)
25
+ fs.midi_to_audio(midi_file, output_file)
26
+
27
+ return
28
+
29
+
30
+ if __name__ == '__main__':
31
+ sound_font_path = os.path.join(os.getcwd(), "Dore Mark's NY S&S Model B-v5.2.sf2")
32
+ fs = FluidSynth(sound_font_path)
33
+ # fs.midi_to_audio('MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_0.midi', 'output.wav')
34
+
35
+ output_dir = sys.argv[2]
36
+ os.makedirs(output_dir, exist_ok=True)
37
+ current_path = os.getcwd()
38
+ output_dir = os.path.join(current_path, output_dir)
39
+
40
+ input_dir = sys.argv[1]
41
+ input_dir = os.path.join(current_path, input_dir)
42
+ convert_midi_to_audio(input_dir, output_dir, fs)
music_evaluation/demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
music_evaluation/fad.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from frechet_audio_distance import FrechetAudioDistance
2
+ import sys
3
+
4
+ # Compute FAD distance between the ground truth dataset and sample dataset
5
+ # Pretty slow depends on the speed
6
+ # Save embeddings.npy for future fast usage
7
+ #
8
+ # Usage: python fad.py background_dir_path eval_dir_path
9
+ # Feel free the change the embedding path in the code
10
+ # More info about FrechetAudioDistance: https://github.com/gudgud96/frechet-audio-distance
11
+
12
+ if __name__ == "__main__":
13
+ # to use `vggish`
14
+ frechet = FrechetAudioDistance(
15
+ model_name="vggish",
16
+ use_pca=False,
17
+ use_activation=False,
18
+ verbose=False
19
+ )
20
+ # # to use `PANN`
21
+ # frechet = FrechetAudioDistance(
22
+ # model_name="pann",
23
+ # use_pca=False,
24
+ # use_activation=False,
25
+ # verbose=False
26
+ # )
27
+
28
+ background_dir = sys.argv[1]
29
+ eval_dir = sys.argv[2]
30
+
31
+ background_embds_path = "./ground_truth_embeddings.npy"
32
+ eval_embds_path = "./eval_embeddings.npy"
33
+
34
+ fad_score = frechet.score(background_dir, eval_dir,
35
+ background_embds_path=background_embds_path,
36
+ eval_embds_path=eval_embds_path,dtype="float32")
37
+
38
+ print(fad_score)
music_evaluation/figaro/chord_recognition.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class MIDIChord(object):
4
+ def __init__(self, pm):
5
+ self.pm = pm
6
+ # define pitch classes
7
+ self.PITCH_CLASSES = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
8
+ # define chord maps (required)
9
+ self.CHORD_MAPS = {'maj': [0, 4],
10
+ 'min': [0, 3],
11
+ 'dim': [0, 3, 6],
12
+ 'aug': [0, 4, 8],
13
+ 'dom7': [0, 4, 10],
14
+ 'maj7': [0, 4, 11],
15
+ 'min7': [0, 3, 10]}
16
+ # define chord insiders (+10)
17
+ self.CHORD_INSIDERS = {'maj': [7],
18
+ 'min': [7],
19
+ 'dim': [9],
20
+ 'aug': [],
21
+ 'dom7': [7],
22
+ 'maj7': [7],
23
+ 'min7': [7]}
24
+ # define chord outsiders (-1)
25
+ self.CHORD_OUTSIDERS_1 = {'maj': [2, 5, 9],
26
+ 'min': [2, 5, 8],
27
+ 'dim': [2, 5, 10],
28
+ 'aug': [2, 5, 9],
29
+ 'dom7': [2, 5, 9],
30
+ 'maj7': [2, 5, 9],
31
+ 'maj7': [2, 5, 9],
32
+ 'min7': [2, 5, 8]}
33
+ # define chord outsiders (-2)
34
+ self.CHORD_OUTSIDERS_2 = {'maj': [1, 3, 6, 8, 10, 11],
35
+ 'min': [1, 4, 6, 9, 11],
36
+ 'dim': [1, 4, 7, 8, 11],
37
+ 'aug': [1, 3, 6, 7, 10],
38
+ 'dom7': [1, 3, 6, 8, 11],
39
+ 'maj7': [1, 3, 6, 8, 10],
40
+ 'min7': [1, 4, 6, 9, 11]}
41
+
42
+ def sequencing(self, chroma):
43
+ candidates = {}
44
+ for index in range(len(chroma)):
45
+ if chroma[index]:
46
+ root_note = index
47
+ _chroma = np.roll(chroma, -root_note)
48
+ sequence = np.where(_chroma == 1)[0]
49
+ candidates[root_note] = list(sequence)
50
+ return candidates
51
+
52
+ def scoring(self, candidates):
53
+ scores = {}
54
+ qualities = {}
55
+ for root_note, sequence in candidates.items():
56
+ if 3 not in sequence and 4 not in sequence:
57
+ scores[root_note] = -100
58
+ qualities[root_note] = 'None'
59
+ elif 3 in sequence and 4 in sequence:
60
+ scores[root_note] = -100
61
+ qualities[root_note] = 'None'
62
+ else:
63
+ # decide quality
64
+ if 3 in sequence:
65
+ if 6 in sequence:
66
+ quality = 'dim'
67
+ else:
68
+ if 10 in sequence:
69
+ quality = 'min7'
70
+ else:
71
+ quality = 'min'
72
+ elif 4 in sequence:
73
+ if 8 in sequence:
74
+ quality = 'aug'
75
+ else:
76
+ if 10 in sequence:
77
+ quality = 'dom7'
78
+ elif 11 in sequence:
79
+ quality = 'maj7'
80
+ else:
81
+ quality = 'maj'
82
+ # decide score
83
+ maps = self.CHORD_MAPS.get(quality)
84
+ _notes = [n for n in sequence if n not in maps]
85
+ score = 0
86
+ for n in _notes:
87
+ if n in self.CHORD_OUTSIDERS_1.get(quality):
88
+ score -= 1
89
+ elif n in self.CHORD_OUTSIDERS_2.get(quality):
90
+ score -= 2
91
+ elif n in self.CHORD_INSIDERS.get(quality):
92
+ score += 10
93
+ scores[root_note] = score
94
+ qualities[root_note] = quality
95
+ return scores, qualities
96
+
97
+ def find_chord(self, chroma, threshold=10):
98
+ chroma = np.sum(chroma, axis=1)
99
+ chroma = np.array([1 if c > threshold else 0 for c in chroma])
100
+ if np.sum(chroma) == 0:
101
+ return 'N', 'N', 'N', 10
102
+ else:
103
+ candidates = self.sequencing(chroma=chroma)
104
+ scores, qualities = self.scoring(candidates=candidates)
105
+ # bass note
106
+ sorted_notes = []
107
+ for i, v in enumerate(chroma):
108
+ if v > 0:
109
+ sorted_notes.append(int(i%12))
110
+ bass_note = sorted_notes[0]
111
+ # root note
112
+ __root_note = []
113
+ _max = max(scores.values())
114
+ for _root_note, score in scores.items():
115
+ if score == _max:
116
+ __root_note.append(_root_note)
117
+ if len(__root_note) == 1:
118
+ root_note = __root_note[0]
119
+ else:
120
+ #TODO: what should i do
121
+ for n in sorted_notes:
122
+ if n in __root_note:
123
+ root_note = n
124
+ break
125
+ # quality
126
+ quality = qualities.get(root_note)
127
+ sequence = candidates.get(root_note)
128
+ # score
129
+ score = scores.get(root_note)
130
+ return self.PITCH_CLASSES[root_note], quality, self.PITCH_CLASSES[bass_note], score
131
+
132
+ def greedy(self, candidates, max_tick, min_length):
133
+ chords = []
134
+ # start from 0
135
+ start_tick = 0
136
+ while start_tick < max_tick:
137
+ _candidates = candidates.get(start_tick)
138
+ _candidates = sorted(_candidates.items(), key=lambda x: (x[1][-1], x[0]))
139
+ # choose
140
+ end_tick, (root_note, quality, bass_note, _) = _candidates[-1]
141
+ if root_note == bass_note:
142
+ chord = '{}:{}'.format(root_note, quality)
143
+ else:
144
+ chord = '{}:{}/{}'.format(root_note, quality, bass_note)
145
+ chords.append([start_tick, end_tick, chord])
146
+ start_tick = end_tick
147
+ # remove :None
148
+ temp = chords
149
+ while ':None' in temp[0][-1]:
150
+ try:
151
+ temp[1][0] = temp[0][0]
152
+ del temp[0]
153
+ except:
154
+ print('NO CHORD')
155
+ return []
156
+ temp2 = []
157
+ for chord in temp:
158
+ if ':None' not in chord[-1]:
159
+ temp2.append(chord)
160
+ else:
161
+ temp2[-1][1] = chord[1]
162
+ return temp2
163
+
164
+ def dynamic(self, candidates, max_tick, min_length):
165
+ # store index of best chord at each position
166
+ chords = [None for i in range(max_tick + 1)]
167
+ # store score of best chords at each position
168
+ scores = np.zeros(max_tick + 1)
169
+ scores[1:].fill(np.NINF)
170
+
171
+ start_tick = 0
172
+ while start_tick < max_tick:
173
+ if start_tick in candidates:
174
+ for i, (end_tick, candidate) in enumerate(candidates.get(start_tick).items()):
175
+ root_note, quality, bass_note, score = candidate
176
+ # if this candidate is best yet, update scores and chords
177
+ if scores[end_tick] < scores[start_tick] + score:
178
+ scores[end_tick] = scores[start_tick] + score
179
+ if root_note == bass_note:
180
+ chord = '{}:{}'.format(root_note, quality)
181
+ else:
182
+ chord = '{}:{}/{}'.format(root_note, quality, bass_note)
183
+ chords[end_tick] = (start_tick, end_tick, chord)
184
+ start_tick += 1
185
+ # Read the best path
186
+ start_tick = len(chords) - 1
187
+ results = []
188
+ while start_tick > 0:
189
+ chord = chords[start_tick]
190
+ start_tick = chord[0]
191
+ results.append(chord)
192
+
193
+ return list(reversed(results))
194
+
195
+ def dedupe(self, chords):
196
+ if len(chords) == 0:
197
+ return []
198
+ deduped = []
199
+ start, end, chord = chords[0]
200
+ for (curr, next) in zip(chords[:-1], chords[1:]):
201
+ if chord == next[2]:
202
+ end = next[1]
203
+ else:
204
+ deduped.append([start, end, chord])
205
+ start, end, chord = next
206
+ deduped.append([start, end, chord])
207
+ return deduped
208
+
209
+ def get_candidates(self, chroma, max_tick, intervals=[1, 2, 3, 4]):
210
+ candidates = {}
211
+ for interval in intervals:
212
+ for start_beat in range(max_tick):
213
+ # set target pianoroll
214
+ end_beat = start_beat + interval
215
+ if end_beat > max_tick:
216
+ end_beat = max_tick
217
+ _chroma = chroma[:, start_beat:end_beat]
218
+ # find chord
219
+ root_note, quality, bass_note, score = self.find_chord(chroma=_chroma)
220
+ # save
221
+ if start_beat not in candidates:
222
+ candidates[start_beat] = {}
223
+ candidates[start_beat][end_beat] = (root_note, quality, bass_note, score)
224
+ else:
225
+ if end_beat not in candidates[start_beat]:
226
+ candidates[start_beat][end_beat] = (root_note, quality, bass_note, score)
227
+ return candidates
228
+
229
+ def extract(self):
230
+ # read
231
+ beats = self.pm.get_beats()
232
+ chroma = self.pm.get_chroma(times=beats)
233
+ # get lots of candidates
234
+ candidates = self.get_candidates(chroma, max_tick=len(beats))
235
+
236
+ # greedy
237
+ chords = self.dynamic(candidates=candidates,
238
+ max_tick=len(beats),
239
+ min_length=1)
240
+ chords = self.dedupe(chords)
241
+ for chord in chords:
242
+ chord[0] = beats[chord[0]]
243
+ if chord[1] >= len(beats):
244
+ chord[1] = self.pm.get_end_time()
245
+ else:
246
+ chord[1] = beats[chord[1]]
247
+ return chords
music_evaluation/figaro/constants.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # parameters for input representation
4
+ DEFAULT_POS_PER_QUARTER = 12
5
+ DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 32+1, dtype=np.int32)
6
+ DEFAULT_DURATION_BINS = np.sort(np.concatenate([
7
+ np.arange(1, 13), # smallest possible units up to 1 quarter
8
+ np.arange(12, 24, 3)[1:], # 16th notes up to 1 bar
9
+ np.arange(13, 24, 4)[1:], # triplets up to 1 bar
10
+ np.arange(24, 48, 6), # 8th notes up to 2 bars
11
+ np.arange(48, 4*48, 12), # quarter notes up to 8 bars
12
+ np.arange(4*48, 16*48+1, 24) # half notes up to 16 bars
13
+ ]))
14
+ DEFAULT_TEMPO_BINS = np.linspace(0, 240, 32+1, dtype=np.int32)
15
+ DEFAULT_NOTE_DENSITY_BINS = np.linspace(0, 12, 32+1)
16
+ DEFAULT_MEAN_VELOCITY_BINS = np.linspace(0, 128, 32+1)
17
+ DEFAULT_MEAN_PITCH_BINS = np.linspace(0, 128, 32+1)
18
+ DEFAULT_MEAN_DURATION_BINS = np.logspace(0, 7, 32+1, base=2) # log space between 1 and 128 positions (~2.5 bars)
19
+
20
+ # parameters for output
21
+ DEFAULT_RESOLUTION = 480
22
+
23
+ # maximum length of a single bar is 3*4 = 12 beats
24
+ MAX_BAR_LENGTH = 3
25
+ # maximum number of bars in a piece is 512 (this covers almost all sequences)
26
+ MAX_N_BARS = 512
27
+
28
+ PAD_TOKEN = '<pad>'
29
+ UNK_TOKEN = '<unk>'
30
+ BOS_TOKEN = '<bos>'
31
+ EOS_TOKEN = '<eos>'
32
+ MASK_TOKEN = '<mask>'
33
+
34
+ TIME_SIGNATURE_KEY = 'Time Signature'
35
+ BAR_KEY = 'Bar'
36
+ POSITION_KEY = 'Position'
37
+ INSTRUMENT_KEY = 'Instrument'
38
+ PITCH_KEY = 'Pitch'
39
+ VELOCITY_KEY = 'Velocity'
40
+ DURATION_KEY = 'Duration'
41
+ TEMPO_KEY = 'Tempo'
42
+ CHORD_KEY = 'Chord'
43
+
44
+ NOTE_DENSITY_KEY = 'Note Density'
45
+ MEAN_PITCH_KEY = 'Mean Pitch'
46
+ MEAN_VELOCITY_KEY = 'Mean Velocity'
47
+ MEAN_DURATION_KEY = 'Mean Duration'
music_evaluation/figaro/evaluate.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+ from statistics import NormalDist
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ import input_representation as ir
7
+
8
+ SAMPLE_DIR = os.getenv('SAMPLE_DIR', './samples')
9
+ OUT_FILE = os.getenv('OUT_FILE', './metrics.csv')
10
+ MAX_SAMPLES = int(os.getenv('MAX_SAMPLES', 1024))
11
+ # use to find base file name when generate multiple files for 1 gt file
12
+ SPLIT_STR = os.getenv('SPLIT_STR', None)
13
+ POST_STR = os.getenv('POST_STR', None)
14
+
15
+ METRICS = [
16
+ 'inst_prec', 'inst_rec', 'inst_f1',
17
+ 'chord_prec', 'chord_rec', 'chord_f1',
18
+ 'time_sig_acc',
19
+ 'note_dens_oa', 'pitch_oa', 'velocity_oa', 'duration_oa',
20
+ 'chroma_crossent', 'chroma_kldiv', 'chroma_sim',
21
+ 'groove_crossent', 'groove_kldiv', 'groove_sim',
22
+ ]
23
+
24
+ DF_KEYS = ['id', 'original', 'sample'] + METRICS
25
+
26
+ keys = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
27
+ qualities = ['maj', 'min', 'dim', 'aug', 'dom7', 'maj7', 'min7', 'None']
28
+ CHORDS = [f"{k}:{q}" for k in keys for q in qualities] + ['N:N']
29
+
30
+ def get_group_id(file):
31
+ # change this depending on name of generated samples
32
+ name = os.path.basename(file)
33
+ return name.split('.')[0]
34
+
35
+ def get_base_name(file):
36
+ base_file_name = os.path.basename(file)
37
+ base_name = base_file_name.split(SPLIT_STR)[0]
38
+ if POST_STR is not None:
39
+ gt_name = base_name + POST_STR
40
+ else:
41
+ gt_name = base_name
42
+ return gt_name
43
+
44
+ def get_file_groups(path, max_samples=MAX_SAMPLES):
45
+ # change this depending on file structure of generated samples
46
+ files = glob.glob(os.path.join(path, '*.mid'), recursive=True) + glob.glob(os.path.join(path, '*.midi'), recursive=True)
47
+ assert len(files), f"provided directory was empty: {path}"
48
+
49
+ samples = sorted(files)
50
+ origs = sorted([os.path.join(path, 'gt', get_base_name(file)) for file in files])
51
+ pairs = list(zip(origs, samples))
52
+
53
+ pairs = list(filter(lambda pair: os.path.exists(pair[0]), pairs))
54
+ if max_samples > 0:
55
+ pairs = pairs[:max_samples]
56
+
57
+ groups = dict()
58
+ for orig, sample in pairs:
59
+ sample_id = get_group_id(sample)
60
+ orig_id = get_group_id(orig)
61
+ if orig_id not in groups:
62
+ groups[orig_id] = list()
63
+ groups[orig_id].append((orig, sample))
64
+
65
+ return list(groups.values())
66
+
67
+ def read_file(file):
68
+ with open(file, 'r') as f:
69
+ events = f.read().split('\n')
70
+ events = [e for e in events if e]
71
+ return events
72
+
73
+ def get_chord_groups(desc):
74
+ bars = [1 if 'Bar_' in item else 0 for item in desc]
75
+ bar_ids = np.cumsum(bars) - 1
76
+ groups = [[] for _ in range(bar_ids[-1] + 1)]
77
+ for i, item in enumerate(desc):
78
+ if 'Chord_' in item:
79
+ chord = item.split('_')[-1]
80
+ groups[bar_ids[i]].append(chord)
81
+ return groups
82
+
83
+ def instruments(events):
84
+ insts = [128 if item.instrument == 'drum' else int(item.instrument) for item in events[1:-1] if item.name == 'Note']
85
+ insts = np.bincount(insts, minlength=129)
86
+ return (insts > 0).astype(int)
87
+
88
+ def chords(events):
89
+ chords = [CHORDS.index(item) for item in events]
90
+ chords = np.bincount(chords, minlength=129)
91
+ return (chords > 0).astype(int)
92
+
93
+ def chroma(events):
94
+ pitch_classes = [item.pitch % 12 for item in events[1:-1] if item.name == 'Note' and item.instrument != 'drum']
95
+ if len(pitch_classes):
96
+ count = np.bincount(pitch_classes, minlength=12)
97
+ count = count / np.sqrt(np.sum(count ** 2))
98
+ else:
99
+ count = np.array([1/12] * 12)
100
+ return count
101
+
102
+ def groove(events, start=0, pos_per_bar=48, ticks_per_bar=1920):
103
+ flags = np.linspace(start, start + ticks_per_bar, pos_per_bar, endpoint=False)
104
+ onsets = [item.start for item in events[1:-1] if item.name == 'Note']
105
+ positions = [np.argmin(np.abs(flags - beat)) for beat in onsets]
106
+ if len(positions):
107
+ count = np.bincount(positions, minlength=pos_per_bar)
108
+ count = np.convolve(count, [1, 4, 1], 'same')
109
+ count = count / np.sqrt(np.sum(count ** 2))
110
+ else:
111
+ count = np.array([1/pos_per_bar] * pos_per_bar)
112
+ return count
113
+
114
+ def multi_class_accuracy(y_true, y_pred):
115
+ tp = ((y_true == 1) & (y_pred == 1)).sum()
116
+ p = tp / y_pred.sum()
117
+ r = tp / y_true.sum()
118
+ if p + r > 0:
119
+ f1 = 2*p*r / (p + r)
120
+ else:
121
+ f1 = 0
122
+ return p, r, f1
123
+
124
+ def cross_entropy(p_true, p_pred, eps=1e-8):
125
+ return -np.sum(p_true * np.log(p_pred + eps)) / len(p_true)
126
+
127
+ def kl_divergence(p_true, p_pred, eps=1e-8):
128
+ return np.sum(p_true * (np.log(p_true + eps) - np.log(p_pred + eps))) / len(p_true)
129
+
130
+ def cosine_sim(p_true, p_pred):
131
+ return np.sum(p_true * p_pred)
132
+
133
+ def sliding_window_metrics(items, start, end, window=1920, step=480, ticks_per_beat=480):
134
+ glob_start, glob_end = start, end
135
+ notes = [item for item in items if item.name == 'Note']
136
+ starts = np.arange(glob_start, glob_end - window, step=step)
137
+
138
+ groups = []
139
+ start_idx, end_idx = 0, 0
140
+ for start in starts:
141
+ while notes[start_idx].start < start:
142
+ start_idx += 1
143
+ while end_idx < len(notes) and notes[end_idx].start < start + window:
144
+ end_idx += 1
145
+
146
+ groups.append([start] + notes[start_idx:end_idx] + [start + window])
147
+ return groups
148
+
149
+ def meta_stats(group, ticks_per_beat=480):
150
+ start, end = group[0], group[-1]
151
+ ns = [item for item in group[1:-1] if item.name == 'Note']
152
+ ns_ = [note for note in ns if note.instrument != 'drum']
153
+ pitches = [note.pitch for note in ns_]
154
+ vels = [note.velocity for note in ns_]
155
+ durs = [(note.end - note.start) / ticks_per_beat for note in ns_]
156
+
157
+ return {
158
+ 'note_density': len(ns) / ((end - start) / ticks_per_beat),
159
+ 'pitch_mean': np.mean(pitches) if len(pitches) else np.nan,
160
+ 'velocity_mean': np.mean(vels) if len(vels) else np.nan,
161
+ 'duration_mean': np.mean(durs) if len(durs) else np.nan,
162
+ 'pitch_std': np.std(pitches) if len(pitches) else np.nan,
163
+ 'velocity_std': np.std(vels) if len(vels) else np.nan,
164
+ 'duration_std': np.std(durs) if len(durs) else np.nan,
165
+ }
166
+
167
+ def overlapping_area(mu1, sigma1, mu2, sigma2, eps=0.01):
168
+ sigma1, sigma2 = max(eps, sigma1), max(eps, sigma2)
169
+ return NormalDist(mu=mu1, sigma=sigma1).overlap(NormalDist(mu=mu2, sigma=sigma2))
170
+
171
+
172
+
173
+ def main():
174
+ file_groups = get_file_groups(SAMPLE_DIR)
175
+
176
+ metrics = pd.DataFrame()
177
+ for sample_id, group in enumerate(file_groups):
178
+ print(f"[info] Group {sample_id + 1}/{len(file_groups)}")
179
+ micro_metrics = pd.DataFrame()
180
+ for orig_file, sample_file in group:
181
+ print(f"original: {orig_file.split('/')[-1]} | sample: {sample_file.split('/')[-1]}")
182
+ orig = ir.InputRepresentation(orig_file)
183
+ sample = ir.InputRepresentation(sample_file)
184
+
185
+ orig_desc, sample_desc = orig.get_description(), sample.get_description()
186
+ if len(orig_desc) == 0 or len(sample_desc) == 0:
187
+ print("[warning] empty sample! skipping")
188
+ continue
189
+
190
+ chord_groups1 = get_chord_groups(orig_desc)
191
+ chord_groups2 = get_chord_groups(sample_desc)
192
+
193
+ note_density_gt = []
194
+
195
+ for g1, g2, cg1, cg2 in zip(orig.groups, sample.groups, chord_groups1, chord_groups2):
196
+ row = pd.DataFrame([{ 'id': sample_id, 'original': orig_file.split('/')[-1], 'sample': sample_file.split('/')[-1]}])
197
+
198
+ meta1, meta2 = meta_stats(g1, ticks_per_beat=orig.pm.resolution), meta_stats(g2, ticks_per_beat=sample.pm.resolution)
199
+ row['pitch_oa'] = overlapping_area(meta1['pitch_mean'], meta1['pitch_std'], meta2['pitch_mean'], meta2['pitch_std'])
200
+ row['velocity_oa'] = overlapping_area(meta1['velocity_mean'], meta1['velocity_std'], meta2['velocity_mean'], meta2['velocity_std'])
201
+ row['duration_oa'] = overlapping_area(meta1['duration_mean'], meta1['duration_std'], meta2['duration_mean'], meta2['duration_std'])
202
+ row['note_density_abs_err'] = np.abs(meta1['note_density'] - meta2['note_density'])
203
+ row['mean_pitch_abs_err'] = np.abs(meta1['pitch_mean'] - meta2['pitch_mean'])
204
+ row['mean_velocity_abs_err'] = np.abs(meta1['velocity_mean'] - meta2['velocity_mean'])
205
+ row['mean_duration_abs_err'] = np.abs(meta1['duration_mean'] - meta2['duration_mean'])
206
+ note_density_gt.append(meta1['note_density'])
207
+
208
+ ts1, ts2 = orig._get_time_signature(g1[0]), sample._get_time_signature(g2[0])
209
+ ts1, ts2 = f"{ts1.numerator}/{ts1.denominator}", f"{ts2.numerator}/{ts2.denominator}"
210
+ row['time_sig_acc'] = 1 if ts1 == ts2 else 0
211
+
212
+ inst1, inst2 = instruments(g1), instruments(g2)
213
+ prec, rec, f1 = multi_class_accuracy(inst1, inst2)
214
+ row['inst_prec'] = prec
215
+ row['inst_rec'] = rec
216
+ row['inst_f1'] = f1
217
+
218
+ chords1, chords2 = chords(cg1), chords(cg2)
219
+ prec, rec, f1 = multi_class_accuracy(chords1, chords2)
220
+ row['chord_prec'] = prec
221
+ row['chord_rec'] = rec
222
+ row['chord_f1'] = f1
223
+
224
+ c1, c2 = chroma(g1), chroma(g2)
225
+ row['chroma_crossent'] = cross_entropy(c1, c2)
226
+ row['chroma_kldiv'] = kl_divergence(c1, c2)
227
+ row['chroma_sim'] = cosine_sim(c1, c2)
228
+
229
+ ppb = max(orig._get_positions_per_bar(g1[0]), sample._get_positions_per_bar(g2[0]))
230
+ tpb = max(orig._get_ticks_per_bar(g1[0]), sample._get_ticks_per_bar(g2[0]))
231
+ r1 = groove(g1, start=g1[0], pos_per_bar=ppb, ticks_per_bar=tpb)
232
+ r2 = groove(g2, start=g2[0], pos_per_bar=ppb, ticks_per_bar=tpb)
233
+ row['groove_crossent'] = cross_entropy(r1, r2)
234
+ row['groove_kldiv'] = kl_divergence(r1, r2)
235
+ row['groove_sim'] = cosine_sim(r1, r2)
236
+
237
+ micro_metrics = pd.concat([micro_metrics, row], ignore_index=True)
238
+ if len(micro_metrics) == 0:
239
+ continue
240
+
241
+ nd_mean = np.mean(note_density_gt)
242
+ micro_metrics['note_density_nsq_err'] = micro_metrics['note_density_abs_err']**2 / nd_mean**2
243
+
244
+ metrics = pd.concat([metrics, micro_metrics], ignore_index=True)
245
+
246
+ micro_avg = micro_metrics.mean(numeric_only=True)
247
+ print("[info] Group {}: inst_f1={:.2f} | chord_f1={:.2f} | pitch_oa={:.2f} | vel_oa={:.2f} | dur_oa={:.2f} | chroma_sim={:.2f} | groove_sim={:.2f}".format(
248
+ sample_id+1, micro_avg['inst_f1'], micro_avg['chord_f1'], micro_avg['pitch_oa'], micro_avg['velocity_oa'], micro_avg['duration_oa'], micro_avg['chroma_sim'], micro_avg['groove_sim']
249
+ ))
250
+
251
+ os.makedirs(os.path.dirname(OUT_FILE), exist_ok=True)
252
+ # metrics.to_csv(OUT_FILE, index=False)
253
+
254
+ summary_keys = ['inst_f1', 'chord_f1', 'time_sig_acc', 'pitch_oa', 'velocity_oa', 'duration_oa', 'chroma_sim', 'groove_sim']
255
+ summary = metrics[summary_keys + ['id']].groupby('id').mean().mean()
256
+
257
+ nsq_err = metrics.groupby('id')['note_density_nsq_err'].mean()
258
+ summary['note_density_nrmse'] = np.sqrt(nsq_err).mean()
259
+
260
+ print('***** SUMMARY *****')
261
+ print(summary)
262
+
263
+ summary.to_frame().T.to_csv(OUT_FILE, index=False)
264
+
265
+ print("done")
266
+
267
+ if __name__ == '__main__':
268
+ main()
music_evaluation/figaro/input_representation.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from chord_recognition import MIDIChord
2
+ import numpy as np
3
+ import pretty_midi
4
+
5
+ from vocab import RemiVocab
6
+
7
+ from constants import (
8
+ EOS_TOKEN,
9
+ # vocab keys
10
+ TIME_SIGNATURE_KEY,
11
+ BAR_KEY,
12
+ POSITION_KEY,
13
+ INSTRUMENT_KEY,
14
+ PITCH_KEY,
15
+ VELOCITY_KEY,
16
+ DURATION_KEY,
17
+ TEMPO_KEY,
18
+ CHORD_KEY,
19
+ NOTE_DENSITY_KEY,
20
+ MEAN_PITCH_KEY,
21
+ MEAN_VELOCITY_KEY,
22
+ MEAN_DURATION_KEY,
23
+ # discretization parameters
24
+ DEFAULT_POS_PER_QUARTER,
25
+ DEFAULT_VELOCITY_BINS,
26
+ DEFAULT_DURATION_BINS,
27
+ DEFAULT_TEMPO_BINS,
28
+ DEFAULT_NOTE_DENSITY_BINS,
29
+ DEFAULT_MEAN_VELOCITY_BINS,
30
+ DEFAULT_MEAN_PITCH_BINS,
31
+ DEFAULT_MEAN_DURATION_BINS,
32
+ DEFAULT_RESOLUTION
33
+ )
34
+
35
+ # define "Item" for general storage
36
+ class Item(object):
37
+ def __init__(self, name, start, end, velocity=None, pitch=None, instrument=None):
38
+ self.name = name
39
+ self.start = start
40
+ self.end = end
41
+ self.velocity = velocity
42
+ self.pitch = pitch
43
+ self.instrument = instrument
44
+
45
+ def __repr__(self):
46
+ return 'Item(name={}, start={}, end={}, velocity={}, pitch={}, instrument={})'.format(
47
+ self.name, self.start, self.end, self.velocity, self.pitch, self.instrument)
48
+
49
+ # define "Event" for event storage
50
+ class Event(object):
51
+ def __init__(self, name, time, value, text):
52
+ self.name = name
53
+ self.time = time
54
+ self.value = value
55
+ self.text = text
56
+
57
+ def __repr__(self):
58
+ return 'Event(name={}, time={}, value={}, text={})'.format(
59
+ self.name, self.time, self.value, self.text)
60
+
61
+ class InputRepresentation():
62
+ def version():
63
+ return 'v4'
64
+
65
+ def __init__(self, file, do_extract_chords=True, strict=False):
66
+ if isinstance(file, pretty_midi.PrettyMIDI):
67
+ self.pm = file
68
+ else:
69
+ self.pm = pretty_midi.PrettyMIDI(file)
70
+
71
+ if strict and len(self.pm.time_signature_changes) == 0:
72
+ raise ValueError("Invalid MIDI file: No time signature defined")
73
+
74
+ self.resolution = self.pm.resolution
75
+
76
+ self.note_items = None
77
+ self.tempo_items = None
78
+ self.chords = None
79
+ self.groups = None
80
+
81
+ self._read_items()
82
+ self._quantize_items()
83
+ if do_extract_chords:
84
+ self.extract_chords()
85
+ self._group_items()
86
+
87
+ if strict and len(self.note_items) == 0:
88
+ raise ValueError("Invalid MIDI file: No notes found, empty file.")
89
+
90
+ # read notes and tempo changes from midi (assume there is only one track)
91
+ def _read_items(self):
92
+ # note
93
+ self.note_items = []
94
+ for instrument in self.pm.instruments:
95
+ pedal_events = [event for event in instrument.control_changes if event.number == 64]
96
+ pedal_pressed = False
97
+ start = None
98
+ pedals = []
99
+ for e in pedal_events:
100
+ if e.value >= 64 and not pedal_pressed:
101
+ pedal_pressed = True
102
+ start = e.time
103
+ elif e.value < 64 and pedal_pressed:
104
+ pedal_pressed = False
105
+ pedals.append(Item(name='Pedal', start=start, end=e.time))
106
+ start = e.time
107
+
108
+ notes = instrument.notes
109
+ notes.sort(key=lambda x: (x.start, x.pitch))
110
+
111
+ if instrument.is_drum:
112
+ instrument_name = 'drum'
113
+ else:
114
+ instrument_name = instrument.program
115
+
116
+ pedal_idx = 0
117
+ for note in notes:
118
+ pedal_candidates = [(i + pedal_idx, pedal) for i, pedal in enumerate(pedals[pedal_idx:]) if note.end >= pedal.start and note.start < pedal.end]
119
+ if len(pedal_candidates) > 0:
120
+ pedal_idx = pedal_candidates[0][0]
121
+ pedal = pedal_candidates[-1][1]
122
+ else:
123
+ pedal = Item(name='Pedal', start=0, end=0)
124
+
125
+ self.note_items.append(Item(
126
+ name='Note',
127
+ start=self.pm.time_to_tick(note.start),
128
+ end=self.pm.time_to_tick(max(note.end, pedal.end)),
129
+ velocity=note.velocity,
130
+ pitch=note.pitch,
131
+ instrument=instrument_name))
132
+ self.note_items.sort(key=lambda x: (x.start, x.pitch))
133
+ # tempo
134
+ self.tempo_items = []
135
+ times, tempi = self.pm.get_tempo_changes()
136
+ for time, tempo in zip(times, tempi):
137
+ self.tempo_items.append(Item(
138
+ name='Tempo',
139
+ start=time,
140
+ end=None,
141
+ velocity=None,
142
+ pitch=int(tempo)))
143
+ self.tempo_items.sort(key=lambda x: x.start)
144
+ # expand to all beat
145
+ max_tick = self.pm.time_to_tick(self.pm.get_end_time())
146
+ existing_ticks = {item.start: item.pitch for item in self.tempo_items}
147
+ wanted_ticks = np.arange(0, max_tick+1, DEFAULT_RESOLUTION)
148
+ output = []
149
+ for tick in wanted_ticks:
150
+ if tick in existing_ticks:
151
+ output.append(Item(
152
+ name='Tempo',
153
+ start=self.pm.time_to_tick(tick),
154
+ end=None,
155
+ velocity=None,
156
+ pitch=existing_ticks[tick]))
157
+ else:
158
+ output.append(Item(
159
+ name='Tempo',
160
+ start=self.pm.time_to_tick(tick),
161
+ end=None,
162
+ velocity=None,
163
+ pitch=output[-1].pitch))
164
+ self.tempo_items = output
165
+
166
+ # quantize items
167
+ def _quantize_items(self):
168
+ ticks = self.resolution / DEFAULT_POS_PER_QUARTER
169
+ # grid
170
+ end_tick = self.pm.time_to_tick(self.pm.get_end_time())
171
+ grids = np.arange(0, max(self.resolution, end_tick), ticks)
172
+ # process
173
+ for item in self.note_items:
174
+ index = np.searchsorted(grids, item.start, side='right')
175
+ if index > 0:
176
+ index -= 1
177
+ shift = round(grids[index]) - item.start
178
+ item.start += shift
179
+ item.end += shift
180
+
181
+ def get_end_tick(self):
182
+ return self.pm.time_to_tick(self.pm.get_end_time())
183
+
184
+ # extract chord
185
+ def extract_chords(self):
186
+ end_tick = self.pm.time_to_tick(self.pm.get_end_time())
187
+ if end_tick < self.resolution:
188
+ # If sequence is shorter than 1/4th note, it's probably empty
189
+ self.chords = []
190
+ return self.chords
191
+ method = MIDIChord(self.pm)
192
+ chords = method.extract()
193
+ output = []
194
+ for chord in chords:
195
+ output.append(Item(
196
+ name='Chord',
197
+ start=self.pm.time_to_tick(chord[0]),
198
+ end=self.pm.time_to_tick(chord[1]),
199
+ velocity=None,
200
+ pitch=chord[2].split('/')[0]))
201
+ if len(output) == 0 or output[0].start > 0:
202
+ if len(output) == 0:
203
+ end = self.pm.time_to_tick(self.pm.get_end_time())
204
+ else:
205
+ end = output[0].start
206
+ output.append(Item(
207
+ name='Chord',
208
+ start=0,
209
+ end=end,
210
+ velocity=None,
211
+ pitch='N:N'
212
+ ))
213
+ self.chords = output
214
+ return self.chords
215
+
216
+ # group items
217
+ def _group_items(self):
218
+ if self.chords:
219
+ items = self.chords + self.tempo_items + self.note_items
220
+ else:
221
+ items = self.tempo_items + self.note_items
222
+
223
+ def _get_key(item):
224
+ type_priority = {
225
+ 'Chord': 0,
226
+ 'Tempo': 1,
227
+ 'Note': 2
228
+ }
229
+ return (
230
+ item.start, # order by time
231
+ type_priority[item.name], # chord events first, then tempo events, then note events
232
+ -1 if item.instrument == 'drum' else item.instrument, # order by instrument
233
+ item.pitch # order by note pitch
234
+ )
235
+
236
+ items.sort(key=_get_key)
237
+ downbeats = self.pm.get_downbeats()
238
+ downbeats = np.concatenate([downbeats, [self.pm.get_end_time()]])
239
+ self.groups = []
240
+ for db1, db2 in zip(downbeats[:-1], downbeats[1:]):
241
+ db1, db2 = self.pm.time_to_tick(db1), self.pm.time_to_tick(db2)
242
+ insiders = []
243
+ for item in items:
244
+ if (item.start >= db1) and (item.start < db2):
245
+ insiders.append(item)
246
+ overall = [db1] + insiders + [db2]
247
+ self.groups.append(overall)
248
+
249
+ # Trim empty groups from the beginning and end
250
+ for idx in [0, -1]:
251
+ while len(self.groups) > 0:
252
+ group = self.groups[idx]
253
+ notes = [item for item in group[1:-1] if item.name == 'Note']
254
+ if len(notes) == 0:
255
+ self.groups.pop(idx)
256
+ else:
257
+ break
258
+
259
+ return self.groups
260
+
261
+ def _get_time_signature(self, start):
262
+ # This method assumes that time signature changes don't happen within a bar
263
+ # which is a convention that commonly holds
264
+ time_sig = None
265
+ for curr_sig, next_sig in zip(self.pm.time_signature_changes[:-1], self.pm.time_signature_changes[1:]):
266
+ if self.pm.time_to_tick(curr_sig.time) <= start and self.pm.time_to_tick(next_sig.time) > start:
267
+ time_sig = curr_sig
268
+ break
269
+ if time_sig is None:
270
+ time_sig = self.pm.time_signature_changes[-1]
271
+ return time_sig
272
+
273
+ def _get_ticks_per_bar(self, start):
274
+ time_sig = self._get_time_signature(start)
275
+ quarters_per_bar = 4 * time_sig.numerator / time_sig.denominator
276
+ return self.pm.resolution * quarters_per_bar
277
+
278
+ def _get_positions_per_bar(self, start=None, time_sig=None):
279
+ if time_sig is None:
280
+ time_sig = self._get_time_signature(start)
281
+ quarters_per_bar = 4 * time_sig.numerator / time_sig.denominator
282
+ positions_per_bar = int(DEFAULT_POS_PER_QUARTER * quarters_per_bar)
283
+ return positions_per_bar
284
+
285
+ def tick_to_position(self, tick):
286
+ return round(tick / self.pm.resolution * DEFAULT_POS_PER_QUARTER)
287
+
288
+ # item to event
289
+ def get_remi_events(self):
290
+ events = []
291
+ n_downbeat = 0
292
+ current_chord = None
293
+ current_tempo = None
294
+ for i in range(len(self.groups)):
295
+ bar_st, bar_et = self.groups[i][0], self.groups[i][-1]
296
+ n_downbeat += 1
297
+ positions_per_bar = self._get_positions_per_bar(bar_st)
298
+ if positions_per_bar <= 0:
299
+ raise ValueError('Invalid REMI file: There must be at least 1 position per bar.')
300
+
301
+ events.append(Event(
302
+ name=BAR_KEY,
303
+ time=None,
304
+ value='{}'.format(n_downbeat),
305
+ text='{}'.format(n_downbeat)))
306
+
307
+ time_sig = self._get_time_signature(bar_st)
308
+ events.append(Event(
309
+ name=TIME_SIGNATURE_KEY,
310
+ time=None,
311
+ value='{}/{}'.format(time_sig.numerator, time_sig.denominator),
312
+ text='{}/{}'.format(time_sig.numerator, time_sig.denominator)
313
+ ))
314
+
315
+ if current_chord is not None:
316
+ events.append(Event(
317
+ name=POSITION_KEY,
318
+ time=0,
319
+ value='{}'.format(0),
320
+ text='{}/{}'.format(1, positions_per_bar)))
321
+ events.append(Event(
322
+ name=CHORD_KEY,
323
+ time=current_chord.start,
324
+ value=current_chord.pitch,
325
+ text='{}'.format(current_chord.pitch)))
326
+
327
+ if current_tempo is not None:
328
+ events.append(Event(
329
+ name=POSITION_KEY,
330
+ time=0,
331
+ value='{}'.format(0),
332
+ text='{}/{}'.format(1, positions_per_bar)))
333
+ tempo = current_tempo.pitch
334
+ index = np.argmin(abs(DEFAULT_TEMPO_BINS-tempo))
335
+ events.append(Event(
336
+ name=TEMPO_KEY,
337
+ time=current_tempo.start,
338
+ value=index,
339
+ text='{}/{}'.format(tempo, DEFAULT_TEMPO_BINS[index])))
340
+
341
+ quarters_per_bar = 4 * time_sig.numerator / time_sig.denominator
342
+ ticks_per_bar = self.pm.resolution * quarters_per_bar
343
+ flags = np.linspace(bar_st, bar_st + ticks_per_bar, positions_per_bar, endpoint=False)
344
+ for item in self.groups[i][1:-1]:
345
+ # position
346
+ index = np.argmin(abs(flags-item.start))
347
+ pos_event = Event(
348
+ name=POSITION_KEY,
349
+ time=item.start,
350
+ value='{}'.format(index),
351
+ text='{}/{}'.format(index+1, positions_per_bar))
352
+
353
+ if item.name == 'Note':
354
+ events.append(pos_event)
355
+ # instrument
356
+ if item.instrument == 'drum':
357
+ name = 'drum'
358
+ else:
359
+ name = pretty_midi.program_to_instrument_name(item.instrument)
360
+ events.append(Event(
361
+ name=INSTRUMENT_KEY,
362
+ time=item.start,
363
+ value=name,
364
+ text='{}'.format(name)))
365
+ # pitch
366
+ events.append(Event(
367
+ name=PITCH_KEY,
368
+ time=item.start,
369
+ value='drum_{}'.format(item.pitch) if name == 'drum' else item.pitch,
370
+ text='{}'.format(pretty_midi.note_number_to_name(item.pitch))))
371
+ # velocity
372
+ velocity_index = np.argmin(abs(DEFAULT_VELOCITY_BINS - item.velocity))
373
+ events.append(Event(
374
+ name=VELOCITY_KEY,
375
+ time=item.start,
376
+ value=velocity_index,
377
+ text='{}/{}'.format(item.velocity, DEFAULT_VELOCITY_BINS[velocity_index])))
378
+ # duration
379
+ duration = self.tick_to_position(item.end - item.start)
380
+ index = np.argmin(abs(DEFAULT_DURATION_BINS-duration))
381
+ events.append(Event(
382
+ name=DURATION_KEY,
383
+ time=item.start,
384
+ value=index,
385
+ text='{}/{}'.format(duration, DEFAULT_DURATION_BINS[index])))
386
+ elif item.name == 'Chord':
387
+ if current_chord is None or item.pitch != current_chord.pitch:
388
+ events.append(pos_event)
389
+ events.append(Event(
390
+ name=CHORD_KEY,
391
+ time=item.start,
392
+ value=item.pitch,
393
+ text='{}'.format(item.pitch)))
394
+ current_chord = item
395
+ elif item.name == 'Tempo':
396
+ if current_tempo is None or item.pitch != current_tempo.pitch:
397
+ events.append(pos_event)
398
+ tempo = item.pitch
399
+ index = np.argmin(abs(DEFAULT_TEMPO_BINS-tempo))
400
+ events.append(Event(
401
+ name=TEMPO_KEY,
402
+ time=item.start,
403
+ value=index,
404
+ text='{}/{}'.format(tempo, DEFAULT_TEMPO_BINS[index])))
405
+ current_tempo = item
406
+
407
+ return [f'{e.name}_{e.value}' for e in events]
408
+
409
+ def get_description(self,
410
+ omit_time_sig=False,
411
+ omit_instruments=False,
412
+ omit_chords=False,
413
+ omit_meta=False):
414
+ events = []
415
+ n_downbeat = 0
416
+ current_chord = None
417
+
418
+ for i in range(len(self.groups)):
419
+ bar_st, bar_et = self.groups[i][0], self.groups[i][-1]
420
+ n_downbeat += 1
421
+ time_sig = self._get_time_signature(bar_st)
422
+ positions_per_bar = self._get_positions_per_bar(time_sig=time_sig)
423
+ if positions_per_bar <= 0:
424
+ raise ValueError('Invalid REMI file: There must be at least 1 position in each bar.')
425
+
426
+ events.append(Event(
427
+ name=BAR_KEY,
428
+ time=None,
429
+ value='{}'.format(n_downbeat),
430
+ text='{}'.format(n_downbeat)))
431
+
432
+ if not omit_time_sig:
433
+ events.append(Event(
434
+ name=TIME_SIGNATURE_KEY,
435
+ time=None,
436
+ value='{}/{}'.format(time_sig.numerator, time_sig.denominator),
437
+ text='{}/{}'.format(time_sig.numerator, time_sig.denominator),
438
+ ))
439
+
440
+ if not omit_meta:
441
+ notes = [item for item in self.groups[i][1:-1] if item.name == 'Note']
442
+ n_notes = len(notes)
443
+ velocities = np.array([item.velocity for item in notes])
444
+ pitches = np.array([item.pitch for item in notes])
445
+ durations = np.array([item.end - item.start for item in notes])
446
+
447
+ note_density = n_notes/positions_per_bar
448
+ index = np.argmin(abs(DEFAULT_NOTE_DENSITY_BINS-note_density))
449
+ events.append(Event(
450
+ name=NOTE_DENSITY_KEY,
451
+ time=None,
452
+ value=index,
453
+ text='{:.2f}/{:.2f}'.format(note_density, DEFAULT_NOTE_DENSITY_BINS[index])
454
+ ))
455
+
456
+ # will be 0 if there's no notes
457
+ mean_velocity = velocities.mean() if len(velocities) > 0 else np.nan
458
+ index = np.argmin(abs(DEFAULT_MEAN_VELOCITY_BINS-mean_velocity))
459
+ events.append(Event(
460
+ name=MEAN_VELOCITY_KEY,
461
+ time=None,
462
+ value=index if mean_velocity != np.nan else 'NaN',
463
+ text='{:.2f}/{:.2f}'.format(mean_velocity, DEFAULT_MEAN_VELOCITY_BINS[index])
464
+ ))
465
+
466
+ # will be 0 if there's no notes
467
+ mean_pitch = pitches.mean() if len(pitches) > 0 else np.nan
468
+ index = np.argmin(abs(DEFAULT_MEAN_PITCH_BINS-mean_pitch))
469
+ events.append(Event(
470
+ name=MEAN_PITCH_KEY,
471
+ time=None,
472
+ value=index if mean_pitch != np.nan else 'NaN',
473
+ text='{:.2f}/{:.2f}'.format(mean_pitch, DEFAULT_MEAN_PITCH_BINS[index])
474
+ ))
475
+
476
+ # will be 1 if there's no notes
477
+ mean_duration = durations.mean() if len(durations) > 0 else np.nan
478
+ index = np.argmin(abs(DEFAULT_MEAN_DURATION_BINS-mean_duration))
479
+ events.append(Event(
480
+ name=MEAN_DURATION_KEY,
481
+ time=None,
482
+ value=index if mean_duration != np.nan else 'NaN',
483
+ text='{:.2f}/{:.2f}'.format(mean_duration, DEFAULT_MEAN_DURATION_BINS[index])
484
+ ))
485
+
486
+ if not omit_instruments:
487
+ instruments = set([item.instrument for item in notes])
488
+ for instrument in instruments:
489
+ instrument = pretty_midi.program_to_instrument_name(instrument) if instrument != 'drum' else 'drum'
490
+ events.append(Event(
491
+ name=INSTRUMENT_KEY,
492
+ time=None,
493
+ value=instrument,
494
+ text=instrument
495
+ ))
496
+
497
+ if not omit_chords:
498
+ chords = [item for item in self.groups[i][1:-1] if item.name == 'Chord']
499
+ if len(chords) == 0 and current_chord is not None:
500
+ chords = [current_chord]
501
+ elif len(chords) > 0:
502
+ if chords[0].start > bar_st and current_chord is not None:
503
+ chords.insert(0, current_chord)
504
+ current_chord = chords[-1]
505
+
506
+ for chord in chords:
507
+ events.append(Event(
508
+ name=CHORD_KEY,
509
+ time=None,
510
+ value=chord.pitch,
511
+ text='{}'.format(chord.pitch)
512
+ ))
513
+
514
+ return [f'{e.name}_{e.value}' for e in events]
515
+
516
+
517
+ #############################################################################################
518
+ # WRITE MIDI
519
+ #############################################################################################
520
+
521
+ def remi2midi(events, bpm=120, time_signature=(4, 4), polyphony_limit=16):
522
+ vocab = RemiVocab()
523
+
524
+ def _get_time(bar, position, bpm=120, positions_per_bar=48):
525
+ abs_position = bar*positions_per_bar + position
526
+ beat = abs_position / DEFAULT_POS_PER_QUARTER
527
+ return beat/bpm*60
528
+
529
+ def _get_time(reference, bar, pos):
530
+ time_sig = reference['time_sig']
531
+ num, denom = time_sig.numerator, time_sig.denominator
532
+ # Quarters per bar, assuming 4 quarters per whole note
533
+ qpb = 4 * num / denom
534
+ ref_pos = reference['pos']
535
+ d_bars = bar - ref_pos[0]
536
+ d_pos = (pos - ref_pos[1]) + d_bars*qpb*DEFAULT_POS_PER_QUARTER
537
+ d_quarters = d_pos / DEFAULT_POS_PER_QUARTER
538
+ # Convert quarters to seconds
539
+ dt = d_quarters / reference['tempo'] * 60
540
+ return reference['time'] + dt
541
+
542
+ # time_sigs = [event.split('_')[-1].split('/') for event in events if f"{TIME_SIGNATURE_KEY}_" in event]
543
+ # time_sigs = [(int(num), int(denom)) for num, denom in time_sigs]
544
+
545
+ tempo_changes = [event for event in events if f"{TEMPO_KEY}_" in event]
546
+ if len(tempo_changes) > 0:
547
+ bpm = DEFAULT_TEMPO_BINS[int(tempo_changes[0].split('_')[-1])]
548
+
549
+ pm = pretty_midi.PrettyMIDI(initial_tempo=bpm)
550
+ num, denom = time_signature
551
+ pm.time_signature_changes.append(pretty_midi.TimeSignature(num, denom, 0))
552
+ current_time_sig = pm.time_signature_changes[0]
553
+
554
+ instruments = {}
555
+
556
+ # Use implicit timeline: keep track of last tempo/time signature change event
557
+ # and calculate time difference relative to that
558
+ last_tl_event = {
559
+ 'time': 0,
560
+ 'pos': (0, 0),
561
+ 'time_sig': current_time_sig,
562
+ 'tempo': bpm
563
+ }
564
+
565
+ bar = -1
566
+ n_notes = 0
567
+ polyphony_control = {}
568
+ for i, event in enumerate(events):
569
+ if event == EOS_TOKEN:
570
+ break
571
+
572
+ if not bar in polyphony_control:
573
+ polyphony_control[bar] = {}
574
+
575
+ if f"{BAR_KEY}_" in events[i]:
576
+ # Next bar is starting
577
+ bar += 1
578
+ polyphony_control[bar] = {}
579
+
580
+ if i+1 < len(events) and f"{TIME_SIGNATURE_KEY}_" in events[i+1]:
581
+ num, denom = events[i+1].split('_')[-1].split('/')
582
+ num, denom = int(num), int(denom)
583
+ current_time_sig = last_tl_event['time_sig']
584
+ if num != current_time_sig.numerator or denom != current_time_sig.denominator:
585
+ time = _get_time(last_tl_event, bar, 0)
586
+ time_sig = pretty_midi.TimeSignature(num, denom, time)
587
+ pm.time_signature_changes.append(time_sig)
588
+ last_tl_event['time'] = time
589
+ last_tl_event['pos'] = (bar, 0)
590
+ last_tl_event['time_sig'] = time_sig
591
+
592
+ elif i+1 < len(events) and \
593
+ f"{POSITION_KEY}_" in events[i] and \
594
+ f"{TEMPO_KEY}_" in events[i+1]:
595
+ position = int(events[i].split('_')[-1])
596
+ tempo_idx = int(events[i+1].split('_')[-1])
597
+ tempo = DEFAULT_TEMPO_BINS[tempo_idx]
598
+
599
+ if tempo != last_tl_event['tempo']:
600
+ time = _get_time(last_tl_event, bar, position)
601
+ last_tl_event['time'] = time
602
+ last_tl_event['pos'] = (bar, position)
603
+ last_tl_event['tempo'] = tempo
604
+
605
+ elif i+4 < len(events) and \
606
+ f"{POSITION_KEY}_" in events[i] and \
607
+ f"{INSTRUMENT_KEY}_" in events[i+1] and \
608
+ f"{PITCH_KEY}_" in events[i+2] and \
609
+ f"{VELOCITY_KEY}_" in events[i+3] and \
610
+ f"{DURATION_KEY}_" in events[i+4]:
611
+ # get position
612
+ position = int(events[i].split('_')[-1])
613
+ if not position in polyphony_control[bar]:
614
+ polyphony_control[bar][position] = {}
615
+
616
+ # get instrument
617
+ instrument_name = events[i+1].split('_')[-1]
618
+ if instrument_name not in polyphony_control[bar][position]:
619
+ polyphony_control[bar][position][instrument_name] = 0
620
+ elif polyphony_control[bar][position][instrument_name] >= polyphony_limit:
621
+ # If number of notes exceeds polyphony limit, omit this note
622
+ continue
623
+
624
+ if instrument_name not in instruments:
625
+ if instrument_name == 'drum':
626
+ instrument = pretty_midi.Instrument(0, is_drum=True)
627
+ else:
628
+ program = pretty_midi.instrument_name_to_program(instrument_name)
629
+ instrument = pretty_midi.Instrument(program)
630
+ instruments[instrument_name] = instrument
631
+ else:
632
+ instrument = instruments[instrument_name]
633
+
634
+ # get pitch
635
+ pitch = int(events[i+2].split('_')[-1])
636
+ # get velocity
637
+ velocity_index = int(events[i+3].split('_')[-1])
638
+ velocity = min(127, DEFAULT_VELOCITY_BINS[velocity_index])
639
+ # get duration
640
+ duration_index = int(events[i+4].split('_')[-1])
641
+ duration = DEFAULT_DURATION_BINS[duration_index]
642
+ # create not and add to instrument
643
+ start = _get_time(last_tl_event, bar, position)
644
+ end = _get_time(last_tl_event, bar, position + duration)
645
+ note = pretty_midi.Note(velocity=velocity,
646
+ pitch=pitch,
647
+ start=start,
648
+ end=end)
649
+ instrument.notes.append(note)
650
+ n_notes += 1
651
+ polyphony_control[bar][position][instrument_name] += 1
652
+
653
+ for instrument in instruments.values():
654
+ pm.instruments.append(instrument)
655
+ return pm
music_evaluation/figaro/vocab.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pretty_midi
2
+ from collections import Counter
3
+ import torchtext
4
+ from torch import Tensor
5
+
6
+ from constants import (
7
+ DEFAULT_VELOCITY_BINS,
8
+ DEFAULT_DURATION_BINS,
9
+ DEFAULT_TEMPO_BINS,
10
+ DEFAULT_POS_PER_QUARTER,
11
+ DEFAULT_NOTE_DENSITY_BINS,
12
+ DEFAULT_MEAN_VELOCITY_BINS,
13
+ DEFAULT_MEAN_PITCH_BINS,
14
+ DEFAULT_MEAN_DURATION_BINS
15
+ )
16
+
17
+
18
+ from constants import (
19
+ MAX_BAR_LENGTH,
20
+ MAX_N_BARS,
21
+
22
+ PAD_TOKEN,
23
+ UNK_TOKEN,
24
+ BOS_TOKEN,
25
+ EOS_TOKEN,
26
+ MASK_TOKEN,
27
+
28
+ TIME_SIGNATURE_KEY,
29
+ BAR_KEY,
30
+ POSITION_KEY,
31
+ INSTRUMENT_KEY,
32
+ PITCH_KEY,
33
+ VELOCITY_KEY,
34
+ DURATION_KEY,
35
+ TEMPO_KEY,
36
+ CHORD_KEY,
37
+
38
+ NOTE_DENSITY_KEY,
39
+ MEAN_PITCH_KEY,
40
+ MEAN_VELOCITY_KEY,
41
+ MEAN_DURATION_KEY,
42
+ )
43
+
44
+
45
+
46
+ class Tokens:
47
+ def get_instrument_tokens(key=INSTRUMENT_KEY):
48
+ tokens = [f'{key}_{pretty_midi.program_to_instrument_name(i)}' for i in range(128)]
49
+ tokens.append(f'{key}_drum')
50
+ return tokens
51
+
52
+ def get_chord_tokens(key=CHORD_KEY, qualities = ['maj', 'min', 'dim', 'aug', 'dom7', 'maj7', 'min7', 'None']):
53
+ pitch_classes = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
54
+
55
+ chords = [f'{root}:{quality}' for root in pitch_classes for quality in qualities]
56
+ chords.append('N:N')
57
+
58
+ tokens = [f'{key}_{chord}' for chord in chords]
59
+ return tokens
60
+
61
+ def get_time_signature_tokens(key=TIME_SIGNATURE_KEY):
62
+ denominators = [2, 4, 8, 16]
63
+ time_sigs = [f'{p}/{q}' for q in denominators for p in range(1, MAX_BAR_LENGTH*q + 1)]
64
+ tokens = [f'{key}_{time_sig}' for time_sig in time_sigs]
65
+ return tokens
66
+
67
+ def get_midi_tokens(
68
+ instrument_key=INSTRUMENT_KEY,
69
+ time_signature_key=TIME_SIGNATURE_KEY,
70
+ pitch_key=PITCH_KEY,
71
+ velocity_key=VELOCITY_KEY,
72
+ duration_key=DURATION_KEY,
73
+ tempo_key=TEMPO_KEY,
74
+ bar_key=BAR_KEY,
75
+ position_key=POSITION_KEY
76
+ ):
77
+ instrument_tokens = Tokens.get_instrument_tokens(instrument_key)
78
+
79
+ pitch_tokens = [f'{pitch_key}_{i}' for i in range(128)] + [f'{pitch_key}_drum_{i}' for i in range(128)]
80
+ velocity_tokens = [f'{velocity_key}_{i}' for i in range(len(DEFAULT_VELOCITY_BINS))]
81
+ duration_tokens = [f'{duration_key}_{i}' for i in range(len(DEFAULT_DURATION_BINS))]
82
+ tempo_tokens = [f'{tempo_key}_{i}' for i in range(len(DEFAULT_TEMPO_BINS))]
83
+ bar_tokens = [f'{bar_key}_{i}' for i in range(MAX_N_BARS)]
84
+ position_tokens = [f'{position_key}_{i}' for i in range(MAX_BAR_LENGTH*4*DEFAULT_POS_PER_QUARTER)]
85
+
86
+ time_sig_tokens = Tokens.get_time_signature_tokens(time_signature_key)
87
+
88
+ return (
89
+ time_sig_tokens +
90
+ tempo_tokens +
91
+ instrument_tokens +
92
+ pitch_tokens +
93
+ velocity_tokens +
94
+ duration_tokens +
95
+ bar_tokens +
96
+ position_tokens
97
+ )
98
+
99
+ class Vocab:
100
+ def __init__(self, counter, specials=[PAD_TOKEN, UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, MASK_TOKEN], unk_token=UNK_TOKEN):
101
+ self.vocab = torchtext.vocab.vocab(counter)
102
+
103
+ self.specials = specials
104
+ for i, token in enumerate(self.specials):
105
+ self.vocab.insert_token(token, i)
106
+
107
+ if unk_token in specials:
108
+ self.vocab.set_default_index(self.vocab.get_stoi()[unk_token])
109
+
110
+ def to_i(self, token):
111
+ return self.vocab.get_stoi()[token]
112
+
113
+ def to_s(self, idx):
114
+ if idx >= len(self.vocab):
115
+ return UNK_TOKEN
116
+ else:
117
+ return self.vocab.get_itos()[idx]
118
+
119
+ def __len__(self):
120
+ return len(self.vocab)
121
+
122
+ def encode(self, seq):
123
+ return self.vocab(seq)
124
+
125
+ def decode(self, seq):
126
+ if isinstance(seq, Tensor):
127
+ seq = seq.numpy()
128
+ return self.vocab.lookup_tokens(seq)
129
+
130
+
131
+ class RemiVocab(Vocab):
132
+ def __init__(self):
133
+ midi_tokens = Tokens.get_midi_tokens()
134
+ chord_tokens = Tokens.get_chord_tokens()
135
+
136
+ self.tokens = midi_tokens + chord_tokens
137
+
138
+ counter = Counter(self.tokens)
139
+ super().__init__(counter)
140
+
141
+
142
+ class DescriptionVocab(Vocab):
143
+ def __init__(self):
144
+ time_sig_tokens = Tokens.get_time_signature_tokens()
145
+ instrument_tokens = Tokens.get_instrument_tokens()
146
+ chord_tokens = Tokens.get_chord_tokens()
147
+
148
+ bar_tokens = [f'Bar_{i}' for i in range(MAX_N_BARS)]
149
+ density_tokens = [f'{NOTE_DENSITY_KEY}_{i}' for i in range(len(DEFAULT_NOTE_DENSITY_BINS))]
150
+ velocity_tokens = [f'{MEAN_VELOCITY_KEY}_{i}' for i in range(len(DEFAULT_MEAN_VELOCITY_BINS))]
151
+ pitch_tokens = [f'{MEAN_PITCH_KEY}_{i}' for i in range(len(DEFAULT_MEAN_PITCH_BINS))]
152
+ duration_tokens = [f'{MEAN_DURATION_KEY}_{i}' for i in range(len(DEFAULT_MEAN_DURATION_BINS))]
153
+
154
+ self.tokens = (
155
+ time_sig_tokens +
156
+ instrument_tokens +
157
+ chord_tokens +
158
+ density_tokens +
159
+ velocity_tokens +
160
+ pitch_tokens +
161
+ duration_tokens +
162
+ bar_tokens
163
+ )
164
+
165
+ counter = Counter(self.tokens)
166
+ super().__init__(counter)
music_evaluation/mgeval/__init__.py ADDED
File without changes
music_evaluation/mgeval/__init__.pyc ADDED
Binary file (105 Bytes). View file
 
music_evaluation/mgeval/core.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding:utf-8
2
+ """core.py
3
+ Include feature extractor and musically informed objective measures.
4
+ """
5
+ import pretty_midi
6
+ import numpy as np
7
+ import sys
8
+ import os
9
+ import statistics
10
+ # import midi
11
+ import glob
12
+ import math
13
+
14
+
15
+ # feature extractor
16
+ def extract_feature(_file):
17
+ """
18
+ This function extracts two midi feature:
19
+ pretty_midi object: https://github.com/craffel/pretty-midi
20
+ midi_pattern: https://github.com/vishnubob/python-midi
21
+
22
+ Returns:
23
+ dict(pretty_midi: pretty_midi object,
24
+ midi_pattern: midi pattern contains a list of tracks)
25
+ """
26
+ feature = {'pretty_midi': pretty_midi.PrettyMIDI(_file)}
27
+ # 'midi_pattern': midi.read_midifile(_file)}
28
+ return feature
29
+
30
+
31
+ # musically informed objective measures.
32
+ class metrics(object):
33
+ def total_used_pitch(self, feature):
34
+ """
35
+ total_used_pitch (Pitch count): The number of different pitches within a sample.
36
+
37
+ Returns:
38
+ 'used_pitch': pitch count, scalar for each sample.
39
+ """
40
+ try:
41
+ instrument = feature['pretty_midi'].instruments[0]
42
+ piano_roll = instrument.get_piano_roll(fs=100)
43
+ sum_notes = np.sum(piano_roll, axis=1)
44
+ used_pitch = np.sum(sum_notes > 0)
45
+ return used_pitch
46
+ except:
47
+ return 0 # empty piano roll
48
+
49
+ def mean_note_velocity(self, feature):
50
+ """
51
+ mean_var_note_velocity: The velocity of different notes within a sample.
52
+
53
+ Returns:
54
+ note_velocity: the average velocity and the variance of velocity, 2 scalar for each sample.
55
+ """
56
+ try:
57
+ instrument = feature['pretty_midi'].instruments[0]
58
+ velocity = []
59
+ for note in instrument.notes:
60
+ velocity.append(note.velocity)
61
+
62
+ mean_velocity = statistics.mean(velocity)
63
+ if len(velocity) > 1:
64
+ variance_velocity = statistics.variance(velocity)
65
+ else:
66
+ variance_velocity = 0
67
+
68
+ return mean_velocity
69
+ except:
70
+ return 0 # empty piano roll
71
+
72
+ def mean_note_duration(self, feature):
73
+ """
74
+ mean_var_note_duration: The duration of different notes within a sample.
75
+
76
+ Returns:
77
+ note_duration: the average duration and the variance of velocity, 2 scalar for each sample.
78
+ """
79
+ try:
80
+ instrument = feature['pretty_midi'].instruments[0]
81
+ duration = []
82
+ for note in instrument.notes:
83
+ d = note.end - note.start
84
+ duration.append(d)
85
+
86
+ mean_duration = statistics.mean(duration)
87
+ if len(duration) > 1:
88
+ variance_duration = statistics.variance(duration)
89
+ else:
90
+ variance_duration = 0
91
+
92
+ return mean_duration
93
+ except:
94
+ return 0 # empty piano roll
95
+
96
+ def note_density(self, feature):
97
+ """
98
+ note_density: the density of note within a sample.
99
+
100
+ Returns:
101
+ note_density: the density of notes, 1 scalar for each sample.
102
+ """
103
+
104
+ # instrument = feature['pretty_midi'].instruments[0]
105
+ total_notes = sum(len(instrument.notes) for instrument in feature['pretty_midi'].instruments)
106
+ total_duration = feature['pretty_midi'].get_end_time()
107
+
108
+ # Calculate the note density
109
+ if total_duration > 0:
110
+ note_density = total_notes / total_duration
111
+ else:
112
+ note_density = 0
113
+
114
+ return note_density
115
+
116
+ # def bar_used_pitch(self, feature, track_num=1, num_bar=None):
117
+ # """
118
+ # bar_used_pitch (Pitch count per bar)
119
+
120
+ # Args:
121
+ # 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
122
+ # 'num_bar': specify the number of bars in the midi pattern, if set as None, round to the number of complete bar.
123
+
124
+ # Returns:
125
+ # 'used_pitch': with shape of [num_bar,1]
126
+ # """
127
+ # pattern = feature['midi_pattern']
128
+ # pattern.make_ticks_abs()
129
+ # resolution = pattern.resolution
130
+ # for i in range(0, len(pattern[track_num])):
131
+ # if type(pattern[track_num][i]) == midi.events.TimeSignatureEvent:
132
+ # time_sig = pattern[track_num][i].data
133
+ # bar_length = time_sig[0] * resolution * 4 / 2**(time_sig[1])
134
+ # if num_bar is None:
135
+ # num_bar = int(round(float(pattern[track_num][-1].tick) / bar_length))
136
+ # used_notes = np.zeros((num_bar, 1))
137
+ # else:
138
+ # used_notes = np.zeros((num_bar, 1))
139
+
140
+ # elif type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
141
+ # if 'time_sig' not in locals(): # set default bar length as 4 beat
142
+ # bar_length = 4 * resolution
143
+ # time_sig = [4, 2, 24, 8]
144
+
145
+ # if num_bar is None:
146
+ # num_bar = int(round(float(pattern[track_num][-1].tick) / bar_length))
147
+ # used_notes = np.zeros((num_bar, 1))
148
+ # used_notes[pattern[track_num][i].tick / bar_length] += 1
149
+ # else:
150
+ # used_notes = np.zeros((num_bar, 1))
151
+ # used_notes[pattern[track_num][i].tick / bar_length] += 1
152
+ # note_list = []
153
+ # note_list.append(pattern[track_num][i].data[0])
154
+
155
+ # else:
156
+ # for j in range(0, num_bar):
157
+ # if 'note_list'in locals():
158
+ # pass
159
+ # else:
160
+ # note_list = []
161
+ # note_list.append(pattern[track_num][i].data[0])
162
+ # idx = pattern[track_num][i].tick / bar_length
163
+ # if idx >= num_bar:
164
+ # continue
165
+ # used_notes[idx] += 1
166
+ # # used_notes[pattern[track_num][i].tick / bar_length] += 1
167
+
168
+ # used_pitch = np.zeros((num_bar, 1))
169
+ # current_note = 0
170
+ # for i in range(0, num_bar):
171
+ # used_pitch[i] = len(set(note_list[current_note:current_note + int(used_notes[i][0])]))
172
+ # current_note += int(used_notes[i][0])
173
+
174
+ # return used_pitch
175
+
176
+ # def total_used_note(self, feature, track_num=1):
177
+ # """
178
+ # total_used_note (Note count): The number of used notes.
179
+ # As opposed to the pitch count, the note count does not contain pitch information but is a rhythm-related feature.
180
+
181
+ # Args:
182
+ # 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
183
+
184
+ # Returns:
185
+ # 'used_notes': a scalar for each sample.
186
+ # """
187
+ # pattern = feature['midi_pattern']
188
+ # used_notes = 0
189
+ # for i in range(0, len(pattern[track_num])):
190
+ # if type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
191
+ # used_notes += 1
192
+ # return used_notes
193
+
194
+ # def bar_used_note(self, feature, track_num=1, num_bar=None):
195
+ # """
196
+ # bar_used_note (Note count per bar).
197
+
198
+ # Args:
199
+ # 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
200
+ # 'num_bar': specify the number of bars in the midi pattern, if set as None, round to the number of complete bar.
201
+
202
+ # Returns:
203
+ # 'used_notes': with shape of [num_bar, 1]
204
+ # """
205
+ # pattern = feature['midi_pattern']
206
+ # pattern.make_ticks_abs()
207
+ # resolution = pattern.resolution
208
+ # for i in range(0, len(pattern[track_num])):
209
+ # if type(pattern[track_num][i]) == midi.events.TimeSignatureEvent:
210
+ # time_sig = pattern[track_num][i].data
211
+ # bar_length = time_sig[track_num] * resolution * 4 / 2**(time_sig[1])
212
+ # if num_bar is None:
213
+ # num_bar = int(round(float(pattern[track_num][-1].tick) / bar_length))
214
+ # used_notes = np.zeros((num_bar, 1))
215
+ # else:
216
+ # used_notes = np.zeros((num_bar, 1))
217
+
218
+ # elif type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
219
+ # if 'time_sig' not in locals(): # set default bar length as 4 beat
220
+ # bar_length = 4 * resolution
221
+ # time_sig = [4, 2, 24, 8]
222
+
223
+ # if num_bar is None:
224
+ # num_bar = int(round(float(pattern[track_num][-1].tick) / bar_length))
225
+ # used_notes = np.zeros((num_bar, 1))
226
+ # used_notes[pattern[track_num][i].tick / bar_length] += 1
227
+ # else:
228
+ # used_notes = np.zeros((num_bar, 1))
229
+ # used_notes[pattern[track_num][i].tick / bar_length] += 1
230
+
231
+ # else:
232
+ # idx = pattern[track_num][i].tick / bar_length
233
+ # if idx >= num_bar:
234
+ # continue
235
+ # used_notes[idx] += 1
236
+ # return used_notes
237
+
238
+ def total_pitch_class_histogram(self, feature):
239
+ """
240
+ total_pitch_class_histogram (Pitch class histogram):
241
+ The pitch class histogram is an octave-independent representation of the pitch content with a dimensionality of 12 for a chromatic scale.
242
+ In our case, it represents to the octave-independent chromatic quantization of the frequency continuum.
243
+
244
+ Returns:
245
+ 'histogram': histrogram of 12 pitch, with weighted duration shape 12
246
+ """
247
+ # print(feature['pretty_midi'].instruments)
248
+ histogram = np.zeros(12)
249
+ try:
250
+ piano_roll = feature['pretty_midi'].instruments[0].get_piano_roll(fs=100)
251
+ # piano_roll = feature['pretty_midi'].get_piano_roll(fs=100)
252
+
253
+ for i in range(0, 128):
254
+ pitch_class = i % 12
255
+ histogram[pitch_class] += np.sum(piano_roll, axis=1)[i]
256
+ histogram = histogram / sum(histogram)
257
+ return histogram
258
+ except:
259
+ return histogram
260
+
261
+ def bar_pitch_class_histogram(self, feature, track_num=1, num_bar=None, bpm=120):
262
+ """
263
+ bar_pitch_class_histogram (Pitch class histogram per bar):
264
+
265
+ Args:
266
+ 'bpm' : specify the assigned speed in bpm, default is 120 bpm.
267
+ 'num_bar': specify the number of bars in the midi pattern, if set as None, round to the number of complete bar.
268
+ 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
269
+
270
+ Returns:
271
+ 'histogram': with shape of [num_bar, 12]
272
+ """
273
+
274
+ # todo: deal with more than one time signature cases
275
+ pm_object = feature['pretty_midi']
276
+ if num_bar is None:
277
+ numer = pm_object.time_signature_changes[-1].numerator
278
+ deno = pm_object.time_signature_changes[-1].denominator
279
+ bar_length = 60. / bpm * numer * 4 / deno * 100
280
+ piano_roll = pm_object.instruments[track_num].get_piano_roll(fs=100)
281
+ piano_roll = np.transpose(piano_roll, (1, 0))
282
+ actual_bar = len(piano_roll) / bar_length
283
+ num_bar = int(round(actual_bar))
284
+ bar_length = int(round(bar_length))
285
+ else:
286
+ numer = pm_object.time_signature_changes[-1].numerator
287
+ deno = pm_object.time_signature_changes[-1].denominator
288
+ bar_length = 60. / bpm * numer * 4 / deno * 100
289
+ piano_roll = pm_object.instruments[track_num].get_piano_roll(fs=100)
290
+ piano_roll = np.transpose(piano_roll, (1, 0))
291
+ actual_bar = len(piano_roll) / bar_length
292
+ bar_length = int(math.ceil(bar_length))
293
+
294
+ if actual_bar > num_bar:
295
+ mod = np.mod(len(piano_roll), bar_length * 128)
296
+ piano_roll = piano_roll[:-np.mod(len(piano_roll), bar_length)].reshape((num_bar, -1, 128)) # make exact bar
297
+ elif actual_bar == num_bar:
298
+ piano_roll = piano_roll.reshape((num_bar, -1, 128))
299
+ else:
300
+ piano_roll = np.pad(piano_roll, ((0, int(num_bar * bar_length - len(piano_roll))), (0, 0)), mode='constant',
301
+ constant_values=0)
302
+ piano_roll = piano_roll.reshape((num_bar, -1, 128))
303
+
304
+ bar_histogram = np.zeros((num_bar, 12))
305
+ for i in range(0, num_bar):
306
+ histogram = np.zeros(12)
307
+ for j in range(0, 128):
308
+ pitch_class = j % 12
309
+ histogram[pitch_class] += np.sum(piano_roll[i], axis=0)[j]
310
+ if sum(histogram) != 0:
311
+ bar_histogram[i] = histogram / sum(histogram)
312
+ else:
313
+ bar_histogram[i] = np.zeros(12)
314
+ return bar_histogram
315
+
316
+ def pitch_class_transition_matrix(self, feature, normalize=0):
317
+ """
318
+ pitch_class_transition_matrix (Pitch class transition matrix):
319
+ The transition of pitch classes contains useful information for tasks such as key detection, chord recognition, or genre pattern recognition.
320
+ The two-dimensional pitch class transition matrix is a histogram-like representation computed by counting the pitch transitions for each (ordered) pair of notes.
321
+
322
+ Args:
323
+ 'normalize' : If set to 0, return transition without normalization.
324
+ If set to 1, normalizae by row.
325
+ If set to 2, normalize by entire matrix sum.
326
+ Returns:
327
+ 'transition_matrix': shape of [12, 12], transition_matrix of 12 x 12.
328
+ """
329
+ pm_object = feature['pretty_midi']
330
+ transition_matrix = pm_object.get_pitch_class_transition_matrix()
331
+
332
+ if normalize == 0:
333
+ return transition_matrix
334
+
335
+ elif normalize == 1:
336
+ sums = np.sum(transition_matrix, axis=1)
337
+ sums[sums == 0] = 1
338
+ return transition_matrix / sums.reshape(-1, 1)
339
+
340
+ elif normalize == 2:
341
+ return transition_matrix / sum(sum(transition_matrix))
342
+
343
+ else:
344
+ print("invalid normalization mode, return unnormalized matrix")
345
+ return transition_matrix
346
+
347
+ def pitch_range(self, feature):
348
+ """
349
+ pitch_range (Pitch range):
350
+ The pitch range is calculated by subtraction of the highest and lowest used pitch in semitones.
351
+
352
+ Returns:
353
+ 'p_range': a scalar for each sample.
354
+ """
355
+ try:
356
+ piano_roll = feature['pretty_midi'].instruments[0].get_piano_roll(fs=100)
357
+ pitch_index = np.where(np.sum(piano_roll, axis=1) > 0)
358
+ p_range = np.max(pitch_index) - np.min(pitch_index)
359
+ return p_range
360
+ except:
361
+ return 0 # empty piano roll
362
+
363
+ # def avg_pitch_shift(self, feature, track_num=1):
364
+ # """
365
+ # avg_pitch_shift (Average pitch interval):
366
+ # Average value of the interval between two consecutive pitches in semitones.
367
+
368
+ # Args:
369
+ # 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
370
+
371
+ # Returns:
372
+ # 'pitch_shift': a scalar for each sample.
373
+ # """
374
+ # pattern = feature['midi_pattern']
375
+ # pattern.make_ticks_abs()
376
+ # resolution = pattern.resolution
377
+ # total_used_note = self.total_used_note(feature, track_num=track_num)
378
+ # d_note = np.zeros((max(total_used_note - 1, 0)))
379
+ # # if total_used_note == 0:
380
+ # # return 0
381
+ # # d_note = np.zeros((total_used_note - 1))
382
+ # current_note = 0
383
+ # counter = 0
384
+ # for i in range(0, len(pattern[track_num])):
385
+ # if type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
386
+ # if counter != 0:
387
+ # d_note[counter - 1] = current_note - pattern[track_num][i].data[0]
388
+ # current_note = pattern[track_num][i].data[0]
389
+ # counter += 1
390
+ # else:
391
+ # current_note = pattern[track_num][i].data[0]
392
+ # counter += 1
393
+ # pitch_shift = np.mean(abs(d_note))
394
+ # return pitch_shift
395
+
396
+ def avg_IOI(self, feature):
397
+ """
398
+ avg_IOI (Average inter-onset-interval):
399
+ To calculate the inter-onset-interval in the symbolic music domain, we find the time between two consecutive notes.
400
+
401
+ Returns:
402
+ 'avg_ioi': a scalar for each sample.
403
+ """
404
+ try:
405
+ tmp = feature['pretty_midi'].instruments[0]
406
+ pm_object = feature['pretty_midi']
407
+ onset = pm_object.get_onsets()
408
+ ioi = np.diff(onset)
409
+ avg_ioi = np.mean(ioi)
410
+ return avg_ioi
411
+ except:
412
+ return 0 # empty piano roll
413
+
414
+ # def note_length_hist(self, feature, track_num=1, normalize=True, pause_event=False):
415
+ # """
416
+ # note_length_hist (Note length histogram):
417
+ # To extract the note length histogram, we first define a set of allowable beat length classes:
418
+ # [full, half, quarter, 8th, 16th, dot half, dot quarter, dot 8th, dot 16th, half note triplet, quarter note triplet, 8th note triplet].
419
+ # The pause_event option, when activated, will double the vector size to represent the same lengths for rests.
420
+ # The classification of each event is performed by dividing the basic unit into the length of (barlength)/96, and each note length is quantized to the closest length category.
421
+
422
+ # Args:
423
+ # 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
424
+ # 'normalize' : If true, normalize by vector sum.
425
+ # 'pause_event' : when activated, will double the vector size to represent the same lengths for rests.
426
+
427
+ # Returns:
428
+ # 'note_length_hist': The output vector has a length of either 12 (or 24 when pause_event is True).
429
+ # """
430
+
431
+ # pattern = feature['midi_pattern']
432
+ # if pause_event is False:
433
+ # note_length_hist = np.zeros((12))
434
+ # pattern.make_ticks_abs()
435
+ # resolution = pattern.resolution
436
+ # # basic unit: bar_length/96
437
+ # for i in range(0, len(pattern[track_num])):
438
+ # if type(pattern[track_num][i]) == midi.events.TimeSignatureEvent:
439
+ # time_sig = pattern[track_num][i].data
440
+ # bar_length = time_sig[track_num] * resolution * 4 / 2**(time_sig[1])
441
+ # elif type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
442
+ # if 'time_sig' not in locals(): # set default bar length as 4 beat
443
+ # bar_length = 4 * resolution
444
+ # time_sig = [4, 2, 24, 8]
445
+ # unit = bar_length / 96.
446
+ # hist_list = [unit * 96, unit * 48, unit * 24, unit * 12, unit * 6, unit * 72, unit * 36, unit * 18, unit * 9, unit * 32, unit * 16, unit * 8]
447
+ # current_tick = pattern[track_num][i].tick
448
+ # current_note = pattern[track_num][i].data[0]
449
+ # # find next note off
450
+ # for j in range(i, len(pattern[track_num])):
451
+ # if type(pattern[track_num][j]) == midi.events.NoteOffEvent or (type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] == 0):
452
+ # if pattern[track_num][j].data[0] == current_note:
453
+
454
+ # note_length = pattern[track_num][j].tick - current_tick
455
+ # distance = np.abs(np.array(hist_list) - note_length)
456
+ # idx = distance.argmin()
457
+ # note_length_hist[idx] += 1
458
+ # break
459
+ # else:
460
+ # note_length_hist = np.zeros((24))
461
+ # pattern.make_ticks_abs()
462
+ # resolution = pattern.resolution
463
+ # # basic unit: bar_length/96
464
+ # for i in range(0, len(pattern[track_num])):
465
+ # if type(pattern[track_num][i]) == midi.events.TimeSignatureEvent:
466
+ # time_sig = pattern[track_num][i].data
467
+ # bar_length = time_sig[track_num] * resolution * 4 / 2**(time_sig[1])
468
+ # elif type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
469
+ # check_previous_off = True
470
+ # if 'time_sig' not in locals(): # set default bar length as 4 beat
471
+ # bar_length = 4 * resolution
472
+ # time_sig = [4, 2, 24, 8]
473
+ # unit = bar_length / 96.
474
+ # tol = 3. * unit
475
+ # hist_list = [unit * 96, unit * 48, unit * 24, unit * 12, unit * 6, unit * 72, unit * 36, unit * 18, unit * 9, unit * 32, unit * 16, unit * 8]
476
+ # current_tick = pattern[track_num][i].tick
477
+ # current_note = pattern[track_num][i].data[0]
478
+ # # find next note off
479
+ # for j in range(i, len(pattern[track_num])):
480
+ # # find next note off
481
+ # if type(pattern[track_num][j]) == midi.events.NoteOffEvent or (type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] == 0):
482
+ # if pattern[track_num][j].data[0] == current_note:
483
+
484
+ # note_length = pattern[track_num][j].tick - current_tick
485
+ # distance = np.abs(np.array(hist_list) - note_length)
486
+ # idx = distance.argmin()
487
+ # note_length_hist[idx] += 1
488
+ # break
489
+ # else:
490
+ # if pattern[track_num][j].tick == current_tick:
491
+ # check_previous_off = False
492
+
493
+ # # find previous note off/on
494
+ # if check_previous_off is True:
495
+ # for j in range(i - 1, 0, -1):
496
+ # if type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] != 0:
497
+ # break
498
+
499
+ # elif type(pattern[track_num][j]) == midi.events.NoteOffEvent or (type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] == 0):
500
+
501
+ # note_length = current_tick - pattern[track_num][j].tick
502
+ # distance = np.abs(np.array(hist_list) - note_length)
503
+ # idx = distance.argmin()
504
+ # if distance[idx] < tol:
505
+ # note_length_hist[idx + 12] += 1
506
+ # break
507
+
508
+ # if normalize is False:
509
+ # return note_length_hist
510
+
511
+ # elif normalize is True:
512
+
513
+ # return note_length_hist / np.sum(note_length_hist)
514
+
515
+ # def note_length_transition_matrix(self, feature, track_num=1, normalize=0, pause_event=False):
516
+ # """
517
+ # note_length_transition_matrix (Note length transition matrix):
518
+ # Similar to the pitch class transition matrix, the note length tran- sition matrix provides useful information for rhythm description.
519
+
520
+ # Args:
521
+ # 'track_num' : specify the track number in the midi pattern, default is 1 (the second track).
522
+ # 'normalize' : If true, normalize by vector sum.
523
+ # 'pause_event' : when activated, will double the vector size to represent the same lengths for rests.
524
+
525
+ # 'normalize' : If set to 0, return transition without normalization.
526
+ # If set to 1, normalizae by row.
527
+ # If set to 2, normalize by entire matrix sum.
528
+
529
+ # Returns:
530
+ # 'transition_matrix': The output feature dimension is 12 × 12 (or 24 x 24 when pause_event is True).
531
+ # """
532
+ # pattern = feature['midi_pattern']
533
+ # if pause_event is False:
534
+ # transition_matrix = np.zeros((12, 12))
535
+ # pattern.make_ticks_abs()
536
+ # resolution = pattern.resolution
537
+ # idx = None
538
+ # # basic unit: bar_length/96
539
+ # for i in range(0, len(pattern[track_num])):
540
+ # if type(pattern[track_num][i]) == midi.events.TimeSignatureEvent:
541
+ # time_sig = pattern[track_num][i].data
542
+ # bar_length = time_sig[track_num] * resolution * 4 / 2**(time_sig[1])
543
+ # elif type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
544
+ # if 'time_sig' not in locals(): # set default bar length as 4 beat
545
+ # bar_length = 4 * resolution
546
+ # time_sig = [4, 2, 24, 8]
547
+ # unit = bar_length / 96.
548
+ # hist_list = [unit * 96, unit * 48, unit * 24, unit * 12, unit * 6, unit * 72, unit * 36, unit * 18, unit * 9, unit * 32, unit * 16, unit * 8]
549
+ # current_tick = pattern[track_num][i].tick
550
+ # current_note = pattern[track_num][i].data[0]
551
+ # # find note off
552
+ # for j in range(i, len(pattern[track_num])):
553
+ # if type(pattern[track_num][j]) == midi.events.NoteOffEvent or (type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] == 0):
554
+ # if pattern[track_num][j].data[0] == current_note:
555
+ # note_length = pattern[track_num][j].tick - current_tick
556
+ # distance = np.abs(np.array(hist_list) - note_length)
557
+
558
+ # last_idx = idx
559
+ # idx = distance.argmin()
560
+ # if last_idx is not None:
561
+ # transition_matrix[last_idx][idx] += 1
562
+ # break
563
+ # else:
564
+ # transition_matrix = np.zeros((24, 24))
565
+ # pattern.make_ticks_abs()
566
+ # resolution = pattern.resolution
567
+ # idx = None
568
+ # # basic unit: bar_length/96
569
+ # for i in range(0, len(pattern[track_num])):
570
+ # if type(pattern[track_num][i]) == midi.events.TimeSignatureEvent:
571
+ # time_sig = pattern[track_num][i].data
572
+ # bar_length = time_sig[track_num] * resolution * 4 / 2**(time_sig[1])
573
+ # elif type(pattern[track_num][i]) == midi.events.NoteOnEvent and pattern[track_num][i].data[1] != 0:
574
+ # check_previous_off = True
575
+ # if 'time_sig' not in locals(): # set default bar length as 4 beat
576
+ # bar_length = 4 * resolution
577
+ # time_sig = [4, 2, 24, 8]
578
+ # unit = bar_length / 96.
579
+ # tol = 3. * unit
580
+ # hist_list = [unit * 96, unit * 48, unit * 24, unit * 12, unit * 6, unit * 72, unit * 36, unit * 18, unit * 9, unit * 32, unit * 16, unit * 8]
581
+ # current_tick = pattern[track_num][i].tick
582
+ # current_note = pattern[track_num][i].data[0]
583
+ # # find next note off
584
+ # for j in range(i, len(pattern[track_num])):
585
+ # # find next note off
586
+ # if type(pattern[track_num][j]) == midi.events.NoteOffEvent or (type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] == 0):
587
+ # if pattern[track_num][j].data[0] == current_note:
588
+
589
+ # note_length = pattern[track_num][j].tick - current_tick
590
+ # distance = np.abs(np.array(hist_list) - note_length)
591
+ # last_idx = idx
592
+ # idx = distance.argmin()
593
+ # if last_idx is not None:
594
+ # transition_matrix[last_idx][idx] += 1
595
+ # break
596
+ # else:
597
+ # if pattern[track_num][j].tick == current_tick:
598
+ # check_previous_off = False
599
+
600
+ # # find previous note off/on
601
+ # if check_previous_off is True:
602
+ # for j in range(i - 1, 0, -1):
603
+ # if type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] != 0:
604
+ # break
605
+
606
+ # elif type(pattern[track_num][j]) == midi.events.NoteOffEvent or (type(pattern[track_num][j]) == midi.events.NoteOnEvent and pattern[track_num][j].data[1] == 0):
607
+
608
+ # note_length = current_tick - pattern[track_num][j].tick
609
+ # distance = np.abs(np.array(hist_list) - note_length)
610
+
611
+ # last_idx = idx
612
+ # idx = distance.argmin()
613
+ # if last_idx is not None:
614
+ # if distance[idx] < tol:
615
+ # idx = last_idx
616
+ # transition_matrix[last_idx][idx + 12] += 1
617
+ # break
618
+
619
+ # if normalize == 0:
620
+ # return transition_matrix
621
+
622
+ # elif normalize == 1:
623
+
624
+ # sums = np.sum(transition_matrix, axis=1)
625
+ # sums[sums == 0] = 1
626
+ # return transition_matrix / sums.reshape(-1, 1)
627
+
628
+ # elif normalize == 2:
629
+
630
+ # return transition_matrix / sum(sum(transition_matrix))
631
+
632
+ # else:
633
+ # print "invalid normalization mode, return unnormalized matrix"
634
+ # return transition_matrix
635
+
636
+ # def chord_dependency(self, feature, bar_chord, bpm=120, num_bar=None, track_num=1):
637
+ # pm_object = feature['pretty_midi']
638
+ # # compare bar chroma with chord chroma. calculate the ecludian
639
+ # bar_pitch_class_histogram = self.bar_pitch_class_histogram(pm_object, bpm=bpm, num_bar=num_bar, track_num=track_num)
640
+ # dist = np.zeros((len(bar_pitch_class_histogram)))
641
+ # for i in range((len(bar_pitch_class_histogram))):
642
+ # dist[i] = np.linalg.norm(bar_pitch_class_histogram[i] - bar_chord[i])
643
+ # average_dist = np.mean(dist)
644
+ # return average_dist
music_evaluation/mgeval/core.pyc ADDED
Binary file (18.2 kB). View file