crlandsc commited on
Commit
d3378e2
1 Parent(s): f055a16

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Imports
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ import torchaudio
6
+ from torch import nn
7
+ import pytorch_lightning as pl
8
+ from ema_pytorch import EMA
9
+ import yaml
10
+ from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
11
+
12
+
13
+ # Load configs
14
+ def load_configs(config_path):
15
+ with open(config_path, 'r') as file:
16
+ config = yaml.safe_load(file)
17
+ pl_configs = config['model']
18
+ model_configs = config['model']['model']
19
+ return pl_configs, model_configs
20
+
21
+ # plot mel spectrogram
22
+ def plot_mel_spectrogram(sample, sr):
23
+ transform = torchaudio.transforms.MelSpectrogram(
24
+ sample_rate=sr,
25
+ n_fft=1024,
26
+ hop_length=512,
27
+ n_mels=80,
28
+ center=True,
29
+ norm="slaney",
30
+ )
31
+
32
+ spectrogram = transform(torch.mean(sample, dim=0)) # downmix and cal spectrogram
33
+ spectrogram = torchaudio.functional.amplitude_to_DB(spectrogram, 1.0, 1e-10, 80.0)
34
+
35
+ # Plot the Mel spectrogram
36
+ fig = plt.figure(figsize=(7, 4))
37
+ plt.imshow(spectrogram, aspect='auto', origin='lower')
38
+ plt.colorbar(format='%+2.0f dB')
39
+ plt.xlabel('Frame')
40
+ plt.ylabel('Mel Bin')
41
+ plt.title('Mel Spectrogram')
42
+ plt.tight_layout()
43
+
44
+ return fig
45
+
46
+ # Define PyTorch Lightning model
47
+ class Model(pl.LightningModule):
48
+ def __init__(
49
+ self,
50
+ lr: float,
51
+ lr_beta1: float,
52
+ lr_beta2: float,
53
+ lr_eps: float,
54
+ lr_weight_decay: float,
55
+ ema_beta: float,
56
+ ema_power: float,
57
+ model: nn.Module,
58
+ ):
59
+ super().__init__()
60
+ self.lr = lr
61
+ self.lr_beta1 = lr_beta1
62
+ self.lr_beta2 = lr_beta2
63
+ self.lr_eps = lr_eps
64
+ self.lr_weight_decay = lr_weight_decay
65
+ self.model = model
66
+ self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power)
67
+
68
+ # Instantiate model (must match model that was trained)
69
+ def load_model(model_configs, pl_configs) -> nn.Module:
70
+ # Diffusion model
71
+ model = DiffusionModel(
72
+ net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
73
+ in_channels=model_configs['in_channels'], # U-Net: number of input/output (audio) channels
74
+ channels=model_configs['channels'], # U-Net: channels at each layer
75
+ factors=model_configs['factors'], # U-Net: downsampling and upsampling factors at each layer
76
+ items=model_configs['items'], # U-Net: number of repeating items at each layer
77
+ attentions=model_configs['attentions'], # U-Net: attention enabled/disabled at each layer
78
+ attention_heads=model_configs['attention_heads'], # U-Net: number of attention heads per attention item
79
+ attention_features=model_configs['attention_features'], # U-Net: number of attention features per attention item
80
+ diffusion_t=VDiffusion, # The diffusion method used
81
+ sampler_t=VSampler # The diffusion sampler used
82
+ )
83
+
84
+ # pl model
85
+ model = Model(
86
+ lr=pl_configs['lr'],
87
+ lr_beta1=pl_configs['lr_beta1'],
88
+ lr_beta2=pl_configs['lr_beta2'],
89
+ lr_eps=pl_configs['lr_eps'],
90
+ lr_weight_decay=pl_configs['lr_weight_decay'],
91
+ ema_beta=pl_configs['ema_beta'],
92
+ ema_power=pl_configs['ema_power'],
93
+ model=model
94
+ )
95
+
96
+ return model
97
+
98
+ # Assign to GPU
99
+ def assign_to_gpu(model):
100
+ if torch.cuda.is_available():
101
+ model = model.to('cuda')
102
+ print(f"Device: {model.device}")
103
+ return model
104
+
105
+ # Load model checkpoint
106
+ def load_checkpoint(model, ckpt_path) -> None:
107
+ checkpoint = torch.load(ckpt_path, map_location='cpu')['state_dict']
108
+ model.load_state_dict(checkpoint) # should output "<All keys matched successfully>"
109
+
110
+
111
+ # Generate Samples
112
+ def generate_samples(model_name, num_samples, num_steps, duration=32768):
113
+ # load_checkpoint
114
+ ckpt_path = models[model_name]
115
+ load_checkpoint(model, ckpt_path)
116
+
117
+ with torch.no_grad():
118
+ all_samples = torch.zeros(2, 0) # initialize all samples
119
+ for i in range(num_samples):
120
+ noise = torch.randn((1, 2, int(duration)), device=model.device) # [batch_size, in_channels, length]
121
+ generated_sample = model.model_ema.ema_model.sample(noise, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100
122
+
123
+ # concatenate all samples:
124
+ all_samples = torch.concat((all_samples, generated_sample), dim=1)
125
+
126
+ torch.cuda.empty_cache()
127
+
128
+ fig = plot_mel_spectrogram(all_samples, sr)
129
+ plt.title(f"{model_name} Mel Spectrogram")
130
+
131
+ return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot
132
+
133
+ # load model & configs
134
+ sr = 44100 # sampling rate
135
+ config_path = "saved_models/config.yaml" # config path
136
+ pl_configs, model_configs = load_configs(config_path)
137
+ model = load_model(model_configs, pl_configs)
138
+ model = assign_to_gpu(model)
139
+
140
+ models = {
141
+ "Kicks": "saved_models/kicks/kicks_v7.ckpt",
142
+ "Snares": "saved_models/snares/snares_v0.ckpt",
143
+ "Hi-hats": "saved_models/hihats/hihats_v2.ckpt",
144
+ "Percussion": "saved_models/percussion/percussion_v0.ckpt"
145
+ }
146
+
147
+ demo = gr.Interface(
148
+ generate_samples,
149
+ inputs=[
150
+ gr.Dropdown(choices=list(models.keys()), value=list(models.keys())[0], label="Model"),
151
+ gr.Slider(1, 25, step=1, label="Number of Samples to Generate", value=1),
152
+ gr.Slider(1, 100, step=1, label="Number of Diffusion Steps", value=10)
153
+ ],
154
+ outputs=[
155
+ gr.Audio(label="Generated Audio Sample"),
156
+ gr.Plot(label="Generated Audio Spectrogram")
157
+ ]
158
+ )
159
+
160
+ if __name__ == "__main__":
161
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0
2
+ torchaudio>=2.0
3
+ pytorch-lightning==1.7.7
4
+ python-dotenv
5
+ hydra-core
6
+ hydra-colorlog
7
+ wandb
8
+ auraloss
9
+ yt-dlp
10
+ datasets
11
+ pyloudnorm
12
+ einops
13
+ omegaconf
14
+ rich
15
+ plotly
16
+ librosa
17
+ transformers
18
+ eng-to-ipa
19
+ ema-pytorch
20
+ py7zr
21
+ notebook
22
+ matplotlib
23
+ ipykernel
24
+ gradio
25
+
26
+ # k-diffusion
27
+ # v-diffusion-pytorch
28
+
29
+ audio-diffusion-pytorch==0.1.3
30
+ audio-encoders-pytorch
31
+ audio-data-pytorch
32
+ quantizer-pytorch
33
+ difformer-pytorch
34
+ a-transformers-pytorch
saved_models/config.yaml ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 12345
2
+ train: true
3
+ ignore_warnings: true
4
+ print_config: false
5
+ work_dir: ${hydra:runtime.cwd}
6
+ logs_dir: ${work_dir}${oc.env:DIR_LOGS}
7
+ data_dir: ${work_dir}${oc.env:DIR_DATA}
8
+ ckpt_dir: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
9
+ module: main.module_base
10
+ batch_size: 1
11
+ accumulate_grad_batches: 32
12
+ num_workers: 8
13
+ sampling_rate: 44100
14
+ length: 32768
15
+ channels: 2
16
+ log_every_n_steps: 1000
17
+ model:
18
+ _target_: ${module}.Model
19
+ lr: 0.0001
20
+ lr_beta1: 0.95
21
+ lr_beta2: 0.999
22
+ lr_eps: 1.0e-06
23
+ lr_weight_decay: 0.001
24
+ ema_beta: 0.995
25
+ ema_power: 0.7
26
+ model:
27
+ _target_: main.DiffusionModel
28
+ net_t:
29
+ _target_: ${module}.UNetT
30
+ in_channels: 2
31
+ channels:
32
+ - 32
33
+ - 32
34
+ - 64
35
+ - 64
36
+ - 128
37
+ - 128
38
+ - 256
39
+ - 256
40
+ factors:
41
+ - 1
42
+ - 2
43
+ - 2
44
+ - 2
45
+ - 2
46
+ - 2
47
+ - 2
48
+ - 2
49
+ items:
50
+ - 2
51
+ - 2
52
+ - 2
53
+ - 2
54
+ - 2
55
+ - 2
56
+ - 4
57
+ - 4
58
+ attentions:
59
+ - 0
60
+ - 0
61
+ - 0
62
+ - 0
63
+ - 0
64
+ - 1
65
+ - 1
66
+ - 1
67
+ attention_heads: 8
68
+ attention_features: 64
69
+ datamodule:
70
+ _target_: main.module_base.Datamodule
71
+ dataset:
72
+ _target_: audio_data_pytorch.WAVDataset
73
+ path: ./data/wav_dataset/kicks
74
+ recursive: true
75
+ sample_rate: ${sampling_rate}
76
+ transforms:
77
+ _target_: audio_data_pytorch.AllTransform
78
+ crop_size: ${length}
79
+ stereo: true
80
+ source_rate: ${sampling_rate}
81
+ target_rate: ${sampling_rate}
82
+ loudness: -20
83
+ val_split: 0.05
84
+ batch_size: ${batch_size}
85
+ num_workers: ${num_workers}
86
+ pin_memory: true
87
+ callbacks:
88
+ rich_progress_bar:
89
+ _target_: pytorch_lightning.callbacks.RichProgressBar
90
+ model_checkpoint:
91
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
92
+ monitor: valid_loss
93
+ save_top_k: 1
94
+ save_last: true
95
+ mode: min
96
+ verbose: false
97
+ dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
98
+ filename: '{epoch:02d}-{valid_loss:.3f}'
99
+ model_summary:
100
+ _target_: pytorch_lightning.callbacks.RichModelSummary
101
+ max_depth: 2
102
+ audio_samples_logger:
103
+ _target_: main.module_base.SampleLogger
104
+ num_items: 4
105
+ channels: ${channels}
106
+ sampling_rate: ${sampling_rate}
107
+ length: ${length}
108
+ sampling_steps:
109
+ - 50
110
+ use_ema_model: true
111
+ loggers:
112
+ wandb:
113
+ _target_: pytorch_lightning.loggers.wandb.WandbLogger
114
+ project: ${oc.env:WANDB_PROJECT}
115
+ entity: ${oc.env:WANDB_ENTITY}
116
+ name: kicks_v7
117
+ job_type: train
118
+ group: ''
119
+ save_dir: ${logs_dir}
120
+ trainer:
121
+ _target_: pytorch_lightning.Trainer
122
+ gpus: 1
123
+ precision: 16
124
+ accelerator: gpu
125
+ min_epochs: 0
126
+ max_epochs: -1
127
+ enable_model_summary: false
128
+ log_every_n_steps: 1
129
+ check_val_every_n_epoch: null
130
+ val_check_interval: ${log_every_n_steps}
131
+ accumulate_grad_batches: ${accumulate_grad_batches}
saved_models/hihats/hihats_v2.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc7245d3d5617bb3a76dcc8534d9cee25030c3986fa80502f19ec3506a68d05c
3
+ size 509086593
saved_models/kicks/kicks_v7.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f3511269e10edc889cfd50393fd5228cdfb069185afc9d92263cef548a18482
3
+ size 509086593
saved_models/percussion/percussion_v0.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8fe5dc0295738995cb74892a7d70a074abdfd2c7e887951a2bc9814ec9acfaf
3
+ size 509086593
saved_models/snares/snares_v0.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2f906655666200635267c3a92ff87631f4bb4ef94bf087cfee3e2611da9b30b
3
+ size 509086593