ymzhang319 commited on
Commit
7f2690b
1 Parent(s): 8c104ce
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +276 -0
  2. configs/auffusion/vocoder/config.json +37 -0
  3. configs/train/train_semantic_adapter.yaml +54 -0
  4. configs/train/train_temporal_adapter.yaml +48 -0
  5. environment.yaml +24 -0
  6. foleycrafter/data/dataset.py +175 -0
  7. foleycrafter/data/video_transforms.py +400 -0
  8. foleycrafter/models/adapters/attention_processor.py +653 -0
  9. foleycrafter/models/adapters/ip_adapter.py +217 -0
  10. foleycrafter/models/adapters/resampler.py +158 -0
  11. foleycrafter/models/adapters/transformer.py +327 -0
  12. foleycrafter/models/adapters/utils.py +81 -0
  13. foleycrafter/models/auffusion/attention.py +669 -0
  14. foleycrafter/models/auffusion/attention_processor.py +0 -0
  15. foleycrafter/models/auffusion/dual_transformer_2d.py +156 -0
  16. foleycrafter/models/auffusion/loaders/ip_adapter.py +520 -0
  17. foleycrafter/models/auffusion/loaders/unet.py +1100 -0
  18. foleycrafter/models/auffusion/resnet.py +685 -0
  19. foleycrafter/models/auffusion/transformer_2d.py +460 -0
  20. foleycrafter/models/auffusion/unet_2d_blocks.py +0 -0
  21. foleycrafter/models/auffusion_unet.py +1260 -0
  22. foleycrafter/models/specvqgan/data/greatesthit.py +993 -0
  23. foleycrafter/models/specvqgan/data/impactset.py +778 -0
  24. foleycrafter/models/specvqgan/data/transforms.py +685 -0
  25. foleycrafter/models/specvqgan/data/utils.py +265 -0
  26. foleycrafter/models/specvqgan/models/av_cond_transformer.py +528 -0
  27. foleycrafter/models/specvqgan/models/cond_transformer.py +455 -0
  28. foleycrafter/models/specvqgan/models/vqgan.py +397 -0
  29. foleycrafter/models/specvqgan/modules/diffusionmodules/model.py +999 -0
  30. foleycrafter/models/specvqgan/modules/discriminator/model.py +295 -0
  31. foleycrafter/models/specvqgan/modules/losses/__init__.py +7 -0
  32. foleycrafter/models/specvqgan/modules/losses/lpaps.py +152 -0
  33. foleycrafter/models/specvqgan/modules/losses/vggishish/configs/melception.yaml +24 -0
  34. foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish.yaml +34 -0
  35. foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh.yaml +25 -0
  36. foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh_action.yaml +25 -0
  37. foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh_material.yaml +25 -0
  38. foleycrafter/models/specvqgan/modules/losses/vggishish/dataset.py +295 -0
  39. foleycrafter/models/specvqgan/modules/losses/vggishish/logger.py +90 -0
  40. foleycrafter/models/specvqgan/modules/losses/vggishish/loss.py +41 -0
  41. foleycrafter/models/specvqgan/modules/losses/vggishish/metrics.py +69 -0
  42. foleycrafter/models/specvqgan/modules/losses/vggishish/model.py +77 -0
  43. foleycrafter/models/specvqgan/modules/losses/vggishish/predict.py +90 -0
  44. foleycrafter/models/specvqgan/modules/losses/vggishish/predict_gh.py +66 -0
  45. foleycrafter/models/specvqgan/modules/losses/vggishish/train_melception.py +241 -0
  46. foleycrafter/models/specvqgan/modules/losses/vggishish/train_vggishish.py +199 -0
  47. foleycrafter/models/specvqgan/modules/losses/vggishish/train_vggishish_gh.py +218 -0
  48. foleycrafter/models/specvqgan/modules/losses/vggishish/transforms.py +98 -0
  49. foleycrafter/models/specvqgan/modules/losses/vqperceptual.py +209 -0
  50. foleycrafter/models/specvqgan/modules/misc/class_cond.py +21 -0
app.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ from argparse import ArgumentParser
8
+ from datetime import datetime
9
+
10
+ import gradio as gr
11
+
12
+ from foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy
13
+ from foleycrafter.pipelines.auffusion_pipeline import denormalize_spectrogram
14
+ from foleycrafter.pipelines.auffusion_pipeline import Generator
15
+ from foleycrafter.models.time_detector.model import VideoOnsetNet
16
+ from foleycrafter.models.specvqgan.onset_baseline.utils import torch_utils
17
+
18
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
19
+ from huggingface_hub import snapshot_download
20
+ from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
21
+
22
+ import soundfile as sf
23
+ from moviepy.editor import AudioFileClip, VideoFileClip
24
+ os.environ['GRADIO_TEMP_DIR'] = './tmp'
25
+
26
+ sample_idx = 0
27
+ scheduler_dict = {
28
+ "DDIM": DDIMScheduler,
29
+ "Euler": EulerDiscreteScheduler,
30
+ "PNDM": PNDMScheduler,
31
+ }
32
+
33
+ css = """
34
+ .toolbutton {
35
+ margin-buttom: 0em 0em 0em 0em;
36
+ max-width: 2.5em;
37
+ min-width: 2.5em !important;
38
+ height: 2.5em;
39
+ }
40
+ """
41
+
42
+ parser = ArgumentParser()
43
+ parser.add_argument("--config", type=str, default="example/config/base.yaml")
44
+ parser.add_argument("--server-name", type=str, default="0.0.0.0")
45
+ parser.add_argument("--port", type=int, default=11451)
46
+ parser.add_argument("--share", action="store_true")
47
+
48
+ parser.add_argument("--save-path", default="samples")
49
+
50
+ args = parser.parse_args()
51
+
52
+
53
+ N_PROMPT = (
54
+ ""
55
+ )
56
+
57
+ class FoleyController:
58
+ def __init__(self):
59
+ # config dirs
60
+ self.basedir = os.getcwd()
61
+ self.model_dir = os.path.join(self.basedir, "models")
62
+ self.savedir = os.path.join(self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
63
+ self.savedir_sample = os.path.join(self.savedir, "sample")
64
+ os.makedirs(self.savedir, exist_ok=True)
65
+
66
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
67
+
68
+ self.pipeline = None
69
+
70
+ self.loaded = False
71
+
72
+ self.load_model()
73
+
74
+ def load_model(self):
75
+ gr.Info("Start Load Models...")
76
+ print("Start Load Models...")
77
+
78
+ # download ckpt
79
+ pretrained_model_name_or_path = 'auffusion/auffusion-full-no-adapter'
80
+ if not os.path.isdir(pretrained_model_name_or_path):
81
+ pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path, local_dir='models/auffusion')
82
+
83
+ fc_ckpt = 'ymzhang319/FoleyCrafter'
84
+ if not os.path.isdir(fc_ckpt):
85
+ fc_ckpt = snapshot_download(fc_ckpt, local_dir='models/')
86
+
87
+ # set model config
88
+ temporal_ckpt_path = osp.join(self.model_dir, 'temporal_adapter.ckpt')
89
+
90
+ # load vocoder
91
+ vocoder_config_path= "./models/auffusion"
92
+ self.vocoder = Generator.from_pretrained(
93
+ vocoder_config_path,
94
+ subfolder="vocoder").to(self.device)
95
+
96
+ # load time detector
97
+ time_detector_ckpt = osp.join(osp.join(self.model_dir, 'timestamp_detector.pth.tar'))
98
+ time_detector = VideoOnsetNet(False)
99
+ self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, time_detector, strict=True, device=self.device)
100
+
101
+ self.pipeline = build_foleycrafter().to(self.device)
102
+ ckpt = torch.load(temporal_ckpt_path)
103
+
104
+ # load temporal adapter
105
+ if 'state_dict' in ckpt.keys():
106
+ ckpt = ckpt['state_dict']
107
+ load_gligen_ckpt = {}
108
+ for key, value in ckpt.items():
109
+ if key.startswith('module.'):
110
+ load_gligen_ckpt[key[len('module.'):]] = value
111
+ else:
112
+ load_gligen_ckpt[key] = value
113
+ m, u = self.pipeline.controlnet.load_state_dict(load_gligen_ckpt, strict=False)
114
+ print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
115
+
116
+ self.image_processor = CLIPImageProcessor()
117
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder='models/image_encoder').to(self.device)
118
+
119
+ self.pipeline.load_ip_adapter(fc_ckpt, subfolder='semantic', weight_name='semantic_adapter.bin', image_encoder_folder=None)
120
+
121
+ gr.Info("Load Finish!")
122
+ print("Load Finish!")
123
+ self.loaded = True
124
+
125
+ return "Load"
126
+
127
+ def foley(
128
+ self,
129
+ input_video,
130
+ prompt_textbox,
131
+ negative_prompt_textbox,
132
+ ip_adapter_scale,
133
+ temporal_scale,
134
+ sampler_dropdown,
135
+ sample_step_slider,
136
+ cfg_scale_slider,
137
+ seed_textbox,
138
+ ):
139
+
140
+ vision_transform_list = [
141
+ torchvision.transforms.Resize((128, 128)),
142
+ torchvision.transforms.CenterCrop((112, 112)),
143
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
144
+ ]
145
+ video_transform = torchvision.transforms.Compose(vision_transform_list)
146
+ if not self.loaded:
147
+ raise gr.Error("Error with loading model")
148
+ generator = torch.Generator()
149
+ if seed_textbox != "":
150
+ torch.manual_seed(int(seed_textbox))
151
+ generator.manual_seed(int(seed_textbox))
152
+ max_frame_nums = 15
153
+ frames, duration = read_frames_with_moviepy(input_video, max_frame_nums=max_frame_nums)
154
+ if duration >= 10:
155
+ duration = 10
156
+ time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2)
157
+ time_frames = video_transform(time_frames)
158
+ time_frames = {'frames': time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)}
159
+ preds = self.time_detector(time_frames)
160
+ preds = torch.sigmoid(preds)
161
+
162
+ # duration
163
+ time_condition = [-1 if preds[0][int(i / (1024 / 10 * duration) * max_frame_nums)] < 0.5 else 1 for i in range(int(1024 / 10 * duration))]
164
+ time_condition = time_condition + [-1] * (1024 - len(time_condition))
165
+ # w -> b c h w
166
+ time_condition = torch.FloatTensor(time_condition).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 256, 1)
167
+
168
+ images = self.image_processor(images=frames, return_tensors="pt").to(self.device)
169
+ image_embeddings = self.image_encoder(**images).image_embeds
170
+ image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0)
171
+ neg_image_embeddings = torch.zeros_like(image_embeddings)
172
+ image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1)
173
+ self.pipeline.set_ip_adapter_scale(ip_adapter_scale)
174
+ sample = self.pipeline(
175
+ prompt=prompt_textbox,
176
+ negative_prompt=negative_prompt_textbox,
177
+ ip_adapter_image_embeds=image_embeddings,
178
+ image=time_condition,
179
+ controlnet_conditioning_scale=float(temporal_scale),
180
+ num_inference_steps=sample_step_slider,
181
+ height=256,
182
+ width=1024,
183
+ output_type="pt",
184
+ generator=generator,
185
+ )
186
+ name = 'output'
187
+ audio_img = sample.images[0]
188
+ audio = denormalize_spectrogram(audio_img)
189
+ audio = self.vocoder.inference(audio, lengths=160000)[0]
190
+ audio_save_path = osp.join(self.savedir_sample, 'audio')
191
+ os.makedirs(audio_save_path, exist_ok=True)
192
+ audio = audio[:int(duration * 16000)]
193
+
194
+ save_path = osp.join(audio_save_path, f'{name}.wav')
195
+ sf.write(save_path, audio, 16000)
196
+
197
+ audio = AudioFileClip(osp.join(audio_save_path, f'{name}.wav'))
198
+ video = VideoFileClip(input_video)
199
+ audio = audio.subclip(0, duration)
200
+ video.audio = audio
201
+ video = video.subclip(0, duration)
202
+ video.write_videofile(osp.join(self.savedir_sample, f'{name}.mp4'))
203
+ save_sample_path = os.path.join(self.savedir_sample, f"{name}.mp4")
204
+
205
+ return save_sample_path
206
+
207
+ controller = FoleyController()
208
+
209
+ def ui():
210
+ with gr.Blocks(css=css) as demo:
211
+ gr.HTML(
212
+ "<div align='center'><font size='6'>FoleyCrafter: Bring Silent Videos to Life with Lifelike and Synchronized Sounds</font></div>"
213
+ )
214
+ with gr.Row():
215
+ gr.Markdown(
216
+ "<div align='center'><font size='5'><a href='https://foleycrafter.github.io/'>Project Page</a> &ensp;" # noqa
217
+ "<a href='https://arxiv.org/abs/xxxx.xxxxx/'>Paper</a> &ensp;"
218
+ "<a href='https://github.com/open-mmlab/foleycrafter'>Code</a> &ensp;"
219
+ "<a href='https://huggingface.co/spaces/ymzhang319/FoleyCrafter'>Demo</a> </font></div>"
220
+ )
221
+
222
+ with gr.Column(variant="panel"):
223
+ with gr.Row(equal_height=False):
224
+ with gr.Column():
225
+ with gr.Row():
226
+ init_img = gr.Video(label="Input Video")
227
+ with gr.Row():
228
+ prompt_textbox = gr.Textbox(value='', label="Prompt", lines=1)
229
+ with gr.Row():
230
+ negative_prompt_textbox = gr.Textbox(value=N_PROMPT, label="Negative prompt", lines=1)
231
+
232
+ with gr.Row():
233
+ sampler_dropdown = gr.Dropdown(
234
+ label="Sampling method",
235
+ choices=list(scheduler_dict.keys()),
236
+ value=list(scheduler_dict.keys())[0],
237
+ )
238
+ sample_step_slider = gr.Slider(
239
+ label="Sampling steps", value=25, minimum=10, maximum=100, step=1
240
+ )
241
+
242
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
243
+ ip_adapter_scale = gr.Slider(label="Visual Content Scale", value=1.0, minimum=0, maximum=1)
244
+ temporal_scale = gr.Slider(label="Temporal Align Scale", value=0., minimum=0., maximum=1.0)
245
+
246
+ with gr.Row():
247
+ seed_textbox = gr.Textbox(label="Seed", value=42)
248
+ seed_button = gr.Button(value="\U0001f3b2", elem_classes="toolbutton")
249
+ seed_button.click(fn=lambda x: random.randint(1, 1e8), outputs=[seed_textbox], queue=False)
250
+
251
+ generate_button = gr.Button(value="Generate", variant="primary")
252
+
253
+ result_video = gr.Video(label="Generated Audio", interactive=False)
254
+
255
+ generate_button.click(
256
+ fn=controller.foley,
257
+ inputs=[
258
+ init_img,
259
+ prompt_textbox,
260
+ negative_prompt_textbox,
261
+ ip_adapter_scale,
262
+ temporal_scale,
263
+ sampler_dropdown,
264
+ sample_step_slider,
265
+ cfg_scale_slider,
266
+ seed_textbox,
267
+ ],
268
+ outputs=[result_video],
269
+ )
270
+
271
+ return demo
272
+
273
+ if __name__ == "__main__":
274
+ demo = ui()
275
+ demo.queue(3)
276
+ demo.launch(server_name=args.server_name, server_port=args.port, share=args.share)
configs/auffusion/vocoder/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 16,
5
+ "learning_rate": 0.0002,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.999,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [5,4,4,2],
12
+ "upsample_kernel_sizes": [11,8,8,4],
13
+ "upsample_initial_channel": 512,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "segment_size": 5120,
18
+ "num_mels": 256,
19
+ "num_freq": 2049,
20
+ "n_fft": 2048,
21
+ "hop_size": 160,
22
+ "win_size": 1024,
23
+
24
+ "sampling_rate": 16000,
25
+
26
+ "fmin": 0,
27
+ "fmax": null,
28
+ "fmax_for_loss": null,
29
+
30
+ "num_workers": 4,
31
+
32
+ "dist_config": {
33
+ "dist_backend": "nccl",
34
+ "dist_url": "tcp://localhost:54321",
35
+ "world_size": 1
36
+ }
37
+ }
configs/train/train_semantic_adapter.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "outputs"
2
+
3
+ pretrained_model_path: ""
4
+
5
+ motion_module_path: "models/mm_sd_v15_v2.ckpt"
6
+
7
+ train_data:
8
+ csv_path: "./curated.csv"
9
+ audio_fps: 48000
10
+ audio_size: 480000
11
+
12
+ validation_data:
13
+ prompts:
14
+ - "./data/input/lighthouse.png"
15
+ - "./data/input/guitar.png"
16
+ - "./data/input/lion.png"
17
+ - "./data/input/gun.png"
18
+ num_inference_steps: 25
19
+ guidance_scale: 7.5
20
+ sample_size: 512
21
+
22
+ trainable_modules:
23
+ - 'to_k_ip'
24
+ - 'to_v_ip'
25
+
26
+ audio_unet_checkpoint_path: ""
27
+
28
+ learning_rate: 1.0e-4
29
+ train_batch_size: 1 # max for mixed
30
+ gradient_accumulation_steps: 1
31
+
32
+ max_train_epoch: -1
33
+ max_train_steps: 200000
34
+ checkpointing_epochs: 4000
35
+ checkpointing_steps: 500
36
+
37
+ validation_steps: 3000
38
+ validation_steps_tuple: [2, 50, 300, 1000]
39
+
40
+ global_seed: 42
41
+ mixed_precision_training: true
42
+
43
+ is_debug: False
44
+
45
+ resume_ckpt: ""
46
+
47
+ # params for adapter
48
+ init_from_ip_adapter: false
49
+
50
+ always_null_text: false
51
+
52
+ reverse_null_text_prob: true
53
+
54
+ frame_wise_condition: true
configs/train/train_temporal_adapter.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "outputs"
2
+
3
+ pretrained_model_path: ""
4
+
5
+ motion_module_path: "models/mm_sd_v15_v2.ckpt"
6
+
7
+ train_data:
8
+ csv_path: "./curated.csv"
9
+ audio_fps: 48000
10
+ audio_size: 480000
11
+
12
+ validation_data:
13
+ prompts:
14
+ - "./data/input/lighthouse.png"
15
+ - "./data/input/guitar.png"
16
+ - "./data/input/lion.png"
17
+ - "./data/input/gun.png"
18
+ num_inference_steps: 25
19
+ guidance_scale: 7.5
20
+ sample_size: 512
21
+
22
+ trainable_modules:
23
+ - 'time_conv_in.'
24
+ - 'conv_in.'
25
+
26
+ video_unet_checkpoint_path: "models/vggsound_unet.ckpt"
27
+ audio_unet_checkpoint_path: ""
28
+
29
+ learning_rate: 5.0e-5
30
+ train_batch_size: 1 # max for mixed
31
+ gradient_accumulation_steps: 1
32
+
33
+ max_train_epoch: -1
34
+ max_train_steps: 500000
35
+ checkpointing_epochs: 4000
36
+ checkpointing_steps: 500
37
+
38
+ validation_steps: 3000
39
+ validation_steps_tuple: [2, 300, 1000]
40
+
41
+ global_seed: 42
42
+ mixed_precision_training: true
43
+
44
+ is_debug: False
45
+
46
+ resume_ckpt: ""
47
+
48
+ zero_no_label_mel: false
environment.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: foleycrafter
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.10
7
+ - pytorch=2.2.0
8
+ - torchvision=0.17.0
9
+ - pytorch-cuda=11.8
10
+ - pip
11
+ - pip:
12
+ - diffusers==0.25.1
13
+ - transformers==4.30.2
14
+ - xformers
15
+ - imageio==2.33.1
16
+ - decord==0.6.0
17
+ - einops
18
+ - omegaconf
19
+ - safetensors
20
+ - gradio
21
+ - tqdm==4.66.1
22
+ - soundfile==0.12.1
23
+ - wandb
24
+ - moviepy==1.0.3
foleycrafter/data/dataset.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from torch.utils.data.dataset import Dataset
4
+ import torch.distributed as dist
5
+ import torchaudio
6
+ import torchvision
7
+ import torchvision.io
8
+
9
+ import os, io, csv, math, random
10
+ import os.path as osp
11
+ from pathlib import Path
12
+ import numpy as np
13
+ import pandas as pd
14
+ from einops import rearrange
15
+ import glob
16
+
17
+ from decord import VideoReader, AudioReader
18
+ import decord
19
+ from copy import deepcopy
20
+ import pickle
21
+
22
+ from petrel_client.client import Client
23
+ import sys
24
+ sys.path.append('./')
25
+ from foleycrafter.data import video_transforms
26
+
27
+ from foleycrafter.utils.util import \
28
+ random_audio_video_clip, get_full_indices, video_tensor_to_np, get_video_frames
29
+ from foleycrafter.utils.spec_to_mel import wav_tensor_to_fbank, read_wav_file_io, load_audio, normalize_wav, pad_wav
30
+ from foleycrafter.utils.converter import get_mel_spectrogram_from_audio, pad_spec, normalize, normalize_spectrogram
31
+
32
+ def zero_rank_print(s):
33
+ if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s, flush=True)
34
+
35
+ @torch.no_grad()
36
+ def get_mel(audio_data, audio_cfg):
37
+ # mel shape: (n_mels, T)
38
+ mel = torchaudio.transforms.MelSpectrogram(
39
+ sample_rate=audio_cfg["sample_rate"],
40
+ n_fft=audio_cfg["window_size"],
41
+ win_length=audio_cfg["window_size"],
42
+ hop_length=audio_cfg["hop_size"],
43
+ center=True,
44
+ pad_mode="reflect",
45
+ power=2.0,
46
+ norm=None,
47
+ onesided=True,
48
+ n_mels=64,
49
+ f_min=audio_cfg["fmin"],
50
+ f_max=audio_cfg["fmax"],
51
+ ).to(audio_data.device)
52
+ mel = mel(audio_data)
53
+ # we use log mel spectrogram as input
54
+ mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
55
+ return mel # (T, n_mels)
56
+
57
+ def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
58
+ """
59
+ PARAMS
60
+ ------
61
+ C: compression factor
62
+ """
63
+ return normalize_fun(torch.clamp(x, min=clip_val) * C)
64
+
65
+ class CPU_Unpickler(pickle.Unpickler):
66
+ def find_class(self, module, name):
67
+ if module == 'torch.storage' and name == '_load_from_bytes':
68
+ return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
69
+ else:
70
+ return super().find_class(module, name)
71
+
72
+ class AudioSetStrong(Dataset):
73
+ # read feature and audio
74
+ def __init__(
75
+ self,
76
+ ):
77
+ super().__init__()
78
+ self.data_path = 'data/AudioSetStrong/train/feature'
79
+ self.data_list = list(self._client.list(self.data_path))
80
+ self.length = len(self.data_list)
81
+ # get video feature
82
+ self.video_path = 'data/AudioSetStrong/train/video'
83
+ vision_transform_list = [
84
+ transforms.Resize((128, 128)),
85
+ transforms.CenterCrop((112, 112)),
86
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
87
+ ]
88
+ self.video_transform = transforms.Compose(vision_transform_list)
89
+
90
+ def get_batch(self, idx):
91
+ embeds = self.data_list[idx]
92
+ mel = embeds['mel']
93
+ save_bsz = mel.shape[0]
94
+ audio_info = embeds['audio_info']
95
+ text_embeds = embeds['text_embeds']
96
+
97
+ # audio_info['label_list'] = np.array(audio_info['label_list'])
98
+ audio_info_array = np.array(audio_info['label_list'])
99
+ prompts = []
100
+ for i in range(save_bsz):
101
+ prompts.append(', '.join(audio_info_array[i, :audio_info['event_num'][i]].tolist()))
102
+ # import ipdb; ipdb.set_trace()
103
+ # read videos
104
+ videos = None
105
+ for video_name in audio_info['audio_name']:
106
+ video_bytes = self._client.Get(osp.join(self.video_path, video_name+'.mp4'))
107
+ video_bytes = io.BytesIO(video_bytes)
108
+ video_reader = VideoReader(video_bytes)
109
+ video = video_reader.get_batch(get_full_indices(video_reader)).asnumpy()
110
+ video = get_video_frames(video, 150)
111
+ video = torch.from_numpy(video).permute(0, 3, 1, 2).contiguous().float()
112
+ video = self.video_transform(video)
113
+ video = video.unsqueeze(0)
114
+ if videos is None:
115
+ videos = video
116
+ else:
117
+ videos = torch.cat([videos, video], dim=0)
118
+ # video = torch.from_numpy(video).permute(0, 3, 1, 2).contiguous()
119
+ assert videos is not None, 'no video read'
120
+
121
+ return mel, audio_info, text_embeds, prompts, videos
122
+
123
+ def __len__(self):
124
+ return self.length
125
+
126
+ def __getitem__(self, idx):
127
+ while True:
128
+ try:
129
+ mel, audio_info, text_embeds, prompts, videos = self.get_batch(idx)
130
+ break
131
+ except Exception as e:
132
+ zero_rank_print(' >>> load error <<<')
133
+ idx = random.randint(0, self.length-1)
134
+ sample = dict(mel=mel, audio_info=audio_info, text_embeds=text_embeds, prompts=prompts, videos=videos)
135
+ return sample
136
+
137
+ class VGGSound(Dataset):
138
+ # read feature and audio
139
+ def __init__(
140
+ self,
141
+ ):
142
+ super().__init__()
143
+ self.data_path = 'data/VGGSound/train/video'
144
+ self.visual_data_path = 'data/VGGSound/train/feature'
145
+ self.embeds_list = glob.glob(f'{self.data_path}/*.pt')
146
+ self.visual_list = glob.glob(f'{self.visual_data_path}/*.pt')
147
+ self.length = len(self.embeds_list)
148
+
149
+ def get_batch(self, idx):
150
+ embeds = torch.load(self.embeds_list[idx], map_location='cpu')
151
+ visual_embeds = torch.load(self.visual_list[idx], map_location='cpu')
152
+
153
+ # audio_embeds = embeds['audio_embeds']
154
+ visual_embeds = visual_embeds['visual_embeds']
155
+ video_name = embeds['video_name']
156
+ text = embeds['text']
157
+ mel = embeds['mel']
158
+
159
+ audio = mel
160
+
161
+ return visual_embeds, audio, text
162
+
163
+ def __len__(self):
164
+ return self.length
165
+
166
+ def __getitem__(self, idx):
167
+ while True:
168
+ try:
169
+ visual_embeds, audio, text = self.get_batch(idx)
170
+ break
171
+ except Exception as e:
172
+ zero_rank_print('load error')
173
+ idx = random.randint(0, self.length-1)
174
+ sample = dict(visual_embeds=visual_embeds, audio=audio, text=text)
175
+ return sample
foleycrafter/data/video_transforms.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numbers
4
+ from torchvision.transforms import RandomCrop, RandomResizedCrop
5
+
6
+ def _is_tensor_video_clip(clip):
7
+ if not torch.is_tensor(clip):
8
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
9
+
10
+ if not clip.ndimension() == 4:
11
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
12
+
13
+ return True
14
+
15
+
16
+ def crop(clip, i, j, h, w):
17
+ """
18
+ Args:
19
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
20
+ """
21
+ if len(clip.size()) != 4:
22
+ raise ValueError("clip should be a 4D tensor")
23
+ return clip[..., i : i + h, j : j + w]
24
+
25
+
26
+ def resize(clip, target_size, interpolation_mode):
27
+ if len(target_size) != 2:
28
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
29
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
30
+
31
+ def resize_scale(clip, target_size, interpolation_mode):
32
+ if len(target_size) != 2:
33
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
34
+ _, _, H, W = clip.shape
35
+ scale_ = target_size[0] / min(H, W)
36
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
37
+
38
+
39
+ def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
40
+ """
41
+ Do spatial cropping and resizing to the video clip
42
+ Args:
43
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
44
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
45
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
46
+ h (int): Height of the cropped region.
47
+ w (int): Width of the cropped region.
48
+ size (tuple(int, int)): height and width of resized clip
49
+ Returns:
50
+ clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
51
+ """
52
+ if not _is_tensor_video_clip(clip):
53
+ raise ValueError("clip should be a 4D torch.tensor")
54
+ clip = crop(clip, i, j, h, w)
55
+ clip = resize(clip, size, interpolation_mode)
56
+ return clip
57
+
58
+
59
+ def center_crop(clip, crop_size):
60
+ if not _is_tensor_video_clip(clip):
61
+ raise ValueError("clip should be a 4D torch.tensor")
62
+ h, w = clip.size(-2), clip.size(-1)
63
+ th, tw = crop_size
64
+ if h < th or w < tw:
65
+ raise ValueError("height and width must be no smaller than crop_size")
66
+
67
+ i = int(round((h - th) / 2.0))
68
+ j = int(round((w - tw) / 2.0))
69
+ return crop(clip, i, j, th, tw)
70
+
71
+ def random_shift_crop(clip):
72
+ '''
73
+ Slide along the long edge, with the short edge as crop size
74
+ '''
75
+ if not _is_tensor_video_clip(clip):
76
+ raise ValueError("clip should be a 4D torch.tensor")
77
+ h, w = clip.size(-2), clip.size(-1)
78
+
79
+ if h <= w:
80
+ long_edge = w
81
+ short_edge = h
82
+ else:
83
+ long_edge = h
84
+ short_edge =w
85
+
86
+ th, tw = short_edge, short_edge
87
+
88
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
89
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
90
+ return crop(clip, i, j, th, tw)
91
+
92
+
93
+ def to_tensor(clip):
94
+ """
95
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
96
+ permute the dimensions of clip tensor
97
+ Args:
98
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
99
+ Return:
100
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
101
+ """
102
+ _is_tensor_video_clip(clip)
103
+ if not clip.dtype == torch.uint8:
104
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
105
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
106
+ return clip.float() / 255.0
107
+
108
+
109
+ def normalize(clip, mean, std, inplace=False):
110
+ """
111
+ Args:
112
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
113
+ mean (tuple): pixel RGB mean. Size is (3)
114
+ std (tuple): pixel standard deviation. Size is (3)
115
+ Returns:
116
+ normalized clip (torch.tensor): Size is (T, C, H, W)
117
+ """
118
+ if not _is_tensor_video_clip(clip):
119
+ raise ValueError("clip should be a 4D torch.tensor")
120
+ if not inplace:
121
+ clip = clip.clone()
122
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
123
+ print(mean)
124
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
125
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
126
+ return clip
127
+
128
+
129
+ def hflip(clip):
130
+ """
131
+ Args:
132
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
133
+ Returns:
134
+ flipped clip (torch.tensor): Size is (T, C, H, W)
135
+ """
136
+ if not _is_tensor_video_clip(clip):
137
+ raise ValueError("clip should be a 4D torch.tensor")
138
+ return clip.flip(-1)
139
+
140
+
141
+ class RandomCropVideo:
142
+ def __init__(self, size):
143
+ if isinstance(size, numbers.Number):
144
+ self.size = (int(size), int(size))
145
+ else:
146
+ self.size = size
147
+
148
+ def __call__(self, clip):
149
+ """
150
+ Args:
151
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
152
+ Returns:
153
+ torch.tensor: randomly cropped video clip.
154
+ size is (T, C, OH, OW)
155
+ """
156
+ i, j, h, w = self.get_params(clip)
157
+ return crop(clip, i, j, h, w)
158
+
159
+ def get_params(self, clip):
160
+ h, w = clip.shape[-2:]
161
+ th, tw = self.size
162
+
163
+ if h < th or w < tw:
164
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
165
+
166
+ if w == tw and h == th:
167
+ return 0, 0, h, w
168
+
169
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
170
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
171
+
172
+ return i, j, th, tw
173
+
174
+ def __repr__(self) -> str:
175
+ return f"{self.__class__.__name__}(size={self.size})"
176
+
177
+
178
+ class UCFCenterCropVideo:
179
+ def __init__(
180
+ self,
181
+ size,
182
+ interpolation_mode="bilinear",
183
+ ):
184
+ if isinstance(size, tuple):
185
+ if len(size) != 2:
186
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
187
+ self.size = size
188
+ else:
189
+ self.size = (size, size)
190
+
191
+ self.interpolation_mode = interpolation_mode
192
+
193
+
194
+ def __call__(self, clip):
195
+ """
196
+ Args:
197
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
198
+ Returns:
199
+ torch.tensor: scale resized / center cropped video clip.
200
+ size is (T, C, crop_size, crop_size)
201
+ """
202
+ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
203
+ clip_center_crop = center_crop(clip_resize, self.size)
204
+ return clip_center_crop
205
+
206
+ def __repr__(self) -> str:
207
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
208
+
209
+ class KineticsRandomCropResizeVideo:
210
+ '''
211
+ Slide along the long edge, with the short edge as crop size. And resie to the desired size.
212
+ '''
213
+ def __init__(
214
+ self,
215
+ size,
216
+ interpolation_mode="bilinear",
217
+ ):
218
+ if isinstance(size, tuple):
219
+ if len(size) != 2:
220
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
221
+ self.size = size
222
+ else:
223
+ self.size = (size, size)
224
+
225
+ self.interpolation_mode = interpolation_mode
226
+
227
+ def __call__(self, clip):
228
+ clip_random_crop = random_shift_crop(clip)
229
+ clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
230
+ return clip_resize
231
+
232
+
233
+ class CenterCropVideo:
234
+ def __init__(
235
+ self,
236
+ size,
237
+ interpolation_mode="bilinear",
238
+ ):
239
+ if isinstance(size, tuple):
240
+ if len(size) != 2:
241
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
242
+ self.size = size
243
+ else:
244
+ self.size = (size, size)
245
+
246
+ self.interpolation_mode = interpolation_mode
247
+
248
+
249
+ def __call__(self, clip):
250
+ """
251
+ Args:
252
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
253
+ Returns:
254
+ torch.tensor: center cropped video clip.
255
+ size is (T, C, crop_size, crop_size)
256
+ """
257
+ clip_center_crop = center_crop(clip, self.size)
258
+ return clip_center_crop
259
+
260
+ def __repr__(self) -> str:
261
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
262
+
263
+
264
+ class NormalizeVideo:
265
+ """
266
+ Normalize the video clip by mean subtraction and division by standard deviation
267
+ Args:
268
+ mean (3-tuple): pixel RGB mean
269
+ std (3-tuple): pixel RGB standard deviation
270
+ inplace (boolean): whether do in-place normalization
271
+ """
272
+
273
+ def __init__(self, mean, std, inplace=False):
274
+ self.mean = mean
275
+ self.std = std
276
+ self.inplace = inplace
277
+
278
+ def __call__(self, clip):
279
+ """
280
+ Args:
281
+ clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
282
+ """
283
+ return normalize(clip, self.mean, self.std, self.inplace)
284
+
285
+ def __repr__(self) -> str:
286
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
287
+
288
+
289
+ class ToTensorVideo:
290
+ """
291
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
292
+ permute the dimensions of clip tensor
293
+ """
294
+
295
+ def __init__(self):
296
+ pass
297
+
298
+ def __call__(self, clip):
299
+ """
300
+ Args:
301
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
302
+ Return:
303
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
304
+ """
305
+ return to_tensor(clip)
306
+
307
+ def __repr__(self) -> str:
308
+ return self.__class__.__name__
309
+
310
+
311
+ class RandomHorizontalFlipVideo:
312
+ """
313
+ Flip the video clip along the horizontal direction with a given probability
314
+ Args:
315
+ p (float): probability of the clip being flipped. Default value is 0.5
316
+ """
317
+
318
+ def __init__(self, p=0.5):
319
+ self.p = p
320
+
321
+ def __call__(self, clip):
322
+ """
323
+ Args:
324
+ clip (torch.tensor): Size is (T, C, H, W)
325
+ Return:
326
+ clip (torch.tensor): Size is (T, C, H, W)
327
+ """
328
+ if random.random() < self.p:
329
+ clip = hflip(clip)
330
+ return clip
331
+
332
+ def __repr__(self) -> str:
333
+ return f"{self.__class__.__name__}(p={self.p})"
334
+
335
+ # ------------------------------------------------------------
336
+ # --------------------- Sampling ---------------------------
337
+ # ------------------------------------------------------------
338
+ class TemporalRandomCrop(object):
339
+ """Temporally crop the given frame indices at a random location.
340
+
341
+ Args:
342
+ size (int): Desired length of frames will be seen in the model.
343
+ """
344
+
345
+ def __init__(self, size):
346
+ self.size = size
347
+
348
+ def __call__(self, total_frames):
349
+ rand_end = max(0, total_frames - self.size - 1)
350
+ begin_index = random.randint(0, rand_end)
351
+ end_index = min(begin_index + self.size, total_frames)
352
+ return begin_index, end_index
353
+
354
+
355
+ if __name__ == '__main__':
356
+ from torchvision import transforms
357
+ import torchvision.io as io
358
+ import numpy as np
359
+ from torchvision.utils import save_image
360
+ import os
361
+
362
+ vframes, aframes, info = io.read_video(
363
+ filename='./v_Archery_g01_c03.avi',
364
+ pts_unit='sec',
365
+ output_format='TCHW'
366
+ )
367
+
368
+ trans = transforms.Compose([
369
+ ToTensorVideo(),
370
+ RandomHorizontalFlipVideo(),
371
+ UCFCenterCropVideo(512),
372
+ # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
373
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
374
+ ])
375
+
376
+ target_video_len = 32
377
+ frame_interval = 1
378
+ total_frames = len(vframes)
379
+ print(total_frames)
380
+
381
+ temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
382
+
383
+
384
+ # Sampling video frames
385
+ start_frame_ind, end_frame_ind = temporal_sample(total_frames)
386
+ # print(start_frame_ind)
387
+ # print(end_frame_ind)
388
+ assert end_frame_ind - start_frame_ind >= target_video_len
389
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
390
+
391
+ select_vframes = vframes[frame_indice]
392
+
393
+ select_vframes_trans = trans(select_vframes)
394
+
395
+ select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
396
+
397
+ io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
398
+
399
+ for i in range(target_video_len):
400
+ save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1))
foleycrafter/models/adapters/attention_processor.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Union
5
+ from einops import rearrange, repeat
6
+
7
+ from diffusers.utils import logging
8
+ from foleycrafter.models.adapters.ip_adapter import MLPProjModel
9
+
10
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
11
+
12
+ class AttnProcessor(nn.Module):
13
+ r"""
14
+ Default processor for performing attention-related computations.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ hidden_size=None,
20
+ cross_attention_dim=None,
21
+ ):
22
+ super().__init__()
23
+
24
+ def __call__(
25
+ self,
26
+ attn,
27
+ hidden_states,
28
+ encoder_hidden_states=None,
29
+ attention_mask=None,
30
+ temb=None,
31
+ ):
32
+ residual = hidden_states
33
+
34
+ if attn.spatial_norm is not None:
35
+ hidden_states = attn.spatial_norm(hidden_states, temb)
36
+
37
+ input_ndim = hidden_states.ndim
38
+
39
+ if input_ndim == 4:
40
+ batch_size, channel, height, width = hidden_states.shape
41
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
42
+
43
+ batch_size, sequence_length, _ = (
44
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
45
+ )
46
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
47
+
48
+ if attn.group_norm is not None:
49
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
50
+
51
+ query = attn.to_q(hidden_states)
52
+
53
+ if encoder_hidden_states is None:
54
+ encoder_hidden_states = hidden_states
55
+ elif attn.norm_cross:
56
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
57
+
58
+ key = attn.to_k(encoder_hidden_states)
59
+ value = attn.to_v(encoder_hidden_states)
60
+
61
+ query = attn.head_to_batch_dim(query)
62
+ key = attn.head_to_batch_dim(key)
63
+ value = attn.head_to_batch_dim(value)
64
+
65
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
66
+ hidden_states = torch.bmm(attention_probs, value)
67
+ hidden_states = attn.batch_to_head_dim(hidden_states)
68
+
69
+ # linear proj
70
+ hidden_states = attn.to_out[0](hidden_states)
71
+ # dropout
72
+ hidden_states = attn.to_out[1](hidden_states)
73
+
74
+ if input_ndim == 4:
75
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
76
+
77
+ if attn.residual_connection:
78
+ hidden_states = hidden_states + residual
79
+
80
+ hidden_states = hidden_states / attn.rescale_output_factor
81
+
82
+ return hidden_states
83
+
84
+
85
+ class IPAttnProcessor(nn.Module):
86
+ r"""
87
+ Attention processor for IP-Adapater.
88
+ Args:
89
+ hidden_size (`int`):
90
+ The hidden size of the attention layer.
91
+ cross_attention_dim (`int`):
92
+ The number of channels in the `encoder_hidden_states`.
93
+ scale (`float`, defaults to 1.0):
94
+ the weight scale of image prompt.
95
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
96
+ The context length of the image features.
97
+ """
98
+
99
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
100
+ super().__init__()
101
+
102
+ self.hidden_size = hidden_size
103
+ self.cross_attention_dim = cross_attention_dim
104
+ self.scale = scale
105
+ self.num_tokens = num_tokens
106
+
107
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
108
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
109
+
110
+ def __call__(
111
+ self,
112
+ attn,
113
+ hidden_states,
114
+ encoder_hidden_states=None,
115
+ attention_mask=None,
116
+ temb=None,
117
+ ):
118
+ residual = hidden_states
119
+
120
+ if attn.spatial_norm is not None:
121
+ hidden_states = attn.spatial_norm(hidden_states, temb)
122
+
123
+ input_ndim = hidden_states.ndim
124
+
125
+ if input_ndim == 4:
126
+ batch_size, channel, height, width = hidden_states.shape
127
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
128
+
129
+ batch_size, sequence_length, _ = (
130
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
131
+ )
132
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
133
+
134
+ if attn.group_norm is not None:
135
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
136
+
137
+ query = attn.to_q(hidden_states)
138
+
139
+ if encoder_hidden_states is None:
140
+ encoder_hidden_states = hidden_states
141
+ else:
142
+ # get encoder_hidden_states, ip_hidden_states
143
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
144
+ encoder_hidden_states, ip_hidden_states = (
145
+ encoder_hidden_states[:, :end_pos, :],
146
+ encoder_hidden_states[:, end_pos:, :],
147
+ )
148
+ if attn.norm_cross:
149
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
150
+
151
+ key = attn.to_k(encoder_hidden_states)
152
+ value = attn.to_v(encoder_hidden_states)
153
+
154
+ query = attn.head_to_batch_dim(query)
155
+ key = attn.head_to_batch_dim(key)
156
+ value = attn.head_to_batch_dim(value)
157
+
158
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
159
+ hidden_states = torch.bmm(attention_probs, value)
160
+ hidden_states = attn.batch_to_head_dim(hidden_states)
161
+
162
+ # for ip-adapter
163
+ ip_key = self.to_k_ip(ip_hidden_states)
164
+ ip_value = self.to_v_ip(ip_hidden_states)
165
+
166
+ ip_key = attn.head_to_batch_dim(ip_key)
167
+ ip_value = attn.head_to_batch_dim(ip_value)
168
+
169
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
170
+ self.attn_map = ip_attention_probs
171
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
172
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
173
+
174
+ hidden_states = hidden_states + self.scale * ip_hidden_states
175
+
176
+ # linear proj
177
+ hidden_states = attn.to_out[0](hidden_states)
178
+ # dropout
179
+ hidden_states = attn.to_out[1](hidden_states)
180
+
181
+ if input_ndim == 4:
182
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
183
+
184
+ if attn.residual_connection:
185
+ hidden_states = hidden_states + residual
186
+
187
+ hidden_states = hidden_states / attn.rescale_output_factor
188
+
189
+ return hidden_states
190
+
191
+
192
+ class AttnProcessor2_0(torch.nn.Module):
193
+ r"""
194
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ hidden_size=None,
200
+ cross_attention_dim=None,
201
+ ):
202
+ super().__init__()
203
+ if not hasattr(F, "scaled_dot_product_attention"):
204
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
205
+
206
+ def __call__(
207
+ self,
208
+ attn,
209
+ hidden_states,
210
+ encoder_hidden_states=None,
211
+ attention_mask=None,
212
+ temb=None,
213
+ ):
214
+ residual = hidden_states
215
+
216
+ if attn.spatial_norm is not None:
217
+ hidden_states = attn.spatial_norm(hidden_states, temb)
218
+
219
+ input_ndim = hidden_states.ndim
220
+
221
+ if input_ndim == 4:
222
+ batch_size, channel, height, width = hidden_states.shape
223
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
224
+
225
+ batch_size, sequence_length, _ = (
226
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
227
+ )
228
+
229
+ if attention_mask is not None:
230
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
231
+ # scaled_dot_product_attention expects attention_mask shape to be
232
+ # (batch, heads, source_length, target_length)
233
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
234
+
235
+ if attn.group_norm is not None:
236
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
237
+
238
+ query = attn.to_q(hidden_states)
239
+
240
+ if encoder_hidden_states is None:
241
+ encoder_hidden_states = hidden_states
242
+ elif attn.norm_cross:
243
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
244
+
245
+ key = attn.to_k(encoder_hidden_states)
246
+ value = attn.to_v(encoder_hidden_states)
247
+
248
+ inner_dim = key.shape[-1]
249
+ head_dim = inner_dim // attn.heads
250
+
251
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
252
+
253
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
254
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
255
+
256
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
257
+ # TODO: add support for attn.scale when we move to Torch 2.1
258
+ hidden_states = F.scaled_dot_product_attention(
259
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
260
+ )
261
+
262
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
263
+ hidden_states = hidden_states.to(query.dtype)
264
+
265
+ # linear proj
266
+ hidden_states = attn.to_out[0](hidden_states)
267
+ # dropout
268
+ hidden_states = attn.to_out[1](hidden_states)
269
+
270
+ if input_ndim == 4:
271
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
272
+
273
+ if attn.residual_connection:
274
+ hidden_states = hidden_states + residual
275
+
276
+ hidden_states = hidden_states / attn.rescale_output_factor
277
+
278
+ return hidden_states
279
+
280
+ class AttnProcessor2_0WithProjection(torch.nn.Module):
281
+ r"""
282
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ hidden_size=None,
288
+ cross_attention_dim=None,
289
+ ):
290
+ super().__init__()
291
+ if not hasattr(F, "scaled_dot_product_attention"):
292
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
293
+ self.before_proj_size = 1024
294
+ self.after_proj_size = 768
295
+ self.visual_proj = nn.Linear(self.before_proj_size, self.after_proj_size)
296
+
297
+ def __call__(
298
+ self,
299
+ attn,
300
+ hidden_states,
301
+ encoder_hidden_states=None,
302
+ attention_mask=None,
303
+ temb=None,
304
+ ):
305
+ residual = hidden_states
306
+ # encoder_hidden_states = self.visual_proj(encoder_hidden_states)
307
+
308
+ if attn.spatial_norm is not None:
309
+ hidden_states = attn.spatial_norm(hidden_states, temb)
310
+
311
+ input_ndim = hidden_states.ndim
312
+
313
+ if input_ndim == 4:
314
+ batch_size, channel, height, width = hidden_states.shape
315
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
316
+
317
+ batch_size, sequence_length, _ = (
318
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
319
+ )
320
+
321
+ if attention_mask is not None:
322
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
323
+ # scaled_dot_product_attention expects attention_mask shape to be
324
+ # (batch, heads, source_length, target_length)
325
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
326
+
327
+ if attn.group_norm is not None:
328
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
329
+
330
+ query = attn.to_q(hidden_states)
331
+
332
+ if encoder_hidden_states is None:
333
+ encoder_hidden_states = hidden_states
334
+ elif attn.norm_cross:
335
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
336
+
337
+ key = attn.to_k(encoder_hidden_states)
338
+ value = attn.to_v(encoder_hidden_states)
339
+
340
+ inner_dim = key.shape[-1]
341
+ head_dim = inner_dim // attn.heads
342
+
343
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
344
+
345
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
346
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
347
+
348
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
349
+ # TODO: add support for attn.scale when we move to Torch 2.1
350
+ hidden_states = F.scaled_dot_product_attention(
351
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
352
+ )
353
+
354
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
355
+ hidden_states = hidden_states.to(query.dtype)
356
+
357
+ # linear proj
358
+ hidden_states = attn.to_out[0](hidden_states)
359
+ # dropout
360
+ hidden_states = attn.to_out[1](hidden_states)
361
+
362
+ if input_ndim == 4:
363
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
364
+
365
+ if attn.residual_connection:
366
+ hidden_states = hidden_states + residual
367
+
368
+ hidden_states = hidden_states / attn.rescale_output_factor
369
+
370
+ return hidden_states
371
+
372
+ class IPAttnProcessor2_0(torch.nn.Module):
373
+ r"""
374
+ Attention processor for IP-Adapater for PyTorch 2.0.
375
+ Args:
376
+ hidden_size (`int`):
377
+ The hidden size of the attention layer.
378
+ cross_attention_dim (`int`):
379
+ The number of channels in the `encoder_hidden_states`.
380
+ scale (`float`, defaults to 1.0):
381
+ the weight scale of image prompt.
382
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
383
+ The context length of the image features.
384
+ """
385
+
386
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
387
+ super().__init__()
388
+
389
+ if not hasattr(F, "scaled_dot_product_attention"):
390
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
391
+
392
+ self.hidden_size = hidden_size
393
+ self.cross_attention_dim = cross_attention_dim
394
+ self.scale = scale
395
+ self.num_tokens = num_tokens
396
+
397
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
398
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
399
+
400
+ def __call__(
401
+ self,
402
+ attn,
403
+ hidden_states,
404
+ encoder_hidden_states=None,
405
+ attention_mask=None,
406
+ temb=None,
407
+ ):
408
+ residual = hidden_states
409
+
410
+ if attn.spatial_norm is not None:
411
+ hidden_states = attn.spatial_norm(hidden_states, temb)
412
+
413
+ input_ndim = hidden_states.ndim
414
+
415
+ if input_ndim == 4:
416
+ batch_size, channel, height, width = hidden_states.shape
417
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
418
+
419
+ batch_size, sequence_length, _ = (
420
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
421
+ )
422
+
423
+ if attention_mask is not None:
424
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
425
+ # scaled_dot_product_attention expects attention_mask shape to be
426
+ # (batch, heads, source_length, target_length)
427
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
428
+
429
+ if attn.group_norm is not None:
430
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
431
+
432
+ query = attn.to_q(hidden_states)
433
+
434
+ if encoder_hidden_states is None:
435
+ encoder_hidden_states = hidden_states
436
+ else:
437
+ # get encoder_hidden_states, ip_hidden_states
438
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
439
+ encoder_hidden_states, ip_hidden_states = (
440
+ encoder_hidden_states[:, :end_pos, :],
441
+ encoder_hidden_states[:, end_pos:, :],
442
+ )
443
+ if attn.norm_cross:
444
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
445
+
446
+ key = attn.to_k(encoder_hidden_states)
447
+ value = attn.to_v(encoder_hidden_states)
448
+
449
+ inner_dim = key.shape[-1]
450
+ head_dim = inner_dim // attn.heads
451
+
452
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
453
+
454
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
455
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
456
+
457
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
458
+ # TODO: add support for attn.scale when we move to Torch 2.1
459
+ hidden_states = F.scaled_dot_product_attention(
460
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
461
+ )
462
+
463
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
464
+ hidden_states = hidden_states.to(query.dtype)
465
+
466
+ # for ip-adapter
467
+ ip_key = self.to_k_ip(ip_hidden_states)
468
+ ip_value = self.to_v_ip(ip_hidden_states)
469
+
470
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
471
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
472
+
473
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
474
+ # TODO: add support for attn.scale when we move to Torch 2.1
475
+ ip_hidden_states = F.scaled_dot_product_attention(
476
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
477
+ )
478
+ with torch.no_grad():
479
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
480
+ #print(self.attn_map.shape)
481
+
482
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
483
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
484
+
485
+ hidden_states = hidden_states + self.scale * ip_hidden_states
486
+
487
+ # linear proj
488
+ hidden_states = attn.to_out[0](hidden_states)
489
+ # dropout
490
+ hidden_states = attn.to_out[1](hidden_states)
491
+
492
+ if input_ndim == 4:
493
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
494
+
495
+ if attn.residual_connection:
496
+ hidden_states = hidden_states + residual
497
+
498
+ hidden_states = hidden_states / attn.rescale_output_factor
499
+
500
+ return hidden_states
501
+
502
+ ## for controlnet
503
+ class CNAttnProcessor:
504
+ r"""
505
+ Default processor for performing attention-related computations.
506
+ """
507
+
508
+ def __init__(self, num_tokens=4):
509
+ self.num_tokens = num_tokens
510
+
511
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
512
+ residual = hidden_states
513
+
514
+ if attn.spatial_norm is not None:
515
+ hidden_states = attn.spatial_norm(hidden_states, temb)
516
+
517
+ input_ndim = hidden_states.ndim
518
+
519
+ if input_ndim == 4:
520
+ batch_size, channel, height, width = hidden_states.shape
521
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
522
+
523
+ batch_size, sequence_length, _ = (
524
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
525
+ )
526
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
527
+
528
+ if attn.group_norm is not None:
529
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
530
+
531
+ query = attn.to_q(hidden_states)
532
+
533
+ if encoder_hidden_states is None:
534
+ encoder_hidden_states = hidden_states
535
+ else:
536
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
537
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
538
+ if attn.norm_cross:
539
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
540
+
541
+ key = attn.to_k(encoder_hidden_states)
542
+ value = attn.to_v(encoder_hidden_states)
543
+
544
+ query = attn.head_to_batch_dim(query)
545
+ key = attn.head_to_batch_dim(key)
546
+ value = attn.head_to_batch_dim(value)
547
+
548
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
549
+ hidden_states = torch.bmm(attention_probs, value)
550
+ hidden_states = attn.batch_to_head_dim(hidden_states)
551
+
552
+ # linear proj
553
+ hidden_states = attn.to_out[0](hidden_states)
554
+ # dropout
555
+ hidden_states = attn.to_out[1](hidden_states)
556
+
557
+ if input_ndim == 4:
558
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
559
+
560
+ if attn.residual_connection:
561
+ hidden_states = hidden_states + residual
562
+
563
+ hidden_states = hidden_states / attn.rescale_output_factor
564
+
565
+ return hidden_states
566
+
567
+
568
+ class CNAttnProcessor2_0:
569
+ r"""
570
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
571
+ """
572
+
573
+ def __init__(self, num_tokens=4):
574
+ if not hasattr(F, "scaled_dot_product_attention"):
575
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
576
+ self.num_tokens = num_tokens
577
+
578
+ def __call__(
579
+ self,
580
+ attn,
581
+ hidden_states,
582
+ encoder_hidden_states=None,
583
+ attention_mask=None,
584
+ temb=None,
585
+ ):
586
+ residual = hidden_states
587
+
588
+ if attn.spatial_norm is not None:
589
+ hidden_states = attn.spatial_norm(hidden_states, temb)
590
+
591
+ input_ndim = hidden_states.ndim
592
+
593
+ if input_ndim == 4:
594
+ batch_size, channel, height, width = hidden_states.shape
595
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
596
+
597
+ batch_size, sequence_length, _ = (
598
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
599
+ )
600
+
601
+ if attention_mask is not None:
602
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
603
+ # scaled_dot_product_attention expects attention_mask shape to be
604
+ # (batch, heads, source_length, target_length)
605
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
606
+
607
+ if attn.group_norm is not None:
608
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
609
+
610
+ query = attn.to_q(hidden_states)
611
+
612
+ if encoder_hidden_states is None:
613
+ encoder_hidden_states = hidden_states
614
+ else:
615
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
616
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
617
+ if attn.norm_cross:
618
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
619
+
620
+ key = attn.to_k(encoder_hidden_states)
621
+ value = attn.to_v(encoder_hidden_states)
622
+
623
+ inner_dim = key.shape[-1]
624
+ head_dim = inner_dim // attn.heads
625
+
626
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
627
+
628
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
629
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
630
+
631
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
632
+ # TODO: add support for attn.scale when we move to Torch 2.1
633
+ hidden_states = F.scaled_dot_product_attention(
634
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
635
+ )
636
+
637
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
638
+ hidden_states = hidden_states.to(query.dtype)
639
+
640
+ # linear proj
641
+ hidden_states = attn.to_out[0](hidden_states)
642
+ # dropout
643
+ hidden_states = attn.to_out[1](hidden_states)
644
+
645
+ if input_ndim == 4:
646
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
647
+
648
+ if attn.residual_connection:
649
+ hidden_states = hidden_states + residual
650
+
651
+ hidden_states = hidden_states / attn.rescale_output_factor
652
+
653
+ return hidden_states
foleycrafter/models/adapters/ip_adapter.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import numpy as np
5
+
6
+ import os
7
+ from typing import List
8
+
9
+ from diffusers import StableDiffusionPipeline
10
+ from diffusers.pipelines.controlnet import MultiControlNetModel
11
+ from PIL import Image
12
+ from safetensors import safe_open
13
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
14
+
15
+ from foleycrafter.models.adapters.resampler import Resampler
16
+ from foleycrafter.models.adapters.utils import is_torch2_available
17
+
18
+ class IPAdapter(torch.nn.Module):
19
+ """IP-Adapter"""
20
+ def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
21
+ super().__init__()
22
+ self.unet = unet
23
+ self.image_proj_model = image_proj_model
24
+ self.adapter_modules = adapter_modules
25
+
26
+ if ckpt_path is not None:
27
+ self.load_from_checkpoint(ckpt_path)
28
+
29
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
30
+ ip_tokens = self.image_proj_model(image_embeds)
31
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
32
+ # Predict the noise residual
33
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
34
+ return noise_pred
35
+
36
+ def load_from_checkpoint(self, ckpt_path: str):
37
+ # Calculate original checksums
38
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
39
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
40
+
41
+ state_dict = torch.load(ckpt_path, map_location="cpu")
42
+
43
+ # Load state dict for image_proj_model and adapter_modules
44
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
45
+ self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
46
+
47
+ # Calculate new checksums
48
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
49
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
50
+
51
+ # Verify if the weights have changed
52
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
53
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
54
+
55
+ print(f"Successfully loaded weights from checkpoint {ckpt_path}")
56
+
57
+ class VideoProjModel(torch.nn.Module):
58
+ """Projection Model"""
59
+
60
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=1, video_frame=50):
61
+ super().__init__()
62
+
63
+ self.cross_attention_dim = cross_attention_dim
64
+ self.clip_extra_context_tokens = clip_extra_context_tokens
65
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
66
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
67
+
68
+ self.video_frame = video_frame
69
+
70
+ def forward(self, image_embeds):
71
+ embeds = image_embeds
72
+ clip_extra_context_tokens = self.proj(embeds)
73
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
74
+ return clip_extra_context_tokens
75
+
76
+ class ImageProjModel(torch.nn.Module):
77
+ """Projection Model"""
78
+
79
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
80
+ super().__init__()
81
+
82
+ self.cross_attention_dim = cross_attention_dim
83
+ self.clip_extra_context_tokens = clip_extra_context_tokens
84
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
85
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
86
+
87
+ def forward(self, image_embeds):
88
+ embeds = image_embeds
89
+ clip_extra_context_tokens = self.proj(embeds).reshape(
90
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
91
+ )
92
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
93
+ return clip_extra_context_tokens
94
+
95
+
96
+ class MLPProjModel(torch.nn.Module):
97
+ """SD model with image prompt"""
98
+ def zero_initialize(module):
99
+ for param in module.parameters():
100
+ param.data.zero_()
101
+
102
+ def zero_initialize_last_layer(module):
103
+ last_layer = None
104
+ for module_name, layer in module.named_modules():
105
+ if isinstance(layer, torch.nn.Linear):
106
+ last_layer = layer
107
+
108
+ if last_layer is not None:
109
+ last_layer.weight.data.zero_()
110
+ last_layer.bias.data.zero_()
111
+
112
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
113
+
114
+ super().__init__()
115
+
116
+ self.proj = torch.nn.Sequential(
117
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
118
+ torch.nn.GELU(),
119
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
120
+ torch.nn.LayerNorm(cross_attention_dim)
121
+ )
122
+ # zero initialize the last layer
123
+ # self.zero_initialize_last_layer()
124
+
125
+ def forward(self, image_embeds):
126
+ clip_extra_context_tokens = self.proj(image_embeds)
127
+ return clip_extra_context_tokens
128
+
129
+ class V2AMapperMLP(torch.nn.Module):
130
+ def __init__(self, cross_attention_dim=512, clip_embeddings_dim=512, mult=4):
131
+ super().__init__()
132
+ self.proj = torch.nn.Sequential(
133
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim * mult),
134
+ torch.nn.GELU(),
135
+ torch.nn.Linear(clip_embeddings_dim * mult, cross_attention_dim),
136
+ torch.nn.LayerNorm(cross_attention_dim)
137
+ )
138
+
139
+ def forward(self, image_embeds):
140
+ clip_extra_context_tokens = self.proj(image_embeds)
141
+ return clip_extra_context_tokens
142
+
143
+ class TimeProjModel(torch.nn.Module):
144
+ def __init__(self, positive_len, out_dim, feature_type="text-only", frame_nums:int=64):
145
+ super().__init__()
146
+ self.positive_len = positive_len
147
+ self.out_dim = out_dim
148
+
149
+ self.position_dim = frame_nums
150
+
151
+ if isinstance(out_dim, tuple):
152
+ out_dim = out_dim[0]
153
+
154
+ if feature_type == "text-only":
155
+ self.linears = nn.Sequential(
156
+ nn.Linear(self.positive_len + self.position_dim, 512),
157
+ nn.SiLU(),
158
+ nn.Linear(512, 512),
159
+ nn.SiLU(),
160
+ nn.Linear(512, out_dim),
161
+ )
162
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
163
+
164
+ elif feature_type == "text-image":
165
+ self.linears_text = nn.Sequential(
166
+ nn.Linear(self.positive_len + self.position_dim, 512),
167
+ nn.SiLU(),
168
+ nn.Linear(512, 512),
169
+ nn.SiLU(),
170
+ nn.Linear(512, out_dim),
171
+ )
172
+ self.linears_image = nn.Sequential(
173
+ nn.Linear(self.positive_len + self.position_dim, 512),
174
+ nn.SiLU(),
175
+ nn.Linear(512, 512),
176
+ nn.SiLU(),
177
+ nn.Linear(512, out_dim),
178
+ )
179
+ self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
180
+ self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
181
+
182
+ # self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
183
+
184
+ def forward(
185
+ self,
186
+ boxes,
187
+ masks,
188
+ positive_embeddings=None,
189
+ ):
190
+ masks = masks.unsqueeze(-1)
191
+
192
+ # # embedding position (it may includes padding as placeholder)
193
+ # xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
194
+
195
+ # # learnable null embedding
196
+ # xyxy_null = self.null_position_feature.view(1, 1, -1)
197
+
198
+ # # replace padding with learnable null embedding
199
+ # xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
200
+
201
+ time_embeds = boxes
202
+
203
+ # positionet with text only information
204
+ if positive_embeddings is not None:
205
+ # learnable null embedding
206
+ positive_null = self.null_positive_feature.view(1, 1, -1)
207
+
208
+ # replace padding with learnable null embedding
209
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
210
+
211
+ objs = self.linears(torch.cat([positive_embeddings, time_embeds], dim=-1))
212
+
213
+ # positionet with text and image infomation
214
+ else:
215
+ raise NotImplementedError
216
+
217
+ return objs
foleycrafter/models/adapters/resampler.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class Resampler(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim=1024,
85
+ depth=8,
86
+ dim_head=64,
87
+ heads=16,
88
+ num_queries=8,
89
+ embedding_dim=768,
90
+ output_dim=1024,
91
+ ff_mult=4,
92
+ max_seq_len: int = 257, # CLIP tokens + CLS token
93
+ apply_pos_emb: bool = False,
94
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
+ ):
96
+ super().__init__()
97
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
+
99
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
+
101
+ self.proj_in = nn.Linear(embedding_dim, dim)
102
+
103
+ self.proj_out = nn.Linear(dim, output_dim)
104
+ self.norm_out = nn.LayerNorm(output_dim)
105
+
106
+ self.to_latents_from_mean_pooled_seq = (
107
+ nn.Sequential(
108
+ nn.LayerNorm(dim),
109
+ nn.Linear(dim, dim * num_latents_mean_pooled),
110
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
+ )
112
+ if num_latents_mean_pooled > 0
113
+ else None
114
+ )
115
+
116
+ self.layers = nn.ModuleList([])
117
+ for _ in range(depth):
118
+ self.layers.append(
119
+ nn.ModuleList(
120
+ [
121
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
+ FeedForward(dim=dim, mult=ff_mult),
123
+ ]
124
+ )
125
+ )
126
+
127
+ def forward(self, x):
128
+ if self.pos_emb is not None:
129
+ n, device = x.shape[1], x.device
130
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
131
+ x = x + pos_emb
132
+
133
+ latents = self.latents.repeat(x.size(0), 1, 1)
134
+
135
+ x = self.proj_in(x)
136
+
137
+ if self.to_latents_from_mean_pooled_seq:
138
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
+
142
+ for attn, ff in self.layers:
143
+ latents = attn(x, latents) + latents
144
+ latents = ff(latents) + latents
145
+
146
+ latents = self.proj_out(latents)
147
+ return self.norm_out(latents)
148
+
149
+
150
+ def masked_mean(t, *, dim, mask=None):
151
+ if mask is None:
152
+ return t.mean(dim=dim)
153
+
154
+ denom = mask.sum(dim=dim, keepdim=True)
155
+ mask = rearrange(mask, "b n -> b n 1")
156
+ masked_t = t.masked_fill(~mask, 0.0)
157
+
158
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
foleycrafter/models/adapters/transformer.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint
4
+
5
+ from typing import Any, Optional, Tuple, Union
6
+
7
+ class Attention(nn.Module):
8
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
9
+
10
+ def __init__(self, hidden_size, num_attention_heads, attention_head_dim, attention_dropout=0.0):
11
+ super().__init__()
12
+ self.embed_dim = hidden_size
13
+ self.num_heads = num_attention_heads
14
+ self.head_dim = attention_head_dim
15
+
16
+ self.scale = self.head_dim**-0.5
17
+ self.dropout = attention_dropout
18
+
19
+ self.inner_dim = self.head_dim * self.num_heads
20
+
21
+ self.k_proj = nn.Linear(self.embed_dim, self.inner_dim)
22
+ self.v_proj = nn.Linear(self.embed_dim, self.inner_dim)
23
+ self.q_proj = nn.Linear(self.embed_dim, self.inner_dim)
24
+ self.out_proj = nn.Linear(self.inner_dim, self.embed_dim)
25
+
26
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
27
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
28
+
29
+ def forward(
30
+ self,
31
+ hidden_states: torch.Tensor,
32
+ attention_mask: Optional[torch.Tensor] = None,
33
+ causal_attention_mask: Optional[torch.Tensor] = None,
34
+ output_attentions: Optional[bool] = False,
35
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
36
+ """Input shape: Batch x Time x Channel"""
37
+
38
+ bsz, tgt_len, embed_dim = hidden_states.size()
39
+
40
+ # get query proj
41
+ query_states = self.q_proj(hidden_states) * self.scale
42
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
43
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
44
+
45
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
46
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
47
+ key_states = key_states.view(*proj_shape)
48
+ value_states = value_states.view(*proj_shape)
49
+
50
+ src_len = key_states.size(1)
51
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
52
+
53
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
54
+ raise ValueError(
55
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
56
+ f" {attn_weights.size()}"
57
+ )
58
+
59
+ # apply the causal_attention_mask first
60
+ if causal_attention_mask is not None:
61
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
62
+ raise ValueError(
63
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
64
+ f" {causal_attention_mask.size()}"
65
+ )
66
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
67
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
68
+
69
+ if attention_mask is not None:
70
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
71
+ raise ValueError(
72
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
73
+ )
74
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
75
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
76
+
77
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
78
+
79
+ if output_attentions:
80
+ # this operation is a bit akward, but it's required to
81
+ # make sure that attn_weights keeps its gradient.
82
+ # In order to do so, attn_weights have to reshaped
83
+ # twice and have to be reused in the following
84
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
85
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
86
+ else:
87
+ attn_weights_reshaped = None
88
+
89
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
90
+
91
+ attn_output = torch.bmm(attn_probs, value_states)
92
+
93
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
94
+ raise ValueError(
95
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
96
+ f" {attn_output.size()}"
97
+ )
98
+
99
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
100
+ attn_output = attn_output.transpose(1, 2)
101
+ attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
102
+
103
+ attn_output = self.out_proj(attn_output)
104
+
105
+ return attn_output, attn_weights_reshaped
106
+
107
+
108
+ class MLP(nn.Module):
109
+ def __init__(self, hidden_size, intermediate_size, mult=4):
110
+ super().__init__()
111
+ self.activation_fn = nn.SiLU()
112
+ self.fc1 = nn.Linear(hidden_size, intermediate_size * mult)
113
+ self.fc2 = nn.Linear(intermediate_size * mult, hidden_size)
114
+
115
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
116
+ hidden_states = self.fc1(hidden_states)
117
+ hidden_states = self.activation_fn(hidden_states)
118
+ hidden_states = self.fc2(hidden_states)
119
+ return hidden_states
120
+
121
+ class Transformer(nn.Module):
122
+ def __init__(self, depth=12):
123
+ super().__init__()
124
+ self.layers = nn.ModuleList([TransformerBlock() for _ in range(depth)])
125
+ def forward(
126
+ self,
127
+ hidden_states: torch.Tensor,
128
+ attention_mask: torch.Tensor=None,
129
+ causal_attention_mask: torch.Tensor=None,
130
+ output_attentions: Optional[bool] = False,
131
+ ) -> Tuple[torch.FloatTensor]:
132
+ """
133
+ Args:
134
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
135
+ attention_mask (`torch.FloatTensor`): attention mask of size
136
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
137
+ `(config.encoder_attention_heads,)`.
138
+ output_attentions (`bool`, *optional*):
139
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
140
+ returned tensors for more detail.
141
+ """
142
+ for layer in self.layers:
143
+ hidden_states = layer(
144
+ hidden_states=hidden_states,
145
+ attention_mask=attention_mask,
146
+ causal_attention_mask=causal_attention_mask,
147
+ output_attentions=output_attentions,
148
+ )
149
+
150
+ return hidden_states
151
+
152
+ class TransformerBlock(nn.Module):
153
+ def __init__(self, hidden_size=512, num_attention_heads=12, attention_head_dim=64, attention_dropout=0.0, dropout=0.0, eps=1e-5):
154
+ super().__init__()
155
+ self.embed_dim = hidden_size
156
+ self.self_attn = Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim)
157
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps)
158
+ self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size)
159
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps)
160
+
161
+ def forward(
162
+ self,
163
+ hidden_states: torch.Tensor,
164
+ attention_mask: torch.Tensor=None,
165
+ causal_attention_mask: torch.Tensor=None,
166
+ output_attentions: Optional[bool] = False,
167
+ ) -> Tuple[torch.FloatTensor]:
168
+ """
169
+ Args:
170
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
171
+ attention_mask (`torch.FloatTensor`): attention mask of size
172
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
173
+ `(config.encoder_attention_heads,)`.
174
+ output_attentions (`bool`, *optional*):
175
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
176
+ returned tensors for more detail.
177
+ """
178
+ residual = hidden_states
179
+
180
+ hidden_states = self.layer_norm1(hidden_states)
181
+ hidden_states, attn_weights = self.self_attn(
182
+ hidden_states=hidden_states,
183
+ attention_mask=attention_mask,
184
+ causal_attention_mask=causal_attention_mask,
185
+ output_attentions=output_attentions,
186
+ )
187
+ hidden_states = residual + hidden_states
188
+
189
+ residual = hidden_states
190
+ hidden_states = self.layer_norm2(hidden_states)
191
+ hidden_states = self.mlp(hidden_states)
192
+ hidden_states = residual + hidden_states
193
+
194
+ outputs = (hidden_states,)
195
+
196
+ if output_attentions:
197
+ outputs += (attn_weights,)
198
+
199
+ return outputs[0]
200
+
201
+ class DiffusionTransformerBlock(nn.Module):
202
+ def __init__(self, hidden_size=512, num_attention_heads=12, attention_head_dim=64, attention_dropout=0.0, dropout=0.0, eps=1e-5):
203
+ super().__init__()
204
+ self.embed_dim = hidden_size
205
+ self.self_attn = Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim)
206
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps)
207
+ self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size)
208
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps)
209
+ self.output_token = nn.Parameter(torch.randn(1, hidden_size))
210
+
211
+ def forward(
212
+ self,
213
+ hidden_states: torch.Tensor,
214
+ attention_mask: torch.Tensor=None,
215
+ causal_attention_mask: torch.Tensor=None,
216
+ output_attentions: Optional[bool] = False,
217
+ ) -> Tuple[torch.FloatTensor]:
218
+ """
219
+ Args:
220
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
221
+ attention_mask (`torch.FloatTensor`): attention mask of size
222
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
223
+ `(config.encoder_attention_heads,)`.
224
+ output_attentions (`bool`, *optional*):
225
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
226
+ returned tensors for more detail.
227
+ """
228
+ output_token = self.output_token.unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)
229
+ hidden_states = torch.cat([output_token, hidden_states], dim=1)
230
+ residual = hidden_states
231
+
232
+ hidden_states = self.layer_norm1(hidden_states)
233
+ hidden_states, attn_weights = self.self_attn(
234
+ hidden_states=hidden_states,
235
+ attention_mask=attention_mask,
236
+ causal_attention_mask=causal_attention_mask,
237
+ output_attentions=output_attentions,
238
+ )
239
+ hidden_states = residual + hidden_states
240
+
241
+ residual = hidden_states
242
+ hidden_states = self.layer_norm2(hidden_states)
243
+ hidden_states = self.mlp(hidden_states)
244
+ hidden_states = residual + hidden_states
245
+
246
+ outputs = (hidden_states,)
247
+
248
+ if output_attentions:
249
+ outputs += (attn_weights,)
250
+
251
+ return outputs[0][:,0:1,...]
252
+
253
+ class V2AMapperMLP(nn.Module):
254
+ def __init__(self, input_dim=512, output_dim=512, expansion_rate=4):
255
+ super().__init__()
256
+ self.linear = nn.Linear(input_dim, input_dim * expansion_rate)
257
+ self.silu = nn.SiLU()
258
+ self.layer_norm = nn.LayerNorm(input_dim * expansion_rate)
259
+ self.linear2 = nn.Linear(input_dim * expansion_rate, output_dim)
260
+
261
+ def forward(self, x):
262
+
263
+ x = self.linear(x)
264
+ x = self.silu(x)
265
+ x = self.layer_norm(x)
266
+ x = self.linear2(x)
267
+
268
+ return x
269
+
270
+ class ImageProjModel(torch.nn.Module):
271
+ """Projection Model"""
272
+
273
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
274
+ super().__init__()
275
+
276
+ self.cross_attention_dim = cross_attention_dim
277
+ self.clip_extra_context_tokens = clip_extra_context_tokens
278
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
279
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
280
+
281
+ self.zero_initialize_last_layer()
282
+
283
+ def zero_initialize_last_layer(module):
284
+ last_layer = None
285
+ for module_name, layer in module.named_modules():
286
+ if isinstance(layer, torch.nn.Linear):
287
+ last_layer = layer
288
+
289
+ if last_layer is not None:
290
+ last_layer.weight.data.zero_()
291
+ last_layer.bias.data.zero_()
292
+
293
+ def forward(self, image_embeds):
294
+ embeds = image_embeds
295
+ clip_extra_context_tokens = self.proj(embeds).reshape(
296
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
297
+ )
298
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
299
+ return clip_extra_context_tokens
300
+
301
+ class VisionAudioAdapter(torch.nn.Module):
302
+ def __init__(
303
+ self,
304
+ embedding_size=768,
305
+ expand_dim=4,
306
+ token_num=4,
307
+ ):
308
+ super().__init__()
309
+
310
+ self.mapper = V2AMapperMLP(
311
+ embedding_size,
312
+ embedding_size,
313
+ expansion_rate=expand_dim,
314
+ )
315
+
316
+ self.proj = ImageProjModel(
317
+ cross_attention_dim=embedding_size,
318
+ clip_embeddings_dim=embedding_size,
319
+ clip_extra_context_tokens=token_num,
320
+ )
321
+
322
+ def forward(self, image_embeds):
323
+ image_embeds = self.mapper(image_embeds)
324
+ image_embeds = self.proj(image_embeds)
325
+ return image_embeds
326
+
327
+
foleycrafter/models/adapters/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ attn_maps = {}
7
+ def hook_fn(name):
8
+ def forward_hook(module, input, output):
9
+ if hasattr(module.processor, "attn_map"):
10
+ attn_maps[name] = module.processor.attn_map
11
+ del module.processor.attn_map
12
+
13
+ return forward_hook
14
+
15
+ def register_cross_attention_hook(unet):
16
+ for name, module in unet.named_modules():
17
+ if name.split('.')[-1].startswith('attn2'):
18
+ module.register_forward_hook(hook_fn(name))
19
+
20
+ return unet
21
+
22
+ def upscale(attn_map, target_size):
23
+ attn_map = torch.mean(attn_map, dim=0)
24
+ attn_map = attn_map.permute(1,0)
25
+ temp_size = None
26
+
27
+ for i in range(0,5):
28
+ scale = 2 ** i
29
+ if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
30
+ temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
31
+ break
32
+
33
+ assert temp_size is not None, "temp_size cannot is None"
34
+
35
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size)
36
+
37
+ attn_map = F.interpolate(
38
+ attn_map.unsqueeze(0).to(dtype=torch.float32),
39
+ size=target_size,
40
+ mode='bilinear',
41
+ align_corners=False
42
+ )[0]
43
+
44
+ attn_map = torch.softmax(attn_map, dim=0)
45
+ return attn_map
46
+ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
47
+
48
+ idx = 0 if instance_or_negative else 1
49
+ net_attn_maps = []
50
+
51
+ for name, attn_map in attn_maps.items():
52
+ attn_map = attn_map.cpu() if detach else attn_map
53
+ attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
54
+ attn_map = upscale(attn_map, image_size)
55
+ net_attn_maps.append(attn_map)
56
+
57
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
58
+
59
+ return net_attn_maps
60
+
61
+ def attnmaps2images(net_attn_maps):
62
+
63
+ #total_attn_scores = 0
64
+ images = []
65
+
66
+ for attn_map in net_attn_maps:
67
+ attn_map = attn_map.cpu().numpy()
68
+ #total_attn_scores += attn_map.mean().item()
69
+
70
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
71
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
72
+ #print("norm: ", normalized_attn_map.shape)
73
+ image = Image.fromarray(normalized_attn_map)
74
+
75
+ #image = fix_save_attn_map(attn_map)
76
+ images.append(image)
77
+
78
+ #print(total_attn_scores)
79
+ return images
80
+ def is_torch2_available():
81
+ return hasattr(F, "scaled_dot_product_attention")
foleycrafter/models/auffusion/attention.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import USE_PEFT_BACKEND
21
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
22
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
23
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
24
+ from diffusers.models.lora import LoRACompatibleLinear
25
+ from diffusers.models.normalization import\
26
+ AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
27
+
28
+ from foleycrafter.models.auffusion.attention_processor import Attention
29
+
30
+ def _chunked_feed_forward(
31
+ ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
32
+ ):
33
+ # "feed_forward_chunk_size" can be used to save memory
34
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
35
+ raise ValueError(
36
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
37
+ )
38
+
39
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
40
+ if lora_scale is None:
41
+ ff_output = torch.cat(
42
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
43
+ dim=chunk_dim,
44
+ )
45
+ else:
46
+ # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
47
+ ff_output = torch.cat(
48
+ [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
49
+ dim=chunk_dim,
50
+ )
51
+
52
+ return ff_output
53
+
54
+
55
+ @maybe_allow_in_graph
56
+ class GatedSelfAttentionDense(nn.Module):
57
+ r"""
58
+ A gated self-attention dense layer that combines visual features and object features.
59
+
60
+ Parameters:
61
+ query_dim (`int`): The number of channels in the query.
62
+ context_dim (`int`): The number of channels in the context.
63
+ n_heads (`int`): The number of heads to use for attention.
64
+ d_head (`int`): The number of channels in each head.
65
+ """
66
+
67
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
68
+ super().__init__()
69
+
70
+ # we need a linear projection since we need cat visual feature and obj feature
71
+ self.linear = nn.Linear(context_dim, query_dim)
72
+
73
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
74
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
75
+
76
+ self.norm1 = nn.LayerNorm(query_dim)
77
+ self.norm2 = nn.LayerNorm(query_dim)
78
+
79
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
80
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
81
+
82
+ self.enabled = True
83
+
84
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
85
+ if not self.enabled:
86
+ return x
87
+
88
+ n_visual = x.shape[1]
89
+ objs = self.linear(objs)
90
+
91
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
92
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
93
+
94
+ return x
95
+
96
+
97
+ @maybe_allow_in_graph
98
+ class BasicTransformerBlock(nn.Module):
99
+ r"""
100
+ A basic Transformer block.
101
+
102
+ Parameters:
103
+ dim (`int`): The number of channels in the input and output.
104
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
105
+ attention_head_dim (`int`): The number of channels in each head.
106
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
107
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
108
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
109
+ num_embeds_ada_norm (:
110
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
111
+ attention_bias (:
112
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
113
+ only_cross_attention (`bool`, *optional*):
114
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
115
+ double_self_attention (`bool`, *optional*):
116
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
117
+ upcast_attention (`bool`, *optional*):
118
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
119
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
120
+ Whether to use learnable elementwise affine parameters for normalization.
121
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
122
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
123
+ final_dropout (`bool` *optional*, defaults to False):
124
+ Whether to apply a final dropout after the last feed-forward layer.
125
+ attention_type (`str`, *optional*, defaults to `"default"`):
126
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
127
+ positional_embeddings (`str`, *optional*, defaults to `None`):
128
+ The type of positional embeddings to apply to.
129
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
130
+ The maximum number of positional embeddings to apply.
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ dim: int,
136
+ num_attention_heads: int,
137
+ attention_head_dim: int,
138
+ dropout=0.0,
139
+ cross_attention_dim: Optional[int] = None,
140
+ activation_fn: str = "geglu",
141
+ num_embeds_ada_norm: Optional[int] = None,
142
+ attention_bias: bool = False,
143
+ only_cross_attention: bool = False,
144
+ double_self_attention: bool = False,
145
+ upcast_attention: bool = False,
146
+ norm_elementwise_affine: bool = True,
147
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
148
+ norm_eps: float = 1e-5,
149
+ final_dropout: bool = False,
150
+ attention_type: str = "default",
151
+ positional_embeddings: Optional[str] = None,
152
+ num_positional_embeddings: Optional[int] = None,
153
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
154
+ ada_norm_bias: Optional[int] = None,
155
+ ff_inner_dim: Optional[int] = None,
156
+ ff_bias: bool = True,
157
+ attention_out_bias: bool = True,
158
+ ):
159
+ super().__init__()
160
+ self.only_cross_attention = only_cross_attention
161
+
162
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
163
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
164
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
165
+ self.use_layer_norm = norm_type == "layer_norm"
166
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
167
+
168
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
169
+ raise ValueError(
170
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
171
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
172
+ )
173
+
174
+ if positional_embeddings and (num_positional_embeddings is None):
175
+ raise ValueError(
176
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
177
+ )
178
+
179
+ if positional_embeddings == "sinusoidal":
180
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
181
+ else:
182
+ self.pos_embed = None
183
+
184
+ # Define 3 blocks. Each block has its own normalization layer.
185
+ # 1. Self-Attn
186
+ if self.use_ada_layer_norm:
187
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
188
+ elif self.use_ada_layer_norm_zero:
189
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
190
+ elif self.use_ada_layer_norm_continuous:
191
+ self.norm1 = AdaLayerNormContinuous(
192
+ dim,
193
+ ada_norm_continous_conditioning_embedding_dim,
194
+ norm_elementwise_affine,
195
+ norm_eps,
196
+ ada_norm_bias,
197
+ "rms_norm",
198
+ )
199
+ else:
200
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
201
+
202
+ self.attn1 = Attention(
203
+ query_dim=dim,
204
+ heads=num_attention_heads,
205
+ dim_head=attention_head_dim,
206
+ dropout=dropout,
207
+ bias=attention_bias,
208
+ cross_attention_dim=cross_attention_dim if (only_cross_attention and not double_self_attention) else None,
209
+ upcast_attention=upcast_attention,
210
+ out_bias=attention_out_bias,
211
+ )
212
+
213
+ # 2. Cross-Attn
214
+ if cross_attention_dim is not None or double_self_attention:
215
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
216
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
217
+ # the second cross attention block.
218
+ if self.use_ada_layer_norm:
219
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
220
+ elif self.use_ada_layer_norm_continuous:
221
+ self.norm2 = AdaLayerNormContinuous(
222
+ dim,
223
+ ada_norm_continous_conditioning_embedding_dim,
224
+ norm_elementwise_affine,
225
+ norm_eps,
226
+ ada_norm_bias,
227
+ "rms_norm",
228
+ )
229
+ else:
230
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
231
+
232
+ self.attn2 = Attention(
233
+ query_dim=dim,
234
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
235
+ heads=num_attention_heads,
236
+ dim_head=attention_head_dim,
237
+ dropout=dropout,
238
+ bias=attention_bias,
239
+ upcast_attention=upcast_attention,
240
+ out_bias=attention_out_bias,
241
+ ) # is self-attn if encoder_hidden_states is none
242
+ else:
243
+ self.norm2 = None
244
+ self.attn2 = None
245
+
246
+ # 3. Feed-forward
247
+ if self.use_ada_layer_norm_continuous:
248
+ self.norm3 = AdaLayerNormContinuous(
249
+ dim,
250
+ ada_norm_continous_conditioning_embedding_dim,
251
+ norm_elementwise_affine,
252
+ norm_eps,
253
+ ada_norm_bias,
254
+ "layer_norm",
255
+ )
256
+ elif not self.use_ada_layer_norm_single:
257
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
258
+
259
+ self.ff = FeedForward(
260
+ dim,
261
+ dropout=dropout,
262
+ activation_fn=activation_fn,
263
+ final_dropout=final_dropout,
264
+ inner_dim=ff_inner_dim,
265
+ bias=ff_bias,
266
+ )
267
+
268
+ # 4. Fuser
269
+ if attention_type == "gated" or attention_type == "gated-text-image":
270
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
271
+
272
+ # 5. Scale-shift for PixArt-Alpha.
273
+ if self.use_ada_layer_norm_single:
274
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
275
+
276
+ # let chunk size default to None
277
+ self._chunk_size = None
278
+ self._chunk_dim = 0
279
+
280
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
281
+ # Sets chunk feed-forward
282
+ self._chunk_size = chunk_size
283
+ self._chunk_dim = dim
284
+
285
+ def forward(
286
+ self,
287
+ hidden_states: torch.FloatTensor,
288
+ attention_mask: Optional[torch.FloatTensor] = None,
289
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
290
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
291
+ timestep: Optional[torch.LongTensor] = None,
292
+ cross_attention_kwargs: Dict[str, Any] = None,
293
+ class_labels: Optional[torch.LongTensor] = None,
294
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
295
+ ) -> torch.FloatTensor:
296
+ # Notice that normalization is always applied before the real computation in the following blocks.
297
+ # 0. Self-Attention
298
+ batch_size = hidden_states.shape[0]
299
+
300
+ if self.use_ada_layer_norm:
301
+ norm_hidden_states = self.norm1(hidden_states, timestep)
302
+ elif self.use_ada_layer_norm_zero:
303
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
304
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
305
+ )
306
+ elif self.use_layer_norm:
307
+ norm_hidden_states = self.norm1(hidden_states)
308
+ elif self.use_ada_layer_norm_continuous:
309
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
310
+ elif self.use_ada_layer_norm_single:
311
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
312
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
313
+ ).chunk(6, dim=1)
314
+ norm_hidden_states = self.norm1(hidden_states)
315
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
316
+ norm_hidden_states = norm_hidden_states.squeeze(1)
317
+ else:
318
+ raise ValueError("Incorrect norm used")
319
+
320
+ if self.pos_embed is not None:
321
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
322
+
323
+ # 1. Retrieve lora scale.
324
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
325
+
326
+ # 2. Prepare GLIGEN inputs
327
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
328
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
329
+
330
+ attn_output = self.attn1(
331
+ norm_hidden_states,
332
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
333
+ attention_mask=attention_mask,
334
+ **cross_attention_kwargs,
335
+ )
336
+ if self.use_ada_layer_norm_zero:
337
+ attn_output = gate_msa.unsqueeze(1) * attn_output
338
+ elif self.use_ada_layer_norm_single:
339
+ attn_output = gate_msa * attn_output
340
+
341
+ hidden_states = attn_output + hidden_states
342
+ if hidden_states.ndim == 4:
343
+ hidden_states = hidden_states.squeeze(1)
344
+
345
+ # 2.5 GLIGEN Control
346
+ if gligen_kwargs is not None:
347
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
348
+
349
+ # 3. Cross-Attention
350
+ if self.attn2 is not None:
351
+ if self.use_ada_layer_norm:
352
+ norm_hidden_states = self.norm2(hidden_states, timestep)
353
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
354
+ norm_hidden_states = self.norm2(hidden_states)
355
+ elif self.use_ada_layer_norm_single:
356
+ # For PixArt norm2 isn't applied here:
357
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
358
+ norm_hidden_states = hidden_states
359
+ elif self.use_ada_layer_norm_continuous:
360
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
361
+ else:
362
+ raise ValueError("Incorrect norm")
363
+
364
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
365
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
366
+
367
+ attn_output = self.attn2(
368
+ norm_hidden_states,
369
+ encoder_hidden_states=encoder_hidden_states,
370
+ attention_mask=encoder_attention_mask,
371
+ **cross_attention_kwargs,
372
+ )
373
+ hidden_states = attn_output + hidden_states
374
+
375
+ # 4. Feed-forward
376
+ if self.use_ada_layer_norm_continuous:
377
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
378
+ elif not self.use_ada_layer_norm_single:
379
+ norm_hidden_states = self.norm3(hidden_states)
380
+
381
+ if self.use_ada_layer_norm_zero:
382
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
383
+
384
+ if self.use_ada_layer_norm_single:
385
+ norm_hidden_states = self.norm2(hidden_states)
386
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
387
+
388
+ if self._chunk_size is not None:
389
+ # "feed_forward_chunk_size" can be used to save memory
390
+ ff_output = _chunked_feed_forward(
391
+ self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
392
+ )
393
+ else:
394
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
395
+
396
+ if self.use_ada_layer_norm_zero:
397
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
398
+ elif self.use_ada_layer_norm_single:
399
+ ff_output = gate_mlp * ff_output
400
+
401
+ hidden_states = ff_output + hidden_states
402
+ if hidden_states.ndim == 4:
403
+ hidden_states = hidden_states.squeeze(1)
404
+
405
+ return hidden_states
406
+
407
+
408
+ @maybe_allow_in_graph
409
+ class TemporalBasicTransformerBlock(nn.Module):
410
+ r"""
411
+ A basic Transformer block for video like data.
412
+
413
+ Parameters:
414
+ dim (`int`): The number of channels in the input and output.
415
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
416
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
417
+ attention_head_dim (`int`): The number of channels in each head.
418
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
419
+ """
420
+
421
+ def __init__(
422
+ self,
423
+ dim: int,
424
+ time_mix_inner_dim: int,
425
+ num_attention_heads: int,
426
+ attention_head_dim: int,
427
+ cross_attention_dim: Optional[int] = None,
428
+ ):
429
+ super().__init__()
430
+ self.is_res = dim == time_mix_inner_dim
431
+
432
+ self.norm_in = nn.LayerNorm(dim)
433
+
434
+ # Define 3 blocks. Each block has its own normalization layer.
435
+ # 1. Self-Attn
436
+ self.norm_in = nn.LayerNorm(dim)
437
+ self.ff_in = FeedForward(
438
+ dim,
439
+ dim_out=time_mix_inner_dim,
440
+ activation_fn="geglu",
441
+ )
442
+
443
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
444
+ self.attn1 = Attention(
445
+ query_dim=time_mix_inner_dim,
446
+ heads=num_attention_heads,
447
+ dim_head=attention_head_dim,
448
+ cross_attention_dim=None,
449
+ )
450
+
451
+ # 2. Cross-Attn
452
+ if cross_attention_dim is not None:
453
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
454
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
455
+ # the second cross attention block.
456
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
457
+ self.attn2 = Attention(
458
+ query_dim=time_mix_inner_dim,
459
+ cross_attention_dim=cross_attention_dim,
460
+ heads=num_attention_heads,
461
+ dim_head=attention_head_dim,
462
+ ) # is self-attn if encoder_hidden_states is none
463
+ else:
464
+ self.norm2 = None
465
+ self.attn2 = None
466
+
467
+ # 3. Feed-forward
468
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
469
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
470
+
471
+ # let chunk size default to None
472
+ self._chunk_size = None
473
+ self._chunk_dim = None
474
+
475
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
476
+ # Sets chunk feed-forward
477
+ self._chunk_size = chunk_size
478
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
479
+ self._chunk_dim = 1
480
+
481
+ def forward(
482
+ self,
483
+ hidden_states: torch.FloatTensor,
484
+ num_frames: int,
485
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
486
+ ) -> torch.FloatTensor:
487
+ # Notice that normalization is always applied before the real computation in the following blocks.
488
+ # 0. Self-Attention
489
+ batch_size = hidden_states.shape[0]
490
+
491
+ batch_frames, seq_length, channels = hidden_states.shape
492
+ batch_size = batch_frames // num_frames
493
+
494
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
495
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
496
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
497
+
498
+ residual = hidden_states
499
+ hidden_states = self.norm_in(hidden_states)
500
+
501
+ if self._chunk_size is not None:
502
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
503
+ else:
504
+ hidden_states = self.ff_in(hidden_states)
505
+
506
+ if self.is_res:
507
+ hidden_states = hidden_states + residual
508
+
509
+ norm_hidden_states = self.norm1(hidden_states)
510
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
511
+ hidden_states = attn_output + hidden_states
512
+
513
+ # 3. Cross-Attention
514
+ if self.attn2 is not None:
515
+ norm_hidden_states = self.norm2(hidden_states)
516
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
517
+ hidden_states = attn_output + hidden_states
518
+
519
+ # 4. Feed-forward
520
+ norm_hidden_states = self.norm3(hidden_states)
521
+
522
+ if self._chunk_size is not None:
523
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
524
+ else:
525
+ ff_output = self.ff(norm_hidden_states)
526
+
527
+ if self.is_res:
528
+ hidden_states = ff_output + hidden_states
529
+ else:
530
+ hidden_states = ff_output
531
+
532
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
533
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
534
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
535
+
536
+ return hidden_states
537
+
538
+
539
+ class SkipFFTransformerBlock(nn.Module):
540
+ def __init__(
541
+ self,
542
+ dim: int,
543
+ num_attention_heads: int,
544
+ attention_head_dim: int,
545
+ kv_input_dim: int,
546
+ kv_input_dim_proj_use_bias: bool,
547
+ dropout=0.0,
548
+ cross_attention_dim: Optional[int] = None,
549
+ attention_bias: bool = False,
550
+ attention_out_bias: bool = True,
551
+ ):
552
+ super().__init__()
553
+ if kv_input_dim != dim:
554
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
555
+ else:
556
+ self.kv_mapper = None
557
+
558
+ self.norm1 = RMSNorm(dim, 1e-06)
559
+
560
+ self.attn1 = Attention(
561
+ query_dim=dim,
562
+ heads=num_attention_heads,
563
+ dim_head=attention_head_dim,
564
+ dropout=dropout,
565
+ bias=attention_bias,
566
+ cross_attention_dim=cross_attention_dim,
567
+ out_bias=attention_out_bias,
568
+ )
569
+
570
+ self.norm2 = RMSNorm(dim, 1e-06)
571
+
572
+ self.attn2 = Attention(
573
+ query_dim=dim,
574
+ cross_attention_dim=cross_attention_dim,
575
+ heads=num_attention_heads,
576
+ dim_head=attention_head_dim,
577
+ dropout=dropout,
578
+ bias=attention_bias,
579
+ out_bias=attention_out_bias,
580
+ )
581
+
582
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
583
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
584
+
585
+ if self.kv_mapper is not None:
586
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
587
+
588
+ norm_hidden_states = self.norm1(hidden_states)
589
+
590
+ attn_output = self.attn1(
591
+ norm_hidden_states,
592
+ encoder_hidden_states=encoder_hidden_states,
593
+ **cross_attention_kwargs,
594
+ )
595
+
596
+ hidden_states = attn_output + hidden_states
597
+
598
+ norm_hidden_states = self.norm2(hidden_states)
599
+
600
+ attn_output = self.attn2(
601
+ norm_hidden_states,
602
+ encoder_hidden_states=encoder_hidden_states,
603
+ **cross_attention_kwargs,
604
+ )
605
+
606
+ hidden_states = attn_output + hidden_states
607
+
608
+ return hidden_states
609
+
610
+
611
+ class FeedForward(nn.Module):
612
+ r"""
613
+ A feed-forward layer.
614
+
615
+ Parameters:
616
+ dim (`int`): The number of channels in the input.
617
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
618
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
619
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
620
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
621
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
622
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
623
+ """
624
+
625
+ def __init__(
626
+ self,
627
+ dim: int,
628
+ dim_out: Optional[int] = None,
629
+ mult: int = 4,
630
+ dropout: float = 0.0,
631
+ activation_fn: str = "geglu",
632
+ final_dropout: bool = False,
633
+ inner_dim=None,
634
+ bias: bool = True,
635
+ ):
636
+ super().__init__()
637
+ if inner_dim is None:
638
+ inner_dim = int(dim * mult)
639
+ dim_out = dim_out if dim_out is not None else dim
640
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
641
+
642
+ if activation_fn == "gelu":
643
+ act_fn = GELU(dim, inner_dim, bias=bias)
644
+ if activation_fn == "gelu-approximate":
645
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
646
+ elif activation_fn == "geglu":
647
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
648
+ elif activation_fn == "geglu-approximate":
649
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
650
+
651
+ self.net = nn.ModuleList([])
652
+ # project in
653
+ self.net.append(act_fn)
654
+ # project dropout
655
+ self.net.append(nn.Dropout(dropout))
656
+ # project out
657
+ self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
658
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
659
+ if final_dropout:
660
+ self.net.append(nn.Dropout(dropout))
661
+
662
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
663
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
664
+ for module in self.net:
665
+ if isinstance(module, compatible_cls):
666
+ hidden_states = module(hidden_states, scale)
667
+ else:
668
+ hidden_states = module(hidden_states)
669
+ return hidden_states
foleycrafter/models/auffusion/attention_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
foleycrafter/models/auffusion/dual_transformer_2d.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ from torch import nn
17
+
18
+ from foleycrafter.models.auffusion.transformer_2d \
19
+ import Transformer2DModel, Transformer2DModelOutput
20
+
21
+
22
+ class DualTransformer2DModel(nn.Module):
23
+ """
24
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
25
+
26
+ Parameters:
27
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
28
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
29
+ in_channels (`int`, *optional*):
30
+ Pass if the input is continuous. The number of channels in the input and output.
31
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
32
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
33
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
34
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
35
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
36
+ `ImagePositionalEmbeddings`.
37
+ num_vector_embeds (`int`, *optional*):
38
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
39
+ Includes the class for the masked latent pixel.
40
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
41
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
42
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
43
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
44
+ up to but not more than steps than `num_embeds_ada_norm`.
45
+ attention_bias (`bool`, *optional*):
46
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ num_attention_heads: int = 16,
52
+ attention_head_dim: int = 88,
53
+ in_channels: Optional[int] = None,
54
+ num_layers: int = 1,
55
+ dropout: float = 0.0,
56
+ norm_num_groups: int = 32,
57
+ cross_attention_dim: Optional[int] = None,
58
+ attention_bias: bool = False,
59
+ sample_size: Optional[int] = None,
60
+ num_vector_embeds: Optional[int] = None,
61
+ activation_fn: str = "geglu",
62
+ num_embeds_ada_norm: Optional[int] = None,
63
+ ):
64
+ super().__init__()
65
+ self.transformers = nn.ModuleList(
66
+ [
67
+ Transformer2DModel(
68
+ num_attention_heads=num_attention_heads,
69
+ attention_head_dim=attention_head_dim,
70
+ in_channels=in_channels,
71
+ num_layers=num_layers,
72
+ dropout=dropout,
73
+ norm_num_groups=norm_num_groups,
74
+ cross_attention_dim=cross_attention_dim,
75
+ attention_bias=attention_bias,
76
+ sample_size=sample_size,
77
+ num_vector_embeds=num_vector_embeds,
78
+ activation_fn=activation_fn,
79
+ num_embeds_ada_norm=num_embeds_ada_norm,
80
+ )
81
+ for _ in range(2)
82
+ ]
83
+ )
84
+
85
+ # Variables that can be set by a pipeline:
86
+
87
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
88
+ self.mix_ratio = 0.5
89
+
90
+ # The shape of `encoder_hidden_states` is expected to be
91
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
92
+ self.condition_lengths = [77, 257]
93
+
94
+ # Which transformer to use to encode which condition.
95
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
96
+ self.transformer_index_for_condition = [1, 0]
97
+
98
+ def forward(
99
+ self,
100
+ hidden_states,
101
+ encoder_hidden_states,
102
+ timestep=None,
103
+ attention_mask=None,
104
+ cross_attention_kwargs=None,
105
+ return_dict: bool = True,
106
+ ):
107
+ """
108
+ Args:
109
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
110
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
111
+ hidden_states.
112
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
113
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
114
+ self-attention.
115
+ timestep ( `torch.long`, *optional*):
116
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
117
+ attention_mask (`torch.FloatTensor`, *optional*):
118
+ Optional attention mask to be applied in Attention.
119
+ cross_attention_kwargs (`dict`, *optional*):
120
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
121
+ `self.processor` in
122
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
123
+ return_dict (`bool`, *optional*, defaults to `True`):
124
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
125
+
126
+ Returns:
127
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
128
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
129
+ returning a tuple, the first element is the sample tensor.
130
+ """
131
+ input_states = hidden_states
132
+
133
+ encoded_states = []
134
+ tokens_start = 0
135
+ # attention_mask is not used yet
136
+ for i in range(2):
137
+ # for each of the two transformers, pass the corresponding condition tokens
138
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
139
+ transformer_index = self.transformer_index_for_condition[i]
140
+ encoded_state = self.transformers[transformer_index](
141
+ input_states,
142
+ encoder_hidden_states=condition_state,
143
+ timestep=timestep,
144
+ cross_attention_kwargs=cross_attention_kwargs,
145
+ return_dict=False,
146
+ )[0]
147
+ encoded_states.append(encoded_state - input_states)
148
+ tokens_start += self.condition_lengths[i]
149
+
150
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
151
+ output_states = output_states + input_states
152
+
153
+ if not return_dict:
154
+ return (output_states,)
155
+
156
+ return Transformer2DModelOutput(sample=output_states)
foleycrafter/models/auffusion/loaders/ip_adapter.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from huggingface_hub.utils import validate_hf_hub_args
20
+ from safetensors import safe_open
21
+
22
+ from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
23
+ from diffusers.utils import (
24
+ _get_model_file,
25
+ is_accelerate_available,
26
+ is_torch_version,
27
+ is_transformers_available,
28
+ logging,
29
+ )
30
+
31
+
32
+ if is_transformers_available():
33
+ from transformers import (
34
+ CLIPImageProcessor,
35
+ CLIPVisionModelWithProjection,
36
+ )
37
+
38
+ from diffusers.models.attention_processor import (
39
+ IPAdapterAttnProcessor,
40
+ )
41
+
42
+ from foleycrafter.models.auffusion.attention_processor import IPAdapterAttnProcessor2_0, VPTemporalAdapterAttnProcessor2_0
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ class IPAdapterMixin:
48
+ """Mixin for handling IP Adapters."""
49
+
50
+ @validate_hf_hub_args
51
+ def load_ip_adapter(
52
+ self,
53
+ pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
54
+ subfolder: Union[str, List[str]],
55
+ weight_name: Union[str, List[str]],
56
+ image_encoder_folder: Optional[str] = "image_encoder",
57
+ **kwargs,
58
+ ):
59
+ """
60
+ Parameters:
61
+ pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
62
+ Can be either:
63
+
64
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
65
+ the Hub.
66
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
67
+ with [`ModelMixin.save_pretrained`].
68
+ - A [torch state
69
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
70
+ subfolder (`str` or `List[str]`):
71
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
72
+ If a list is passed, it should have the same length as `weight_name`.
73
+ weight_name (`str` or `List[str]`):
74
+ The name of the weight file to load. If a list is passed, it should have the same length as
75
+ `weight_name`.
76
+ image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
77
+ The subfolder location of the image encoder within a larger model repository on the Hub or locally.
78
+ Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,
79
+ you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`.
80
+ If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,
81
+ for example, `image_encoder_folder="different_subfolder/image_encoder"`.
82
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
83
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
84
+ is not used.
85
+ force_download (`bool`, *optional*, defaults to `False`):
86
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
87
+ cached versions if they exist.
88
+ resume_download (`bool`, *optional*, defaults to `False`):
89
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
90
+ incompletely downloaded files are deleted.
91
+ proxies (`Dict[str, str]`, *optional*):
92
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
93
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
94
+ local_files_only (`bool`, *optional*, defaults to `False`):
95
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
96
+ won't be downloaded from the Hub.
97
+ token (`str` or *bool*, *optional*):
98
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
99
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
100
+ revision (`str`, *optional*, defaults to `"main"`):
101
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
102
+ allowed by Git.
103
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
104
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
105
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
106
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
107
+ argument to `True` will raise an error.
108
+ """
109
+
110
+ # handle the list inputs for multiple IP Adapters
111
+ if not isinstance(weight_name, list):
112
+ weight_name = [weight_name]
113
+
114
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
115
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
116
+ if len(pretrained_model_name_or_path_or_dict) == 1:
117
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
118
+
119
+ if not isinstance(subfolder, list):
120
+ subfolder = [subfolder]
121
+ if len(subfolder) == 1:
122
+ subfolder = subfolder * len(weight_name)
123
+
124
+ if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
125
+ raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
126
+
127
+ if len(weight_name) != len(subfolder):
128
+ raise ValueError("`weight_name` and `subfolder` must have the same length.")
129
+
130
+ # Load the main state dict first.
131
+ cache_dir = kwargs.pop("cache_dir", None)
132
+ force_download = kwargs.pop("force_download", False)
133
+ resume_download = kwargs.pop("resume_download", False)
134
+ proxies = kwargs.pop("proxies", None)
135
+ local_files_only = kwargs.pop("local_files_only", None)
136
+ token = kwargs.pop("token", None)
137
+ revision = kwargs.pop("revision", None)
138
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
139
+
140
+ if low_cpu_mem_usage and not is_accelerate_available():
141
+ low_cpu_mem_usage = False
142
+ logger.warning(
143
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
144
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
145
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
146
+ " install accelerate\n```\n."
147
+ )
148
+
149
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
150
+ raise NotImplementedError(
151
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
152
+ " `low_cpu_mem_usage=False`."
153
+ )
154
+
155
+ user_agent = {
156
+ "file_type": "attn_procs_weights",
157
+ "framework": "pytorch",
158
+ }
159
+ state_dicts = []
160
+ for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
161
+ pretrained_model_name_or_path_or_dict, weight_name, subfolder
162
+ ):
163
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
164
+ model_file = _get_model_file(
165
+ pretrained_model_name_or_path_or_dict,
166
+ weights_name=weight_name,
167
+ cache_dir=cache_dir,
168
+ force_download=force_download,
169
+ resume_download=resume_download,
170
+ proxies=proxies,
171
+ local_files_only=local_files_only,
172
+ token=token,
173
+ revision=revision,
174
+ subfolder=subfolder,
175
+ user_agent=user_agent,
176
+ )
177
+ if weight_name.endswith(".safetensors"):
178
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
179
+ with safe_open(model_file, framework="pt", device="cpu") as f:
180
+ for key in f.keys():
181
+ if key.startswith("image_proj."):
182
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
183
+ elif key.startswith("ip_adapter."):
184
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
185
+ else:
186
+ state_dict = torch.load(model_file, map_location="cpu")
187
+ else:
188
+ state_dict = pretrained_model_name_or_path_or_dict
189
+
190
+ keys = list(state_dict.keys())
191
+ if keys != ["image_proj", "ip_adapter"]:
192
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
193
+
194
+ state_dicts.append(state_dict)
195
+
196
+ # load CLIP image encoder here if it has not been registered to the pipeline yet
197
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
198
+ if image_encoder_folder is not None:
199
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
200
+ logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
201
+ if image_encoder_folder.count("/") == 0:
202
+ image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
203
+ else:
204
+ image_encoder_subfolder = Path(image_encoder_folder).as_posix()
205
+
206
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
207
+ pretrained_model_name_or_path_or_dict,
208
+ subfolder=image_encoder_subfolder,
209
+ low_cpu_mem_usage=low_cpu_mem_usage,
210
+ ).to(self.device, dtype=self.dtype)
211
+ self.register_modules(image_encoder=image_encoder)
212
+ else:
213
+ raise ValueError(
214
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
215
+ )
216
+ else:
217
+ logger.warning(
218
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
219
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
220
+ )
221
+
222
+ # create feature extractor if it has not been registered to the pipeline yet
223
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
224
+ feature_extractor = CLIPImageProcessor()
225
+ self.register_modules(feature_extractor=feature_extractor)
226
+
227
+ # load ip-adapter into unet
228
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
229
+ unet._load_ip_adapter_weights(state_dicts)
230
+
231
+ def set_ip_adapter_scale(self, scale):
232
+ """
233
+ Sets the conditioning scale between text and image.
234
+
235
+ Example:
236
+
237
+ ```py
238
+ pipeline.set_ip_adapter_scale(0.5)
239
+ ```
240
+ """
241
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
242
+ for attn_processor in unet.attn_processors.values():
243
+ if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
244
+ if not isinstance(scale, list):
245
+ scale = [scale] * len(attn_processor.scale)
246
+ if len(attn_processor.scale) != len(scale):
247
+ raise ValueError(
248
+ f"`scale` should be a list of same length as the number if ip-adapters "
249
+ f"Expected {len(attn_processor.scale)} but got {len(scale)}."
250
+ )
251
+ attn_processor.scale = scale
252
+
253
+ def unload_ip_adapter(self):
254
+ """
255
+ Unloads the IP Adapter weights
256
+
257
+ Examples:
258
+
259
+ ```python
260
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
261
+ >>> pipeline.unload_ip_adapter()
262
+ >>> ...
263
+ ```
264
+ """
265
+ # remove CLIP image encoder
266
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
267
+ self.image_encoder = None
268
+ self.register_to_config(image_encoder=[None, None])
269
+
270
+ # remove feature extractor only when safety_checker is None as safety_checker uses
271
+ # the feature_extractor later
272
+ if not hasattr(self, "safety_checker"):
273
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
274
+ self.feature_extractor = None
275
+ self.register_to_config(feature_extractor=[None, None])
276
+
277
+ # remove hidden encoder
278
+ self.unet.encoder_hid_proj = None
279
+ self.config.encoder_hid_dim_type = None
280
+
281
+ # restore original Unet attention processors layers
282
+ self.unet.set_default_attn_processor()
283
+
284
+
285
+ class VPAdapterMixin:
286
+ """Mixin for handling IP Adapters."""
287
+
288
+ @validate_hf_hub_args
289
+ def load_ip_adapter(
290
+ self,
291
+ pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
292
+ subfolder: Union[str, List[str]],
293
+ weight_name: Union[str, List[str]],
294
+ image_encoder_folder: Optional[str] = "image_encoder",
295
+ **kwargs,
296
+ ):
297
+ """
298
+ Parameters:
299
+ pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
300
+ Can be either:
301
+
302
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
303
+ the Hub.
304
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
305
+ with [`ModelMixin.save_pretrained`].
306
+ - A [torch state
307
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
308
+ subfolder (`str` or `List[str]`):
309
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
310
+ If a list is passed, it should have the same length as `weight_name`.
311
+ weight_name (`str` or `List[str]`):
312
+ The name of the weight file to load. If a list is passed, it should have the same length as
313
+ `weight_name`.
314
+ image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
315
+ The subfolder location of the image encoder within a larger model repository on the Hub or locally.
316
+ Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,
317
+ you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`.
318
+ If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,
319
+ for example, `image_encoder_folder="different_subfolder/image_encoder"`.
320
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
321
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
322
+ is not used.
323
+ force_download (`bool`, *optional*, defaults to `False`):
324
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
325
+ cached versions if they exist.
326
+ resume_download (`bool`, *optional*, defaults to `False`):
327
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
328
+ incompletely downloaded files are deleted.
329
+ proxies (`Dict[str, str]`, *optional*):
330
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
331
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
332
+ local_files_only (`bool`, *optional*, defaults to `False`):
333
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
334
+ won't be downloaded from the Hub.
335
+ token (`str` or *bool*, *optional*):
336
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
337
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
338
+ revision (`str`, *optional*, defaults to `"main"`):
339
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
340
+ allowed by Git.
341
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
342
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
343
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
344
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
345
+ argument to `True` will raise an error.
346
+ """
347
+
348
+ # handle the list inputs for multiple IP Adapters
349
+ if not isinstance(weight_name, list):
350
+ weight_name = [weight_name]
351
+
352
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
353
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
354
+ if len(pretrained_model_name_or_path_or_dict) == 1:
355
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
356
+
357
+ if not isinstance(subfolder, list):
358
+ subfolder = [subfolder]
359
+ if len(subfolder) == 1:
360
+ subfolder = subfolder * len(weight_name)
361
+
362
+ if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
363
+ raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
364
+
365
+ if len(weight_name) != len(subfolder):
366
+ raise ValueError("`weight_name` and `subfolder` must have the same length.")
367
+
368
+ # Load the main state dict first.
369
+ cache_dir = kwargs.pop("cache_dir", None)
370
+ force_download = kwargs.pop("force_download", False)
371
+ resume_download = kwargs.pop("resume_download", False)
372
+ proxies = kwargs.pop("proxies", None)
373
+ local_files_only = kwargs.pop("local_files_only", None)
374
+ token = kwargs.pop("token", None)
375
+ revision = kwargs.pop("revision", None)
376
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
377
+
378
+ if low_cpu_mem_usage and not is_accelerate_available():
379
+ low_cpu_mem_usage = False
380
+ logger.warning(
381
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
382
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
383
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
384
+ " install accelerate\n```\n."
385
+ )
386
+
387
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
388
+ raise NotImplementedError(
389
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
390
+ " `low_cpu_mem_usage=False`."
391
+ )
392
+
393
+ user_agent = {
394
+ "file_type": "attn_procs_weights",
395
+ "framework": "pytorch",
396
+ }
397
+ state_dicts = []
398
+ for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
399
+ pretrained_model_name_or_path_or_dict, weight_name, subfolder
400
+ ):
401
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
402
+ model_file = _get_model_file(
403
+ pretrained_model_name_or_path_or_dict,
404
+ weights_name=weight_name,
405
+ cache_dir=cache_dir,
406
+ force_download=force_download,
407
+ resume_download=resume_download,
408
+ proxies=proxies,
409
+ local_files_only=local_files_only,
410
+ token=token,
411
+ revision=revision,
412
+ subfolder=subfolder,
413
+ user_agent=user_agent,
414
+ )
415
+ if weight_name.endswith(".safetensors"):
416
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
417
+ with safe_open(model_file, framework="pt", device="cpu") as f:
418
+ for key in f.keys():
419
+ if key.startswith("image_proj."):
420
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
421
+ elif key.startswith("ip_adapter."):
422
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
423
+ else:
424
+ state_dict = torch.load(model_file, map_location="cpu")
425
+ else:
426
+ state_dict = pretrained_model_name_or_path_or_dict
427
+
428
+ keys = list(state_dict.keys())
429
+ if keys != ["image_proj", "ip_adapter"]:
430
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
431
+
432
+ state_dicts.append(state_dict)
433
+
434
+ # load CLIP image encoder here if it has not been registered to the pipeline yet
435
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
436
+ if image_encoder_folder is not None:
437
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
438
+ logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
439
+ if image_encoder_folder.count("/") == 0:
440
+ image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
441
+ else:
442
+ image_encoder_subfolder = Path(image_encoder_folder).as_posix()
443
+
444
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
445
+ pretrained_model_name_or_path_or_dict,
446
+ subfolder=image_encoder_subfolder,
447
+ low_cpu_mem_usage=low_cpu_mem_usage,
448
+ ).to(self.device, dtype=self.dtype)
449
+ self.register_modules(image_encoder=image_encoder)
450
+ else:
451
+ raise ValueError(
452
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
453
+ )
454
+ else:
455
+ logger.warning(
456
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
457
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
458
+ )
459
+
460
+ # create feature extractor if it has not been registered to the pipeline yet
461
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
462
+ feature_extractor = CLIPImageProcessor()
463
+ self.register_modules(feature_extractor=feature_extractor)
464
+
465
+ # load ip-adapter into unet
466
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
467
+ unet._load_ip_adapter_weights_VPAdapter(state_dicts)
468
+
469
+ def set_ip_adapter_scale(self, scale):
470
+ """
471
+ Sets the conditioning scale between text and image.
472
+
473
+ Example:
474
+
475
+ ```py
476
+ pipeline.set_ip_adapter_scale(0.5)
477
+ ```
478
+ """
479
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
480
+ for attn_processor in unet.attn_processors.values():
481
+ if isinstance(attn_processor, (IPAdapterAttnProcessor, VPTemporalAdapterAttnProcessor2_0)):
482
+ if not isinstance(scale, list):
483
+ scale = [scale] * len(attn_processor.scale)
484
+ if len(attn_processor.scale) != len(scale):
485
+ raise ValueError(
486
+ f"`scale` should be a list of same length as the number if ip-adapters "
487
+ f"Expected {len(attn_processor.scale)} but got {len(scale)}."
488
+ )
489
+ attn_processor.scale = scale
490
+
491
+ def unload_ip_adapter(self):
492
+ """
493
+ Unloads the IP Adapter weights
494
+
495
+ Examples:
496
+
497
+ ```python
498
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
499
+ >>> pipeline.unload_ip_adapter()
500
+ >>> ...
501
+ ```
502
+ """
503
+ # remove CLIP image encoder
504
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
505
+ self.image_encoder = None
506
+ self.register_to_config(image_encoder=[None, None])
507
+
508
+ # remove feature extractor only when safety_checker is None as safety_checker uses
509
+ # the feature_extractor later
510
+ if not hasattr(self, "safety_checker"):
511
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
512
+ self.feature_extractor = None
513
+ self.register_to_config(feature_extractor=[None, None])
514
+
515
+ # remove hidden encoder
516
+ self.unet.encoder_hid_proj = None
517
+ self.config.encoder_hid_dim_type = None
518
+
519
+ # restore original Unet attention processors layers
520
+ self.unet.set_default_attn_processor()
foleycrafter/models/auffusion/loaders/unet.py ADDED
@@ -0,0 +1,1100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ import os
16
+ from collections import defaultdict
17
+ from contextlib import nullcontext
18
+ from functools import partial
19
+ from typing import Callable, Dict, List, Optional, Union, Tuple
20
+
21
+ import safetensors
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from huggingface_hub.utils import validate_hf_hub_args
25
+ from torch import nn
26
+
27
+ from diffusers.models.embeddings import ImageProjection, MLPProjection, Resampler
28
+ from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
29
+ from diffusers.utils import (
30
+ USE_PEFT_BACKEND,
31
+ _get_model_file,
32
+ delete_adapter_layers,
33
+ is_accelerate_available,
34
+ logging,
35
+ is_torch_version,
36
+ set_adapter_layers,
37
+ set_weights_and_activate_adapters,
38
+ )
39
+ from diffusers.loaders.utils import AttnProcsLayers
40
+
41
+ from foleycrafter.models.adapters.ip_adapter import VideoProjModel
42
+ from foleycrafter.models.auffusion.attention_processor import IPAdapterAttnProcessor2_0, VPTemporalAdapterAttnProcessor2_0, AttnProcessor2_0
43
+
44
+
45
+ if is_accelerate_available():
46
+ from accelerate import init_empty_weights
47
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ class VPAdapterImageProjection(nn.Module):
52
+ def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
53
+ super().__init__()
54
+ self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
55
+
56
+ def forward(self, image_embeds: List[torch.FloatTensor]):
57
+ projected_image_embeds = []
58
+
59
+ # currently, we accept `image_embeds` as
60
+ # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
61
+ # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
62
+ if not isinstance(image_embeds, list):
63
+ deprecation_message = (
64
+ "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
65
+ " Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
66
+ )
67
+ image_embeds = [image_embeds.unsqueeze(1)]
68
+
69
+ if len(image_embeds) != len(self.image_projection_layers):
70
+ raise ValueError(
71
+ f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
72
+ )
73
+
74
+ for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
75
+ image_embed = image_embed.squeeze(1)
76
+ batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
77
+ image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
78
+ image_embed = image_projection_layer(image_embed)
79
+ image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
80
+
81
+ projected_image_embeds.append(image_embed)
82
+
83
+ return projected_image_embeds
84
+
85
+ class MultiIPAdapterImageProjection(nn.Module):
86
+ def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
87
+ super().__init__()
88
+ self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
89
+
90
+ def forward(self, image_embeds: List[torch.FloatTensor]):
91
+ projected_image_embeds = []
92
+
93
+ # currently, we accept `image_embeds` as
94
+ # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
95
+ # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
96
+ if not isinstance(image_embeds, list):
97
+ deprecation_message = (
98
+ "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
99
+ " Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
100
+ )
101
+ image_embeds = [image_embeds.unsqueeze(1)]
102
+
103
+ if len(image_embeds) != len(self.image_projection_layers):
104
+ raise ValueError(
105
+ f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
106
+ )
107
+
108
+ for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
109
+ batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
110
+ image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
111
+ image_embed = image_projection_layer(image_embed)
112
+ image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
113
+
114
+ projected_image_embeds.append(image_embed)
115
+
116
+ return projected_image_embeds
117
+
118
+
119
+ TEXT_ENCODER_NAME = "text_encoder"
120
+ UNET_NAME = "unet"
121
+
122
+ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
123
+ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
124
+
125
+ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
126
+ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
127
+
128
+
129
+ class UNet2DConditionLoadersMixin:
130
+ """
131
+ Load LoRA layers into a [`UNet2DCondtionModel`].
132
+ """
133
+
134
+ text_encoder_name = TEXT_ENCODER_NAME
135
+ unet_name = UNET_NAME
136
+
137
+ @validate_hf_hub_args
138
+ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
139
+ r"""
140
+ Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
141
+ defined in
142
+ [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
143
+ and be a `torch.nn.Module` class.
144
+
145
+ Parameters:
146
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
147
+ Can be either:
148
+
149
+ - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
150
+ the Hub.
151
+ - A path to a directory (for example `./my_model_directory`) containing the model weights saved
152
+ with [`ModelMixin.save_pretrained`].
153
+ - A [torch state
154
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
155
+
156
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
157
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
158
+ is not used.
159
+ force_download (`bool`, *optional*, defaults to `False`):
160
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
161
+ cached versions if they exist.
162
+ resume_download (`bool`, *optional*, defaults to `False`):
163
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
164
+ incompletely downloaded files are deleted.
165
+ proxies (`Dict[str, str]`, *optional*):
166
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
167
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
168
+ local_files_only (`bool`, *optional*, defaults to `False`):
169
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
170
+ won't be downloaded from the Hub.
171
+ token (`str` or *bool*, *optional*):
172
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
173
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
174
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
175
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
176
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
177
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
178
+ argument to `True` will raise an error.
179
+ revision (`str`, *optional*, defaults to `"main"`):
180
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
181
+ allowed by Git.
182
+ subfolder (`str`, *optional*, defaults to `""`):
183
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
184
+ mirror (`str`, *optional*):
185
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
186
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
187
+ information.
188
+
189
+ Example:
190
+
191
+ ```py
192
+ from diffusers import AutoPipelineForText2Image
193
+ import torch
194
+
195
+ pipeline = AutoPipelineForText2Image.from_pretrained(
196
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
197
+ ).to("cuda")
198
+ pipeline.unet.load_attn_procs(
199
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
200
+ )
201
+ ```
202
+ """
203
+ from diffusers.models.attention_processor import CustomDiffusionAttnProcessor
204
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
205
+
206
+ cache_dir = kwargs.pop("cache_dir", None)
207
+ force_download = kwargs.pop("force_download", False)
208
+ resume_download = kwargs.pop("resume_download", False)
209
+ proxies = kwargs.pop("proxies", None)
210
+ local_files_only = kwargs.pop("local_files_only", None)
211
+ token = kwargs.pop("token", None)
212
+ revision = kwargs.pop("revision", None)
213
+ subfolder = kwargs.pop("subfolder", None)
214
+ weight_name = kwargs.pop("weight_name", None)
215
+ use_safetensors = kwargs.pop("use_safetensors", None)
216
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
217
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
218
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
219
+ network_alphas = kwargs.pop("network_alphas", None)
220
+
221
+ _pipeline = kwargs.pop("_pipeline", None)
222
+
223
+ is_network_alphas_none = network_alphas is None
224
+
225
+ allow_pickle = False
226
+
227
+ if use_safetensors is None:
228
+ use_safetensors = True
229
+ allow_pickle = True
230
+
231
+ user_agent = {
232
+ "file_type": "attn_procs_weights",
233
+ "framework": "pytorch",
234
+ }
235
+
236
+ if low_cpu_mem_usage and not is_accelerate_available():
237
+ low_cpu_mem_usage = False
238
+ logger.warning(
239
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
240
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
241
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
242
+ " install accelerate\n```\n."
243
+ )
244
+
245
+ model_file = None
246
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
247
+ # Let's first try to load .safetensors weights
248
+ if (use_safetensors and weight_name is None) or (
249
+ weight_name is not None and weight_name.endswith(".safetensors")
250
+ ):
251
+ try:
252
+ model_file = _get_model_file(
253
+ pretrained_model_name_or_path_or_dict,
254
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
255
+ cache_dir=cache_dir,
256
+ force_download=force_download,
257
+ resume_download=resume_download,
258
+ proxies=proxies,
259
+ local_files_only=local_files_only,
260
+ token=token,
261
+ revision=revision,
262
+ subfolder=subfolder,
263
+ user_agent=user_agent,
264
+ )
265
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
266
+ except IOError as e:
267
+ if not allow_pickle:
268
+ raise e
269
+ # try loading non-safetensors weights
270
+ pass
271
+ if model_file is None:
272
+ model_file = _get_model_file(
273
+ pretrained_model_name_or_path_or_dict,
274
+ weights_name=weight_name or LORA_WEIGHT_NAME,
275
+ cache_dir=cache_dir,
276
+ force_download=force_download,
277
+ resume_download=resume_download,
278
+ proxies=proxies,
279
+ local_files_only=local_files_only,
280
+ token=token,
281
+ revision=revision,
282
+ subfolder=subfolder,
283
+ user_agent=user_agent,
284
+ )
285
+ state_dict = torch.load(model_file, map_location="cpu")
286
+ else:
287
+ state_dict = pretrained_model_name_or_path_or_dict
288
+
289
+ # fill attn processors
290
+ lora_layers_list = []
291
+
292
+ is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
293
+ is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
294
+
295
+ if is_lora:
296
+ # correct keys
297
+ state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
298
+
299
+ if network_alphas is not None:
300
+ network_alphas_keys = list(network_alphas.keys())
301
+ used_network_alphas_keys = set()
302
+
303
+ lora_grouped_dict = defaultdict(dict)
304
+ mapped_network_alphas = {}
305
+
306
+ all_keys = list(state_dict.keys())
307
+ for key in all_keys:
308
+ value = state_dict.pop(key)
309
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
310
+ lora_grouped_dict[attn_processor_key][sub_key] = value
311
+
312
+ # Create another `mapped_network_alphas` dictionary so that we can properly map them.
313
+ if network_alphas is not None:
314
+ for k in network_alphas_keys:
315
+ if k.replace(".alpha", "") in key:
316
+ mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
317
+ used_network_alphas_keys.add(k)
318
+
319
+ if not is_network_alphas_none:
320
+ if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
321
+ raise ValueError(
322
+ f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
323
+ )
324
+
325
+ if len(state_dict) > 0:
326
+ raise ValueError(
327
+ f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
328
+ )
329
+
330
+ for key, value_dict in lora_grouped_dict.items():
331
+ attn_processor = self
332
+ for sub_key in key.split("."):
333
+ attn_processor = getattr(attn_processor, sub_key)
334
+
335
+ # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
336
+ # or add_{k,v,q,out_proj}_proj_lora layers.
337
+ rank = value_dict["lora.down.weight"].shape[0]
338
+
339
+ if isinstance(attn_processor, LoRACompatibleConv):
340
+ in_features = attn_processor.in_channels
341
+ out_features = attn_processor.out_channels
342
+ kernel_size = attn_processor.kernel_size
343
+
344
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
345
+ with ctx():
346
+ lora = LoRAConv2dLayer(
347
+ in_features=in_features,
348
+ out_features=out_features,
349
+ rank=rank,
350
+ kernel_size=kernel_size,
351
+ stride=attn_processor.stride,
352
+ padding=attn_processor.padding,
353
+ network_alpha=mapped_network_alphas.get(key),
354
+ )
355
+ elif isinstance(attn_processor, LoRACompatibleLinear):
356
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
357
+ with ctx():
358
+ lora = LoRALinearLayer(
359
+ attn_processor.in_features,
360
+ attn_processor.out_features,
361
+ rank,
362
+ mapped_network_alphas.get(key),
363
+ )
364
+ else:
365
+ raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
366
+
367
+ value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
368
+ lora_layers_list.append((attn_processor, lora))
369
+
370
+ if low_cpu_mem_usage:
371
+ device = next(iter(value_dict.values())).device
372
+ dtype = next(iter(value_dict.values())).dtype
373
+ load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
374
+ else:
375
+ lora.load_state_dict(value_dict)
376
+
377
+ elif is_custom_diffusion:
378
+ attn_processors = {}
379
+ custom_diffusion_grouped_dict = defaultdict(dict)
380
+ for key, value in state_dict.items():
381
+ if len(value) == 0:
382
+ custom_diffusion_grouped_dict[key] = {}
383
+ else:
384
+ if "to_out" in key:
385
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
386
+ else:
387
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
388
+ custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
389
+
390
+ for key, value_dict in custom_diffusion_grouped_dict.items():
391
+ if len(value_dict) == 0:
392
+ attn_processors[key] = CustomDiffusionAttnProcessor(
393
+ train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
394
+ )
395
+ else:
396
+ cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
397
+ hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
398
+ train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
399
+ attn_processors[key] = CustomDiffusionAttnProcessor(
400
+ train_kv=True,
401
+ train_q_out=train_q_out,
402
+ hidden_size=hidden_size,
403
+ cross_attention_dim=cross_attention_dim,
404
+ )
405
+ attn_processors[key].load_state_dict(value_dict)
406
+ elif USE_PEFT_BACKEND:
407
+ # In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
408
+ # on the Unet
409
+ pass
410
+ else:
411
+ raise ValueError(
412
+ f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
413
+ )
414
+
415
+ # <Unsafe code
416
+ # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
417
+ # Now we remove any existing hooks to
418
+ is_model_cpu_offload = False
419
+ is_sequential_cpu_offload = False
420
+
421
+ # For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
422
+ if not USE_PEFT_BACKEND:
423
+ if _pipeline is not None:
424
+ for _, component in _pipeline.components.items():
425
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
426
+ is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
427
+ is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
428
+
429
+ logger.info(
430
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
431
+ )
432
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
433
+
434
+ # only custom diffusion needs to set attn processors
435
+ if is_custom_diffusion:
436
+ self.set_attn_processor(attn_processors)
437
+
438
+ # set lora layers
439
+ for target_module, lora_layer in lora_layers_list:
440
+ target_module.set_lora_layer(lora_layer)
441
+
442
+ self.to(dtype=self.dtype, device=self.device)
443
+
444
+ # Offload back.
445
+ if is_model_cpu_offload:
446
+ _pipeline.enable_model_cpu_offload()
447
+ elif is_sequential_cpu_offload:
448
+ _pipeline.enable_sequential_cpu_offload()
449
+ # Unsafe code />
450
+
451
+ def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
452
+ is_new_lora_format = all(
453
+ key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
454
+ )
455
+ if is_new_lora_format:
456
+ # Strip the `"unet"` prefix.
457
+ is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
458
+ if is_text_encoder_present:
459
+ warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
460
+ logger.warn(warn_message)
461
+ unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
462
+ state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
463
+
464
+ # change processor format to 'pure' LoRACompatibleLinear format
465
+ if any("processor" in k.split(".") for k in state_dict.keys()):
466
+
467
+ def format_to_lora_compatible(key):
468
+ if "processor" not in key.split("."):
469
+ return key
470
+ return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora")
471
+
472
+ state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}
473
+
474
+ if network_alphas is not None:
475
+ network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
476
+ return state_dict, network_alphas
477
+
478
+ def save_attn_procs(
479
+ self,
480
+ save_directory: Union[str, os.PathLike],
481
+ is_main_process: bool = True,
482
+ weight_name: str = None,
483
+ save_function: Callable = None,
484
+ safe_serialization: bool = True,
485
+ **kwargs,
486
+ ):
487
+ r"""
488
+ Save attention processor layers to a directory so that it can be reloaded with the
489
+ [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
490
+
491
+ Arguments:
492
+ save_directory (`str` or `os.PathLike`):
493
+ Directory to save an attention processor to (will be created if it doesn't exist).
494
+ is_main_process (`bool`, *optional*, defaults to `True`):
495
+ Whether the process calling this is the main process or not. Useful during distributed training and you
496
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
497
+ process to avoid race conditions.
498
+ save_function (`Callable`):
499
+ The function to use to save the state dictionary. Useful during distributed training when you need to
500
+ replace `torch.save` with another method. Can be configured with the environment variable
501
+ `DIFFUSERS_SAVE_MODE`.
502
+ safe_serialization (`bool`, *optional*, defaults to `True`):
503
+ Whether to save the model using `safetensors` or with `pickle`.
504
+
505
+ Example:
506
+
507
+ ```py
508
+ import torch
509
+ from diffusers import DiffusionPipeline
510
+
511
+ pipeline = DiffusionPipeline.from_pretrained(
512
+ "CompVis/stable-diffusion-v1-4",
513
+ torch_dtype=torch.float16,
514
+ ).to("cuda")
515
+ pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
516
+ pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
517
+ ```
518
+ """
519
+ from diffusers.models.attention_processor import (
520
+ CustomDiffusionAttnProcessor,
521
+ CustomDiffusionAttnProcessor2_0,
522
+ CustomDiffusionXFormersAttnProcessor,
523
+ )
524
+
525
+ if os.path.isfile(save_directory):
526
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
527
+ return
528
+
529
+ if save_function is None:
530
+ if safe_serialization:
531
+
532
+ def save_function(weights, filename):
533
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
534
+
535
+ else:
536
+ save_function = torch.save
537
+
538
+ os.makedirs(save_directory, exist_ok=True)
539
+
540
+ is_custom_diffusion = any(
541
+ isinstance(
542
+ x,
543
+ (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
544
+ )
545
+ for (_, x) in self.attn_processors.items()
546
+ )
547
+ if is_custom_diffusion:
548
+ model_to_save = AttnProcsLayers(
549
+ {
550
+ y: x
551
+ for (y, x) in self.attn_processors.items()
552
+ if isinstance(
553
+ x,
554
+ (
555
+ CustomDiffusionAttnProcessor,
556
+ CustomDiffusionAttnProcessor2_0,
557
+ CustomDiffusionXFormersAttnProcessor,
558
+ ),
559
+ )
560
+ }
561
+ )
562
+ state_dict = model_to_save.state_dict()
563
+ for name, attn in self.attn_processors.items():
564
+ if len(attn.state_dict()) == 0:
565
+ state_dict[name] = {}
566
+ else:
567
+ model_to_save = AttnProcsLayers(self.attn_processors)
568
+ state_dict = model_to_save.state_dict()
569
+
570
+ if weight_name is None:
571
+ if safe_serialization:
572
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
573
+ else:
574
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
575
+
576
+ # Save the model
577
+ save_function(state_dict, os.path.join(save_directory, weight_name))
578
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
579
+
580
+ def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
581
+ self.lora_scale = lora_scale
582
+ self._safe_fusing = safe_fusing
583
+ self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
584
+
585
+ def _fuse_lora_apply(self, module, adapter_names=None):
586
+ if not USE_PEFT_BACKEND:
587
+ if hasattr(module, "_fuse_lora"):
588
+ module._fuse_lora(self.lora_scale, self._safe_fusing)
589
+
590
+ if adapter_names is not None:
591
+ raise ValueError(
592
+ "The `adapter_names` argument is not supported in your environment. Please switch"
593
+ " to PEFT backend to use this argument by installing latest PEFT and transformers."
594
+ " `pip install -U peft transformers`"
595
+ )
596
+ else:
597
+ from peft.tuners.tuners_utils import BaseTunerLayer
598
+
599
+ merge_kwargs = {"safe_merge": self._safe_fusing}
600
+
601
+ if isinstance(module, BaseTunerLayer):
602
+ if self.lora_scale != 1.0:
603
+ module.scale_layer(self.lora_scale)
604
+
605
+ # For BC with prevous PEFT versions, we need to check the signature
606
+ # of the `merge` method to see if it supports the `adapter_names` argument.
607
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
608
+ if "adapter_names" in supported_merge_kwargs:
609
+ merge_kwargs["adapter_names"] = adapter_names
610
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
611
+ raise ValueError(
612
+ "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
613
+ " to the latest version of PEFT. `pip install -U peft`"
614
+ )
615
+
616
+ module.merge(**merge_kwargs)
617
+
618
+ def unfuse_lora(self):
619
+ self.apply(self._unfuse_lora_apply)
620
+
621
+ def _unfuse_lora_apply(self, module):
622
+ if not USE_PEFT_BACKEND:
623
+ if hasattr(module, "_unfuse_lora"):
624
+ module._unfuse_lora()
625
+ else:
626
+ from peft.tuners.tuners_utils import BaseTunerLayer
627
+
628
+ if isinstance(module, BaseTunerLayer):
629
+ module.unmerge()
630
+
631
+ def set_adapters(
632
+ self,
633
+ adapter_names: Union[List[str], str],
634
+ weights: Optional[Union[List[float], float]] = None,
635
+ ):
636
+ """
637
+ Set the currently active adapters for use in the UNet.
638
+
639
+ Args:
640
+ adapter_names (`List[str]` or `str`):
641
+ The names of the adapters to use.
642
+ adapter_weights (`Union[List[float], float]`, *optional*):
643
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
644
+ adapters.
645
+
646
+ Example:
647
+
648
+ ```py
649
+ from diffusers import AutoPipelineForText2Image
650
+ import torch
651
+
652
+ pipeline = AutoPipelineForText2Image.from_pretrained(
653
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
654
+ ).to("cuda")
655
+ pipeline.load_lora_weights(
656
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
657
+ )
658
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
659
+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
660
+ ```
661
+ """
662
+ if not USE_PEFT_BACKEND:
663
+ raise ValueError("PEFT backend is required for `set_adapters()`.")
664
+
665
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
666
+
667
+ if weights is None:
668
+ weights = [1.0] * len(adapter_names)
669
+ elif isinstance(weights, float):
670
+ weights = [weights] * len(adapter_names)
671
+
672
+ if len(adapter_names) != len(weights):
673
+ raise ValueError(
674
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
675
+ )
676
+
677
+ set_weights_and_activate_adapters(self, adapter_names, weights)
678
+
679
+ def disable_lora(self):
680
+ """
681
+ Disable the UNet's active LoRA layers.
682
+
683
+ Example:
684
+
685
+ ```py
686
+ from diffusers import AutoPipelineForText2Image
687
+ import torch
688
+
689
+ pipeline = AutoPipelineForText2Image.from_pretrained(
690
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
691
+ ).to("cuda")
692
+ pipeline.load_lora_weights(
693
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
694
+ )
695
+ pipeline.disable_lora()
696
+ ```
697
+ """
698
+ if not USE_PEFT_BACKEND:
699
+ raise ValueError("PEFT backend is required for this method.")
700
+ set_adapter_layers(self, enabled=False)
701
+
702
+ def enable_lora(self):
703
+ """
704
+ Enable the UNet's active LoRA layers.
705
+
706
+ Example:
707
+
708
+ ```py
709
+ from diffusers import AutoPipelineForText2Image
710
+ import torch
711
+
712
+ pipeline = AutoPipelineForText2Image.from_pretrained(
713
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
714
+ ).to("cuda")
715
+ pipeline.load_lora_weights(
716
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
717
+ )
718
+ pipeline.enable_lora()
719
+ ```
720
+ """
721
+ if not USE_PEFT_BACKEND:
722
+ raise ValueError("PEFT backend is required for this method.")
723
+ set_adapter_layers(self, enabled=True)
724
+
725
+ def delete_adapters(self, adapter_names: Union[List[str], str]):
726
+ """
727
+ Delete an adapter's LoRA layers from the UNet.
728
+
729
+ Args:
730
+ adapter_names (`Union[List[str], str]`):
731
+ The names (single string or list of strings) of the adapter to delete.
732
+
733
+ Example:
734
+
735
+ ```py
736
+ from diffusers import AutoPipelineForText2Image
737
+ import torch
738
+
739
+ pipeline = AutoPipelineForText2Image.from_pretrained(
740
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
741
+ ).to("cuda")
742
+ pipeline.load_lora_weights(
743
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
744
+ )
745
+ pipeline.delete_adapters("cinematic")
746
+ ```
747
+ """
748
+ if not USE_PEFT_BACKEND:
749
+ raise ValueError("PEFT backend is required for this method.")
750
+
751
+ if isinstance(adapter_names, str):
752
+ adapter_names = [adapter_names]
753
+
754
+ for adapter_name in adapter_names:
755
+ delete_adapter_layers(self, adapter_name)
756
+
757
+ # Pop also the corresponding adapter from the config
758
+ if hasattr(self, "peft_config"):
759
+ self.peft_config.pop(adapter_name, None)
760
+
761
+ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
762
+ if low_cpu_mem_usage:
763
+ if is_accelerate_available():
764
+ from accelerate import init_empty_weights
765
+
766
+ else:
767
+ low_cpu_mem_usage = False
768
+ logger.warning(
769
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
770
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
771
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
772
+ " install accelerate\n```\n."
773
+ )
774
+
775
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
776
+ raise NotImplementedError(
777
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
778
+ " `low_cpu_mem_usage=False`."
779
+ )
780
+
781
+ updated_state_dict = {}
782
+ image_projection = None
783
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
784
+
785
+ if "proj.weight" in state_dict:
786
+ # IP-Adapter
787
+ num_image_text_embeds = 4
788
+ clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
789
+ cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
790
+
791
+ with init_context():
792
+ image_projection = ImageProjection(
793
+ cross_attention_dim=cross_attention_dim,
794
+ image_embed_dim=clip_embeddings_dim,
795
+ num_image_text_embeds=num_image_text_embeds,
796
+ )
797
+
798
+ for key, value in state_dict.items():
799
+ diffusers_name = key.replace("proj", "image_embeds")
800
+ updated_state_dict[diffusers_name] = value
801
+
802
+ if not low_cpu_mem_usage:
803
+ image_projection.load_state_dict(updated_state_dict)
804
+ else:
805
+ load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
806
+
807
+ return image_projection
808
+
809
+ # def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, multi_frames_condition):
810
+ # updated_state_dict = {}
811
+ # image_projection = None
812
+
813
+ # if "proj.weight" in state_dict:
814
+ # # IP-Adapter
815
+ # # NOTE: adapt for multi-frame
816
+ # num_image_text_embeds = 4
817
+ # clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
818
+ # cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
819
+ # # cross_attention_dim = state_dict["proj.weight"].shape[0]
820
+
821
+ # if not multi_frames_condition:
822
+ # image_projection = ImageProjection(
823
+ # cross_attention_dim=cross_attention_dim,
824
+ # image_embed_dim=clip_embeddings_dim,
825
+ # num_image_text_embeds=num_image_text_embeds,
826
+ # )
827
+ # else:
828
+ # num_image_text_embeds = 50
829
+ # cross_attention_dim = state_dict["proj.weight"].shape[0]
830
+ # image_projection = VideoProjModel(
831
+ # cross_attention_dim=cross_attention_dim,
832
+ # clip_embeddings_dim=clip_embeddings_dim,
833
+ # clip_extra_context_tokens=1,
834
+ # video_frame=num_image_text_embeds,
835
+ # )
836
+
837
+ # for key, value in state_dict.items():
838
+ # if not multi_frames_condition:
839
+ # diffusers_name = key.replace("proj", "image_embeds")
840
+ # else:
841
+ # diffusers_name = key
842
+ # updated_state_dict[diffusers_name] = value
843
+
844
+ # elif "proj.3.weight" in state_dict:
845
+ # # IP-Adapter Full
846
+ # clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
847
+ # cross_attention_dim = state_dict["proj.3.weight"].shape[0]
848
+
849
+ # image_projection = MLPProjection(
850
+ # cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
851
+ # )
852
+
853
+ # for key, value in state_dict.items():
854
+ # diffusers_name = key.replace("proj.0", "ff.net.0.proj")
855
+ # diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
856
+ # diffusers_name = diffusers_name.replace("proj.3", "norm")
857
+ # updated_state_dict[diffusers_name] = value
858
+
859
+ # else:
860
+ # # IP-Adapter Plus
861
+ # num_image_text_embeds = state_dict["latents"].shape[1]
862
+ # embed_dims = state_dict["proj_in.weight"].shape[1]
863
+ # output_dims = state_dict["proj_out.weight"].shape[0]
864
+ # hidden_dims = state_dict["latents"].shape[2]
865
+ # heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
866
+
867
+ # image_projection = Resampler(
868
+ # embed_dims=embed_dims,
869
+ # output_dims=output_dims,
870
+ # hidden_dims=hidden_dims,
871
+ # heads=heads,
872
+ # num_queries=num_image_text_embeds,
873
+ # )
874
+
875
+ # for key, value in state_dict.items():
876
+ # diffusers_name = key.replace("0.to", "2.to")
877
+ # diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
878
+ # diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
879
+ # diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
880
+ # diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
881
+
882
+ # if "norm1" in diffusers_name:
883
+ # updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
884
+ # elif "norm2" in diffusers_name:
885
+ # updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
886
+ # elif "to_kv" in diffusers_name:
887
+ # v_chunk = value.chunk(2, dim=0)
888
+ # updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
889
+ # updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
890
+ # elif "to_out" in diffusers_name:
891
+ # updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
892
+ # else:
893
+ # updated_state_dict[diffusers_name] = value
894
+
895
+ # image_projection.load_state_dict(updated_state_dict)
896
+ # return image_projection
897
+
898
+ def _convert_ip_adapter_attn_to_diffusers_VPAdapter(self, state_dicts, low_cpu_mem_usage=False):
899
+ from diffusers.models.attention_processor import (
900
+ AttnProcessor,
901
+ IPAdapterAttnProcessor,
902
+ )
903
+
904
+ if low_cpu_mem_usage:
905
+ if is_accelerate_available():
906
+ from accelerate import init_empty_weights
907
+
908
+ else:
909
+ low_cpu_mem_usage = False
910
+ logger.warning(
911
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
912
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
913
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
914
+ " install accelerate\n```\n."
915
+ )
916
+
917
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
918
+ raise NotImplementedError(
919
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
920
+ " `low_cpu_mem_usage=False`."
921
+ )
922
+
923
+ # set ip-adapter cross-attention processors & load state_dict
924
+ attn_procs = {}
925
+ key_id = 1
926
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
927
+ for name in self.attn_processors.keys():
928
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
929
+ if name.startswith("mid_block"):
930
+ hidden_size = self.config.block_out_channels[-1]
931
+ elif name.startswith("up_blocks"):
932
+ block_id = int(name[len("up_blocks.")])
933
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
934
+ elif name.startswith("down_blocks"):
935
+ block_id = int(name[len("down_blocks.")])
936
+ hidden_size = self.config.block_out_channels[block_id]
937
+
938
+ if cross_attention_dim is None or "motion_modules" in name or 'fuser' in name:
939
+ attn_processor_class = (
940
+ AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
941
+ )
942
+ attn_procs[name] = attn_processor_class()
943
+ else:
944
+ attn_processor_class = (
945
+ VPTemporalAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
946
+ )
947
+ num_image_text_embeds = []
948
+ for state_dict in state_dicts:
949
+ if "proj.weight" in state_dict["image_proj"]:
950
+ # IP-Adapter
951
+ num_image_text_embeds += [4]
952
+ elif "proj.3.weight" in state_dict["image_proj"]:
953
+ # IP-Adapter Full Face
954
+ num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
955
+ else:
956
+ # IP-Adapter Plus
957
+ num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
958
+
959
+ with init_context():
960
+ attn_procs[name] = attn_processor_class(
961
+ hidden_size=hidden_size,
962
+ cross_attention_dim=cross_attention_dim,
963
+ scale=1.0,
964
+ num_tokens=num_image_text_embeds,
965
+ )
966
+
967
+ value_dict = {}
968
+ for i, state_dict in enumerate(state_dicts):
969
+ value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
970
+ value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
971
+
972
+ if not low_cpu_mem_usage:
973
+ attn_procs[name].load_state_dict(value_dict)
974
+ else:
975
+ device = next(iter(value_dict.values())).device
976
+ dtype = next(iter(value_dict.values())).dtype
977
+ load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
978
+
979
+ key_id += 2
980
+
981
+ return attn_procs
982
+
983
+ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
984
+ from diffusers.models.attention_processor import (
985
+ AttnProcessor,
986
+ IPAdapterAttnProcessor,
987
+ )
988
+
989
+ if low_cpu_mem_usage:
990
+ if is_accelerate_available():
991
+ from accelerate import init_empty_weights
992
+
993
+ else:
994
+ low_cpu_mem_usage = False
995
+ logger.warning(
996
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
997
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
998
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
999
+ " install accelerate\n```\n."
1000
+ )
1001
+
1002
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
1003
+ raise NotImplementedError(
1004
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
1005
+ " `low_cpu_mem_usage=False`."
1006
+ )
1007
+
1008
+ # set ip-adapter cross-attention processors & load state_dict
1009
+ attn_procs = {}
1010
+ key_id = 1
1011
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
1012
+ for name in self.attn_processors.keys():
1013
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
1014
+ if name.startswith("mid_block"):
1015
+ hidden_size = self.config.block_out_channels[-1]
1016
+ elif name.startswith("up_blocks"):
1017
+ block_id = int(name[len("up_blocks.")])
1018
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
1019
+ elif name.startswith("down_blocks"):
1020
+ block_id = int(name[len("down_blocks.")])
1021
+ hidden_size = self.config.block_out_channels[block_id]
1022
+
1023
+ if cross_attention_dim is None or "motion_modules" in name or 'fuser' in name:
1024
+ attn_processor_class = (
1025
+ AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
1026
+ )
1027
+ attn_procs[name] = attn_processor_class()
1028
+ else:
1029
+ attn_processor_class = (
1030
+ IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
1031
+ )
1032
+ num_image_text_embeds = []
1033
+ for state_dict in state_dicts:
1034
+ if "proj.weight" in state_dict["image_proj"]:
1035
+ # IP-Adapter
1036
+ num_image_text_embeds += [4]
1037
+ elif "proj.3.weight" in state_dict["image_proj"]:
1038
+ # IP-Adapter Full Face
1039
+ num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
1040
+ else:
1041
+ # IP-Adapter Plus
1042
+ num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
1043
+
1044
+ with init_context():
1045
+ attn_procs[name] = attn_processor_class(
1046
+ hidden_size=hidden_size,
1047
+ cross_attention_dim=cross_attention_dim,
1048
+ scale=1.0,
1049
+ num_tokens=num_image_text_embeds,
1050
+ )
1051
+
1052
+ value_dict = {}
1053
+ for i, state_dict in enumerate(state_dicts):
1054
+ value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
1055
+ value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
1056
+
1057
+ if not low_cpu_mem_usage:
1058
+ attn_procs[name].load_state_dict(value_dict)
1059
+ else:
1060
+ device = next(iter(value_dict.values())).device
1061
+ dtype = next(iter(value_dict.values())).dtype
1062
+ load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
1063
+
1064
+ key_id += 2
1065
+
1066
+ return attn_procs
1067
+
1068
+ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
1069
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
1070
+ self.set_attn_processor(attn_procs)
1071
+
1072
+ # convert IP-Adapter Image Projection layers to diffusers
1073
+ image_projection_layers = []
1074
+ for state_dict in state_dicts:
1075
+ image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
1076
+ state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
1077
+ )
1078
+ image_projection_layers.append(image_projection_layer)
1079
+
1080
+ self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
1081
+ self.config.encoder_hid_dim_type = "ip_image_proj"
1082
+
1083
+ self.to(dtype=self.dtype, device=self.device)
1084
+
1085
+ def _load_ip_adapter_weights_VPAdapter(self, state_dicts, low_cpu_mem_usage=False):
1086
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers_VPAdapter(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
1087
+ self.set_attn_processor(attn_procs)
1088
+
1089
+ # convert IP-Adapter Image Projection layers to diffusers
1090
+ image_projection_layers = []
1091
+ for state_dict in state_dicts:
1092
+ image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
1093
+ state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
1094
+ )
1095
+ image_projection_layers.append(image_projection_layer)
1096
+
1097
+ self.encoder_hid_proj = VPAdapterImageProjection(image_projection_layers)
1098
+ self.config.encoder_hid_dim_type = "ip_image_proj"
1099
+
1100
+ self.to(dtype=self.dtype, device=self.device)
foleycrafter/models/auffusion/resnet.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.utils import USE_PEFT_BACKEND
24
+ from diffusers.models.activations import get_activation
25
+ from diffusers.models.downsampling import ( # noqa
26
+ Downsample1D,
27
+ Downsample2D,
28
+ FirDownsample2D,
29
+ KDownsample2D,
30
+ downsample_2d,
31
+ )
32
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
33
+ from diffusers.models.normalization import AdaGroupNorm
34
+ from diffusers.models.upsampling import ( # noqa
35
+ FirUpsample2D,
36
+ KUpsample2D,
37
+ Upsample1D,
38
+ Upsample2D,
39
+ upfirdn2d_native,
40
+ upsample_2d,
41
+ )
42
+ from foleycrafter.models.auffusion.attention_processor import SpatialNorm
43
+
44
+
45
+ class ResnetBlock2D(nn.Module):
46
+ r"""
47
+ A Resnet block.
48
+
49
+ Parameters:
50
+ in_channels (`int`): The number of channels in the input.
51
+ out_channels (`int`, *optional*, default to be `None`):
52
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
53
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
54
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
55
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
56
+ groups_out (`int`, *optional*, default to None):
57
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
58
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
59
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
60
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
61
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
62
+ "ada_group" for a stronger conditioning with scale and shift.
63
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
64
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
65
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
66
+ use_in_shortcut (`bool`, *optional*, default to `True`):
67
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
68
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
69
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
70
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
71
+ `conv_shortcut` output.
72
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
73
+ If None, same as `out_channels`.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ *,
79
+ in_channels: int,
80
+ out_channels: Optional[int] = None,
81
+ conv_shortcut: bool = False,
82
+ dropout: float = 0.0,
83
+ temb_channels: int = 512,
84
+ groups: int = 32,
85
+ groups_out: Optional[int] = None,
86
+ pre_norm: bool = True,
87
+ eps: float = 1e-6,
88
+ non_linearity: str = "swish",
89
+ skip_time_act: bool = False,
90
+ time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
91
+ kernel: Optional[torch.FloatTensor] = None,
92
+ output_scale_factor: float = 1.0,
93
+ use_in_shortcut: Optional[bool] = None,
94
+ up: bool = False,
95
+ down: bool = False,
96
+ conv_shortcut_bias: bool = True,
97
+ conv_2d_out_channels: Optional[int] = None,
98
+ ):
99
+ super().__init__()
100
+ self.pre_norm = pre_norm
101
+ self.pre_norm = True
102
+ self.in_channels = in_channels
103
+ out_channels = in_channels if out_channels is None else out_channels
104
+ self.out_channels = out_channels
105
+ self.use_conv_shortcut = conv_shortcut
106
+ self.up = up
107
+ self.down = down
108
+ self.output_scale_factor = output_scale_factor
109
+ self.time_embedding_norm = time_embedding_norm
110
+ self.skip_time_act = skip_time_act
111
+
112
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
113
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
114
+
115
+ if groups_out is None:
116
+ groups_out = groups
117
+
118
+ if self.time_embedding_norm == "ada_group":
119
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
120
+ elif self.time_embedding_norm == "spatial":
121
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
122
+ else:
123
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
124
+
125
+ self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
126
+
127
+ if temb_channels is not None:
128
+ if self.time_embedding_norm == "default":
129
+ self.time_emb_proj = linear_cls(temb_channels, out_channels)
130
+ elif self.time_embedding_norm == "scale_shift":
131
+ self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
132
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
133
+ self.time_emb_proj = None
134
+ else:
135
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
136
+ else:
137
+ self.time_emb_proj = None
138
+
139
+ if self.time_embedding_norm == "ada_group":
140
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
141
+ elif self.time_embedding_norm == "spatial":
142
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
143
+ else:
144
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
145
+
146
+ self.dropout = torch.nn.Dropout(dropout)
147
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
148
+ self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
149
+
150
+ self.nonlinearity = get_activation(non_linearity)
151
+
152
+ self.upsample = self.downsample = None
153
+ if self.up:
154
+ if kernel == "fir":
155
+ fir_kernel = (1, 3, 3, 1)
156
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
157
+ elif kernel == "sde_vp":
158
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
159
+ else:
160
+ self.upsample = Upsample2D(in_channels, use_conv=False)
161
+ elif self.down:
162
+ if kernel == "fir":
163
+ fir_kernel = (1, 3, 3, 1)
164
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
165
+ elif kernel == "sde_vp":
166
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
167
+ else:
168
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
169
+
170
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
171
+
172
+ self.conv_shortcut = None
173
+ if self.use_in_shortcut:
174
+ self.conv_shortcut = conv_cls(
175
+ in_channels,
176
+ conv_2d_out_channels,
177
+ kernel_size=1,
178
+ stride=1,
179
+ padding=0,
180
+ bias=conv_shortcut_bias,
181
+ )
182
+
183
+ def forward(
184
+ self,
185
+ input_tensor: torch.FloatTensor,
186
+ temb: torch.FloatTensor,
187
+ scale: float = 1.0,
188
+ ) -> torch.FloatTensor:
189
+ hidden_states = input_tensor
190
+
191
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
192
+ hidden_states = self.norm1(hidden_states, temb)
193
+ else:
194
+ hidden_states = self.norm1(hidden_states)
195
+
196
+ hidden_states = self.nonlinearity(hidden_states)
197
+
198
+ if self.upsample is not None:
199
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
200
+ if hidden_states.shape[0] >= 64:
201
+ input_tensor = input_tensor.contiguous()
202
+ hidden_states = hidden_states.contiguous()
203
+ input_tensor = (
204
+ self.upsample(input_tensor, scale=scale)
205
+ if isinstance(self.upsample, Upsample2D)
206
+ else self.upsample(input_tensor)
207
+ )
208
+ hidden_states = (
209
+ self.upsample(hidden_states, scale=scale)
210
+ if isinstance(self.upsample, Upsample2D)
211
+ else self.upsample(hidden_states)
212
+ )
213
+ elif self.downsample is not None:
214
+ input_tensor = (
215
+ self.downsample(input_tensor, scale=scale)
216
+ if isinstance(self.downsample, Downsample2D)
217
+ else self.downsample(input_tensor)
218
+ )
219
+ hidden_states = (
220
+ self.downsample(hidden_states, scale=scale)
221
+ if isinstance(self.downsample, Downsample2D)
222
+ else self.downsample(hidden_states)
223
+ )
224
+
225
+ hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
226
+
227
+ if self.time_emb_proj is not None:
228
+ if not self.skip_time_act:
229
+ temb = self.nonlinearity(temb)
230
+ temb = (
231
+ self.time_emb_proj(temb, scale)[:, :, None, None]
232
+ if not USE_PEFT_BACKEND
233
+ # NOTE: Maybe we can use different prompt in different time
234
+ else self.time_emb_proj(temb)[:, :, None, None]
235
+ )
236
+
237
+ if temb is not None and self.time_embedding_norm == "default":
238
+ hidden_states = hidden_states + temb
239
+
240
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
241
+ hidden_states = self.norm2(hidden_states, temb)
242
+ else:
243
+ hidden_states = self.norm2(hidden_states)
244
+
245
+ if temb is not None and self.time_embedding_norm == "scale_shift":
246
+ scale, shift = torch.chunk(temb, 2, dim=1)
247
+ hidden_states = hidden_states * (1 + scale) + shift
248
+
249
+ hidden_states = self.nonlinearity(hidden_states)
250
+
251
+ hidden_states = self.dropout(hidden_states)
252
+ hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
253
+
254
+ if self.conv_shortcut is not None:
255
+ input_tensor = (
256
+ self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
257
+ )
258
+
259
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
260
+
261
+ return output_tensor
262
+
263
+
264
+ # unet_rl.py
265
+ def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
266
+ if len(tensor.shape) == 2:
267
+ return tensor[:, :, None]
268
+ if len(tensor.shape) == 3:
269
+ return tensor[:, :, None, :]
270
+ elif len(tensor.shape) == 4:
271
+ return tensor[:, :, 0, :]
272
+ else:
273
+ raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
274
+
275
+
276
+ class Conv1dBlock(nn.Module):
277
+ """
278
+ Conv1d --> GroupNorm --> Mish
279
+
280
+ Parameters:
281
+ inp_channels (`int`): Number of input channels.
282
+ out_channels (`int`): Number of output channels.
283
+ kernel_size (`int` or `tuple`): Size of the convolving kernel.
284
+ n_groups (`int`, default `8`): Number of groups to separate the channels into.
285
+ activation (`str`, defaults to `mish`): Name of the activation function.
286
+ """
287
+
288
+ def __init__(
289
+ self,
290
+ inp_channels: int,
291
+ out_channels: int,
292
+ kernel_size: Union[int, Tuple[int, int]],
293
+ n_groups: int = 8,
294
+ activation: str = "mish",
295
+ ):
296
+ super().__init__()
297
+
298
+ self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
299
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
300
+ self.mish = get_activation(activation)
301
+
302
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
303
+ intermediate_repr = self.conv1d(inputs)
304
+ intermediate_repr = rearrange_dims(intermediate_repr)
305
+ intermediate_repr = self.group_norm(intermediate_repr)
306
+ intermediate_repr = rearrange_dims(intermediate_repr)
307
+ output = self.mish(intermediate_repr)
308
+ return output
309
+
310
+
311
+ # unet_rl.py
312
+ class ResidualTemporalBlock1D(nn.Module):
313
+ """
314
+ Residual 1D block with temporal convolutions.
315
+
316
+ Parameters:
317
+ inp_channels (`int`): Number of input channels.
318
+ out_channels (`int`): Number of output channels.
319
+ embed_dim (`int`): Embedding dimension.
320
+ kernel_size (`int` or `tuple`): Size of the convolving kernel.
321
+ activation (`str`, defaults `mish`): It is possible to choose the right activation function.
322
+ """
323
+
324
+ def __init__(
325
+ self,
326
+ inp_channels: int,
327
+ out_channels: int,
328
+ embed_dim: int,
329
+ kernel_size: Union[int, Tuple[int, int]] = 5,
330
+ activation: str = "mish",
331
+ ):
332
+ super().__init__()
333
+ self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
334
+ self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
335
+
336
+ self.time_emb_act = get_activation(activation)
337
+ self.time_emb = nn.Linear(embed_dim, out_channels)
338
+
339
+ self.residual_conv = (
340
+ nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
341
+ )
342
+
343
+ def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
344
+ """
345
+ Args:
346
+ inputs : [ batch_size x inp_channels x horizon ]
347
+ t : [ batch_size x embed_dim ]
348
+
349
+ returns:
350
+ out : [ batch_size x out_channels x horizon ]
351
+ """
352
+ t = self.time_emb_act(t)
353
+ t = self.time_emb(t)
354
+ out = self.conv_in(inputs) + rearrange_dims(t)
355
+ out = self.conv_out(out)
356
+ return out + self.residual_conv(inputs)
357
+
358
+
359
+ class TemporalConvLayer(nn.Module):
360
+ """
361
+ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
362
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
363
+
364
+ Parameters:
365
+ in_dim (`int`): Number of input channels.
366
+ out_dim (`int`): Number of output channels.
367
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
368
+ """
369
+
370
+ def __init__(
371
+ self,
372
+ in_dim: int,
373
+ out_dim: Optional[int] = None,
374
+ dropout: float = 0.0,
375
+ norm_num_groups: int = 32,
376
+ ):
377
+ super().__init__()
378
+ out_dim = out_dim or in_dim
379
+ self.in_dim = in_dim
380
+ self.out_dim = out_dim
381
+
382
+ # conv layers
383
+ self.conv1 = nn.Sequential(
384
+ nn.GroupNorm(norm_num_groups, in_dim),
385
+ nn.SiLU(),
386
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
387
+ )
388
+ self.conv2 = nn.Sequential(
389
+ nn.GroupNorm(norm_num_groups, out_dim),
390
+ nn.SiLU(),
391
+ nn.Dropout(dropout),
392
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
393
+ )
394
+ self.conv3 = nn.Sequential(
395
+ nn.GroupNorm(norm_num_groups, out_dim),
396
+ nn.SiLU(),
397
+ nn.Dropout(dropout),
398
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
399
+ )
400
+ self.conv4 = nn.Sequential(
401
+ nn.GroupNorm(norm_num_groups, out_dim),
402
+ nn.SiLU(),
403
+ nn.Dropout(dropout),
404
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
405
+ )
406
+
407
+ # zero out the last layer params,so the conv block is identity
408
+ nn.init.zeros_(self.conv4[-1].weight)
409
+ nn.init.zeros_(self.conv4[-1].bias)
410
+
411
+ def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
412
+ hidden_states = (
413
+ hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
414
+ )
415
+
416
+ identity = hidden_states
417
+ hidden_states = self.conv1(hidden_states)
418
+ hidden_states = self.conv2(hidden_states)
419
+ hidden_states = self.conv3(hidden_states)
420
+ hidden_states = self.conv4(hidden_states)
421
+
422
+ hidden_states = identity + hidden_states
423
+
424
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
425
+ (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
426
+ )
427
+ return hidden_states
428
+
429
+
430
+ class TemporalResnetBlock(nn.Module):
431
+ r"""
432
+ A Resnet block.
433
+
434
+ Parameters:
435
+ in_channels (`int`): The number of channels in the input.
436
+ out_channels (`int`, *optional*, default to be `None`):
437
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
438
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
439
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
440
+ """
441
+
442
+ def __init__(
443
+ self,
444
+ in_channels: int,
445
+ out_channels: Optional[int] = None,
446
+ temb_channels: int = 512,
447
+ eps: float = 1e-6,
448
+ ):
449
+ super().__init__()
450
+ self.in_channels = in_channels
451
+ out_channels = in_channels if out_channels is None else out_channels
452
+ self.out_channels = out_channels
453
+
454
+ kernel_size = (3, 1, 1)
455
+ padding = [k // 2 for k in kernel_size]
456
+
457
+ self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)
458
+ self.conv1 = nn.Conv3d(
459
+ in_channels,
460
+ out_channels,
461
+ kernel_size=kernel_size,
462
+ stride=1,
463
+ padding=padding,
464
+ )
465
+
466
+ if temb_channels is not None:
467
+ self.time_emb_proj = nn.Linear(temb_channels, out_channels)
468
+ else:
469
+ self.time_emb_proj = None
470
+
471
+ self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)
472
+
473
+ self.dropout = torch.nn.Dropout(0.0)
474
+ self.conv2 = nn.Conv3d(
475
+ out_channels,
476
+ out_channels,
477
+ kernel_size=kernel_size,
478
+ stride=1,
479
+ padding=padding,
480
+ )
481
+
482
+ self.nonlinearity = get_activation("silu")
483
+
484
+ self.use_in_shortcut = self.in_channels != out_channels
485
+
486
+ self.conv_shortcut = None
487
+ if self.use_in_shortcut:
488
+ self.conv_shortcut = nn.Conv3d(
489
+ in_channels,
490
+ out_channels,
491
+ kernel_size=1,
492
+ stride=1,
493
+ padding=0,
494
+ )
495
+
496
+ def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
497
+ hidden_states = input_tensor
498
+
499
+ hidden_states = self.norm1(hidden_states)
500
+ hidden_states = self.nonlinearity(hidden_states)
501
+ hidden_states = self.conv1(hidden_states)
502
+
503
+ if self.time_emb_proj is not None:
504
+ temb = self.nonlinearity(temb)
505
+ temb = self.time_emb_proj(temb)[:, :, :, None, None]
506
+ temb = temb.permute(0, 2, 1, 3, 4)
507
+ hidden_states = hidden_states + temb
508
+
509
+ hidden_states = self.norm2(hidden_states)
510
+ hidden_states = self.nonlinearity(hidden_states)
511
+ hidden_states = self.dropout(hidden_states)
512
+ hidden_states = self.conv2(hidden_states)
513
+
514
+ if self.conv_shortcut is not None:
515
+ input_tensor = self.conv_shortcut(input_tensor)
516
+
517
+ output_tensor = input_tensor + hidden_states
518
+
519
+ return output_tensor
520
+
521
+
522
+ # VideoResBlock
523
+ class SpatioTemporalResBlock(nn.Module):
524
+ r"""
525
+ A SpatioTemporal Resnet block.
526
+
527
+ Parameters:
528
+ in_channels (`int`): The number of channels in the input.
529
+ out_channels (`int`, *optional*, default to be `None`):
530
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
531
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
532
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
533
+ temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
534
+ merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
535
+ merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
536
+ The merge strategy to use for the temporal mixing.
537
+ switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
538
+ If `True`, switch the spatial and temporal mixing.
539
+ """
540
+
541
+ def __init__(
542
+ self,
543
+ in_channels: int,
544
+ out_channels: Optional[int] = None,
545
+ temb_channels: int = 512,
546
+ eps: float = 1e-6,
547
+ temporal_eps: Optional[float] = None,
548
+ merge_factor: float = 0.5,
549
+ merge_strategy="learned_with_images",
550
+ switch_spatial_to_temporal_mix: bool = False,
551
+ ):
552
+ super().__init__()
553
+
554
+ self.spatial_res_block = ResnetBlock2D(
555
+ in_channels=in_channels,
556
+ out_channels=out_channels,
557
+ temb_channels=temb_channels,
558
+ eps=eps,
559
+ )
560
+
561
+ self.temporal_res_block = TemporalResnetBlock(
562
+ in_channels=out_channels if out_channels is not None else in_channels,
563
+ out_channels=out_channels if out_channels is not None else in_channels,
564
+ temb_channels=temb_channels,
565
+ eps=temporal_eps if temporal_eps is not None else eps,
566
+ )
567
+
568
+ self.time_mixer = AlphaBlender(
569
+ alpha=merge_factor,
570
+ merge_strategy=merge_strategy,
571
+ switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
572
+ )
573
+
574
+ def forward(
575
+ self,
576
+ hidden_states: torch.FloatTensor,
577
+ temb: Optional[torch.FloatTensor] = None,
578
+ image_only_indicator: Optional[torch.Tensor] = None,
579
+ ):
580
+ num_frames = image_only_indicator.shape[-1]
581
+ hidden_states = self.spatial_res_block(hidden_states, temb)
582
+
583
+ batch_frames, channels, height, width = hidden_states.shape
584
+ batch_size = batch_frames // num_frames
585
+
586
+ hidden_states_mix = (
587
+ hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
588
+ )
589
+ hidden_states = (
590
+ hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
591
+ )
592
+
593
+ if temb is not None:
594
+ temb = temb.reshape(batch_size, num_frames, -1)
595
+
596
+ hidden_states = self.temporal_res_block(hidden_states, temb)
597
+ hidden_states = self.time_mixer(
598
+ x_spatial=hidden_states_mix,
599
+ x_temporal=hidden_states,
600
+ image_only_indicator=image_only_indicator,
601
+ )
602
+
603
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
604
+ return hidden_states
605
+
606
+
607
+ class AlphaBlender(nn.Module):
608
+ r"""
609
+ A module to blend spatial and temporal features.
610
+
611
+ Parameters:
612
+ alpha (`float`): The initial value of the blending factor.
613
+ merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
614
+ The merge strategy to use for the temporal mixing.
615
+ switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
616
+ If `True`, switch the spatial and temporal mixing.
617
+ """
618
+
619
+ strategies = ["learned", "fixed", "learned_with_images"]
620
+
621
+ def __init__(
622
+ self,
623
+ alpha: float,
624
+ merge_strategy: str = "learned_with_images",
625
+ switch_spatial_to_temporal_mix: bool = False,
626
+ ):
627
+ super().__init__()
628
+ self.merge_strategy = merge_strategy
629
+ self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE
630
+
631
+ if merge_strategy not in self.strategies:
632
+ raise ValueError(f"merge_strategy needs to be in {self.strategies}")
633
+
634
+ if self.merge_strategy == "fixed":
635
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
636
+ elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
637
+ self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
638
+ else:
639
+ raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
640
+
641
+ def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
642
+ if self.merge_strategy == "fixed":
643
+ alpha = self.mix_factor
644
+
645
+ elif self.merge_strategy == "learned":
646
+ alpha = torch.sigmoid(self.mix_factor)
647
+
648
+ elif self.merge_strategy == "learned_with_images":
649
+ if image_only_indicator is None:
650
+ raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")
651
+
652
+ alpha = torch.where(
653
+ image_only_indicator.bool(),
654
+ torch.ones(1, 1, device=image_only_indicator.device),
655
+ torch.sigmoid(self.mix_factor)[..., None],
656
+ )
657
+
658
+ # (batch, channel, frames, height, width)
659
+ if ndims == 5:
660
+ alpha = alpha[:, None, :, None, None]
661
+ # (batch*frames, height*width, channels)
662
+ elif ndims == 3:
663
+ alpha = alpha.reshape(-1)[:, None, None]
664
+ else:
665
+ raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
666
+
667
+ else:
668
+ raise NotImplementedError
669
+
670
+ return alpha
671
+
672
+ def forward(
673
+ self,
674
+ x_spatial: torch.Tensor,
675
+ x_temporal: torch.Tensor,
676
+ image_only_indicator: Optional[torch.Tensor] = None,
677
+ ) -> torch.Tensor:
678
+ alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
679
+ alpha = alpha.to(x_spatial.dtype)
680
+
681
+ if self.switch_spatial_to_temporal_mix:
682
+ alpha = 1.0 - alpha
683
+
684
+ x = alpha * x_spatial + (1.0 - alpha) * x_temporal
685
+ return x
foleycrafter/models/auffusion/transformer_2d.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
24
+ from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection
25
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from diffusers.models.normalization import AdaLayerNormSingle
28
+
29
+ from foleycrafter.models.auffusion.attention import BasicTransformerBlock
30
+
31
+ @dataclass
32
+ class Transformer2DModelOutput(BaseOutput):
33
+ """
34
+ The output of [`Transformer2DModel`].
35
+
36
+ Args:
37
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
38
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
39
+ distributions for the unnoised latent pixels.
40
+ """
41
+
42
+ sample: torch.FloatTensor
43
+
44
+ class Transformer2DModel(ModelMixin, ConfigMixin):
45
+ """
46
+ A 2D Transformer model for image-like data.
47
+
48
+ Parameters:
49
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
50
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
51
+ in_channels (`int`, *optional*):
52
+ The number of channels in the input and output (specify if the input is **continuous**).
53
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
54
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
55
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
56
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
57
+ This is fixed during training since it is used to learn a number of position embeddings.
58
+ num_vector_embeds (`int`, *optional*):
59
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
60
+ Includes the class for the masked latent pixel.
61
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
62
+ num_embeds_ada_norm ( `int`, *optional*):
63
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
64
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
65
+ added to the hidden states.
66
+
67
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
68
+ attention_bias (`bool`, *optional*):
69
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
70
+ """
71
+
72
+ _supports_gradient_checkpointing = True
73
+
74
+ @register_to_config
75
+ def __init__(
76
+ self,
77
+ num_attention_heads: int = 16,
78
+ attention_head_dim: int = 88,
79
+ in_channels: Optional[int] = None,
80
+ out_channels: Optional[int] = None,
81
+ num_layers: int = 1,
82
+ dropout: float = 0.0,
83
+ norm_num_groups: int = 32,
84
+ cross_attention_dim: Optional[int] = None,
85
+ attention_bias: bool = False,
86
+ sample_size: Optional[int] = None,
87
+ num_vector_embeds: Optional[int] = None,
88
+ patch_size: Optional[int] = None,
89
+ activation_fn: str = "geglu",
90
+ num_embeds_ada_norm: Optional[int] = None,
91
+ use_linear_projection: bool = False,
92
+ only_cross_attention: bool = False,
93
+ double_self_attention: bool = False,
94
+ upcast_attention: bool = False,
95
+ norm_type: str = "layer_norm",
96
+ norm_elementwise_affine: bool = True,
97
+ norm_eps: float = 1e-5,
98
+ attention_type: str = "default",
99
+ caption_channels: int = None,
100
+ ):
101
+ super().__init__()
102
+ self.use_linear_projection = use_linear_projection
103
+ self.num_attention_heads = num_attention_heads
104
+ self.attention_head_dim = attention_head_dim
105
+ inner_dim = num_attention_heads * attention_head_dim
106
+
107
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
108
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
109
+
110
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
111
+ # Define whether input is continuous or discrete depending on configuration
112
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
113
+ self.is_input_vectorized = num_vector_embeds is not None
114
+ self.is_input_patches = in_channels is not None and patch_size is not None
115
+
116
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
117
+ deprecation_message = (
118
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
119
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
120
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
121
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
122
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
123
+ )
124
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
125
+ norm_type = "ada_norm"
126
+
127
+ if self.is_input_continuous and self.is_input_vectorized:
128
+ raise ValueError(
129
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
130
+ " sure that either `in_channels` or `num_vector_embeds` is None."
131
+ )
132
+ elif self.is_input_vectorized and self.is_input_patches:
133
+ raise ValueError(
134
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
135
+ " sure that either `num_vector_embeds` or `num_patches` is None."
136
+ )
137
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
138
+ raise ValueError(
139
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
140
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
141
+ )
142
+
143
+ # 2. Define input layers
144
+ if self.is_input_continuous:
145
+ self.in_channels = in_channels
146
+
147
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
148
+ if use_linear_projection:
149
+ self.proj_in = linear_cls(in_channels, inner_dim)
150
+ else:
151
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
152
+ elif self.is_input_vectorized:
153
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
154
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
155
+
156
+ self.height = sample_size
157
+ self.width = sample_size
158
+ self.num_vector_embeds = num_vector_embeds
159
+ self.num_latent_pixels = self.height * self.width
160
+
161
+ self.latent_image_embedding = ImagePositionalEmbeddings(
162
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
163
+ )
164
+ elif self.is_input_patches:
165
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
166
+
167
+ self.height = sample_size
168
+ self.width = sample_size
169
+
170
+ self.patch_size = patch_size
171
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
172
+ interpolation_scale = max(interpolation_scale, 1)
173
+ self.pos_embed = PatchEmbed(
174
+ height=sample_size,
175
+ width=sample_size,
176
+ patch_size=patch_size,
177
+ in_channels=in_channels,
178
+ embed_dim=inner_dim,
179
+ interpolation_scale=interpolation_scale,
180
+ )
181
+
182
+ # 3. Define transformers blocks
183
+ self.transformer_blocks = nn.ModuleList(
184
+ [
185
+ # NOTE: remember to change
186
+ BasicTransformerBlock(
187
+ inner_dim,
188
+ num_attention_heads,
189
+ attention_head_dim,
190
+ dropout=dropout,
191
+ cross_attention_dim=cross_attention_dim,
192
+ activation_fn=activation_fn,
193
+ num_embeds_ada_norm=num_embeds_ada_norm,
194
+ attention_bias=attention_bias,
195
+ only_cross_attention=only_cross_attention,
196
+ double_self_attention=double_self_attention,
197
+ upcast_attention=upcast_attention,
198
+ norm_type=norm_type,
199
+ norm_elementwise_affine=norm_elementwise_affine,
200
+ norm_eps=norm_eps,
201
+ attention_type=attention_type,
202
+ )
203
+ for d in range(num_layers)
204
+ ]
205
+ )
206
+
207
+ # 4. Define output layers
208
+ self.out_channels = in_channels if out_channels is None else out_channels
209
+ if self.is_input_continuous:
210
+ # TODO: should use out_channels for continuous projections
211
+ if use_linear_projection:
212
+ self.proj_out = linear_cls(inner_dim, in_channels)
213
+ else:
214
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
215
+ elif self.is_input_vectorized:
216
+ self.norm_out = nn.LayerNorm(inner_dim)
217
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
218
+ elif self.is_input_patches and norm_type != "ada_norm_single":
219
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
220
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
221
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
222
+ elif self.is_input_patches and norm_type == "ada_norm_single":
223
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
224
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
225
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
226
+
227
+ # 5. PixArt-Alpha blocks.
228
+ self.adaln_single = None
229
+ self.use_additional_conditions = False
230
+ if norm_type == "ada_norm_single":
231
+ self.use_additional_conditions = self.config.sample_size == 128
232
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
233
+ # additional conditions until we find better name
234
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
235
+
236
+ self.caption_projection = None
237
+ if caption_channels is not None:
238
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
239
+
240
+ self.gradient_checkpointing = False
241
+
242
+ def _set_gradient_checkpointing(self, module, value=False):
243
+ if hasattr(module, "gradient_checkpointing"):
244
+ module.gradient_checkpointing = value
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ encoder_hidden_states: Optional[torch.Tensor] = None,
250
+ timestep: Optional[torch.LongTensor] = None,
251
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
252
+ class_labels: Optional[torch.LongTensor] = None,
253
+ cross_attention_kwargs: Dict[str, Any] = None,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ encoder_attention_mask: Optional[torch.Tensor] = None,
256
+ return_dict: bool = True,
257
+ ):
258
+ """
259
+ The [`Transformer2DModel`] forward method.
260
+
261
+ Args:
262
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
263
+ Input `hidden_states`.
264
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
265
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
266
+ self-attention.
267
+ timestep ( `torch.LongTensor`, *optional*):
268
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
269
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
270
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
271
+ `AdaLayerZeroNorm`.
272
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
273
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
274
+ `self.processor` in
275
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
276
+ attention_mask ( `torch.Tensor`, *optional*):
277
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
278
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
279
+ negative values to the attention scores corresponding to "discard" tokens.
280
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
281
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
282
+
283
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
284
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
285
+
286
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
287
+ above. This bias will be added to the cross-attention scores.
288
+ return_dict (`bool`, *optional*, defaults to `True`):
289
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
290
+ tuple.
291
+
292
+ Returns:
293
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
294
+ `tuple` where the first element is the sample tensor.
295
+ """
296
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
297
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
298
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
299
+ # expects mask of shape:
300
+ # [batch, key_tokens]
301
+ # adds singleton query_tokens dimension:
302
+ # [batch, 1, key_tokens]
303
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
304
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
305
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
306
+ if attention_mask is not None and attention_mask.ndim == 2:
307
+ # assume that mask is expressed as:
308
+ # (1 = keep, 0 = discard)
309
+ # convert mask into a bias that can be added to attention scores:
310
+ # (keep = +0, discard = -10000.0)
311
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
312
+ attention_mask = attention_mask.unsqueeze(1)
313
+
314
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
315
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
316
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
317
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
318
+
319
+ # Retrieve lora scale.
320
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
321
+
322
+ # 1. Input
323
+ if self.is_input_continuous:
324
+ batch, _, height, width = hidden_states.shape
325
+ inner_dim = hidden_states.shape[1]
326
+ residual = hidden_states
327
+
328
+ hidden_states = self.norm(hidden_states)
329
+ if not self.use_linear_projection:
330
+ hidden_states = (
331
+ self.proj_in(hidden_states, scale=lora_scale)
332
+ if not USE_PEFT_BACKEND
333
+ else self.proj_in(hidden_states)
334
+ )
335
+ inner_dim = hidden_states.shape[1]
336
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
337
+ else:
338
+ inner_dim = hidden_states.shape[1]
339
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
340
+ hidden_states = (
341
+ self.proj_in(hidden_states, scale=lora_scale)
342
+ if not USE_PEFT_BACKEND
343
+ else self.proj_in(hidden_states)
344
+ )
345
+
346
+ elif self.is_input_vectorized:
347
+ hidden_states = self.latent_image_embedding(hidden_states)
348
+ elif self.is_input_patches:
349
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
350
+ self.height, self.width = height, width
351
+ hidden_states = self.pos_embed(hidden_states)
352
+
353
+ if self.adaln_single is not None:
354
+ if self.use_additional_conditions and added_cond_kwargs is None:
355
+ raise ValueError(
356
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
357
+ )
358
+ batch_size = hidden_states.shape[0]
359
+ timestep, embedded_timestep = self.adaln_single(
360
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
361
+ )
362
+
363
+ if self.caption_projection is not None:
364
+ batch_size = hidden_states.shape[0]
365
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
366
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
367
+ # 2. Blocks
368
+ for block in self.transformer_blocks:
369
+ if self.training and self.gradient_checkpointing:
370
+
371
+ def create_custom_forward(module, return_dict=None):
372
+ def custom_forward(*inputs):
373
+ if return_dict is not None:
374
+ return module(*inputs, return_dict=return_dict)
375
+ else:
376
+ return module(*inputs)
377
+
378
+ return custom_forward
379
+
380
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
381
+ hidden_states = torch.utils.checkpoint.checkpoint(
382
+ create_custom_forward(block),
383
+ hidden_states,
384
+ attention_mask,
385
+ encoder_hidden_states,
386
+ encoder_attention_mask,
387
+ timestep,
388
+ cross_attention_kwargs,
389
+ class_labels,
390
+ **ckpt_kwargs,
391
+ )
392
+ else:
393
+ hidden_states = block(
394
+ hidden_states,
395
+ attention_mask=attention_mask,
396
+ encoder_hidden_states=encoder_hidden_states,
397
+ encoder_attention_mask=encoder_attention_mask,
398
+ timestep=timestep,
399
+ cross_attention_kwargs=cross_attention_kwargs,
400
+ class_labels=class_labels,
401
+ )
402
+
403
+ # 3. Output
404
+ if self.is_input_continuous:
405
+ if not self.use_linear_projection:
406
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
407
+ hidden_states = (
408
+ self.proj_out(hidden_states, scale=lora_scale)
409
+ if not USE_PEFT_BACKEND
410
+ else self.proj_out(hidden_states)
411
+ )
412
+ else:
413
+ hidden_states = (
414
+ self.proj_out(hidden_states, scale=lora_scale)
415
+ if not USE_PEFT_BACKEND
416
+ else self.proj_out(hidden_states)
417
+ )
418
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
419
+
420
+ output = hidden_states + residual
421
+ elif self.is_input_vectorized:
422
+ hidden_states = self.norm_out(hidden_states)
423
+ logits = self.out(hidden_states)
424
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
425
+ logits = logits.permute(0, 2, 1)
426
+
427
+ # log(p(x_0))
428
+ output = F.log_softmax(logits.double(), dim=1).float()
429
+
430
+ if self.is_input_patches:
431
+ if self.config.norm_type != "ada_norm_single":
432
+ conditioning = self.transformer_blocks[0].norm1.emb(
433
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
434
+ )
435
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
436
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
437
+ hidden_states = self.proj_out_2(hidden_states)
438
+ elif self.config.norm_type == "ada_norm_single":
439
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
440
+ hidden_states = self.norm_out(hidden_states)
441
+ # Modulation
442
+ hidden_states = hidden_states * (1 + scale) + shift
443
+ hidden_states = self.proj_out(hidden_states)
444
+ hidden_states = hidden_states.squeeze(1)
445
+
446
+ # unpatchify
447
+ if self.adaln_single is None:
448
+ height = width = int(hidden_states.shape[1] ** 0.5)
449
+ hidden_states = hidden_states.reshape(
450
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
451
+ )
452
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
453
+ output = hidden_states.reshape(
454
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
455
+ )
456
+
457
+ if not return_dict:
458
+ return (output,)
459
+
460
+ return Transformer2DModelOutput(sample=output)
foleycrafter/models/auffusion/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
foleycrafter/models/auffusion_unet.py ADDED
@@ -0,0 +1,1260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.utils.import_utils import is_xformers_available, is_torch_version
23
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
24
+ from diffusers.models.activations import get_activation
25
+ # from diffusers import StableDiffusionGLIGENPipeline
26
+ from diffusers.models.attention_processor import (
27
+ ADDED_KV_ATTENTION_PROCESSORS,
28
+ CROSS_ATTENTION_PROCESSORS,
29
+ Attention,
30
+ AttentionProcessor,
31
+ AttnAddedKVProcessor,
32
+ AttnProcessor,
33
+ XFormersAttnProcessor,
34
+ )
35
+ from diffusers.models.embeddings import (
36
+ GaussianFourierProjection,
37
+ ImageHintTimeEmbedding,
38
+ ImageProjection,
39
+ ImageTimeEmbedding,
40
+ PositionNet,
41
+ TextImageProjection,
42
+ TextImageTimeEmbedding,
43
+ TextTimeEmbedding,
44
+ TimestepEmbedding,
45
+ Timesteps,
46
+ )
47
+ from diffusers.models.modeling_utils import ModelMixin
48
+
49
+ from foleycrafter.models.auffusion.unet_2d_blocks import (
50
+ UNetMidBlock2D,
51
+ UNetMidBlock2DCrossAttn,
52
+ UNetMidBlock2DSimpleCrossAttn,
53
+ get_down_block,
54
+ get_up_block,
55
+ )
56
+
57
+ from foleycrafter.models.auffusion.attention_processor\
58
+ import AttnProcessor2_0
59
+ from foleycrafter.models.adapters.ip_adapter import TimeProjModel
60
+ from foleycrafter.models.auffusion.loaders.unet import UNet2DConditionLoadersMixin
61
+
62
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
63
+
64
+
65
+ @dataclass
66
+ class UNet2DConditionOutput(BaseOutput):
67
+ """
68
+ The output of [`UNet2DConditionModel`].
69
+
70
+ Args:
71
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
72
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
73
+ """
74
+
75
+ sample: torch.FloatTensor = None
76
+
77
+
78
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
79
+ r"""
80
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
81
+ shaped output.
82
+
83
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
84
+ for all models (such as downloading or saving).
85
+
86
+ Parameters:
87
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
88
+ Height and width of input/output sample.
89
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
90
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
91
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
92
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
93
+ Whether to flip the sin to cos in the time embedding.
94
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
95
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
96
+ The tuple of downsample blocks to use.
97
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
98
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
99
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
100
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
101
+ The tuple of upsample blocks to use.
102
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
103
+ Whether to include self-attention in the basic transformer blocks, see
104
+ [`~models.attention.BasicTransformerBlock`].
105
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
106
+ The tuple of output channels for each block.
107
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
108
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
109
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
110
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
111
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
112
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
113
+ If `None`, normalization and activation layers is skipped in post-processing.
114
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
115
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
116
+ The dimension of the cross attention features.
117
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
118
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
119
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
120
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
121
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
122
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
123
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
124
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
125
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
126
+ encoder_hid_dim (`int`, *optional*, defaults to None):
127
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
128
+ dimension to `cross_attention_dim`.
129
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
130
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
131
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
132
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
133
+ num_attention_heads (`int`, *optional*):
134
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
135
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
136
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
137
+ class_embed_type (`str`, *optional*, defaults to `None`):
138
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
139
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
140
+ addition_embed_type (`str`, *optional*, defaults to `None`):
141
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
142
+ "text". "text" will use the `TextTimeEmbedding` layer.
143
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
144
+ Dimension for the timestep embeddings.
145
+ num_class_embeds (`int`, *optional*, defaults to `None`):
146
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
147
+ class conditioning with `class_embed_type` equal to `None`.
148
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
149
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
150
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
151
+ An optional override for the dimension of the projected time embedding.
152
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
153
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
154
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
155
+ timestep_post_act (`str`, *optional*, defaults to `None`):
156
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
157
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
158
+ The dimension of `cond_proj` layer in the timestep embedding.
159
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
160
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
161
+ *optional*): The dimension of the `class_labels` input when
162
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
163
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
164
+ embeddings with the class embeddings.
165
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
166
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
167
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
168
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
169
+ otherwise.
170
+ """
171
+
172
+ _supports_gradient_checkpointing = True
173
+
174
+ @register_to_config
175
+ def __init__(
176
+ self,
177
+ sample_size: Optional[int] = None,
178
+ in_channels: int = 4,
179
+ out_channels: int = 4,
180
+ center_input_sample: bool = False,
181
+ flip_sin_to_cos: bool = True,
182
+ freq_shift: int = 0,
183
+ down_block_types: Tuple[str] = (
184
+ "CrossAttnDownBlock2D",
185
+ "CrossAttnDownBlock2D",
186
+ "CrossAttnDownBlock2D",
187
+ "DownBlock2D",
188
+ ),
189
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
190
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
191
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
192
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
193
+ layers_per_block: Union[int, Tuple[int]] = 2,
194
+ downsample_padding: int = 1,
195
+ mid_block_scale_factor: float = 1,
196
+ dropout: float = 0.0,
197
+ act_fn: str = "silu",
198
+ norm_num_groups: Optional[int] = 32,
199
+ norm_eps: float = 1e-5,
200
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
201
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
202
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
203
+ encoder_hid_dim: Optional[int] = None,
204
+ encoder_hid_dim_type: Optional[str] = None,
205
+ attention_head_dim: Union[int, Tuple[int]] = 8,
206
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
207
+ dual_cross_attention: bool = False,
208
+ use_linear_projection: bool = False,
209
+ class_embed_type: Optional[str] = None,
210
+ addition_embed_type: Optional[str] = None,
211
+ addition_time_embed_dim: Optional[int] = None,
212
+ num_class_embeds: Optional[int] = None,
213
+ upcast_attention: bool = False,
214
+ resnet_time_scale_shift: str = "default",
215
+ resnet_skip_time_act: bool = False,
216
+ resnet_out_scale_factor: int = 1.0,
217
+ time_embedding_type: str = "positional",
218
+ time_embedding_dim: Optional[int] = None,
219
+ time_embedding_act_fn: Optional[str] = None,
220
+ timestep_post_act: Optional[str] = None,
221
+ time_cond_proj_dim: Optional[int] = None,
222
+ conv_in_kernel: int = 3,
223
+ conv_out_kernel: int = 3,
224
+ projection_class_embeddings_input_dim: Optional[int] = None,
225
+ attention_type: str = "default",
226
+ class_embeddings_concat: bool = False,
227
+ mid_block_only_cross_attention: Optional[bool] = None,
228
+ cross_attention_norm: Optional[str] = None,
229
+ addition_embed_type_num_heads=64,
230
+
231
+ # param for joint
232
+ video_feature_dim: tuple=(320, 640, 1280, 1280),
233
+ video_cross_attn_dim: int=1024,
234
+ video_frame_nums: int=16,
235
+ ):
236
+ super().__init__()
237
+
238
+ self.sample_size = sample_size
239
+
240
+ if num_attention_heads is not None:
241
+ raise ValueError(
242
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
243
+ )
244
+
245
+ # If `num_attention_heads` is not defined (which is the case for most models)
246
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
247
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
248
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
249
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
250
+ # which is why we correct for the naming here.
251
+ num_attention_heads = num_attention_heads or attention_head_dim
252
+
253
+ # Check inputs
254
+ if len(down_block_types) != len(up_block_types):
255
+ raise ValueError(
256
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
257
+ )
258
+
259
+ if len(block_out_channels) != len(down_block_types):
260
+ raise ValueError(
261
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
262
+ )
263
+
264
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
265
+ raise ValueError(
266
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
267
+ )
268
+
269
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
270
+ raise ValueError(
271
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
272
+ )
273
+
274
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
275
+ raise ValueError(
276
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
277
+ )
278
+
279
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
280
+ raise ValueError(
281
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
282
+ )
283
+
284
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
285
+ raise ValueError(
286
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
287
+ )
288
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
289
+ for layer_number_per_block in transformer_layers_per_block:
290
+ if isinstance(layer_number_per_block, list):
291
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
292
+
293
+ # input
294
+ conv_in_padding = (conv_in_kernel - 1) // 2
295
+ self.conv_in = nn.Conv2d(
296
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
297
+ )
298
+
299
+ # time
300
+ if time_embedding_type == "fourier":
301
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
302
+ if time_embed_dim % 2 != 0:
303
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
304
+ self.time_proj = GaussianFourierProjection(
305
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
306
+ )
307
+ timestep_input_dim = time_embed_dim
308
+ elif time_embedding_type == "positional":
309
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
310
+
311
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
312
+ timestep_input_dim = block_out_channels[0]
313
+ else:
314
+ raise ValueError(
315
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
316
+ )
317
+
318
+ self.time_embedding = TimestepEmbedding(
319
+ timestep_input_dim,
320
+ time_embed_dim,
321
+ act_fn=act_fn,
322
+ post_act_fn=timestep_post_act,
323
+ cond_proj_dim=time_cond_proj_dim,
324
+ )
325
+
326
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
327
+ encoder_hid_dim_type = "text_proj"
328
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
329
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
330
+
331
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
332
+ raise ValueError(
333
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
334
+ )
335
+
336
+ if encoder_hid_dim_type == "text_proj":
337
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
338
+ elif encoder_hid_dim_type == "text_image_proj":
339
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
340
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
341
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
342
+ self.encoder_hid_proj = TextImageProjection(
343
+ text_embed_dim=encoder_hid_dim,
344
+ image_embed_dim=cross_attention_dim,
345
+ cross_attention_dim=cross_attention_dim,
346
+ )
347
+ elif encoder_hid_dim_type == "image_proj":
348
+ # Kandinsky 2.2
349
+ self.encoder_hid_proj = ImageProjection(
350
+ image_embed_dim=encoder_hid_dim,
351
+ cross_attention_dim=cross_attention_dim,
352
+ )
353
+ elif encoder_hid_dim_type is not None:
354
+ raise ValueError(
355
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
356
+ )
357
+ else:
358
+ self.encoder_hid_proj = None
359
+
360
+ # class embedding
361
+ if class_embed_type is None and num_class_embeds is not None:
362
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
363
+ elif class_embed_type == "timestep":
364
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
365
+ elif class_embed_type == "identity":
366
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
367
+ elif class_embed_type == "projection":
368
+ if projection_class_embeddings_input_dim is None:
369
+ raise ValueError(
370
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
371
+ )
372
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
373
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
374
+ # 2. it projects from an arbitrary input dimension.
375
+ #
376
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
377
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
378
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
379
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
380
+ elif class_embed_type == "simple_projection":
381
+ if projection_class_embeddings_input_dim is None:
382
+ raise ValueError(
383
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
384
+ )
385
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
386
+ else:
387
+ self.class_embedding = None
388
+
389
+ if addition_embed_type == "text":
390
+ if encoder_hid_dim is not None:
391
+ text_time_embedding_from_dim = encoder_hid_dim
392
+ else:
393
+ text_time_embedding_from_dim = cross_attention_dim
394
+
395
+ self.add_embedding = TextTimeEmbedding(
396
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
397
+ )
398
+ elif addition_embed_type == "text_image":
399
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
400
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
401
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
402
+ self.add_embedding = TextImageTimeEmbedding(
403
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
404
+ )
405
+ elif addition_embed_type == "text_time":
406
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
407
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
408
+ elif addition_embed_type == "image":
409
+ # Kandinsky 2.2
410
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
411
+ elif addition_embed_type == "image_hint":
412
+ # Kandinsky 2.2 ControlNet
413
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
414
+ elif addition_embed_type is not None:
415
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
416
+
417
+ if time_embedding_act_fn is None:
418
+ self.time_embed_act = None
419
+ else:
420
+ self.time_embed_act = get_activation(time_embedding_act_fn)
421
+
422
+ self.down_blocks = nn.ModuleList([])
423
+ self.up_blocks = nn.ModuleList([])
424
+
425
+ if isinstance(only_cross_attention, bool):
426
+ if mid_block_only_cross_attention is None:
427
+ mid_block_only_cross_attention = only_cross_attention
428
+
429
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
430
+
431
+ if mid_block_only_cross_attention is None:
432
+ mid_block_only_cross_attention = False
433
+
434
+ if isinstance(num_attention_heads, int):
435
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
436
+
437
+ if isinstance(attention_head_dim, int):
438
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
439
+
440
+ if isinstance(cross_attention_dim, int):
441
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
442
+
443
+ if isinstance(layers_per_block, int):
444
+ layers_per_block = [layers_per_block] * len(down_block_types)
445
+
446
+ if isinstance(transformer_layers_per_block, int):
447
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
448
+
449
+ if class_embeddings_concat:
450
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
451
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
452
+ # regular time embeddings
453
+ blocks_time_embed_dim = time_embed_dim * 2
454
+ else:
455
+ blocks_time_embed_dim = time_embed_dim
456
+
457
+ # down
458
+ output_channel = block_out_channels[0]
459
+ for i, down_block_type in enumerate(down_block_types):
460
+ input_channel = output_channel
461
+ output_channel = block_out_channels[i]
462
+ is_final_block = i == len(block_out_channels) - 1
463
+
464
+ down_block = get_down_block(
465
+ down_block_type,
466
+ num_layers=layers_per_block[i],
467
+ transformer_layers_per_block=transformer_layers_per_block[i],
468
+ in_channels=input_channel,
469
+ out_channels=output_channel,
470
+ temb_channels=blocks_time_embed_dim,
471
+ add_downsample=not is_final_block,
472
+ resnet_eps=norm_eps,
473
+ resnet_act_fn=act_fn,
474
+ resnet_groups=norm_num_groups,
475
+ cross_attention_dim=cross_attention_dim[i],
476
+ num_attention_heads=num_attention_heads[i],
477
+ downsample_padding=downsample_padding,
478
+ dual_cross_attention=dual_cross_attention,
479
+ use_linear_projection=use_linear_projection,
480
+ only_cross_attention=only_cross_attention[i],
481
+ upcast_attention=upcast_attention,
482
+ resnet_time_scale_shift=resnet_time_scale_shift,
483
+ attention_type=attention_type,
484
+ resnet_skip_time_act=resnet_skip_time_act,
485
+ resnet_out_scale_factor=resnet_out_scale_factor,
486
+ cross_attention_norm=cross_attention_norm,
487
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
488
+ dropout=dropout,
489
+ )
490
+ self.down_blocks.append(down_block)
491
+
492
+ # mid
493
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
494
+ self.mid_block = UNetMidBlock2DCrossAttn(
495
+ transformer_layers_per_block=transformer_layers_per_block[-1],
496
+ in_channels=block_out_channels[-1],
497
+ temb_channels=blocks_time_embed_dim,
498
+ dropout=dropout,
499
+ resnet_eps=norm_eps,
500
+ resnet_act_fn=act_fn,
501
+ output_scale_factor=mid_block_scale_factor,
502
+ resnet_time_scale_shift=resnet_time_scale_shift,
503
+ cross_attention_dim=cross_attention_dim[-1],
504
+ num_attention_heads=num_attention_heads[-1],
505
+ resnet_groups=norm_num_groups,
506
+ dual_cross_attention=dual_cross_attention,
507
+ use_linear_projection=use_linear_projection,
508
+ upcast_attention=upcast_attention,
509
+ attention_type=attention_type,
510
+ )
511
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
512
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
513
+ in_channels=block_out_channels[-1],
514
+ temb_channels=blocks_time_embed_dim,
515
+ dropout=dropout,
516
+ resnet_eps=norm_eps,
517
+ resnet_act_fn=act_fn,
518
+ output_scale_factor=mid_block_scale_factor,
519
+ cross_attention_dim=cross_attention_dim[-1],
520
+ attention_head_dim=attention_head_dim[-1],
521
+ resnet_groups=norm_num_groups,
522
+ resnet_time_scale_shift=resnet_time_scale_shift,
523
+ skip_time_act=resnet_skip_time_act,
524
+ only_cross_attention=mid_block_only_cross_attention,
525
+ cross_attention_norm=cross_attention_norm,
526
+ )
527
+ elif mid_block_type == "UNetMidBlock2D":
528
+ self.mid_block = UNetMidBlock2D(
529
+ in_channels=block_out_channels[-1],
530
+ temb_channels=blocks_time_embed_dim,
531
+ dropout=dropout,
532
+ num_layers=0,
533
+ resnet_eps=norm_eps,
534
+ resnet_act_fn=act_fn,
535
+ output_scale_factor=mid_block_scale_factor,
536
+ resnet_groups=norm_num_groups,
537
+ resnet_time_scale_shift=resnet_time_scale_shift,
538
+ add_attention=False,
539
+ )
540
+ elif mid_block_type is None:
541
+ self.mid_block = None
542
+ else:
543
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
544
+
545
+ # count how many layers upsample the images
546
+ self.num_upsamplers = 0
547
+
548
+ # up
549
+ reversed_block_out_channels = list(reversed(block_out_channels))
550
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
551
+ reversed_layers_per_block = list(reversed(layers_per_block))
552
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
553
+ reversed_transformer_layers_per_block = (
554
+ list(reversed(transformer_layers_per_block))
555
+ if reverse_transformer_layers_per_block is None
556
+ else reverse_transformer_layers_per_block
557
+ )
558
+ only_cross_attention = list(reversed(only_cross_attention))
559
+
560
+ output_channel = reversed_block_out_channels[0]
561
+ for i, up_block_type in enumerate(up_block_types):
562
+ is_final_block = i == len(block_out_channels) - 1
563
+
564
+ prev_output_channel = output_channel
565
+ output_channel = reversed_block_out_channels[i]
566
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
567
+
568
+ # add upsample block for all BUT final layer
569
+ if not is_final_block:
570
+ add_upsample = True
571
+ self.num_upsamplers += 1
572
+ else:
573
+ add_upsample = False
574
+
575
+ up_block = get_up_block(
576
+ up_block_type,
577
+ num_layers=reversed_layers_per_block[i] + 1,
578
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
579
+ in_channels=input_channel,
580
+ out_channels=output_channel,
581
+ prev_output_channel=prev_output_channel,
582
+ temb_channels=blocks_time_embed_dim,
583
+ add_upsample=add_upsample,
584
+ resnet_eps=norm_eps,
585
+ resnet_act_fn=act_fn,
586
+ resolution_idx=i,
587
+ resnet_groups=norm_num_groups,
588
+ cross_attention_dim=reversed_cross_attention_dim[i],
589
+ num_attention_heads=reversed_num_attention_heads[i],
590
+ dual_cross_attention=dual_cross_attention,
591
+ use_linear_projection=use_linear_projection,
592
+ only_cross_attention=only_cross_attention[i],
593
+ upcast_attention=upcast_attention,
594
+ resnet_time_scale_shift=resnet_time_scale_shift,
595
+ attention_type=attention_type,
596
+ resnet_skip_time_act=resnet_skip_time_act,
597
+ resnet_out_scale_factor=resnet_out_scale_factor,
598
+ cross_attention_norm=cross_attention_norm,
599
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
600
+ dropout=dropout,
601
+ )
602
+ self.up_blocks.append(up_block)
603
+ prev_output_channel = output_channel
604
+
605
+ # out
606
+ if norm_num_groups is not None:
607
+ self.conv_norm_out = nn.GroupNorm(
608
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
609
+ )
610
+
611
+ self.conv_act = get_activation(act_fn)
612
+
613
+ else:
614
+ self.conv_norm_out = None
615
+ self.conv_act = None
616
+
617
+ conv_out_padding = (conv_out_kernel - 1) // 2
618
+ self.conv_out = nn.Conv2d(
619
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
620
+ )
621
+
622
+ if attention_type in ["gated", "gated-text-image"]:
623
+ positive_len = 768
624
+ if isinstance(cross_attention_dim, int):
625
+ positive_len = cross_attention_dim
626
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
627
+ positive_len = cross_attention_dim[0]
628
+
629
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
630
+ self.position_net = TimeProjModel(
631
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
632
+ )
633
+
634
+ # additional settings
635
+ self.video_feature_dim = video_feature_dim
636
+ self.cross_attention_dim = cross_attention_dim
637
+ self.video_cross_attn_dim = video_cross_attn_dim
638
+ self.video_frame_nums = video_frame_nums
639
+
640
+ self.multi_frames_condition = False
641
+
642
+ def load_attention(self):
643
+ attn_dict = {}
644
+ for name in self.attn_processors.keys():
645
+ # if self-attention, save feature
646
+ if name.endswith("attn1.processor"):
647
+ if is_xformers_available():
648
+ attn_dict[name] = XFormersAttnProcessor()
649
+ else:
650
+ attn_dict[name] = AttnProcessor()
651
+ else:
652
+ attn_dict[name] = AttnProcessor2_0()
653
+ self.set_attn_processor(attn_dict)
654
+
655
+ def get_writer_feature(self):
656
+ return self.attn_feature_writer.get_cross_attention_feature()
657
+
658
+ def clear_writer_feature(self):
659
+ self.attn_feature_writer.clear_cross_attention_feature()
660
+
661
+ def disable_feature_adapters(self):
662
+ raise NotImplementedError
663
+
664
+ def set_reader_feature(self, features:list):
665
+ return self.attn_feature_reader.set_cross_attention_feature(features)
666
+
667
+ @property
668
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
669
+ r"""
670
+ Returns:
671
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
672
+ indexed by its weight name.
673
+ """
674
+ # set recursively
675
+ processors = {}
676
+
677
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
678
+ if hasattr(module, "get_processor"):
679
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
680
+
681
+ for sub_name, child in module.named_children():
682
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
683
+
684
+ return processors
685
+
686
+ for name, module in self.named_children():
687
+ fn_recursive_add_processors(name, module, processors)
688
+
689
+ return processors
690
+
691
+ def set_attn_processor(
692
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
693
+ ):
694
+ r"""
695
+ Sets the attention processor to use to compute attention.
696
+
697
+ Parameters:
698
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
699
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
700
+ for **all** `Attention` layers.
701
+
702
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
703
+ processor. This is strongly recommended when setting trainable attention processors.
704
+
705
+ """
706
+ count = len(self.attn_processors.keys())
707
+
708
+ if isinstance(processor, dict) and len(processor) != count:
709
+ raise ValueError(
710
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
711
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
712
+ )
713
+
714
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
715
+ if hasattr(module, "set_processor"):
716
+ if not isinstance(processor, dict):
717
+ module.set_processor(processor, _remove_lora=_remove_lora)
718
+ else:
719
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
720
+
721
+ for sub_name, child in module.named_children():
722
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
723
+
724
+ for name, module in self.named_children():
725
+ fn_recursive_attn_processor(name, module, processor)
726
+
727
+ def set_default_attn_processor(self):
728
+ """
729
+ Disables custom attention processors and sets the default attention implementation.
730
+ """
731
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
732
+ processor = AttnAddedKVProcessor()
733
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
734
+ processor = AttnProcessor()
735
+ else:
736
+ raise ValueError(
737
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
738
+ )
739
+
740
+ self.set_attn_processor(processor, _remove_lora=True)
741
+
742
+ def set_attention_slice(self, slice_size):
743
+ r"""
744
+ Enable sliced attention computation.
745
+
746
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
747
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
748
+
749
+ Args:
750
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
751
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
752
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
753
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
754
+ must be a multiple of `slice_size`.
755
+ """
756
+ sliceable_head_dims = []
757
+
758
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
759
+ if hasattr(module, "set_attention_slice"):
760
+ sliceable_head_dims.append(module.sliceable_head_dim)
761
+
762
+ for child in module.children():
763
+ fn_recursive_retrieve_sliceable_dims(child)
764
+
765
+ # retrieve number of attention layers
766
+ for module in self.children():
767
+ fn_recursive_retrieve_sliceable_dims(module)
768
+
769
+ num_sliceable_layers = len(sliceable_head_dims)
770
+
771
+ if slice_size == "auto":
772
+ # half the attention head size is usually a good trade-off between
773
+ # speed and memory
774
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
775
+ elif slice_size == "max":
776
+ # make smallest slice possible
777
+ slice_size = num_sliceable_layers * [1]
778
+
779
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
780
+
781
+ if len(slice_size) != len(sliceable_head_dims):
782
+ raise ValueError(
783
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
784
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
785
+ )
786
+
787
+ for i in range(len(slice_size)):
788
+ size = slice_size[i]
789
+ dim = sliceable_head_dims[i]
790
+ if size is not None and size > dim:
791
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
792
+
793
+ # Recursively walk through all the children.
794
+ # Any children which exposes the set_attention_slice method
795
+ # gets the message
796
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
797
+ if hasattr(module, "set_attention_slice"):
798
+ module.set_attention_slice(slice_size.pop())
799
+
800
+ for child in module.children():
801
+ fn_recursive_set_attention_slice(child, slice_size)
802
+
803
+ reversed_slice_size = list(reversed(slice_size))
804
+ for module in self.children():
805
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
806
+
807
+ def _set_gradient_checkpointing(self, module, value=False):
808
+ if hasattr(module, "gradient_checkpointing"):
809
+ module.gradient_checkpointing = value
810
+
811
+ def enable_freeu(self, s1, s2, b1, b2):
812
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
813
+
814
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
815
+
816
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
817
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
818
+
819
+ Args:
820
+ s1 (`float`):
821
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
822
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
823
+ s2 (`float`):
824
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
825
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
826
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
827
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
828
+ """
829
+ for i, upsample_block in enumerate(self.up_blocks):
830
+ setattr(upsample_block, "s1", s1)
831
+ setattr(upsample_block, "s2", s2)
832
+ setattr(upsample_block, "b1", b1)
833
+ setattr(upsample_block, "b2", b2)
834
+
835
+ def disable_freeu(self):
836
+ """Disables the FreeU mechanism."""
837
+ freeu_keys = {"s1", "s2", "b1", "b2"}
838
+ for i, upsample_block in enumerate(self.up_blocks):
839
+ for k in freeu_keys:
840
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
841
+ setattr(upsample_block, k, None)
842
+
843
+ def fuse_qkv_projections(self):
844
+ """
845
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
846
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
847
+
848
+ <Tip warning={true}>
849
+
850
+ This API is 🧪 experimental.
851
+
852
+ </Tip>
853
+ """
854
+ self.original_attn_processors = None
855
+
856
+ for _, attn_processor in self.attn_processors.items():
857
+ if "Added" in str(attn_processor.__class__.__name__):
858
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
859
+
860
+ self.original_attn_processors = self.attn_processors
861
+
862
+ for module in self.modules():
863
+ if isinstance(module, Attention):
864
+ module.fuse_projections(fuse=True)
865
+
866
+ def unfuse_qkv_projections(self):
867
+ """Disables the fused QKV projection if enabled.
868
+
869
+ <Tip warning={true}>
870
+
871
+ This API is 🧪 experimental.
872
+
873
+ </Tip>
874
+
875
+ """
876
+ if self.original_attn_processors is not None:
877
+ self.set_attn_processor(self.original_attn_processors)
878
+
879
+ def forward(
880
+ self,
881
+ sample: torch.FloatTensor,
882
+ timestep: Union[torch.Tensor, float, int],
883
+ encoder_hidden_states: torch.Tensor,
884
+ class_labels: Optional[torch.Tensor] = None,
885
+ timestep_cond: Optional[torch.Tensor] = None,
886
+ attention_mask: Optional[torch.Tensor] = None,
887
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
888
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
889
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
890
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
891
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
892
+ encoder_attention_mask: Optional[torch.Tensor] = None,
893
+ return_dict: bool = True,
894
+ ) -> Union[UNet2DConditionOutput, Tuple]:
895
+ # import ipdb; ipdb.set_trace()
896
+ r"""
897
+ The [`UNet2DConditionModel`] forward method.
898
+
899
+ Args:
900
+ sample (`torch.FloatTensor`):
901
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
902
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
903
+ encoder_hidden_states (`torch.FloatTensor`):
904
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
905
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
906
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
907
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
908
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
909
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
910
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
911
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
912
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
913
+ negative values to the attention scores corresponding to "discard" tokens.
914
+ cross_attention_kwargs (`dict`, *optional*):
915
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
916
+ `self.processor` in
917
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
918
+ added_cond_kwargs: (`dict`, *optional*):
919
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
920
+ are passed along to the UNet blocks.
921
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
922
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
923
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
924
+ A tensor that if specified is added to the residual of the middle unet block.
925
+ encoder_attention_mask (`torch.Tensor`):
926
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
927
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
928
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
929
+ return_dict (`bool`, *optional*, defaults to `True`):
930
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
931
+ tuple.
932
+ cross_attention_kwargs (`dict`, *optional*):
933
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
934
+ added_cond_kwargs: (`dict`, *optional*):
935
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
936
+ are passed along to the UNet blocks.
937
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
938
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
939
+ example from ControlNet side model(s)
940
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
941
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
942
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
943
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
944
+
945
+ Returns:
946
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
947
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
948
+ a `tuple` is returned where the first element is the sample tensor.
949
+ """
950
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
951
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
952
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
953
+ # on the fly if necessary.
954
+ default_overall_up_factor = 2**self.num_upsamplers
955
+
956
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
957
+ forward_upsample_size = False
958
+ upsample_size = None
959
+
960
+ for dim in sample.shape[-2:]:
961
+ if dim % default_overall_up_factor != 0:
962
+ # Forward upsample size to force interpolation output size.
963
+ forward_upsample_size = True
964
+ break
965
+
966
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
967
+ # expects mask of shape:
968
+ # [batch, key_tokens]
969
+ # adds singleton query_tokens dimension:
970
+ # [batch, 1, key_tokens]
971
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
972
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
973
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
974
+ if attention_mask is not None:
975
+ # assume that mask is expressed as:
976
+ # (1 = keep, 0 = discard)
977
+ # convert mask into a bias that can be added to attention scores:
978
+ # (keep = +0, discard = -10000.0)
979
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
980
+ attention_mask = attention_mask.unsqueeze(1)
981
+
982
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
983
+ if encoder_attention_mask is not None:
984
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
985
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
986
+
987
+ # 0. center input if necessary
988
+ if self.config.center_input_sample:
989
+ sample = 2 * sample - 1.0
990
+
991
+ # 1. time
992
+ timesteps = timestep
993
+ if not torch.is_tensor(timesteps):
994
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
995
+ # This would be a good case for the `match` statement (Python 3.10+)
996
+ is_mps = sample.device.type == "mps"
997
+ if isinstance(timestep, float):
998
+ dtype = torch.float32 if is_mps else torch.float64
999
+ else:
1000
+ dtype = torch.int32 if is_mps else torch.int64
1001
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1002
+ elif len(timesteps.shape) == 0:
1003
+ timesteps = timesteps[None].to(sample.device)
1004
+
1005
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1006
+ timesteps = timesteps.expand(sample.shape[0])
1007
+
1008
+ t_emb = self.time_proj(timesteps)
1009
+
1010
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1011
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1012
+ # there might be better ways to encapsulate this.
1013
+ t_emb = t_emb.to(dtype=sample.dtype)
1014
+
1015
+ emb = self.time_embedding(t_emb, timestep_cond)
1016
+ aug_emb = None
1017
+
1018
+ if self.class_embedding is not None:
1019
+ if class_labels is None:
1020
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
1021
+
1022
+ if self.config.class_embed_type == "timestep":
1023
+ class_labels = self.time_proj(class_labels)
1024
+
1025
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1026
+ # there might be better ways to encapsulate this.
1027
+ class_labels = class_labels.to(dtype=sample.dtype)
1028
+
1029
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1030
+
1031
+ if self.config.class_embeddings_concat:
1032
+ emb = torch.cat([emb, class_emb], dim=-1)
1033
+ else:
1034
+ emb = emb + class_emb
1035
+
1036
+ if self.config.addition_embed_type == "text":
1037
+ aug_emb = self.add_embedding(encoder_hidden_states)
1038
+ elif self.config.addition_embed_type == "text_image":
1039
+ # Kandinsky 2.1 - style
1040
+ if "image_embeds" not in added_cond_kwargs:
1041
+ raise ValueError(
1042
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1043
+ )
1044
+
1045
+ image_embs = added_cond_kwargs.get("image_embeds")
1046
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1047
+ aug_emb = self.add_embedding(text_embs, image_embs)
1048
+ elif self.config.addition_embed_type == "text_time":
1049
+ # SDXL - style
1050
+ if "text_embeds" not in added_cond_kwargs:
1051
+ raise ValueError(
1052
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1053
+ )
1054
+ text_embeds = added_cond_kwargs.get("text_embeds")
1055
+ if "time_ids" not in added_cond_kwargs:
1056
+ raise ValueError(
1057
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1058
+ )
1059
+ time_ids = added_cond_kwargs.get("time_ids")
1060
+ time_embeds = self.add_time_proj(time_ids.flatten())
1061
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1062
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1063
+ add_embeds = add_embeds.to(emb.dtype)
1064
+ aug_emb = self.add_embedding(add_embeds)
1065
+ elif self.config.addition_embed_type == "image":
1066
+ # Kandinsky 2.2 - style
1067
+ if "image_embeds" not in added_cond_kwargs:
1068
+ raise ValueError(
1069
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1070
+ )
1071
+ image_embs = added_cond_kwargs.get("image_embeds")
1072
+ aug_emb = self.add_embedding(image_embs)
1073
+ elif self.config.addition_embed_type == "image_hint":
1074
+ # Kandinsky 2.2 - style
1075
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1076
+ raise ValueError(
1077
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1078
+ )
1079
+ image_embs = added_cond_kwargs.get("image_embeds")
1080
+ hint = added_cond_kwargs.get("hint")
1081
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1082
+ sample = torch.cat([sample, hint], dim=1)
1083
+
1084
+ emb = emb + aug_emb if aug_emb is not None else emb
1085
+
1086
+ if self.time_embed_act is not None:
1087
+ emb = self.time_embed_act(emb)
1088
+
1089
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1090
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1091
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1092
+ # Kadinsky 2.1 - style
1093
+ if "image_embeds" not in added_cond_kwargs:
1094
+ raise ValueError(
1095
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1096
+ )
1097
+
1098
+ image_embeds = added_cond_kwargs.get("image_embeds")
1099
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1100
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1101
+ # Kandinsky 2.2 - style
1102
+ if "image_embeds" not in added_cond_kwargs:
1103
+ raise ValueError(
1104
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1105
+ )
1106
+ image_embeds = added_cond_kwargs.get("image_embeds")
1107
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1108
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1109
+ if "image_embeds" not in added_cond_kwargs:
1110
+ raise ValueError(
1111
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1112
+ )
1113
+ image_embeds = added_cond_kwargs.get("image_embeds")
1114
+ image_embeds = self.encoder_hid_proj(image_embeds)
1115
+ if isinstance(image_embeds, list):
1116
+ image_embeds = [image_embed.to(encoder_hidden_states.dtype) for image_embed in image_embeds]
1117
+ else:
1118
+ image_embeds = image_embeds.to(encoder_hidden_states.dtype)
1119
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
1120
+ # encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
1121
+ # import ipdb; ipdb.set_trace()
1122
+ # 2. pre-process
1123
+ sample = self.conv_in(sample)
1124
+
1125
+ # 2.5 GLIGEN position net
1126
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1127
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1128
+ gligen_args = cross_attention_kwargs.pop("gligen")
1129
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1130
+
1131
+ # 3. down
1132
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1133
+ if USE_PEFT_BACKEND:
1134
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1135
+ scale_lora_layers(self, lora_scale)
1136
+
1137
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1138
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1139
+ is_adapter = down_intrablock_additional_residuals is not None
1140
+ # maintain backward compatibility for legacy usage, where
1141
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1142
+ # but can only use one or the other
1143
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1144
+ deprecate(
1145
+ "T2I should not use down_block_additional_residuals",
1146
+ "1.3.0",
1147
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1148
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1149
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1150
+ standard_warn=False,
1151
+ )
1152
+ down_intrablock_additional_residuals = down_block_additional_residuals
1153
+ is_adapter = True
1154
+ # import ipdb; ipdb.set_trace()
1155
+ down_block_res_samples = (sample,)
1156
+ for downsample_block in self.down_blocks:
1157
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1158
+ # For t2i-adapter CrossAttnDownBlock2D
1159
+ additional_residuals = {}
1160
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1161
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1162
+
1163
+ sample, res_samples = downsample_block(
1164
+ hidden_states=sample,
1165
+ temb=emb,
1166
+ encoder_hidden_states=encoder_hidden_states,
1167
+ attention_mask=attention_mask,
1168
+ cross_attention_kwargs=cross_attention_kwargs,
1169
+ encoder_attention_mask=encoder_attention_mask,
1170
+ **additional_residuals,
1171
+ )
1172
+ # import ipdb; ipdb.set_trace()
1173
+ else:
1174
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
1175
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1176
+ sample += down_intrablock_additional_residuals.pop(0)
1177
+
1178
+ down_block_res_samples += res_samples
1179
+
1180
+ if is_controlnet:
1181
+ new_down_block_res_samples = ()
1182
+
1183
+ for down_block_res_sample, down_block_additional_residual in zip(
1184
+ down_block_res_samples, down_block_additional_residuals
1185
+ ):
1186
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1187
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1188
+
1189
+ down_block_res_samples = new_down_block_res_samples
1190
+ # 4. mid
1191
+ if self.mid_block is not None:
1192
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1193
+ sample = self.mid_block(
1194
+ sample,
1195
+ emb,
1196
+ encoder_hidden_states=encoder_hidden_states,
1197
+ attention_mask=attention_mask,
1198
+ cross_attention_kwargs=cross_attention_kwargs,
1199
+ encoder_attention_mask=encoder_attention_mask,
1200
+ )
1201
+ else:
1202
+ sample = self.mid_block(sample, emb)
1203
+
1204
+ # To support T2I-Adapter-XL
1205
+ if (
1206
+ is_adapter
1207
+ and len(down_intrablock_additional_residuals) > 0
1208
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1209
+ ):
1210
+ sample += down_intrablock_additional_residuals.pop(0)
1211
+
1212
+ if is_controlnet:
1213
+ sample = sample + mid_block_additional_residual
1214
+ # import ipdb; ipdb.set_trace()
1215
+ # 5. up
1216
+ for i, upsample_block in enumerate(self.up_blocks):
1217
+ is_final_block = i == len(self.up_blocks) - 1
1218
+
1219
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1220
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1221
+
1222
+ # if we have not reached the final block and need to forward the
1223
+ # upsample size, we do it here
1224
+ if not is_final_block and forward_upsample_size:
1225
+ upsample_size = down_block_res_samples[-1].shape[2:]
1226
+
1227
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1228
+ sample = upsample_block(
1229
+ hidden_states=sample,
1230
+ temb=emb,
1231
+ res_hidden_states_tuple=res_samples,
1232
+ encoder_hidden_states=encoder_hidden_states,
1233
+ cross_attention_kwargs=cross_attention_kwargs,
1234
+ upsample_size=upsample_size,
1235
+ attention_mask=attention_mask,
1236
+ encoder_attention_mask=encoder_attention_mask,
1237
+ )
1238
+ else:
1239
+ sample = upsample_block(
1240
+ hidden_states=sample,
1241
+ temb=emb,
1242
+ res_hidden_states_tuple=res_samples,
1243
+ upsample_size=upsample_size,
1244
+ scale=lora_scale,
1245
+ )
1246
+ # import ipdb; ipdb.set_trace()
1247
+ # 6. post-process
1248
+ if self.conv_norm_out:
1249
+ sample = self.conv_norm_out(sample)
1250
+ sample = self.conv_act(sample)
1251
+ sample = self.conv_out(sample)
1252
+
1253
+ if USE_PEFT_BACKEND:
1254
+ # remove `lora_scale` from each PEFT layer
1255
+ unscale_lora_layers(self, lora_scale)
1256
+
1257
+ if not return_dict:
1258
+ return (sample,)
1259
+ # import ipdb; ipdb.set_trace()
1260
+ return UNet2DConditionOutput(sample=sample)
foleycrafter/models/specvqgan/data/greatesthit.py ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import collections
2
+ import json
3
+ import os
4
+ import copy
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ from torchvision import transforms
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ from random import sample
11
+ import torchaudio
12
+ import logging
13
+ import collections
14
+ from glob import glob
15
+ import sys
16
+ import albumentations
17
+ import soundfile
18
+
19
+ sys.path.insert(0, '.') # nopep8
20
+ from train import instantiate_from_config
21
+ from foleycrafter.models.specvqgan.data.transforms import *
22
+
23
+ torchaudio.set_audio_backend("sox_io")
24
+ logger = logging.getLogger(f'main.{__name__}')
25
+
26
+ SR = 22050
27
+ FPS = 15
28
+ MAX_SAMPLE_ITER = 10
29
+
30
+ def non_negative(x): return int(np.round(max(0, x), 0))
31
+
32
+ def rms(x): return np.sqrt(np.mean(x**2))
33
+
34
+ def get_GH_data_identifier(video_name, start_idx, split='_'):
35
+ if isinstance(start_idx, str):
36
+ return video_name + split + start_idx
37
+ elif isinstance(start_idx, int):
38
+ return video_name + split + str(start_idx)
39
+ else:
40
+ raise NotImplementedError
41
+
42
+
43
+ class Crop(object):
44
+
45
+ def __init__(self, cropped_shape=None, random_crop=False):
46
+ self.cropped_shape = cropped_shape
47
+ if cropped_shape is not None:
48
+ mel_num, spec_len = cropped_shape
49
+ if random_crop:
50
+ self.cropper = albumentations.RandomCrop
51
+ else:
52
+ self.cropper = albumentations.CenterCrop
53
+ self.preprocessor = albumentations.Compose([self.cropper(mel_num, spec_len)])
54
+ else:
55
+ self.preprocessor = lambda **kwargs: kwargs
56
+
57
+ def __call__(self, item):
58
+ item['image'] = self.preprocessor(image=item['image'])['image']
59
+ if 'cond_image' in item.keys():
60
+ item['cond_image'] = self.preprocessor(image=item['cond_image'])['image']
61
+ return item
62
+
63
+ class CropImage(Crop):
64
+ def __init__(self, *crop_args):
65
+ super().__init__(*crop_args)
66
+
67
+ class CropFeats(Crop):
68
+ def __init__(self, *crop_args):
69
+ super().__init__(*crop_args)
70
+
71
+ def __call__(self, item):
72
+ item['feature'] = self.preprocessor(image=item['feature'])['image']
73
+ return item
74
+
75
+ class CropCoords(Crop):
76
+ def __init__(self, *crop_args):
77
+ super().__init__(*crop_args)
78
+
79
+ def __call__(self, item):
80
+ item['coord'] = self.preprocessor(image=item['coord'])['image']
81
+ return item
82
+
83
+ class ResampleFrames(object):
84
+ def __init__(self, feat_sample_size, times_to_repeat_after_resample=None):
85
+ self.feat_sample_size = feat_sample_size
86
+ self.times_to_repeat_after_resample = times_to_repeat_after_resample
87
+
88
+ def __call__(self, item):
89
+ feat_len = item['feature'].shape[0]
90
+
91
+ ## resample
92
+ assert feat_len >= self.feat_sample_size
93
+ # evenly spaced points (abcdefghkl -> aoooofoooo)
94
+ idx = np.linspace(0, feat_len, self.feat_sample_size, dtype=np.int, endpoint=False)
95
+ # xoooo xoooo -> ooxoo ooxoo
96
+ shift = feat_len // (self.feat_sample_size + 1)
97
+ idx = idx + shift
98
+
99
+ ## repeat after resampling (abc -> aaaabbbbcccc)
100
+ if self.times_to_repeat_after_resample is not None and self.times_to_repeat_after_resample > 1:
101
+ idx = np.repeat(idx, self.times_to_repeat_after_resample)
102
+
103
+ item['feature'] = item['feature'][idx, :]
104
+ return item
105
+
106
+
107
+ class GreatestHitSpecs(torch.utils.data.Dataset):
108
+
109
+ def __init__(self, split, spec_dir_path, spec_len, random_crop, mel_num,
110
+ spec_crop_len, L=2.0, rand_shift=False, spec_transforms=None, splits_path='./data',
111
+ meta_path='./data/info_r2plus1d_dim1024_15fps.json'):
112
+ super().__init__()
113
+ self.split = split
114
+ self.specs_dir = spec_dir_path
115
+ self.spec_transforms = spec_transforms
116
+ self.splits_path = splits_path
117
+ self.meta_path = meta_path
118
+ self.spec_len = spec_len
119
+ self.rand_shift = rand_shift
120
+ self.L = L
121
+ self.spec_take_first = int(math.ceil(860 * (L / 10.) / 32) * 32)
122
+ self.spec_take_first = 860 if self.spec_take_first > 860 else self.spec_take_first
123
+
124
+ greatesthit_meta = json.load(open(self.meta_path, 'r'))
125
+ unique_classes = sorted(list(set(ht for ht in greatesthit_meta['hit_type'])))
126
+ self.label2target = {label: target for target, label in enumerate(unique_classes)}
127
+ self.target2label = {target: label for label, target in self.label2target.items()}
128
+ self.video_idx2label = {
129
+ get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]):
130
+ greatesthit_meta['hit_type'][i] for i in range(len(greatesthit_meta['video_name']))
131
+ }
132
+ self.available_video_hit = list(self.video_idx2label.keys())
133
+ self.video_idx2path = {
134
+ vh: os.path.join(self.specs_dir,
135
+ vh.replace('_', '_denoised_') + '_' + self.video_idx2label[vh].replace(' ', '_') +'_mel.npy')
136
+ for vh in self.available_video_hit
137
+ }
138
+ self.video_idx2idx = {
139
+ get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]):
140
+ i for i in range(len(greatesthit_meta['video_name']))
141
+ }
142
+
143
+ split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json')
144
+ if not os.path.exists(split_clip_ids_path):
145
+ raise NotImplementedError()
146
+ clip_video_hit = json.load(open(split_clip_ids_path, 'r'))
147
+ self.dataset = clip_video_hit
148
+ spec_crop_len = self.spec_take_first if self.spec_take_first <= spec_crop_len else spec_crop_len
149
+ self.spec_transforms = transforms.Compose([
150
+ CropImage([mel_num, spec_crop_len], random_crop),
151
+ # transforms.RandomApply([FrequencyMasking(freq_mask_param=20)], p=0),
152
+ # transforms.RandomApply([TimeMasking(time_mask_param=int(32 * self.L))], p=0)
153
+ ])
154
+
155
+ self.video2indexes = {}
156
+ for video_idx in self.dataset:
157
+ video, start_idx = video_idx.split('_')
158
+ if video not in self.video2indexes.keys():
159
+ self.video2indexes[video] = []
160
+ self.video2indexes[video].append(start_idx)
161
+ for video in self.video2indexes.keys():
162
+ if len(self.video2indexes[video]) == 1: # given video contains only one hit
163
+ self.dataset.remove(
164
+ get_GH_data_identifier(video, self.video2indexes[video][0])
165
+ )
166
+
167
+ def __len__(self):
168
+ return len(self.dataset)
169
+
170
+ def __getitem__(self, idx):
171
+ item = {}
172
+
173
+ video_idx = self.dataset[idx]
174
+ spec_path = self.video_idx2path[video_idx]
175
+ spec = np.load(spec_path) # (80, 860)
176
+
177
+ if self.rand_shift:
178
+ shift = random.uniform(0, 0.5)
179
+ spec_shift = int(shift * spec.shape[1] // 10)
180
+ # Since only the first second is used
181
+ spec = np.roll(spec, -spec_shift, 1)
182
+
183
+ # concat spec outside dataload
184
+ item['image'] = 2 * spec - 1 # (80, 860)
185
+ item['image'] = item['image'][:, :self.spec_take_first]
186
+ item['file_path'] = spec_path
187
+
188
+ item['label'] = self.video_idx2label[video_idx]
189
+ item['target'] = self.label2target[item['label']]
190
+
191
+ if self.spec_transforms is not None:
192
+ item = self.spec_transforms(item)
193
+
194
+ return item
195
+
196
+
197
+ class GreatestHitSpecsTrain(GreatestHitSpecs):
198
+ def __init__(self, specs_dataset_cfg):
199
+ super().__init__('train', **specs_dataset_cfg)
200
+
201
+ class GreatestHitSpecsValidation(GreatestHitSpecs):
202
+ def __init__(self, specs_dataset_cfg):
203
+ super().__init__('val', **specs_dataset_cfg)
204
+
205
+ class GreatestHitSpecsTest(GreatestHitSpecs):
206
+ def __init__(self, specs_dataset_cfg):
207
+ super().__init__('test', **specs_dataset_cfg)
208
+
209
+
210
+
211
+ class GreatestHitWave(torch.utils.data.Dataset):
212
+
213
+ def __init__(self, split, wav_dir, random_crop, mel_num, spec_crop_len, spec_len,
214
+ L=2.0, splits_path='./data', rand_shift=True,
215
+ data_path='data/greatesthit/greatesthit-process-resized'):
216
+ super().__init__()
217
+ self.split = split
218
+ self.wav_dir = wav_dir
219
+ self.splits_path = splits_path
220
+ self.data_path = data_path
221
+ self.L = L
222
+ self.rand_shift = rand_shift
223
+
224
+ split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json')
225
+ if not os.path.exists(split_clip_ids_path):
226
+ raise NotImplementedError()
227
+ clip_video_hit = json.load(open(split_clip_ids_path, 'r'))
228
+
229
+ video_name = list(set([vidx.split('_')[0] for vidx in clip_video_hit]))
230
+
231
+ self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) // 2 for v in video_name}
232
+ self.left_over = int(FPS * L + 1)
233
+ self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_denoised_resampled.wav') for v in video_name}
234
+ self.dataset = clip_video_hit
235
+
236
+ self.video2indexes = {}
237
+ for video_idx in self.dataset:
238
+ video, start_idx = video_idx.split('_')
239
+ if video not in self.video2indexes.keys():
240
+ self.video2indexes[video] = []
241
+ self.video2indexes[video].append(start_idx)
242
+ for video in self.video2indexes.keys():
243
+ if len(self.video2indexes[video]) == 1: # given video contains only one hit
244
+ self.dataset.remove(
245
+ get_GH_data_identifier(video, self.video2indexes[video][0])
246
+ )
247
+
248
+ self.wav_transforms = transforms.Compose([
249
+ MakeMono(),
250
+ Padding(target_len=int(SR * self.L)),
251
+ ])
252
+
253
+ def __len__(self):
254
+ return len(self.dataset)
255
+
256
+ def __getitem__(self, idx):
257
+ item = {}
258
+ video_idx = self.dataset[idx]
259
+ video, start_idx = video_idx.split('_')
260
+ start_idx = int(start_idx)
261
+ if self.rand_shift:
262
+ shift = int(random.uniform(-0.5, 0.5) * SR)
263
+ start_idx = non_negative(start_idx + shift)
264
+
265
+ wave_path = self.video_audio_path[video]
266
+ wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_idx)
267
+ assert sr == SR
268
+ wav = self.wav_transforms(wav)
269
+
270
+ item['image'] = wav # (44100,)
271
+ # item['wav'] = wav
272
+ item['file_path_wav_'] = wave_path
273
+
274
+ item['label'] = 'None'
275
+ item['target'] = 'None'
276
+
277
+ return item
278
+
279
+
280
+ class GreatestHitWaveTrain(GreatestHitWave):
281
+ def __init__(self, specs_dataset_cfg):
282
+ super().__init__('train', **specs_dataset_cfg)
283
+
284
+ class GreatestHitWaveValidation(GreatestHitWave):
285
+ def __init__(self, specs_dataset_cfg):
286
+ super().__init__('val', **specs_dataset_cfg)
287
+
288
+ class GreatestHitWaveTest(GreatestHitWave):
289
+ def __init__(self, specs_dataset_cfg):
290
+ super().__init__('test', **specs_dataset_cfg)
291
+
292
+
293
+ class CondGreatestHitSpecsCondOnImage(torch.utils.data.Dataset):
294
+
295
+ def __init__(self, split, specs_dir, spec_len, feat_len, feat_depth, feat_crop_len, random_crop, mel_num, spec_crop_len,
296
+ vqgan_L=10.0, L=1.0, rand_shift=False, spec_transforms=None, frame_transforms=None, splits_path='./data',
297
+ meta_path='./data/info_r2plus1d_dim1024_15fps.json', frame_path='data/greatesthit/greatesthit_processed',
298
+ p_outside_cond=0., p_audio_aug=0.5):
299
+ super().__init__()
300
+ self.split = split
301
+ self.specs_dir = specs_dir
302
+ self.spec_transforms = spec_transforms
303
+ self.frame_transforms = frame_transforms
304
+ self.splits_path = splits_path
305
+ self.meta_path = meta_path
306
+ self.frame_path = frame_path
307
+ self.feat_len = feat_len
308
+ self.feat_depth = feat_depth
309
+ self.feat_crop_len = feat_crop_len
310
+ self.spec_len = spec_len
311
+ self.rand_shift = rand_shift
312
+ self.L = L
313
+ self.spec_take_first = int(math.ceil(860 * (vqgan_L / 10.) / 32) * 32)
314
+ self.spec_take_first = 860 if self.spec_take_first > 860 else self.spec_take_first
315
+ self.p_outside_cond = torch.tensor(p_outside_cond)
316
+
317
+ greatesthit_meta = json.load(open(self.meta_path, 'r'))
318
+ unique_classes = sorted(list(set(ht for ht in greatesthit_meta['hit_type'])))
319
+ self.label2target = {label: target for target, label in enumerate(unique_classes)}
320
+ self.target2label = {target: label for label, target in self.label2target.items()}
321
+ self.video_idx2label = {
322
+ get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]):
323
+ greatesthit_meta['hit_type'][i] for i in range(len(greatesthit_meta['video_name']))
324
+ }
325
+ self.available_video_hit = list(self.video_idx2label.keys())
326
+ self.video_idx2path = {
327
+ vh: os.path.join(self.specs_dir,
328
+ vh.replace('_', '_denoised_') + '_' + self.video_idx2label[vh].replace(' ', '_') +'_mel.npy')
329
+ for vh in self.available_video_hit
330
+ }
331
+ for value in self.video_idx2path.values():
332
+ assert os.path.exists(value)
333
+ self.video_idx2idx = {
334
+ get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]):
335
+ i for i in range(len(greatesthit_meta['video_name']))
336
+ }
337
+
338
+ split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json')
339
+ if not os.path.exists(split_clip_ids_path):
340
+ self.make_split_files()
341
+ clip_video_hit = json.load(open(split_clip_ids_path, 'r'))
342
+ self.dataset = clip_video_hit
343
+ spec_crop_len = self.spec_take_first if self.spec_take_first <= spec_crop_len else spec_crop_len
344
+ self.spec_transforms = transforms.Compose([
345
+ CropImage([mel_num, spec_crop_len], random_crop),
346
+ # transforms.RandomApply([FrequencyMasking(freq_mask_param=20)], p=p_audio_aug),
347
+ # transforms.RandomApply([TimeMasking(time_mask_param=int(32 * self.L))], p=p_audio_aug)
348
+ ])
349
+ if self.frame_transforms == None:
350
+ self.frame_transforms = transforms.Compose([
351
+ Resize3D(128),
352
+ RandomResizedCrop3D(112, scale=(0.5, 1.0)),
353
+ RandomHorizontalFlip3D(),
354
+ ColorJitter3D(brightness=0.1, saturation=0.1),
355
+ ToTensor3D(),
356
+ Normalize3D(mean=[0.485, 0.456, 0.406],
357
+ std=[0.229, 0.224, 0.225]),
358
+ ])
359
+
360
+ self.video2indexes = {}
361
+ for video_idx in self.dataset:
362
+ video, start_idx = video_idx.split('_')
363
+ if video not in self.video2indexes.keys():
364
+ self.video2indexes[video] = []
365
+ self.video2indexes[video].append(start_idx)
366
+ for video in self.video2indexes.keys():
367
+ if len(self.video2indexes[video]) == 1: # given video contains only one hit
368
+ self.dataset.remove(
369
+ get_GH_data_identifier(video, self.video2indexes[video][0])
370
+ )
371
+
372
+ clip_classes = [self.label2target[self.video_idx2label[vh]] for vh in clip_video_hit]
373
+ class2count = collections.Counter(clip_classes)
374
+ self.class_counts = torch.tensor([class2count[cls] for cls in range(len(class2count))])
375
+ if self.L != 1.0:
376
+ print(split, L)
377
+ self.validate_data()
378
+ self.video2indexes = {}
379
+ for video_idx in self.dataset:
380
+ video, start_idx = video_idx.split('_')
381
+ if video not in self.video2indexes.keys():
382
+ self.video2indexes[video] = []
383
+ self.video2indexes[video].append(start_idx)
384
+
385
+ def __len__(self):
386
+ return len(self.dataset)
387
+
388
+ def __getitem__(self, idx):
389
+ item = {}
390
+
391
+ try:
392
+ video_idx = self.dataset[idx]
393
+ spec_path = self.video_idx2path[video_idx]
394
+ spec = np.load(spec_path) # (80, 860)
395
+
396
+ video, start_idx = video_idx.split('_')
397
+ frame_path = os.path.join(self.frame_path, video, 'frames')
398
+ start_frame_idx = non_negative(FPS * int(start_idx)/SR)
399
+ end_frame_idx = non_negative(start_frame_idx + FPS * self.L)
400
+
401
+ if self.rand_shift:
402
+ shift = random.uniform(0, 0.5)
403
+ spec_shift = int(shift * spec.shape[1] // 10)
404
+ # Since only the first second is used
405
+ spec = np.roll(spec, -spec_shift, 1)
406
+ start_frame_idx += int(FPS * shift)
407
+ end_frame_idx += int(FPS * shift)
408
+
409
+ frames = [Image.open(os.path.join(
410
+ frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
411
+ range(start_frame_idx, end_frame_idx)]
412
+
413
+ # Sample condition
414
+ if torch.all(torch.bernoulli(self.p_outside_cond) == 1.):
415
+ # Sample condition from outside video
416
+ all_idx = set(list(range(len(self.dataset))))
417
+ all_idx.remove(idx)
418
+ cond_video_idx = self.dataset[sample(all_idx, k=1)[0]]
419
+ cond_video, cond_start_idx = cond_video_idx.split('_')
420
+ else:
421
+ cond_video = video
422
+ video_hits_idx = copy.copy(self.video2indexes[video])
423
+ video_hits_idx.remove(start_idx)
424
+ cond_start_idx = sample(video_hits_idx, k=1)[0]
425
+ cond_video_idx = get_GH_data_identifier(cond_video, cond_start_idx)
426
+
427
+ cond_spec_path = self.video_idx2path[cond_video_idx]
428
+ cond_spec = np.load(cond_spec_path) # (80, 860)
429
+
430
+ cond_video, cond_start_idx = cond_video_idx.split('_')
431
+ cond_frame_path = os.path.join(self.frame_path, cond_video, 'frames')
432
+ cond_start_frame_idx = non_negative(FPS * int(cond_start_idx)/SR)
433
+ cond_end_frame_idx = non_negative(cond_start_frame_idx + FPS * self.L)
434
+
435
+ if self.rand_shift:
436
+ cond_shift = random.uniform(0, 0.5)
437
+ cond_spec_shift = int(cond_shift * cond_spec.shape[1] // 10)
438
+ # Since only the first second is used
439
+ cond_spec = np.roll(cond_spec, -cond_spec_shift, 1)
440
+ cond_start_frame_idx += int(FPS * cond_shift)
441
+ cond_end_frame_idx += int(FPS * cond_shift)
442
+
443
+ cond_frames = [Image.open(os.path.join(
444
+ cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
445
+ range(cond_start_frame_idx, cond_end_frame_idx)]
446
+
447
+ # concat spec outside dataload
448
+ item['image'] = 2 * spec - 1 # (80, 860)
449
+ item['cond_image'] = 2 * cond_spec - 1 # (80, 860)
450
+ item['image'] = item['image'][:, :self.spec_take_first]
451
+ item['cond_image'] = item['cond_image'][:, :self.spec_take_first]
452
+ item['file_path_specs_'] = spec_path
453
+ item['file_path_cond_specs_'] = cond_spec_path
454
+
455
+ if self.frame_transforms is not None:
456
+ cond_frames = self.frame_transforms(cond_frames)
457
+ frames = self.frame_transforms(frames)
458
+
459
+ item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3)
460
+ item['file_path_feats_'] = (frame_path, start_frame_idx)
461
+ item['file_path_cond_feats_'] = (cond_frame_path, cond_start_frame_idx)
462
+
463
+ item['label'] = self.video_idx2label[video_idx]
464
+ item['target'] = self.label2target[item['label']]
465
+
466
+ if self.spec_transforms is not None:
467
+ item = self.spec_transforms(item)
468
+ except Exception:
469
+ print(sys.exc_info()[2])
470
+ print('!!!!!!!!!!!!!!!!!!!!', video_idx, cond_video_idx)
471
+ print('!!!!!!!!!!!!!!!!!!!!', end_frame_idx, cond_end_frame_idx)
472
+ exit(1)
473
+
474
+ return item
475
+
476
+
477
+ def validate_data(self):
478
+ original_len = len(self.dataset)
479
+ valid_dataset = []
480
+ for video_idx in tqdm(self.dataset):
481
+ video, start_idx = video_idx.split('_')
482
+ frame_path = os.path.join(self.frame_path, video, 'frames')
483
+ start_frame_idx = non_negative(FPS * int(start_idx)/SR)
484
+ end_frame_idx = non_negative(start_frame_idx + FPS * (self.L + 0.6))
485
+ if os.path.exists(os.path.join(frame_path, f'frame{end_frame_idx:0>6d}.jpg')):
486
+ valid_dataset.append(video_idx)
487
+ else:
488
+ self.video2indexes[video].remove(start_idx)
489
+ for video_idx in valid_dataset:
490
+ video, start_idx = video_idx.split('_')
491
+ if len(self.video2indexes[video]) == 1:
492
+ valid_dataset.remove(video_idx)
493
+ if original_len != len(valid_dataset):
494
+ print(f'Validated dataset with enough frames: {len(valid_dataset)}')
495
+ self.dataset = valid_dataset
496
+ split_clip_ids_path = os.path.join(self.splits_path, f'greatesthit_{self.split}_{self.L:.2f}.json')
497
+ if not os.path.exists(split_clip_ids_path):
498
+ with open(split_clip_ids_path, 'w') as f:
499
+ json.dump(valid_dataset, f)
500
+
501
+
502
+ def make_split_files(self, ratio=[0.85, 0.1, 0.05]):
503
+ random.seed(1337)
504
+ print(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
505
+ # The downloaded videos (some went missing on YouTube and no longer available)
506
+ available_mel_paths = set(glob(os.path.join(self.specs_dir, '*_mel.npy')))
507
+ self.available_video_hit = [vh for vh in self.available_video_hit if self.video_idx2path[vh] in available_mel_paths]
508
+
509
+ all_video = list(self.video2indexes.keys())
510
+
511
+ print(f'The number of clips available after download: {len(self.available_video_hit)}')
512
+ print(f'The number of videos available after download: {len(all_video)}')
513
+
514
+ available_idx = list(range(len(all_video)))
515
+ random.shuffle(available_idx)
516
+ assert sum(ratio) == 1.
517
+ cut_train = int(ratio[0] * len(all_video))
518
+ cut_test = cut_train + int(ratio[1] * len(all_video))
519
+
520
+ train_idx = available_idx[:cut_train]
521
+ test_idx = available_idx[cut_train:cut_test]
522
+ valid_idx = available_idx[cut_test:]
523
+
524
+ train_video = [all_video[i] for i in train_idx]
525
+ test_video = [all_video[i] for i in test_idx]
526
+ valid_video = [all_video[i] for i in valid_idx]
527
+
528
+ train_video_hit = []
529
+ for v in train_video:
530
+ train_video_hit += [get_GH_data_identifier(v, hit_idx) for hit_idx in self.video2indexes[v]]
531
+ test_video_hit = []
532
+ for v in test_video:
533
+ test_video_hit += [get_GH_data_identifier(v, hit_idx) for hit_idx in self.video2indexes[v]]
534
+ valid_video_hit = []
535
+ for v in valid_video:
536
+ valid_video_hit += [get_GH_data_identifier(v, hit_idx) for hit_idx in self.video2indexes[v]]
537
+
538
+ # mix train and valid for better validation loss
539
+ mixed = train_video_hit + valid_video_hit
540
+ random.shuffle(mixed)
541
+ split = int(len(mixed) * ratio[0] / (ratio[0] + ratio[2]))
542
+ train_video_hit = mixed[:split]
543
+ valid_video_hit = mixed[split:]
544
+
545
+ with open(os.path.join(self.splits_path, 'greatesthit_train.json'), 'w') as train_file,\
546
+ open(os.path.join(self.splits_path, 'greatesthit_test.json'), 'w') as test_file,\
547
+ open(os.path.join(self.splits_path, 'greatesthit_valid.json'), 'w') as valid_file:
548
+ json.dump(train_video_hit, train_file)
549
+ json.dump(test_video_hit, test_file)
550
+ json.dump(valid_video_hit, valid_file)
551
+
552
+ print(f'Put {len(train_idx)} clips to the train set and saved it to ./data/greatesthit_train.json')
553
+ print(f'Put {len(test_idx)} clips to the test set and saved it to ./data/greatesthit_test.json')
554
+ print(f'Put {len(valid_idx)} clips to the valid set and saved it to ./data/greatesthit_valid.json')
555
+
556
+
557
+ class CondGreatestHitSpecsCondOnImageTrain(CondGreatestHitSpecsCondOnImage):
558
+ def __init__(self, dataset_cfg):
559
+ train_transforms = transforms.Compose([
560
+ Resize3D(256),
561
+ RandomResizedCrop3D(224, scale=(0.5, 1.0)),
562
+ RandomHorizontalFlip3D(),
563
+ ColorJitter3D(brightness=0.1, saturation=0.1),
564
+ ToTensor3D(),
565
+ Normalize3D(mean=[0.485, 0.456, 0.406],
566
+ std=[0.229, 0.224, 0.225]),
567
+ ])
568
+ super().__init__('train', frame_transforms=train_transforms, **dataset_cfg)
569
+
570
+ class CondGreatestHitSpecsCondOnImageValidation(CondGreatestHitSpecsCondOnImage):
571
+ def __init__(self, dataset_cfg):
572
+ valid_transforms = transforms.Compose([
573
+ Resize3D(256),
574
+ CenterCrop3D(224),
575
+ ToTensor3D(),
576
+ Normalize3D(mean=[0.485, 0.456, 0.406],
577
+ std=[0.229, 0.224, 0.225]),
578
+ ])
579
+ super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg)
580
+
581
+ class CondGreatestHitSpecsCondOnImageTest(CondGreatestHitSpecsCondOnImage):
582
+ def __init__(self, dataset_cfg):
583
+ test_transforms = transforms.Compose([
584
+ Resize3D(256),
585
+ CenterCrop3D(224),
586
+ ToTensor3D(),
587
+ Normalize3D(mean=[0.485, 0.456, 0.406],
588
+ std=[0.229, 0.224, 0.225]),
589
+ ])
590
+ super().__init__('test', frame_transforms=test_transforms, **dataset_cfg)
591
+
592
+
593
+ class CondGreatestHitWaveCondOnImage(torch.utils.data.Dataset):
594
+
595
+ def __init__(self, split, wav_dir, spec_len, random_crop, mel_num, spec_crop_len,
596
+ L=2.0, frame_transforms=None, splits_path='./data',
597
+ data_path='data/greatesthit/greatesthit-process-resized',
598
+ p_outside_cond=0., p_audio_aug=0.5, rand_shift=True):
599
+ super().__init__()
600
+ self.split = split
601
+ self.wav_dir = wav_dir
602
+ self.frame_transforms = frame_transforms
603
+ self.splits_path = splits_path
604
+ self.data_path = data_path
605
+ self.spec_len = spec_len
606
+ self.L = L
607
+ self.rand_shift = rand_shift
608
+ self.p_outside_cond = torch.tensor(p_outside_cond)
609
+
610
+ split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json')
611
+ if not os.path.exists(split_clip_ids_path):
612
+ raise NotImplementedError()
613
+ clip_video_hit = json.load(open(split_clip_ids_path, 'r'))
614
+
615
+ video_name = list(set([vidx.split('_')[0] for vidx in clip_video_hit]))
616
+
617
+ self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames')))//2 for v in video_name}
618
+ self.left_over = int(FPS * L + 1)
619
+ self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_denoised_resampled.wav') for v in video_name}
620
+ self.dataset = clip_video_hit
621
+
622
+ self.video2indexes = {}
623
+ for video_idx in self.dataset:
624
+ video, start_idx = video_idx.split('_')
625
+ if video not in self.video2indexes.keys():
626
+ self.video2indexes[video] = []
627
+ self.video2indexes[video].append(start_idx)
628
+ for video in self.video2indexes.keys():
629
+ if len(self.video2indexes[video]) == 1: # given video contains only one hit
630
+ self.dataset.remove(
631
+ get_GH_data_identifier(video, self.video2indexes[video][0])
632
+ )
633
+
634
+ self.wav_transforms = transforms.Compose([
635
+ MakeMono(),
636
+ Padding(target_len=int(SR * self.L)),
637
+ ])
638
+ if self.frame_transforms == None:
639
+ self.frame_transforms = transforms.Compose([
640
+ Resize3D(256),
641
+ RandomResizedCrop3D(224, scale=(0.5, 1.0)),
642
+ RandomHorizontalFlip3D(),
643
+ ColorJitter3D(brightness=0.1, saturation=0.1),
644
+ ToTensor3D(),
645
+ Normalize3D(mean=[0.485, 0.456, 0.406],
646
+ std=[0.229, 0.224, 0.225]),
647
+ ])
648
+
649
+ def __len__(self):
650
+ return len(self.dataset)
651
+
652
+ def __getitem__(self, idx):
653
+ item = {}
654
+ video_idx = self.dataset[idx]
655
+ video, start_idx = video_idx.split('_')
656
+ start_idx = int(start_idx)
657
+ frame_path = os.path.join(self.data_path, video, 'frames')
658
+ start_frame_idx = non_negative(FPS * int(start_idx)/SR)
659
+ if self.rand_shift:
660
+ shift = random.uniform(-0.5, 0.5)
661
+ start_frame_idx = non_negative(start_frame_idx + int(FPS * shift))
662
+ start_idx = non_negative(start_idx + int(SR * shift))
663
+ if start_frame_idx > self.video_frame_cnt[video] - self.left_over:
664
+ start_frame_idx = self.video_frame_cnt[video] - self.left_over
665
+ start_idx = non_negative(SR * (start_frame_idx / FPS))
666
+
667
+ end_frame_idx = non_negative(start_frame_idx + FPS * self.L)
668
+
669
+ # target
670
+ wave_path = self.video_audio_path[video]
671
+ frames = [Image.open(os.path.join(
672
+ frame_path, f'frame{i+1:0>6d}')).convert('RGB') for i in
673
+ range(start_frame_idx, end_frame_idx)]
674
+ wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_idx)
675
+ assert sr == SR
676
+ wav = self.wav_transforms(wav)
677
+
678
+ # cond
679
+ if torch.all(torch.bernoulli(self.p_outside_cond) == 1.):
680
+ all_idx = set(list(range(len(self.dataset))))
681
+ all_idx.remove(idx)
682
+ cond_video_idx = self.dataset[sample(all_idx, k=1)[0]]
683
+ cond_video, cond_start_idx = cond_video_idx.split('_')
684
+ else:
685
+ cond_video = video
686
+ video_hits_idx = copy.copy(self.video2indexes[video])
687
+ if str(start_idx) in video_hits_idx:
688
+ video_hits_idx.remove(str(start_idx))
689
+ cond_start_idx = sample(video_hits_idx, k=1)[0]
690
+ cond_video_idx = get_GH_data_identifier(cond_video, cond_start_idx)
691
+
692
+ cond_video, cond_start_idx = cond_video_idx.split('_')
693
+ cond_start_idx = int(cond_start_idx)
694
+ cond_frame_path = os.path.join(self.data_path, cond_video, 'frames')
695
+ cond_start_frame_idx = non_negative(FPS * int(cond_start_idx)/SR)
696
+ cond_wave_path = self.video_audio_path[cond_video]
697
+
698
+ if self.rand_shift:
699
+ cond_shift = random.uniform(-0.5, 0.5)
700
+ cond_start_frame_idx = non_negative(cond_start_frame_idx + int(FPS * cond_shift))
701
+ cond_start_idx = non_negative(cond_start_idx + int(shift * SR))
702
+ if cond_start_frame_idx > self.video_frame_cnt[cond_video] - self.left_over:
703
+ cond_start_frame_idx = self.video_frame_cnt[cond_video] - self.left_over
704
+ cond_start_idx = non_negative(SR * (cond_start_frame_idx / FPS))
705
+ cond_end_frame_idx = non_negative(cond_start_frame_idx + FPS * self.L)
706
+
707
+ cond_frames = [Image.open(os.path.join(
708
+ cond_frame_path, f'frame{i+1:0>6d}')).convert('RGB') for i in
709
+ range(cond_start_frame_idx, cond_end_frame_idx)]
710
+ cond_wav, _ = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_start_idx)
711
+ cond_wav = self.wav_transforms(cond_wav)
712
+
713
+ item['image'] = wav # (44100,)
714
+ item['cond_image'] = cond_wav # (44100,)
715
+ item['file_path_wav_'] = wave_path
716
+ item['file_path_cond_wav_'] = cond_wave_path
717
+
718
+ if self.frame_transforms is not None:
719
+ cond_frames = self.frame_transforms(cond_frames)
720
+ frames = self.frame_transforms(frames)
721
+
722
+ item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3)
723
+ item['file_path_feats_'] = (frame_path, start_idx)
724
+ item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx)
725
+
726
+ item['label'] = 'None'
727
+ item['target'] = 'None'
728
+
729
+ return item
730
+
731
+ def validate_data(self):
732
+ raise NotImplementedError()
733
+
734
+ def make_split_files(self, ratio=[0.85, 0.1, 0.05]):
735
+ random.seed(1337)
736
+ print(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
737
+
738
+ all_video = sorted(os.listdir(self.data_path))
739
+ print(f'The number of videos available after download: {len(all_video)}')
740
+
741
+ available_idx = list(range(len(all_video)))
742
+ random.shuffle(available_idx)
743
+ assert sum(ratio) == 1.
744
+ cut_train = int(ratio[0] * len(all_video))
745
+ cut_test = cut_train + int(ratio[1] * len(all_video))
746
+
747
+ train_idx = available_idx[:cut_train]
748
+ test_idx = available_idx[cut_train:cut_test]
749
+ valid_idx = available_idx[cut_test:]
750
+
751
+ train_video = [all_video[i] for i in train_idx]
752
+ test_video = [all_video[i] for i in test_idx]
753
+ valid_video = [all_video[i] for i in valid_idx]
754
+
755
+ with open(os.path.join(self.splits_path, 'greatesthit_video_train.json'), 'w') as train_file,\
756
+ open(os.path.join(self.splits_path, 'greatesthit_video_test.json'), 'w') as test_file,\
757
+ open(os.path.join(self.splits_path, 'greatesthit_video_valid.json'), 'w') as valid_file:
758
+ json.dump(train_video, train_file)
759
+ json.dump(test_video, test_file)
760
+ json.dump(valid_video, valid_file)
761
+
762
+ print(f'Put {len(train_idx)} videos to the train set and saved it to ./data/greatesthit_video_train.json')
763
+ print(f'Put {len(test_idx)} videos to the test set and saved it to ./data/greatesthit_video_test.json')
764
+ print(f'Put {len(valid_idx)} videos to the valid set and saved it to ./data/greatesthit_video_valid.json')
765
+
766
+
767
+ class CondGreatestHitWaveCondOnImageTrain(CondGreatestHitWaveCondOnImage):
768
+ def __init__(self, dataset_cfg):
769
+ train_transforms = transforms.Compose([
770
+ Resize3D(128),
771
+ RandomResizedCrop3D(112, scale=(0.5, 1.0)),
772
+ RandomHorizontalFlip3D(),
773
+ ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1),
774
+ ToTensor3D(),
775
+ Normalize3D(mean=[0.485, 0.456, 0.406],
776
+ std=[0.229, 0.224, 0.225]),
777
+ ])
778
+ super().__init__('train', frame_transforms=train_transforms, **dataset_cfg)
779
+
780
+ class CondGreatestHitWaveCondOnImageValidation(CondGreatestHitWaveCondOnImage):
781
+ def __init__(self, dataset_cfg):
782
+ valid_transforms = transforms.Compose([
783
+ Resize3D(128),
784
+ CenterCrop3D(112),
785
+ ToTensor3D(),
786
+ Normalize3D(mean=[0.485, 0.456, 0.406],
787
+ std=[0.229, 0.224, 0.225]),
788
+ ])
789
+ super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg)
790
+
791
+ class CondGreatestHitWaveCondOnImageTest(CondGreatestHitWaveCondOnImage):
792
+ def __init__(self, dataset_cfg):
793
+ test_transforms = transforms.Compose([
794
+ Resize3D(128),
795
+ CenterCrop3D(112),
796
+ ToTensor3D(),
797
+ Normalize3D(mean=[0.485, 0.456, 0.406],
798
+ std=[0.229, 0.224, 0.225]),
799
+ ])
800
+ super().__init__('test', frame_transforms=test_transforms, **dataset_cfg)
801
+
802
+
803
+
804
+ class GreatestHitWaveCondOnImage(torch.utils.data.Dataset):
805
+
806
+ def __init__(self, split, wav_dir, spec_len, random_crop, mel_num, spec_crop_len,
807
+ L=2.0, frame_transforms=None, splits_path='./data',
808
+ data_path='data/greatesthit/greatesthit-process-resized',
809
+ p_outside_cond=0., p_audio_aug=0.5, rand_shift=True):
810
+ super().__init__()
811
+ self.split = split
812
+ self.wav_dir = wav_dir
813
+ self.frame_transforms = frame_transforms
814
+ self.splits_path = splits_path
815
+ self.data_path = data_path
816
+ self.spec_len = spec_len
817
+ self.L = L
818
+ self.rand_shift = rand_shift
819
+ self.p_outside_cond = torch.tensor(p_outside_cond)
820
+
821
+ split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json')
822
+ if not os.path.exists(split_clip_ids_path):
823
+ raise NotImplementedError()
824
+ clip_video_hit = json.load(open(split_clip_ids_path, 'r'))
825
+
826
+ video_name = list(set([vidx.split('_')[0] for vidx in clip_video_hit]))
827
+
828
+ self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames')))//2 for v in video_name}
829
+ self.left_over = int(FPS * L + 1)
830
+ self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_denoised_resampled.wav') for v in video_name}
831
+ self.dataset = clip_video_hit
832
+
833
+ self.video2indexes = {}
834
+ for video_idx in self.dataset:
835
+ video, start_idx = video_idx.split('_')
836
+ if video not in self.video2indexes.keys():
837
+ self.video2indexes[video] = []
838
+ self.video2indexes[video].append(start_idx)
839
+ for video in self.video2indexes.keys():
840
+ if len(self.video2indexes[video]) == 1: # given video contains only one hit
841
+ self.dataset.remove(
842
+ get_GH_data_identifier(video, self.video2indexes[video][0])
843
+ )
844
+
845
+ self.wav_transforms = transforms.Compose([
846
+ MakeMono(),
847
+ Padding(target_len=int(SR * self.L)),
848
+ ])
849
+ if self.frame_transforms == None:
850
+ self.frame_transforms = transforms.Compose([
851
+ Resize3D(256),
852
+ RandomResizedCrop3D(224, scale=(0.5, 1.0)),
853
+ RandomHorizontalFlip3D(),
854
+ ColorJitter3D(brightness=0.1, saturation=0.1),
855
+ ToTensor3D(),
856
+ Normalize3D(mean=[0.485, 0.456, 0.406],
857
+ std=[0.229, 0.224, 0.225]),
858
+ ])
859
+
860
+ def __len__(self):
861
+ return len(self.dataset)
862
+
863
+ def __getitem__(self, idx):
864
+ item = {}
865
+ video_idx = self.dataset[idx]
866
+ video, start_idx = video_idx.split('_')
867
+ start_idx = int(start_idx)
868
+ frame_path = os.path.join(self.data_path, video, 'frames')
869
+ start_frame_idx = non_negative(FPS * int(start_idx)/SR)
870
+ if self.rand_shift:
871
+ shift = random.uniform(-0.5, 0.5)
872
+ start_frame_idx = non_negative(start_frame_idx + int(FPS * shift))
873
+ start_idx = non_negative(start_idx + int(SR * shift))
874
+ if start_frame_idx > self.video_frame_cnt[video] - self.left_over:
875
+ start_frame_idx = self.video_frame_cnt[video] - self.left_over
876
+ start_idx = non_negative(SR * (start_frame_idx / FPS))
877
+
878
+ end_frame_idx = non_negative(start_frame_idx + FPS * self.L)
879
+
880
+ # target
881
+ wave_path = self.video_audio_path[video]
882
+ frames = [Image.open(os.path.join(
883
+ frame_path, f'frame{i+1:0>6d}')).convert('RGB') for i in
884
+ range(start_frame_idx, end_frame_idx)]
885
+ wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_idx)
886
+ assert sr == SR
887
+ wav = self.wav_transforms(wav)
888
+
889
+ item['image'] = wav # (44100,)
890
+ item['file_path_wav_'] = wave_path
891
+
892
+ if self.frame_transforms is not None:
893
+ frames = self.frame_transforms(frames)
894
+
895
+ item['feature'] = torch.stack(frames, dim=0) # (15 * L, 112, 112, 3)
896
+ item['file_path_feats_'] = (frame_path, start_idx)
897
+
898
+ item['label'] = 'None'
899
+ item['target'] = 'None'
900
+
901
+ return item
902
+
903
+ def validate_data(self):
904
+ raise NotImplementedError()
905
+
906
+ def make_split_files(self, ratio=[0.85, 0.1, 0.05]):
907
+ random.seed(1337)
908
+ print(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
909
+
910
+ all_video = sorted(os.listdir(self.data_path))
911
+ print(f'The number of videos available after download: {len(all_video)}')
912
+
913
+ available_idx = list(range(len(all_video)))
914
+ random.shuffle(available_idx)
915
+ assert sum(ratio) == 1.
916
+ cut_train = int(ratio[0] * len(all_video))
917
+ cut_test = cut_train + int(ratio[1] * len(all_video))
918
+
919
+ train_idx = available_idx[:cut_train]
920
+ test_idx = available_idx[cut_train:cut_test]
921
+ valid_idx = available_idx[cut_test:]
922
+
923
+ train_video = [all_video[i] for i in train_idx]
924
+ test_video = [all_video[i] for i in test_idx]
925
+ valid_video = [all_video[i] for i in valid_idx]
926
+
927
+ with open(os.path.join(self.splits_path, 'greatesthit_video_train.json'), 'w') as train_file,\
928
+ open(os.path.join(self.splits_path, 'greatesthit_video_test.json'), 'w') as test_file,\
929
+ open(os.path.join(self.splits_path, 'greatesthit_video_valid.json'), 'w') as valid_file:
930
+ json.dump(train_video, train_file)
931
+ json.dump(test_video, test_file)
932
+ json.dump(valid_video, valid_file)
933
+
934
+ print(f'Put {len(train_idx)} videos to the train set and saved it to ./data/greatesthit_video_train.json')
935
+ print(f'Put {len(test_idx)} videos to the test set and saved it to ./data/greatesthit_video_test.json')
936
+ print(f'Put {len(valid_idx)} videos to the valid set and saved it to ./data/greatesthit_video_valid.json')
937
+
938
+
939
+ class GreatestHitWaveCondOnImageTrain(GreatestHitWaveCondOnImage):
940
+ def __init__(self, dataset_cfg):
941
+ train_transforms = transforms.Compose([
942
+ Resize3D(128),
943
+ RandomResizedCrop3D(112, scale=(0.5, 1.0)),
944
+ RandomHorizontalFlip3D(),
945
+ ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1),
946
+ ToTensor3D(),
947
+ Normalize3D(mean=[0.485, 0.456, 0.406],
948
+ std=[0.229, 0.224, 0.225]),
949
+ ])
950
+ super().__init__('train', frame_transforms=train_transforms, **dataset_cfg)
951
+
952
+ class GreatestHitWaveCondOnImageValidation(GreatestHitWaveCondOnImage):
953
+ def __init__(self, dataset_cfg):
954
+ valid_transforms = transforms.Compose([
955
+ Resize3D(128),
956
+ CenterCrop3D(112),
957
+ ToTensor3D(),
958
+ Normalize3D(mean=[0.485, 0.456, 0.406],
959
+ std=[0.229, 0.224, 0.225]),
960
+ ])
961
+ super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg)
962
+
963
+ class GreatestHitWaveCondOnImageTest(GreatestHitWaveCondOnImage):
964
+ def __init__(self, dataset_cfg):
965
+ test_transforms = transforms.Compose([
966
+ Resize3D(128),
967
+ CenterCrop3D(112),
968
+ ToTensor3D(),
969
+ Normalize3D(mean=[0.485, 0.456, 0.406],
970
+ std=[0.229, 0.224, 0.225]),
971
+ ])
972
+ super().__init__('test', frame_transforms=test_transforms, **dataset_cfg)
973
+
974
+
975
+ def draw_spec(spec, dest, cmap='magma'):
976
+ plt.imshow(spec, cmap=cmap, origin='lower')
977
+ plt.axis('off')
978
+ plt.savefig(dest, bbox_inches='tight', pad_inches=0., dpi=300)
979
+ plt.close()
980
+
981
+ if __name__ == '__main__':
982
+ import sys
983
+
984
+ from omegaconf import OmegaConf
985
+
986
+ # cfg = OmegaConf.load('configs/greatesthit_transformer_with_vNet_randshift_2s_GH_vqgan_no_earlystop.yaml')
987
+ cfg = OmegaConf.load('configs/greatesthit_codebook.yaml')
988
+ data = instantiate_from_config(cfg.data)
989
+ data.prepare_data()
990
+ data.setup()
991
+ print(len(data.datasets['train']))
992
+ print(data.datasets['train'][24])
993
+
foleycrafter/models/specvqgan/data/impactset.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ from torchvision import transforms
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from random import sample
9
+ import torchaudio
10
+ import logging
11
+ from glob import glob
12
+ import sys
13
+ import soundfile
14
+ import copy
15
+ import csv
16
+ import noisereduce as nr
17
+
18
+ sys.path.insert(0, '.') # nopep8
19
+ from train import instantiate_from_config
20
+ from foleycrafter.models.specvqgan.data.transforms import *
21
+
22
+ torchaudio.set_audio_backend("sox_io")
23
+ logger = logging.getLogger(f'main.{__name__}')
24
+
25
+ SR = 22050
26
+ FPS = 15
27
+ MAX_SAMPLE_ITER = 10
28
+
29
+ def non_negative(x): return int(np.round(max(0, x), 0))
30
+
31
+ def rms(x): return np.sqrt(np.mean(x**2))
32
+
33
+ def get_GH_data_identifier(video_name, start_idx, split='_'):
34
+ if isinstance(start_idx, str):
35
+ return video_name + split + start_idx
36
+ elif isinstance(start_idx, int):
37
+ return video_name + split + str(start_idx)
38
+ else:
39
+ raise NotImplementedError
40
+
41
+ def draw_spec(spec, dest, cmap='magma'):
42
+ plt.imshow(spec, cmap=cmap, origin='lower')
43
+ plt.axis('off')
44
+ plt.savefig(dest, bbox_inches='tight', pad_inches=0., dpi=300)
45
+ plt.close()
46
+
47
+ def convert_to_decibel(arr):
48
+ ref = 1
49
+ return 20 * np.log10(abs(arr + 1e-4) / ref)
50
+
51
+ class ResampleFrames(object):
52
+ def __init__(self, feat_sample_size, times_to_repeat_after_resample=None):
53
+ self.feat_sample_size = feat_sample_size
54
+ self.times_to_repeat_after_resample = times_to_repeat_after_resample
55
+
56
+ def __call__(self, item):
57
+ feat_len = item['feature'].shape[0]
58
+
59
+ ## resample
60
+ assert feat_len >= self.feat_sample_size
61
+ # evenly spaced points (abcdefghkl -> aoooofoooo)
62
+ idx = np.linspace(0, feat_len, self.feat_sample_size, dtype=np.int, endpoint=False)
63
+ # xoooo xoooo -> ooxoo ooxoo
64
+ shift = feat_len // (self.feat_sample_size + 1)
65
+ idx = idx + shift
66
+
67
+ ## repeat after resampling (abc -> aaaabbbbcccc)
68
+ if self.times_to_repeat_after_resample is not None and self.times_to_repeat_after_resample > 1:
69
+ idx = np.repeat(idx, self.times_to_repeat_after_resample)
70
+
71
+ item['feature'] = item['feature'][idx, :]
72
+ return item
73
+
74
+
75
+ class ImpactSetWave(torch.utils.data.Dataset):
76
+
77
+ def __init__(self, split, random_crop, mel_num, spec_crop_len,
78
+ L=2.0, denoise=False, splits_path='./data',
79
+ data_path='data/ImpactSet/impactset-proccess-resize'):
80
+ super().__init__()
81
+ self.split = split
82
+ self.splits_path = splits_path
83
+ self.data_path = data_path
84
+ self.L = L
85
+ self.denoise = denoise
86
+
87
+ video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json')
88
+ if not os.path.exists(video_name_split_path):
89
+ self.make_split_files()
90
+ video_name = json.load(open(video_name_split_path, 'r'))
91
+ self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name}
92
+ self.left_over = int(FPS * L + 1)
93
+ self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name}
94
+ self.dataset = video_name
95
+
96
+ self.wav_transforms = transforms.Compose([
97
+ MakeMono(),
98
+ Padding(target_len=int(SR * self.L)),
99
+ ])
100
+
101
+ self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop)
102
+
103
+ def __len__(self):
104
+ return len(self.dataset)
105
+
106
+ def __getitem__(self, idx):
107
+ item = {}
108
+ video = self.dataset[idx]
109
+
110
+ available_frame_idx = self.video_frame_cnt[video] - self.left_over
111
+ wav = None
112
+ spec = None
113
+ max_db = -np.inf
114
+ wave_path = ''
115
+ cur_wave_path = self.video_audio_path[video]
116
+ if self.denoise:
117
+ cur_wave_path = cur_wave_path.replace('.wav', '_denoised.wav')
118
+ for _ in range(10):
119
+ start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0]
120
+ # target
121
+ start_t = (start_idx + 0.5) / FPS
122
+ start_audio_idx = non_negative(start_t * SR)
123
+
124
+ cur_wav, _ = soundfile.read(cur_wave_path, frames=int(SR * self.L), start=start_audio_idx)
125
+
126
+ decibel = convert_to_decibel(cur_wav)
127
+ if float(np.mean(decibel)) > max_db:
128
+ wav = cur_wav
129
+ wave_path = cur_wave_path
130
+ max_db = float(np.mean(decibel))
131
+ if max_db >= -40:
132
+ break
133
+
134
+ # print(max_db)
135
+ wav = self.wav_transforms(wav)
136
+ item['image'] = wav # (80, 173)
137
+ # item['wav'] = wav
138
+ item['file_path_wav_'] = wave_path
139
+
140
+ item['label'] = 'None'
141
+ item['target'] = 'None'
142
+
143
+ return item
144
+
145
+ def make_split_files(self):
146
+ raise NotImplementedError
147
+
148
+ class ImpactSetWaveTrain(ImpactSetWave):
149
+ def __init__(self, specs_dataset_cfg):
150
+ super().__init__('train', **specs_dataset_cfg)
151
+
152
+ class ImpactSetWaveValidation(ImpactSetWave):
153
+ def __init__(self, specs_dataset_cfg):
154
+ super().__init__('val', **specs_dataset_cfg)
155
+
156
+ class ImpactSetWaveTest(ImpactSetWave):
157
+ def __init__(self, specs_dataset_cfg):
158
+ super().__init__('test', **specs_dataset_cfg)
159
+
160
+
161
+ class ImpactSetSpec(torch.utils.data.Dataset):
162
+
163
+ def __init__(self, split, random_crop, mel_num, spec_crop_len,
164
+ L=2.0, denoise=False, splits_path='./data',
165
+ data_path='data/ImpactSet/impactset-proccess-resize'):
166
+ super().__init__()
167
+ self.split = split
168
+ self.splits_path = splits_path
169
+ self.data_path = data_path
170
+ self.L = L
171
+ self.denoise = denoise
172
+
173
+ video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json')
174
+ if not os.path.exists(video_name_split_path):
175
+ self.make_split_files()
176
+ video_name = json.load(open(video_name_split_path, 'r'))
177
+ self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name}
178
+ self.left_over = int(FPS * L + 1)
179
+ self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name}
180
+ self.dataset = video_name
181
+
182
+ self.wav_transforms = transforms.Compose([
183
+ MakeMono(),
184
+ SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1),
185
+ MelScaleTorchAudio(sr=SR, stft=513, fmin=125, fmax=7600, nmels=80),
186
+ LowerThresh(1e-5),
187
+ Log10(),
188
+ Multiply(20),
189
+ Subtract(20),
190
+ Add(100),
191
+ Divide(100),
192
+ Clip(0, 1.0),
193
+ TrimSpec(173),
194
+ ])
195
+
196
+ self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop)
197
+
198
+ def __len__(self):
199
+ return len(self.dataset)
200
+
201
+ def __getitem__(self, idx):
202
+ item = {}
203
+ video = self.dataset[idx]
204
+
205
+ available_frame_idx = self.video_frame_cnt[video] - self.left_over
206
+ wav = None
207
+ spec = None
208
+ max_rms = -np.inf
209
+ wave_path = ''
210
+ cur_wave_path = self.video_audio_path[video]
211
+ if self.denoise:
212
+ cur_wave_path = cur_wave_path.replace('.wav', '_denoised.wav')
213
+ for _ in range(10):
214
+ start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0]
215
+ # target
216
+ start_t = (start_idx + 0.5) / FPS
217
+ start_audio_idx = non_negative(start_t * SR)
218
+
219
+ cur_wav, _ = soundfile.read(cur_wave_path, frames=int(SR * self.L), start=start_audio_idx)
220
+
221
+ if self.wav_transforms is not None:
222
+ spec_tensor = self.wav_transforms(torch.tensor(cur_wav).float())
223
+ cur_spec = spec_tensor.numpy()
224
+ # zeros padding if not enough spec t steps
225
+ if cur_spec.shape[1] < 173:
226
+ pad = np.zeros((80, 173), dtype=cur_spec.dtype)
227
+ pad[:, :cur_spec.shape[1]] = cur_spec
228
+ cur_spec = pad
229
+ rms_val = rms(cur_spec)
230
+ if rms_val > max_rms:
231
+ wav = cur_wav
232
+ spec = cur_spec
233
+ wave_path = cur_wave_path
234
+ max_rms = rms_val
235
+ # print(rms_val)
236
+ if max_rms >= 0.1:
237
+ break
238
+
239
+ item['image'] = 2 * spec - 1 # (80, 173)
240
+ # item['wav'] = wav
241
+ item['file_path_wav_'] = wave_path
242
+
243
+ item['label'] = 'None'
244
+ item['target'] = 'None'
245
+
246
+ if self.spec_transforms is not None:
247
+ item = self.spec_transforms(item)
248
+ return item
249
+
250
+ def make_split_files(self):
251
+ raise NotImplementedError
252
+
253
+ class ImpactSetSpecTrain(ImpactSetSpec):
254
+ def __init__(self, specs_dataset_cfg):
255
+ super().__init__('train', **specs_dataset_cfg)
256
+
257
+ class ImpactSetSpecValidation(ImpactSetSpec):
258
+ def __init__(self, specs_dataset_cfg):
259
+ super().__init__('val', **specs_dataset_cfg)
260
+
261
+ class ImpactSetSpecTest(ImpactSetSpec):
262
+ def __init__(self, specs_dataset_cfg):
263
+ super().__init__('test', **specs_dataset_cfg)
264
+
265
+
266
+
267
+ class ImpactSetWaveTestTime(torch.utils.data.Dataset):
268
+
269
+ def __init__(self, split, random_crop, mel_num, spec_crop_len,
270
+ L=2.0, denoise=False, splits_path='./data',
271
+ data_path='data/ImpactSet/impactset-proccess-resize'):
272
+ super().__init__()
273
+ self.split = split
274
+ self.splits_path = splits_path
275
+ self.data_path = data_path
276
+ self.L = L
277
+ self.denoise = denoise
278
+
279
+ self.video_list = glob('data/ImpactSet/RawVideos/StockVideo_sound/*.wav') + [
280
+ 'data/ImpactSet/RawVideos/YouTube-impact-ccl/1_ckbCU5aQs/1_ckbCU5aQs_0013_0016_resize.wav',
281
+ 'data/ImpactSet/RawVideos/YouTube-impact-ccl/GFmuVBiwz6k/GFmuVBiwz6k_0034_0054_resize.wav',
282
+ 'data/ImpactSet/RawVideos/YouTube-impact-ccl/OsPcY316h1M/OsPcY316h1M_0000_0005_resize.wav',
283
+ 'data/ImpactSet/RawVideos/YouTube-impact-ccl/SExIpBIBj_k/SExIpBIBj_k_0009_0019_resize.wav',
284
+ 'data/ImpactSet/RawVideos/YouTube-impact-ccl/S6TkbV4B4QI/S6TkbV4B4QI_0028_0036_resize.wav',
285
+ 'data/ImpactSet/RawVideos/YouTube-impact-ccl/2Ld24pPIn3k/2Ld24pPIn3k_0005_0011_resize.wav',
286
+ 'data/ImpactSet/RawVideos/YouTube-impact-ccl/6d1YS7fdBK4/6d1YS7fdBK4_0007_0019_resize.wav',
287
+ 'data/ImpactSet/RawVideos/YouTube-impact-ccl/JnBsmJgEkiw/JnBsmJgEkiw_0008_0016_resize.wav',
288
+ 'data/ImpactSet/RawVideos/YouTube-impact-ccl/xcUyiXt0gjo/xcUyiXt0gjo_0015_0021_resize.wav',
289
+ 'data/ImpactSet/RawVideos/YouTube-impact-ccl/4DRFJnZjpMM/4DRFJnZjpMM_0000_0010_resize.wav'
290
+ ] + glob('data/ImpactSet/RawVideos/self_recorded/*_resize.wav')
291
+
292
+ self.wav_transforms = transforms.Compose([
293
+ MakeMono(),
294
+ SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1),
295
+ MelScaleTorchAudio(sr=SR, stft=513, fmin=125, fmax=7600, nmels=80),
296
+ LowerThresh(1e-5),
297
+ Log10(),
298
+ Multiply(20),
299
+ Subtract(20),
300
+ Add(100),
301
+ Divide(100),
302
+ Clip(0, 1.0),
303
+ TrimSpec(173),
304
+ ])
305
+ self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop)
306
+
307
+ def __len__(self):
308
+ return len(self.video_list)
309
+
310
+ def __getitem__(self, idx):
311
+ item = {}
312
+
313
+ wave_path = self.video_list[idx]
314
+
315
+ wav, _ = soundfile.read(wave_path)
316
+ start_idx = random.randint(0, min(4, wav.shape[0] - int(SR * self.L)))
317
+ wav = wav[start_idx:start_idx+int(SR * self.L)]
318
+
319
+ if self.denoise:
320
+ if len(wav.shape) == 1:
321
+ wav = wav[None, :]
322
+ wav = nr.reduce_noise(y=wav, sr=SR, n_fft=1024, hop_length=1024//4)
323
+ wav = wav.squeeze()
324
+ if self.wav_transforms is not None:
325
+ spec_tensor = self.wav_transforms(torch.tensor(wav).float())
326
+ spec = spec_tensor.numpy()
327
+ if spec.shape[1] < 173:
328
+ pad = np.zeros((80, 173), dtype=spec.dtype)
329
+ pad[:, :spec.shape[1]] = spec
330
+ spec = pad
331
+
332
+ item['image'] = 2 * spec - 1 # (80, 173)
333
+ # item['wav'] = wav
334
+ item['file_path_wav_'] = wave_path
335
+
336
+ item['label'] = 'None'
337
+ item['target'] = 'None'
338
+
339
+ if self.spec_transforms is not None:
340
+ item = self.spec_transforms(item)
341
+ return item
342
+
343
+ def make_split_files(self):
344
+ raise NotImplementedError
345
+
346
+ class ImpactSetWaveTestTimeTrain(ImpactSetWaveTestTime):
347
+ def __init__(self, specs_dataset_cfg):
348
+ super().__init__('train', **specs_dataset_cfg)
349
+
350
+ class ImpactSetWaveTestTimeValidation(ImpactSetWaveTestTime):
351
+ def __init__(self, specs_dataset_cfg):
352
+ super().__init__('val', **specs_dataset_cfg)
353
+
354
+ class ImpactSetWaveTestTimeTest(ImpactSetWaveTestTime):
355
+ def __init__(self, specs_dataset_cfg):
356
+ super().__init__('test', **specs_dataset_cfg)
357
+
358
+
359
+ class ImpactSetWaveWithSilent(torch.utils.data.Dataset):
360
+
361
+ def __init__(self, split, random_crop, mel_num, spec_crop_len,
362
+ L=2.0, denoise=False, splits_path='./data',
363
+ data_path='data/ImpactSet/impactset-proccess-resize'):
364
+ super().__init__()
365
+ self.split = split
366
+ self.splits_path = splits_path
367
+ self.data_path = data_path
368
+ self.L = L
369
+ self.denoise = denoise
370
+
371
+ video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json')
372
+ if not os.path.exists(video_name_split_path):
373
+ self.make_split_files()
374
+ video_name = json.load(open(video_name_split_path, 'r'))
375
+ self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name}
376
+ self.left_over = int(FPS * L + 1)
377
+ self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name}
378
+ self.dataset = video_name
379
+
380
+ self.wav_transforms = transforms.Compose([
381
+ MakeMono(),
382
+ Padding(target_len=int(SR * self.L)),
383
+ ])
384
+
385
+ self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop)
386
+
387
+ def __len__(self):
388
+ return len(self.dataset)
389
+
390
+ def __getitem__(self, idx):
391
+ item = {}
392
+ video = self.dataset[idx]
393
+
394
+ available_frame_idx = self.video_frame_cnt[video] - self.left_over
395
+ wave_path = self.video_audio_path[video]
396
+ if self.denoise:
397
+ wave_path = wave_path.replace('.wav', '_denoised.wav')
398
+ start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0]
399
+ # target
400
+ start_t = (start_idx + 0.5) / FPS
401
+ start_audio_idx = non_negative(start_t * SR)
402
+
403
+ wav, _ = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx)
404
+
405
+ wav = self.wav_transforms(wav)
406
+
407
+ item['image'] = wav # (44100,)
408
+ # item['wav'] = wav
409
+ item['file_path_wav_'] = wave_path
410
+
411
+ item['label'] = 'None'
412
+ item['target'] = 'None'
413
+ return item
414
+
415
+ def make_split_files(self):
416
+ raise NotImplementedError
417
+
418
+ class ImpactSetWaveWithSilentTrain(ImpactSetWaveWithSilent):
419
+ def __init__(self, specs_dataset_cfg):
420
+ super().__init__('train', **specs_dataset_cfg)
421
+
422
+ class ImpactSetWaveWithSilentValidation(ImpactSetWaveWithSilent):
423
+ def __init__(self, specs_dataset_cfg):
424
+ super().__init__('val', **specs_dataset_cfg)
425
+
426
+ class ImpactSetWaveWithSilentTest(ImpactSetWaveWithSilent):
427
+ def __init__(self, specs_dataset_cfg):
428
+ super().__init__('test', **specs_dataset_cfg)
429
+
430
+
431
+ class ImpactSetWaveCondOnImage(torch.utils.data.Dataset):
432
+
433
+ def __init__(self, split,
434
+ L=2.0, frame_transforms=None, denoise=False, splits_path='./data',
435
+ data_path='data/ImpactSet/impactset-proccess-resize',
436
+ p_outside_cond=0.):
437
+ super().__init__()
438
+ self.split = split
439
+ self.splits_path = splits_path
440
+ self.frame_transforms = frame_transforms
441
+ self.data_path = data_path
442
+ self.L = L
443
+ self.denoise = denoise
444
+ self.p_outside_cond = torch.tensor(p_outside_cond)
445
+
446
+ video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json')
447
+ if not os.path.exists(video_name_split_path):
448
+ self.make_split_files()
449
+ video_name = json.load(open(video_name_split_path, 'r'))
450
+ self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name}
451
+ self.left_over = int(FPS * L + 1)
452
+ for v, cnt in self.video_frame_cnt.items():
453
+ if cnt - (3*self.left_over) <= 0:
454
+ video_name.remove(v)
455
+ self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name}
456
+ self.dataset = video_name
457
+
458
+ video_timing_split_path = os.path.join(splits_path, f'countixAV_{split}_timing.json')
459
+ self.video_timing = json.load(open(video_timing_split_path, 'r'))
460
+ self.video_timing = {v: [int(float(t) * FPS) for t in ts] for v, ts in self.video_timing.items()}
461
+
462
+ if split != 'test':
463
+ video_class_path = os.path.join(splits_path, f'countixAV_{split}_class.json')
464
+ if not os.path.exists(video_class_path):
465
+ self.make_video_class()
466
+ self.video_class = json.load(open(video_class_path, 'r'))
467
+ self.class2video = {}
468
+ for v, c in self.video_class.items():
469
+ if c not in self.class2video.keys():
470
+ self.class2video[c] = []
471
+ self.class2video[c].append(v)
472
+
473
+ self.wav_transforms = transforms.Compose([
474
+ MakeMono(),
475
+ Padding(target_len=int(SR * self.L)),
476
+ ])
477
+ if self.frame_transforms == None:
478
+ self.frame_transforms = transforms.Compose([
479
+ Resize3D(128),
480
+ RandomResizedCrop3D(112, scale=(0.5, 1.0)),
481
+ RandomHorizontalFlip3D(),
482
+ ColorJitter3D(brightness=0.1, saturation=0.1),
483
+ ToTensor3D(),
484
+ Normalize3D(mean=[0.485, 0.456, 0.406],
485
+ std=[0.229, 0.224, 0.225]),
486
+ ])
487
+
488
+ def make_video_class(self):
489
+ meta_path = f'data/ImpactSet/data-info/CountixAV_{self.split}.csv'
490
+ video_class = {}
491
+ with open(meta_path, 'r') as f:
492
+ reader = csv.reader(f)
493
+ for i, row in enumerate(reader):
494
+ if i == 0:
495
+ continue
496
+ vid, k_st, k_et = row[:3]
497
+ video_name = f'{vid}_{int(k_st):0>4d}_{int(k_et):0>4d}'
498
+ if video_name not in self.dataset:
499
+ continue
500
+ video_class[video_name] = row[-1]
501
+ with open(os.path.join(self.splits_path, f'countixAV_{self.split}_class.json'), 'w') as f:
502
+ json.dump(video_class, f)
503
+
504
+ def __len__(self):
505
+ return len(self.dataset)
506
+
507
+ def __getitem__(self, idx):
508
+ item = {}
509
+ video = self.dataset[idx]
510
+
511
+ available_frame_idx = self.video_frame_cnt[video] - self.left_over
512
+ rep_start_idx, rep_end_idx = self.video_timing[video]
513
+ rep_end_idx = min(available_frame_idx, rep_end_idx)
514
+ if available_frame_idx <= rep_start_idx + self.L * FPS:
515
+ idx_set = list(range(0, available_frame_idx))
516
+ else:
517
+ idx_set = list(range(rep_start_idx, rep_end_idx))
518
+ start_idx = sample(idx_set, k=1)[0]
519
+
520
+ wave_path = self.video_audio_path[video]
521
+ if self.denoise:
522
+ wave_path = wave_path.replace('.wav', '_denoised.wav')
523
+
524
+ # target
525
+ start_t = (start_idx + 0.5) / FPS
526
+ end_idx= non_negative(start_idx + FPS * self.L)
527
+ start_audio_idx = non_negative(start_t * SR)
528
+ wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx)
529
+ assert sr == SR
530
+ wav = self.wav_transforms(wav)
531
+ frame_path = os.path.join(self.data_path, video, 'frames')
532
+ frames = [Image.open(os.path.join(
533
+ frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
534
+ range(start_idx, end_idx)]
535
+
536
+ if torch.all(torch.bernoulli(self.p_outside_cond) == 1.) and self.split != 'test':
537
+ # outside from the same class
538
+ cur_class = self.video_class[video]
539
+ tmp_video = copy.copy(self.class2video[cur_class])
540
+ if len(tmp_video) > 1:
541
+ # if only 1 video in the class, use itself
542
+ tmp_video.remove(video)
543
+ cond_video = sample(tmp_video, k=1)[0]
544
+ cond_available_frame_idx = self.video_frame_cnt[cond_video] - self.left_over
545
+ cond_start_idx = torch.randint(0, cond_available_frame_idx, (1,)).tolist()[0]
546
+ else:
547
+ cond_video = video
548
+ idx_set = list(range(0, start_idx)) + list(range(end_idx, available_frame_idx))
549
+ cond_start_idx = random.sample(idx_set, k=1)[0]
550
+
551
+ cond_end_idx = non_negative(cond_start_idx + FPS * self.L)
552
+ cond_start_t = (cond_start_idx + 0.5) / FPS
553
+ cond_audio_idx = non_negative(cond_start_t * SR)
554
+ cond_frame_path = os.path.join(self.data_path, cond_video, 'frames')
555
+ cond_wave_path = self.video_audio_path[cond_video]
556
+
557
+ cond_frames = [Image.open(os.path.join(
558
+ cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
559
+ range(cond_start_idx, cond_end_idx)]
560
+ cond_wav, sr = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_audio_idx)
561
+ assert sr == SR
562
+ cond_wav = self.wav_transforms(cond_wav)
563
+
564
+ item['image'] = wav # (44100,)
565
+ item['cond_image'] = cond_wav # (44100,)
566
+ item['file_path_wav_'] = wave_path
567
+ item['file_path_cond_wav_'] = cond_wave_path
568
+
569
+ if self.frame_transforms is not None:
570
+ cond_frames = self.frame_transforms(cond_frames)
571
+ frames = self.frame_transforms(frames)
572
+
573
+ item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3)
574
+ item['file_path_feats_'] = (frame_path, start_idx)
575
+ item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx)
576
+
577
+ item['label'] = 'None'
578
+ item['target'] = 'None'
579
+
580
+ return item
581
+
582
+ def make_split_files(self):
583
+ raise NotImplementedError
584
+
585
+
586
+ class ImpactSetWaveCondOnImageTrain(ImpactSetWaveCondOnImage):
587
+ def __init__(self, dataset_cfg):
588
+ train_transforms = transforms.Compose([
589
+ Resize3D(128),
590
+ RandomResizedCrop3D(112, scale=(0.5, 1.0)),
591
+ RandomHorizontalFlip3D(),
592
+ ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1),
593
+ ToTensor3D(),
594
+ Normalize3D(mean=[0.485, 0.456, 0.406],
595
+ std=[0.229, 0.224, 0.225]),
596
+ ])
597
+ super().__init__('train', frame_transforms=train_transforms, **dataset_cfg)
598
+
599
+ class ImpactSetWaveCondOnImageValidation(ImpactSetWaveCondOnImage):
600
+ def __init__(self, dataset_cfg):
601
+ valid_transforms = transforms.Compose([
602
+ Resize3D(128),
603
+ CenterCrop3D(112),
604
+ ToTensor3D(),
605
+ Normalize3D(mean=[0.485, 0.456, 0.406],
606
+ std=[0.229, 0.224, 0.225]),
607
+ ])
608
+ super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg)
609
+
610
+ class ImpactSetWaveCondOnImageTest(ImpactSetWaveCondOnImage):
611
+ def __init__(self, dataset_cfg):
612
+ test_transforms = transforms.Compose([
613
+ Resize3D(128),
614
+ CenterCrop3D(112),
615
+ ToTensor3D(),
616
+ Normalize3D(mean=[0.485, 0.456, 0.406],
617
+ std=[0.229, 0.224, 0.225]),
618
+ ])
619
+ super().__init__('test', frame_transforms=test_transforms, **dataset_cfg)
620
+
621
+
622
+
623
+ class ImpactSetCleanWaveCondOnImage(ImpactSetWaveCondOnImage):
624
+ def __init__(self, split, L=2, frame_transforms=None, denoise=False, splits_path='./data', data_path='data/ImpactSet/impactset-proccess-resize', p_outside_cond=0):
625
+ super().__init__(split, L, frame_transforms, denoise, splits_path, data_path, p_outside_cond)
626
+ pred_timing_path = f'data/countixAV_{split}_timing_processed_0.20.json'
627
+ assert os.path.exists(pred_timing_path)
628
+ self.pred_timing = json.load(open(pred_timing_path, 'r'))
629
+
630
+ self.dataset = []
631
+ for v, ts in self.pred_timing.items():
632
+ if v in self.video_audio_path.keys():
633
+ for t in ts:
634
+ self.dataset.append([v, t])
635
+
636
+ def __getitem__(self, idx):
637
+ item = {}
638
+ video, start_t = self.dataset[idx]
639
+ available_frame_idx = self.video_frame_cnt[video] - self.left_over
640
+ available_timing = (available_frame_idx + 0.5) / FPS
641
+ start_t = float(start_t)
642
+ start_t = min(start_t, available_timing)
643
+
644
+ start_idx = non_negative(start_t * FPS - 0.5)
645
+
646
+ wave_path = self.video_audio_path[video]
647
+ if self.denoise:
648
+ wave_path = wave_path.replace('.wav', '_denoised.wav')
649
+
650
+ # target
651
+ end_idx= non_negative(start_idx + FPS * self.L)
652
+ start_audio_idx = non_negative(start_t * SR)
653
+ wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx)
654
+ assert sr == SR
655
+ wav = self.wav_transforms(wav)
656
+ frame_path = os.path.join(self.data_path, video, 'frames')
657
+ frames = [Image.open(os.path.join(
658
+ frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
659
+ range(start_idx, end_idx)]
660
+
661
+ if torch.all(torch.bernoulli(self.p_outside_cond) == 1.):
662
+ other_video = list(self.pred_timing.keys())
663
+ other_video.remove(video)
664
+ cond_video = sample(other_video, k=1)[0]
665
+ cond_available_frame_idx = self.video_frame_cnt[cond_video] - self.left_over
666
+ cond_available_timing = (cond_available_frame_idx + 0.5) / FPS
667
+ else:
668
+ cond_video = video
669
+ cond_available_timing = available_timing
670
+
671
+ cond_start_t = sample(self.pred_timing[cond_video], k=1)[0]
672
+ cond_start_t = float(cond_start_t)
673
+ cond_start_t = min(cond_start_t, cond_available_timing)
674
+ cond_start_idx = non_negative(cond_start_t * FPS - 0.5)
675
+ cond_end_idx = non_negative(cond_start_idx + FPS * self.L)
676
+ cond_audio_idx = non_negative(cond_start_t * SR)
677
+ cond_frame_path = os.path.join(self.data_path, cond_video, 'frames')
678
+ cond_wave_path = self.video_audio_path[cond_video]
679
+
680
+ cond_frames = [Image.open(os.path.join(
681
+ cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in
682
+ range(cond_start_idx, cond_end_idx)]
683
+ cond_wav, sr = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_audio_idx)
684
+ assert sr == SR
685
+ cond_wav = self.wav_transforms(cond_wav)
686
+
687
+ item['image'] = wav # (44100,)
688
+ item['cond_image'] = cond_wav # (44100,)
689
+ item['file_path_wav_'] = wave_path
690
+ item['file_path_cond_wav_'] = cond_wave_path
691
+
692
+ if self.frame_transforms is not None:
693
+ cond_frames = self.frame_transforms(cond_frames)
694
+ frames = self.frame_transforms(frames)
695
+
696
+ item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3)
697
+ item['file_path_feats_'] = (frame_path, start_idx)
698
+ item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx)
699
+
700
+ item['label'] = 'None'
701
+ item['target'] = 'None'
702
+
703
+ return item
704
+
705
+
706
+ class ImpactSetCleanWaveCondOnImageTrain(ImpactSetCleanWaveCondOnImage):
707
+ def __init__(self, dataset_cfg):
708
+ train_transforms = transforms.Compose([
709
+ Resize3D(128),
710
+ RandomResizedCrop3D(112, scale=(0.5, 1.0)),
711
+ RandomHorizontalFlip3D(),
712
+ ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1),
713
+ ToTensor3D(),
714
+ Normalize3D(mean=[0.485, 0.456, 0.406],
715
+ std=[0.229, 0.224, 0.225]),
716
+ ])
717
+ super().__init__('train', frame_transforms=train_transforms, **dataset_cfg)
718
+
719
+ class ImpactSetCleanWaveCondOnImageValidation(ImpactSetCleanWaveCondOnImage):
720
+ def __init__(self, dataset_cfg):
721
+ valid_transforms = transforms.Compose([
722
+ Resize3D(128),
723
+ CenterCrop3D(112),
724
+ ToTensor3D(),
725
+ Normalize3D(mean=[0.485, 0.456, 0.406],
726
+ std=[0.229, 0.224, 0.225]),
727
+ ])
728
+ super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg)
729
+
730
+ class ImpactSetCleanWaveCondOnImageTest(ImpactSetCleanWaveCondOnImage):
731
+ def __init__(self, dataset_cfg):
732
+ test_transforms = transforms.Compose([
733
+ Resize3D(128),
734
+ CenterCrop3D(112),
735
+ ToTensor3D(),
736
+ Normalize3D(mean=[0.485, 0.456, 0.406],
737
+ std=[0.229, 0.224, 0.225]),
738
+ ])
739
+ super().__init__('test', frame_transforms=test_transforms, **dataset_cfg)
740
+
741
+
742
+ if __name__ == '__main__':
743
+ import sys
744
+
745
+ from omegaconf import OmegaConf
746
+ cfg = OmegaConf.load('configs/countixAV_transformer_denoise_clean.yaml')
747
+ data = instantiate_from_config(cfg.data)
748
+ data.prepare_data()
749
+ data.setup()
750
+
751
+ print(data.datasets['train'])
752
+ print(len(data.datasets['train']))
753
+ # print(data.datasets['train'][24])
754
+ exit()
755
+
756
+ stats = []
757
+ torch.manual_seed(0)
758
+ np.random.seed(0)
759
+ random.seed = 0
760
+ for k in range(1):
761
+ x = np.arange(SR * 2)
762
+ for i in tqdm(range(len(data.datasets['train']))):
763
+ wav = data.datasets['train'][i]['wav']
764
+ spec = data.datasets['train'][i]['image']
765
+ spec = 0.5 * (spec + 1)
766
+ spec_rms = rms(spec)
767
+ stats.append(float(spec_rms))
768
+ # plt.plot(x, wav)
769
+ # plt.ylim(-1, 1)
770
+ # plt.savefig(f'tmp/th0.1_wav_e_{k}_{i}_{mean_val:.3f}_{spec_rms:.3f}.png')
771
+ # plt.close()
772
+ # plt.cla()
773
+ soundfile.write(f'tmp/wav_e_{k}_{i}_{spec_rms:.3f}.wav', wav, SR)
774
+ draw_spec(spec, f'tmp/wav_spec_e_{k}_{i}_{spec_rms:.3f}.png')
775
+ if i == 100:
776
+ break
777
+ # plt.hist(stats, bins=50)
778
+ # plt.savefig(f'tmp/rms_spec_stats.png')
foleycrafter/models/specvqgan/data/transforms.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torchaudio.functional
4
+ from torchvision import transforms
5
+ import torchvision.transforms.functional as F
6
+ import torch.nn as nn
7
+ from PIL import Image
8
+ import numpy as np
9
+ import math
10
+ import random
11
+ import soundfile
12
+ import os
13
+ import librosa
14
+ import albumentations
15
+ from torch_pitch_shift import *
16
+
17
+ SR = 22050
18
+
19
+ class ResizeShortSide(object):
20
+ def __init__(self, size):
21
+ super().__init__()
22
+ self.size = size
23
+
24
+ def __call__(self, x):
25
+ '''
26
+ x must be PIL.Image
27
+ '''
28
+ w, h = x.size
29
+ short_side = min(w, h)
30
+ w_target = int((w / short_side) * self.size)
31
+ h_target = int((h / short_side) * self.size)
32
+ return x.resize((w_target, h_target))
33
+
34
+
35
+ class Crop(object):
36
+ def __init__(self, cropped_shape=None, random_crop=False):
37
+ self.cropped_shape = cropped_shape
38
+ if cropped_shape is not None:
39
+ mel_num, spec_len = cropped_shape
40
+ if random_crop:
41
+ self.cropper = albumentations.RandomCrop
42
+ else:
43
+ self.cropper = albumentations.CenterCrop
44
+ self.preprocessor = albumentations.Compose([self.cropper(mel_num, spec_len)])
45
+ else:
46
+ self.preprocessor = lambda **kwargs: kwargs
47
+
48
+ def __call__(self, item):
49
+ item['image'] = self.preprocessor(image=item['image'])['image']
50
+ if 'cond_image' in item.keys():
51
+ item['cond_image'] = self.preprocessor(image=item['cond_image'])['image']
52
+ return item
53
+
54
+ class CropImage(Crop):
55
+ def __init__(self, *crop_args):
56
+ super().__init__(*crop_args)
57
+
58
+ class CropFeats(Crop):
59
+ def __init__(self, *crop_args):
60
+ super().__init__(*crop_args)
61
+
62
+ def __call__(self, item):
63
+ item['feature'] = self.preprocessor(image=item['feature'])['image']
64
+ return item
65
+
66
+ class CropCoords(Crop):
67
+ def __init__(self, *crop_args):
68
+ super().__init__(*crop_args)
69
+
70
+ def __call__(self, item):
71
+ item['coord'] = self.preprocessor(image=item['coord'])['image']
72
+ return item
73
+
74
+
75
+ class RandomResizedCrop3D(nn.Module):
76
+ """Crop the given series of images to random size and aspect ratio.
77
+ The image can be a PIL Images or a Tensor, in which case it is expected
78
+ to have [N, ..., H, W] shape, where ... means an arbitrary number of leading dimensions
79
+
80
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
81
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
82
+ is finally resized to given size.
83
+ This is popularly used to train the Inception networks.
84
+
85
+ Args:
86
+ size (int or sequence): expected output size of each edge. If size is an
87
+ int instead of sequence like (h, w), a square output size ``(size, size)`` is
88
+ made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
89
+ scale (tuple of float): range of size of the origin size cropped
90
+ ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.
91
+ interpolation (int): Desired interpolation enum defined by `filters`_.
92
+ Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
93
+ and ``PIL.Image.BICUBIC`` are supported.
94
+ """
95
+
96
+ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=transforms.InterpolationMode.BILINEAR):
97
+ super().__init__()
98
+ if isinstance(size, tuple) and len(size) == 2:
99
+ self.size = size
100
+ else:
101
+ self.size = (size, size)
102
+
103
+ self.interpolation = interpolation
104
+ self.scale = scale
105
+ self.ratio = ratio
106
+
107
+ @staticmethod
108
+ def get_params(img, scale, ratio):
109
+ """Get parameters for ``crop`` for a random sized crop.
110
+
111
+ Args:
112
+ img (PIL Image or Tensor): Input image.
113
+ scale (list): range of scale of the origin size cropped
114
+ ratio (list): range of aspect ratio of the origin aspect ratio cropped
115
+
116
+ Returns:
117
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
118
+ sized crop.
119
+ """
120
+ width, height = img.size
121
+ area = height * width
122
+
123
+ for _ in range(10):
124
+ target_area = area * \
125
+ torch.empty(1).uniform_(scale[0], scale[1]).item()
126
+ log_ratio = torch.log(torch.tensor(ratio))
127
+ aspect_ratio = torch.exp(
128
+ torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
129
+ ).item()
130
+
131
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
132
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
133
+
134
+ if 0 < w <= width and 0 < h <= height:
135
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
136
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
137
+ return i, j, h, w
138
+
139
+ # Fallback to central crop
140
+ in_ratio = float(width) / float(height)
141
+ if in_ratio < min(ratio):
142
+ w = width
143
+ h = int(round(w / min(ratio)))
144
+ elif in_ratio > max(ratio):
145
+ h = height
146
+ w = int(round(h * max(ratio)))
147
+ else: # whole image
148
+ w = width
149
+ h = height
150
+ i = (height - h) // 2
151
+ j = (width - w) // 2
152
+ return i, j, h, w
153
+
154
+ def forward(self, imgs):
155
+ """
156
+ Args:
157
+ img (PIL Image or Tensor): Image to be cropped and resized.
158
+
159
+ Returns:
160
+ PIL Image or Tensor: Randomly cropped and resized image.
161
+ """
162
+ i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio)
163
+ return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation) for img in imgs]
164
+
165
+
166
+ class Resize3D(object):
167
+ def __init__(self, size):
168
+ super().__init__()
169
+ self.size = size
170
+
171
+ def __call__(self, imgs):
172
+ '''
173
+ x must be PIL.Image
174
+ '''
175
+ return [x.resize((self.size, self.size)) for x in imgs]
176
+
177
+
178
+ class RandomHorizontalFlip3D(object):
179
+ def __init__(self, p=0.5):
180
+ super().__init__()
181
+ self.p = p
182
+
183
+ def __call__(self, imgs):
184
+ '''
185
+ x must be PIL.Image
186
+ '''
187
+ if np.random.rand() < self.p:
188
+ return [x.transpose(Image.FLIP_LEFT_RIGHT) for x in imgs]
189
+ else:
190
+ return imgs
191
+
192
+
193
+ class ColorJitter3D(torch.nn.Module):
194
+ """Randomly change the brightness, contrast and saturation of an image.
195
+
196
+ Args:
197
+ brightness (float or tuple of float (min, max)): How much to jitter brightness.
198
+ brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
199
+ or the given [min, max]. Should be non negative numbers.
200
+ contrast (float or tuple of float (min, max)): How much to jitter contrast.
201
+ contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
202
+ or the given [min, max]. Should be non negative numbers.
203
+ saturation (float or tuple of float (min, max)): How much to jitter saturation.
204
+ saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
205
+ or the given [min, max]. Should be non negative numbers.
206
+ hue (float or tuple of float (min, max)): How much to jitter hue.
207
+ hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
208
+ Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
209
+ """
210
+
211
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
212
+ super().__init__()
213
+ self.brightness = (1-brightness, 1+brightness)
214
+ self.contrast = (1-contrast, 1+contrast)
215
+ self.saturation = (1-saturation, 1+saturation)
216
+ self.hue = (0-hue, 0+hue)
217
+
218
+ @staticmethod
219
+ def get_params(brightness, contrast, saturation, hue):
220
+ """Get a randomized transform to be applied on image.
221
+
222
+ Arguments are same as that of __init__.
223
+
224
+ Returns:
225
+ Transform which randomly adjusts brightness, contrast and
226
+ saturation in a random order.
227
+ """
228
+ tfs = []
229
+
230
+ if brightness is not None:
231
+ brightness_factor = random.uniform(brightness[0], brightness[1])
232
+ tfs.append(transforms.Lambda(
233
+ lambda img: F.adjust_brightness(img, brightness_factor)))
234
+
235
+ if contrast is not None:
236
+ contrast_factor = random.uniform(contrast[0], contrast[1])
237
+ tfs.append(transforms.Lambda(
238
+ lambda img: F.adjust_contrast(img, contrast_factor)))
239
+
240
+ if saturation is not None:
241
+ saturation_factor = random.uniform(saturation[0], saturation[1])
242
+ tfs.append(transforms.Lambda(
243
+ lambda img: F.adjust_saturation(img, saturation_factor)))
244
+
245
+ if hue is not None:
246
+ hue_factor = random.uniform(hue[0], hue[1])
247
+ tfs.append(transforms.Lambda(
248
+ lambda img: F.adjust_hue(img, hue_factor)))
249
+
250
+ random.shuffle(tfs)
251
+ transform = transforms.Compose(tfs)
252
+
253
+ return transform
254
+
255
+ def forward(self, imgs):
256
+ """
257
+ Args:
258
+ img (PIL Image or Tensor): Input image.
259
+
260
+ Returns:
261
+ PIL Image or Tensor: Color jittered image.
262
+ """
263
+ transform = self.get_params(
264
+ self.brightness, self.contrast, self.saturation, self.hue)
265
+ return [transform(img) for img in imgs]
266
+
267
+
268
+ class ToTensor3D(object):
269
+ def __init__(self):
270
+ super().__init__()
271
+
272
+ def __call__(self, imgs):
273
+ '''
274
+ x must be PIL.Image
275
+ '''
276
+ return [F.to_tensor(img) for img in imgs]
277
+
278
+
279
+ class Normalize3D(object):
280
+ def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False):
281
+ super().__init__()
282
+ self.mean = mean
283
+ self.std = std
284
+ self.inplace = inplace
285
+
286
+ def __call__(self, imgs):
287
+ '''
288
+ x must be PIL.Image
289
+ '''
290
+ return [F.normalize(img, self.mean, self.std, self.inplace) for img in imgs]
291
+
292
+
293
+ class CenterCrop3D(object):
294
+ def __init__(self, size):
295
+ super().__init__()
296
+ self.size = size
297
+
298
+ def __call__(self, imgs):
299
+ '''
300
+ x must be PIL.Image
301
+ '''
302
+ return [F.center_crop(img, self.size) for img in imgs]
303
+
304
+
305
+ class FrequencyMasking(object):
306
+ def __init__(self, freq_mask_param: int, iid_masks: bool = False):
307
+ super().__init__()
308
+ self.masking = torchaudio.transforms.FrequencyMasking(freq_mask_param, iid_masks)
309
+
310
+ def __call__(self, item):
311
+ if 'cond_image' in item.keys():
312
+ batched_spec = torch.stack(
313
+ [torch.tensor(item['image']), torch.tensor(item['cond_image'])], dim=0
314
+ )[:, None] # (2, 1, H, W)
315
+ masked = self.masking(batched_spec).numpy()
316
+ item['image'] = masked[0, 0]
317
+ item['cond_image'] = masked[1, 0]
318
+ elif 'image' in item.keys():
319
+ inp = torch.tensor(item['image'])
320
+ item['image'] = self.masking(inp).numpy()
321
+ else:
322
+ raise NotImplementedError()
323
+ return item
324
+
325
+
326
+ class TimeMasking(object):
327
+ def __init__(self, time_mask_param: int, iid_masks: bool = False):
328
+ super().__init__()
329
+ self.masking = torchaudio.transforms.TimeMasking(time_mask_param, iid_masks)
330
+
331
+ def __call__(self, item):
332
+ if 'cond_image' in item.keys():
333
+ batched_spec = torch.stack(
334
+ [torch.tensor(item['image']), torch.tensor(item['cond_image'])], dim=0
335
+ )[:, None] # (2, 1, H, W)
336
+ masked = self.masking(batched_spec).numpy()
337
+ item['image'] = masked[0, 0]
338
+ item['cond_image'] = masked[1, 0]
339
+ elif 'image' in item.keys():
340
+ inp = torch.tensor(item['image'])
341
+ item['image'] = self.masking(inp).numpy()
342
+ else:
343
+ raise NotImplementedError()
344
+ return item
345
+
346
+
347
+ class PitchShift(nn.Module):
348
+
349
+ def __init__(self, up=12, down=-12, sample_rate=SR):
350
+ super().__init__()
351
+ self.range = (down, up)
352
+ self.sr = sample_rate
353
+
354
+ def forward(self, x):
355
+ assert len(x.shape) == 2
356
+ x = x[:, None, :]
357
+ ratio = float(random.randint(self.range[0], self.range[1]) / 12.)
358
+ shifted = pitch_shift(x, ratio, self.sr)
359
+ return shifted.squeeze()
360
+
361
+
362
+ class MelSpectrogram(object):
363
+ def __init__(self, sr, nfft, fmin, fmax, nmels, hoplen, spec_power, inverse=False):
364
+ self.sr = sr
365
+ self.nfft = nfft
366
+ self.fmin = fmin
367
+ self.fmax = fmax
368
+ self.nmels = nmels
369
+ self.hoplen = hoplen
370
+ self.spec_power = spec_power
371
+ self.inverse = inverse
372
+
373
+ self.mel_basis = librosa.filters.mel(sr=sr, n_fft=nfft, fmin=fmin, fmax=fmax, n_mels=nmels)
374
+
375
+ def __call__(self, x):
376
+ x = x.numpy()
377
+ if self.inverse:
378
+ spec = librosa.feature.inverse.mel_to_stft(
379
+ x, sr=self.sr, n_fft=self.nfft, fmin=self.fmin, fmax=self.fmax, power=self.spec_power
380
+ )
381
+ wav = librosa.griffinlim(spec, hop_length=self.hoplen)
382
+ return torch.FloatTensor(wav)
383
+ else:
384
+ spec = np.abs(librosa.stft(x, n_fft=self.nfft, hop_length=self.hoplen)) ** self.spec_power
385
+ mel_spec = np.dot(self.mel_basis, spec)
386
+ return torch.FloatTensor(mel_spec)
387
+
388
+ class SpectrogramTorchAudio(object):
389
+ def __init__(self, nfft, hoplen, spec_power, inverse=False):
390
+ self.nfft = nfft
391
+ self.hoplen = hoplen
392
+ self.spec_power = spec_power
393
+ self.inverse = inverse
394
+
395
+ self.spec_trans = torchaudio.transforms.Spectrogram(
396
+ n_fft=self.nfft,
397
+ hop_length=self.hoplen,
398
+ power=self.spec_power,
399
+ )
400
+ self.inv_spec_trans = torchaudio.transforms.GriffinLim(
401
+ n_fft=self.nfft,
402
+ hop_length=self.hoplen,
403
+ power=self.spec_power,
404
+ )
405
+
406
+ def __call__(self, x):
407
+ if self.inverse:
408
+ wav = self.inv_spec_trans(x)
409
+ return wav
410
+ else:
411
+ spec = torch.abs(self.spec_trans(x))
412
+ return spec
413
+
414
+
415
+ class MelScaleTorchAudio(object):
416
+ def __init__(self, sr, stft, fmin, fmax, nmels, inverse=False):
417
+ self.sr = sr
418
+ self.stft = stft
419
+ self.fmin = fmin
420
+ self.fmax = fmax
421
+ self.nmels = nmels
422
+ self.inverse = inverse
423
+
424
+ self.mel_trans = torchaudio.transforms.MelScale(
425
+ n_mels=self.nmels,
426
+ sample_rate=self.sr,
427
+ f_min=self.fmin,
428
+ f_max=self.fmax,
429
+ n_stft=self.stft,
430
+ norm='slaney'
431
+ )
432
+ self.inv_mel_trans = torchaudio.transforms.InverseMelScale(
433
+ n_mels=self.nmels,
434
+ sample_rate=self.sr,
435
+ f_min=self.fmin,
436
+ f_max=self.fmax,
437
+ n_stft=self.stft,
438
+ norm='slaney'
439
+ )
440
+
441
+ def __call__(self, x):
442
+ if self.inverse:
443
+ spec = self.inv_mel_trans(x)
444
+ return spec
445
+ else:
446
+ mel_spec = self.mel_trans(x)
447
+ return mel_spec
448
+
449
+ class Padding(object):
450
+ def __init__(self, target_len, inverse=False):
451
+ self.target_len=int(target_len)
452
+ self.inverse = inverse
453
+
454
+ def __call__(self, x):
455
+ if self.inverse:
456
+ return x
457
+ else:
458
+ x = x.squeeze()
459
+ if x.shape[0] < self.target_len:
460
+ pad = torch.zeros((self.target_len,), dtype=x.dtype, device=x.device)
461
+ pad[:x.shape[0]] = x
462
+ x = pad
463
+ elif x.shape[0] > self.target_len:
464
+ raise NotImplementedError()
465
+ return x
466
+
467
+ class MakeMono(object):
468
+ def __init__(self, inverse=False):
469
+ self.inverse = inverse
470
+
471
+ def __call__(self, x):
472
+ if self.inverse:
473
+ return x
474
+ else:
475
+ x = x.squeeze()
476
+ if len(x.shape) == 1:
477
+ return torch.FloatTensor(x)
478
+ elif len(x.shape) == 2:
479
+ target_dim = int(torch.argmin(torch.tensor(x.shape)))
480
+ return torch.mean(x, dim=target_dim)
481
+ else:
482
+ raise NotImplementedError
483
+
484
+ class LowerThresh(object):
485
+ def __init__(self, min_val, inverse=False):
486
+ self.min_val = torch.tensor(min_val)
487
+ self.inverse = inverse
488
+
489
+ def __call__(self, x):
490
+ if self.inverse:
491
+ return x
492
+ else:
493
+ return torch.maximum(self.min_val, x)
494
+
495
+ class Add(object):
496
+ def __init__(self, val, inverse=False):
497
+ self.inverse = inverse
498
+ self.val = val
499
+
500
+ def __call__(self, x):
501
+ if self.inverse:
502
+ return x - self.val
503
+ else:
504
+ return x + self.val
505
+
506
+ class Subtract(Add):
507
+ def __init__(self, val, inverse=False):
508
+ self.inverse = inverse
509
+ self.val = val
510
+
511
+ def __call__(self, x):
512
+ if self.inverse:
513
+ return x + self.val
514
+ else:
515
+ return x - self.val
516
+
517
+ class Multiply(object):
518
+ def __init__(self, val, inverse=False) -> None:
519
+ self.val = val
520
+ self.inverse = inverse
521
+
522
+ def __call__(self, x):
523
+ if self.inverse:
524
+ return x / self.val
525
+ else:
526
+ return x * self.val
527
+
528
+ class Divide(Multiply):
529
+ def __init__(self, val, inverse=False):
530
+ self.inverse = inverse
531
+ self.val = val
532
+
533
+ def __call__(self, x):
534
+ if self.inverse:
535
+ return x * self.val
536
+ else:
537
+ return x / self.val
538
+
539
+
540
+ class Log10(object):
541
+ def __init__(self, inverse=False):
542
+ self.inverse = inverse
543
+
544
+ def __call__(self, x):
545
+ if self.inverse:
546
+ return 10 ** x
547
+ else:
548
+ return torch.log10(x)
549
+
550
+ class Clip(object):
551
+ def __init__(self, min_val, max_val, inverse=False):
552
+ self.min_val = min_val
553
+ self.max_val = max_val
554
+ self.inverse = inverse
555
+
556
+ def __call__(self, x):
557
+ if self.inverse:
558
+ return x
559
+ else:
560
+ return torch.clip(x, self.min_val, self.max_val)
561
+
562
+ class TrimSpec(object):
563
+ def __init__(self, max_len, inverse=False):
564
+ self.max_len = max_len
565
+ self.inverse = inverse
566
+
567
+ def __call__(self, x):
568
+ if self.inverse:
569
+ return x
570
+ else:
571
+ return x[:, :self.max_len]
572
+
573
+ class MaxNorm(object):
574
+ def __init__(self, inverse=False):
575
+ self.inverse = inverse
576
+ self.eps = 1e-10
577
+
578
+ def __call__(self, x):
579
+ if self.inverse:
580
+ return x
581
+ else:
582
+ return x / (x.max() + self.eps)
583
+
584
+
585
+ class NormalizeAudio(object):
586
+ def __init__(self, inverse=False, desired_rms=0.1, eps=1e-4):
587
+ self.inverse = inverse
588
+ self.desired_rms = desired_rms
589
+ self.eps = torch.tensor(eps)
590
+
591
+ def __call__(self, x):
592
+ if self.inverse:
593
+ return x
594
+ else:
595
+ rms = torch.maximum(self.eps, torch.sqrt(torch.mean(x**2)))
596
+ x = x * (self.desired_rms / rms)
597
+ x[x > 1.] = 1.
598
+ x[x < -1.] = -1.
599
+ return x
600
+
601
+
602
+ class RandomNormalizeAudio(object):
603
+ def __init__(self, inverse=False, rms_range=[0.05, 0.2], eps=1e-4):
604
+ self.inverse = inverse
605
+ self.rms_low, self.rms_high = rms_range
606
+ self.eps = torch.tensor(eps)
607
+
608
+ def __call__(self, x):
609
+ if self.inverse:
610
+ return x
611
+ else:
612
+ rms = torch.maximum(self.eps, torch.sqrt(torch.mean(x**2)))
613
+ desired_rms = (torch.rand(1) * (self.rms_high - self.rms_low)) + self.rms_low
614
+ x = x * (desired_rms / rms)
615
+ x[x > 1.] = 1.
616
+ x[x < -1.] = -1.
617
+ return x
618
+
619
+
620
+ class MakeDouble(nn.Module):
621
+ def __init__(self):
622
+ super().__init__()
623
+
624
+ def forward(self, x):
625
+ return x.to(torch.double)
626
+
627
+
628
+ class MakeFloat(nn.Module):
629
+ def __init__(self):
630
+ super().__init__()
631
+
632
+ def forward(self, x):
633
+ return x.to(torch.float)
634
+
635
+
636
+ class Wave2Spectrogram(nn.Module):
637
+ def __init__(self, mel_num, spec_crop_len):
638
+ super().__init__()
639
+ self.trans = transforms.Compose([
640
+ LowerThresh(1e-5),
641
+ Log10(),
642
+ Multiply(20),
643
+ Subtract(20),
644
+ Add(100),
645
+ Divide(100),
646
+ Clip(0, 1.0),
647
+ TrimSpec(173),
648
+ transforms.CenterCrop((mel_num, spec_crop_len))
649
+ ])
650
+
651
+ def forward(self, x):
652
+ return self.trans(x)
653
+
654
+
655
+
656
+ TRANSFORMS = transforms.Compose([
657
+ SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1),
658
+ MelScaleTorchAudio(sr=22050, stft=513, fmin=125, fmax=7600, nmels=80),
659
+ LowerThresh(1e-5),
660
+ Log10(),
661
+ Multiply(20),
662
+ Subtract(20),
663
+ Add(100),
664
+ Divide(100),
665
+ Clip(0, 1.0),
666
+ ])
667
+
668
+ def get_spectrogram_torch(audio_path, save_dir, length, save_results=True):
669
+ wav, _ = soundfile.read(audio_path)
670
+ wav = torch.FloatTensor(wav)
671
+ y = torch.zeros(length)
672
+ if wav.shape[0] < length:
673
+ y[:len(wav)] = wav
674
+ else:
675
+ y = wav[:length]
676
+
677
+ mel_spec = TRANSFORMS(y).numpy()
678
+ y = y.numpy()
679
+ if save_results:
680
+ os.makedirs(save_dir, exist_ok=True)
681
+ audio_name = os.path.basename(audio_path).split('.')[0]
682
+ np.save(os.path.join(save_dir, audio_name + '_mel.npy'), mel_spec)
683
+ np.save(os.path.join(save_dir, audio_name + '_audio.npy'), y)
684
+ else:
685
+ return y, mel_spec
foleycrafter/models/specvqgan/data/utils.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import json
7
+ from random import shuffle, choice, sample
8
+
9
+ from moviepy.editor import VideoFileClip
10
+ import librosa
11
+ from scipy import signal
12
+ from scipy.io import wavfile
13
+ import torchaudio
14
+ torchaudio.set_audio_backend("sox_io")
15
+
16
+ INTERVAL = 1000
17
+
18
+ # discard
19
+ stft = torchaudio.transforms.MelSpectrogram(
20
+ sample_rate=16000, hop_length=161, n_mels=64).cuda()
21
+
22
+
23
+ def log10(x): return torch.log(x)/torch.log(torch.tensor(10.))
24
+
25
+
26
+ def norm_range(x, min_val, max_val):
27
+ return 2.*(x - min_val)/float(max_val - min_val) - 1.
28
+
29
+
30
+ def normalize_spec(spec, spec_min, spec_max):
31
+ return norm_range(spec, spec_min, spec_max)
32
+
33
+
34
+ def db_from_amp(x, cuda=False):
35
+ # rescale the audio
36
+ if cuda:
37
+ return 20. * log10(torch.max(torch.tensor(1e-5).to('cuda'), x.float()))
38
+ else:
39
+ return 20. * log10(torch.max(torch.tensor(1e-5), x.float()))
40
+
41
+
42
+ def audio_stft(audio, stft=stft):
43
+ # We'll apply stft to the audio samples to convert it to a HxW matrix
44
+ N, C, A = audio.size()
45
+ audio = audio.view(N * C, A)
46
+ spec = stft(audio)
47
+ spec = spec.transpose(-1, -2)
48
+ spec = db_from_amp(spec, cuda=True)
49
+ spec = normalize_spec(spec, -100., 100.)
50
+ _, T, F = spec.size()
51
+ spec = spec.view(N, C, T, F)
52
+ return spec
53
+
54
+
55
+ # discard
56
+ # def get_spec(
57
+ # wavs,
58
+ # sample_rate=16000,
59
+ # use_volume_jittering=False,
60
+ # center=False,
61
+ # ):
62
+ # # Volume jittering - scale volume by factor in range (0.9, 1.1)
63
+ # if use_volume_jittering:
64
+ # wavs = [wav * np.random.uniform(0.9, 1.1) for wav in wavs]
65
+ # if center:
66
+ # wavs = [center_only(wav) for wav in wavs]
67
+
68
+ # # Convert to log filterbank
69
+ # specs = [logfbank(
70
+ # wav,
71
+ # sample_rate,
72
+ # winlen=0.009,
73
+ # winstep=0.005, # if num_sec==1 else 0.01,
74
+ # nfilt=256,
75
+ # nfft=1024
76
+ # ).astype('float32').T for wav in wavs]
77
+
78
+ # # Convert to 32-bit float and expand dim
79
+ # specs = np.stack(specs, axis=0)
80
+ # specs = np.expand_dims(specs, 1)
81
+ # specs = torch.as_tensor(specs) # Nx1xFxT
82
+
83
+ # return specs
84
+
85
+
86
+ def center_only(audio, sr=16000, L=1.0):
87
+ # center_wav = np.arange(0, L, L/(0.5*sr)) ** 2
88
+ # center_wav = np.concatenate([center_wav, center_wav[::-1]])
89
+ # center_wav[L*sr//2:3*L*sr//4] = 1
90
+ # only take 0.3 sec audio
91
+ center_wav = np.zeros(int(L * sr))
92
+ center_wav[int(0.4*L*sr):int(0.7*L*sr)] = 1
93
+
94
+ return audio * center_wav
95
+
96
+ def get_spec_librosa(
97
+ wavs,
98
+ sample_rate=16000,
99
+ use_volume_jittering=False,
100
+ center=False,
101
+ ):
102
+ # Volume jittering - scale volume by factor in range (0.9, 1.1)
103
+ if use_volume_jittering:
104
+ wavs = [wav * np.random.uniform(0.9, 1.1) for wav in wavs]
105
+ if center:
106
+ wavs = [center_only(wav) for wav in wavs]
107
+
108
+ # Convert to log filterbank
109
+ specs = [librosa.feature.melspectrogram(
110
+ y=wav,
111
+ sr=sample_rate,
112
+ n_fft=400,
113
+ hop_length=126,
114
+ n_mels=128,
115
+ ).astype('float32') for wav in wavs]
116
+
117
+ # Convert to 32-bit float and expand dim
118
+ specs = [librosa.power_to_db(spec) for spec in specs]
119
+ specs = np.stack(specs, axis=0)
120
+ specs = np.expand_dims(specs, 1)
121
+ specs = torch.as_tensor(specs) # Nx1xFxT
122
+
123
+ return specs
124
+
125
+
126
+ def calcEuclideanDistance_Mat(X, Y):
127
+ """
128
+ Inputs:
129
+ - X: A numpy array of shape (N, F)
130
+ - Y: A numpy array of shape (M, F)
131
+
132
+ Returns:
133
+ A numpy array D of shape (N, M) where D[i, j] is the Euclidean distance
134
+ between X[i] and Y[j].
135
+ """
136
+ return ((torch.sum(X ** 2, axis=1, keepdims=True)) + (torch.sum(Y ** 2, axis=1, keepdims=True)).T - 2 * X @ Y.T) ** 0.5
137
+
138
+
139
+ def calcEuclideanDistance(x1, x2):
140
+ return torch.sum((x1 - x2)**2, dim=1)**0.5
141
+
142
+
143
+ def split_data(in_list, portion=(0.9, 0.95), is_shuffle=True):
144
+ if is_shuffle:
145
+ shuffle(in_list)
146
+ if type(in_list) == str:
147
+ with open(in_list) as l:
148
+ fw_list = json.load(l)
149
+ elif type(in_list) == list:
150
+ fw_list = in_list
151
+ else:
152
+ print(type(in_list))
153
+ raise TypeError('Invalid input list type')
154
+ c1, c2 = int(len(fw_list) * portion[0]), int(len(fw_list) * portion[1])
155
+ tr_list, va_list, te_list = fw_list[:c1], fw_list[c1:c2], fw_list[c2:]
156
+ print(
157
+ f'==> train set: {len(tr_list)}, validation set: {len(va_list)}, test set: {len(te_list)}')
158
+ return tr_list, va_list, te_list
159
+
160
+
161
+ def load_one_clip(video_path):
162
+ v = VideoFileClip(video_path)
163
+ fps = int(v.fps)
164
+ frames = [f for f in v.iter_frames()][:-1]
165
+ frame_cnt = len(frames)
166
+ frame_length = 1000./fps
167
+ total_length = int(1000 * (frame_cnt / fps))
168
+
169
+ a = v.audio
170
+ sr = a.fps
171
+ a = np.array([fa for fa in a.iter_frames()])
172
+ a = librosa.resample(a, sr, 48000)
173
+ if len(a.shape) > 1:
174
+ a = np.mean(a, axis=1)
175
+
176
+ while True:
177
+ idx = np.random.choice(np.arange(frame_cnt - 1), 1)[0]
178
+ frame_clip = frames[idx]
179
+ start_time = int(idx * frame_length + 0.5 * frame_length - 500)
180
+ end_time = start_time + INTERVAL
181
+ if start_time < 0 or end_time > total_length:
182
+ continue
183
+ wave_clip = a[48 * start_time: 48 * end_time]
184
+ if wave_clip.shape[0] != 48000:
185
+ continue
186
+ break
187
+ return frame_clip, wave_clip
188
+
189
+
190
+ def resize_frame(frame):
191
+ H, W = frame.size
192
+ short_edge = min(H, W)
193
+ scale = 256 / short_edge
194
+ H_tar, W_tar = int(np.round(H * scale)), int(np.round(W * scale))
195
+ return frame.resize((H_tar, W_tar))
196
+
197
+
198
+ def get_spectrogram(wave, amp_jitter, amp_jitter_range, log_scale=True, sr=48000):
199
+ # random clip-level amplitude jittering
200
+ if amp_jitter:
201
+ amplified = wave * np.random.uniform(*amp_jitter_range)
202
+ if wave.dtype == np.int16:
203
+ amplified[amplified >= 32767] = 32767
204
+ amplified[amplified <= -32768] = -32768
205
+ wave = amplified.astype('int16')
206
+ elif wave.dtype == np.float32 or wave.dtype == np.float64:
207
+ amplified[amplified >= 1] = 1
208
+ amplified[amplified <= -1] = -1
209
+
210
+ # fr, ts, spectrogram = signal.spectrogram(wave[:48000], fs=sr, nperseg=480, noverlap=240, nfft=512)
211
+ # spectrogram = librosa.feature.melspectrogram(S=spectrogram, n_mels=257) # Try log-mel spectrogram?
212
+ spectrogram = librosa.feature.melspectrogram(
213
+ y=wave[:48000], sr=sr, hop_length=240, win_length=480, n_mels=257)
214
+ if log_scale:
215
+ spectrogram = librosa.power_to_db(spectrogram, ref=np.max)
216
+ assert spectrogram.shape[0] == 257
217
+
218
+ return spectrogram
219
+
220
+
221
+ def cropAudio(audio, sr, f_idx, fps=10, length=1., left_shift=0):
222
+ time_per_frame = 1./fps
223
+ assert audio.shape[0] > sr * length
224
+ start_time = f_idx * time_per_frame - left_shift
225
+ start_time = 0 if start_time < 0 else start_time
226
+ start_idx = int(np.round(sr * start_time))
227
+ end_idx = int(np.round(start_idx + (sr * length)))
228
+ if end_idx > audio.shape[0]:
229
+ end_idx = audio.shape[0]
230
+ start_idx = int(end_idx - (sr * length))
231
+ try:
232
+ assert audio[start_idx:end_idx].shape[0] == sr * length
233
+ except:
234
+ print(audio.shape, start_idx, end_idx, end_idx - start_idx)
235
+ exit(1)
236
+ return audio[start_idx:end_idx]
237
+
238
+
239
+ def pick_async_frame_idx(idx, total_frames, fps=10, gap=2.0, length=1.0, cnt=1):
240
+ assert idx < total_frames - fps * length
241
+ lower_bound = idx - int((length + gap) * fps)
242
+ upper_bound = idx + int((length + gap) * fps)
243
+ proposal = list(range(0, lower_bound)) + \
244
+ list(range(upper_bound, int(total_frames - fps * length)))
245
+ # assert len(proposal) >= cnt
246
+ avail_cnt = len(proposal)
247
+ try:
248
+ for i in range(cnt - avail_cnt):
249
+ proposal.append(proposal[i % avail_cnt])
250
+ except Exception as e:
251
+ print(idx, total_frames, proposal)
252
+ raise e
253
+ return sample(proposal, k=cnt)
254
+
255
+
256
+ def adjust_learning_rate(optimizer, epoch, args):
257
+ """Decay the learning rate based on schedule"""
258
+ lr = args.lr
259
+ if args.cos: # cosine lr schedule
260
+ lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epoch))
261
+ else: # stepwise lr schedule
262
+ for milestone in args.schedule:
263
+ lr *= 0.1 if epoch >= milestone else 1.
264
+ for param_group in optimizer.param_groups:
265
+ param_group['lr'] = lr
foleycrafter/models/specvqgan/models/av_cond_transformer.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchvision import transforms
8
+ import torchaudio
9
+ from omegaconf.listconfig import ListConfig
10
+
11
+ sys.path.insert(0, '.') # nopep8
12
+ from foleycrafter.models.specvqgan.modules.transformer.mingpt import (GPTClass, GPTFeats, GPTFeatsClass)
13
+ from foleycrafter.models.specvqgan.data.transforms import Wave2Spectrogram, PitchShift, NormalizeAudio
14
+ from train import instantiate_from_config
15
+
16
+ SR = 22050
17
+
18
+ def disabled_train(self, mode=True):
19
+ """Overwrite model.train with this function to make sure train/eval mode
20
+ does not change anymore."""
21
+ return self
22
+
23
+
24
+ class Net2NetTransformerAVCond(pl.LightningModule):
25
+ def __init__(self, transformer_config, first_stage_config,
26
+ cond_stage_config,
27
+ drop_condition=False, drop_video=False, drop_cond_video=False,
28
+ first_stage_permuter_config=None, cond_stage_permuter_config=None,
29
+ ckpt_path=None, ignore_keys=[],
30
+ first_stage_key="image",
31
+ cond_first_stage_key="cond_image",
32
+ cond_stage_key="depth",
33
+ downsample_cond_size=-1,
34
+ pkeep=1.0,
35
+ clip=30,
36
+ p_audio_aug=0.5,
37
+ p_pitch_shift=0.,
38
+ p_normalize=0.,
39
+ mel_num=80,
40
+ spec_crop_len=160):
41
+
42
+ super().__init__()
43
+ self.init_first_stage_from_ckpt(first_stage_config)
44
+ self.init_cond_stage_from_ckpt(cond_stage_config)
45
+ if first_stage_permuter_config is None:
46
+ first_stage_permuter_config = {"target": "foleycrafter.models.specvqgan.modules.transformer.permuter.Identity"}
47
+ if cond_stage_permuter_config is None:
48
+ cond_stage_permuter_config = {"target": "foleycrafter.models.specvqgan.modules.transformer.permuter.Identity"}
49
+ self.first_stage_permuter = instantiate_from_config(config=first_stage_permuter_config)
50
+ self.cond_stage_permuter = instantiate_from_config(config=cond_stage_permuter_config)
51
+ self.transformer = instantiate_from_config(config=transformer_config)
52
+
53
+ self.wav_transforms = nn.Sequential(
54
+ transforms.RandomApply([NormalizeAudio()], p=p_normalize),
55
+ transforms.RandomApply([PitchShift()], p=p_pitch_shift),
56
+ torchaudio.transforms.Spectrogram(
57
+ n_fft=1024,
58
+ hop_length=1024//4,
59
+ power=1,
60
+ ),
61
+ # transforms.RandomApply([
62
+ # torchaudio.transforms.FrequencyMasking(freq_mask_param=40, iid_masks=False)
63
+ # ], p=p_audio_aug),
64
+ # transforms.RandomApply([
65
+ # torchaudio.transforms.TimeMasking(time_mask_param=int(32 * 2), iid_masks=False)
66
+ # ], p=p_audio_aug),
67
+ torchaudio.transforms.MelScale(
68
+ n_mels=80,
69
+ sample_rate=SR,
70
+ f_min=125,
71
+ f_max=7600,
72
+ n_stft=513,
73
+ norm='slaney'
74
+ ),
75
+ Wave2Spectrogram(mel_num, spec_crop_len),
76
+ )
77
+ ignore_keys = ['wav_transforms']
78
+
79
+ if ckpt_path is not None:
80
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
81
+ self.drop_condition = drop_condition
82
+ self.drop_video = drop_video
83
+ self.drop_cond_video = drop_cond_video
84
+ print(f'>>> Feature setting: all cond: {self.drop_condition}, video: {self.drop_video}, cond video: {self.drop_cond_video}')
85
+ self.first_stage_key = first_stage_key
86
+ self.cond_first_stage_key = cond_first_stage_key
87
+ self.cond_stage_key = cond_stage_key
88
+ self.downsample_cond_size = downsample_cond_size
89
+ self.pkeep = pkeep
90
+ self.clip = clip
91
+ print('>>> model init done.')
92
+
93
+ def init_from_ckpt(self, path, ignore_keys=list()):
94
+ sd = torch.load(path, map_location="cpu")["state_dict"]
95
+ for k in sd.keys():
96
+ for ik in ignore_keys:
97
+ if k.startswith(ik):
98
+ self.print("Deleting key {} from state_dict.".format(k))
99
+ del sd[k]
100
+ self.load_state_dict(sd, strict=False)
101
+ print(f"Restored from {path}")
102
+
103
+ def init_first_stage_from_ckpt(self, config):
104
+ model = instantiate_from_config(config)
105
+ model = model.eval()
106
+ model.train = disabled_train
107
+ self.first_stage_model = model
108
+
109
+ def init_cond_stage_from_ckpt(self, config):
110
+ model = instantiate_from_config(config)
111
+ model = model.eval()
112
+ model.train = disabled_train
113
+ self.cond_stage_model = model
114
+
115
+ def forward(self, x, c, xp):
116
+ # one step to produce the logits
117
+ _, z_indices = self.encode_to_z(x) # VQ-GAN encoding
118
+ _, zp_indices = self.encode_to_z(xp)
119
+ _, c_indices = self.encode_to_c(c) # Conv1-1 down dim + col-major permuter
120
+ z_indices = z_indices[:, :self.clip]
121
+ zp_indices = zp_indices[:, :self.clip]
122
+ if not self.drop_condition:
123
+ z_indices = torch.cat([zp_indices, z_indices], dim=1)
124
+
125
+ if self.training and self.pkeep < 1.0:
126
+ mask = torch.bernoulli(self.pkeep * torch.ones(z_indices.shape, device=z_indices.device))
127
+ mask = mask.round().to(dtype=torch.int64)
128
+ r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
129
+ a_indices = mask*z_indices+(1-mask)*r_indices
130
+ else:
131
+ a_indices = z_indices
132
+
133
+ # target includes all sequence elements (no need to handle first one
134
+ # differently because we are conditioning)
135
+ if self.drop_condition:
136
+ target = z_indices
137
+ else:
138
+ target = z_indices[:, self.clip:]
139
+
140
+ # in the case we do not want to encode condition anyhow (e.g. inputs are features)
141
+ if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
142
+ # make the prediction
143
+ logits, _, _ = self.transformer(z_indices[:, :-1], c)
144
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
145
+ if isinstance(self.transformer, GPTFeatsClass):
146
+ cond_size = c['feature'].size(-1) + c['target'].size(-1)
147
+ else:
148
+ cond_size = c.size(-1)
149
+ if self.drop_condition:
150
+ logits = logits[:, cond_size-1:]
151
+ else:
152
+ logits = logits[:, cond_size-1:][:, self.clip:]
153
+ else:
154
+ cz_indices = torch.cat((c_indices, a_indices), dim=1)
155
+ # make the prediction
156
+ logits, _, _ = self.transformer(cz_indices[:, :-1])
157
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
158
+ logits = logits[:, c_indices.shape[1]-1:]
159
+
160
+ return logits, target
161
+
162
+ def top_k_logits(self, logits, k):
163
+ v, ix = torch.topk(logits, k)
164
+ out = logits.clone()
165
+ out[out < v[..., [-1]]] = -float('Inf')
166
+ return out
167
+
168
+ @torch.no_grad()
169
+ def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
170
+ callback=lambda k: None):
171
+ x = x if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)) else torch.cat((c, x), dim=1)
172
+ block_size = self.transformer.get_block_size()
173
+ assert not self.transformer.training
174
+ if self.pkeep <= 0.0:
175
+ raise NotImplementedError('Implement for GPTFeatsCLass')
176
+ raise NotImplementedError('Implement for GPTFeats')
177
+ raise NotImplementedError('Implement for GPTClass')
178
+ raise NotImplementedError('also the model outputs attention')
179
+ # one pass suffices since input is pure noise anyway
180
+ assert len(x.shape)==2
181
+ # noise_shape = (x.shape[0], steps-1)
182
+ # noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
183
+ noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
184
+ x = torch.cat((x,noise),dim=1)
185
+ logits, _ = self.transformer(x)
186
+ # take all logits for now and scale by temp
187
+ logits = logits / temperature
188
+ # optionally crop probabilities to only the top k options
189
+ if top_k is not None:
190
+ logits = self.top_k_logits(logits, top_k)
191
+ # apply softmax to convert to probabilities
192
+ probs = F.softmax(logits, dim=-1)
193
+ # sample from the distribution or take the most likely
194
+ if sample:
195
+ shape = probs.shape
196
+ probs = probs.reshape(shape[0]*shape[1],shape[2])
197
+ ix = torch.multinomial(probs, num_samples=1)
198
+ probs = probs.reshape(shape[0],shape[1],shape[2])
199
+ ix = ix.reshape(shape[0],shape[1])
200
+ else:
201
+ _, ix = torch.topk(probs, k=1, dim=-1)
202
+ # cut off conditioning
203
+ x = ix[:, c.shape[1]-1:]
204
+ else:
205
+ for k in range(steps):
206
+ callback(k)
207
+ if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
208
+ # if assert is removed, you need to make sure that the combined len is not longer block_s
209
+ if isinstance(self.transformer, GPTFeatsClass):
210
+ cond_size = c['feature'].size(-1) + c['target'].size(-1)
211
+ else:
212
+ cond_size = c.size(-1)
213
+ assert x.size(1) + cond_size <= block_size
214
+
215
+ x_cond = x
216
+ c_cond = c
217
+ logits, _, att = self.transformer(x_cond, c_cond)
218
+ else:
219
+ assert x.size(1) <= block_size # make sure model can see conditioning
220
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
221
+ logits, _, att = self.transformer(x_cond)
222
+ # pluck the logits at the final step and scale by temperature
223
+ logits = logits[:, -1, :] / temperature
224
+ # optionally crop probabilities to only the top k options
225
+ if top_k is not None:
226
+ logits = self.top_k_logits(logits, top_k)
227
+ # apply softmax to convert to probabilities
228
+ probs = F.softmax(logits, dim=-1)
229
+ # sample from the distribution or take the most likely
230
+ if sample:
231
+ ix = torch.multinomial(probs, num_samples=1)
232
+ else:
233
+ _, ix = torch.topk(probs, k=1, dim=-1)
234
+ # append to the sequence and continue
235
+ x = torch.cat((x, ix), dim=1)
236
+ # cut off conditioning
237
+ x = x if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)) else x[:, c.shape[1]:]
238
+ return x, att.detach().cpu()
239
+
240
+ @torch.no_grad()
241
+ def encode_to_z(self, x):
242
+ quant_z, _, info = self.first_stage_model.encode(x)
243
+ indices = info[2].view(quant_z.shape[0], -1)
244
+ indices = self.first_stage_permuter(indices)
245
+ return quant_z, indices
246
+
247
+ @torch.no_grad()
248
+ def encode_to_c(self, c):
249
+ if self.downsample_cond_size > -1:
250
+ c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
251
+ quant_c, _, info = self.cond_stage_model.encode(c)
252
+ if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
253
+ # these are not indices but raw features or a class
254
+ indices = info[2]
255
+ else:
256
+ indices = info[2].view(quant_c.shape[0], -1)
257
+ indices = self.cond_stage_permuter(indices)
258
+ return quant_c, indices
259
+
260
+ @torch.no_grad()
261
+ def decode_to_img(self, index, zshape, stage='first'):
262
+ if stage == 'first':
263
+ index = self.first_stage_permuter(index, reverse=True)
264
+ elif stage == 'cond':
265
+ print('in cond stage in decode_to_img which is unexpected ')
266
+ index = self.cond_stage_permuter(index, reverse=True)
267
+ else:
268
+ raise NotImplementedError
269
+
270
+ bhwc = (zshape[0], zshape[2], zshape[3], zshape[1])
271
+ quant_z = self.first_stage_model.quantize.get_codebook_entry(index.reshape(-1), shape=bhwc)
272
+ x = self.first_stage_model.decode(quant_z)
273
+ return x
274
+
275
+ @torch.no_grad()
276
+ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
277
+ log = dict()
278
+
279
+ N = 4
280
+ if lr_interface:
281
+ x, c, xp = self.get_xcxp(batch, N, diffuse=False, upsample_factor=8)
282
+ else:
283
+ x, c, xp = self.get_xcxp(batch, N)
284
+ x = x.to(device=self.device)
285
+ xp = xp.to(device=self.device)
286
+ # c = c.to(device=self.device)
287
+ if isinstance(c, dict):
288
+ c = {k: v.to(self.device) for k, v in c.items()}
289
+ else:
290
+ c = c.to(self.device)
291
+
292
+ quant_z, z_indices = self.encode_to_z(x)
293
+ quant_zp, zp_indices = self.encode_to_z(xp)
294
+ quant_c, c_indices = self.encode_to_c(c) # output can be features or a single class or a featcls dict
295
+ z_indices_rec = z_indices.clone()
296
+ zp_indices_clip = zp_indices[:, :self.clip]
297
+ z_indices_clip = z_indices[:, :self.clip]
298
+
299
+ # create a "half"" sample
300
+ z_start_indices = z_indices_clip[:, :z_indices_clip.shape[1]//2]
301
+ if self.drop_condition:
302
+ steps = z_indices_clip.shape[1]-z_start_indices.shape[1]
303
+ else:
304
+ z_start_indices = torch.cat([zp_indices_clip, z_start_indices], dim=-1)
305
+ steps = 2*z_indices_clip.shape[1]-z_start_indices.shape[1]
306
+ index_sample, att_half = self.sample(z_start_indices, c_indices,
307
+ steps=steps,
308
+ temperature=temperature if temperature is not None else 1.0,
309
+ sample=True,
310
+ top_k=top_k if top_k is not None else 100,
311
+ callback=callback if callback is not None else lambda k: None)
312
+ if self.drop_condition:
313
+ z_indices_rec[:, :self.clip] = index_sample
314
+ else:
315
+ z_indices_rec[:, :self.clip] = index_sample[:, self.clip:]
316
+ x_sample = self.decode_to_img(z_indices_rec, quant_z.shape)
317
+
318
+ # sample
319
+ z_start_indices = z_indices_clip[:, :0]
320
+ if not self.drop_condition:
321
+ z_start_indices = torch.cat([zp_indices_clip, z_start_indices], dim=-1)
322
+ index_sample, att_nopix = self.sample(z_start_indices, c_indices,
323
+ steps=z_indices_clip.shape[1],
324
+ temperature=temperature if temperature is not None else 1.0,
325
+ sample=True,
326
+ top_k=top_k if top_k is not None else 100,
327
+ callback=callback if callback is not None else lambda k: None)
328
+ if self.drop_condition:
329
+ z_indices_rec[:, :self.clip] = index_sample
330
+ else:
331
+ z_indices_rec[:, :self.clip] = index_sample[:, self.clip:]
332
+ x_sample_nopix = self.decode_to_img(z_indices_rec, quant_z.shape)
333
+
334
+ # det sample
335
+ z_start_indices = z_indices_clip[:, :0]
336
+ if not self.drop_condition:
337
+ z_start_indices = torch.cat([zp_indices_clip, z_start_indices], dim=-1)
338
+ index_sample, att_det = self.sample(z_start_indices, c_indices,
339
+ steps=z_indices_clip.shape[1],
340
+ sample=False,
341
+ callback=callback if callback is not None else lambda k: None)
342
+ if self.drop_condition:
343
+ z_indices_rec[:, :self.clip] = index_sample
344
+ else:
345
+ z_indices_rec[:, :self.clip] = index_sample[:, self.clip:]
346
+ x_sample_det = self.decode_to_img(z_indices_rec, quant_z.shape)
347
+
348
+ # reconstruction
349
+ x_rec = self.decode_to_img(z_indices, quant_z.shape)
350
+
351
+ log["inputs"] = x
352
+ log["reconstructions"] = x_rec
353
+
354
+ if isinstance(self.cond_stage_key, str):
355
+ cond_is_not_image = self.cond_stage_key != "image"
356
+ cond_has_segmentation = self.cond_stage_key == "segmentation"
357
+ elif isinstance(self.cond_stage_key, ListConfig):
358
+ cond_is_not_image = 'image' not in self.cond_stage_key
359
+ cond_has_segmentation = 'segmentation' in self.cond_stage_key
360
+ else:
361
+ raise NotImplementedError
362
+
363
+ if cond_is_not_image:
364
+ cond_rec = self.cond_stage_model.decode(quant_c)
365
+ if cond_has_segmentation:
366
+ # get image from segmentation mask
367
+ num_classes = cond_rec.shape[1]
368
+
369
+ c = torch.argmax(c, dim=1, keepdim=True)
370
+ c = F.one_hot(c, num_classes=num_classes)
371
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
372
+ c = self.cond_stage_model.to_rgb(c)
373
+
374
+ cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
375
+ cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
376
+ cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
377
+ cond_rec = self.cond_stage_model.to_rgb(cond_rec)
378
+ log["conditioning_rec"] = cond_rec
379
+ log["conditioning"] = c
380
+
381
+ log["samples_half"] = x_sample
382
+ log["samples_nopix"] = x_sample_nopix
383
+ log["samples_det"] = x_sample_det
384
+ log["att_half"] = att_half
385
+ log["att_nopix"] = att_nopix
386
+ log["att_det"] = att_det
387
+ return log
388
+
389
+ def spec_transform(self, batch):
390
+ wav = batch[self.first_stage_key]
391
+ wav_cond = batch[self.cond_first_stage_key]
392
+ N = wav.shape[0]
393
+ wav_cat = torch.cat([wav, wav_cond], dim=0)
394
+ self.wav_transforms.to(wav_cat.device)
395
+ spec = self.wav_transforms(wav_cat.to(torch.float32))
396
+ batch[self.first_stage_key] = 2 * spec[:N] - 1
397
+ batch[self.cond_first_stage_key] = 2 * spec[N:] - 1
398
+ return batch
399
+
400
+ def get_input(self, key, batch):
401
+ if isinstance(key, str):
402
+ # if batch[key] is 1D; else the batch[key] is 2D
403
+ if key in ['feature', 'target']:
404
+ if self.drop_condition or self.drop_cond_video:
405
+ cond_size = batch[key].shape[1] // 2
406
+ batch[key] = batch[key][:, cond_size:]
407
+ x = self.cond_stage_model.get_input(
408
+ batch, key, drop_cond=(self.drop_condition or self.drop_cond_video)
409
+ )
410
+ else:
411
+ x = batch[key]
412
+ if len(x.shape) == 3:
413
+ x = x[..., None]
414
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
415
+ if x.dtype == torch.double:
416
+ x = x.float()
417
+ elif isinstance(key, ListConfig):
418
+ x = self.cond_stage_model.get_input(batch, key)
419
+ for k, v in x.items():
420
+ if v.dtype == torch.double:
421
+ x[k] = v.float()
422
+ return x
423
+
424
+ def get_xcxp(self, batch, N=None):
425
+ if len(batch[self.first_stage_key].shape) == 2:
426
+ batch = self.spec_transform(batch)
427
+ x = self.get_input(self.first_stage_key, batch)
428
+ c = self.get_input(self.cond_stage_key, batch)
429
+ xp = self.get_input(self.cond_first_stage_key, batch)
430
+ if N is not None:
431
+ x = x[:N]
432
+ xp = xp[:N]
433
+ if isinstance(self.cond_stage_key, ListConfig):
434
+ c = {k: v[:N] for k, v in c.items()}
435
+ else:
436
+ c = c[:N]
437
+ # Drop additional information during training
438
+ if self.drop_condition:
439
+ xp[:] = 0
440
+ if self.drop_video:
441
+ c[:] = 0
442
+ return x, c, xp
443
+
444
+ def shared_step(self, batch, batch_idx):
445
+ x, c, xp = self.get_xcxp(batch)
446
+ logits, target = self(x, c, xp)
447
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
448
+ return loss
449
+
450
+ def training_step(self, batch, batch_idx):
451
+ loss = self.shared_step(batch, batch_idx)
452
+ self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
453
+ return loss
454
+
455
+ def validation_step(self, batch, batch_idx):
456
+ loss = self.shared_step(batch, batch_idx)
457
+ self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
458
+ return loss
459
+
460
+ def configure_optimizers(self):
461
+ """
462
+ Following minGPT:
463
+ This long function is unfortunately doing something very simple and is being very defensive:
464
+ We are separating out all parameters of the model into two buckets: those that will experience
465
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
466
+ We are then returning the PyTorch optimizer object.
467
+ """
468
+ # separate out all parameters to those that will and won't experience regularizing weight decay
469
+ decay = set()
470
+ no_decay = set()
471
+ whitelist_weight_modules = (torch.nn.Linear, )
472
+
473
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, torch.nn.Conv1d, torch.nn.LSTM, torch.nn.GRU)
474
+ for mn, m in self.transformer.named_modules():
475
+ for pn, p in m.named_parameters():
476
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
477
+
478
+ if pn.endswith('bias'):
479
+ # all biases will not be decayed
480
+ no_decay.add(fpn)
481
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
482
+ # weights of whitelist modules will be weight decayed
483
+ decay.add(fpn)
484
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
485
+ # weights of blacklist modules will NOT be weight decayed
486
+ no_decay.add(fpn)
487
+ elif ('weight' in pn or 'bias' in pn) and isinstance(m, (torch.nn.LSTM, torch.nn.GRU)):
488
+ no_decay.add(fpn)
489
+
490
+ # special case the position embedding parameter in the root GPT module as not decayed
491
+ no_decay.add('pos_emb')
492
+
493
+ # validate that we considered every parameter
494
+ param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
495
+ inter_params = decay & no_decay
496
+ union_params = decay | no_decay
497
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
498
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
499
+ % (str(param_dict.keys() - union_params), )
500
+
501
+ # create the pytorch optimizer object
502
+ optim_groups = [
503
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
504
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
505
+ ]
506
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
507
+ return optimizer
508
+
509
+
510
+ if __name__ == '__main__':
511
+ from omegaconf import OmegaConf
512
+
513
+ cfg_image = OmegaConf.load('./configs/vggsound_transformer.yaml')
514
+ cfg_image.model.params.first_stage_config.params.ckpt_path = './logs/2021-05-19T22-16-54_vggsound_codebook/checkpoints/last.ckpt'
515
+
516
+ transformer_cfg = cfg_image.model.params.transformer_config
517
+ first_stage_cfg = cfg_image.model.params.first_stage_config
518
+ cond_stage_cfg = cfg_image.model.params.cond_stage_config
519
+ permuter_cfg = cfg_image.model.params.permuter_config
520
+ transformer = Net2NetTransformerAVCond(
521
+ transformer_cfg, first_stage_cfg, cond_stage_cfg, permuter_cfg
522
+ )
523
+
524
+ c = torch.rand(2, 2048, 212)
525
+ x = torch.rand(2, 1, 80, 848)
526
+
527
+ logits, target = transformer(x, c)
528
+ print(logits.shape, target.shape)
foleycrafter/models/specvqgan/models/cond_transformer.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from omegaconf.listconfig import ListConfig
8
+ from torchvision import transforms
9
+ from foleycrafter.models.specvqgan.data.transforms import Wave2Spectrogram
10
+ import torchaudio
11
+
12
+ sys.path.insert(0, '.') # nopep8
13
+ from foleycrafter.models.specvqgan.modules.transformer.mingpt import (GPTClass, GPTFeats, GPTFeatsClass)
14
+ from train import instantiate_from_config
15
+
16
+
17
+ def disabled_train(self, mode=True):
18
+ """Overwrite model.train with this function to make sure train/eval mode
19
+ does not change anymore."""
20
+ return self
21
+
22
+
23
+ class Net2NetTransformer(pl.LightningModule):
24
+ def __init__(self, transformer_config, first_stage_config,
25
+ cond_stage_config,
26
+ first_stage_permuter_config=None, cond_stage_permuter_config=None,
27
+ ckpt_path=None, ignore_keys=[],
28
+ first_stage_key="image",
29
+ cond_stage_key="depth",
30
+ downsample_cond_size=-1,
31
+ pkeep=1.0,
32
+ mel_num=80,
33
+ spec_crop_len=160):
34
+
35
+ super().__init__()
36
+ self.init_first_stage_from_ckpt(first_stage_config)
37
+ self.init_cond_stage_from_ckpt(cond_stage_config)
38
+ if first_stage_permuter_config is None:
39
+ first_stage_permuter_config = {"target": "foleycrafter.models.specvqgan.modules.transformer.permuter.Identity"}
40
+ if cond_stage_permuter_config is None:
41
+ cond_stage_permuter_config = {"target": "foleycrafter.models.specvqgan.modules.transformer.permuter.Identity"}
42
+ self.first_stage_permuter = instantiate_from_config(config=first_stage_permuter_config)
43
+ self.cond_stage_permuter = instantiate_from_config(config=cond_stage_permuter_config)
44
+ self.transformer = instantiate_from_config(config=transformer_config)
45
+
46
+ self.wav_transforms = nn.Sequential(
47
+ torchaudio.transforms.Spectrogram(
48
+ n_fft=1024,
49
+ hop_length=1024//4,
50
+ power=1,
51
+ ),
52
+ torchaudio.transforms.MelScale(
53
+ n_mels=80,
54
+ sample_rate=22050,
55
+ f_min=125,
56
+ f_max=7600,
57
+ n_stft=513,
58
+ norm='slaney'
59
+ ),
60
+ Wave2Spectrogram(mel_num, spec_crop_len),
61
+ )
62
+ ignore_keys = ['wav_transforms']
63
+
64
+ if ckpt_path is not None:
65
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
66
+ self.first_stage_key = first_stage_key
67
+ self.cond_stage_key = cond_stage_key
68
+ self.downsample_cond_size = downsample_cond_size
69
+ self.pkeep = pkeep
70
+ print('>>> model init done.')
71
+
72
+ def init_from_ckpt(self, path, ignore_keys=list()):
73
+ sd = torch.load(path, map_location="cpu")["state_dict"]
74
+ for k in sd.keys():
75
+ for ik in ignore_keys:
76
+ if k.startswith(ik):
77
+ self.print("Deleting key {} from state_dict.".format(k))
78
+ del sd[k]
79
+ self.load_state_dict(sd, strict=False)
80
+ print(f"Restored from {path}")
81
+
82
+ def init_first_stage_from_ckpt(self, config):
83
+ model = instantiate_from_config(config)
84
+ model = model.eval()
85
+ model.train = disabled_train
86
+ self.first_stage_model = model
87
+
88
+ def init_cond_stage_from_ckpt(self, config):
89
+ model = instantiate_from_config(config)
90
+ model = model.eval()
91
+ model.train = disabled_train
92
+ self.cond_stage_model = model
93
+
94
+ def forward(self, x, c):
95
+ # one step to produce the logits
96
+ _, z_indices = self.encode_to_z(x)
97
+ _, c_indices = self.encode_to_c(c)
98
+
99
+ if self.training and self.pkeep < 1.0:
100
+ mask = torch.bernoulli(self.pkeep * torch.ones(z_indices.shape, device=z_indices.device))
101
+ mask = mask.round().to(dtype=torch.int64)
102
+ r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
103
+ a_indices = mask*z_indices+(1-mask)*r_indices
104
+ else:
105
+ a_indices = z_indices
106
+
107
+ # target includes all sequence elements (no need to handle first one
108
+ # differently because we are conditioning)
109
+ target = z_indices
110
+
111
+ # in the case we do not want to encode condition anyhow (e.g. inputs are features)
112
+ if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
113
+ # make the prediction
114
+ logits, _, _ = self.transformer(z_indices[:, :-1], c)
115
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
116
+ if isinstance(self.transformer, GPTFeatsClass):
117
+ cond_size = c['feature'].size(-1) + c['target'].size(-1)
118
+ else:
119
+ cond_size = c.size(-1)
120
+ logits = logits[:, cond_size-1:]
121
+ else:
122
+ cz_indices = torch.cat((c_indices, a_indices), dim=1)
123
+ # make the prediction
124
+ logits, _, _ = self.transformer(cz_indices[:, :-1])
125
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
126
+ logits = logits[:, c_indices.shape[1]-1:]
127
+
128
+ return logits, target
129
+
130
+ def top_k_logits(self, logits, k):
131
+ v, ix = torch.topk(logits, k)
132
+ out = logits.clone()
133
+ out[out < v[..., [-1]]] = -float('Inf')
134
+ return out
135
+
136
+ @torch.no_grad()
137
+ def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
138
+ callback=lambda k: None):
139
+ x = x if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)) else torch.cat((c, x), dim=1)
140
+ block_size = self.transformer.get_block_size()
141
+ assert not self.transformer.training
142
+ if self.pkeep <= 0.0:
143
+ raise NotImplementedError('Implement for GPTFeatsCLass')
144
+ raise NotImplementedError('Implement for GPTFeats')
145
+ raise NotImplementedError('Implement for GPTClass')
146
+ raise NotImplementedError('also the model outputs attention')
147
+ # one pass suffices since input is pure noise anyway
148
+ assert len(x.shape)==2
149
+ # noise_shape = (x.shape[0], steps-1)
150
+ # noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
151
+ noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
152
+ x = torch.cat((x,noise),dim=1)
153
+ logits, _ = self.transformer(x)
154
+ # take all logits for now and scale by temp
155
+ logits = logits / temperature
156
+ # optionally crop probabilities to only the top k options
157
+ if top_k is not None:
158
+ logits = self.top_k_logits(logits, top_k)
159
+ # apply softmax to convert to probabilities
160
+ probs = F.softmax(logits, dim=-1)
161
+ # sample from the distribution or take the most likely
162
+ if sample:
163
+ shape = probs.shape
164
+ probs = probs.reshape(shape[0]*shape[1],shape[2])
165
+ ix = torch.multinomial(probs, num_samples=1)
166
+ probs = probs.reshape(shape[0],shape[1],shape[2])
167
+ ix = ix.reshape(shape[0],shape[1])
168
+ else:
169
+ _, ix = torch.topk(probs, k=1, dim=-1)
170
+ # cut off conditioning
171
+ x = ix[:, c.shape[1]-1:]
172
+ else:
173
+ for k in range(steps):
174
+ callback(k)
175
+ if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
176
+ # if assert is removed, you need to make sure that the combined len is not longer block_s
177
+ if isinstance(self.transformer, GPTFeatsClass):
178
+ cond_size = c['feature'].size(-1) + c['target'].size(-1)
179
+ else:
180
+ cond_size = c.size(-1)
181
+ assert x.size(1) + cond_size <= block_size
182
+
183
+ x_cond = x
184
+ c_cond = c
185
+ logits, _, att = self.transformer(x_cond, c_cond)
186
+ else:
187
+ assert x.size(1) <= block_size # make sure model can see conditioning
188
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
189
+ logits, _, att = self.transformer(x_cond)
190
+ # pluck the logits at the final step and scale by temperature
191
+ logits = logits[:, -1, :] / temperature
192
+ # optionally crop probabilities to only the top k options
193
+ if top_k is not None:
194
+ logits = self.top_k_logits(logits, top_k)
195
+ # apply softmax to convert to probabilities
196
+ probs = F.softmax(logits, dim=-1)
197
+ # sample from the distribution or take the most likely
198
+ if sample:
199
+ ix = torch.multinomial(probs, num_samples=1)
200
+ else:
201
+ _, ix = torch.topk(probs, k=1, dim=-1)
202
+ # append to the sequence and continue
203
+ x = torch.cat((x, ix), dim=1)
204
+ # cut off conditioning
205
+ x = x if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)) else x[:, c.shape[1]:]
206
+ return x, att.detach().cpu()
207
+
208
+ @torch.no_grad()
209
+ def encode_to_z(self, x):
210
+ quant_z, _, info = self.first_stage_model.encode(x)
211
+ indices = info[2].view(quant_z.shape[0], -1)
212
+ indices = self.first_stage_permuter(indices)
213
+ return quant_z, indices
214
+
215
+ @torch.no_grad()
216
+ def encode_to_c(self, c):
217
+ if self.downsample_cond_size > -1:
218
+ c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
219
+ quant_c, _, info = self.cond_stage_model.encode(c)
220
+ if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
221
+ # these are not indices but raw features or a class
222
+ indices = info[2]
223
+ else:
224
+ indices = info[2].view(quant_c.shape[0], -1)
225
+ indices = self.cond_stage_permuter(indices)
226
+ return quant_c, indices
227
+
228
+ @torch.no_grad()
229
+ def decode_to_img(self, index, zshape, stage='first'):
230
+ if stage == 'first':
231
+ index = self.first_stage_permuter(index, reverse=True)
232
+ elif stage == 'cond':
233
+ print('in cond stage in decode_to_img which is unexpected ')
234
+ index = self.cond_stage_permuter(index, reverse=True)
235
+ else:
236
+ raise NotImplementedError
237
+
238
+ bhwc = (zshape[0], zshape[2], zshape[3], zshape[1])
239
+ quant_z = self.first_stage_model.quantize.get_codebook_entry(index.reshape(-1), shape=bhwc)
240
+ x = self.first_stage_model.decode(quant_z)
241
+ return x
242
+
243
+ @torch.no_grad()
244
+ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
245
+ log = dict()
246
+
247
+ N = 4
248
+ if lr_interface:
249
+ x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
250
+ else:
251
+ x, c = self.get_xc(batch, N)
252
+ x = x.to(device=self.device)
253
+ # c = c.to(device=self.device)
254
+ if isinstance(c, dict):
255
+ c = {k: v.to(self.device) for k, v in c.items()}
256
+ else:
257
+ c = c.to(self.device)
258
+
259
+ quant_z, z_indices = self.encode_to_z(x)
260
+ quant_c, c_indices = self.encode_to_c(c) # output can be features or a single class or a featcls dict
261
+
262
+ # create a "half"" sample
263
+ z_start_indices = z_indices[:, :z_indices.shape[1]//2]
264
+ index_sample, att_half = self.sample(z_start_indices, c_indices,
265
+ steps=z_indices.shape[1]-z_start_indices.shape[1],
266
+ temperature=temperature if temperature is not None else 1.0,
267
+ sample=True,
268
+ top_k=top_k if top_k is not None else 100,
269
+ callback=callback if callback is not None else lambda k: None)
270
+ x_sample = self.decode_to_img(index_sample, quant_z.shape)
271
+
272
+ # sample
273
+ z_start_indices = z_indices[:, :0]
274
+ index_sample, att_nopix = self.sample(z_start_indices, c_indices,
275
+ steps=z_indices.shape[1],
276
+ temperature=temperature if temperature is not None else 1.0,
277
+ sample=True,
278
+ top_k=top_k if top_k is not None else 100,
279
+ callback=callback if callback is not None else lambda k: None)
280
+ x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
281
+
282
+ # det sample
283
+ z_start_indices = z_indices[:, :0]
284
+ index_sample, att_det = self.sample(z_start_indices, c_indices,
285
+ steps=z_indices.shape[1],
286
+ sample=False,
287
+ callback=callback if callback is not None else lambda k: None)
288
+ x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
289
+
290
+ # reconstruction
291
+ x_rec = self.decode_to_img(z_indices, quant_z.shape)
292
+
293
+ log["inputs"] = x
294
+ log["reconstructions"] = x_rec
295
+
296
+ if isinstance(self.cond_stage_key, str):
297
+ cond_is_not_image = self.cond_stage_key != "image"
298
+ cond_has_segmentation = self.cond_stage_key == "segmentation"
299
+ elif isinstance(self.cond_stage_key, ListConfig):
300
+ cond_is_not_image = 'image' not in self.cond_stage_key
301
+ cond_has_segmentation = 'segmentation' in self.cond_stage_key
302
+ else:
303
+ raise NotImplementedError
304
+
305
+ if cond_is_not_image:
306
+ cond_rec = self.cond_stage_model.decode(quant_c)
307
+ if cond_has_segmentation:
308
+ # get image from segmentation mask
309
+ num_classes = cond_rec.shape[1]
310
+
311
+ c = torch.argmax(c, dim=1, keepdim=True)
312
+ c = F.one_hot(c, num_classes=num_classes)
313
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
314
+ c = self.cond_stage_model.to_rgb(c)
315
+
316
+ cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
317
+ cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
318
+ cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
319
+ cond_rec = self.cond_stage_model.to_rgb(cond_rec)
320
+ log["conditioning_rec"] = cond_rec
321
+ log["conditioning"] = c
322
+
323
+ log["samples_half"] = x_sample
324
+ log["samples_nopix"] = x_sample_nopix
325
+ log["samples_det"] = x_sample_det
326
+ log["att_half"] = att_half
327
+ log["att_nopix"] = att_nopix
328
+ log["att_det"] = att_det
329
+ return log
330
+
331
+ def spec_transform(self, batch):
332
+ wav = batch[self.first_stage_key]
333
+ N = wav.shape[0]
334
+ self.wav_transforms.to(wav.device)
335
+ spec = self.wav_transforms(wav.to(torch.float32))
336
+ batch[self.first_stage_key] = 2 * spec[:N] - 1
337
+ return batch
338
+
339
+ def get_input(self, key, batch):
340
+ if isinstance(key, str):
341
+ # if batch[key] is 1D; else the batch[key] is 2D
342
+ if key in ['feature', 'target']:
343
+ x = self.cond_stage_model.get_input(batch, key)
344
+ else:
345
+ x = batch[key]
346
+ if len(x.shape) == 3:
347
+ x = x[..., None]
348
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
349
+ if x.dtype == torch.double:
350
+ x = x.float()
351
+ elif isinstance(key, ListConfig):
352
+ x = self.cond_stage_model.get_input(batch, key)
353
+ for k, v in x.items():
354
+ if v.dtype == torch.double:
355
+ x[k] = v.float()
356
+ return x
357
+
358
+ def get_xc(self, batch, N=None):
359
+ if len(batch[self.first_stage_key].shape) == 2:
360
+ batch = self.spec_transform(batch)
361
+ x = self.get_input(self.first_stage_key, batch)
362
+ c = self.get_input(self.cond_stage_key, batch)
363
+ if N is not None:
364
+ x = x[:N]
365
+ if isinstance(self.cond_stage_key, ListConfig):
366
+ c = {k: v[:N] for k, v in c.items()}
367
+ else:
368
+ c = c[:N]
369
+ return x, c
370
+
371
+ def shared_step(self, batch, batch_idx):
372
+ x, c = self.get_xc(batch)
373
+ logits, target = self(x, c)
374
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
375
+ return loss
376
+
377
+ def training_step(self, batch, batch_idx):
378
+ loss = self.shared_step(batch, batch_idx)
379
+ self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
380
+ return loss
381
+
382
+ def validation_step(self, batch, batch_idx):
383
+ loss = self.shared_step(batch, batch_idx)
384
+ self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
385
+ return loss
386
+
387
+ def configure_optimizers(self):
388
+ """
389
+ Following minGPT:
390
+ This long function is unfortunately doing something very simple and is being very defensive:
391
+ We are separating out all parameters of the model into two buckets: those that will experience
392
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
393
+ We are then returning the PyTorch optimizer object.
394
+ """
395
+ # separate out all parameters to those that will and won't experience regularizing weight decay
396
+ decay = set()
397
+ no_decay = set()
398
+ whitelist_weight_modules = (torch.nn.Linear, )
399
+
400
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, torch.nn.Conv1d, torch.nn.LSTM, torch.nn.GRU)
401
+ for mn, m in self.transformer.named_modules():
402
+ for pn, p in m.named_parameters():
403
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
404
+
405
+ if pn.endswith('bias'):
406
+ # all biases will not be decayed
407
+ no_decay.add(fpn)
408
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
409
+ # weights of whitelist modules will be weight decayed
410
+ decay.add(fpn)
411
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
412
+ # weights of blacklist modules will NOT be weight decayed
413
+ no_decay.add(fpn)
414
+ elif ('weight' in pn or 'bias' in pn) and isinstance(m, (torch.nn.LSTM, torch.nn.GRU)):
415
+ no_decay.add(fpn)
416
+
417
+ # special case the position embedding parameter in the root GPT module as not decayed
418
+ no_decay.add('pos_emb')
419
+
420
+ # validate that we considered every parameter
421
+ param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
422
+ inter_params = decay & no_decay
423
+ union_params = decay | no_decay
424
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
425
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
426
+ % (str(param_dict.keys() - union_params), )
427
+
428
+ # create the pytorch optimizer object
429
+ optim_groups = [
430
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
431
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
432
+ ]
433
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
434
+ return optimizer
435
+
436
+
437
+ if __name__ == '__main__':
438
+ from omegaconf import OmegaConf
439
+
440
+ cfg_image = OmegaConf.load('./configs/vggsound_transformer.yaml')
441
+ cfg_image.model.params.first_stage_config.params.ckpt_path = './logs/2021-05-19T22-16-54_vggsound_codebook/checkpoints/last.ckpt'
442
+
443
+ transformer_cfg = cfg_image.model.params.transformer_config
444
+ first_stage_cfg = cfg_image.model.params.first_stage_config
445
+ cond_stage_cfg = cfg_image.model.params.cond_stage_config
446
+ permuter_cfg = cfg_image.model.params.permuter_config
447
+ transformer = Net2NetTransformer(
448
+ transformer_cfg, first_stage_cfg, cond_stage_cfg, permuter_cfg
449
+ )
450
+
451
+ c = torch.rand(2, 2048, 212)
452
+ x = torch.rand(2, 1, 80, 160)
453
+
454
+ logits, target = transformer(x, c)
455
+ print(logits.shape, target.shape)
foleycrafter/models/specvqgan/models/vqgan.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+ import pytorch_lightning as pl
7
+
8
+ import sys
9
+ import math
10
+ sys.path.insert(0, '.') # nopep8
11
+ from train import instantiate_from_config
12
+ from foleycrafter.models.specvqgan.data.transforms import Wave2Spectrogram, NormalizeAudio
13
+
14
+ from foleycrafter.models.specvqgan.modules.diffusionmodules.model import Encoder, Decoder, Encoder1d, Decoder1d
15
+ from foleycrafter.models.specvqgan.modules.vqvae.quantize import VectorQuantizer, VectorQuantizer1d
16
+
17
+
18
+ class VQModel(pl.LightningModule):
19
+ def __init__(self,
20
+ ddconfig,
21
+ lossconfig,
22
+ n_embed,
23
+ embed_dim,
24
+ ckpt_path=None,
25
+ ignore_keys=[],
26
+ image_key="image",
27
+ colorize_nlabels=None,
28
+ monitor=None,
29
+ L=10.,
30
+ mel_num=80,
31
+ spec_crop_len=160,
32
+ normalize=False,
33
+ freeze_encoder=False,
34
+ ):
35
+ super().__init__()
36
+ self.image_key = image_key
37
+ # we need this one for compatibility in train.ImageLogger.log_img if statement
38
+ self.first_stage_key = image_key
39
+ self.encoder = Encoder(**ddconfig)
40
+ self.decoder = Decoder(**ddconfig)
41
+ self.loss = instantiate_from_config(lossconfig)
42
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
43
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
44
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
45
+
46
+ aug_list = [
47
+ torchaudio.transforms.Spectrogram(
48
+ n_fft=1024,
49
+ hop_length=1024//4,
50
+ power=1,
51
+ ),
52
+ torchaudio.transforms.MelScale(
53
+ n_mels=80,
54
+ sample_rate=22050,
55
+ f_min=125,
56
+ f_max=7600,
57
+ n_stft=513,
58
+ norm='slaney'
59
+ ),
60
+ Wave2Spectrogram(mel_num, spec_crop_len),
61
+ ]
62
+ if normalize:
63
+ aug_list = [transforms.RandomApply([NormalizeAudio()], p=1. if normalize else 0.)] + aug_list
64
+
65
+ if not freeze_encoder:
66
+ self.wav_transforms = nn.Sequential(*aug_list)
67
+ ignore_keys += ['first_stage_model.wav_transforms', 'wav_transforms']
68
+
69
+ if ckpt_path is not None:
70
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
71
+ if colorize_nlabels is not None:
72
+ assert type(colorize_nlabels)==int
73
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
74
+ if monitor is not None:
75
+ self.monitor = monitor
76
+ self.used_codes = []
77
+ self.counts = [0 for _ in range(self.quantize.n_e)]
78
+
79
+ if freeze_encoder:
80
+ for param in self.encoder.parameters():
81
+ param.requires_grad = False
82
+ for param in self.quantize.parameters():
83
+ param.requires_grad = False
84
+ for param in self.quant_conv.parameters():
85
+ param.requires_grad = False
86
+
87
+ def init_from_ckpt(self, path, ignore_keys=list()):
88
+ sd = torch.load(path, map_location="cpu")["state_dict"]
89
+ keys = list(sd.keys())
90
+ for k in keys:
91
+ for ik in ignore_keys:
92
+ if k.startswith(ik):
93
+ print("Deleting key {} from state_dict.".format(k))
94
+ del sd[k]
95
+ self.load_state_dict(sd, strict=False)
96
+ print(f"Restored from {path}")
97
+
98
+ def encode(self, x):
99
+ h = self.encoder(x) # 2d: (B, 256, 16, 16) <- (B, 3, 256, 256)
100
+ h = self.quant_conv(h) # 2d: (B, 256, 16, 16)
101
+ quant, emb_loss, info = self.quantize(h) # (B, 256, 16, 16), (), ((), (768, 1024), (768, 1))
102
+ if not self.training:
103
+ self.counts = [info[2].squeeze().tolist().count(i) + self.counts[i] for i in range(self.quantize.n_e)]
104
+ return quant, emb_loss, info
105
+
106
+ def decode(self, quant):
107
+ quant = self.post_quant_conv(quant)
108
+ dec = self.decoder(quant)
109
+ return dec
110
+
111
+ def decode_code(self, code_b):
112
+ quant_b = self.quantize.embed_code(code_b)
113
+ dec = self.decode(quant_b)
114
+ return dec
115
+
116
+ def forward(self, input):
117
+ quant, diff, _ = self.encode(input)
118
+ dec = self.decode(quant)
119
+ return dec, diff
120
+
121
+ def get_input(self, batch, k):
122
+ x = batch[k]
123
+ if len(x.shape) == 2:
124
+ x = self.spec_trans(x)
125
+ if len(x.shape) == 3:
126
+ x = x[..., None]
127
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
128
+ return x.float()
129
+
130
+ def spec_trans(self, wav):
131
+ self.wav_transforms.to(wav.device)
132
+ spec = self.wav_transforms(wav.to(torch.float32))
133
+ return 2 * spec - 1
134
+
135
+ def training_step(self, batch, batch_idx, optimizer_idx):
136
+ x = self.get_input(batch, self.image_key)
137
+ xrec, qloss = self(x)
138
+
139
+ if optimizer_idx == 0:
140
+ # autoencode
141
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
142
+ last_layer=self.get_last_layer(), split="train")
143
+
144
+ self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
145
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
146
+ return aeloss
147
+
148
+ if optimizer_idx == 1:
149
+ # discriminator
150
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
+ last_layer=self.get_last_layer(), split="train")
152
+ self.log("train/disc_loss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
153
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
154
+ return discloss
155
+
156
+ def validation_step(self, batch, batch_idx):
157
+ if batch_idx == 0 and self.global_step != 0 and sum(self.counts) > 0:
158
+ zero_hit_codes = len([1 for count in self.counts if count == 0])
159
+ used_codes = []
160
+ for c, count in enumerate(self.counts):
161
+ used_codes.extend([c] * count)
162
+ self.logger.experiment.add_histogram('val/code_hits', torch.tensor(used_codes), self.global_step)
163
+ self.logger.experiment.add_scalar('val/zero_hit_codes', zero_hit_codes, self.global_step)
164
+ self.counts = [0 for _ in range(self.quantize.n_e)]
165
+ x = self.get_input(batch, self.image_key)
166
+ xrec, qloss = self(x)
167
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
168
+ last_layer=self.get_last_layer(), split="val")
169
+
170
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
171
+ last_layer=self.get_last_layer(), split="val")
172
+ rec_loss = log_dict_ae['val/rec_loss']
173
+ self.log('val/rec_loss', rec_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
174
+ self.log('val/aeloss', aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
175
+ self.log_dict(log_dict_ae)
176
+ self.log_dict(log_dict_disc)
177
+ return self.log_dict
178
+
179
+ def configure_optimizers(self):
180
+ lr = self.learning_rate
181
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
182
+ list(self.decoder.parameters()) +
183
+ list(self.quantize.parameters()) +
184
+ list(self.quant_conv.parameters()) +
185
+ list(self.post_quant_conv.parameters()),
186
+ lr=lr, betas=(0.5, 0.9))
187
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
188
+ lr=lr, betas=(0.5, 0.9))
189
+ return [opt_ae, opt_disc], []
190
+
191
+ def get_last_layer(self):
192
+ return self.decoder.conv_out.weight
193
+
194
+ def log_images(self, batch, **kwargs):
195
+ log = dict()
196
+ x = self.get_input(batch, self.image_key)
197
+ x = x.to(self.device)
198
+ xrec, _ = self(x)
199
+ if x.shape[1] > 3:
200
+ # colorize with random projection
201
+ assert xrec.shape[1] > 3
202
+ x = self.to_rgb(x)
203
+ xrec = self.to_rgb(xrec)
204
+ log["inputs"] = x
205
+ log["reconstructions"] = xrec
206
+ return log
207
+
208
+ def to_rgb(self, x):
209
+ assert self.image_key == "segmentation"
210
+ if not hasattr(self, "colorize"):
211
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
212
+ x = F.conv2d(x, weight=self.colorize)
213
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
214
+ return x
215
+
216
+
217
+ class VQModel1d(VQModel):
218
+ def __init__(self, ddconfig, lossconfig, n_embed, embed_dim, ckpt_path=None, ignore_keys=[],
219
+ image_key='feature', colorize_nlabels=None, monitor=None):
220
+ # ckpt_path is none to super because otherwise will try to load 1D checkpoint into 2D model
221
+ super().__init__(ddconfig, lossconfig, n_embed, embed_dim)
222
+ self.image_key = image_key
223
+ # we need this one for compatibility in train.ImageLogger.log_img if statement
224
+ self.first_stage_key = image_key
225
+ self.encoder = Encoder1d(**ddconfig)
226
+ self.decoder = Decoder1d(**ddconfig)
227
+ self.loss = instantiate_from_config(lossconfig)
228
+ self.quantize = VectorQuantizer1d(n_embed, embed_dim, beta=0.25)
229
+ self.quant_conv = torch.nn.Conv1d(ddconfig['z_channels'], embed_dim, 1)
230
+ self.post_quant_conv = torch.nn.Conv1d(embed_dim, ddconfig['z_channels'], 1)
231
+ if ckpt_path is not None:
232
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
233
+ if colorize_nlabels is not None:
234
+ assert type(colorize_nlabels)==int
235
+ self.register_buffer('colorize', torch.randn(3, colorize_nlabels, 1, 1))
236
+ if monitor is not None:
237
+ self.monitor = monitor
238
+
239
+ def get_input(self, batch, k):
240
+ x = batch[k]
241
+ if self.image_key == 'feature':
242
+ x = x.permute(0, 2, 1)
243
+ elif self.image_key == 'image':
244
+ x = x.unsqueeze(1)
245
+ x = x.to(memory_format=torch.contiguous_format)
246
+ return x.float()
247
+
248
+ def forward(self, input):
249
+ if self.image_key == 'image':
250
+ input = input.squeeze(1)
251
+ quant, diff, _ = self.encode(input)
252
+ dec = self.decode(quant)
253
+ if self.image_key == 'image':
254
+ dec = dec.unsqueeze(1)
255
+ return dec, diff
256
+
257
+ def log_images(self, batch, **kwargs):
258
+ if self.image_key == 'image':
259
+ log = dict()
260
+ x = self.get_input(batch, self.image_key)
261
+ x = x.to(self.device)
262
+ xrec, _ = self(x)
263
+ if x.shape[1] > 3:
264
+ # colorize with random projection
265
+ assert xrec.shape[1] > 3
266
+ x = self.to_rgb(x)
267
+ xrec = self.to_rgb(xrec)
268
+ log['inputs'] = x
269
+ log['reconstructions'] = xrec
270
+ return log
271
+ else:
272
+ raise NotImplementedError('1d input should be treated differently')
273
+
274
+ def to_rgb(self, batch, **kwargs):
275
+ raise NotImplementedError('1d input should be treated differently')
276
+
277
+
278
+ class VQSegmentationModel(VQModel):
279
+ def __init__(self, n_labels, *args, **kwargs):
280
+ super().__init__(*args, **kwargs)
281
+ self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
282
+
283
+ def configure_optimizers(self):
284
+ lr = self.learning_rate
285
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
286
+ list(self.decoder.parameters())+
287
+ list(self.quantize.parameters())+
288
+ list(self.quant_conv.parameters())+
289
+ list(self.post_quant_conv.parameters()),
290
+ lr=lr, betas=(0.5, 0.9))
291
+ return opt_ae
292
+
293
+ def training_step(self, batch, batch_idx):
294
+ x = self.get_input(batch, self.image_key)
295
+ xrec, qloss = self(x)
296
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
297
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
298
+ return aeloss
299
+
300
+ def validation_step(self, batch, batch_idx):
301
+ x = self.get_input(batch, self.image_key)
302
+ xrec, qloss = self(x)
303
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
304
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
305
+ total_loss = log_dict_ae["val/total_loss"]
306
+ self.log("val/total_loss", total_loss,
307
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
308
+ return aeloss
309
+
310
+ @torch.no_grad()
311
+ def log_images(self, batch, **kwargs):
312
+ log = dict()
313
+ x = self.get_input(batch, self.image_key)
314
+ x = x.to(self.device)
315
+ xrec, _ = self(x)
316
+ if x.shape[1] > 3:
317
+ # colorize with random projection
318
+ assert xrec.shape[1] > 3
319
+ # convert logits to indices
320
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
321
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
322
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
323
+ x = self.to_rgb(x)
324
+ xrec = self.to_rgb(xrec)
325
+ log["inputs"] = x
326
+ log["reconstructions"] = xrec
327
+ return log
328
+
329
+
330
+ class VQNoDiscModel(VQModel):
331
+ def __init__(self,
332
+ ddconfig,
333
+ lossconfig,
334
+ n_embed,
335
+ embed_dim,
336
+ ckpt_path=None,
337
+ ignore_keys=[],
338
+ image_key="image",
339
+ colorize_nlabels=None
340
+ ):
341
+ super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
342
+ ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
343
+ colorize_nlabels=colorize_nlabels)
344
+
345
+ def training_step(self, batch, batch_idx):
346
+ x = self.get_input(batch, self.image_key)
347
+ xrec, qloss = self(x)
348
+ # autoencode
349
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
350
+ output = pl.TrainResult(minimize=aeloss)
351
+ output.log("train/aeloss", aeloss,
352
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
353
+ output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
354
+ return output
355
+
356
+ def validation_step(self, batch, batch_idx):
357
+ x = self.get_input(batch, self.image_key)
358
+ xrec, qloss = self(x)
359
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
360
+ rec_loss = log_dict_ae["val/rec_loss"]
361
+ output = pl.EvalResult(checkpoint_on=rec_loss)
362
+ output.log("val/rec_loss", rec_loss,
363
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
364
+ output.log("val/aeloss", aeloss,
365
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
366
+ output.log_dict(log_dict_ae)
367
+
368
+ return output
369
+
370
+ def configure_optimizers(self):
371
+ optimizer = torch.optim.Adam(list(self.encoder.parameters()) +
372
+ list(self.decoder.parameters()) +
373
+ list(self.quantize.parameters()) +
374
+ list(self.quant_conv.parameters()) +
375
+ list(self.post_quant_conv.parameters()),
376
+ lr=self.learning_rate, betas=(0.5, 0.9))
377
+ return optimizer
378
+
379
+
380
+ if __name__ == '__main__':
381
+ from omegaconf import OmegaConf
382
+ from train import instantiate_from_config
383
+
384
+ image_key = 'image'
385
+ cfg_audio = OmegaConf.load('./configs/vggsound_codebook.yaml')
386
+ model = VQModel(cfg_audio.model.params.ddconfig,
387
+ cfg_audio.model.params.lossconfig,
388
+ cfg_audio.model.params.n_embed,
389
+ cfg_audio.model.params.embed_dim,
390
+ image_key='image')
391
+ batch = {
392
+ 'image': torch.rand((4, 80, 848)),
393
+ 'file_path_': ['data/vggsound/mel123.npy', 'data/vggsound/mel123.npy', 'data/vggsound/mel123.npy'],
394
+ 'class': [1, 1, 1],
395
+ }
396
+ xrec, qloss = model(model.get_input(batch, image_key))
397
+ print(xrec.shape, qloss.shape)
foleycrafter/models/specvqgan/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,999 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+
8
+ def get_timestep_embedding(timesteps, embedding_dim):
9
+ """
10
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
11
+ From Fairseq.
12
+ Build sinusoidal embeddings.
13
+ This matches the implementation in tensor2tensor, but differs slightly
14
+ from the description in Section 3.5 of "Attention Is All You Need".
15
+ """
16
+ assert len(timesteps.shape) == 1
17
+
18
+ half_dim = embedding_dim // 2
19
+ emb = math.log(10000) / (half_dim - 1)
20
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
21
+ emb = emb.to(device=timesteps.device)
22
+ emb = timesteps.float()[:, None] * emb[None, :]
23
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
24
+ if embedding_dim % 2 == 1: # zero pad
25
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
26
+ return emb
27
+
28
+
29
+ def nonlinearity(x):
30
+ # swish
31
+ return x*torch.sigmoid(x)
32
+
33
+
34
+ def Normalize(in_channels):
35
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
36
+
37
+ class Upsample(nn.Module):
38
+ def __init__(self, in_channels, with_conv):
39
+ super().__init__()
40
+ self.with_conv = with_conv
41
+ if self.with_conv:
42
+ self.conv = torch.nn.Conv2d(in_channels,
43
+ in_channels,
44
+ kernel_size=3,
45
+ stride=1,
46
+ padding=1)
47
+
48
+ def forward(self, x):
49
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
50
+ if self.with_conv:
51
+ x = self.conv(x)
52
+ return x
53
+
54
+ class Upsample1d(Upsample):
55
+ def __init__(self, in_channels, with_conv):
56
+ super().__init__(in_channels, with_conv)
57
+ if self.with_conv:
58
+ self.conv = torch.nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
59
+
60
+ class Downsample(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ # no asymmetric padding in torch conv, must do it ourselves
66
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
67
+ self.pad = (0, 1, 0, 1)
68
+ else:
69
+ self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
70
+
71
+ def forward(self, x):
72
+ if self.with_conv: # bp: check self.avgpool and self.pad
73
+ x = torch.nn.functional.pad(x, self.pad, mode="constant", value=0)
74
+ x = self.conv(x)
75
+ else:
76
+ x = self.avg_pool(x)
77
+ return x
78
+
79
+ class Downsample1d(Downsample):
80
+
81
+ def __init__(self, in_channels, with_conv):
82
+ super().__init__(in_channels, with_conv)
83
+ if self.with_conv:
84
+ # no asymmetric padding in torch conv, must do it ourselves
85
+ # TODO: can we replace it just with conv2d with padding 1?
86
+ self.conv = torch.nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
87
+ self.pad = (1, 1)
88
+ else:
89
+ self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2)
90
+
91
+
92
+ class ResnetBlock(nn.Module):
93
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
94
+ dropout, temb_channels=512):
95
+ super().__init__()
96
+ self.in_channels = in_channels
97
+ out_channels = in_channels if out_channels is None else out_channels
98
+ self.out_channels = out_channels
99
+ self.use_conv_shortcut = conv_shortcut
100
+
101
+ self.norm1 = Normalize(in_channels)
102
+ self.conv1 = torch.nn.Conv2d(in_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if temb_channels > 0:
108
+ self.temb_proj = torch.nn.Linear(temb_channels,
109
+ out_channels)
110
+ self.norm2 = Normalize(out_channels)
111
+ self.dropout = torch.nn.Dropout(dropout)
112
+ self.conv2 = torch.nn.Conv2d(out_channels,
113
+ out_channels,
114
+ kernel_size=3,
115
+ stride=1,
116
+ padding=1)
117
+ if self.in_channels != self.out_channels:
118
+ if self.use_conv_shortcut:
119
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
120
+ out_channels,
121
+ kernel_size=3,
122
+ stride=1,
123
+ padding=1)
124
+ else:
125
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
126
+ out_channels,
127
+ kernel_size=1,
128
+ stride=1,
129
+ padding=0)
130
+
131
+ def forward(self, x, temb):
132
+ h = x
133
+ h = self.norm1(h)
134
+ h = nonlinearity(h)
135
+ h = self.conv1(h)
136
+
137
+ if temb is not None:
138
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
139
+
140
+ h = self.norm2(h)
141
+ h = nonlinearity(h)
142
+ h = self.dropout(h)
143
+ h = self.conv2(h)
144
+
145
+ if self.in_channels != self.out_channels:
146
+ if self.use_conv_shortcut:
147
+ x = self.conv_shortcut(x)
148
+ else:
149
+ x = self.nin_shortcut(x)
150
+
151
+ return x+h
152
+
153
+ class ResnetBlock1d(ResnetBlock):
154
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
155
+ dropout, temb_channels=512):
156
+ super().__init__(in_channels=in_channels, out_channels=out_channels,
157
+ conv_shortcut=conv_shortcut, dropout=dropout, temb_channels=temb_channels)
158
+ # redefining different elements (forward is goint to be the same as in RenetBlock)
159
+ if temb_channels > 0:
160
+ raise NotImplementedError('go to ResnetBlock and figure out how to deal with it in forward')
161
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
162
+
163
+ self.conv1 = torch.nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
164
+ self.conv2 = torch.nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
165
+ if self.in_channels != self.out_channels:
166
+ if self.use_conv_shortcut:
167
+ self.conv_shortcut = torch.nn.Conv1d(in_channels, out_channels, kernel_size=3,
168
+ stride=1, padding=1)
169
+ else:
170
+ self.nin_shortcut = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1,
171
+ stride=1, padding=0)
172
+
173
+
174
+ class AttnBlock(nn.Module):
175
+ def __init__(self, in_channels):
176
+ super().__init__()
177
+ self.in_channels = in_channels
178
+
179
+ self.norm = Normalize(in_channels)
180
+ self.q = torch.nn.Conv2d(in_channels,
181
+ in_channels,
182
+ kernel_size=1,
183
+ stride=1,
184
+ padding=0)
185
+ self.k = torch.nn.Conv2d(in_channels,
186
+ in_channels,
187
+ kernel_size=1,
188
+ stride=1,
189
+ padding=0)
190
+ self.v = torch.nn.Conv2d(in_channels,
191
+ in_channels,
192
+ kernel_size=1,
193
+ stride=1,
194
+ padding=0)
195
+ self.proj_out = torch.nn.Conv2d(in_channels,
196
+ in_channels,
197
+ kernel_size=1,
198
+ stride=1,
199
+ padding=0)
200
+
201
+
202
+ def forward(self, x):
203
+ h_ = x
204
+ h_ = self.norm(h_)
205
+ q = self.q(h_)
206
+ k = self.k(h_)
207
+ v = self.v(h_)
208
+
209
+ # compute attention
210
+ b,c,h,w = q.shape
211
+ q = q.reshape(b,c,h*w)
212
+ q = q.permute(0,2,1) # b,hw,c
213
+ k = k.reshape(b,c,h*w) # b,c,hw
214
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
215
+ w_ = w_ * (int(c)**(-0.5))
216
+ w_ = torch.nn.functional.softmax(w_, dim=2)
217
+
218
+ # attend to values
219
+ v = v.reshape(b,c,h*w)
220
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
221
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
222
+ h_ = h_.reshape(b,c,h,w)
223
+
224
+ h_ = self.proj_out(h_)
225
+
226
+ return x+h_
227
+
228
+ class AttnBlock1d(nn.Module):
229
+
230
+ def __init__(self, in_channels):
231
+ super().__init__()
232
+ self.in_channels = in_channels
233
+
234
+ self.norm = Normalize(in_channels)
235
+ self.q = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
236
+ self.k = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
237
+ self.v = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
238
+ self.proj_out = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
239
+
240
+ def forward(self, x):
241
+ h_ = x
242
+ h_ = self.norm(h_)
243
+ q = self.q(h_)
244
+ k = self.k(h_)
245
+ v = self.v(h_)
246
+
247
+ # compute attention
248
+ b, c, t = q.shape
249
+ q = q.permute(0, 2, 1) # b,t,c
250
+ w_ = torch.bmm(q, k) # b,t,t w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
251
+ w_ = w_ * (int(c) ** (-0.5))
252
+ w_ = torch.nn.functional.softmax(w_, dim=2)
253
+
254
+ # attend to values
255
+ w_ = w_.permute(0, 2, 1) # b,t,t (first t of k, second of q)
256
+ h_ = torch.bmm(v, w_) # b,c,t (t of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
257
+
258
+ h_ = self.proj_out(h_)
259
+
260
+ return x + h_
261
+
262
+
263
+ class Model(nn.Module):
264
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
265
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
266
+ resolution, use_timestep=True):
267
+ super().__init__()
268
+ self.ch = ch
269
+ self.temb_ch = self.ch*4
270
+ self.num_resolutions = len(ch_mult)
271
+ self.num_res_blocks = num_res_blocks
272
+ self.resolution = resolution
273
+ self.in_channels = in_channels
274
+
275
+ self.use_timestep = use_timestep
276
+ if self.use_timestep:
277
+ # timestep embedding
278
+ self.temb = nn.Module()
279
+ self.temb.dense = nn.ModuleList([
280
+ torch.nn.Linear(self.ch,
281
+ self.temb_ch),
282
+ torch.nn.Linear(self.temb_ch,
283
+ self.temb_ch),
284
+ ])
285
+
286
+ # downsampling
287
+ self.conv_in = torch.nn.Conv2d(in_channels,
288
+ self.ch,
289
+ kernel_size=3,
290
+ stride=1,
291
+ padding=1)
292
+
293
+ curr_res = resolution
294
+ in_ch_mult = (1,)+tuple(ch_mult)
295
+ self.down = nn.ModuleList()
296
+ for i_level in range(self.num_resolutions):
297
+ block = nn.ModuleList()
298
+ attn = nn.ModuleList()
299
+ block_in = ch*in_ch_mult[i_level]
300
+ block_out = ch*ch_mult[i_level]
301
+ for i_block in range(self.num_res_blocks):
302
+ block.append(ResnetBlock(in_channels=block_in,
303
+ out_channels=block_out,
304
+ temb_channels=self.temb_ch,
305
+ dropout=dropout))
306
+ block_in = block_out
307
+ if curr_res in attn_resolutions:
308
+ attn.append(AttnBlock(block_in))
309
+ down = nn.Module()
310
+ down.block = block
311
+ down.attn = attn
312
+ if i_level != self.num_resolutions-1:
313
+ down.downsample = Downsample(block_in, resamp_with_conv)
314
+ curr_res = curr_res // 2
315
+ self.down.append(down)
316
+
317
+ # middle
318
+ self.mid = nn.Module()
319
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
320
+ out_channels=block_in,
321
+ temb_channels=self.temb_ch,
322
+ dropout=dropout)
323
+ self.mid.attn_1 = AttnBlock(block_in)
324
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
325
+ out_channels=block_in,
326
+ temb_channels=self.temb_ch,
327
+ dropout=dropout)
328
+
329
+ # upsampling
330
+ self.up = nn.ModuleList()
331
+ for i_level in reversed(range(self.num_resolutions)):
332
+ block = nn.ModuleList()
333
+ attn = nn.ModuleList()
334
+ block_out = ch*ch_mult[i_level]
335
+ skip_in = ch*ch_mult[i_level]
336
+ for i_block in range(self.num_res_blocks+1):
337
+ if i_block == self.num_res_blocks:
338
+ skip_in = ch*in_ch_mult[i_level]
339
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
340
+ out_channels=block_out,
341
+ temb_channels=self.temb_ch,
342
+ dropout=dropout))
343
+ block_in = block_out
344
+ if curr_res in attn_resolutions:
345
+ attn.append(AttnBlock(block_in))
346
+ up = nn.Module()
347
+ up.block = block
348
+ up.attn = attn
349
+ if i_level != 0:
350
+ up.upsample = Upsample(block_in, resamp_with_conv)
351
+ curr_res = curr_res * 2
352
+ self.up.insert(0, up) # prepend to get consistent order
353
+
354
+ # end
355
+ self.norm_out = Normalize(block_in)
356
+ self.conv_out = torch.nn.Conv2d(block_in,
357
+ out_ch,
358
+ kernel_size=3,
359
+ stride=1,
360
+ padding=1)
361
+
362
+
363
+ def forward(self, x, t=None):
364
+ #assert x.shape[2] == x.shape[3] == self.resolution
365
+
366
+ if self.use_timestep:
367
+ # timestep embedding
368
+ assert t is not None
369
+ temb = get_timestep_embedding(t, self.ch)
370
+ temb = self.temb.dense[0](temb)
371
+ temb = nonlinearity(temb)
372
+ temb = self.temb.dense[1](temb)
373
+ else:
374
+ temb = None
375
+
376
+ # downsampling
377
+ hs = [self.conv_in(x)]
378
+ for i_level in range(self.num_resolutions):
379
+ for i_block in range(self.num_res_blocks):
380
+ h = self.down[i_level].block[i_block](hs[-1], temb)
381
+ if len(self.down[i_level].attn) > 0:
382
+ h = self.down[i_level].attn[i_block](h)
383
+ hs.append(h)
384
+ if i_level != self.num_resolutions-1:
385
+ hs.append(self.down[i_level].downsample(hs[-1]))
386
+
387
+ # middle
388
+ h = hs[-1]
389
+ h = self.mid.block_1(h, temb)
390
+ h = self.mid.attn_1(h)
391
+ h = self.mid.block_2(h, temb)
392
+
393
+ # upsampling
394
+ for i_level in reversed(range(self.num_resolutions)):
395
+ for i_block in range(self.num_res_blocks+1):
396
+ h = self.up[i_level].block[i_block](
397
+ torch.cat([h, hs.pop()], dim=1), temb)
398
+ if len(self.up[i_level].attn) > 0:
399
+ h = self.up[i_level].attn[i_block](h)
400
+ if i_level != 0:
401
+ h = self.up[i_level].upsample(h)
402
+
403
+ # end
404
+ h = self.norm_out(h)
405
+ h = nonlinearity(h)
406
+ h = self.conv_out(h)
407
+ return h
408
+
409
+
410
+ class Encoder(nn.Module):
411
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
412
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
413
+ resolution, z_channels, double_z=True, **ignore_kwargs):
414
+ super().__init__()
415
+ self.ch = ch
416
+ self.temb_ch = 0
417
+ self.num_resolutions = len(ch_mult)
418
+ self.num_res_blocks = num_res_blocks
419
+ self.resolution = resolution
420
+ self.in_channels = in_channels
421
+
422
+ # downsampling
423
+ self.conv_in = torch.nn.Conv2d(in_channels,
424
+ self.ch,
425
+ kernel_size=3,
426
+ stride=1,
427
+ padding=1)
428
+
429
+ curr_res = resolution
430
+ in_ch_mult = (1,)+tuple(ch_mult)
431
+ self.down = nn.ModuleList()
432
+ for i_level in range(self.num_resolutions):
433
+ block = nn.ModuleList()
434
+ attn = nn.ModuleList()
435
+ block_in = ch*in_ch_mult[i_level]
436
+ block_out = ch*ch_mult[i_level]
437
+ for i_block in range(self.num_res_blocks):
438
+ block.append(ResnetBlock(in_channels=block_in,
439
+ out_channels=block_out,
440
+ temb_channels=self.temb_ch,
441
+ dropout=dropout))
442
+ block_in = block_out
443
+ if curr_res in attn_resolutions:
444
+ attn.append(AttnBlock(block_in))
445
+ down = nn.Module()
446
+ down.block = block
447
+ down.attn = attn
448
+ if i_level != self.num_resolutions-1:
449
+ down.downsample = Downsample(block_in, resamp_with_conv)
450
+ curr_res = curr_res // 2
451
+ self.down.append(down)
452
+
453
+ # middle
454
+ self.mid = nn.Module()
455
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
456
+ out_channels=block_in,
457
+ temb_channels=self.temb_ch,
458
+ dropout=dropout)
459
+ self.mid.attn_1 = AttnBlock(block_in)
460
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
461
+ out_channels=block_in,
462
+ temb_channels=self.temb_ch,
463
+ dropout=dropout)
464
+
465
+ # end
466
+ self.norm_out = Normalize(block_in)
467
+ self.conv_out = torch.nn.Conv2d(block_in,
468
+ 2*z_channels if double_z else z_channels,
469
+ kernel_size=3,
470
+ stride=1,
471
+ padding=1)
472
+
473
+
474
+ def forward(self, x):
475
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
476
+
477
+ # timestep embedding
478
+ temb = None
479
+
480
+ # downsampling
481
+ hs = [self.conv_in(x)]
482
+ for i_level in range(self.num_resolutions):
483
+ for i_block in range(self.num_res_blocks):
484
+ h = self.down[i_level].block[i_block](hs[-1], temb)
485
+ if len(self.down[i_level].attn) > 0:
486
+ h = self.down[i_level].attn[i_block](h)
487
+ hs.append(h)
488
+ if i_level != self.num_resolutions-1:
489
+ hs.append(self.down[i_level].downsample(hs[-1]))
490
+
491
+ # middle
492
+ h = hs[-1]
493
+ h = self.mid.block_1(h, temb)
494
+ h = self.mid.attn_1(h)
495
+ h = self.mid.block_2(h, temb)
496
+
497
+ # end
498
+ h = self.norm_out(h)
499
+ h = nonlinearity(h)
500
+ h = self.conv_out(h)
501
+ return h
502
+
503
+ class Encoder1d(Encoder):
504
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
505
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
506
+ resolution, z_channels, double_z=True, **ignore_kwargs):
507
+ super().__init__(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
508
+ attn_resolutions=attn_resolutions, dropout=dropout,
509
+ resamp_with_conv=resamp_with_conv,
510
+ in_channels=in_channels, resolution=resolution, z_channels=z_channels,
511
+ double_z=double_z, **ignore_kwargs)
512
+ self.ch = ch
513
+ self.temb_ch = 0
514
+ self.num_resolutions = len(ch_mult)
515
+ self.num_res_blocks = num_res_blocks
516
+ self.resolution = resolution
517
+ self.in_channels = in_channels
518
+
519
+ # downsampling
520
+ self.conv_in = torch.nn.Conv1d(in_channels,
521
+ self.ch,
522
+ kernel_size=3,
523
+ stride=1,
524
+ padding=1)
525
+
526
+ curr_res = resolution
527
+ in_ch_mult = (1,)+tuple(ch_mult)
528
+ self.down = nn.ModuleList()
529
+ for i_level in range(self.num_resolutions):
530
+ block = nn.ModuleList()
531
+ attn = nn.ModuleList()
532
+ block_in = ch*in_ch_mult[i_level]
533
+ block_out = ch*ch_mult[i_level]
534
+ for i_block in range(self.num_res_blocks):
535
+ block.append(ResnetBlock1d(in_channels=block_in,
536
+ out_channels=block_out,
537
+ temb_channels=self.temb_ch,
538
+ dropout=dropout))
539
+ block_in = block_out
540
+ if curr_res in attn_resolutions:
541
+ attn.append(AttnBlock1d(block_in))
542
+ down = nn.Module()
543
+ down.block = block
544
+ down.attn = attn
545
+ if i_level != self.num_resolutions-1:
546
+ down.downsample = Downsample1d(block_in, resamp_with_conv)
547
+ curr_res = curr_res // 2
548
+ self.down.append(down)
549
+
550
+ # middle
551
+ self.mid = nn.Module()
552
+ self.mid.block_1 = ResnetBlock1d(in_channels=block_in,
553
+ out_channels=block_in,
554
+ temb_channels=self.temb_ch,
555
+ dropout=dropout)
556
+ self.mid.attn_1 = AttnBlock1d(block_in)
557
+ self.mid.block_2 = ResnetBlock1d(in_channels=block_in,
558
+ out_channels=block_in,
559
+ temb_channels=self.temb_ch,
560
+ dropout=dropout)
561
+
562
+ # end
563
+ self.norm_out = Normalize(block_in)
564
+ self.conv_out = torch.nn.Conv1d(block_in,
565
+ 2*z_channels if double_z else z_channels,
566
+ kernel_size=3,
567
+ stride=1,
568
+ padding=1)
569
+
570
+
571
+ class Decoder(nn.Module):
572
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
573
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
574
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
575
+ super().__init__()
576
+ self.ch = ch
577
+ self.temb_ch = 0
578
+ self.num_resolutions = len(ch_mult)
579
+ self.num_res_blocks = num_res_blocks
580
+ self.resolution = resolution
581
+ self.in_channels = in_channels
582
+ self.give_pre_end = give_pre_end
583
+
584
+ # compute in_ch_mult, block_in and curr_res at lowest res
585
+ in_ch_mult = (1,)+tuple(ch_mult)
586
+ block_in = ch*ch_mult[self.num_resolutions-1]
587
+ curr_res = resolution // 2**(self.num_resolutions-1)
588
+ # self.z_shape = (1,z_channels,curr_res,curr_res)
589
+ # print("Working with z of shape {} = {} dimensions.".format(
590
+ # self.z_shape, np.prod(self.z_shape)))
591
+
592
+ # z to block_in
593
+ self.conv_in = torch.nn.Conv2d(z_channels,
594
+ block_in,
595
+ kernel_size=3,
596
+ stride=1,
597
+ padding=1)
598
+
599
+ # middle
600
+ self.mid = nn.Module()
601
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
602
+ out_channels=block_in,
603
+ temb_channels=self.temb_ch,
604
+ dropout=dropout)
605
+ self.mid.attn_1 = AttnBlock(block_in)
606
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
607
+ out_channels=block_in,
608
+ temb_channels=self.temb_ch,
609
+ dropout=dropout)
610
+
611
+ # upsampling
612
+ self.up = nn.ModuleList()
613
+ for i_level in reversed(range(self.num_resolutions)):
614
+ block = nn.ModuleList()
615
+ attn = nn.ModuleList()
616
+ block_out = ch*ch_mult[i_level]
617
+ for i_block in range(self.num_res_blocks+1):
618
+ block.append(ResnetBlock(in_channels=block_in,
619
+ out_channels=block_out,
620
+ temb_channels=self.temb_ch,
621
+ dropout=dropout))
622
+ block_in = block_out
623
+ if curr_res in attn_resolutions:
624
+ attn.append(AttnBlock(block_in))
625
+ up = nn.Module()
626
+ up.block = block
627
+ up.attn = attn
628
+ if i_level != 0:
629
+ up.upsample = Upsample(block_in, resamp_with_conv)
630
+ curr_res = curr_res * 2
631
+ self.up.insert(0, up) # prepend to get consistent order
632
+
633
+ # end
634
+ self.norm_out = Normalize(block_in)
635
+ self.conv_out = torch.nn.Conv2d(block_in,
636
+ out_ch,
637
+ kernel_size=3,
638
+ stride=1,
639
+ padding=1)
640
+
641
+ def forward(self, z):
642
+ #assert z.shape[1:] == self.z_shape[1:]
643
+ self.last_z_shape = z.shape
644
+
645
+ # timestep embedding
646
+ temb = None
647
+
648
+ # z to block_in
649
+ h = self.conv_in(z)
650
+
651
+ # middle
652
+ h = self.mid.block_1(h, temb)
653
+ h = self.mid.attn_1(h)
654
+ h = self.mid.block_2(h, temb)
655
+
656
+ # upsampling
657
+ for i_level in reversed(range(self.num_resolutions)):
658
+ for i_block in range(self.num_res_blocks+1):
659
+ h = self.up[i_level].block[i_block](h, temb)
660
+ if len(self.up[i_level].attn) > 0:
661
+ h = self.up[i_level].attn[i_block](h)
662
+ if i_level != 0:
663
+ h = self.up[i_level].upsample(h)
664
+
665
+ # end
666
+ if self.give_pre_end:
667
+ return h
668
+
669
+ h = self.norm_out(h)
670
+ h = nonlinearity(h)
671
+ h = self.conv_out(h)
672
+ return h
673
+
674
+ class Decoder1d(Decoder):
675
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
676
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
677
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
678
+ super().__init__(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
679
+ attn_resolutions=attn_resolutions, dropout=dropout,
680
+ resamp_with_conv=resamp_with_conv,
681
+ in_channels=in_channels, resolution=resolution, z_channels=z_channels,
682
+ give_pre_end=give_pre_end, **ignorekwargs)
683
+ self.ch = ch
684
+ self.temb_ch = 0
685
+ self.num_resolutions = len(ch_mult)
686
+ self.num_res_blocks = num_res_blocks
687
+ self.resolution = resolution
688
+ self.in_channels = in_channels
689
+ self.give_pre_end = give_pre_end
690
+
691
+ # compute in_ch_mult, block_in and curr_res at lowest res
692
+ in_ch_mult = (1,) + tuple(ch_mult)
693
+ block_in = ch * ch_mult[self.num_resolutions-1]
694
+ curr_res = resolution // 2**(self.num_resolutions-1)
695
+ # self.z_shape = (1,z_channels,curr_res,curr_res)
696
+ # print("Working with z of shape {} = {} dimensions.".format(
697
+ # self.z_shape, np.prod(self.z_shape)))
698
+
699
+ # z to block_in
700
+ self.conv_in = torch.nn.Conv1d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
701
+
702
+ # middle
703
+ self.mid = nn.Module()
704
+ self.mid.block_1 = ResnetBlock1d(in_channels=block_in, out_channels=block_in,
705
+ temb_channels=self.temb_ch, dropout=dropout)
706
+ self.mid.attn_1 = AttnBlock1d(block_in)
707
+ self.mid.block_2 = ResnetBlock1d(in_channels=block_in, out_channels=block_in,
708
+ temb_channels=self.temb_ch, dropout=dropout)
709
+
710
+ # upsampling
711
+ self.up = nn.ModuleList()
712
+ for i_level in reversed(range(self.num_resolutions)):
713
+ block = nn.ModuleList()
714
+ attn = nn.ModuleList()
715
+ block_out = ch * ch_mult[i_level]
716
+ for i_block in range(self.num_res_blocks+1):
717
+ block.append(ResnetBlock1d(in_channels=block_in, out_channels=block_out,
718
+ temb_channels=self.temb_ch, dropout=dropout))
719
+ block_in = block_out
720
+ if curr_res in attn_resolutions:
721
+ attn.append(AttnBlock1d(block_in))
722
+ up = nn.Module()
723
+ up.block = block
724
+ up.attn = attn
725
+ if i_level != 0:
726
+ up.upsample = Upsample1d(block_in, resamp_with_conv)
727
+ curr_res = curr_res * 2
728
+ self.up.insert(0, up) # prepend to get consistent order
729
+
730
+ # end
731
+ self.norm_out = Normalize(block_in)
732
+ self.conv_out = torch.nn.Conv1d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
733
+
734
+
735
+ class VUNet(nn.Module):
736
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
737
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
738
+ in_channels, c_channels,
739
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
740
+ super().__init__()
741
+ self.ch = ch
742
+ self.temb_ch = self.ch*4
743
+ self.num_resolutions = len(ch_mult)
744
+ self.num_res_blocks = num_res_blocks
745
+ self.resolution = resolution
746
+
747
+ self.use_timestep = use_timestep
748
+ if self.use_timestep:
749
+ # timestep embedding
750
+ self.temb = nn.Module()
751
+ self.temb.dense = nn.ModuleList([
752
+ torch.nn.Linear(self.ch,
753
+ self.temb_ch),
754
+ torch.nn.Linear(self.temb_ch,
755
+ self.temb_ch),
756
+ ])
757
+
758
+ # downsampling
759
+ self.conv_in = torch.nn.Conv2d(c_channels,
760
+ self.ch,
761
+ kernel_size=3,
762
+ stride=1,
763
+ padding=1)
764
+
765
+ curr_res = resolution
766
+ in_ch_mult = (1,)+tuple(ch_mult)
767
+ self.down = nn.ModuleList()
768
+ for i_level in range(self.num_resolutions):
769
+ block = nn.ModuleList()
770
+ attn = nn.ModuleList()
771
+ block_in = ch*in_ch_mult[i_level]
772
+ block_out = ch*ch_mult[i_level]
773
+ for i_block in range(self.num_res_blocks):
774
+ block.append(ResnetBlock(in_channels=block_in,
775
+ out_channels=block_out,
776
+ temb_channels=self.temb_ch,
777
+ dropout=dropout))
778
+ block_in = block_out
779
+ if curr_res in attn_resolutions:
780
+ attn.append(AttnBlock(block_in))
781
+ down = nn.Module()
782
+ down.block = block
783
+ down.attn = attn
784
+ if i_level != self.num_resolutions-1:
785
+ down.downsample = Downsample(block_in, resamp_with_conv)
786
+ curr_res = curr_res // 2
787
+ self.down.append(down)
788
+
789
+ self.z_in = torch.nn.Conv2d(z_channels,
790
+ block_in,
791
+ kernel_size=1,
792
+ stride=1,
793
+ padding=0)
794
+ # middle
795
+ self.mid = nn.Module()
796
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
797
+ out_channels=block_in,
798
+ temb_channels=self.temb_ch,
799
+ dropout=dropout)
800
+ self.mid.attn_1 = AttnBlock(block_in)
801
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
802
+ out_channels=block_in,
803
+ temb_channels=self.temb_ch,
804
+ dropout=dropout)
805
+
806
+ # upsampling
807
+ self.up = nn.ModuleList()
808
+ for i_level in reversed(range(self.num_resolutions)):
809
+ block = nn.ModuleList()
810
+ attn = nn.ModuleList()
811
+ block_out = ch*ch_mult[i_level]
812
+ skip_in = ch*ch_mult[i_level]
813
+ for i_block in range(self.num_res_blocks+1):
814
+ if i_block == self.num_res_blocks:
815
+ skip_in = ch*in_ch_mult[i_level]
816
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
817
+ out_channels=block_out,
818
+ temb_channels=self.temb_ch,
819
+ dropout=dropout))
820
+ block_in = block_out
821
+ if curr_res in attn_resolutions:
822
+ attn.append(AttnBlock(block_in))
823
+ up = nn.Module()
824
+ up.block = block
825
+ up.attn = attn
826
+ if i_level != 0:
827
+ up.upsample = Upsample(block_in, resamp_with_conv)
828
+ curr_res = curr_res * 2
829
+ self.up.insert(0, up) # prepend to get consistent order
830
+
831
+ # end
832
+ self.norm_out = Normalize(block_in)
833
+ self.conv_out = torch.nn.Conv2d(block_in,
834
+ out_ch,
835
+ kernel_size=3,
836
+ stride=1,
837
+ padding=1)
838
+
839
+
840
+ def forward(self, x, z):
841
+ #assert x.shape[2] == x.shape[3] == self.resolution
842
+
843
+ if self.use_timestep:
844
+ # timestep embedding
845
+ assert t is not None
846
+ temb = get_timestep_embedding(t, self.ch)
847
+ temb = self.temb.dense[0](temb)
848
+ temb = nonlinearity(temb)
849
+ temb = self.temb.dense[1](temb)
850
+ else:
851
+ temb = None
852
+
853
+ # downsampling
854
+ hs = [self.conv_in(x)]
855
+ for i_level in range(self.num_resolutions):
856
+ for i_block in range(self.num_res_blocks):
857
+ h = self.down[i_level].block[i_block](hs[-1], temb)
858
+ if len(self.down[i_level].attn) > 0:
859
+ h = self.down[i_level].attn[i_block](h)
860
+ hs.append(h)
861
+ if i_level != self.num_resolutions-1:
862
+ hs.append(self.down[i_level].downsample(hs[-1]))
863
+
864
+ # middle
865
+ h = hs[-1]
866
+ z = self.z_in(z)
867
+ h = torch.cat((h,z),dim=1)
868
+ h = self.mid.block_1(h, temb)
869
+ h = self.mid.attn_1(h)
870
+ h = self.mid.block_2(h, temb)
871
+
872
+ # upsampling
873
+ for i_level in reversed(range(self.num_resolutions)):
874
+ for i_block in range(self.num_res_blocks+1):
875
+ h = self.up[i_level].block[i_block](
876
+ torch.cat([h, hs.pop()], dim=1), temb)
877
+ if len(self.up[i_level].attn) > 0:
878
+ h = self.up[i_level].attn[i_block](h)
879
+ if i_level != 0:
880
+ h = self.up[i_level].upsample(h)
881
+
882
+ # end
883
+ h = self.norm_out(h)
884
+ h = nonlinearity(h)
885
+ h = self.conv_out(h)
886
+ return h
887
+
888
+
889
+ class SimpleDecoder(nn.Module):
890
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
891
+ super().__init__()
892
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
893
+ ResnetBlock(in_channels=in_channels,
894
+ out_channels=2 * in_channels,
895
+ temb_channels=0, dropout=0.0),
896
+ ResnetBlock(in_channels=2 * in_channels,
897
+ out_channels=4 * in_channels,
898
+ temb_channels=0, dropout=0.0),
899
+ ResnetBlock(in_channels=4 * in_channels,
900
+ out_channels=2 * in_channels,
901
+ temb_channels=0, dropout=0.0),
902
+ nn.Conv2d(2*in_channels, in_channels, 1),
903
+ Upsample(in_channels, with_conv=True)])
904
+ # end
905
+ self.norm_out = Normalize(in_channels)
906
+ self.conv_out = torch.nn.Conv2d(in_channels,
907
+ out_channels,
908
+ kernel_size=3,
909
+ stride=1,
910
+ padding=1)
911
+
912
+ def forward(self, x):
913
+ for i, layer in enumerate(self.model):
914
+ if i in [1,2,3]:
915
+ x = layer(x, None)
916
+ else:
917
+ x = layer(x)
918
+
919
+ h = self.norm_out(x)
920
+ h = nonlinearity(h)
921
+ x = self.conv_out(h)
922
+ return x
923
+
924
+
925
+ class UpsampleDecoder(nn.Module):
926
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
927
+ ch_mult=(2,2), dropout=0.0):
928
+ super().__init__()
929
+ # upsampling
930
+ self.temb_ch = 0
931
+ self.num_resolutions = len(ch_mult)
932
+ self.num_res_blocks = num_res_blocks
933
+ block_in = in_channels
934
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
935
+ self.res_blocks = nn.ModuleList()
936
+ self.upsample_blocks = nn.ModuleList()
937
+ for i_level in range(self.num_resolutions):
938
+ res_block = []
939
+ block_out = ch * ch_mult[i_level]
940
+ for i_block in range(self.num_res_blocks + 1):
941
+ res_block.append(ResnetBlock(in_channels=block_in,
942
+ out_channels=block_out,
943
+ temb_channels=self.temb_ch,
944
+ dropout=dropout))
945
+ block_in = block_out
946
+ self.res_blocks.append(nn.ModuleList(res_block))
947
+ if i_level != self.num_resolutions - 1:
948
+ self.upsample_blocks.append(Upsample(block_in, True))
949
+ curr_res = curr_res * 2
950
+
951
+ # end
952
+ self.norm_out = Normalize(block_in)
953
+ self.conv_out = torch.nn.Conv2d(block_in,
954
+ out_channels,
955
+ kernel_size=3,
956
+ stride=1,
957
+ padding=1)
958
+
959
+ def forward(self, x):
960
+ # upsampling
961
+ h = x
962
+ for k, i_level in enumerate(range(self.num_resolutions)):
963
+ for i_block in range(self.num_res_blocks + 1):
964
+ h = self.res_blocks[i_level][i_block](h, None)
965
+ if i_level != self.num_resolutions - 1:
966
+ h = self.upsample_blocks[k](h)
967
+ h = self.norm_out(h)
968
+ h = nonlinearity(h)
969
+ h = self.conv_out(h)
970
+ return h
971
+
972
+
973
+ if __name__ == '__main__':
974
+ ddconfig = {
975
+ 'ch': 128,
976
+ 'num_res_blocks': 2,
977
+ 'dropout': 0.0,
978
+ 'z_channels': 256,
979
+ 'double_z': False,
980
+ }
981
+
982
+ # Audio example ##
983
+ ddconfig['in_channels'] = 1
984
+ ddconfig['resolution'] = 848
985
+ ddconfig['attn_resolutions'] = [53]
986
+ ddconfig['ch_mult'] = [1, 1, 2, 2, 4]
987
+ ddconfig['out_ch'] = 1
988
+ # input
989
+ inputs = torch.rand(4, 1, 80, 848)
990
+ print('Input:', inputs.shape)
991
+ # Encoder
992
+ encoder = Encoder(**ddconfig)
993
+ enc_outs = encoder(inputs)
994
+ print('Encoder out:', enc_outs.shape)
995
+ # Decoder
996
+ decoder = Decoder(**ddconfig)
997
+ quant_outs = torch.rand(4, 256, 5, 53)
998
+ dec_outs = decoder(quant_outs)
999
+ print('Decoder out:', dec_outs.shape)
foleycrafter/models/specvqgan/modules/discriminator/model.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.nn as nn
3
+
4
+
5
+ class ActNorm(nn.Module):
6
+ def __init__(self, num_features, logdet=False, affine=True,
7
+ allow_reverse_init=False):
8
+ assert affine
9
+ super().__init__()
10
+ self.logdet = logdet
11
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
12
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
13
+ self.allow_reverse_init = allow_reverse_init
14
+
15
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
16
+
17
+ def initialize(self, input):
18
+ with torch.no_grad():
19
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
20
+ mean = (
21
+ flatten.mean(1)
22
+ .unsqueeze(1)
23
+ .unsqueeze(2)
24
+ .unsqueeze(3)
25
+ .permute(1, 0, 2, 3)
26
+ )
27
+ std = (
28
+ flatten.std(1)
29
+ .unsqueeze(1)
30
+ .unsqueeze(2)
31
+ .unsqueeze(3)
32
+ .permute(1, 0, 2, 3)
33
+ )
34
+
35
+ self.loc.data.copy_(-mean)
36
+ self.scale.data.copy_(1 / (std + 1e-6))
37
+
38
+ def forward(self, input, reverse=False):
39
+ if reverse:
40
+ return self.reverse(input)
41
+ if len(input.shape) == 2:
42
+ input = input[:, :, None, None]
43
+ squeeze = True
44
+ else:
45
+ squeeze = False
46
+
47
+ _, _, height, width = input.shape
48
+
49
+ if self.training and self.initialized.item() == 0:
50
+ self.initialize(input)
51
+ self.initialized.fill_(1)
52
+
53
+ h = self.scale * (input + self.loc)
54
+
55
+ if squeeze:
56
+ h = h.squeeze(-1).squeeze(-1)
57
+
58
+ if self.logdet:
59
+ log_abs = torch.log(torch.abs(self.scale))
60
+ logdet = height * width * torch.sum(log_abs)
61
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
62
+ return h, logdet
63
+
64
+ return h
65
+
66
+ def reverse(self, output):
67
+ if self.training and self.initialized.item() == 0:
68
+ if not self.allow_reverse_init:
69
+ raise RuntimeError(
70
+ "Initializing ActNorm in reverse direction is "
71
+ "disabled by default. Use allow_reverse_init=True to enable."
72
+ )
73
+ else:
74
+ self.initialize(output)
75
+ self.initialized.fill_(1)
76
+
77
+ if len(output.shape) == 2:
78
+ output = output[:, :, None, None]
79
+ squeeze = True
80
+ else:
81
+ squeeze = False
82
+
83
+ h = output / self.scale - self.loc
84
+
85
+ if squeeze:
86
+ h = h.squeeze(-1).squeeze(-1)
87
+ return h
88
+
89
+ def weights_init(m):
90
+ classname = m.__class__.__name__
91
+ if classname.find('Conv') != -1:
92
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
93
+ elif classname.find('BatchNorm') != -1:
94
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
95
+ nn.init.constant_(m.bias.data, 0)
96
+
97
+
98
+ class NLayerDiscriminator(nn.Module):
99
+ """Defines a PatchGAN discriminator as in Pix2Pix
100
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
101
+ """
102
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
103
+ """Construct a PatchGAN discriminator
104
+ Parameters:
105
+ input_nc (int) -- the number of channels in input images
106
+ ndf (int) -- the number of filters in the last conv layer
107
+ n_layers (int) -- the number of conv layers in the discriminator
108
+ norm_layer -- normalization layer
109
+ """
110
+ super(NLayerDiscriminator, self).__init__()
111
+ if not use_actnorm:
112
+ norm_layer = nn.BatchNorm2d
113
+ else:
114
+ norm_layer = ActNorm
115
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
116
+ use_bias = norm_layer.func != nn.BatchNorm2d
117
+ else:
118
+ use_bias = norm_layer != nn.BatchNorm2d
119
+
120
+ kw = 4
121
+ padw = 1
122
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
123
+ nf_mult = 1
124
+ nf_mult_prev = 1
125
+ for n in range(1, n_layers): # gradually increase the number of filters
126
+ nf_mult_prev = nf_mult
127
+ nf_mult = min(2 ** n, 8)
128
+ sequence += [
129
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
130
+ norm_layer(ndf * nf_mult),
131
+ nn.LeakyReLU(0.2, True)
132
+ ]
133
+
134
+ nf_mult_prev = nf_mult
135
+ nf_mult = min(2 ** n_layers, 8)
136
+ sequence += [
137
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
138
+ norm_layer(ndf * nf_mult),
139
+ nn.LeakyReLU(0.2, True)
140
+ ]
141
+ # output 1 channel prediction map
142
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
143
+ self.main = nn.Sequential(*sequence)
144
+
145
+ def forward(self, input):
146
+ """Standard forward."""
147
+ return self.main(input)
148
+
149
+ class NLayerDiscriminator1dFeats(NLayerDiscriminator):
150
+ """Defines a PatchGAN discriminator as in Pix2Pix
151
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
152
+ """
153
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
154
+ """Construct a PatchGAN discriminator
155
+ Parameters:
156
+ input_nc (int) -- the number of channels in input feats
157
+ ndf (int) -- the number of filters in the last conv layer
158
+ n_layers (int) -- the number of conv layers in the discriminator
159
+ norm_layer -- normalization layer
160
+ """
161
+ super().__init__(input_nc=input_nc, ndf=64, n_layers=n_layers, use_actnorm=use_actnorm)
162
+
163
+ if not use_actnorm:
164
+ norm_layer = nn.BatchNorm1d
165
+ else:
166
+ norm_layer = ActNorm
167
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm has affine parameters
168
+ use_bias = norm_layer.func != nn.BatchNorm1d
169
+ else:
170
+ use_bias = norm_layer != nn.BatchNorm1d
171
+
172
+ kw = 4
173
+ padw = 1
174
+ sequence = [nn.Conv1d(input_nc, input_nc//2, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
175
+ nf_mult = input_nc//2
176
+ nf_mult_prev = 1
177
+ for n in range(1, n_layers): # gradually decrease the number of filters
178
+ nf_mult_prev = nf_mult
179
+ nf_mult = max(nf_mult_prev // (2 ** n), 8)
180
+ sequence += [
181
+ nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
182
+ norm_layer(nf_mult),
183
+ nn.LeakyReLU(0.2, True)
184
+ ]
185
+
186
+ nf_mult_prev = nf_mult
187
+ nf_mult = max(nf_mult_prev // (2 ** n), 8)
188
+ sequence += [
189
+ nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
190
+ norm_layer(nf_mult),
191
+ nn.LeakyReLU(0.2, True)
192
+ ]
193
+ nf_mult_prev = nf_mult
194
+ nf_mult = max(nf_mult_prev // (2 ** n), 8)
195
+ sequence += [
196
+ nn.Conv1d(nf_mult_prev, nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
197
+ norm_layer(nf_mult),
198
+ nn.LeakyReLU(0.2, True)
199
+ ]
200
+ # output 1 channel prediction map
201
+ sequence += [nn.Conv1d(nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
202
+ self.main = nn.Sequential(*sequence)
203
+
204
+
205
+ class NLayerDiscriminator1dSpecs(NLayerDiscriminator):
206
+ """Defines a PatchGAN discriminator as in Pix2Pix
207
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
208
+ """
209
+ def __init__(self, input_nc=80, ndf=64, n_layers=3, use_actnorm=False):
210
+ """Construct a PatchGAN discriminator
211
+ Parameters:
212
+ input_nc (int) -- the number of channels in input specs
213
+ ndf (int) -- the number of filters in the last conv layer
214
+ n_layers (int) -- the number of conv layers in the discriminator
215
+ norm_layer -- normalization layer
216
+ """
217
+ super().__init__(input_nc=input_nc, ndf=64, n_layers=n_layers, use_actnorm=use_actnorm)
218
+
219
+ if not use_actnorm:
220
+ norm_layer = nn.BatchNorm1d
221
+ else:
222
+ norm_layer = ActNorm
223
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm has affine parameters
224
+ use_bias = norm_layer.func != nn.BatchNorm1d
225
+ else:
226
+ use_bias = norm_layer != nn.BatchNorm1d
227
+
228
+ kw = 4
229
+ padw = 1
230
+ sequence = [nn.Conv1d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
231
+ nf_mult = 1
232
+ nf_mult_prev = 1
233
+ for n in range(1, n_layers): # gradually decrease the number of filters
234
+ nf_mult_prev = nf_mult
235
+ nf_mult = min(2 ** n, 8)
236
+ sequence += [
237
+ nn.Conv1d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
238
+ norm_layer(ndf * nf_mult),
239
+ nn.LeakyReLU(0.2, True)
240
+ ]
241
+
242
+ nf_mult_prev = nf_mult
243
+ nf_mult = min(2 ** n_layers, 8)
244
+ sequence += [
245
+ nn.Conv1d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
246
+ norm_layer(ndf * nf_mult),
247
+ nn.LeakyReLU(0.2, True)
248
+ ]
249
+ # output 1 channel prediction map
250
+ sequence += [nn.Conv1d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
251
+ self.main = nn.Sequential(*sequence)
252
+
253
+ def forward(self, input):
254
+ """Standard forward."""
255
+ # (B, C, L)
256
+ input = input.squeeze(1)
257
+ input = self.main(input)
258
+ return input
259
+
260
+
261
+ if __name__ == '__main__':
262
+ import torch
263
+
264
+ ## FEATURES
265
+ disc_in_channels = 2048
266
+ disc_num_layers = 2
267
+ use_actnorm = False
268
+ disc_ndf = 64
269
+ discriminator = NLayerDiscriminator1dFeats(input_nc=disc_in_channels, n_layers=disc_num_layers,
270
+ use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
271
+ inputs = torch.rand((6, 2048, 212))
272
+ outputs = discriminator(inputs)
273
+ print(outputs.shape)
274
+
275
+ ## AUDIO
276
+ disc_in_channels = 1
277
+ disc_num_layers = 3
278
+ use_actnorm = False
279
+ disc_ndf = 64
280
+ discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers,
281
+ use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
282
+ inputs = torch.rand((6, 1, 80, 848))
283
+ outputs = discriminator(inputs)
284
+ print(outputs.shape)
285
+
286
+ ## IMAGE
287
+ disc_in_channels = 3
288
+ disc_num_layers = 3
289
+ use_actnorm = False
290
+ disc_ndf = 64
291
+ discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers,
292
+ use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
293
+ inputs = torch.rand((6, 3, 256, 256))
294
+ outputs = discriminator(inputs)
295
+ print(outputs.shape)
foleycrafter/models/specvqgan/modules/losses/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from foleycrafter.models.specvqgan.modules.losses.vqperceptual import DummyLoss
2
+
3
+ # relative imports pain
4
+ import os
5
+ import sys
6
+ path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'vggishish')
7
+ sys.path.append(path)
foleycrafter/models/specvqgan/modules/losses/lpaps.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on https://github.com/CompVis/taming-transformers/blob/52720829/taming/modules/losses/lpips.py
3
+ Adapted for spectrograms by Vladimir Iashin (v-iashin)
4
+ """
5
+ from collections import namedtuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ import sys
12
+ sys.path.insert(0, '.') # nopep8
13
+ from foleycrafter.models.specvqgan.modules.losses.vggishish.model import VGGishish
14
+ from foleycrafter.models.specvqgan.util import get_ckpt_path
15
+
16
+
17
+ class LPAPS(nn.Module):
18
+ # Learned perceptual metric
19
+ def __init__(self, use_dropout=True):
20
+ super().__init__()
21
+ self.scaling_layer = ScalingLayer()
22
+ self.chns = [64, 128, 256, 512, 512] # vggish16 features
23
+ self.net = vggishish16(pretrained=True, requires_grad=False)
24
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
25
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
26
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
27
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
28
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
29
+ self.load_from_pretrained()
30
+ for param in self.parameters():
31
+ param.requires_grad = False
32
+
33
+ def load_from_pretrained(self, name="vggishish_lpaps"):
34
+ ckpt = get_ckpt_path(name, "specvqgan/modules/autoencoder/lpaps")
35
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
36
+ print("loaded pretrained LPAPS loss from {}".format(ckpt))
37
+
38
+ @classmethod
39
+ def from_pretrained(cls, name="vggishish_lpaps"):
40
+ if name != "vggishish_lpaps":
41
+ raise NotImplementedError
42
+ model = cls()
43
+ ckpt = get_ckpt_path(name)
44
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
45
+ return model
46
+
47
+ def forward(self, input, target):
48
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
49
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
50
+ feats0, feats1, diffs = {}, {}, {}
51
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
52
+ for kk in range(len(self.chns)):
53
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
54
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
55
+
56
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
57
+ val = res[0]
58
+ for l in range(1, len(self.chns)):
59
+ val += res[l]
60
+ return val
61
+
62
+ class ScalingLayer(nn.Module):
63
+ def __init__(self):
64
+ super(ScalingLayer, self).__init__()
65
+ # we are gonna use get_ckpt_path to donwload the stats as well
66
+ stat_path = get_ckpt_path('vggishish_mean_std_melspec_10s_22050hz', 'specvqgan/modules/autoencoder/lpaps')
67
+ # if for images we normalize on the channel dim, in spectrogram we will norm on frequency dimension
68
+ means, stds = np.loadtxt(stat_path, dtype=np.float32).T
69
+ # the normalization in means and stds are given for [0, 1], but specvqgan expects [-1, 1]:
70
+ means = 2 * means - 1
71
+ stds = 2 * stds
72
+ # input is expected to be (B, 1, F, T)
73
+ self.register_buffer('shift', torch.from_numpy(means)[None, None, :, None])
74
+ self.register_buffer('scale', torch.from_numpy(stds)[None, None, :, None])
75
+
76
+ def forward(self, inp):
77
+ return (inp - self.shift) / self.scale
78
+
79
+
80
+ class NetLinLayer(nn.Module):
81
+ """ A single linear layer which does a 1x1 conv """
82
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
83
+ super(NetLinLayer, self).__init__()
84
+ layers = [nn.Dropout(), ] if (use_dropout) else []
85
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
86
+ self.model = nn.Sequential(*layers)
87
+
88
+ class vggishish16(torch.nn.Module):
89
+ def __init__(self, requires_grad=False, pretrained=True):
90
+ super().__init__()
91
+ vgg_pretrained_features = self.vggishish16(pretrained=pretrained).features
92
+ self.slice1 = torch.nn.Sequential()
93
+ self.slice2 = torch.nn.Sequential()
94
+ self.slice3 = torch.nn.Sequential()
95
+ self.slice4 = torch.nn.Sequential()
96
+ self.slice5 = torch.nn.Sequential()
97
+ self.N_slices = 5
98
+ for x in range(4):
99
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
100
+ for x in range(4, 9):
101
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
102
+ for x in range(9, 16):
103
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
104
+ for x in range(16, 23):
105
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
106
+ for x in range(23, 30):
107
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
108
+ if not requires_grad:
109
+ for param in self.parameters():
110
+ param.requires_grad = False
111
+
112
+ def forward(self, X):
113
+ h = self.slice1(X)
114
+ h_relu1_2 = h
115
+ h = self.slice2(h)
116
+ h_relu2_2 = h
117
+ h = self.slice3(h)
118
+ h_relu3_3 = h
119
+ h = self.slice4(h)
120
+ h_relu4_3 = h
121
+ h = self.slice5(h)
122
+ h_relu5_3 = h
123
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
124
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
125
+ return out
126
+
127
+ def vggishish16(self, pretrained: bool = True) -> VGGishish:
128
+ # loading vggishish pretrained on vggsound
129
+ num_classes_vggsound = 309
130
+ conv_layers = [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
131
+ model = VGGishish(conv_layers, use_bn=False, num_classes=num_classes_vggsound)
132
+ if pretrained:
133
+ ckpt_path = get_ckpt_path('vggishish_lpaps', "specvqgan/modules/autoencoder/lpaps")
134
+ ckpt = torch.load(ckpt_path, map_location=torch.device("cpu"))
135
+ model.load_state_dict(ckpt, strict=False)
136
+ return model
137
+
138
+ def normalize_tensor(x, eps=1e-10):
139
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
140
+ return x / (norm_factor+eps)
141
+
142
+ def spatial_average(x, keepdim=True):
143
+ return x.mean([2, 3], keepdim=keepdim)
144
+
145
+
146
+ if __name__ == '__main__':
147
+ inputs = torch.rand((16, 1, 80, 848))
148
+ reconstructions = torch.rand((16, 1, 80, 848))
149
+ lpips = LPAPS().eval()
150
+ loss_p = lpips(inputs.contiguous(), reconstructions.contiguous())
151
+ # (16, 1, 1, 1)
152
+ print(loss_p.shape)
foleycrafter/models/specvqgan/modules/losses/vggishish/configs/melception.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 1337
2
+ log_code_state: True
3
+ # patterns to ignore when backing up the code folder
4
+ patterns_to_ignore: ['logs', '.git', '__pycache__', 'data', 'checkpoints', '*.pt']
5
+
6
+ # data:
7
+ mels_path: '/home/nvme/data/vggsound/features/melspec_10s_22050hz/'
8
+ spec_shape: [80, 860]
9
+ cropped_size: [80, 848]
10
+ random_crop: False
11
+
12
+ # train:
13
+ device: 'cuda:0'
14
+ batch_size: 8
15
+ num_workers: 0
16
+ optimizer: adam
17
+ betas: [0.9, 0.999]
18
+ momentum: 0.9
19
+ learning_rate: 3e-4
20
+ weight_decay: 0
21
+ num_epochs: 100
22
+ patience: 3
23
+ logdir: './logs'
24
+ cls_weights_in_loss: False
foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 1337
2
+ log_code_state: True
3
+ # patterns to ignore when backing up the code folder
4
+ patterns_to_ignore: ['logs', '.git', '__pycache__']
5
+
6
+ # data:
7
+ mels_path: '/home/nvme/data/vggsound/features/melspec_10s_22050hz/'
8
+ spec_shape: [80, 860]
9
+ cropped_size: [80, 848]
10
+ random_crop: False
11
+
12
+ # model:
13
+ # original vgg family except for MP is missing at the end
14
+ # 'vggish': [64, 'MP', 128, 'MP', 256, 256, 'MP', 512, 512]
15
+ # 'vgg11': [64, 'MP', 128, 'MP', 256, 256, 'MP', 512, 512, 'MP', 512, 512],
16
+ # 'vgg13': [64, 64, 'MP', 128, 128, 'MP', 256, 256, 'MP', 512, 512, 'MP', 512, 512],
17
+ # 'vgg16': [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512],
18
+ # 'vgg19': [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 256, 'MP', 512, 512, 512, 512, 'MP', 512, 512, 512, 512],
19
+ conv_layers: [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
20
+ use_bn: False
21
+
22
+ # train:
23
+ device: 'cuda:0'
24
+ batch_size: 32
25
+ num_workers: 0
26
+ optimizer: adam
27
+ betas: [0.9, 0.999]
28
+ momentum: 0.9
29
+ learning_rate: 3e-4
30
+ weight_decay: 0.0001
31
+ num_epochs: 100
32
+ patience: 3
33
+ logdir: './logs'
34
+ cls_weights_in_loss: False
foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 1337
2
+ log_code_state: True
3
+ patterns_to_ignore: ['logs', '.git', '__pycache__']
4
+
5
+ mels_path: '/home/duyxxd/SpecVQGAN/data/greatesthit/melspec_10s_22050hz'
6
+ batch_size: 32
7
+ num_workers: 8
8
+ device: 'cuda:0'
9
+ conv_layers: [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
10
+ use_bn: False
11
+ optimizer: adam
12
+ learning_rate: 1e-4
13
+ betas: [0.9, 0.999]
14
+ cropped_size: [80, 160]
15
+ momentum: 0.9
16
+ weight_decay: 1e-4
17
+ cls_weights_in_loss: False
18
+ num_epochs: 100
19
+ patience: 20
20
+ logdir: '/home/duyxxd/SpecVQGAN/logs'
21
+ exp_name: 'mix'
22
+ action_only: False
23
+ material_only: False
24
+
25
+ load_model: /home/duyxxd/SpecVQGAN/logs/vggishish16.pt
foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh_action.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 1337
2
+ log_code_state: True
3
+ patterns_to_ignore: ['logs', '.git', '__pycache__']
4
+
5
+ mels_path: '/home/duyxxd/SpecVQGAN/data/greatesthit/melspec_10s_22050hz'
6
+ batch_size: 32
7
+ num_workers: 8
8
+ device: 'cuda:0'
9
+ conv_layers: [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
10
+ use_bn: False
11
+ optimizer: adam
12
+ learning_rate: 1e-4
13
+ betas: [0.9, 0.999]
14
+ cropped_size: [80, 160]
15
+ momentum: 0.9
16
+ weight_decay: 1e-4
17
+ cls_weights_in_loss: False
18
+ num_epochs: 20
19
+ patience: 20
20
+ logdir: '/home/duyxxd/SpecVQGAN/logs'
21
+ exp_name: 'action'
22
+ action_only: True
23
+ material_only: False
24
+
25
+ load_model: /home/duyxxd/SpecVQGAN/logs/vggishish16.pt
foleycrafter/models/specvqgan/modules/losses/vggishish/configs/vggish_gh_material.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 1337
2
+ log_code_state: True
3
+ patterns_to_ignore: ['logs', '.git', '__pycache__']
4
+
5
+ mels_path: '/home/duyxxd/SpecVQGAN/data/greatesthit/melspec_10s_22050hz'
6
+ batch_size: 32
7
+ num_workers: 8
8
+ device: 'cuda:0'
9
+ conv_layers: [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
10
+ use_bn: False
11
+ optimizer: adam
12
+ learning_rate: 1e-4
13
+ betas: [0.9, 0.999]
14
+ cropped_size: [80, 160]
15
+ momentum: 0.9
16
+ weight_decay: 1e-4
17
+ cls_weights_in_loss: False
18
+ num_epochs: 20
19
+ patience: 20
20
+ logdir: '/home/duyxxd/SpecVQGAN/logs'
21
+ exp_name: 'material'
22
+ action_only: False
23
+ material_only: True
24
+
25
+ load_model: /home/duyxxd/SpecVQGAN/logs/vggishish16.pt
foleycrafter/models/specvqgan/modules/losses/vggishish/dataset.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import csv
3
+ import logging
4
+ import os
5
+ import random
6
+ import math
7
+ import json
8
+ from glob import glob
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torchvision
14
+
15
+ logger = logging.getLogger(f'main.{__name__}')
16
+
17
+
18
+ class VGGSound(torch.utils.data.Dataset):
19
+
20
+ def __init__(self, split, specs_dir, transforms=None, splits_path='./data', meta_path='./data/vggsound.csv'):
21
+ super().__init__()
22
+ self.split = split
23
+ self.specs_dir = specs_dir
24
+ self.transforms = transforms
25
+ self.splits_path = splits_path
26
+ self.meta_path = meta_path
27
+
28
+ vggsound_meta = list(csv.reader(open(meta_path), quotechar='"'))
29
+ unique_classes = sorted(list(set(row[2] for row in vggsound_meta)))
30
+ self.label2target = {label: target for target, label in enumerate(unique_classes)}
31
+ self.target2label = {target: label for label, target in self.label2target.items()}
32
+ self.video2target = {row[0]: self.label2target[row[2]] for row in vggsound_meta}
33
+
34
+ split_clip_ids_path = os.path.join(splits_path, f'vggsound_{split}_partial.txt')
35
+ print('&&&&&&&&&&&&&&&&', split_clip_ids_path)
36
+ if not os.path.exists(split_clip_ids_path):
37
+ self.make_split_files()
38
+ clip_ids_with_timestamp = open(split_clip_ids_path).read().splitlines()
39
+ clip_paths = [os.path.join(specs_dir, v + '_mel.npy') for v in clip_ids_with_timestamp]
40
+ self.dataset = clip_paths
41
+ # self.dataset = clip_paths[:10000] # overfit one batch
42
+
43
+ # 'zyTX_1BXKDE_16000_26000'[:11] -> 'zyTX_1BXKDE'
44
+ vid_classes = [self.video2target[Path(path).stem[:11]] for path in self.dataset]
45
+ class2count = collections.Counter(vid_classes)
46
+ self.class_counts = torch.tensor([class2count[cls] for cls in range(len(class2count))])
47
+ # self.sample_weights = [len(self.dataset) / class2count[self.video2target[Path(path).stem[:11]]] for path in self.dataset]
48
+
49
+ def __getitem__(self, idx):
50
+ item = {}
51
+
52
+ spec_path = self.dataset[idx]
53
+ # 'zyTX_1BXKDE_16000_26000' -> 'zyTX_1BXKDE'
54
+ video_name = Path(spec_path).stem[:11]
55
+
56
+ item['input'] = np.load(spec_path)
57
+ item['input_path'] = spec_path
58
+
59
+ # if self.split in ['train', 'valid']:
60
+ item['target'] = self.video2target[video_name]
61
+ item['label'] = self.target2label[item['target']]
62
+
63
+ if self.transforms is not None:
64
+ item = self.transforms(item)
65
+
66
+ return item
67
+
68
+ def __len__(self):
69
+ return len(self.dataset)
70
+
71
+ def make_split_files(self):
72
+ random.seed(1337)
73
+ logger.info(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.')
74
+ # The downloaded videos (some went missing on YouTube and no longer available)
75
+ available_vid_paths = sorted(glob(os.path.join(self.specs_dir, '*_mel.npy')))
76
+ logger.info(f'The number of clips available after download: {len(available_vid_paths)}')
77
+
78
+ # original (full) train and test sets
79
+ vggsound_meta = list(csv.reader(open(self.meta_path), quotechar='"'))
80
+ train_vids = {row[0] for row in vggsound_meta if row[3] == 'train'}
81
+ test_vids = {row[0] for row in vggsound_meta if row[3] == 'test'}
82
+ logger.info(f'The number of videos in vggsound train set: {len(train_vids)}')
83
+ logger.info(f'The number of videos in vggsound test set: {len(test_vids)}')
84
+
85
+ # class counts in test set. We would like to have the same distribution in valid
86
+ unique_classes = sorted(list(set(row[2] for row in vggsound_meta)))
87
+ label2target = {label: target for target, label in enumerate(unique_classes)}
88
+ video2target = {row[0]: label2target[row[2]] for row in vggsound_meta}
89
+ test_vid_classes = [video2target[vid] for vid in test_vids]
90
+ test_target2count = collections.Counter(test_vid_classes)
91
+
92
+ # now given the counts from test set, sample the same count for validation and the rest leave in train
93
+ train_vids_wo_valid, valid_vids = set(), set()
94
+ for target, label in enumerate(label2target.keys()):
95
+ class_train_vids = [vid for vid in train_vids if video2target[vid] == target]
96
+ random.shuffle(class_train_vids)
97
+ count = test_target2count[target]
98
+ valid_vids.update(class_train_vids[:count])
99
+ train_vids_wo_valid.update(class_train_vids[count:])
100
+
101
+ # make file with a list of available test videos (each video should contain timestamps as well)
102
+ train_i = valid_i = test_i = 0
103
+ with open(os.path.join(self.splits_path, 'vggsound_train.txt'), 'w') as train_file, \
104
+ open(os.path.join(self.splits_path, 'vggsound_valid.txt'), 'w') as valid_file, \
105
+ open(os.path.join(self.splits_path, 'vggsound_test.txt'), 'w') as test_file:
106
+ for path in available_vid_paths:
107
+ path = path.replace('_mel.npy', '')
108
+ vid_name = Path(path).name
109
+ # 'zyTX_1BXKDE_16000_26000'[:11] -> 'zyTX_1BXKDE'
110
+ if vid_name[:11] in train_vids_wo_valid:
111
+ train_file.write(vid_name + '\n')
112
+ train_i += 1
113
+ elif vid_name[:11] in valid_vids:
114
+ valid_file.write(vid_name + '\n')
115
+ valid_i += 1
116
+ elif vid_name[:11] in test_vids:
117
+ test_file.write(vid_name + '\n')
118
+ test_i += 1
119
+ else:
120
+ raise Exception(f'Clip {vid_name} is neither in train, valid nor test. Strange.')
121
+
122
+ logger.info(f'Put {train_i} clips to the train set and saved it to ./data/vggsound_train.txt')
123
+ logger.info(f'Put {valid_i} clips to the valid set and saved it to ./data/vggsound_valid.txt')
124
+ logger.info(f'Put {test_i} clips to the test set and saved it to ./data/vggsound_test.txt')
125
+
126
+
127
+ def get_GH_data_identifier(video_name, start_idx, split='_'):
128
+ if isinstance(start_idx, str):
129
+ return video_name + split + start_idx
130
+ elif isinstance(start_idx, int):
131
+ return video_name + split + str(start_idx)
132
+ else:
133
+ raise NotImplementedError
134
+
135
+
136
+ class GreatestHit(torch.utils.data.Dataset):
137
+
138
+ def __init__(self, split, spec_dir_path, spec_transform=None, L=2.0, action_only=False,
139
+ material_only=False, splits_path='/home/duyxxd/SpecVQGAN/data',
140
+ meta_path='/home/duyxxd/SpecVQGAN/data/info_r2plus1d_dim1024_15fps.json'):
141
+ super().__init__()
142
+ self.split = split
143
+ self.specs_dir = spec_dir_path
144
+ self.splits_path = splits_path
145
+ self.meta_path = meta_path
146
+ self.spec_transform = spec_transform
147
+ self.L = L
148
+ self.spec_take_first = int(math.ceil(860 * (L / 10.) / 32) * 32)
149
+ self.spec_take_first = 860 if self.spec_take_first > 860 else self.spec_take_first
150
+ self.spec_take_first = 173
151
+
152
+ greatesthit_meta = json.load(open(self.meta_path, 'r'))
153
+ self.video_idx2label = {
154
+ get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]):
155
+ greatesthit_meta['hit_type'][i] for i in range(len(greatesthit_meta['video_name']))
156
+ }
157
+ self.available_video_hit = list(self.video_idx2label.keys())
158
+ self.video_idx2path = {
159
+ vh: os.path.join(self.specs_dir,
160
+ vh.replace('_', '_denoised_') + '_' + self.video_idx2label[vh].replace(' ', '_') +'_mel.npy')
161
+ for vh in self.available_video_hit
162
+ }
163
+ self.video_idx2idx = {
164
+ get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]):
165
+ i for i in range(len(greatesthit_meta['video_name']))
166
+ }
167
+
168
+ split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}_2.00_single_type_only.json')
169
+ if not os.path.exists(split_clip_ids_path):
170
+ raise NotImplementedError()
171
+ clip_video_hit = json.load(open(split_clip_ids_path, 'r'))
172
+ self.dataset = list(clip_video_hit.keys())
173
+ if action_only:
174
+ self.video_idx2label = {k: v.split(' ')[1] for k, v in clip_video_hit.items()}
175
+ elif material_only:
176
+ self.video_idx2label = {k: v.split(' ')[0] for k, v in clip_video_hit.items()}
177
+ else:
178
+ self.video_idx2label = clip_video_hit
179
+
180
+
181
+ self.video2indexes = {}
182
+ for video_idx in self.dataset:
183
+ video, start_idx = video_idx.split('_')
184
+ if video not in self.video2indexes.keys():
185
+ self.video2indexes[video] = []
186
+ self.video2indexes[video].append(start_idx)
187
+ for video in self.video2indexes.keys():
188
+ if len(self.video2indexes[video]) == 1: # given video contains only one hit
189
+ self.dataset.remove(
190
+ get_GH_data_identifier(video, self.video2indexes[video][0])
191
+ )
192
+
193
+ vid_classes = list(self.video_idx2label.values())
194
+ unique_classes = sorted(list(set(vid_classes)))
195
+ self.label2target = {label: target for target, label in enumerate(unique_classes)}
196
+ if action_only:
197
+ label2target_fix = {'hit': 0, 'scratch': 1}
198
+ elif material_only:
199
+ label2target_fix = {'carpet': 0, 'ceramic': 1, 'cloth': 2, 'dirt': 3, 'drywall': 4, 'glass': 5, 'grass': 6, 'gravel': 7, 'leaf': 8, 'metal': 9, 'paper': 10, 'plastic': 11, 'plastic-bag': 12, 'rock': 13, 'tile': 14, 'water': 15, 'wood': 16}
200
+ else:
201
+ label2target_fix = {'carpet hit': 0, 'carpet scratch': 1, 'ceramic hit': 2, 'ceramic scratch': 3, 'cloth hit': 4, 'cloth scratch': 5, 'dirt hit': 6, 'dirt scratch': 7, 'drywall hit': 8, 'drywall scratch': 9, 'glass hit': 10, 'glass scratch': 11, 'grass hit': 12, 'grass scratch': 13, 'gravel hit': 14, 'gravel scratch': 15, 'leaf hit': 16, 'leaf scratch': 17, 'metal hit': 18, 'metal scratch': 19, 'paper hit': 20, 'paper scratch': 21, 'plastic hit': 22, 'plastic scratch': 23, 'plastic-bag hit': 24, 'plastic-bag scratch': 25, 'rock hit': 26, 'rock scratch': 27, 'tile hit': 28, 'tile scratch': 29, 'water hit': 30, 'water scratch': 31, 'wood hit': 32, 'wood scratch': 33}
202
+ for k in self.label2target.keys():
203
+ assert k in label2target_fix.keys()
204
+ self.label2target = label2target_fix
205
+ self.target2label = {target: label for label, target in self.label2target.items()}
206
+ class2count = collections.Counter(vid_classes)
207
+ self.class_counts = torch.tensor([class2count[cls] for cls in range(len(class2count))])
208
+ print(self.label2target)
209
+ print(len(vid_classes), len(class2count), class2count)
210
+
211
+ def __len__(self):
212
+ return len(self.dataset)
213
+
214
+ def __getitem__(self, idx):
215
+ item = {}
216
+
217
+ video_idx = self.dataset[idx]
218
+ spec_path = self.video_idx2path[video_idx]
219
+ spec = np.load(spec_path) # (80, 860)
220
+
221
+ # concat spec outside dataload
222
+ item['input'] = 2 * spec - 1 # (80, 860)
223
+ item['input'] = item['input'][:, :self.spec_take_first] # (80, 173) (since 2sec audio can only generate 173)
224
+ item['file_path'] = spec_path
225
+
226
+ item['label'] = self.video_idx2label[video_idx]
227
+ item['target'] = self.label2target[item['label']]
228
+
229
+ if self.spec_transform is not None:
230
+ item = self.spec_transform(item)
231
+
232
+ return item
233
+
234
+
235
+
236
+ class AMT_test(torch.utils.data.Dataset):
237
+
238
+ def __init__(self, spec_dir_path, spec_transform=None, action_only=False, material_only=False):
239
+ super().__init__()
240
+ self.specs_dir = spec_dir_path
241
+ self.spec_transform = spec_transform
242
+ self.spec_take_first = 173
243
+
244
+ self.dataset = sorted([os.path.join(self.specs_dir, f) for f in os.listdir(self.specs_dir)])
245
+ if action_only:
246
+ self.label2target = {'hit': 0, 'scratch': 1}
247
+ elif material_only:
248
+ self.label2target = {'carpet': 0, 'ceramic': 1, 'cloth': 2, 'dirt': 3, 'drywall': 4, 'glass': 5, 'grass': 6, 'gravel': 7, 'leaf': 8, 'metal': 9, 'paper': 10, 'plastic': 11, 'plastic-bag': 12, 'rock': 13, 'tile': 14, 'water': 15, 'wood': 16}
249
+ else:
250
+ self.label2target = {'carpet hit': 0, 'carpet scratch': 1, 'ceramic hit': 2, 'ceramic scratch': 3, 'cloth hit': 4, 'cloth scratch': 5, 'dirt hit': 6, 'dirt scratch': 7, 'drywall hit': 8, 'drywall scratch': 9, 'glass hit': 10, 'glass scratch': 11, 'grass hit': 12, 'grass scratch': 13, 'gravel hit': 14, 'gravel scratch': 15, 'leaf hit': 16, 'leaf scratch': 17, 'metal hit': 18, 'metal scratch': 19, 'paper hit': 20, 'paper scratch': 21, 'plastic hit': 22, 'plastic scratch': 23, 'plastic-bag hit': 24, 'plastic-bag scratch': 25, 'rock hit': 26, 'rock scratch': 27, 'tile hit': 28, 'tile scratch': 29, 'water hit': 30, 'water scratch': 31, 'wood hit': 32, 'wood scratch': 33}
251
+ self.target2label = {v: k for k, v in self.label2target.items()}
252
+
253
+ def __len__(self):
254
+ return len(self.dataset)
255
+
256
+ def __getitem__(self, idx):
257
+ item = {}
258
+
259
+ spec_path = self.dataset[idx]
260
+ spec = np.load(spec_path) # (80, 860)
261
+
262
+ # concat spec outside dataload
263
+ item['input'] = 2 * spec - 1 # (80, 860)
264
+ item['input'] = item['input'][:, :self.spec_take_first] # (80, 173) (since 2sec audio can only generate 173)
265
+ item['file_path'] = spec_path
266
+
267
+ if self.spec_transform is not None:
268
+ item = self.spec_transform(item)
269
+
270
+ return item
271
+
272
+
273
+ if __name__ == '__main__':
274
+ from transforms import Crop, StandardNormalizeAudio, ToTensor
275
+ specs_path = '/home/nvme/data/vggsound/features/melspec_10s_22050hz/'
276
+
277
+ transforms = torchvision.transforms.transforms.Compose([
278
+ StandardNormalizeAudio(specs_path),
279
+ ToTensor(),
280
+ Crop([80, 848]),
281
+ ])
282
+
283
+ datasets = {
284
+ 'train': VGGSound('train', specs_path, transforms),
285
+ 'valid': VGGSound('valid', specs_path, transforms),
286
+ 'test': VGGSound('test', specs_path, transforms),
287
+ }
288
+
289
+ print(datasets['train'][0])
290
+ print(datasets['valid'][0])
291
+ print(datasets['test'][0])
292
+
293
+ print(datasets['train'].class_counts)
294
+ print(datasets['valid'].class_counts)
295
+ print(datasets['test'].class_counts)
foleycrafter/models/specvqgan/modules/losses/vggishish/logger.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from shutil import copytree, ignore_patterns
5
+
6
+ import torch
7
+ from omegaconf import OmegaConf
8
+ from torch.utils.tensorboard import SummaryWriter, summary
9
+
10
+
11
+ class LoggerWithTBoard(SummaryWriter):
12
+
13
+ def __init__(self, cfg):
14
+ # current time stamp and experiment log directory
15
+ self.start_time = time.strftime('%y-%m-%dT%H-%M-%S', time.localtime())
16
+ if cfg.exp_name is not None:
17
+ self.logdir = os.path.join(cfg.logdir, self.start_time + f'_{cfg.exp_name}')
18
+ else:
19
+ self.logdir = os.path.join(cfg.logdir, self.start_time)
20
+ # init tboard
21
+ super().__init__(self.logdir)
22
+ # backup the cfg
23
+ OmegaConf.save(cfg, os.path.join(self.log_dir, 'cfg.yaml'))
24
+ # backup the code state
25
+ if cfg.log_code_state:
26
+ dest_dir = os.path.join(self.logdir, 'code')
27
+ copytree(os.getcwd(), dest_dir, ignore=ignore_patterns(*cfg.patterns_to_ignore))
28
+
29
+ # init logger which handles printing and logging mostly same things to the log file
30
+ self.print_logger = logging.getLogger('main')
31
+ self.print_logger.setLevel(logging.INFO)
32
+ msgfmt = '[%(levelname)s] %(asctime)s - %(name)s \n %(message)s'
33
+ datefmt = '%d %b %Y %H:%M:%S'
34
+ formatter = logging.Formatter(msgfmt, datefmt)
35
+ # stdout
36
+ sh = logging.StreamHandler()
37
+ sh.setLevel(logging.DEBUG)
38
+ sh.setFormatter(formatter)
39
+ self.print_logger.addHandler(sh)
40
+ # log file
41
+ fh = logging.FileHandler(os.path.join(self.log_dir, 'log.txt'))
42
+ fh.setLevel(logging.INFO)
43
+ fh.setFormatter(formatter)
44
+ self.print_logger.addHandler(fh)
45
+
46
+ self.print_logger.info(f'Saving logs and checkpoints @ {self.logdir}')
47
+
48
+ def log_param_num(self, model):
49
+ param_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
50
+ self.print_logger.info(f'The number of parameters: {param_num/1e+6:.3f} mil')
51
+ self.add_scalar('num_params', param_num, 0)
52
+ return param_num
53
+
54
+ def log_iter_loss(self, loss, iter, phase):
55
+ self.add_scalar(f'{phase}/loss_iter', loss, iter)
56
+
57
+ def log_epoch_loss(self, loss, epoch, phase):
58
+ self.add_scalar(f'{phase}/loss', loss, epoch)
59
+ self.print_logger.info(f'{phase} ({epoch}): loss {loss:.3f};')
60
+
61
+ def log_epoch_metrics(self, metrics_dict, epoch, phase):
62
+ for metric, val in metrics_dict.items():
63
+ self.add_scalar(f'{phase}/{metric}', val, epoch)
64
+ metrics_dict = {k: round(v, 4) for k, v in metrics_dict.items()}
65
+ self.print_logger.info(f'{phase} ({epoch}) metrics: {metrics_dict};')
66
+
67
+ def log_test_metrics(self, metrics_dict, hparams_dict, best_epoch):
68
+ allowed_types = (int, float, str, bool, torch.Tensor)
69
+ hparams_dict = {k: v for k, v in hparams_dict.items() if isinstance(v, allowed_types)}
70
+ metrics_dict = {f'test/{k}': round(v, 4) for k, v in metrics_dict.items()}
71
+ exp, ssi, sei = summary.hparams(hparams_dict, metrics_dict)
72
+ self.file_writer.add_summary(exp)
73
+ self.file_writer.add_summary(ssi)
74
+ self.file_writer.add_summary(sei)
75
+ for k, v in metrics_dict.items():
76
+ self.add_scalar(k, v, best_epoch)
77
+ self.print_logger.info(f'test ({best_epoch}) metrics: {metrics_dict};')
78
+
79
+ def log_best_model(self, model, loss, epoch, optimizer, metrics_dict):
80
+ model_name = model.__class__.__name__
81
+ self.best_model_path = os.path.join(self.logdir, f'{model_name}-{self.start_time}.pt')
82
+ checkpoint = {
83
+ 'loss': loss,
84
+ 'metrics': metrics_dict,
85
+ 'epoch': epoch,
86
+ 'optimizer': optimizer.state_dict(),
87
+ 'model': model.state_dict(),
88
+ }
89
+ torch.save(checkpoint, self.best_model_path)
90
+ self.print_logger.info(f'Saved model in {self.best_model_path}')
foleycrafter/models/specvqgan/modules/losses/vggishish/loss.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+
6
+ class WeightedCrossEntropy(nn.CrossEntropyLoss):
7
+
8
+ def __init__(self, weights, **pytorch_ce_loss_args) -> None:
9
+ super().__init__(reduction='none', **pytorch_ce_loss_args)
10
+ self.weights = weights
11
+
12
+ def __call__(self, outputs, targets, to_weight=True):
13
+ loss = super().__call__(outputs, targets)
14
+ if to_weight:
15
+ return (loss * self.weights[targets]).sum() / self.weights[targets].sum()
16
+ else:
17
+ return loss.mean()
18
+
19
+
20
+ if __name__ == '__main__':
21
+ x = torch.randn(10, 5)
22
+ target = torch.randint(0, 5, (10,))
23
+ weights = torch.tensor([1., 2., 3., 4., 5.])
24
+
25
+ # criterion_weighted = nn.CrossEntropyLoss(weight=weights)
26
+ # loss_weighted = criterion_weighted(x, target)
27
+
28
+ # criterion_weighted_manual = nn.CrossEntropyLoss(reduction='none')
29
+ # loss_weighted_manual = criterion_weighted_manual(x, target)
30
+ # print(loss_weighted, loss_weighted_manual.mean())
31
+ # loss_weighted_manual = (loss_weighted_manual * weights[target]).sum() / weights[target].sum()
32
+ # print(loss_weighted, loss_weighted_manual)
33
+ # print(torch.allclose(loss_weighted, loss_weighted_manual))
34
+
35
+ pytorch_weighted = nn.CrossEntropyLoss(weight=weights)
36
+ pytorch_unweighted = nn.CrossEntropyLoss()
37
+ custom = WeightedCrossEntropy(weights)
38
+
39
+ assert torch.allclose(pytorch_weighted(x, target), custom(x, target, to_weight=True))
40
+ assert torch.allclose(pytorch_unweighted(x, target), custom(x, target, to_weight=False))
41
+ print(custom(x, target, to_weight=True), custom(x, target, to_weight=False))
foleycrafter/models/specvqgan/modules/losses/vggishish/metrics.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from sklearn.metrics import average_precision_score, roc_auc_score
7
+
8
+ logger = logging.getLogger(f'main.{__name__}')
9
+
10
+ def metrics(targets, outputs, topk=(1, 5)):
11
+ """
12
+ Adapted from https://github.com/hche11/VGGSound/blob/master/utils.py
13
+
14
+ Calculate statistics including mAP, AUC, and d-prime.
15
+ Args:
16
+ output: 2d tensors, (dataset_size, classes_num) - before softmax
17
+ target: 1d tensors, (dataset_size, )
18
+ topk: tuple
19
+ Returns:
20
+ metric_dict: a dict of metrics
21
+ """
22
+ metrics_dict = dict()
23
+
24
+ num_cls = outputs.shape[-1]
25
+
26
+ # accuracy@k
27
+ _, preds = torch.topk(outputs, k=max(topk), dim=1)
28
+ correct_for_maxtopk = preds == targets.view(-1, 1).expand_as(preds)
29
+ for k in topk:
30
+ metrics_dict[f'accuracy_{k}'] = float(correct_for_maxtopk[:, :k].sum() / correct_for_maxtopk.shape[0])
31
+
32
+ # avg precision, average roc_auc, and dprime
33
+ targets = torch.nn.functional.one_hot(targets, num_classes=num_cls)
34
+
35
+ # ids of the predicted classes (same as softmax)
36
+ targets_pred = torch.softmax(outputs, dim=1)
37
+
38
+ targets = targets.numpy()
39
+ targets_pred = targets_pred.numpy()
40
+
41
+ # one-vs-rest
42
+ avg_p = [average_precision_score(targets[:, c], targets_pred[:, c], average=None) for c in range(num_cls)]
43
+ try:
44
+ roc_aucs = [roc_auc_score(targets[:, c], targets_pred[:, c], average=None) for c in range(num_cls)]
45
+ except ValueError:
46
+ logger.warning('Weird... Some classes never occured in targets. Do not trust the metrics.')
47
+ roc_aucs = np.array([0.5])
48
+ avg_p = np.array([0])
49
+
50
+ metrics_dict['mAP'] = np.mean(avg_p)
51
+ metrics_dict['mROCAUC'] = np.mean(roc_aucs)
52
+ # Percent point function (ppf) (inverse of cdf — percentiles).
53
+ metrics_dict['dprime'] = scipy.stats.norm().ppf(metrics_dict['mROCAUC']) * np.sqrt(2)
54
+
55
+ return metrics_dict
56
+
57
+
58
+ if __name__ == '__main__':
59
+ targets = torch.tensor([3, 3, 1, 2, 1, 0])
60
+ outputs = torch.tensor([
61
+ [1.2, 1.3, 1.1, 1.5],
62
+ [1.3, 1.4, 1.0, 1.1],
63
+ [1.5, 1.1, 1.4, 1.3],
64
+ [1.0, 1.2, 1.4, 1.5],
65
+ [1.2, 1.3, 1.1, 1.1],
66
+ [1.2, 1.1, 1.1, 1.1],
67
+ ]).float()
68
+ metrics_dict = metrics(targets, outputs, topk=(1, 3))
69
+ print(metrics_dict)
foleycrafter/models/specvqgan/modules/losses/vggishish/model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class VGGishish(nn.Module):
6
+
7
+ def __init__(self, conv_layers, use_bn, num_classes):
8
+ '''
9
+ Mostly from
10
+ https://pytorch.org/vision/0.8/_modules/torchvision/models/vgg.html
11
+ '''
12
+ super().__init__()
13
+ layers = []
14
+ in_channels = 1
15
+
16
+ # a list of channels with 'MP' (maxpool) from config
17
+ for v in conv_layers:
18
+ if v == 'MP':
19
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
20
+ else:
21
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, stride=1)
22
+ if use_bn:
23
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
24
+ else:
25
+ layers += [conv2d, nn.ReLU(inplace=True)]
26
+ in_channels = v
27
+ self.features = nn.Sequential(*layers)
28
+
29
+ self.avgpool = nn.AdaptiveAvgPool2d((5, 10))
30
+
31
+ self.flatten = nn.Flatten()
32
+ self.classifier = nn.Sequential(
33
+ nn.Linear(512 * 5 * 10, 4096),
34
+ nn.ReLU(True),
35
+ nn.Linear(4096, 4096),
36
+ nn.ReLU(True),
37
+ nn.Linear(4096, num_classes)
38
+ )
39
+
40
+ # weight init
41
+ self.reset_parameters()
42
+
43
+ def forward(self, x):
44
+ # adding channel dim for conv2d (B, 1, F, T) <-
45
+ x = x.unsqueeze(1)
46
+ # backbone (B, 1, 5, 53) <- (B, 1, 80, 860)
47
+ x = self.features(x)
48
+ # adaptive avg pooling (B, 1, 5, 10) <- (B, 1, 5, 53) – if no MP is used as the end of VGG
49
+ x = self.avgpool(x)
50
+ # flatten
51
+ x = self.flatten(x)
52
+ # classify
53
+ x = self.classifier(x)
54
+ return x
55
+
56
+ def reset_parameters(self):
57
+ for m in self.modules():
58
+ if isinstance(m, nn.Conv2d):
59
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
60
+ if m.bias is not None:
61
+ nn.init.constant_(m.bias, 0)
62
+ elif isinstance(m, nn.BatchNorm2d):
63
+ nn.init.constant_(m.weight, 1)
64
+ nn.init.constant_(m.bias, 0)
65
+ elif isinstance(m, nn.Linear):
66
+ nn.init.normal_(m.weight, 0, 0.01)
67
+ nn.init.constant_(m.bias, 0)
68
+
69
+
70
+ if __name__ == '__main__':
71
+ num_classes = 309
72
+ inputs = torch.rand(3, 80, 848)
73
+ conv_layers = [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
74
+ # conv_layers = [64, 'MP', 128, 'MP', 256, 256, 'MP', 512, 512, 'MP']
75
+ model = VGGishish(conv_layers, use_bn=False, num_classes=num_classes)
76
+ outputs = model(inputs)
77
+ print(outputs.shape)
foleycrafter/models/specvqgan/modules/losses/vggishish/predict.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.data import DataLoader
3
+ import torchvision
4
+ from tqdm import tqdm
5
+ from dataset import VGGSound
6
+ import torch
7
+ import torch.nn as nn
8
+ from metrics import metrics
9
+ from omegaconf import OmegaConf
10
+ from model import VGGishish
11
+ from transforms import Crop, StandardNormalizeAudio, ToTensor
12
+
13
+
14
+ if __name__ == '__main__':
15
+ cfg_cli = OmegaConf.from_cli()
16
+ print(cfg_cli.config)
17
+ cfg_yml = OmegaConf.load(cfg_cli.config)
18
+ # the latter arguments are prioritized
19
+ cfg = OmegaConf.merge(cfg_yml, cfg_cli)
20
+ OmegaConf.set_readonly(cfg, True)
21
+ print(OmegaConf.to_yaml(cfg))
22
+
23
+ # logger = LoggerWithTBoard(cfg)
24
+ transforms = [
25
+ StandardNormalizeAudio(cfg.mels_path),
26
+ ToTensor(),
27
+ ]
28
+ if cfg.cropped_size not in [None, 'None', 'none']:
29
+ transforms.append(Crop(cfg.cropped_size))
30
+ transforms = torchvision.transforms.transforms.Compose(transforms)
31
+
32
+ datasets = {
33
+ 'test': VGGSound('test', cfg.mels_path, transforms),
34
+ }
35
+
36
+ loaders = {
37
+ 'test': DataLoader(datasets['test'], batch_size=cfg.batch_size,
38
+ num_workers=cfg.num_workers, pin_memory=True)
39
+ }
40
+
41
+ device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')
42
+ model = VGGishish(cfg.conv_layers, cfg.use_bn, num_classes=len(datasets['test'].target2label))
43
+ model = model.to(device)
44
+
45
+ optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
46
+ criterion = nn.CrossEntropyLoss()
47
+
48
+ # loading the best model
49
+ folder_name = os.path.split(cfg.config)[0].split('/')[-1]
50
+ print(folder_name)
51
+ ckpt = torch.load(f'./logs/{folder_name}/vggishish-{folder_name}.pt', map_location='cpu')
52
+ model.load_state_dict(ckpt['model'])
53
+ print((f'The model was trained for {ckpt["epoch"]} epochs. Loss: {ckpt["loss"]:.4f}'))
54
+
55
+ # Testing the model
56
+ model.eval()
57
+ running_loss = 0
58
+ preds_from_each_batch = []
59
+ targets_from_each_batch = []
60
+
61
+ for i, batch in enumerate(tqdm(loaders['test'])):
62
+ inputs = batch['input'].to(device)
63
+ targets = batch['target'].to(device)
64
+
65
+ # zero the parameter gradients
66
+ optimizer.zero_grad()
67
+
68
+ # forward + backward + optimize
69
+ with torch.set_grad_enabled(False):
70
+ outputs = model(inputs)
71
+ loss = criterion(outputs, targets)
72
+
73
+ # loss
74
+ running_loss += loss.item()
75
+
76
+ # for metrics calculation later on
77
+ preds_from_each_batch += [outputs.detach().cpu()]
78
+ targets_from_each_batch += [targets.cpu()]
79
+
80
+ # logging metrics
81
+ preds_from_each_batch = torch.cat(preds_from_each_batch)
82
+ targets_from_each_batch = torch.cat(targets_from_each_batch)
83
+ test_metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch)
84
+ test_metrics_dict['avg_loss'] = running_loss / len(loaders['test'])
85
+ test_metrics_dict['param_num'] = sum(p.numel() for p in model.parameters() if p.requires_grad)
86
+
87
+ # TODO: I have no idea why tboard doesn't keep metrics (hparams) in a tensorboard when
88
+ # I run this experiment from cli: `python main.py config=./configs/vggish.yaml`
89
+ # while when I run it in vscode debugger the metrics are present in the tboard (weird)
90
+ print(test_metrics_dict)
foleycrafter/models/specvqgan/modules/losses/vggishish/predict_gh.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ from torch.utils.data import DataLoader
5
+ import torchvision
6
+ from tqdm import tqdm
7
+ from dataset import GreatestHit, AMT_test
8
+ import torch
9
+ import torch.nn as nn
10
+ from metrics import metrics
11
+ from omegaconf import OmegaConf
12
+ from model import VGGishish
13
+ from transforms import Crop, StandardNormalizeAudio, ToTensor
14
+
15
+
16
+ if __name__ == '__main__':
17
+ cfg_cli = sys.argv[1]
18
+ target_path = sys.argv[2]
19
+ model_path = sys.argv[3]
20
+ cfg_yml = OmegaConf.load(cfg_cli)
21
+ # the latter arguments are prioritized
22
+ cfg = cfg_yml
23
+ OmegaConf.set_readonly(cfg, True)
24
+ # print(OmegaConf.to_yaml(cfg))
25
+
26
+ device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')
27
+ transforms = [
28
+ StandardNormalizeAudio(cfg.mels_path),
29
+ ]
30
+ if cfg.cropped_size not in [None, 'None', 'none']:
31
+ transforms.append(Crop(cfg.cropped_size))
32
+ transforms.append(ToTensor())
33
+ transforms = torchvision.transforms.transforms.Compose(transforms)
34
+
35
+ testset = AMT_test(target_path, transforms, action_only=cfg.action_only, material_only=cfg.material_only)
36
+ loader = DataLoader(testset, batch_size=cfg.batch_size,
37
+ num_workers=cfg.num_workers, pin_memory=True)
38
+
39
+ model = VGGishish(cfg.conv_layers, cfg.use_bn, num_classes=len(testset.label2target))
40
+ ckpt = torch.load(model_path)['model']
41
+ model.load_state_dict(ckpt, strict=True)
42
+ model = model.to(device)
43
+
44
+ model.eval()
45
+
46
+ if cfg.cls_weights_in_loss:
47
+ weights = 1 / testset.class_counts
48
+ else:
49
+ weights = torch.ones(len(testset.label2target))
50
+
51
+ preds_from_each_batch = []
52
+ file_path_from_each_batch = []
53
+ for batch in tqdm(loader):
54
+ inputs = batch['input'].to(device)
55
+ file_path = batch['file_path']
56
+ with torch.set_grad_enabled(False):
57
+ outputs = model(inputs)
58
+ # for metrics calculation later on
59
+ preds_from_each_batch += [outputs.detach().cpu()]
60
+ file_path_from_each_batch += file_path
61
+ preds_from_each_batch = torch.cat(preds_from_each_batch)
62
+ _, preds = torch.topk(preds_from_each_batch, k=1)
63
+ pred_dict = {fp: int(p.item()) for fp, p in zip(file_path_from_each_batch, preds)}
64
+ mel_parent_dir = os.path.dirname(list(pred_dict.keys())[0])
65
+ pred_list = [pred_dict[os.path.join(mel_parent_dir, f'{i}.npy')] for i in range(len(pred_dict))]
66
+ json.dump(pred_list, open(target_path + f'_{cfg.exp_name}_preds.json', 'w'))
foleycrafter/models/specvqgan/modules/losses/vggishish/train_melception.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchvision
6
+ from omegaconf import OmegaConf
7
+ from torch.utils.data.dataloader import DataLoader
8
+ from torchvision.models.inception import BasicConv2d, Inception3
9
+ from tqdm import tqdm
10
+
11
+ from dataset import VGGSound
12
+ from logger import LoggerWithTBoard
13
+ from loss import WeightedCrossEntropy
14
+ from metrics import metrics
15
+ from transforms import Crop, StandardNormalizeAudio, ToTensor
16
+
17
+
18
+ # TODO: refactor ./evaluation/feature_extractors/melception.py to handle this class as well.
19
+ # So far couldn't do it because of the difference in outputs
20
+ class Melception(Inception3):
21
+
22
+ def __init__(self, num_classes, **kwargs):
23
+ # inception = Melception(num_classes=309)
24
+ super().__init__(num_classes=num_classes, **kwargs)
25
+ # the same as https://github.com/pytorch/vision/blob/5339e63148/torchvision/models/inception.py#L95
26
+ # but for 1-channel input instead of RGB.
27
+ self.Conv2d_1a_3x3 = BasicConv2d(1, 32, kernel_size=3, stride=2)
28
+ # also the 'hight' of the mel spec is 80 (vs 299 in RGB) we remove all max pool from Inception
29
+ self.maxpool1 = torch.nn.Identity()
30
+ self.maxpool2 = torch.nn.Identity()
31
+
32
+ def forward(self, x):
33
+ x = x.unsqueeze(1)
34
+ return super().forward(x)
35
+
36
+ def train_inception_scorer(cfg):
37
+ logger = LoggerWithTBoard(cfg)
38
+
39
+ random.seed(cfg.seed)
40
+ np.random.seed(cfg.seed)
41
+ torch.manual_seed(cfg.seed)
42
+ torch.cuda.manual_seed_all(cfg.seed)
43
+ # makes iterations faster (in this case 30%) if your inputs are of a fixed size
44
+ # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
45
+ torch.backends.cudnn.benchmark = True
46
+
47
+ meta_path = './data/vggsound.csv'
48
+ train_ids_path = './data/vggsound_train.txt'
49
+ cache_path = './data/'
50
+ splits_path = cache_path
51
+
52
+ transforms = [
53
+ StandardNormalizeAudio(cfg.mels_path, train_ids_path, cache_path),
54
+ ]
55
+ if cfg.cropped_size not in [None, 'None', 'none']:
56
+ logger.print_logger.info(f'Using cropping {cfg.cropped_size}')
57
+ transforms.append(Crop(cfg.cropped_size))
58
+ transforms.append(ToTensor())
59
+ transforms = torchvision.transforms.transforms.Compose(transforms)
60
+
61
+ datasets = {
62
+ 'train': VGGSound('train', cfg.mels_path, transforms, splits_path, meta_path),
63
+ 'valid': VGGSound('valid', cfg.mels_path, transforms, splits_path, meta_path),
64
+ 'test': VGGSound('test', cfg.mels_path, transforms, splits_path, meta_path),
65
+ }
66
+
67
+ loaders = {
68
+ 'train': DataLoader(datasets['train'], batch_size=cfg.batch_size, shuffle=True, drop_last=True,
69
+ num_workers=cfg.num_workers, pin_memory=True),
70
+ 'valid': DataLoader(datasets['valid'], batch_size=cfg.batch_size,
71
+ num_workers=cfg.num_workers, pin_memory=True),
72
+ 'test': DataLoader(datasets['test'], batch_size=cfg.batch_size,
73
+ num_workers=cfg.num_workers, pin_memory=True),
74
+ }
75
+
76
+ device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')
77
+
78
+ model = Melception(num_classes=len(datasets['train'].target2label))
79
+ model = model.to(device)
80
+ param_num = logger.log_param_num(model)
81
+
82
+ if cfg.optimizer == 'adam':
83
+ optimizer = torch.optim.Adam(
84
+ model.parameters(), lr=cfg.learning_rate, betas=cfg.betas, weight_decay=cfg.weight_decay)
85
+ elif cfg.optimizer == 'sgd':
86
+ optimizer = torch.optim.SGD(
87
+ model.parameters(), lr=cfg.learning_rate, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ if cfg.cls_weights_in_loss:
92
+ weights = 1 / datasets['train'].class_counts
93
+ else:
94
+ weights = torch.ones(len(datasets['train'].target2label))
95
+ criterion = WeightedCrossEntropy(weights.to(device))
96
+
97
+ # loop over the train and validation multiple times (typical PT boilerplate)
98
+ no_change_epochs = 0
99
+ best_valid_loss = float('inf')
100
+ early_stop_triggered = False
101
+
102
+ for epoch in range(cfg.num_epochs):
103
+
104
+ for phase in ['train', 'valid']:
105
+ if phase == 'train':
106
+ model.train()
107
+ else:
108
+ model.eval()
109
+
110
+ running_loss = 0
111
+ preds_from_each_batch = []
112
+ targets_from_each_batch = []
113
+
114
+ prog_bar = tqdm(loaders[phase], f'{phase} ({epoch})', ncols=0)
115
+ for i, batch in enumerate(prog_bar):
116
+ inputs = batch['input'].to(device)
117
+ targets = batch['target'].to(device)
118
+
119
+ # zero the parameter gradients
120
+ optimizer.zero_grad()
121
+
122
+ # forward + backward + optimize
123
+ with torch.set_grad_enabled(phase == 'train'):
124
+ # inception v3
125
+ if phase == 'train':
126
+ outputs, aux_outputs = model(inputs)
127
+ loss1 = criterion(outputs, targets)
128
+ loss2 = criterion(aux_outputs, targets)
129
+ loss = loss1 + 0.4*loss2
130
+ loss = criterion(outputs, targets, to_weight=True)
131
+ else:
132
+ outputs = model(inputs)
133
+ loss = criterion(outputs, targets, to_weight=False)
134
+
135
+ if phase == 'train':
136
+ loss.backward()
137
+ optimizer.step()
138
+
139
+ # loss
140
+ running_loss += loss.item()
141
+
142
+ # for metrics calculation later on
143
+ preds_from_each_batch += [outputs.detach().cpu()]
144
+ targets_from_each_batch += [targets.cpu()]
145
+
146
+ # iter logging
147
+ if i % 50 == 0:
148
+ logger.log_iter_loss(loss.item(), epoch*len(loaders[phase])+i, phase)
149
+ # tracks loss in the tqdm progress bar
150
+ prog_bar.set_postfix(loss=loss.item())
151
+
152
+ # logging loss
153
+ epoch_loss = running_loss / len(loaders[phase])
154
+ logger.log_epoch_loss(epoch_loss, epoch, phase)
155
+
156
+ # logging metrics
157
+ preds_from_each_batch = torch.cat(preds_from_each_batch)
158
+ targets_from_each_batch = torch.cat(targets_from_each_batch)
159
+ metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch)
160
+ logger.log_epoch_metrics(metrics_dict, epoch, phase)
161
+
162
+ # Early stopping
163
+ if phase == 'valid':
164
+ if epoch_loss < best_valid_loss:
165
+ no_change_epochs = 0
166
+ best_valid_loss = epoch_loss
167
+ logger.log_best_model(model, epoch_loss, epoch, optimizer, metrics_dict)
168
+ else:
169
+ no_change_epochs += 1
170
+ logger.print_logger.info(
171
+ f'Valid loss hasnt changed for {no_change_epochs} patience: {cfg.patience}'
172
+ )
173
+ if no_change_epochs >= cfg.patience:
174
+ early_stop_triggered = True
175
+
176
+ if early_stop_triggered:
177
+ logger.print_logger.info(f'Training is early stopped @ {epoch}')
178
+ break
179
+
180
+ logger.print_logger.info('Finished Training')
181
+
182
+ # loading the best model
183
+ ckpt = torch.load(logger.best_model_path)
184
+ model.load_state_dict(ckpt['model'])
185
+ logger.print_logger.info(f'Loading the best model from {logger.best_model_path}')
186
+ logger.print_logger.info((f'The model was trained for {ckpt["epoch"]} epochs. Loss: {ckpt["loss"]:.4f}'))
187
+
188
+ # Testing the model
189
+ model.eval()
190
+ running_loss = 0
191
+ preds_from_each_batch = []
192
+ targets_from_each_batch = []
193
+
194
+ for i, batch in enumerate(loaders['test']):
195
+ inputs = batch['input'].to(device)
196
+ targets = batch['target'].to(device)
197
+
198
+ # zero the parameter gradients
199
+ optimizer.zero_grad()
200
+
201
+ # forward + backward + optimize
202
+ with torch.set_grad_enabled(False):
203
+ outputs = model(inputs)
204
+ loss = criterion(outputs, targets, to_weight=False)
205
+
206
+ # loss
207
+ running_loss += loss.item()
208
+
209
+ # for metrics calculation later on
210
+ preds_from_each_batch += [outputs.detach().cpu()]
211
+ targets_from_each_batch += [targets.cpu()]
212
+
213
+ # logging metrics
214
+ preds_from_each_batch = torch.cat(preds_from_each_batch)
215
+ targets_from_each_batch = torch.cat(targets_from_each_batch)
216
+ test_metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch)
217
+ test_metrics_dict['avg_loss'] = running_loss / len(loaders['test'])
218
+ test_metrics_dict['param_num'] = param_num
219
+ # TODO: I have no idea why tboard doesn't keep metrics (hparams) when
220
+ # I run this experiment from cli: `python train_melception.py config=./configs/vggish.yaml`
221
+ # while when I run it in vscode debugger the metrics are logger (wtf)
222
+ logger.log_test_metrics(test_metrics_dict, dict(cfg), ckpt['epoch'])
223
+
224
+ logger.print_logger.info('Finished the experiment')
225
+
226
+
227
+ if __name__ == '__main__':
228
+ # input = torch.rand(16, 1, 80, 848)
229
+ # output, aux = inception(input)
230
+ # print(output.shape, aux.shape)
231
+ # Expected input size: (3, 299, 299) in RGB -> (1, 80, 848) in Mel Spec
232
+ # train_inception_scorer()
233
+
234
+ cfg_cli = OmegaConf.from_cli()
235
+ cfg_yml = OmegaConf.load(cfg_cli.config)
236
+ # the latter arguments are prioritized
237
+ cfg = OmegaConf.merge(cfg_yml, cfg_cli)
238
+ OmegaConf.set_readonly(cfg, True)
239
+ print(OmegaConf.to_yaml(cfg))
240
+
241
+ train_inception_scorer(cfg)
foleycrafter/models/specvqgan/modules/losses/vggishish/train_vggishish.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loss import WeightedCrossEntropy
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+ from omegaconf import OmegaConf
8
+ from torch.utils.data.dataloader import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from dataset import VGGSound
12
+ from transforms import Crop, StandardNormalizeAudio, ToTensor
13
+ from logger import LoggerWithTBoard
14
+ from metrics import metrics
15
+ from model import VGGishish
16
+
17
+ if __name__ == "__main__":
18
+ cfg_cli = OmegaConf.from_cli()
19
+ cfg_yml = OmegaConf.load(cfg_cli.config)
20
+ # the latter arguments are prioritized
21
+ cfg = OmegaConf.merge(cfg_yml, cfg_cli)
22
+ OmegaConf.set_readonly(cfg, True)
23
+ print(OmegaConf.to_yaml(cfg))
24
+
25
+ logger = LoggerWithTBoard(cfg)
26
+
27
+ random.seed(cfg.seed)
28
+ np.random.seed(cfg.seed)
29
+ torch.manual_seed(cfg.seed)
30
+ torch.cuda.manual_seed_all(cfg.seed)
31
+ # makes iterations faster (in this case 30%) if your inputs are of a fixed size
32
+ # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
33
+ torch.backends.cudnn.benchmark = True
34
+
35
+ transforms = [
36
+ StandardNormalizeAudio(cfg.mels_path),
37
+ ]
38
+ if cfg.cropped_size not in [None, 'None', 'none']:
39
+ logger.print_logger.info(f'Using cropping {cfg.cropped_size}')
40
+ transforms.append(Crop(cfg.cropped_size))
41
+ transforms.append(ToTensor())
42
+ transforms = torchvision.transforms.transforms.Compose(transforms)
43
+
44
+ datasets = {
45
+ 'train': VGGSound('train', cfg.mels_path, transforms),
46
+ 'valid': VGGSound('valid', cfg.mels_path, transforms),
47
+ 'test': VGGSound('test', cfg.mels_path, transforms),
48
+ }
49
+
50
+ loaders = {
51
+ 'train': DataLoader(datasets['train'], batch_size=cfg.batch_size, shuffle=True, drop_last=True,
52
+ num_workers=cfg.num_workers, pin_memory=True),
53
+ 'valid': DataLoader(datasets['valid'], batch_size=cfg.batch_size,
54
+ num_workers=cfg.num_workers, pin_memory=True),
55
+ 'test': DataLoader(datasets['test'], batch_size=cfg.batch_size,
56
+ num_workers=cfg.num_workers, pin_memory=True),
57
+ }
58
+
59
+ device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')
60
+
61
+ model = VGGishish(cfg.conv_layers, cfg.use_bn, num_classes=len(datasets['train'].target2label))
62
+ model = model.to(device)
63
+ param_num = logger.log_param_num(model)
64
+
65
+ if cfg.optimizer == 'adam':
66
+ optimizer = torch.optim.Adam(
67
+ model.parameters(), lr=cfg.learning_rate, betas=cfg.betas, weight_decay=cfg.weight_decay)
68
+ elif cfg.optimizer == 'sgd':
69
+ optimizer = torch.optim.SGD(
70
+ model.parameters(), lr=cfg.learning_rate, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
71
+ else:
72
+ raise NotImplementedError
73
+
74
+ if cfg.cls_weights_in_loss:
75
+ weights = 1 / datasets['train'].class_counts
76
+ else:
77
+ weights = torch.ones(len(datasets['train'].target2label))
78
+ criterion = WeightedCrossEntropy(weights.to(device))
79
+
80
+ # loop over the train and validation multiple times (typical PT boilerplate)
81
+ no_change_epochs = 0
82
+ best_valid_loss = float('inf')
83
+ early_stop_triggered = False
84
+
85
+ for epoch in range(cfg.num_epochs):
86
+
87
+ for phase in ['train', 'valid']:
88
+ if phase == 'train':
89
+ model.train()
90
+ else:
91
+ model.eval()
92
+
93
+ running_loss = 0
94
+ preds_from_each_batch = []
95
+ targets_from_each_batch = []
96
+
97
+ prog_bar = tqdm(loaders[phase], f'{phase} ({epoch})', ncols=0)
98
+ for i, batch in enumerate(prog_bar):
99
+ inputs = batch['input'].to(device)
100
+ targets = batch['target'].to(device)
101
+
102
+ # zero the parameter gradients
103
+ optimizer.zero_grad()
104
+
105
+ # forward + backward + optimize
106
+ with torch.set_grad_enabled(phase == 'train'):
107
+ outputs = model(inputs)
108
+ loss = criterion(outputs, targets, to_weight=phase == 'train')
109
+
110
+ if phase == 'train':
111
+ loss.backward()
112
+ optimizer.step()
113
+
114
+ # loss
115
+ running_loss += loss.item()
116
+
117
+ # for metrics calculation later on
118
+ preds_from_each_batch += [outputs.detach().cpu()]
119
+ targets_from_each_batch += [targets.cpu()]
120
+
121
+ # iter logging
122
+ if i % 50 == 0:
123
+ logger.log_iter_loss(loss.item(), epoch*len(loaders[phase])+i, phase)
124
+ # tracks loss in the tqdm progress bar
125
+ prog_bar.set_postfix(loss=loss.item())
126
+
127
+ # logging loss
128
+ epoch_loss = running_loss / len(loaders[phase])
129
+ logger.log_epoch_loss(epoch_loss, epoch, phase)
130
+
131
+ # logging metrics
132
+ preds_from_each_batch = torch.cat(preds_from_each_batch)
133
+ targets_from_each_batch = torch.cat(targets_from_each_batch)
134
+ metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch)
135
+ logger.log_epoch_metrics(metrics_dict, epoch, phase)
136
+
137
+ # Early stopping
138
+ if phase == 'valid':
139
+ if epoch_loss < best_valid_loss:
140
+ no_change_epochs = 0
141
+ best_valid_loss = epoch_loss
142
+ logger.log_best_model(model, epoch_loss, epoch, optimizer, metrics_dict)
143
+ else:
144
+ no_change_epochs += 1
145
+ logger.print_logger.info(
146
+ f'Valid loss hasnt changed for {no_change_epochs} patience: {cfg.patience}'
147
+ )
148
+ if no_change_epochs >= cfg.patience:
149
+ early_stop_triggered = True
150
+
151
+ if early_stop_triggered:
152
+ logger.print_logger.info(f'Training is early stopped @ {epoch}')
153
+ break
154
+
155
+ logger.print_logger.info('Finished Training')
156
+
157
+ # loading the best model
158
+ ckpt = torch.load(logger.best_model_path)
159
+ model.load_state_dict(ckpt['model'])
160
+ logger.print_logger.info(f'Loading the best model from {logger.best_model_path}')
161
+ logger.print_logger.info((f'The model was trained for {ckpt["epoch"]} epochs. Loss: {ckpt["loss"]:.4f}'))
162
+
163
+ # Testing the model
164
+ model.eval()
165
+ running_loss = 0
166
+ preds_from_each_batch = []
167
+ targets_from_each_batch = []
168
+
169
+ for i, batch in enumerate(loaders['test']):
170
+ inputs = batch['input'].to(device)
171
+ targets = batch['target'].to(device)
172
+
173
+ # zero the parameter gradients
174
+ optimizer.zero_grad()
175
+
176
+ # forward + backward + optimize
177
+ with torch.set_grad_enabled(False):
178
+ outputs = model(inputs)
179
+ loss = criterion(outputs, targets, to_weight=False)
180
+
181
+ # loss
182
+ running_loss += loss.item()
183
+
184
+ # for metrics calculation later on
185
+ preds_from_each_batch += [outputs.detach().cpu()]
186
+ targets_from_each_batch += [targets.cpu()]
187
+
188
+ # logging metrics
189
+ preds_from_each_batch = torch.cat(preds_from_each_batch)
190
+ targets_from_each_batch = torch.cat(targets_from_each_batch)
191
+ test_metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch)
192
+ test_metrics_dict['avg_loss'] = running_loss / len(loaders['test'])
193
+ test_metrics_dict['param_num'] = param_num
194
+ # TODO: I have no idea why tboard doesn't keep metrics (hparams) when
195
+ # I run this experiment from cli: `python train_vggishish.py config=./configs/vggish.yaml`
196
+ # while when I run it in vscode debugger the metrics are logger (wtf)
197
+ logger.log_test_metrics(test_metrics_dict, dict(cfg), ckpt['epoch'])
198
+
199
+ logger.print_logger.info('Finished the experiment')
foleycrafter/models/specvqgan/modules/losses/vggishish/train_vggishish_gh.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loss import WeightedCrossEntropy
2
+ import random
3
+ import os
4
+ import sys
5
+ import json
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torchvision
10
+ from omegaconf import OmegaConf
11
+ from torch.utils.data.dataloader import DataLoader
12
+ from tqdm import tqdm
13
+
14
+ from dataset import GreatestHit, AMT_test
15
+ from transforms import Crop, StandardNormalizeAudio, ToTensor
16
+ from logger import LoggerWithTBoard
17
+ from metrics import metrics
18
+ from model import VGGishish
19
+
20
+
21
+ if __name__ == "__main__":
22
+ cfg_cli = sys.argv[1]
23
+ cfg_yml = OmegaConf.load(cfg_cli)
24
+ # the latter arguments are prioritized
25
+ cfg = cfg_yml
26
+ OmegaConf.set_readonly(cfg, True)
27
+ print(OmegaConf.to_yaml(cfg))
28
+
29
+ logger = LoggerWithTBoard(cfg)
30
+
31
+ random.seed(cfg.seed)
32
+ np.random.seed(cfg.seed)
33
+ torch.manual_seed(cfg.seed)
34
+ torch.cuda.manual_seed_all(cfg.seed)
35
+ # makes iterations faster (in this case 30%) if your inputs are of a fixed size
36
+ # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
37
+ torch.backends.cudnn.benchmark = True
38
+
39
+ transforms = [
40
+ StandardNormalizeAudio(cfg.mels_path),
41
+ ]
42
+ if cfg.cropped_size not in [None, 'None', 'none']:
43
+ logger.print_logger.info(f'Using cropping {cfg.cropped_size}')
44
+ transforms.append(Crop(cfg.cropped_size))
45
+ transforms.append(ToTensor())
46
+ transforms = torchvision.transforms.transforms.Compose(transforms)
47
+
48
+ datasets = {
49
+ 'train': GreatestHit('train', cfg.mels_path, transforms, action_only=cfg.action_only, material_only=cfg.material_only),
50
+ 'valid': GreatestHit('valid', cfg.mels_path, transforms, action_only=cfg.action_only, material_only=cfg.material_only),
51
+ 'test': GreatestHit('test', cfg.mels_path, transforms, action_only=cfg.action_only, material_only=cfg.material_only),
52
+ }
53
+
54
+ loaders = {
55
+ 'train': DataLoader(datasets['train'], batch_size=cfg.batch_size, shuffle=True, drop_last=True,
56
+ num_workers=cfg.num_workers, pin_memory=True),
57
+ 'valid': DataLoader(datasets['valid'], batch_size=cfg.batch_size,
58
+ num_workers=cfg.num_workers, pin_memory=True),
59
+ 'test': DataLoader(datasets['test'], batch_size=cfg.batch_size,
60
+ num_workers=cfg.num_workers, pin_memory=True),
61
+ }
62
+
63
+ device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')
64
+
65
+ model = VGGishish(cfg.conv_layers, cfg.use_bn, num_classes=len(datasets['train'].label2target))
66
+ model = model.to(device)
67
+ if cfg.load_model is not None:
68
+ state_dict = torch.load(cfg.load_model, map_location=device)['model']
69
+ target_dict = {}
70
+ # ignore the last layer
71
+ for key, v in state_dict.items():
72
+ # ignore classifier
73
+ if 'classifier' not in key:
74
+ target_dict[key] = v
75
+ model.load_state_dict(target_dict, strict=False)
76
+ param_num = logger.log_param_num(model)
77
+
78
+ if cfg.optimizer == 'adam':
79
+ optimizer = torch.optim.Adam(
80
+ model.parameters(), lr=cfg.learning_rate, betas=cfg.betas, weight_decay=cfg.weight_decay)
81
+ elif cfg.optimizer == 'sgd':
82
+ optimizer = torch.optim.SGD(
83
+ model.parameters(), lr=cfg.learning_rate, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ if cfg.cls_weights_in_loss:
88
+ weights = 1 / datasets['train'].class_counts
89
+ else:
90
+ weights = torch.ones(len(datasets['train'].label2target))
91
+ criterion = WeightedCrossEntropy(weights.to(device))
92
+
93
+ # loop over the train and validation multiple times (typical PT boilerplate)
94
+ no_change_epochs = 0
95
+ best_valid_loss = float('inf')
96
+ early_stop_triggered = False
97
+
98
+ for epoch in range(cfg.num_epochs):
99
+
100
+ for phase in ['train', 'valid']:
101
+ if phase == 'train':
102
+ model.train()
103
+ else:
104
+ model.eval()
105
+
106
+ running_loss = 0
107
+ preds_from_each_batch = []
108
+ targets_from_each_batch = []
109
+
110
+ prog_bar = tqdm(loaders[phase], f'{phase} ({epoch})', ncols=0)
111
+ for i, batch in enumerate(prog_bar):
112
+ inputs = batch['input'].to(device)
113
+ targets = batch['target'].to(device)
114
+
115
+ # zero the parameter gradients
116
+ optimizer.zero_grad()
117
+
118
+ # forward + backward + optimize
119
+ with torch.set_grad_enabled(phase == 'train'):
120
+ outputs = model(inputs)
121
+ loss = criterion(outputs, targets, to_weight=phase == 'train')
122
+
123
+ if phase == 'train':
124
+ loss.backward()
125
+ optimizer.step()
126
+
127
+ # loss
128
+ running_loss += loss.item()
129
+
130
+ # for metrics calculation later on
131
+ preds_from_each_batch += [outputs.detach().cpu()]
132
+ targets_from_each_batch += [targets.cpu()]
133
+
134
+ # iter logging
135
+ if i % 50 == 0:
136
+ logger.log_iter_loss(loss.item(), epoch*len(loaders[phase])+i, phase)
137
+ # tracks loss in the tqdm progress bar
138
+ prog_bar.set_postfix(loss=loss.item())
139
+
140
+ # logging loss
141
+ epoch_loss = running_loss / len(loaders[phase])
142
+ logger.log_epoch_loss(epoch_loss, epoch, phase)
143
+
144
+ # logging metrics
145
+ preds_from_each_batch = torch.cat(preds_from_each_batch)
146
+ targets_from_each_batch = torch.cat(targets_from_each_batch)
147
+ if cfg.action_only:
148
+ metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch, topk=(1,))
149
+ else:
150
+ metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch, topk=(1, 5))
151
+ logger.log_epoch_metrics(metrics_dict, epoch, phase)
152
+
153
+ # Early stopping
154
+ if phase == 'valid':
155
+ if epoch_loss < best_valid_loss:
156
+ no_change_epochs = 0
157
+ best_valid_loss = epoch_loss
158
+ logger.log_best_model(model, epoch_loss, epoch, optimizer, metrics_dict)
159
+ else:
160
+ no_change_epochs += 1
161
+ logger.print_logger.info(
162
+ f'Valid loss hasnt changed for {no_change_epochs} patience: {cfg.patience}'
163
+ )
164
+ if no_change_epochs >= cfg.patience:
165
+ early_stop_triggered = True
166
+
167
+ if early_stop_triggered:
168
+ logger.print_logger.info(f'Training is early stopped @ {epoch}')
169
+ break
170
+
171
+ logger.print_logger.info('Finished Training')
172
+
173
+ # loading the best model
174
+ ckpt = torch.load(logger.best_model_path)
175
+ model.load_state_dict(ckpt['model'])
176
+ logger.print_logger.info(f'Loading the best model from {logger.best_model_path}')
177
+ logger.print_logger.info((f'The model was trained for {ckpt["epoch"]} epochs. Loss: {ckpt["loss"]:.4f}'))
178
+
179
+ # Testing the model
180
+ model.eval()
181
+ running_loss = 0
182
+ preds_from_each_batch = []
183
+ targets_from_each_batch = []
184
+
185
+ for i, batch in enumerate(loaders['test']):
186
+ inputs = batch['input'].to(device)
187
+ targets = batch['target'].to(device)
188
+
189
+ # zero the parameter gradients
190
+ optimizer.zero_grad()
191
+
192
+ # forward + backward + optimize
193
+ with torch.set_grad_enabled(False):
194
+ outputs = model(inputs)
195
+ loss = criterion(outputs, targets, to_weight=False)
196
+
197
+ # loss
198
+ running_loss += loss.item()
199
+
200
+ # for metrics calculation later on
201
+ preds_from_each_batch += [outputs.detach().cpu()]
202
+ targets_from_each_batch += [targets.cpu()]
203
+
204
+ # logging metrics
205
+ preds_from_each_batch = torch.cat(preds_from_each_batch)
206
+ targets_from_each_batch = torch.cat(targets_from_each_batch)
207
+ if cfg.action_only:
208
+ test_metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch, topk=(1,))
209
+ else:
210
+ test_metrics_dict = metrics(targets_from_each_batch, preds_from_each_batch, topk=(1, 5))
211
+ test_metrics_dict['avg_loss'] = running_loss / len(loaders['test'])
212
+ test_metrics_dict['param_num'] = param_num
213
+ # TODO: I have no idea why tboard doesn't keep metrics (hparams) when
214
+ # I run this experiment from cli: `python train_vggishish.py config=./configs/vggish.yaml`
215
+ # while when I run it in vscode debugger the metrics are logger (wtf)
216
+ logger.log_test_metrics(test_metrics_dict, dict(cfg), ckpt['epoch'])
217
+
218
+ logger.print_logger.info('Finished the experiment')
foleycrafter/models/specvqgan/modules/losses/vggishish/transforms.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import albumentations
6
+ import numpy as np
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ logger = logging.getLogger(f'main.{__name__}')
11
+
12
+
13
+ class StandardNormalizeAudio(object):
14
+ '''
15
+ Frequency-wise normalization
16
+ '''
17
+ def __init__(self, specs_dir, train_ids_path='./data/vggsound_train.txt', cache_path='./data/'):
18
+ self.specs_dir = specs_dir
19
+ self.train_ids_path = train_ids_path
20
+ # making the stats filename to match the specs dir name
21
+ self.cache_path = os.path.join(cache_path, f'train_means_stds_{Path(specs_dir).stem}.txt')
22
+ logger.info('Assuming that the input stats are calculated using preprocessed spectrograms (log)')
23
+ self.train_stats = self.calculate_or_load_stats()
24
+
25
+ def __call__(self, item):
26
+ # just to generalizat the input handling. Useful for FID, IS eval and training other staff
27
+ if isinstance(item, dict):
28
+ if 'input' in item:
29
+ input_key = 'input'
30
+ elif 'image' in item:
31
+ input_key = 'image'
32
+ else:
33
+ raise NotImplementedError
34
+ item[input_key] = (item[input_key] - self.train_stats['means']) / self.train_stats['stds']
35
+ elif isinstance(item, torch.Tensor):
36
+ # broadcasts np.ndarray (80, 1) to (1, 80, 1) because item is torch.Tensor (B, 80, T)
37
+ item = (item - self.train_stats['means']) / self.train_stats['stds']
38
+ else:
39
+ raise NotImplementedError
40
+ return item
41
+
42
+ def calculate_or_load_stats(self):
43
+ try:
44
+ # (F, 2)
45
+ train_stats = np.loadtxt(self.cache_path)
46
+ means, stds = train_stats.T
47
+ logger.info('Trying to load train stats for Standard Normalization of inputs')
48
+ except OSError:
49
+ logger.info('Could not find the precalculated stats for Standard Normalization. Calculating...')
50
+ train_vid_ids = open(self.train_ids_path)
51
+ specs_paths = [os.path.join(self.specs_dir, f'{i.rstrip()}_mel.npy') for i in train_vid_ids]
52
+ means = [None] * len(specs_paths)
53
+ stds = [None] * len(specs_paths)
54
+ for i, path in enumerate(tqdm(specs_paths)):
55
+ spec = np.load(path)
56
+ means[i] = spec.mean(axis=1)
57
+ stds[i] = spec.std(axis=1)
58
+ # (F) <- (num_files, F)
59
+ means = np.array(means).mean(axis=0)
60
+ stds = np.array(stds).mean(axis=0)
61
+ # saving in two columns
62
+ np.savetxt(self.cache_path, np.vstack([means, stds]).T, fmt='%0.8f')
63
+ means = means.reshape(-1, 1)
64
+ stds = stds.reshape(-1, 1)
65
+ return {'means': means, 'stds': stds}
66
+
67
+ class ToTensor(object):
68
+
69
+ def __call__(self, item):
70
+ item['input'] = torch.from_numpy(item['input']).float()
71
+ if 'target' in item:
72
+ item['target'] = torch.tensor(item['target'])
73
+ return item
74
+
75
+ class Crop(object):
76
+
77
+ def __init__(self, cropped_shape=None, random_crop=False):
78
+ self.cropped_shape = cropped_shape
79
+ if cropped_shape is not None:
80
+ mel_num, spec_len = cropped_shape
81
+ if random_crop:
82
+ self.cropper = albumentations.RandomCrop
83
+ else:
84
+ self.cropper = albumentations.CenterCrop
85
+ self.preprocessor = albumentations.Compose([self.cropper(mel_num, spec_len)])
86
+ else:
87
+ self.preprocessor = lambda **kwargs: kwargs
88
+
89
+ def __call__(self, item):
90
+ item['input'] = self.preprocessor(image=item['input'])['image']
91
+ return item
92
+
93
+
94
+ if __name__ == '__main__':
95
+ cropper = Crop([80, 848])
96
+ item = {'input': torch.rand([80, 860])}
97
+ outputs = cropper(item)
98
+ print(outputs['input'].shape)
foleycrafter/models/specvqgan/modules/losses/vqperceptual.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import sys
5
+
6
+ sys.path.insert(0, '.') # nopep8
7
+ from foleycrafter.models.specvqgan.modules.discriminator.model import (NLayerDiscriminator, NLayerDiscriminator1dFeats,
8
+ NLayerDiscriminator1dSpecs,
9
+ weights_init)
10
+ from foleycrafter.models.specvqgan.modules.losses.lpaps import LPAPS
11
+
12
+
13
+ class DummyLoss(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+
18
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
19
+ if global_step < threshold:
20
+ weight = value
21
+ return weight
22
+
23
+
24
+ def hinge_d_loss(logits_real, logits_fake):
25
+ loss_real = torch.mean(F.relu(1. - logits_real))
26
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
27
+ d_loss = 0.5 * (loss_real + loss_fake)
28
+ return d_loss
29
+
30
+
31
+ def vanilla_d_loss(logits_real, logits_fake):
32
+ d_loss = 0.5 * (
33
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
34
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
35
+ return d_loss
36
+
37
+
38
+ class VQLPAPSWithDiscriminator(nn.Module):
39
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
40
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
41
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
42
+ disc_ndf=64, disc_loss="hinge", min_adapt_weight=0.0, max_adapt_weight=1e4):
43
+ super().__init__()
44
+ assert disc_loss in ["hinge", "vanilla"]
45
+ self.codebook_weight = codebook_weight
46
+ self.pixel_weight = pixelloss_weight
47
+ self.perceptual_loss = LPAPS().eval()
48
+ self.perceptual_weight = perceptual_weight
49
+
50
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
51
+ n_layers=disc_num_layers,
52
+ use_actnorm=use_actnorm,
53
+ ndf=disc_ndf
54
+ ).apply(weights_init)
55
+ self.discriminator_iter_start = disc_start
56
+ if disc_loss == "hinge":
57
+ self.disc_loss = hinge_d_loss
58
+ elif disc_loss == "vanilla":
59
+ self.disc_loss = vanilla_d_loss
60
+ else:
61
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
62
+ print(f"VQLPAPSWithDiscriminator running with {disc_loss} loss.")
63
+ self.disc_factor = disc_factor
64
+ self.discriminator_weight = disc_weight
65
+ self.disc_conditional = disc_conditional
66
+ self.min_adapt_weight = min_adapt_weight
67
+ self.max_adapt_weight = max_adapt_weight
68
+
69
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
70
+ if last_layer is not None:
71
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
72
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
73
+ else:
74
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
75
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
76
+
77
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
78
+ d_weight = torch.clamp(d_weight, self.min_adapt_weight, self.max_adapt_weight).detach()
79
+ d_weight = d_weight * self.discriminator_weight
80
+ return d_weight
81
+
82
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
83
+ global_step, last_layer=None, cond=None, split="train"):
84
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
85
+ if self.perceptual_weight > 0:
86
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
87
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
88
+ else:
89
+ p_loss = torch.tensor([0.0])
90
+
91
+ nll_loss = rec_loss
92
+ # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
93
+ nll_loss = torch.mean(nll_loss)
94
+
95
+ # now the GAN part
96
+ if optimizer_idx == 0:
97
+ # generator update
98
+ if cond is None:
99
+ assert not self.disc_conditional
100
+ logits_fake = self.discriminator(reconstructions.contiguous())
101
+ else:
102
+ assert self.disc_conditional
103
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
104
+ g_loss = -torch.mean(logits_fake)
105
+
106
+ try:
107
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
108
+ except RuntimeError:
109
+ assert not self.training
110
+ d_weight = torch.tensor(0.0)
111
+
112
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
113
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
114
+
115
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
116
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
117
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
118
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
119
+ "{}/p_loss".format(split): p_loss.detach().mean(),
120
+ "{}/d_weight".format(split): d_weight.detach(),
121
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
122
+ "{}/g_loss".format(split): g_loss.detach().mean(),
123
+ }
124
+ return loss, log
125
+
126
+ if optimizer_idx == 1:
127
+ # second pass for discriminator update
128
+ if cond is None:
129
+ logits_real = self.discriminator(inputs.contiguous().detach())
130
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
131
+ else:
132
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
133
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
134
+
135
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
136
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
137
+
138
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
139
+ "{}/logits_real".format(split): logits_real.detach().mean(),
140
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
141
+ }
142
+ return d_loss, log
143
+
144
+
145
+ class VQLPAPSWithDiscriminator1dFeats(VQLPAPSWithDiscriminator):
146
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
147
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
148
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
149
+ disc_ndf=64, disc_loss="hinge", min_adapt_weight=0.0, max_adapt_weight=1e4):
150
+ super().__init__(disc_start=disc_start, codebook_weight=codebook_weight,
151
+ pixelloss_weight=pixelloss_weight, disc_num_layers=disc_num_layers,
152
+ disc_in_channels=disc_in_channels, disc_factor=disc_factor, disc_weight=disc_weight,
153
+ perceptual_weight=perceptual_weight, use_actnorm=use_actnorm,
154
+ disc_conditional=disc_conditional, disc_ndf=disc_ndf, disc_loss=disc_loss,
155
+ min_adapt_weight=min_adapt_weight, max_adapt_weight=max_adapt_weight)
156
+
157
+ self.discriminator = NLayerDiscriminator1dFeats(input_nc=disc_in_channels, n_layers=disc_num_layers,
158
+ use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
159
+
160
+ class VQLPAPSWithDiscriminator1dSpecs(VQLPAPSWithDiscriminator):
161
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
162
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
163
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
164
+ disc_ndf=64, disc_loss="hinge", min_adapt_weight=0.0, max_adapt_weight=1e4):
165
+ super().__init__(disc_start=disc_start, codebook_weight=codebook_weight,
166
+ pixelloss_weight=pixelloss_weight, disc_num_layers=disc_num_layers,
167
+ disc_in_channels=disc_in_channels, disc_factor=disc_factor, disc_weight=disc_weight,
168
+ perceptual_weight=perceptual_weight, use_actnorm=use_actnorm,
169
+ disc_conditional=disc_conditional, disc_ndf=disc_ndf, disc_loss=disc_loss,
170
+ min_adapt_weight=min_adapt_weight, max_adapt_weight=max_adapt_weight)
171
+
172
+ self.discriminator = NLayerDiscriminator1dSpecs(input_nc=disc_in_channels, n_layers=disc_num_layers,
173
+ use_actnorm=use_actnorm, ndf=disc_ndf).apply(weights_init)
174
+
175
+
176
+ if __name__ == '__main__':
177
+ from foleycrafter.models.specvqgan.modules.diffusionmodules.model import Decoder, Decoder1d
178
+
179
+ optimizer_idx = 0
180
+ loss_config = {
181
+ 'disc_conditional': False,
182
+ 'disc_start': 30001,
183
+ 'disc_weight': 0.8,
184
+ 'codebook_weight': 1.0,
185
+ }
186
+ ddconfig = {
187
+ 'ch': 128,
188
+ 'num_res_blocks': 2,
189
+ 'dropout': 0.0,
190
+ 'z_channels': 256,
191
+ 'double_z': False,
192
+ }
193
+ qloss = torch.rand(1, requires_grad=True)
194
+
195
+ ## AUDIO
196
+ loss_config['disc_in_channels'] = 1
197
+ ddconfig['in_channels'] = 1
198
+ ddconfig['resolution'] = 848
199
+ ddconfig['attn_resolutions'] = [53]
200
+ ddconfig['out_ch'] = 1
201
+ ddconfig['ch_mult'] = [1, 1, 2, 2, 4]
202
+ decoder = Decoder(**ddconfig)
203
+ loss = VQLPAPSWithDiscriminator(**loss_config)
204
+ x = torch.rand(16, 1, 80, 848)
205
+ # subtracting something which uses dec_conv_out so that it will be in a graph
206
+ xrec = torch.rand(16, 1, 80, 848) - decoder.conv_out(torch.rand(16, 128, 80, 848)).mean()
207
+ aeloss, log_dict_ae = loss(qloss, x, xrec, optimizer_idx, global_step=0,last_layer=decoder.conv_out.weight)
208
+ print(aeloss)
209
+ print(log_dict_ae)
foleycrafter/models/specvqgan/modules/misc/class_cond.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class ClassOnlyStage(object):
4
+ def __init__(self):
5
+ pass
6
+
7
+ def eval(self):
8
+ return self
9
+
10
+ def encode(self, c):
11
+ """fake vqmodel interface because self.cond_stage_model should have something
12
+ similar to coord.py but even more `dummy`"""
13
+ # assert 0.0 <= c.min() and c.max() <= 1.0
14
+ info = None, None, c
15
+ return c, None, info
16
+
17
+ def decode(self, c):
18
+ return c
19
+
20
+ def get_input(self, batch, k):
21
+ return batch[k].unsqueeze(1).to(memory_format=torch.contiguous_format)