haoheliu sanchit-gandhi HF staff commited on
Commit
3c4ad5e
•
1 Parent(s): 5cda646

Swap to HF diffusers (#258)

Browse files

- Swap to HF diffusers (d03edfafdf371af210d3a8c062f61da2683d3e7c)
- camera ready (f1daa6039ead8b008c9f9a424f88cc6ce5d7ae1e)


Co-authored-by: Sanchit Gandhi <sanchit-gandhi@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -0
  2. app.py +196 -228
  3. audioldm/__init__.py +0 -3
  4. audioldm/audio/__init__.py +0 -0
  5. audioldm/audio/audio_processing.py +0 -100
  6. audioldm/audio/stft.py +0 -180
  7. audioldm/audio/tools.py +0 -33
  8. audioldm/clap/__init__.py +0 -0
  9. audioldm/clap/encoders.py +0 -170
  10. audioldm/clap/open_clip/__init__.py +0 -25
  11. audioldm/clap/open_clip/bert.py +0 -40
  12. audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz +0 -3
  13. audioldm/clap/open_clip/factory.py +0 -277
  14. audioldm/clap/open_clip/feature_fusion.py +0 -192
  15. audioldm/clap/open_clip/htsat.py +0 -1308
  16. audioldm/clap/open_clip/linear_probe.py +0 -66
  17. audioldm/clap/open_clip/loss.py +0 -398
  18. audioldm/clap/open_clip/model.py +0 -936
  19. audioldm/clap/open_clip/model_configs/HTSAT-base.json +0 -23
  20. audioldm/clap/open_clip/model_configs/HTSAT-large.json +0 -23
  21. audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +0 -23
  22. audioldm/clap/open_clip/model_configs/HTSAT-tiny.json +0 -23
  23. audioldm/clap/open_clip/model_configs/PANN-10.json +0 -23
  24. audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json +0 -23
  25. audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +0 -23
  26. audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +0 -23
  27. audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json +0 -23
  28. audioldm/clap/open_clip/model_configs/PANN-14.json +0 -23
  29. audioldm/clap/open_clip/model_configs/PANN-6.json +0 -23
  30. audioldm/clap/open_clip/model_configs/RN101-quickgelu.json +0 -22
  31. audioldm/clap/open_clip/model_configs/RN101.json +0 -21
  32. audioldm/clap/open_clip/model_configs/RN50-quickgelu.json +0 -22
  33. audioldm/clap/open_clip/model_configs/RN50.json +0 -21
  34. audioldm/clap/open_clip/model_configs/RN50x16.json +0 -21
  35. audioldm/clap/open_clip/model_configs/RN50x4.json +0 -21
  36. audioldm/clap/open_clip/model_configs/ViT-B-16.json +0 -16
  37. audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json +0 -17
  38. audioldm/clap/open_clip/model_configs/ViT-B-32.json +0 -16
  39. audioldm/clap/open_clip/model_configs/ViT-L-14.json +0 -16
  40. audioldm/clap/open_clip/openai.py +0 -156
  41. audioldm/clap/open_clip/pann_model.py +0 -703
  42. audioldm/clap/open_clip/pretrained.py +0 -167
  43. audioldm/clap/open_clip/timm_model.py +0 -112
  44. audioldm/clap/open_clip/tokenizer.py +0 -197
  45. audioldm/clap/open_clip/transform.py +0 -45
  46. audioldm/clap/open_clip/utils.py +0 -361
  47. audioldm/clap/open_clip/version.py +0 -1
  48. audioldm/clap/training/__init__.py +0 -0
  49. audioldm/clap/training/audioset_textmap.npy +0 -3
  50. audioldm/clap/training/data.py +0 -977
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 3.27.0
8
  app_file: app.py
9
  pinned: false
10
  license: bigscience-openrail-m
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
8
  app_file: app.py
9
  pinned: false
10
  license: bigscience-openrail-m
11
+ duplicated_from: haoheliu/audioldm-text-to-audio-generation
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,197 +1,139 @@
1
  import gradio as gr
2
- import numpy as np
3
- from audioldm import text_to_audio, build_model
4
  from share_btn import community_icon_html, loading_icon_html, share_js
5
 
6
- model_id="haoheliu/AudioLDM-S-Full"
7
 
8
- audioldm = None
9
- current_model_name = None
10
 
11
- # def predict(input, history=[]):
12
- # # tokenize the new input sentence
13
- # new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
 
 
 
 
14
 
15
- # # append the new user input tokens to the chat history
16
- # bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
 
 
17
 
18
- # # generate a response
19
- # history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
 
20
 
21
- # # convert the tokens to text, and then split the responses into lines
22
- # response = tokenizer.decode(history[0]).split("<|endoftext|>")
23
- # response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list
24
- # return response, history
25
-
26
- def text2audio(text, duration, guidance_scale, random_seed, n_candidates, model_name="audioldm-m-text-ft"):
27
- global audioldm, current_model_name
28
-
29
- if audioldm is None or model_name != current_model_name:
30
- audioldm=build_model(model_name=model_name)
31
- current_model_name = model_name
32
-
33
- # print(text, length, guidance_scale)
34
- waveform = text_to_audio(
35
- latent_diffusion=audioldm,
36
- text=text,
37
- seed=random_seed,
38
- duration=duration,
39
  guidance_scale=guidance_scale,
40
- n_candidate_gen_per_text=int(n_candidates),
41
- ) # [bs, 1, samples]
42
- waveform = [
43
- gr.make_waveform((16000, wave[0]), bg_image="bg.png") for wave in waveform
44
- ]
45
- # waveform = [(16000, np.random.randn(16000)), (16000, np.random.randn(16000))]
46
- if(len(waveform) == 1):
47
- waveform = waveform[0]
48
- return waveform
 
 
 
49
 
50
- # iface = gr.Interface(fn=text2audio, inputs=[
51
- # gr.Textbox(value="A man is speaking in a huge room", max_lines=1),
52
- # gr.Slider(2.5, 10, value=5, step=2.5),
53
- # gr.Slider(0, 5, value=2.5, step=0.5),
54
- # gr.Number(value=42)
55
- # ], outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")],
56
- # allow_flagging="never"
57
- # )
58
- # iface.launch(share=True)
59
 
60
 
61
  css = """
62
  a {
63
- color: inherit;
64
- text-decoration: underline;
65
- }
66
- .gradio-container {
67
  font-family: 'IBM Plex Sans', sans-serif;
68
- }
69
- .gr-button {
70
- color: white;
71
- border-color: #000000;
72
- background: #000000;
73
- }
74
- input[type='range'] {
75
  accent-color: #000000;
76
- }
77
- .dark input[type='range'] {
78
  accent-color: #dfdfdf;
79
- }
80
- .container {
81
- max-width: 730px;
82
- margin: auto;
83
- padding-top: 1.5rem;
84
- }
85
- #gallery {
86
- min-height: 22rem;
87
- margin-bottom: 15px;
88
- margin-left: auto;
89
- margin-right: auto;
90
- border-bottom-right-radius: .5rem !important;
91
- border-bottom-left-radius: .5rem !important;
92
- }
93
- #gallery>div>.h-full {
94
  min-height: 20rem;
95
- }
96
- .details:hover {
97
  text-decoration: underline;
98
- }
99
- .gr-button {
100
  white-space: nowrap;
101
- }
102
- .gr-button:focus {
103
- border-color: rgb(147 197 253 / var(--tw-border-opacity));
104
- outline: none;
105
- box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
106
- --tw-border-opacity: 1;
107
- --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
108
- --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
109
- --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
110
- --tw-ring-opacity: .5;
111
- }
112
- #advanced-btn {
113
- font-size: .7rem !important;
114
- line-height: 19px;
115
- margin-top: 12px;
116
- margin-bottom: 12px;
117
- padding: 2px 8px;
118
  border-radius: 14px !important;
119
- }
120
- #advanced-options {
121
  margin-bottom: 20px;
122
- }
123
- .footer {
124
- margin-bottom: 45px;
125
- margin-top: 35px;
126
- text-align: center;
127
- border-bottom: 1px solid #e5e5e5;
128
- }
129
- .footer>p {
130
- font-size: .8rem;
131
- display: inline-block;
132
- padding: 0 10px;
133
- transform: translateY(10px);
134
- background: white;
135
- }
136
- .dark .footer {
137
  border-color: #303030;
138
- }
139
- .dark .footer>p {
140
  background: #0b0f19;
141
- }
142
- .acknowledgments h4{
143
- margin: 1.25em 0 .25em 0;
144
- font-weight: bold;
145
- font-size: 115%;
146
- }
147
- #container-advanced-btns{
148
- display: flex;
149
- flex-wrap: wrap;
150
- justify-content: space-between;
151
- align-items: center;
152
- }
153
- .animate-spin {
154
  animation: spin 1s linear infinite;
155
- }
156
- @keyframes spin {
157
  from {
158
  transform: rotate(0deg);
159
- }
160
- to {
161
  transform: rotate(360deg);
162
  }
163
- }
164
- #share-btn-container {
165
- display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
166
- margin-top: 10px;
167
- margin-left: auto;
168
- }
169
- #share-btn {
170
- all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
171
- }
172
- #share-btn * {
173
  all: unset;
174
- }
175
- #share-btn-container div:nth-child(-n+2){
176
- width: auto !important;
177
- min-height: 0px !important;
178
- }
179
- #share-btn-container .wrap {
180
  display: none !important;
181
- }
182
- .gr-form{
183
  flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
184
- }
185
- #prompt-container{
186
  gap: 0;
187
- }
188
- #generated_id{
189
  min-height: 700px
190
- }
191
- #setting_id{
192
- margin-bottom: 12px;
193
- text-align: center;
194
- font-weight: 900;
195
  }
196
  """
197
  iface = gr.Blocks(css=css)
@@ -202,56 +144,72 @@ with iface:
202
  <div style="text-align: center; max-width: 700px; margin: 0 auto;">
203
  <div
204
  style="
205
- display: inline-flex;
206
- align-items: center;
207
- gap: 0.8rem;
208
- font-size: 1.75rem;
209
  "
210
  >
211
  <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
212
  AudioLDM: Text-to-Audio Generation with Latent Diffusion Models
213
  </h1>
214
- </div>
215
- <p style="margin-bottom: 10px; font-size: 94%">
216
- <a href="https://arxiv.org/abs/2301.12503">[Paper]</a> <a href="https://audioldm.github.io/">[Project page]</a>
 
217
  </p>
218
  </div>
219
  """
220
  )
221
- gr.HTML("""
222
- <h1 style="font-weight: 900; margin-bottom: 7px;">
223
- AudioLDM: Text-to-Audio Generation with Latent Diffusion Models
224
- </h1>
225
- <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
226
- <br/>
227
- <a href="https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation?duplicate=true">
228
- <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
229
- <p/>
230
- """)
231
  with gr.Group():
232
  with gr.Box():
233
- ############# Input
234
- textbox = gr.Textbox(value="A hammer is hitting a wooden surface", max_lines=1, label="Input your text here. Your text is important for the audio quality. Please ensure it is descriptive by using more adjectives.", elem_id="prompt-in")
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  with gr.Accordion("Click to modify detailed configurations", open=False):
237
- seed = gr.Number(value=45, label="Change this value (any integer number) will lead to a different generation result.")
238
- duration = gr.Slider(2.5, 10, value=5, step=2.5, label="Duration (seconds)")
239
- guidance_scale = gr.Slider(0, 4, value=2.5, step=0.5, label="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)")
240
- n_candidates = gr.Slider(1, 3, value=3, step=1, label="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation")
241
- # model_name = gr.Dropdown(
242
- # ["audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full","audioldm-s-full-v2", "audioldm-s-full", "audioldm-l-full"], value="audioldm-m-full", label="Choose the model to use. audioldm-m-text-ft and audioldm-s-text-ft are recommanded. -s- means small, -m- means medium and -l- means large",
243
- # )
244
- ############# Output
245
- # outputs=gr.Audio(label="Output", type="numpy")
246
- outputs=gr.Video(label="Output", elem_id="output-video")
247
-
248
- # with gr.Group(elem_id="container-advanced-btns"):
249
- # # advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
250
- # with gr.Group(elem_id="share-btn-container"):
251
- # community_icon = gr.HTML(community_icon_html, visible=False)
252
- # loading_icon = gr.HTML(loading_icon_html, visible=False)
253
- # share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
254
- # outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")]
 
 
 
 
 
 
255
  btn = gr.Button("Submit").style(full_width=True)
256
 
257
  with gr.Group(elem_id="share-btn-container", visible=False):
@@ -259,51 +217,61 @@ with iface:
259
  loading_icon = gr.HTML(loading_icon_html)
260
  share_button = gr.Button("Share to community", elem_id="share-btn")
261
 
262
- # btn.click(text2audio, inputs=[
263
- # textbox, duration, guidance_scale, seed, n_candidates, model_name], outputs=[outputs])
264
- btn.click(text2audio, inputs=[
265
- textbox, duration, guidance_scale, seed, n_candidates], outputs=[outputs])
266
-
 
267
  share_button.click(None, [], [], _js=share_js)
268
- gr.HTML('''
 
269
  <div class="footer" style="text-align: center; max-width: 700px; margin: 0 auto;">
270
- <p>Follow the latest update of AudioLDM on our<a href="https://github.com/haoheliu/AudioLDM" style="text-decoration: underline;" target="_blank"> Github repo</a>
271
- </p>
272
- <br>
273
- <p>Model by <a href="https://twitter.com/LiuHaohe" style="text-decoration: underline;" target="_blank">Haohe Liu</a></p>
274
- <br>
275
  </div>
276
- ''')
277
- gr.Examples([
278
- ["A hammer is hitting a wooden surface", 5, 2.5, 45, 3, "audioldm-m-full"],
279
- ["Peaceful and calming ambient music with singing bowl and other instruments.", 5, 2.5, 45, 3, "audioldm-m-full"],
280
- ["A man is speaking in a small room.", 5, 2.5, 45, 3, "audioldm-m-full"],
281
- ["A female is speaking followed by footstep sound", 5, 2.5, 45, 3, "audioldm-m-full"],
282
- ["Wooden table tapping sound followed by water pouring sound.", 5, 2.5, 45, 3, "audioldm-m-full"],
283
- ],
 
 
284
  fn=text2audio,
285
- # inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name],
286
- inputs=[textbox, duration, guidance_scale, seed, n_candidates],
287
  outputs=[outputs],
288
  cache_examples=True,
289
  )
290
- gr.HTML('''
291
- <div class="acknowledgements">
292
- <p>Essential Tricks for Enhancing the Quality of Your Generated Audio</p>
293
- <p>1. Try to use more adjectives to describe your sound. For example: "A man is speaking clearly and slowly in a large room" is better than "A man is speaking". This can make sure AudioLDM understands what you want.</p>
294
- <p>2. Try to use different random seeds, which can affect the generation quality significantly sometimes.</p>
295
- <p>3. It's better to use general terms like 'man' or 'woman' instead of specific names for individuals or abstract objects that humans may not be familiar with, such as 'mummy'.</p>
296
- </div>
297
- ''')
 
 
 
 
298
  with gr.Accordion("Additional information", open=False):
299
  gr.HTML(
300
  """
301
  <div class="acknowledgments">
302
- <p> We build the model with data from <a href="http://research.google.com/audioset/">AudioSet</a>, <a href="https://freesound.org/">Freesound</a> and <a href="https://sound-effects.bbcrewind.co.uk/">BBC Sound Effect library</a>. We share this demo based on the <a href="https://assets.publishing.service.gov.uk/government/uploads/system/uploads/attachment_data/file/375954/Research.pdf">UK copyright exception</a> of data for academic research. </p>
 
 
 
 
 
303
  </div>
304
  """
305
  )
306
  # <p>This demo is strictly for research demo purpose only. For commercial use please <a href="haoheliu@gmail.com">contact us</a>.</p>
307
 
308
  iface.queue(max_size=10).launch(debug=True)
309
- # iface.launch(debug=True, share=True)
1
  import gradio as gr
2
+ import torch
3
+ from diffusers import AudioLDMPipeline
4
  from share_btn import community_icon_html, loading_icon_html, share_js
5
 
6
+ from transformers import AutoProcessor, ClapModel
7
 
 
 
8
 
9
+ # make Space compatible with CPU duplicates
10
+ if torch.cuda.is_available():
11
+ device = "cuda"
12
+ torch_dtype = torch.float16
13
+ else:
14
+ device = "cpu"
15
+ torch_dtype = torch.float32
16
 
17
+ # load the diffusers pipeline
18
+ repo_id = "cvssp/audioldm-m-full"
19
+ pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
20
+ pipe.unet = torch.compile(pipe.unet)
21
 
22
+ # CLAP model (only required for automatic scoring)
23
+ clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device)
24
+ processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full")
25
 
26
+ generator = torch.Generator(device)
27
+
28
+
29
+ def text2audio(text, negative_prompt, duration, guidance_scale, random_seed, n_candidates):
30
+ if text is None:
31
+ raise gr.Error("Please provide a text input.")
32
+
33
+ waveforms = pipe(
34
+ text,
35
+ audio_length_in_s=duration,
 
 
 
 
 
 
 
 
36
  guidance_scale=guidance_scale,
37
+ negative_prompt=negative_prompt,
38
+ num_waveforms_per_prompt=n_candidates if n_candidates else 1,
39
+ generator=generator.manual_seed(int(random_seed)),
40
+ )["audios"]
41
+
42
+ if waveforms.shape[0] > 1:
43
+ waveform = score_waveforms(text, waveforms)
44
+ else:
45
+ waveform = waveforms[0]
46
+
47
+ return gr.make_waveform((16000, waveform), bg_image="bg.png")
48
+
49
 
50
+ def score_waveforms(text, waveforms):
51
+ inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True)
52
+ inputs = {key: inputs[key].to(device) for key in inputs}
53
+ with torch.no_grad():
54
+ logits_per_text = clap_model(**inputs).logits_per_text # this is the audio-text similarity score
55
+ probs = logits_per_text.softmax(dim=-1) # we can take the softmax to get the label probabilities
56
+ most_probable = torch.argmax(probs) # and now select the most likely audio waveform
57
+ waveform = waveforms[most_probable]
58
+ return waveform
59
 
60
 
61
  css = """
62
  a {
63
+ color: inherit; text-decoration: underline;
64
+ } .gradio-container {
 
 
65
  font-family: 'IBM Plex Sans', sans-serif;
66
+ } .gr-button {
67
+ color: white; border-color: #000000; background: #000000;
68
+ } input[type='range'] {
 
 
 
 
69
  accent-color: #000000;
70
+ } .dark input[type='range'] {
 
71
  accent-color: #dfdfdf;
72
+ } .container {
73
+ max-width: 730px; margin: auto; padding-top: 1.5rem;
74
+ } #gallery {
75
+ min-height: 22rem; margin-bottom: 15px; margin-left: auto; margin-right: auto; border-bottom-right-radius:
76
+ .5rem !important; border-bottom-left-radius: .5rem !important;
77
+ } #gallery>div>.h-full {
 
 
 
 
 
 
 
 
 
