Files changed (2) hide show
  1. app.py +189 -61
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,4 +1,10 @@
1
  import torch
 
 
 
 
 
 
2
  from diffusers.loaders import AttnProcsLayers
3
  from transformers import CLIPTextModel, CLIPTokenizer
4
  from modules.beats.BEATs import BEATs, BEATsConfig
@@ -6,9 +12,21 @@ from modules.AudioToken.embedder import FGAEmbedder
6
  from diffusers import AutoencoderKL, UNet2DConditionModel
7
  from diffusers.models.attention_processor import LoRAAttnProcessor
8
  from diffusers import StableDiffusionPipeline
9
- import numpy as np
10
- import gradio as gr
11
- from scipy import signal
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  class AudioTokenWrapper(torch.nn.Module):
@@ -19,27 +37,62 @@ class AudioTokenWrapper(torch.nn.Module):
19
  lora,
20
  device,
21
  ):
22
-
23
  super().__init__()
 
24
  # Load scheduler and models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  self.tokenizer = CLIPTokenizer.from_pretrained(
26
- "CompVis/stable-diffusion-v1-4", subfolder="tokenizer"
27
  )
28
  self.text_encoder = CLIPTextModel.from_pretrained(
29
- "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", revision=None
30
  )
31
  self.unet = UNet2DConditionModel.from_pretrained(
32
- "CompVis/stable-diffusion-v1-4", subfolder="unet", revision=None
33
  )
34
  self.vae = AutoencoderKL.from_pretrained(
35
- "CompVis/stable-diffusion-v1-4", subfolder="vae", revision=None
36
  )
37
 
38
  checkpoint = torch.load(
39
- 'models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
40
- cfg = BEATsConfig(checkpoint['cfg'])
 
41
  self.aud_encoder = BEATs(cfg)
42
- self.aud_encoder.load_state_dict(checkpoint['model'])
43
  self.aud_encoder.predictor = None
44
  input_size = 768 * 3
45
  self.embedder = FGAEmbedder(input_size=input_size, output_size=768)
@@ -53,48 +106,88 @@ class AudioTokenWrapper(torch.nn.Module):
53
  # Set correct lora layers
54
  lora_attn_procs = {}
55
  for name in self.unet.attn_processors.keys():
56
- cross_attention_dim = None if name.endswith(
57
- "attn1.processor") else self.unet.config.cross_attention_dim
 
 
 
58
  if name.startswith("mid_block"):
59
  hidden_size = self.unet.config.block_out_channels[-1]
60
  elif name.startswith("up_blocks"):
61
  block_id = int(name[len("up_blocks.")])
62
- hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
 
 
63
  elif name.startswith("down_blocks"):
64
  block_id = int(name[len("down_blocks.")])
65
  hidden_size = self.unet.config.block_out_channels[block_id]
66
 
67
- lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size,
68
- cross_attention_dim=cross_attention_dim)
 
69
 
70
  self.unet.set_attn_processor(lora_attn_procs)
71
  self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
72
  self.lora_layers.eval()
73
- lora_layers_learned_embeds = 'models/lora_layers_learned_embeds.bin'
74
- self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=device))
 
 
75
  self.unet.load_attn_procs(lora_layers_learned_embeds)
76
 
77
  self.embedder.eval()
78
- embedder_learned_embeds = 'models/embedder_learned_embeds.bin'
79
- self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=device))
 
 
80
 
81
- self.placeholder_token = '<*>'
82
  num_added_tokens = self.tokenizer.add_tokens(self.placeholder_token)
83
  if num_added_tokens == 0:
84
  raise ValueError(
85
  f"The tokenizer already contains the token {self.placeholder_token}. Please pass a different"
86
  " `placeholder_token` that is not already in the tokenizer."
87
  )
88
- self.placeholder_token_id = self.tokenizer.convert_tokens_to_ids(self.placeholder_token)
 
 
89
  # Resize the token embeddings as we are adding new special tokens to the tokenizer
90
  self.text_encoder.resize_token_embeddings(len(self.tokenizer))
91
 
92
 
93
- def greet(audio):
94
  sample_rate, audio = audio
95
- audio = audio.astype(np.float32, order='C') / 32768.0
96
  desired_sample_rate = 16000
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  if audio.ndim == 2:
99
  audio = audio.sum(axis=1) / 2
100
 
@@ -109,54 +202,89 @@ def greet(audio):
109
  audio = signal.resample(audio, new_length)
110
 
111
  weight_dtype = torch.float32
112
- prompt = 'a photo of <*>'
113
 
114
- audio_values = torch.unsqueeze(torch.tensor(audio), dim=0).to(device).to(dtype=weight_dtype)
 
 
115
  if audio_values.ndim == 1:
116
  audio_values = torch.unsqueeze(audio_values, dim=0)
117
- aud_features = model.aud_encoder.extract_features(audio_values)[1]
118
- audio_token = model.embedder(aud_features)
119
 
120
- token_embeds = model.text_encoder.get_input_embeddings().weight.data
121
- token_embeds[model.placeholder_token_id] = audio_token.clone()
 
 
 
122
 
 
 
 
123
  pipeline = StableDiffusionPipeline.from_pretrained(
124
- "CompVis/stable-diffusion-v1-4",
125
  tokenizer=model.tokenizer,
126
  text_encoder=model.text_encoder,
127
  vae=model.vae,
128
  unet=model.unet,
 
 
129
  ).to(device)
130
- image = pipeline(prompt, num_inference_steps=40, guidance_scale=7.5).images[0]
 
 
 
 
 
 
 
131
  return image
132
 
133
 
134
- if __name__ == "__main__":
135
-
136
- lora = False
137
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
138
- model = AudioTokenWrapper(lora, device)
139
-
140
- description = """<p>
141
- This is a demo of <a href='https://pages.cs.huji.ac.il/adiyoss-lab/AudioToken' target='_blank'>AudioToken: Adaptation of Text-Conditioned Diffusion Models for Audio-to-Image Generation</a>.<br><br>
142
- A novel method utilizing latent diffusion models trained for text-to-image-generation to generate images conditioned on audio recordings. Using a pre-trained audio encoding model, the proposed method encodes audio into a new token, which can be considered as an adaptation layer between the audio and text representations.<br><br>
143
- For more information, please see the original <a href='https://arxiv.org/abs/2305.13050' target='_blank'>paper</a> and <a href='https://github.com/guyyariv/AudioToken' target='_blank'>repo</a>.
144
- </p>"""
145
-
146
- examples = [
147
- # ["assets/train.wav"],
148
- ["assets/dog barking.wav"],
149
- ["assets/airplane taking off.wav"],
150
- # ["assets/electric guitar.wav"],
151
- # ["assets/female sings.wav"],
152
- ]
153
-
154
- demo = gr.Interface(
155
- fn=greet,
156
- inputs="audio",
157
- outputs="image",
158
- title='AudioToken',
159
- description=description,
160
- examples=examples
161
- )
162
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ from scipy import signal
5
+ from diffusers.utils import logging
6
+
7
+ logging.set_verbosity_error()
8
  from diffusers.loaders import AttnProcsLayers
9
  from transformers import CLIPTextModel, CLIPTokenizer
10
  from modules.beats.BEATs import BEATs, BEATsConfig
 
12
  from diffusers import AutoencoderKL, UNet2DConditionModel
13
  from diffusers.models.attention_processor import LoRAAttnProcessor
14
  from diffusers import StableDiffusionPipeline
15
+ from diffusers import (
16
+ DDPMScheduler,
17
+ DDIMScheduler,
18
+ PNDMScheduler,
19
+ LMSDiscreteScheduler,
20
+ EulerDiscreteScheduler,
21
+ EulerAncestralDiscreteScheduler,
22
+ DPMSolverMultistepScheduler,
23
+ DPMSolverSinglestepScheduler,
24
+ DEISMultistepScheduler,
25
+ UniPCMultistepScheduler,
26
+ HeunDiscreteScheduler,
27
+ KDPM2AncestralDiscreteScheduler,
28
+ KDPM2DiscreteScheduler,
29
+ )
30
 
31
 
32
  class AudioTokenWrapper(torch.nn.Module):
 
37
  lora,
38
  device,
39
  ):
 
40
  super().__init__()
41
+ self.repo_id = repo_id
42
  # Load scheduler and models
43
+ self.ddpm = DDPMScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
44
+ self.ddim = DDIMScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
45
+ self.pndm = PNDMScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
46
+ self.lms = LMSDiscreteScheduler.from_pretrained(
47
+ self.repo_id, subfolder="scheduler"
48
+ )
49
+ self.euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(
50
+ self.repo_id, subfolder="scheduler"
51
+ )
52
+ self.euler = EulerDiscreteScheduler.from_pretrained(
53
+ self.repo_id, subfolder="scheduler"
54
+ )
55
+ self.dpm = DPMSolverMultistepScheduler.from_pretrained(
56
+ self.repo_id, subfolder="scheduler"
57
+ )
58
+ self.dpms = DPMSolverSinglestepScheduler.from_pretrained(
59
+ self.repo_id, subfolder="scheduler"
60
+ )
61
+ self.deis = DEISMultistepScheduler.from_pretrained(
62
+ self.repo_id, subfolder="scheduler"
63
+ )
64
+ self.unipc = UniPCMultistepScheduler.from_pretrained(
65
+ self.repo_id, subfolder="scheduler"
66
+ )
67
+ self.heun = HeunDiscreteScheduler.from_pretrained(
68
+ self.repo_id, subfolder="scheduler"
69
+ )
70
+ self.kdpm2_anc = KDPM2AncestralDiscreteScheduler.from_pretrained(
71
+ self.repo_id, subfolder="scheduler"
72
+ )
73
+ self.kdpm2 = KDPM2DiscreteScheduler.from_pretrained(
74
+ self.repo_id, subfolder="scheduler"
75
+ )
76
+
77
  self.tokenizer = CLIPTokenizer.from_pretrained(
78
+ self.repo_id, subfolder="tokenizer"
79
  )
