genevera commited on
Commit
acfa0fb
1 Parent(s): b561bb5

allow user to set steps, pick scheduler, and make "gradio app.py" work

Browse files
Files changed (1) hide show
  1. app.py +110 -40
app.py CHANGED
@@ -6,11 +6,25 @@ from modules.AudioToken.embedder import FGAEmbedder
6
  from diffusers import AutoencoderKL, UNet2DConditionModel
7
  from diffusers.models.attention_processor import LoRAAttnProcessor
8
  from diffusers import StableDiffusionPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import numpy as np
10
  import gradio as gr
11
  from scipy import signal
12
 
13
-
14
  class AudioTokenWrapper(torch.nn.Module):
15
  """Simple wrapper module for Stable Diffusion that holds all the models together"""
16
 
@@ -22,17 +36,33 @@ class AudioTokenWrapper(torch.nn.Module):
22
 
23
  super().__init__()
24
  # Load scheduler and models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  self.tokenizer = CLIPTokenizer.from_pretrained(
26
- "CompVis/stable-diffusion-v1-4", subfolder="tokenizer"
27
  )
28
  self.text_encoder = CLIPTextModel.from_pretrained(
29
- "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", revision=None
30
  )
31
  self.unet = UNet2DConditionModel.from_pretrained(
32
- "CompVis/stable-diffusion-v1-4", subfolder="unet", revision=None
33
  )
34
  self.vae = AutoencoderKL.from_pretrained(
35
- "CompVis/stable-diffusion-v1-4", subfolder="vae", revision=None
36
  )
37
 