78
  min-height: 20rem;
79
+ } .details:hover {
 
80
  text-decoration: underline;
81
+ } .gr-button {
 
82
  white-space: nowrap;
83
+ } .gr-button:focus {
84
+ border-color: rgb(147 197 253 / var(--tw-border-opacity)); outline: none; box-shadow:
85
+ var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); --tw-border-opacity: 1;
86
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width)
87
+ var(--tw-ring-offset-color); --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px
88
+ var(--tw-ring-offset-width)) var(--tw-ring-color); --tw-ring-color: rgb(191 219 254 /
89
+ var(--tw-ring-opacity)); --tw-ring-opacity: .5;
90
+ } #advanced-btn {
91
+ font-size: .7rem !important; line-height: 19px; margin-top: 12px; margin-bottom: 12px; padding: 2px 8px;
 
 
 
 
 
 
 
 
92
  border-radius: 14px !important;
93
+ } #advanced-options {
 
94
  margin-bottom: 20px;
95
+ } .footer {
96
+ margin-bottom: 45px; margin-top: 35px; text-align: center; border-bottom: 1px solid #e5e5e5;
97
+ } .footer>p {
98
+ font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white;
99
+ } .dark .footer {
 
 
 
 
 
 
 
 
 
 
100
  border-color: #303030;
101
+ } .dark .footer>p {
 
102
  background: #0b0f19;
103
+ } .acknowledgments h4{
104
+ margin: 1.25em 0 .25em 0; font-weight: bold; font-size: 115%;
105
+ } #container-advanced-btns{
106
+ display: flex; flex-wrap: wrap; justify-content: space-between; align-items: center;
107
+ } .animate-spin {
 
 
 
 
 
 
 
 
108
  animation: spin 1s linear infinite;
109
+ } @keyframes spin {
 
110
  from {
111
  transform: rotate(0deg);
112
+ } to {
 
113
  transform: rotate(360deg);
114
  }
115
+ } #share-btn-container {
116
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color:
117
+ #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
118
+ margin-top: 10px; margin-left: auto;
119
+ } #share-btn {
120
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif;
121
+ margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem
122
+ !important;right:0;
123
+ } #share-btn * {
 