80
  self.text_encoder = CLIPTextModel.from_pretrained(
81
+ self.repo_id, subfolder="text_encoder", revision=None
82
  )
83
  self.unet = UNet2DConditionModel.from_pretrained(
84
+ self.repo_id, subfolder="unet", revision=None
85
  )
86
  self.vae = AutoencoderKL.from_pretrained(
87
+ self.repo_id, subfolder="vae", revision=None
88
  )
89
 
90
  checkpoint = torch.load(
91
+ "models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt"
92
+ )
93
+ cfg = BEATsConfig(checkpoint["cfg"])
94
  self.aud_encoder = BEATs(cfg)
95
+ self.aud_encoder.load_state_dict(checkpoint["model"])
96
  self.aud_encoder.predictor = None
97
  input_size = 768 * 3
98
  self.embedder = FGAEmbedder(input_size=input_size, output_size=768)
 
106
  # Set correct lora layers
107
  lora_attn_procs = {}
108
  for name in self.unet.attn_processors.keys():
109
+ cross_attention_dim = (
110
+ None
111
+ if name.endswith("attn1.processor")
112
+ else self.unet.config.cross_attention_dim
113
+ )
114
  if name.startswith("mid_block"):
115
  hidden_size = self.unet.config.block_out_channels[-1]
116
  elif name.startswith("up_blocks"):
117
  block_id = int(name[len("up_blocks.")])
118
+ hidden_size = list(reversed(self.unet.config.block_out_channels))[
119
+ block_id
120
+ ]
121
  elif name.startswith("down_blocks"):
122
  block_id = int(name[len("down_blocks.")])
123
  hidden_size = self.unet.config.block_out_channels[block_id]
124
 
125
+ lora_attn_procs[name] = LoRAAttnProcessor(
126
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
127
+ )
128
 
129
  self.unet.set_attn_processor(lora_attn_procs)
130
  self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
131
  self.lora_layers.eval()
132
+ lora_layers_learned_embeds = "models/lora_layers_learned_embeds.bin"
133
+ self.lora_layers.load_state_dict(
134
+ torch.load(lora_layers_learned_embeds, map_location=device)
135
+ )
136
  self.unet.load_attn_procs(lora_layers_learned_embeds)
137
 
138
  self.embedder.eval()
139
+ embedder_learned_embeds = "models/embedder_learned_embeds.bin"
140
+ self.embedder.load_state_dict(
141
+ torch.load(embedder_learned_embeds, map_location=device)
142
+ )
143
 
144
+ self.placeholder_token = "<*>"
145
  num_added_tokens = self.tokenizer.add_tokens(self.placeholder_token)
146
  if num_added_tokens == 0:
147
  raise ValueError(
148
  f"The tokenizer already contains the token {self.placeholder_token}. Please pass a different"
149
  " `placeholder_token` that is not already in the tokenizer."
150
  )
151
+ self.placeholder_token_id = self.tokenizer.convert_tokens_to_ids(
152
+ self.placeholder_token
153
+ )
154
  # Resize the token embeddings as we are adding new special tokens to the tokenizer
155
  self.text_encoder.resize_token_embeddings(len(self.tokenizer))
156
 
157
 
158
+ def greet(audio, steps=25, scheduler="ddpm"):
159
  sample_rate, audio = audio
160
+ audio = audio.astype(np.float32, order="C") / 32768.0
161
  desired_sample_rate = 16000
162
 