38
  checkpoint = torch.load(
@@ -90,11 +120,39 @@ class AudioTokenWrapper(torch.nn.Module):
90
  self.text_encoder.resize_token_embeddings(len(self.tokenizer))
91
 
92
 
93
- def greet(audio):
94
  sample_rate, audio = audio
95
  audio = audio.astype(np.float32, order='C') / 32768.0
96
  desired_sample_rate = 16000
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  if audio.ndim == 2:
99
  audio = audio.sum(axis=1) / 2
100
 
@@ -114,49 +172,61 @@ def greet(audio):
114
  audio_values = torch.unsqueeze(torch.tensor(audio), dim=0).to(device).to(dtype=weight_dtype)
115
  if audio_values.ndim == 1:
116
  audio_values = torch.unsqueeze(audio_values, dim=0)
117
- aud_features = model.aud_encoder.extract_features(audio_values)[1]
118
- audio_token = model.embedder(aud_features)
 
 
119
 
120
  token_embeds = model.text_encoder.get_input_embeddings().weight.data
121
  token_embeds[model.placeholder_token_id] = audio_token.clone()
122
-
 
123
  pipeline = StableDiffusionPipeline.from_pretrained(
124
- "CompVis/stable-diffusion-v1-4",
125
  tokenizer=model.tokenizer,
126
  text_encoder=model.text_encoder,
127
  vae=model.vae,
128
  unet=model.unet,
 
 
129
  ).to(device)
130
- image = pipeline(prompt, num_inference_steps=40, guidance_scale=7.5).images[0]
 
 
 
 
131
  return image
132
 
133
 
134
- if __name__ == "__main__":
135
-
136
- lora = False
137
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
138
- model = AudioTokenWrapper(lora, device)
139
-
140
- description = """<p>
141
- This is a demo of <a href='https://pages.cs.huji.ac.il/adiyoss-lab/AudioToken' target='_blank'>AudioToken: Adaptation of Text-Conditioned Diffusion Models for Audio-to-Image Generation</a>.<br><br>
142
- A novel method utilizing latent diffusion models trained for text-to-image-generation to generate images conditioned on audio recordings. Using a pre-trained audio encoding model, the proposed method encodes audio into a new token, which can be considered as an adaptation layer between the audio and text representations.<br><br>
143
- For more information, please see the original <a href='https://arxiv.org/abs/2305.13050' target='_blank'>paper</a> and <a href='https://github.com/guyyariv/AudioToken' target='_blank'>repo</a>.
144
- </p>"""
145
-
146
- examples = [
147
- # ["assets/train.wav"],
148
- ["assets/dog barking.wav"],
149
- ["assets/airplane taking off.wav"],
150
- # ["assets/electric guitar.wav"],
151
- # ["assets/female sings.wav"],
152
- ]
153
-
154
- demo = gr.Interface(
155
- fn=greet,
156
- inputs="audio",
157
- outputs="image",
158
- title='AudioToken',
159
- description=description,
160
- examples=examples
161
- )
162
- demo.launch()
 
 
 
 
6
  from diffusers import AutoencoderKL, UNet2DConditionModel
7
  from diffusers.models.attention_processor import LoRAAttnProcessor
8
  from diffusers import StableDiffusionPipeline
9
+ from diffusers import (
10
+ DDPMScheduler,
11
+ DDIMScheduler,
12
+ PNDMScheduler,
13
+ LMSDiscreteScheduler,
14
+ EulerDiscreteScheduler,
15
+ EulerAncestralDiscreteScheduler,
16
+ DPMSolverMultistepScheduler,
17
+ DPMSolverSinglestepScheduler,
18
+ DEISMultistepScheduler,
19
+ UniPCMultistepScheduler,
20
+ HeunDiscreteScheduler,
21
+ KDPM2AncestralDiscreteScheduler,
22
+ KDPM2DiscreteScheduler,
23
+ )
24
  import numpy as np
25
  import gradio as gr
26
  from scipy import signal
27
 
 
28
  class AudioTokenWrapper(torch.nn.Module):
29
  """Simple wrapper module for Stable Diffusion that holds all the models together"""
30
 
 
36
 
37
  super().__init__()
38
  # Load scheduler and models
39
+ self.ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
40
+ self.ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler")
41
+ self.pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler")
42
+ self.lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
43
+ self.euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
44
+ self.euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
45
+ self.dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
46
+ self.dpms = DPMSolverSinglestepScheduler.from_pretrained(repo_id, subfolder="scheduler")
47
+ self.deis = DEISMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
48
+ self.unipc = UniPCMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
49
+ self.heun = HeunDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
50
+ self.kdpm2_anc = KDPM2AncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
51
+ self.kdpm2 = KDPM2DiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
52
+
53
+
54
+
55
  self.tokenizer = CLIPTokenizer.from_pretrained(
56
+ repo_id, subfolder="tokenizer"
57
  )
58
  self.text_encoder = CLIPTextModel.from_pretrained(
59
+ repo_id, subfolder="text_encoder", revision=None
60
  )
61
  self.unet = UNet2DConditionModel.from_pretrained(
62
+ repo_id, subfolder="unet", revision=None
63
  )
64
  self.vae = AutoencoderKL.from_pretrained(
65
+ repo_id, subfolder="vae", revision=None
66
  )
67
 
68
  checkpoint = torch.load(
 
120
  self.text_encoder.resize_token_embeddings(len(self.tokenizer))
121
 
122
 
123
+ def greet(audio, steps=25, scheduler="ddpm"):
124
  sample_rate, audio = audio
125
  audio = audio.astype(np.float32, order='C') / 32768.0
126
  desired_sample_rate = 16000
127
 
128
+ match scheduler:
129
+ case "ddpm":
130
+ use_sched = model.ddpm
131
+ case "ddim":
132
+ use_sched = model.ddim
133
+ case "pndm":
134
+ use_sched = model.pndm
135
+ case "lms":
136
+ use_sched = model.lms
137
+ case "euler_anc":
138
+ use_sched = model.euler_anc
139
+ case "euler":
140
+ use_sched = model.euler
141
+ case "dpm":
142
+ use_sched = model.dpm
143
+ case "dpms":
144
+ use_sched = model.dpms
145
+ case "deis":
146
+ use_sched = model.deis
147
+ case "unipc":
148
+ use_sched = model.unipc
149
+ case "heun":
150
+ use_sched = model.heun
151
+ case "kdpm2_anc":
152
+ use_sched = model.kdpm2_anc
153
+ case "kdpm2":
154
+ use_sched = model.kdpm2
155
+
156
  if audio.ndim == 2:
157
  audio = audio.sum(axis=1) / 2
158
 
 
172
  audio_values = torch.unsqueeze(torch.tensor(audio), dim=0).to(device).to(dtype=weight_dtype)
173
  if audio_values.ndim == 1:
174
  audio_values = torch.unsqueeze(audio_values, dim=0)
175
+ with torch.no_grad():
176
+ torch.cuda.empty_cache()
177
+ aud_features = model.aud_encoder.extract_features(audio_values)[1]
178
+ audio_token = model.embedder(aud_features)
179
 
180
  token_embeds = model.text_encoder.get_input_embeddings().weight.data
181
  token_embeds[model.placeholder_token_id] = audio_token.clone()
182
+ g_gpu = torch.Generator(device='cuda')
183
+ g_gpu.manual_seed(23029249075547) # no reason this can't be input by the user!
184
  pipeline = StableDiffusionPipeline.from_pretrained(
185
+ "philz1337/reliberate",
186
  tokenizer=model.tokenizer,
187
  text_encoder=model.text_encoder,
188
  vae=model.vae,
189
  unet=model.unet,
190
+ scheduler=use_sched,
191
+ safety_checker=None,
192
  ).to(device)
193
+ pipeline.enable_attention_slicing()
194
+ # pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
195
+ # pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
196
+ print(f"taking {steps} steps using the {scheduler} scheduler")
197
+ image = pipeline(prompt, num_inference_steps=steps, guidance_scale=8.5, generator=g_gpu).images[0]
198
  return image
199
 
200
 
201
+ lora = False
202
+ repo_id = "philz1337/reliberate"
203
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
204
+ model = AudioTokenWrapper(lora, device)
205
+ model = model.to(device)
206
+ description = """<p>
207
+ This is a demo of <a href='https://pages.cs.huji.ac.il/adiyoss-lab/AudioToken' target='_blank'>AudioToken: Adaptation of Text-Conditioned Diffusion Models for Audio-to-Image Generation</a>.<br><br>
208
+ A novel method utilizing latent diffusion models trained for text-to-image-generation to generate images conditioned on audio recordings. Using a pre-trained audio encoding model, the proposed method encodes audio into a new token, which can be considered as an adaptation layer between the audio and text representations.<br><br>
209
+ For more information, please see the original <a href='https://arxiv.org/abs/2305.13050' target='_blank'>paper</a> and <a href='https://github.com/guyyariv/AudioToken' target='_blank'>repo</a>.
210
+ </p>"""
211
+
212
+ examples = [
213
+ # ["assets/train.wav"],
214
+ # ["assets/dog barking.wav"],
215
+ # ["assets/airplane taking off.wav"],
216
+ # ["assets/electric guitar.wav"],
217
+ # ["assets/female sings.wav"],
218
+ ]
219
+
220
+ my_demo = gr.Interface(
221
+ fn=greet,
222
+ inputs=[
223
+ "audio",
224
+ gr.Slider(value=25,step=1,label="diffusion steps"),
225
+ gr.Dropdown(choices=["ddim","ddpm","pndm","lms","euler_anc","euler","dpm","dpms","deis","unipc","heun","kdpm2_anc","kdpm2"],value="unipc"),
226
+ ],
227
+ outputs="image",
228
+ title='AudioToken',
229
+ description=description,
230
+ examples=examples
231
+ )
232
+ my_demo.launch()