124
  all: unset;
125
+ } #share-btn-container div:nth-child(-n+2){
126
+ width: auto !important; min-height: 0px !important;
127
+ } #share-btn-container .wrap {
 
 
 
128
  display: none !important;
129
+ } .gr-form{
 
130
  flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
131
+ } #prompt-container{
 
132
  gap: 0;
133
+ } #generated_id{
 
134
  min-height: 700px
135
+ } #setting_id{
136
+ margin-bottom: 12px; text-align: center; font-weight: 900;
 
 
 
137
  }
138
  """
139
  iface = gr.Blocks(css=css)
144
  <div style="text-align: center; max-width: 700px; margin: 0 auto;">
145
  <div
146
  style="
147
+ display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
 
 
 
148
  "
149
  >
150
  <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
151
  AudioLDM: Text-to-Audio Generation with Latent Diffusion Models
152
  </h1>
153
+ </div> <p style="margin-bottom: 10px; font-size: 94%">
154
+ <a href="https://arxiv.org/abs/2301.12503">[Paper]</a> <a href="https://audioldm.github.io/">[Project
155
+ page]</a> <a href="https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm">[🧨
156
+ Diffusers]</a>
157
  </p>
158
  </div>
159
  """
160
  )
161
+ gr.HTML(
162
+ """
163
+ <p>This is the demo for AudioLDM, powered by 🧨 Diffusers. Demo uses the checkpoint <a
164
+ href="https://huggingface.co/cvssp/audioldm-m-full"> audioldm-m-full </a>. For faster inference without waiting in
165
+ queue, you may duplicate the space and upgrade to a GPU in the settings. <br/> <a
166
+ href="https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation?duplicate=true"> <img
167
+ style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> <p/>
168
+ """
169
+ )
170
+
171
  with gr.Group():