163
+ match scheduler:
164
+ case "ddpm":
165
+ use_sched = model.ddpm
166
+ case "ddim":
167
+ use_sched = model.ddim
168
+ case "pndm":
169
+ use_sched = model.pndm
170
+ case "lms":
171
+ use_sched = model.lms
172
+ case "euler_anc":
173
+ use_sched = model.euler_anc
174
+ case "euler":
175
+ use_sched = model.euler
176
+ case "dpm":
177
+ use_sched = model.dpm
178
+ case "dpms":
179
+ use_sched = model.dpms
180
+ case "deis":
181
+ use_sched = model.deis
182
+ case "unipc":
183
+ use_sched = model.unipc
184
+ case "heun":
185
+ use_sched = model.heun
186
+ case "kdpm2_anc":
187
+ use_sched = model.kdpm2_anc
188
+ case "kdpm2":
189
+ use_sched = model.kdpm2
190
+
191
  if audio.ndim == 2:
192
  audio = audio.sum(axis=1) / 2
193
 
 
202
  audio = signal.resample(audio, new_length)
203
 
204
  weight_dtype = torch.float32
205
+ prompt = "a photo of <*>"
206
 
207
+ audio_values = (
208
+ torch.unsqueeze(torch.tensor(audio), dim=0).to(device).to(dtype=weight_dtype)
209
+ )
210
  if audio_values.ndim == 1:
211
  audio_values = torch.unsqueeze(audio_values, dim=0)
 
 
212
 
213
+ # i dont know why but this seems mandatory for deterministic results
214
+ with torch.no_grad():
215
+ aud_features = model.aud_encoder.extract_features(audio_values)[1]
216
+ audio_token = model.embedder(aud_features)
217
+ token_embeds = model.text_encoder.get_input_embeddings().weight.data
218
 
219
+ token_embeds[model.placeholder_token_id] = audio_token.clone()
220
+ generator = torch.Generator(device=device)
221
+ generator.manual_seed(23229249375547) # no reason this can't be input by the user!
222
  pipeline = StableDiffusionPipeline.from_pretrained(
223
+ pretrained_model_name_or_path=model.repo_id,
224
  tokenizer=model.tokenizer,
225
  text_encoder=model.text_encoder,
226
  vae=model.vae,
227
  unet=model.unet,
228
+ scheduler=use_sched,
229
+ safety_checker=None,
230
  ).to(device)
231
+ pipeline.enable_attention_slicing()
232
+ if torch.cuda.is_available():
233
+ pipeline.enable_xformers_memory_efficient_attention()
234
+
235
+ # print(f"taking {steps} steps using the {scheduler} scheduler")
236
+ image = pipeline(
237
+ prompt, num_inference_steps=steps, guidance_scale=8.5, generator=generator
238
+ ).images[0]
239
  return image
240
 
241
 
242
+ lora = False
243
+ repo_id = "philz1337/reliberate"
244
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
245
+ model = AudioTokenWrapper(lora, device)
246
+ model = model.to(device)
247
+ description = """<p>
248
+ This is a demo of <a href='https://pages.cs.huji.ac.il/adiyoss-lab/AudioToken' target='_blank'>AudioToken: Adaptation of Text-Conditioned Diffusion Models for Audio-to-Image Generation</a>.<br><br>
249
+ A novel method utilizing latent diffusion models trained for text-to-image-generation to generate images conditioned on audio recordings. Using a pre-trained audio encoding model, the proposed method encodes audio into a new token, which can be considered as an adaptation layer between the audio and text representations.<br><br>
250
+ For more information, please see the original <a href='https://arxiv.org/abs/2305.13050' target='_blank'>paper</a> and <a href='https://github.com/guyyariv/AudioToken' target='_blank'>repo</a>.
251
+ </p>"""
252
+
253
+ examples = [
254
+ # ["assets/train.wav"],
255
+ # ["assets/dog barking.wav"],
256
+ # ["assets/airplane taking off.wav"],
257
+ # ["assets/electric guitar.wav"],
258
+ # ["assets/female sings.wav"],
259
+ ]
260
+
261
+ my_demo = gr.Interface(
262
+ fn=greet,
263
+ inputs=[
264
+ "audio",
265
+ gr.Slider(value=25, step=1, label="diffusion steps"),
266
+ gr.Dropdown(
267
+ choices=[
268
+ "ddim",
269
+ "ddpm",
270
+ "pndm",
271
+ "lms",
272
+ "euler_anc",
273
+ "euler",
274
+ "dpm",
275
+ "dpms",
276
+ "deis",
277
+ "unipc",
278
+ "heun",
279
+ "kdpm2_anc",
280
+ "kdpm2",
281
+ ],
282
+ value="unipc",
283
+ ),
284
+ ],
285
+ outputs="image",
286
+ title="AudioToken",
287
+ description=description,
288
+ examples=examples,
289
+ )
290
+ my_demo.launch()
requirements.txt CHANGED
@@ -10,3 +10,4 @@ pandas
10
  torchaudio
11
  datasets
12
  scipy
 
 
10
  torchaudio
11
  datasets
12
  scipy
13
+ xformers --pre