172
  with gr.Box():
173
+ textbox = gr.Textbox(
174
+ value="A hammer is hitting a wooden surface",
175
+ max_lines=1,
176
+ label="Input text",
177
+ info="Your text is important for the audio quality. Please ensure it is descriptive by using more adjectives.",
178
+ elem_id="prompt-in",
179
+ )
180
+ negative_textbox = gr.Textbox(
181
+ value="low quality, average quality",
182
+ max_lines=1,
183
+ label="Negative prompt",
184
+ info="Enter a negative prompt not to guide the audio generation. Selecting appropriate negative prompts can improve the audio quality significantly.",
185
+ elem_id="prompt-in",
186
+ )
187
 
188
  with gr.Accordion("Click to modify detailed configurations", open=False):
189
+ seed = gr.Number(
190
+ value=45,
191
+ label="Seed",
192
+ info="Change this value (any integer number) will lead to a different generation result.",
193
+ )
194
+ duration = gr.Slider(2.5, 10, value=5, step=2.5, label="Duration (seconds)")
195
+ guidance_scale = gr.Slider(
196
+ 0,
197
+ 4,
198
+ value=2.5,
199
+ step=0.5,
200
+ label="Guidance scale",
201
+ info="Large => better quality and relevancy to text; Small => better diversity",
202
+ )
203
+ n_candidates = gr.Slider(
204
+ 1,
205
+ 3,
206
+ value=3,
207
+ step=1,
208
+ label="Number waveforms to generate",
209
+ info="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
210
+ )
211
+
212
+ outputs = gr.Video(label="Output", elem_id="output-video")
213
  btn = gr.Button("Submit").style(full_width=True)
214
 
215
  with gr.Group(elem_id="share-btn-container", visible=False):
217
  loading_icon = gr.HTML(loading_icon_html)
218
  share_button = gr.Button("Share to community", elem_id="share-btn")
219
 
220
+ btn.click(
221
+ text2audio,
222
+ inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates],
223
+ outputs=[outputs],
224
+ )
225
+
226
  share_button.click(None, [], [], _js=share_js)
227
+ gr.HTML(
228
+ """
229
  <div class="footer" style="text-align: center; max-width: 700px; margin: 0 auto;">
230
+ <p>Follow the latest update of AudioLDM on our<a href="https://github.com/haoheliu/AudioLDM"
231
+ style="text-decoration: underline;" target="_blank"> Github repo</a> </p> <br> <p>Model by <a
232
+ href="https://twitter.com/LiuHaohe" style="text-decoration: underline;" target="_blank">Haohe
233
+ Liu</a>. Code and demo by 🤗 Hugging Face.</p> <br>
 
234
  </div>
235
+ """
236
+ )
237
+ gr.Examples(
238
+ [
239
+ ["A hammer is hitting a wooden surface", "low quality, average quality", 5, 2.5, 45, 3],
240
+ ["Peaceful and calming ambient music with singing bowl and other instruments.", "low quality, average quality", 5, 2.5, 45, 3],
241
+ ["A man is speaking in a small room.", "low quality, average quality", 5, 2.5, 45, 3],
242
+ ["A female is speaking followed by footstep sound", "low quality, average quality", 5, 2.5, 45, 3],
243
+ ["Wooden table tapping sound followed by water pouring sound.", "low quality, average quality", 5, 2.5, 45, 3],
244
+ ],
245
  fn=text2audio,
246
+ inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates],
 
247
  outputs=[outputs],
248
  cache_examples=True,
249
  )
250
+ gr.HTML(
251
+ """
252
+ <div class="acknowledgements"> <p>Essential Tricks for Enhancing the Quality of Your Generated
253
+ Audio</p> <p>1. Try to use more adjectives to describe your sound. For example: "A man is speaking
254
+ clearly and slowly in a large room" is better than "A man is speaking". This can make sure AudioLDM
255
+ understands what you want.</p> <p>2. Try to use different random seeds, which can affect the generation
256
+ quality significantly sometimes.</p> <p>3. It's better to use general terms like 'man' or 'woman'
257
+ instead of specific names for individuals or abstract objects that humans may not be familiar with,
258
+ such as 'mummy'.</p> <p>4. Using a negative prompt to not guide the diffusion process can improve the
259
+ audio quality significantly. Try using negative prompts like 'low quality'.</p> </div>
260
+ """
261
+ )
262
  with gr.Accordion("Additional information", open=False):
263
  gr.HTML(
264
  """
265
  <div class="acknowledgments">
266
+ <p> We build the model with data from <a href="http://research.google.com/audioset/">AudioSet</a>,
267
+ <a href="https://freesound.org/">Freesound</a> and <a
268
+ href="https://sound-effects.bbcrewind.co.uk/">BBC Sound Effect library</a>. We share this demo
269
+ based on the <a
270
+ href="https://assets.publishing.service.gov.uk/government/uploads/system/uploads/attachment_data/file/375954/Research.pdf">UK
271
+ copyright exception</a> of data for academic research. </p>
272
  </div>
273
  """
274
  )
275
  # <p>This demo is strictly for research demo purpose only. For commercial use please <a href="haoheliu@gmail.com">contact us</a>.</p>
276
 
277
  iface.queue(max_size=10).launch(debug=True)
 
audioldm/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .ldm import LatentDiffusion
2
- from .utils import seed_everything
3
- from .pipeline import *
 
 
 
audioldm/audio/__init__.py DELETED
File without changes
audioldm/audio/audio_processing.py DELETED
@@ -1,100 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import librosa.util as librosa_util
4
- from scipy.signal import get_window
5
-
6
-
7
- def window_sumsquare(
8
- window,
9
- n_frames,
10
- hop_length,
11
- win_length,
12
- n_fft,
13
- dtype=np.float32,
14
- norm=None,
15
- ):
16
- """
17
- # from librosa 0.6
18
- Compute the sum-square envelope of a window function at a given hop length.
19
-
20
- This is used to estimate modulation effects induced by windowing
21
- observations in short-time fourier transforms.
22
-
23
- Parameters
24
- ----------
25
- window : string, tuple, number, callable, or list-like
26
- Window specification, as in `get_window`
27
-
28
- n_frames : int > 0
29
- The number of analysis frames
30
-
31
- hop_length : int > 0
32
- The number of samples to advance between frames
33
-
34
- win_length : [optional]
35
- The length of the window function. By default, this matches `n_fft`.
36
-
37
- n_fft : int > 0
38
- The length of each analysis frame.
39
-
40
- dtype : np.dtype
41
- The data type of the output
42
-
43
- Returns
44
- -------
45
- wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46
- The sum-squared envelope of the window function
47
- """
48
- if win_length is None:
49
- win_length = n_fft
50
-
51
- n = n_fft + hop_length * (n_frames - 1)
52
- x = np.zeros(n, dtype=dtype)
53
-
54
- # Compute the squared window at the desired length
55
- win_sq = get_window(window, win_length, fftbins=True)
56
- win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57
- win_sq = librosa_util.pad_center(win_sq, n_fft)
58
-
59
- # Fill the envelope
60
- for i in range(n_frames):
61
- sample = i * hop_length
62
- x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63
- return x
64
-
65
-
66
- def griffin_lim(magnitudes, stft_fn, n_iters=30):
67
- """
68
- PARAMS
69
- ------
70
- magnitudes: spectrogram magnitudes
71
- stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72
- """
73
-
74
- angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75
- angles = angles.astype(np.float32)
76
- angles = torch.autograd.Variable(torch.from_numpy(angles))
77
- signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78
-
79
- for i in range(n_iters):
80
- _, angles = stft_fn.transform(signal)
81
- signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82
- return signal
83
-
84
-
85
- def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
86
- """
87
- PARAMS
88
- ------
89
- C: compression factor
90
- """
91
- return normalize_fun(torch.clamp(x, min=clip_val) * C)
92
-
93
-
94
- def dynamic_range_decompression(x, C=1):
95
- """
96
- PARAMS
97
- ------
98
- C: compression factor used to compress
99
- """
100
- return torch.exp(x) / C
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/audio/stft.py DELETED
@@ -1,180 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import numpy as np
4
- from scipy.signal import get_window
5
- from librosa.util import pad_center, tiny
6
- from librosa.filters import mel as librosa_mel_fn
7
-
8
- from audioldm.audio.audio_processing import (
9
- dynamic_range_compression,
10
- dynamic_range_decompression,
11
- window_sumsquare,
12
- )
13
-
14
-
15
- class STFT(torch.nn.Module):
16
- """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17
-
18
- def __init__(self, filter_length, hop_length, win_length, window="hann"):
19
- super(STFT, self).__init__()
20
- self.filter_length = filter_length
21
- self.hop_length = hop_length
22
- self.win_length = win_length
23
- self.window = window
24
- self.forward_transform = None
25
- scale = self.filter_length / self.hop_length
26
- fourier_basis = np.fft.fft(np.eye(self.filter_length))
27
-
28
- cutoff = int((self.filter_length / 2 + 1))
29
- fourier_basis = np.vstack(
30
- [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31
- )
32
-
33
- forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34
- inverse_basis = torch.FloatTensor(
35
- np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36
- )
37
-
38
- if window is not None:
39
- assert filter_length >= win_length
40
- # get window and zero center pad it to filter_length
41
- fft_window = get_window(window, win_length, fftbins=True)
42
- fft_window = pad_center(fft_window, filter_length)
43
- fft_window = torch.from_numpy(fft_window).float()
44
-
45
- # window the bases
46
- forward_basis *= fft_window
47
- inverse_basis *= fft_window
48
-
49
- self.register_buffer("forward_basis", forward_basis.float())
50
- self.register_buffer("inverse_basis", inverse_basis.float())
51
-
52
- def transform(self, input_data):
53
- num_batches = input_data.size(0)
54
- num_samples = input_data.size(1)
55
-
56
- self.num_samples = num_samples
57
-
58
- # similar to librosa, reflect-pad the input
59
- input_data = input_data.view(num_batches, 1, num_samples)
60
- input_data = F.pad(
61
- input_data.unsqueeze(1),
62
- (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
63
- mode="reflect",
64
- )
65
- input_data = input_data.squeeze(1)
66
-
67
- forward_transform = F.conv1d(
68
- input_data,
69
- torch.autograd.Variable(self.forward_basis, requires_grad=False),
70
- stride=self.hop_length,
71
- padding=0,
72
- ).cpu()
73
-
74
- cutoff = int((self.filter_length / 2) + 1)
75
- real_part = forward_transform[:, :cutoff, :]
76
- imag_part = forward_transform[:, cutoff:, :]
77
-
78
- magnitude = torch.sqrt(real_part**2 + imag_part**2)
79
- phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
80
-
81
- return magnitude, phase
82
-
83
- def inverse(self, magnitude, phase):
84
- recombine_magnitude_phase = torch.cat(
85
- [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
86
- )
87
-
88
- inverse_transform = F.conv_transpose1d(
89
- recombine_magnitude_phase,
90
- torch.autograd.Variable(self.inverse_basis, requires_grad=False),
91
- stride=self.hop_length,
92
- padding=0,
93
- )
94
-
95
- if self.window is not None:
96
- window_sum = window_sumsquare(
97
- self.window,
98
- magnitude.size(-1),
99
- hop_length=self.hop_length,
100
- win_length=self.win_length,
101
- n_fft=self.filter_length,
102
- dtype=np.float32,
103
- )
104
- # remove modulation effects
105
- approx_nonzero_indices = torch.from_numpy(
106
- np.where(window_sum > tiny(window_sum))[0]
107
- )
108
- window_sum = torch.autograd.Variable(
109
- torch.from_numpy(window_sum), requires_grad=False
110
- )
111
- window_sum = window_sum
112
- inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
113
- approx_nonzero_indices
114
- ]
115
-
116
- # scale by hop ratio
117
- inverse_transform *= float(self.filter_length) / self.hop_length
118
-
119
- inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
120
- inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
121
-
122
- return inverse_transform
123
-
124
- def forward(self, input_data):
125
- self.magnitude, self.phase = self.transform(input_data)
126
- reconstruction = self.inverse(self.magnitude, self.phase)
127
- return reconstruction
128
-
129
-
130
- class TacotronSTFT(torch.nn.Module):
131
- def __init__(
132
- self,
133
- filter_length,
134
- hop_length,
135
- win_length,
136
- n_mel_channels,
137
- sampling_rate,
138
- mel_fmin,
139
- mel_fmax,
140
- ):
141
- super(TacotronSTFT, self).__init__()
142
- self.n_mel_channels = n_mel_channels
143
- self.sampling_rate = sampling_rate
144
- self.stft_fn = STFT(filter_length, hop_length, win_length)
145
- mel_basis = librosa_mel_fn(
146
- sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
147
- )
148
- mel_basis = torch.from_numpy(mel_basis).float()
149
- self.register_buffer("mel_basis", mel_basis)
150
-
151
- def spectral_normalize(self, magnitudes, normalize_fun):
152
- output = dynamic_range_compression(magnitudes, normalize_fun)
153
- return output
154
-
155
- def spectral_de_normalize(self, magnitudes):
156
- output = dynamic_range_decompression(magnitudes)
157
- return output
158
-
159
- def mel_spectrogram(self, y, normalize_fun=torch.log):
160
- """Computes mel-spectrograms from a batch of waves
161
- PARAMS
162
- ------
163
- y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
164
-
165
- RETURNS
166
- -------
167
- mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
168
- """
169
- assert torch.min(y.data) >= -1, torch.min(y.data)
170
- assert torch.max(y.data) <= 1, torch.max(y.data)
171
-
172
- magnitudes, phases = self.stft_fn.transform(y)
173
- magnitudes = magnitudes.data
174
- mel_output = torch.matmul(self.mel_basis, magnitudes)
175
- mel_output = self.spectral_normalize(mel_output, normalize_fun)
176
- energy = torch.norm(magnitudes, dim=1)
177
-
178
- log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
179
-
180
- return mel_output, log_magnitudes, energy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/audio/tools.py DELETED
@@ -1,33 +0,0 @@
1
- import torch
2
- import numpy as np
3
-
4
-
5
- def get_mel_from_wav(audio, _stft):
6
- audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
7
- audio = torch.autograd.Variable(audio, requires_grad=False)
8
- melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
9
- melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
10
- log_magnitudes_stft = (
11
- torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
12
- )
13
- energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
14
- return melspec, log_magnitudes_stft, energy
15
-
16
-
17
- # def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
18
- # mel = torch.stack([mel])
19
- # mel_decompress = _stft.spectral_de_normalize(mel)
20
- # mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
21
- # spec_from_mel_scaling = 1000
22
- # spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
23
- # spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
24
- # spec_from_mel = spec_from_mel * spec_from_mel_scaling
25
-
26
- # audio = griffin_lim(
27
- # torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
28
- # )
29
-
30
- # audio = audio.squeeze()
31
- # audio = audio.cpu().numpy()
32
- # audio_path = out_filename
33
- # write(audio_path, _stft.sampling_rate, audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/__init__.py DELETED
File without changes
audioldm/clap/encoders.py DELETED
@@ -1,170 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from audioldm.clap.open_clip import create_model
4
- from audioldm.clap.training.data import get_audio_features
5
- import torchaudio
6
- from transformers import RobertaTokenizer
7
- import torch.nn.functional as F
8
-
9
-
10
- class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
11
- def __init__(
12
- self,
13
- pretrained_path="",
14
- key="class",
15
- sampling_rate=16000,
16
- embed_mode="audio",
17
- amodel = "HTSAT-tiny",
18
- unconditional_prob=0.1,
19
- random_mute=False,
20
- max_random_mute_portion=0.5,
21
- training_mode=True,
22
- ):
23
- super().__init__()
24
-
25
- self.key = key
26
- self.device = "cpu"
27
- self.precision = "fp32"
28
- self.amodel = amodel
29
- self.tmodel = "roberta" # the best text encoder in our training
30
- self.enable_fusion = False # False if you do not want to use the fusion model
31
- self.fusion_type = "aff_2d"
32
- self.pretrained = pretrained_path
33
- self.embed_mode = embed_mode
34
- self.embed_mode_orig = embed_mode
35
- self.sampling_rate = sampling_rate
36
- self.unconditional_prob = unconditional_prob
37
- self.random_mute = random_mute
38
- self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
39
- self.max_random_mute_portion = max_random_mute_portion
40
- self.training_mode = training_mode
41
- self.model, self.model_cfg = create_model(
42
- self.amodel,
43
- self.tmodel,
44
- self.pretrained,
45
- precision=self.precision,
46
- device=self.device,
47
- enable_fusion=self.enable_fusion,
48
- fusion_type=self.fusion_type,
49
- )
50
- for p in self.model.parameters():
51
- p.requires_grad = False
52
-
53
- self.model.eval()
54
-
55
- def get_unconditional_condition(self, batchsize):
56
- self.unconditional_token = self.model.get_text_embedding(
57
- self.tokenizer(["", ""])
58
- )[0:1]
59
- return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
60
-
61
- def batch_to_list(self, batch):
62
- ret = []
63
- for i in range(batch.size(0)):
64
- ret.append(batch[i])
65
- return ret
66
-
67
- def make_decision(self, probability):
68
- if float(torch.rand(1)) < probability:
69
- return True
70
- else:
71
- return False
72
-
73
- def random_uniform(self, start, end):
74
- val = torch.rand(1).item()
75
- return start + (end - start) * val
76
-
77
- def _random_mute(self, waveform):
78
- # waveform: [bs, t-steps]
79
- t_steps = waveform.size(-1)
80
- for i in range(waveform.size(0)):
81
- mute_size = int(
82
- self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
83
- )
84
- mute_start = int(self.random_uniform(0, t_steps - mute_size))
85
- waveform[i, mute_start : mute_start + mute_size] = 0
86
- return waveform
87
-
88
- def cos_similarity(self, waveform, text):
89
- # waveform: [bs, t_steps]
90
- with torch.no_grad():
91
- self.embed_mode = "audio"
92
- audio_emb = self(waveform.cuda())
93
- self.embed_mode = "text"
94
- text_emb = self(text)
95
- similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
96
- return similarity.squeeze()
97
-
98
- def forward(self, batch, key=None):
99
- # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
100
- # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
101
- if self.model.training == True and not self.training_mode:
102
- print(
103
- "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
104
- )
105
- self.model, self.model_cfg = create_model(
106
- self.amodel,
107
- self.tmodel,
108
- self.pretrained,
109
- precision=self.precision,
110
- device="cuda",
111
- enable_fusion=self.enable_fusion,
112
- fusion_type=self.fusion_type,
113
- )
114
- for p in self.model.parameters():
115
- p.requires_grad = False
116
- self.model.eval()
117
-
118
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
119
- if self.embed_mode == "audio":
120
- with torch.no_grad():
121
- audio_dict_list = []
122
- assert (
123
- self.sampling_rate == 16000
124
- ), "We only support 16000 sampling rate"
125
- if self.random_mute:
126
- batch = self._random_mute(batch)
127
- # batch: [bs, 1, t-samples]
128
- batch = torchaudio.functional.resample(
129
- batch, orig_freq=self.sampling_rate, new_freq=48000
130
- )
131
- for waveform in self.batch_to_list(batch):
132
- audio_dict = {}
133
- audio_dict = get_audio_features(
134
- audio_dict,
135
- waveform,
136
- 480000,
137
- data_truncating="fusion",
138
- data_filling="repeatpad",
139
- audio_cfg=self.model_cfg["audio_cfg"],
140
- )
141
- audio_dict_list.append(audio_dict)
142
- # [bs, 512]
143
- embed = self.model.get_audio_embedding(audio_dict_list)
144
- elif self.embed_mode == "text":
145
- with torch.no_grad():
146
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
147
- text_data = self.tokenizer(batch)
148
- embed = self.model.get_text_embedding(text_data)
149
-
150
- embed = embed.unsqueeze(1)
151
- self.unconditional_token = self.model.get_text_embedding(
152
- self.tokenizer(["", ""])
153
- )[0:1]
154
-
155
- for i in range(embed.size(0)):
156
- if self.make_decision(self.unconditional_prob):
157
- embed[i] = self.unconditional_token
158
-
159
- # [bs, 1, 512]
160
- return embed.detach()
161
-
162
- def tokenizer(self, text):
163
- result = self.tokenize(
164
- text,
165
- padding="max_length",
166
- truncation=True,
167
- max_length=512,
168
- return_tensors="pt",
169
- )
170
- return {k: v.squeeze(0) for k, v in result.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- from .factory import (
2
- list_models,
3
- create_model,
4
- create_model_and_transforms,
5
- add_model_config,
6
- )
7
- from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
- from .model import (
9
- CLAP,
10
- CLAPTextCfg,
11
- CLAPVisionCfg,
12
- CLAPAudioCfp,
13
- convert_weights_to_fp16,
14
- trace_model,
15
- )
16
- from .openai import load_openai_model, list_openai_models
17
- from .pretrained import (
18
- list_pretrained,
19
- list_pretrained_tag_models,
20
- list_pretrained_model_tags,
21
- get_pretrained_url,
22
- download_pretrained,
23
- )
24
- from .tokenizer import SimpleTokenizer, tokenize
25
- from .transform import image_transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/bert.py DELETED
@@ -1,40 +0,0 @@
1
- from transformers import BertTokenizer, BertModel
2
-
3
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
4
- model = BertModel.from_pretrained("bert-base-uncased")
5
- text = "Replace me by any text you'd like."
6
-
7
-
8
- def bert_embeddings(text):
9
- # text = "Replace me by any text you'd like."
10
- encoded_input = tokenizer(text, return_tensors="pt")
11
- output = model(**encoded_input)
12
- return output
13
-
14
-
15
- from transformers import RobertaTokenizer, RobertaModel
16
-
17
- tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
18
- model = RobertaModel.from_pretrained("roberta-base")
19
- text = "Replace me by any text you'd like."
20
-
21
-
22
- def Roberta_embeddings(text):
23
- # text = "Replace me by any text you'd like."
24
- encoded_input = tokenizer(text, return_tensors="pt")
25
- output = model(**encoded_input)
26
- return output
27
-
28
-
29
- from transformers import BartTokenizer, BartModel
30
-
31
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
32
- model = BartModel.from_pretrained("facebook/bart-base")
33
- text = "Replace me by any text you'd like."
34
-
35
-
36
- def bart_embeddings(text):
37
- # text = "Replace me by any text you'd like."
38
- encoded_input = tokenizer(text, return_tensors="pt")
39
- output = model(**encoded_input)
40
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
- size 1356917
 
 
 
audioldm/clap/open_clip/factory.py DELETED
@@ -1,277 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import pathlib
5
- import re
6
- from copy import deepcopy
7
- from pathlib import Path
8
-
9
- import torch
10
-
11
- from .model import CLAP, convert_weights_to_fp16
12
- from .openai import load_openai_model
13
- from .pretrained import get_pretrained_url, download_pretrained
14
- from .transform import image_transform
15
-
16
- _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
17
- _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
18
-
19
-
20
- def _natural_key(string_):
21
- return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
22
-
23
-
24
- def _rescan_model_configs():
25
- global _MODEL_CONFIGS
26
-
27
- config_ext = (".json",)
28
- config_files = []
29
- for config_path in _MODEL_CONFIG_PATHS:
30
- if config_path.is_file() and config_path.suffix in config_ext:
31
- config_files.append(config_path)
32
- elif config_path.is_dir():
33
- for ext in config_ext:
34
- config_files.extend(config_path.glob(f"*{ext}"))
35
-
36
- for cf in config_files:
37
- if os.path.basename(cf)[0] == ".":
38
- continue # Ignore hidden files
39
-
40
- with open(cf, "r") as f:
41
- model_cfg = json.load(f)
42
- if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
43
- _MODEL_CONFIGS[cf.stem] = model_cfg
44
-
45
- _MODEL_CONFIGS = {
46
- k: v
47
- for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
48
- }
49
-
50
-
51
- _rescan_model_configs() # initial populate of model config registry
52
-
53
-
54
- def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
55
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
56
- if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
57
- state_dict = checkpoint["state_dict"]
58
- else:
59
- state_dict = checkpoint
60
- if skip_params:
61
- if next(iter(state_dict.items()))[0].startswith("module"):
62
- state_dict = {k[7:]: v for k, v in state_dict.items()}
63
- # for k in state_dict:
64
- # if k.startswith('transformer'):
65
- # v = state_dict.pop(k)
66
- # state_dict['text_branch.' + k[12:]] = v
67
- return state_dict
68
-
69
-
70
- def create_model(
71
- amodel_name: str,
72
- tmodel_name: str,
73
- pretrained: str = "",
74
- precision: str = "fp32",
75
- device: torch.device = torch.device("cpu"),
76
- jit: bool = False,
77
- force_quick_gelu: bool = False,
78
- openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
79
- skip_params=True,
80
- pretrained_audio: str = "",
81
- pretrained_text: str = "",
82
- enable_fusion: bool = False,
83
- fusion_type: str = "None"
84
- # pretrained_image: bool = False,
85
- ):
86
- amodel_name = amodel_name.replace(
87
- "/", "-"
88
- ) # for callers using old naming with / in ViT names
89
- pretrained_orig = pretrained
90
- pretrained = pretrained.lower()
91
- if pretrained == "openai":
92
- if amodel_name in _MODEL_CONFIGS:
93
- logging.info(f"Loading {amodel_name} model config.")
94
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
95
- else:
96
- logging.error(
97
- f"Model config for {amodel_name} not found; available models {list_models()}."
98
- )
99
- raise RuntimeError(f"Model config for {amodel_name} not found.")
100
-
101
- logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
102
- # Hard Code in model name
103
- model_cfg["text_cfg"]["model_type"] = tmodel_name
104
- model = load_openai_model(
105
- "ViT-B-16",
106
- model_cfg,
107
- device=device,
108
- jit=jit,
109
- cache_dir=openai_model_cache_dir,
110
- enable_fusion=enable_fusion,
111
- fusion_type=fusion_type,
112
- )
113
- # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
114
- if precision == "amp" or precision == "fp32":
115
- model = model.float()
116
- else:
117
- if amodel_name in _MODEL_CONFIGS:
118
- logging.info(f"Loading {amodel_name} model config.")
119
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
120
- else:
121
- logging.error(
122
- f"Model config for {amodel_name} not found; available models {list_models()}."
123
- )
124
- raise RuntimeError(f"Model config for {amodel_name} not found.")
125
-
126
- if force_quick_gelu:
127
- # override for use of QuickGELU on non-OpenAI transformer models
128
- model_cfg["quick_gelu"] = True
129
-
130
- # if pretrained_image:
131
- # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
132
- # # pretrained weight loading for timm models set via vision_cfg
133
- # model_cfg['vision_cfg']['timm_model_pretrained'] = True
134
- # else:
135
- # assert False, 'pretrained image towers currently only supported for timm models'
136
- model_cfg["text_cfg"]["model_type"] = tmodel_name
137
- model_cfg["enable_fusion"] = enable_fusion
138
- model_cfg["fusion_type"] = fusion_type
139
- model = CLAP(**model_cfg)
140
-
141
- if pretrained:
142
- checkpoint_path = ""
143
- url = get_pretrained_url(amodel_name, pretrained)
144
- if url:
145
- checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
146
- elif os.path.exists(pretrained_orig):
147
- checkpoint_path = pretrained_orig
148
- if checkpoint_path:
149
- logging.info(
150
- f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
151
- )
152
- ckpt = load_state_dict(checkpoint_path, skip_params=True)
153
- model.load_state_dict(ckpt)
154
- param_names = [n for n, p in model.named_parameters()]
155
- # for n in param_names:
156
- # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
157
- else:
158
- logging.warning(
159
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
160
- )
161
- raise RuntimeError(
162
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
163
- )
164
-
165
- if pretrained_audio:
166
- if amodel_name.startswith("PANN"):
167
- if "Cnn14_mAP" in pretrained_audio: # official checkpoint
168
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
169
- audio_ckpt = audio_ckpt["model"]
170
- keys = list(audio_ckpt.keys())
171
- for key in keys:
172
- if (
173
- "spectrogram_extractor" not in key
174
- and "logmel_extractor" not in key
175
- ):
176
- v = audio_ckpt.pop(key)
177
- audio_ckpt["audio_branch." + key] = v
178
- elif os.path.basename(pretrained_audio).startswith(
179
- "PANN"
180
- ): # checkpoint trained via HTSAT codebase
181
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
182
- audio_ckpt = audio_ckpt["state_dict"]
183
- keys = list(audio_ckpt.keys())
184
- for key in keys:
185
- if key.startswith("sed_model"):
186
- v = audio_ckpt.pop(key)
187
- audio_ckpt["audio_branch." + key[10:]] = v
188
- elif os.path.basename(pretrained_audio).startswith(
189
- "finetuned"
190
- ): # checkpoint trained via linear probe codebase
191
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
192
- else:
193
- raise ValueError("Unknown audio checkpoint")
194
- elif amodel_name.startswith("HTSAT"):
195
- if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
196
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
197
- audio_ckpt = audio_ckpt["state_dict"]
198
- keys = list(audio_ckpt.keys())
199
- for key in keys:
200
- if key.startswith("sed_model") and (
201
- "spectrogram_extractor" not in key
202
- and "logmel_extractor" not in key
203
- ):
204
- v = audio_ckpt.pop(key)
205
- audio_ckpt["audio_branch." + key[10:]] = v
206
- elif os.path.basename(pretrained_audio).startswith(
207
- "HTSAT"
208
- ): # checkpoint trained via HTSAT codebase
209
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
210
- audio_ckpt = audio_ckpt["state_dict"]
211
- keys = list(audio_ckpt.keys())
212
- for key in keys:
213
- if key.startswith("sed_model"):
214
- v = audio_ckpt.pop(key)
215
- audio_ckpt["audio_branch." + key[10:]] = v
216
- elif os.path.basename(pretrained_audio).startswith(
217
- "finetuned"
218
- ): # checkpoint trained via linear probe codebase
219
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
220
- else:
221
- raise ValueError("Unknown audio checkpoint")
222
- else:
223
- raise f"this audio encoder pretrained checkpoint is not support"
224
-
225
- model.load_state_dict(audio_ckpt, strict=False)
226
- logging.info(
227
- f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
228
- )
229
- param_names = [n for n, p in model.named_parameters()]
230
- for n in param_names:
231
- print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
232
-
233
- model.to(device=device)
234
- if precision == "fp16":
235
- assert device.type != "cpu"
236
- convert_weights_to_fp16(model)
237
-
238
- if jit:
239
- model = torch.jit.script(model)
240
-
241
- return model, model_cfg
242
-
243
-
244
- def create_model_and_transforms(
245
- model_name: str,
246
- pretrained: str = "",
247
- precision: str = "fp32",
248
- device: torch.device = torch.device("cpu"),
249
- jit: bool = False,
250
- force_quick_gelu: bool = False,
251
- # pretrained_image: bool = False,
252
- ):
253
- model = create_model(
254
- model_name,
255
- pretrained,
256
- precision,
257
- device,
258
- jit,
259
- force_quick_gelu=force_quick_gelu,
260
- # pretrained_image=pretrained_image
261
- )
262
- preprocess_train = image_transform(model.visual.image_size, is_train=True)
263
- preprocess_val = image_transform(model.visual.image_size, is_train=False)
264
- return model, preprocess_train, preprocess_val
265
-
266
-
267
- def list_models():
268
- """enumerate available model architectures based on config files"""
269
- return list(_MODEL_CONFIGS.keys())
270
-
271
-
272
- def add_model_config(path):
273
- """add model config path or file and update registry"""
274
- if not isinstance(path, Path):
275
- path = Path(path)
276
- _MODEL_CONFIG_PATHS.append(path)
277
- _rescan_model_configs()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/feature_fusion.py DELETED
@@ -1,192 +0,0 @@
1
- """
2
- Feature Fusion for Varible-Length Data Processing
3
- AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
- According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
- """
6
-
7
- import torch
8
- import torch.nn as nn
9
-
10
-
11
- class DAF(nn.Module):
12
- """
13
- 直接相加 DirectAddFuse
14
- """
15
-
16
- def __init__(self):
17
- super(DAF, self).__init__()
18
-
19
- def forward(self, x, residual):
20
- return x + residual
21
-
22
-
23
- class iAFF(nn.Module):
24
- """
25
- 多特征融合 iAFF
26
- """
27
-
28
- def __init__(self, channels=64, r=4, type="2D"):
29
- super(iAFF, self).__init__()
30
- inter_channels = int(channels // r)
31
-
32
- if type == "1D":
33
- # 本地注意力
34
- self.local_att = nn.Sequential(
35
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
- nn.BatchNorm1d(inter_channels),
37
- nn.ReLU(inplace=True),
38
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
- nn.BatchNorm1d(channels),
40
- )
41
-
42
- # 全局注意力
43
- self.global_att = nn.Sequential(
44
- nn.AdaptiveAvgPool1d(1),
45
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
- nn.BatchNorm1d(inter_channels),
47
- nn.ReLU(inplace=True),
48
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
- nn.BatchNorm1d(channels),
50
- )
51
-
52
- # 第二次本地注意力
53
- self.local_att2 = nn.Sequential(
54
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
- nn.BatchNorm1d(inter_channels),
56
- nn.ReLU(inplace=True),
57
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
- nn.BatchNorm1d(channels),
59
- )
60
- # 第二次全局注意力
61
- self.global_att2 = nn.Sequential(
62
- nn.AdaptiveAvgPool1d(1),
63
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
- nn.BatchNorm1d(inter_channels),
65
- nn.ReLU(inplace=True),
66
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
- nn.BatchNorm1d(channels),
68
- )
69
- elif type == "2D":
70
- # 本地注意力
71
- self.local_att = nn.Sequential(
72
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
- nn.BatchNorm2d(inter_channels),
74
- nn.ReLU(inplace=True),
75
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
- nn.BatchNorm2d(channels),
77
- )
78
-
79
- # 全局注意力
80
- self.global_att = nn.Sequential(
81
- nn.AdaptiveAvgPool2d(1),
82
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
- nn.BatchNorm2d(inter_channels),
84
- nn.ReLU(inplace=True),
85
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
- nn.BatchNorm2d(channels),
87
- )
88
-
89
- # 第二次本地注意力
90
- self.local_att2 = nn.Sequential(
91
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
- nn.BatchNorm2d(inter_channels),
93
- nn.ReLU(inplace=True),
94
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
- nn.BatchNorm2d(channels),
96
- )
97
- # 第二次全局注意力
98
- self.global_att2 = nn.Sequential(
99
- nn.AdaptiveAvgPool2d(1),
100
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
- nn.BatchNorm2d(inter_channels),
102
- nn.ReLU(inplace=True),
103
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
- nn.BatchNorm2d(channels),
105
- )
106
- else:
107
- raise f"the type is not supported"
108
-
109
- self.sigmoid = nn.Sigmoid()
110
-
111
- def forward(self, x, residual):
112
- flag = False
113
- xa = x + residual
114
- if xa.size(0) == 1:
115
- xa = torch.cat([xa, xa], dim=0)
116
- flag = True
117
- xl = self.local_att(xa)
118
- xg = self.global_att(xa)
119
- xlg = xl + xg
120
- wei = self.sigmoid(xlg)
121
- xi = x * wei + residual * (1 - wei)
122
-
123
- xl2 = self.local_att2(xi)
124
- xg2 = self.global_att(xi)
125
- xlg2 = xl2 + xg2
126
- wei2 = self.sigmoid(xlg2)
127
- xo = x * wei2 + residual * (1 - wei2)
128
- if flag:
129
- xo = xo[0].unsqueeze(0)
130
- return xo
131
-
132
-
133
- class AFF(nn.Module):
134
- """
135
- 多特征融合 AFF
136
- """
137
-
138
- def __init__(self, channels=64, r=4, type="2D"):
139
- super(AFF, self).__init__()
140
- inter_channels = int(channels // r)
141
-
142
- if type == "1D":
143
- self.local_att = nn.Sequential(
144
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
- nn.BatchNorm1d(inter_channels),
146
- nn.ReLU(inplace=True),
147
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
- nn.BatchNorm1d(channels),
149
- )
150
- self.global_att = nn.Sequential(
151
- nn.AdaptiveAvgPool1d(1),
152
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
- nn.BatchNorm1d(inter_channels),
154
- nn.ReLU(inplace=True),
155
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
- nn.BatchNorm1d(channels),
157
- )
158
- elif type == "2D":
159
- self.local_att = nn.Sequential(
160
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
- nn.BatchNorm2d(inter_channels),
162
- nn.ReLU(inplace=True),
163
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
- nn.BatchNorm2d(channels),
165
- )
166
- self.global_att = nn.Sequential(
167
- nn.AdaptiveAvgPool2d(1),
168
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
- nn.BatchNorm2d(inter_channels),
170
- nn.ReLU(inplace=True),
171
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
- nn.BatchNorm2d(channels),
173
- )
174
- else:
175
- raise f"the type is not supported."
176
-
177
- self.sigmoid = nn.Sigmoid()
178
-
179
- def forward(self, x, residual):
180
- flag = False
181
- xa = x + residual
182
- if xa.size(0) == 1:
183
- xa = torch.cat([xa, xa], dim=0)
184
- flag = True
185
- xl = self.local_att(xa)
186
- xg = self.global_att(xa)
187
- xlg = xl + xg
188
- wei = self.sigmoid(xlg)
189
- xo = 2 * x * wei + 2 * residual * (1 - wei)
190
- if flag:
191
- xo = xo[0].unsqueeze(0)
192
- return xo