sanchit-gandhi HF staff commited on
Commit
f1daa60
•
1 Parent(s): d03edfa

camera ready

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -0
  2. audioldm/__init__.py +0 -3
  3. audioldm/audio/__init__.py +0 -0
  4. audioldm/audio/audio_processing.py +0 -100
  5. audioldm/audio/stft.py +0 -180
  6. audioldm/audio/tools.py +0 -33
  7. audioldm/clap/__init__.py +0 -0
  8. audioldm/clap/encoders.py +0 -170
  9. audioldm/clap/open_clip/__init__.py +0 -25
  10. audioldm/clap/open_clip/bert.py +0 -40
  11. audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz +0 -3
  12. audioldm/clap/open_clip/factory.py +0 -277
  13. audioldm/clap/open_clip/feature_fusion.py +0 -192
  14. audioldm/clap/open_clip/htsat.py +0 -1308
  15. audioldm/clap/open_clip/linear_probe.py +0 -66
  16. audioldm/clap/open_clip/loss.py +0 -398
  17. audioldm/clap/open_clip/model.py +0 -936
  18. audioldm/clap/open_clip/model_configs/HTSAT-base.json +0 -23
  19. audioldm/clap/open_clip/model_configs/HTSAT-large.json +0 -23
  20. audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +0 -23
  21. audioldm/clap/open_clip/model_configs/HTSAT-tiny.json +0 -23
  22. audioldm/clap/open_clip/model_configs/PANN-10.json +0 -23
  23. audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json +0 -23
  24. audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +0 -23
  25. audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +0 -23
  26. audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json +0 -23
  27. audioldm/clap/open_clip/model_configs/PANN-14.json +0 -23
  28. audioldm/clap/open_clip/model_configs/PANN-6.json +0 -23
  29. audioldm/clap/open_clip/model_configs/RN101-quickgelu.json +0 -22
  30. audioldm/clap/open_clip/model_configs/RN101.json +0 -21
  31. audioldm/clap/open_clip/model_configs/RN50-quickgelu.json +0 -22
  32. audioldm/clap/open_clip/model_configs/RN50.json +0 -21
  33. audioldm/clap/open_clip/model_configs/RN50x16.json +0 -21
  34. audioldm/clap/open_clip/model_configs/RN50x4.json +0 -21
  35. audioldm/clap/open_clip/model_configs/ViT-B-16.json +0 -16
  36. audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json +0 -17
  37. audioldm/clap/open_clip/model_configs/ViT-B-32.json +0 -16
  38. audioldm/clap/open_clip/model_configs/ViT-L-14.json +0 -16
  39. audioldm/clap/open_clip/openai.py +0 -156
  40. audioldm/clap/open_clip/pann_model.py +0 -703
  41. audioldm/clap/open_clip/pretrained.py +0 -167
  42. audioldm/clap/open_clip/timm_model.py +0 -112
  43. audioldm/clap/open_clip/tokenizer.py +0 -197
  44. audioldm/clap/open_clip/transform.py +0 -45
  45. audioldm/clap/open_clip/utils.py +0 -361
  46. audioldm/clap/open_clip/version.py +0 -1
  47. audioldm/clap/training/__init__.py +0 -0
  48. audioldm/clap/training/audioset_textmap.npy +0 -3
  49. audioldm/clap/training/data.py +0 -977
  50. audioldm/clap/training/distributed.py +0 -150
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 3.27.0
8
  app_file: app.py
9
  pinned: false
10
  license: bigscience-openrail-m
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
8
  app_file: app.py
9
  pinned: false
10
  license: bigscience-openrail-m
11
+ duplicated_from: haoheliu/audioldm-text-to-audio-generation
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
audioldm/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .ldm import LatentDiffusion
2
- from .utils import seed_everything
3
- from .pipeline import *
 
 
 
audioldm/audio/__init__.py DELETED
File without changes
audioldm/audio/audio_processing.py DELETED
@@ -1,100 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import librosa.util as librosa_util
4
- from scipy.signal import get_window
5
-
6
-
7
- def window_sumsquare(
8
- window,
9
- n_frames,
10
- hop_length,
11
- win_length,
12
- n_fft,
13
- dtype=np.float32,
14
- norm=None,
15
- ):
16
- """
17
- # from librosa 0.6
18
- Compute the sum-square envelope of a window function at a given hop length.
19
-
20
- This is used to estimate modulation effects induced by windowing
21
- observations in short-time fourier transforms.
22
-
23
- Parameters
24
- ----------
25
- window : string, tuple, number, callable, or list-like
26
- Window specification, as in `get_window`
27
-
28
- n_frames : int > 0
29
- The number of analysis frames
30
-
31
- hop_length : int > 0
32
- The number of samples to advance between frames
33
-
34
- win_length : [optional]
35
- The length of the window function. By default, this matches `n_fft`.
36
-
37
- n_fft : int > 0
38
- The length of each analysis frame.
39
-
40
- dtype : np.dtype
41
- The data type of the output
42
-
43
- Returns
44
- -------
45
- wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46
- The sum-squared envelope of the window function
47
- """
48
- if win_length is None:
49
- win_length = n_fft
50
-
51
- n = n_fft + hop_length * (n_frames - 1)
52
- x = np.zeros(n, dtype=dtype)
53
-
54
- # Compute the squared window at the desired length
55
- win_sq = get_window(window, win_length, fftbins=True)
56
- win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57
- win_sq = librosa_util.pad_center(win_sq, n_fft)
58
-
59
- # Fill the envelope
60
- for i in range(n_frames):
61
- sample = i * hop_length
62
- x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63
- return x
64
-
65
-
66
- def griffin_lim(magnitudes, stft_fn, n_iters=30):
67
- """
68
- PARAMS
69
- ------
70
- magnitudes: spectrogram magnitudes
71
- stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72
- """
73
-
74
- angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75
- angles = angles.astype(np.float32)
76
- angles = torch.autograd.Variable(torch.from_numpy(angles))
77
- signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78
-
79
- for i in range(n_iters):
80
- _, angles = stft_fn.transform(signal)
81
- signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82
- return signal
83
-
84
-
85
- def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
86
- """
87
- PARAMS
88
- ------
89
- C: compression factor
90
- """
91
- return normalize_fun(torch.clamp(x, min=clip_val) * C)
92
-
93
-
94
- def dynamic_range_decompression(x, C=1):
95
- """
96
- PARAMS
97
- ------
98
- C: compression factor used to compress
99
- """
100
- return torch.exp(x) / C
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/audio/stft.py DELETED
@@ -1,180 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import numpy as np
4
- from scipy.signal import get_window
5
- from librosa.util import pad_center, tiny
6
- from librosa.filters import mel as librosa_mel_fn
7
-
8
- from audioldm.audio.audio_processing import (
9
- dynamic_range_compression,
10
- dynamic_range_decompression,
11
- window_sumsquare,
12
- )
13
-
14
-
15
- class STFT(torch.nn.Module):
16
- """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17
-
18
- def __init__(self, filter_length, hop_length, win_length, window="hann"):
19
- super(STFT, self).__init__()
20
- self.filter_length = filter_length
21
- self.hop_length = hop_length
22
- self.win_length = win_length
23
- self.window = window
24
- self.forward_transform = None
25
- scale = self.filter_length / self.hop_length
26
- fourier_basis = np.fft.fft(np.eye(self.filter_length))
27
-
28
- cutoff = int((self.filter_length / 2 + 1))
29
- fourier_basis = np.vstack(
30
- [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31
- )
32
-
33
- forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34
- inverse_basis = torch.FloatTensor(
35
- np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36
- )
37
-
38
- if window is not None:
39
- assert filter_length >= win_length
40
- # get window and zero center pad it to filter_length
41
- fft_window = get_window(window, win_length, fftbins=True)
42
- fft_window = pad_center(fft_window, filter_length)
43
- fft_window = torch.from_numpy(fft_window).float()
44
-
45
- # window the bases
46
- forward_basis *= fft_window
47
- inverse_basis *= fft_window
48
-
49
- self.register_buffer("forward_basis", forward_basis.float())
50
- self.register_buffer("inverse_basis", inverse_basis.float())
51
-
52
- def transform(self, input_data):
53
- num_batches = input_data.size(0)
54
- num_samples = input_data.size(1)
55
-
56
- self.num_samples = num_samples
57
-
58
- # similar to librosa, reflect-pad the input
59
- input_data = input_data.view(num_batches, 1, num_samples)
60
- input_data = F.pad(
61
- input_data.unsqueeze(1),
62
- (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
63
- mode="reflect",
64
- )
65
- input_data = input_data.squeeze(1)
66
-
67
- forward_transform = F.conv1d(
68
- input_data,
69
- torch.autograd.Variable(self.forward_basis, requires_grad=False),
70
- stride=self.hop_length,
71
- padding=0,
72
- ).cpu()
73
-
74
- cutoff = int((self.filter_length / 2) + 1)
75
- real_part = forward_transform[:, :cutoff, :]
76
- imag_part = forward_transform[:, cutoff:, :]
77
-
78
- magnitude = torch.sqrt(real_part**2 + imag_part**2)
79
- phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
80
-
81
- return magnitude, phase
82
-
83
- def inverse(self, magnitude, phase):
84
- recombine_magnitude_phase = torch.cat(
85
- [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
86
- )
87
-
88
- inverse_transform = F.conv_transpose1d(
89
- recombine_magnitude_phase,
90
- torch.autograd.Variable(self.inverse_basis, requires_grad=False),
91
- stride=self.hop_length,
92
- padding=0,
93
- )
94
-
95
- if self.window is not None:
96
- window_sum = window_sumsquare(
97
- self.window,
98
- magnitude.size(-1),
99
- hop_length=self.hop_length,
100
- win_length=self.win_length,
101
- n_fft=self.filter_length,
102
- dtype=np.float32,
103
- )
104
- # remove modulation effects
105
- approx_nonzero_indices = torch.from_numpy(
106
- np.where(window_sum > tiny(window_sum))[0]
107
- )
108
- window_sum = torch.autograd.Variable(
109
- torch.from_numpy(window_sum), requires_grad=False
110
- )
111
- window_sum = window_sum
112
- inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
113
- approx_nonzero_indices
114
- ]
115
-
116
- # scale by hop ratio
117
- inverse_transform *= float(self.filter_length) / self.hop_length
118
-
119
- inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
120
- inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
121
-
122
- return inverse_transform
123
-
124
- def forward(self, input_data):
125
- self.magnitude, self.phase = self.transform(input_data)
126
- reconstruction = self.inverse(self.magnitude, self.phase)
127
- return reconstruction
128
-
129
-
130
- class TacotronSTFT(torch.nn.Module):
131
- def __init__(
132
- self,
133
- filter_length,
134
- hop_length,
135
- win_length,
136
- n_mel_channels,
137
- sampling_rate,
138
- mel_fmin,
139
- mel_fmax,
140
- ):
141
- super(TacotronSTFT, self).__init__()
142
- self.n_mel_channels = n_mel_channels
143
- self.sampling_rate = sampling_rate
144
- self.stft_fn = STFT(filter_length, hop_length, win_length)
145
- mel_basis = librosa_mel_fn(
146
- sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
147
- )
148
- mel_basis = torch.from_numpy(mel_basis).float()
149
- self.register_buffer("mel_basis", mel_basis)
150
-
151
- def spectral_normalize(self, magnitudes, normalize_fun):
152
- output = dynamic_range_compression(magnitudes, normalize_fun)
153
- return output
154
-
155
- def spectral_de_normalize(self, magnitudes):
156
- output = dynamic_range_decompression(magnitudes)
157
- return output
158
-
159
- def mel_spectrogram(self, y, normalize_fun=torch.log):
160
- """Computes mel-spectrograms from a batch of waves
161
- PARAMS
162
- ------
163
- y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
164
-
165
- RETURNS
166
- -------
167
- mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
168
- """
169
- assert torch.min(y.data) >= -1, torch.min(y.data)
170
- assert torch.max(y.data) <= 1, torch.max(y.data)
171
-
172
- magnitudes, phases = self.stft_fn.transform(y)
173
- magnitudes = magnitudes.data
174
- mel_output = torch.matmul(self.mel_basis, magnitudes)
175
- mel_output = self.spectral_normalize(mel_output, normalize_fun)
176
- energy = torch.norm(magnitudes, dim=1)
177
-
178
- log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
179
-
180
- return mel_output, log_magnitudes, energy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/audio/tools.py DELETED
@@ -1,33 +0,0 @@
1
- import torch
2
- import numpy as np
3
-
4
-
5
- def get_mel_from_wav(audio, _stft):
6
- audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
7
- audio = torch.autograd.Variable(audio, requires_grad=False)
8
- melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
9
- melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
10
- log_magnitudes_stft = (
11
- torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
12
- )
13
- energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
14
- return melspec, log_magnitudes_stft, energy
15
-
16
-
17
- # def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
18
- # mel = torch.stack([mel])
19
- # mel_decompress = _stft.spectral_de_normalize(mel)
20
- # mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
21
- # spec_from_mel_scaling = 1000
22
- # spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
23
- # spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
24
- # spec_from_mel = spec_from_mel * spec_from_mel_scaling
25
-
26
- # audio = griffin_lim(
27
- # torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
28
- # )
29
-
30
- # audio = audio.squeeze()
31
- # audio = audio.cpu().numpy()
32
- # audio_path = out_filename
33
- # write(audio_path, _stft.sampling_rate, audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/__init__.py DELETED
File without changes
audioldm/clap/encoders.py DELETED
@@ -1,170 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from audioldm.clap.open_clip import create_model
4
- from audioldm.clap.training.data import get_audio_features
5
- import torchaudio
6
- from transformers import RobertaTokenizer
7
- import torch.nn.functional as F
8
-
9
-
10
- class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
11
- def __init__(
12
- self,
13
- pretrained_path="",
14
- key="class",
15
- sampling_rate=16000,
16
- embed_mode="audio",
17
- amodel = "HTSAT-tiny",
18
- unconditional_prob=0.1,
19
- random_mute=False,
20
- max_random_mute_portion=0.5,
21
- training_mode=True,
22
- ):
23
- super().__init__()
24
-
25
- self.key = key
26
- self.device = "cpu"
27
- self.precision = "fp32"
28
- self.amodel = amodel
29
- self.tmodel = "roberta" # the best text encoder in our training
30
- self.enable_fusion = False # False if you do not want to use the fusion model
31
- self.fusion_type = "aff_2d"
32
- self.pretrained = pretrained_path
33
- self.embed_mode = embed_mode
34
- self.embed_mode_orig = embed_mode
35
- self.sampling_rate = sampling_rate
36
- self.unconditional_prob = unconditional_prob
37
- self.random_mute = random_mute
38
- self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
39
- self.max_random_mute_portion = max_random_mute_portion
40
- self.training_mode = training_mode
41
- self.model, self.model_cfg = create_model(
42
- self.amodel,
43
- self.tmodel,
44
- self.pretrained,
45
- precision=self.precision,
46
- device=self.device,
47
- enable_fusion=self.enable_fusion,
48
- fusion_type=self.fusion_type,
49
- )
50
- for p in self.model.parameters():
51
- p.requires_grad = False
52
-
53
- self.model.eval()
54
-
55
- def get_unconditional_condition(self, batchsize):
56
- self.unconditional_token = self.model.get_text_embedding(
57
- self.tokenizer(["", ""])
58
- )[0:1]
59
- return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
60
-
61
- def batch_to_list(self, batch):
62
- ret = []
63
- for i in range(batch.size(0)):
64
- ret.append(batch[i])
65
- return ret
66
-
67
- def make_decision(self, probability):
68
- if float(torch.rand(1)) < probability:
69
- return True
70
- else:
71
- return False
72
-
73
- def random_uniform(self, start, end):
74
- val = torch.rand(1).item()
75
- return start + (end - start) * val
76
-
77
- def _random_mute(self, waveform):
78
- # waveform: [bs, t-steps]
79
- t_steps = waveform.size(-1)
80
- for i in range(waveform.size(0)):
81
- mute_size = int(
82
- self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
83
- )
84
- mute_start = int(self.random_uniform(0, t_steps - mute_size))
85
- waveform[i, mute_start : mute_start + mute_size] = 0
86
- return waveform
87
-
88
- def cos_similarity(self, waveform, text):
89
- # waveform: [bs, t_steps]
90
- with torch.no_grad():
91
- self.embed_mode = "audio"
92
- audio_emb = self(waveform.cuda())
93
- self.embed_mode = "text"
94
- text_emb = self(text)
95
- similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
96
- return similarity.squeeze()
97
-
98
- def forward(self, batch, key=None):
99
- # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
100
- # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
101
- if self.model.training == True and not self.training_mode:
102
- print(
103
- "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
104
- )
105
- self.model, self.model_cfg = create_model(
106
- self.amodel,
107
- self.tmodel,
108
- self.pretrained,
109
- precision=self.precision,
110
- device="cuda",
111
- enable_fusion=self.enable_fusion,
112
- fusion_type=self.fusion_type,
113
- )
114
- for p in self.model.parameters():
115
- p.requires_grad = False
116
- self.model.eval()
117
-
118
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
119
- if self.embed_mode == "audio":
120
- with torch.no_grad():
121
- audio_dict_list = []
122
- assert (
123
- self.sampling_rate == 16000
124
- ), "We only support 16000 sampling rate"
125
- if self.random_mute:
126
- batch = self._random_mute(batch)
127
- # batch: [bs, 1, t-samples]
128
- batch = torchaudio.functional.resample(
129
- batch, orig_freq=self.sampling_rate, new_freq=48000
130
- )
131
- for waveform in self.batch_to_list(batch):
132
- audio_dict = {}
133
- audio_dict = get_audio_features(
134
- audio_dict,
135
- waveform,
136
- 480000,
137
- data_truncating="fusion",
138
- data_filling="repeatpad",
139
- audio_cfg=self.model_cfg["audio_cfg"],
140
- )
141
- audio_dict_list.append(audio_dict)
142
- # [bs, 512]
143
- embed = self.model.get_audio_embedding(audio_dict_list)
144
- elif self.embed_mode == "text":
145
- with torch.no_grad():
146
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
147
- text_data = self.tokenizer(batch)
148
- embed = self.model.get_text_embedding(text_data)
149
-
150
- embed = embed.unsqueeze(1)
151
- self.unconditional_token = self.model.get_text_embedding(
152
- self.tokenizer(["", ""])
153
- )[0:1]
154
-
155
- for i in range(embed.size(0)):
156
- if self.make_decision(self.unconditional_prob):
157
- embed[i] = self.unconditional_token
158
-
159
- # [bs, 1, 512]
160
- return embed.detach()
161
-
162
- def tokenizer(self, text):
163
- result = self.tokenize(
164
- text,
165
- padding="max_length",
166
- truncation=True,
167
- max_length=512,
168
- return_tensors="pt",
169
- )
170
- return {k: v.squeeze(0) for k, v in result.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- from .factory import (
2
- list_models,
3
- create_model,
4
- create_model_and_transforms,
5
- add_model_config,
6
- )
7
- from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
- from .model import (
9
- CLAP,
10
- CLAPTextCfg,
11
- CLAPVisionCfg,
12
- CLAPAudioCfp,
13
- convert_weights_to_fp16,
14
- trace_model,
15
- )
16
- from .openai import load_openai_model, list_openai_models
17
- from .pretrained import (
18
- list_pretrained,
19
- list_pretrained_tag_models,
20
- list_pretrained_model_tags,
21
- get_pretrained_url,
22
- download_pretrained,
23
- )
24
- from .tokenizer import SimpleTokenizer, tokenize
25
- from .transform import image_transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/bert.py DELETED
@@ -1,40 +0,0 @@
1
- from transformers import BertTokenizer, BertModel
2
-
3
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
4
- model = BertModel.from_pretrained("bert-base-uncased")
5
- text = "Replace me by any text you'd like."
6
-
7
-
8
- def bert_embeddings(text):
9
- # text = "Replace me by any text you'd like."
10
- encoded_input = tokenizer(text, return_tensors="pt")
11
- output = model(**encoded_input)
12
- return output
13
-
14
-
15
- from transformers import RobertaTokenizer, RobertaModel
16
-
17
- tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
18
- model = RobertaModel.from_pretrained("roberta-base")
19
- text = "Replace me by any text you'd like."
20
-
21
-
22
- def Roberta_embeddings(text):
23
- # text = "Replace me by any text you'd like."
24
- encoded_input = tokenizer(text, return_tensors="pt")
25
- output = model(**encoded_input)
26
- return output
27
-
28
-
29
- from transformers import BartTokenizer, BartModel
30
-
31
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
32
- model = BartModel.from_pretrained("facebook/bart-base")
33
- text = "Replace me by any text you'd like."
34
-
35
-
36
- def bart_embeddings(text):
37
- # text = "Replace me by any text you'd like."
38
- encoded_input = tokenizer(text, return_tensors="pt")
39
- output = model(**encoded_input)
40
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
- size 1356917
 
 
 
audioldm/clap/open_clip/factory.py DELETED
@@ -1,277 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import pathlib
5
- import re
6
- from copy import deepcopy
7
- from pathlib import Path
8
-
9
- import torch
10
-
11
- from .model import CLAP, convert_weights_to_fp16
12
- from .openai import load_openai_model
13
- from .pretrained import get_pretrained_url, download_pretrained
14
- from .transform import image_transform
15
-
16
- _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
17
- _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
18
-
19
-
20
- def _natural_key(string_):
21
- return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
22
-
23
-
24
- def _rescan_model_configs():
25
- global _MODEL_CONFIGS
26
-
27
- config_ext = (".json",)
28
- config_files = []
29
- for config_path in _MODEL_CONFIG_PATHS:
30
- if config_path.is_file() and config_path.suffix in config_ext:
31
- config_files.append(config_path)
32
- elif config_path.is_dir():
33
- for ext in config_ext:
34
- config_files.extend(config_path.glob(f"*{ext}"))
35
-
36
- for cf in config_files:
37
- if os.path.basename(cf)[0] == ".":
38
- continue # Ignore hidden files
39
-
40
- with open(cf, "r") as f:
41
- model_cfg = json.load(f)
42
- if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
43
- _MODEL_CONFIGS[cf.stem] = model_cfg
44
-
45
- _MODEL_CONFIGS = {
46
- k: v
47
- for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
48
- }
49
-
50
-
51
- _rescan_model_configs() # initial populate of model config registry
52
-
53
-
54
- def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
55
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
56
- if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
57
- state_dict = checkpoint["state_dict"]
58
- else:
59
- state_dict = checkpoint
60
- if skip_params:
61
- if next(iter(state_dict.items()))[0].startswith("module"):
62
- state_dict = {k[7:]: v for k, v in state_dict.items()}
63
- # for k in state_dict:
64
- # if k.startswith('transformer'):
65
- # v = state_dict.pop(k)
66
- # state_dict['text_branch.' + k[12:]] = v
67
- return state_dict
68
-
69
-
70
- def create_model(
71
- amodel_name: str,
72
- tmodel_name: str,
73
- pretrained: str = "",
74
- precision: str = "fp32",
75
- device: torch.device = torch.device("cpu"),
76
- jit: bool = False,
77
- force_quick_gelu: bool = False,
78
- openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
79
- skip_params=True,
80
- pretrained_audio: str = "",
81
- pretrained_text: str = "",
82
- enable_fusion: bool = False,
83
- fusion_type: str = "None"
84
- # pretrained_image: bool = False,
85
- ):
86
- amodel_name = amodel_name.replace(
87
- "/", "-"
88
- ) # for callers using old naming with / in ViT names
89
- pretrained_orig = pretrained
90
- pretrained = pretrained.lower()
91
- if pretrained == "openai":
92
- if amodel_name in _MODEL_CONFIGS:
93
- logging.info(f"Loading {amodel_name} model config.")
94
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
95
- else:
96
- logging.error(
97
- f"Model config for {amodel_name} not found; available models {list_models()}."
98
- )
99
- raise RuntimeError(f"Model config for {amodel_name} not found.")
100
-
101
- logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
102
- # Hard Code in model name
103
- model_cfg["text_cfg"]["model_type"] = tmodel_name
104
- model = load_openai_model(
105
- "ViT-B-16",
106
- model_cfg,
107
- device=device,
108
- jit=jit,
109
- cache_dir=openai_model_cache_dir,
110
- enable_fusion=enable_fusion,
111
- fusion_type=fusion_type,
112
- )
113
- # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
114
- if precision == "amp" or precision == "fp32":
115
- model = model.float()
116
- else:
117
- if amodel_name in _MODEL_CONFIGS:
118
- logging.info(f"Loading {amodel_name} model config.")
119
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
120
- else:
121
- logging.error(
122
- f"Model config for {amodel_name} not found; available models {list_models()}."
123
- )
124
- raise RuntimeError(f"Model config for {amodel_name} not found.")
125
-
126
- if force_quick_gelu:
127
- # override for use of QuickGELU on non-OpenAI transformer models
128
- model_cfg["quick_gelu"] = True
129
-
130
- # if pretrained_image:
131
- # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
132
- # # pretrained weight loading for timm models set via vision_cfg
133
- # model_cfg['vision_cfg']['timm_model_pretrained'] = True
134
- # else:
135
- # assert False, 'pretrained image towers currently only supported for timm models'
136
- model_cfg["text_cfg"]["model_type"] = tmodel_name
137
- model_cfg["enable_fusion"] = enable_fusion
138
- model_cfg["fusion_type"] = fusion_type
139
- model = CLAP(**model_cfg)
140
-
141
- if pretrained:
142
- checkpoint_path = ""
143
- url = get_pretrained_url(amodel_name, pretrained)
144
- if url:
145
- checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
146
- elif os.path.exists(pretrained_orig):
147
- checkpoint_path = pretrained_orig
148
- if checkpoint_path:
149
- logging.info(
150
- f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
151
- )
152
- ckpt = load_state_dict(checkpoint_path, skip_params=True)
153
- model.load_state_dict(ckpt)
154
- param_names = [n for n, p in model.named_parameters()]
155
- # for n in param_names:
156
- # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
157
- else:
158
- logging.warning(
159
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
160
- )
161
- raise RuntimeError(
162
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
163
- )
164
-
165
- if pretrained_audio:
166
- if amodel_name.startswith("PANN"):
167
- if "Cnn14_mAP" in pretrained_audio: # official checkpoint
168
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
169
- audio_ckpt = audio_ckpt["model"]
170
- keys = list(audio_ckpt.keys())
171
- for key in keys:
172
- if (
173
- "spectrogram_extractor" not in key
174
- and "logmel_extractor" not in key
175
- ):
176
- v = audio_ckpt.pop(key)
177
- audio_ckpt["audio_branch." + key] = v
178
- elif os.path.basename(pretrained_audio).startswith(
179
- "PANN"
180
- ): # checkpoint trained via HTSAT codebase
181
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
182
- audio_ckpt = audio_ckpt["state_dict"]
183
- keys = list(audio_ckpt.keys())
184
- for key in keys:
185
- if key.startswith("sed_model"):
186
- v = audio_ckpt.pop(key)
187
- audio_ckpt["audio_branch." + key[10:]] = v
188
- elif os.path.basename(pretrained_audio).startswith(
189
- "finetuned"
190
- ): # checkpoint trained via linear probe codebase
191
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
192
- else:
193
- raise ValueError("Unknown audio checkpoint")
194
- elif amodel_name.startswith("HTSAT"):
195
- if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
196
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
197
- audio_ckpt = audio_ckpt["state_dict"]
198
- keys = list(audio_ckpt.keys())
199
- for key in keys:
200
- if key.startswith("sed_model") and (
201
- "spectrogram_extractor" not in key
202
- and "logmel_extractor" not in key
203
- ):
204
- v = audio_ckpt.pop(key)
205
- audio_ckpt["audio_branch." + key[10:]] = v
206
- elif os.path.basename(pretrained_audio).startswith(
207
- "HTSAT"
208
- ): # checkpoint trained via HTSAT codebase
209
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
210
- audio_ckpt = audio_ckpt["state_dict"]
211
- keys = list(audio_ckpt.keys())
212
- for key in keys:
213
- if key.startswith("sed_model"):
214
- v = audio_ckpt.pop(key)
215
- audio_ckpt["audio_branch." + key[10:]] = v
216
- elif os.path.basename(pretrained_audio).startswith(
217
- "finetuned"
218
- ): # checkpoint trained via linear probe codebase
219
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
220
- else:
221
- raise ValueError("Unknown audio checkpoint")
222
- else:
223
- raise f"this audio encoder pretrained checkpoint is not support"
224
-
225
- model.load_state_dict(audio_ckpt, strict=False)
226
- logging.info(
227
- f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
228
- )
229
- param_names = [n for n, p in model.named_parameters()]
230
- for n in param_names:
231
- print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
232
-
233
- model.to(device=device)
234
- if precision == "fp16":
235
- assert device.type != "cpu"
236
- convert_weights_to_fp16(model)
237
-
238
- if jit:
239
- model = torch.jit.script(model)
240
-
241
- return model, model_cfg
242
-
243
-
244
- def create_model_and_transforms(
245
- model_name: str,
246
- pretrained: str = "",
247
- precision: str = "fp32",
248
- device: torch.device = torch.device("cpu"),
249
- jit: bool = False,
250
- force_quick_gelu: bool = False,
251
- # pretrained_image: bool = False,
252
- ):
253
- model = create_model(
254
- model_name,
255
- pretrained,
256
- precision,
257
- device,
258
- jit,
259
- force_quick_gelu=force_quick_gelu,
260
- # pretrained_image=pretrained_image
261
- )
262
- preprocess_train = image_transform(model.visual.image_size, is_train=True)
263
- preprocess_val = image_transform(model.visual.image_size, is_train=False)
264
- return model, preprocess_train, preprocess_val
265
-
266
-
267
- def list_models():
268
- """enumerate available model architectures based on config files"""
269
- return list(_MODEL_CONFIGS.keys())
270
-
271
-
272
- def add_model_config(path):
273
- """add model config path or file and update registry"""
274
- if not isinstance(path, Path):
275
- path = Path(path)
276
- _MODEL_CONFIG_PATHS.append(path)
277
- _rescan_model_configs()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/feature_fusion.py DELETED
@@ -1,192 +0,0 @@
1
- """
2
- Feature Fusion for Varible-Length Data Processing
3
- AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
- According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
- """
6
-
7
- import torch
8
- import torch.nn as nn
9
-
10
-
11
- class DAF(nn.Module):
12
- """
13
- 直接相加 DirectAddFuse
14
- """
15
-
16
- def __init__(self):
17
- super(DAF, self).__init__()
18
-
19
- def forward(self, x, residual):
20
- return x + residual
21
-
22
-
23
- class iAFF(nn.Module):
24
- """
25
- 多特征融合 iAFF
26
- """
27
-
28
- def __init__(self, channels=64, r=4, type="2D"):
29
- super(iAFF, self).__init__()
30
- inter_channels = int(channels // r)
31
-
32
- if type == "1D":
33
- # 本地注意力
34
- self.local_att = nn.Sequential(
35
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
- nn.BatchNorm1d(inter_channels),
37
- nn.ReLU(inplace=True),
38
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
- nn.BatchNorm1d(channels),
40
- )
41
-
42
- # 全局注意力
43
- self.global_att = nn.Sequential(
44
- nn.AdaptiveAvgPool1d(1),
45
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
- nn.BatchNorm1d(inter_channels),
47
- nn.ReLU(inplace=True),
48
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
- nn.BatchNorm1d(channels),
50
- )
51
-
52
- # 第二次本地注意力
53
- self.local_att2 = nn.Sequential(
54
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
- nn.BatchNorm1d(inter_channels),
56
- nn.ReLU(inplace=True),
57
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
- nn.BatchNorm1d(channels),
59
- )
60
- # 第二次全局注意力
61
- self.global_att2 = nn.Sequential(
62
- nn.AdaptiveAvgPool1d(1),
63
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
- nn.BatchNorm1d(inter_channels),
65
- nn.ReLU(inplace=True),
66
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
- nn.BatchNorm1d(channels),
68
- )
69
- elif type == "2D":
70
- # 本地注意力
71
- self.local_att = nn.Sequential(
72
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
- nn.BatchNorm2d(inter_channels),
74
- nn.ReLU(inplace=True),
75
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
- nn.BatchNorm2d(channels),
77
- )
78
-
79
- # 全局注意力
80
- self.global_att = nn.Sequential(
81
- nn.AdaptiveAvgPool2d(1),
82
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
- nn.BatchNorm2d(inter_channels),
84
- nn.ReLU(inplace=True),
85
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
- nn.BatchNorm2d(channels),
87
- )
88
-
89
- # 第二次本地注意力
90
- self.local_att2 = nn.Sequential(
91
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
- nn.BatchNorm2d(inter_channels),
93
- nn.ReLU(inplace=True),
94
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
- nn.BatchNorm2d(channels),
96
- )
97
- # 第二次全局注意力
98
- self.global_att2 = nn.Sequential(
99
- nn.AdaptiveAvgPool2d(1),
100
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
- nn.BatchNorm2d(inter_channels),
102
- nn.ReLU(inplace=True),
103
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
- nn.BatchNorm2d(channels),
105
- )
106
- else:
107
- raise f"the type is not supported"
108
-
109
- self.sigmoid = nn.Sigmoid()
110
-
111
- def forward(self, x, residual):
112
- flag = False
113
- xa = x + residual
114
- if xa.size(0) == 1:
115
- xa = torch.cat([xa, xa], dim=0)
116
- flag = True
117
- xl = self.local_att(xa)
118
- xg = self.global_att(xa)
119
- xlg = xl + xg
120
- wei = self.sigmoid(xlg)
121
- xi = x * wei + residual * (1 - wei)
122
-
123
- xl2 = self.local_att2(xi)
124
- xg2 = self.global_att(xi)
125
- xlg2 = xl2 + xg2
126
- wei2 = self.sigmoid(xlg2)
127
- xo = x * wei2 + residual * (1 - wei2)
128
- if flag:
129
- xo = xo[0].unsqueeze(0)
130
- return xo
131
-
132
-
133
- class AFF(nn.Module):
134
- """
135
- 多特征融合 AFF
136
- """
137
-
138
- def __init__(self, channels=64, r=4, type="2D"):
139
- super(AFF, self).__init__()
140
- inter_channels = int(channels // r)
141
-
142
- if type == "1D":
143
- self.local_att = nn.Sequential(
144
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
- nn.BatchNorm1d(inter_channels),
146
- nn.ReLU(inplace=True),
147
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
- nn.BatchNorm1d(channels),
149
- )
150
- self.global_att = nn.Sequential(
151
- nn.AdaptiveAvgPool1d(1),
152
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
- nn.BatchNorm1d(inter_channels),
154
- nn.ReLU(inplace=True),
155
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
- nn.BatchNorm1d(channels),
157
- )
158
- elif type == "2D":
159
- self.local_att = nn.Sequential(
160
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
- nn.BatchNorm2d(inter_channels),
162
- nn.ReLU(inplace=True),
163
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
- nn.BatchNorm2d(channels),
165
- )
166
- self.global_att = nn.Sequential(
167
- nn.AdaptiveAvgPool2d(1),
168
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
- nn.BatchNorm2d(inter_channels),
170
- nn.ReLU(inplace=True),
171
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
- nn.BatchNorm2d(channels),
173
- )
174
- else:
175
- raise f"the type is not supported."
176
-
177
- self.sigmoid = nn.Sigmoid()
178
-
179
- def forward(self, x, residual):
180
- flag = False
181
- xa = x + residual
182
- if xa.size(0) == 1:
183
- xa = torch.cat([xa, xa], dim=0)
184
- flag = True
185
- xl = self.local_att(xa)
186
- xg = self.global_att(xa)
187
- xlg = xl + xg
188
- wei = self.sigmoid(xlg)
189
- xo = 2 * x * wei + 2 * residual * (1 - wei)
190
- if flag:
191
- xo = xo[0].unsqueeze(0)
192
- return xo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/htsat.py DELETED
@@ -1,1308 +0,0 @@
1
- # Ke Chen
2
- # knutchen@ucsd.edu
3
- # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
- # Some layers designed on the model
5
- # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
- # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from itertools import repeat
12
- import collections.abc
13
- import math
14
- import warnings
15
-
16
- from torch.nn.init import _calculate_fan_in_and_fan_out
17
- import torch.utils.checkpoint as checkpoint
18
-
19
- import random
20
-
21
- from torchlibrosa.stft import Spectrogram, LogmelFilterBank
22
- from torchlibrosa.augmentation import SpecAugmentation
23
-
24
- from itertools import repeat
25
- from .utils import do_mixup, interpolate
26
-
27
- from .feature_fusion import iAFF, AFF, DAF
28
-
29
- # from PyTorch internals
30
- def _ntuple(n):
31
- def parse(x):
32
- if isinstance(x, collections.abc.Iterable):
33
- return x
34
- return tuple(repeat(x, n))
35
-
36
- return parse
37
-
38
-
39
- to_1tuple = _ntuple(1)
40
- to_2tuple = _ntuple(2)
41
- to_3tuple = _ntuple(3)
42
- to_4tuple = _ntuple(4)
43
- to_ntuple = _ntuple
44
-
45
-
46
- def drop_path(x, drop_prob: float = 0.0, training: bool = False):
47
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
48
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
49
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
50
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
51
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
52
- 'survival rate' as the argument.
53
- """
54
- if drop_prob == 0.0 or not training:
55
- return x
56
- keep_prob = 1 - drop_prob
57
- shape = (x.shape[0],) + (1,) * (
58
- x.ndim - 1
59
- ) # work with diff dim tensors, not just 2D ConvNets
60
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
61
- random_tensor.floor_() # binarize
62
- output = x.div(keep_prob) * random_tensor
63
- return output
64
-
65
-
66
- class DropPath(nn.Module):
67
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
68
-
69
- def __init__(self, drop_prob=None):
70
- super(DropPath, self).__init__()
71
- self.drop_prob = drop_prob
72
-
73
- def forward(self, x):
74
- return drop_path(x, self.drop_prob, self.training)
75
-
76
-
77
- class PatchEmbed(nn.Module):
78
- """2D Image to Patch Embedding"""
79
-
80
- def __init__(
81
- self,
82
- img_size=224,
83
- patch_size=16,
84
- in_chans=3,
85
- embed_dim=768,
86
- norm_layer=None,
87
- flatten=True,
88
- patch_stride=16,
89
- enable_fusion=False,
90
- fusion_type="None",
91
- ):
92
- super().__init__()
93
- img_size = to_2tuple(img_size)
94
- patch_size = to_2tuple(patch_size)
95
- patch_stride = to_2tuple(patch_stride)
96
- self.img_size = img_size
97
- self.patch_size = patch_size
98
- self.patch_stride = patch_stride
99
- self.grid_size = (
100
- img_size[0] // patch_stride[0],
101
- img_size[1] // patch_stride[1],
102
- )
103
- self.num_patches = self.grid_size[0] * self.grid_size[1]
104
- self.flatten = flatten
105
- self.in_chans = in_chans
106
- self.embed_dim = embed_dim
107
-
108
- self.enable_fusion = enable_fusion
109
- self.fusion_type = fusion_type
110
-
111
- padding = (
112
- (patch_size[0] - patch_stride[0]) // 2,
113
- (patch_size[1] - patch_stride[1]) // 2,
114
- )
115
-
116
- if (self.enable_fusion) and (self.fusion_type == "channel_map"):
117
- self.proj = nn.Conv2d(
118
- in_chans * 4,
119
- embed_dim,
120
- kernel_size=patch_size,
121
- stride=patch_stride,
122
- padding=padding,
123
- )
124
- else:
125
- self.proj = nn.Conv2d(
126
- in_chans,
127
- embed_dim,
128
- kernel_size=patch_size,
129
- stride=patch_stride,
130
- padding=padding,
131
- )
132
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
133
-
134
- if (self.enable_fusion) and (
135
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
136
- ):
137
- self.mel_conv2d = nn.Conv2d(
138
- in_chans,
139
- embed_dim,
140
- kernel_size=(patch_size[0], patch_size[1] * 3),
141
- stride=(patch_stride[0], patch_stride[1] * 3),
142
- padding=padding,
143
- )
144
- if self.fusion_type == "daf_2d":
145
- self.fusion_model = DAF()
146
- elif self.fusion_type == "aff_2d":
147
- self.fusion_model = AFF(channels=embed_dim, type="2D")
148
- elif self.fusion_type == "iaff_2d":
149
- self.fusion_model = iAFF(channels=embed_dim, type="2D")
150
-
151
- def forward(self, x, longer_idx=None):
152
- if (self.enable_fusion) and (
153
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
154
- ):
155
- global_x = x[:, 0:1, :, :]
156
-
157
- # global processing
158
- B, C, H, W = global_x.shape
159
- assert (
160
- H == self.img_size[0] and W == self.img_size[1]
161
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
162
- global_x = self.proj(global_x)
163
- TW = global_x.size(-1)
164
- if len(longer_idx) > 0:
165
- # local processing
166
- local_x = x[longer_idx, 1:, :, :].contiguous()
167
- B, C, H, W = local_x.shape
168
- local_x = local_x.view(B * C, 1, H, W)
169
- local_x = self.mel_conv2d(local_x)
170
- local_x = local_x.view(
171
- B, C, local_x.size(1), local_x.size(2), local_x.size(3)
172
- )
173
- local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
174
- TB, TC, TH, _ = local_x.size()
175
- if local_x.size(-1) < TW:
176
- local_x = torch.cat(
177
- [
178
- local_x,
179
- torch.zeros(
180
- (TB, TC, TH, TW - local_x.size(-1)),
181
- device=global_x.device,
182
- ),
183
- ],
184
- dim=-1,
185
- )
186
- else:
187
- local_x = local_x[:, :, :, :TW]
188
-
189
- global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
190
- x = global_x
191
- else:
192
- B, C, H, W = x.shape
193
- assert (
194
- H == self.img_size[0] and W == self.img_size[1]
195
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
196
- x = self.proj(x)
197
-
198
- if self.flatten:
199
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
200
- x = self.norm(x)
201
- return x
202
-
203
-
204
- class Mlp(nn.Module):
205
- """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
206
-
207
- def __init__(
208
- self,
209
- in_features,
210
- hidden_features=None,
211
- out_features=None,
212
- act_layer=nn.GELU,
213
- drop=0.0,
214
- ):
215
- super().__init__()
216
- out_features = out_features or in_features
217
- hidden_features = hidden_features or in_features
218
- self.fc1 = nn.Linear(in_features, hidden_features)
219
- self.act = act_layer()
220
- self.fc2 = nn.Linear(hidden_features, out_features)
221
- self.drop = nn.Dropout(drop)
222
-
223
- def forward(self, x):
224
- x = self.fc1(x)
225
- x = self.act(x)
226
- x = self.drop(x)
227
- x = self.fc2(x)
228
- x = self.drop(x)
229
- return x
230
-
231
-
232
- def _no_grad_trunc_normal_(tensor, mean, std, a, b):
233
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
234
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
235
- def norm_cdf(x):
236
- # Computes standard normal cumulative distribution function
237
- return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
238
-
239
- if (mean < a - 2 * std) or (mean > b + 2 * std):
240
- warnings.warn(
241
- "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
242
- "The distribution of values may be incorrect.",
243
- stacklevel=2,
244
- )
245
-
246
- with torch.no_grad():
247
- # Values are generated by using a truncated uniform distribution and
248
- # then using the inverse CDF for the normal distribution.
249
- # Get upper and lower cdf values
250
- l = norm_cdf((a - mean) / std)
251
- u = norm_cdf((b - mean) / std)
252
-
253
- # Uniformly fill tensor with values from [l, u], then translate to
254
- # [2l-1, 2u-1].
255
- tensor.uniform_(2 * l - 1, 2 * u - 1)
256
-
257
- # Use inverse cdf transform for normal distribution to get truncated
258
- # standard normal
259
- tensor.erfinv_()
260
-
261
- # Transform to proper mean, std
262
- tensor.mul_(std * math.sqrt(2.0))
263
- tensor.add_(mean)
264
-
265
- # Clamp to ensure it's in the proper range
266
- tensor.clamp_(min=a, max=b)
267
- return tensor
268
-
269
-
270
- def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
271
- # type: (Tensor, float, float, float, float) -> Tensor
272
- r"""Fills the input Tensor with values drawn from a truncated
273
- normal distribution. The values are effectively drawn from the
274
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
275
- with values outside :math:`[a, b]` redrawn until they are within
276
- the bounds. The method used for generating the random values works
277
- best when :math:`a \leq \text{mean} \leq b`.
278
- Args:
279
- tensor: an n-dimensional `torch.Tensor`
280
- mean: the mean of the normal distribution
281
- std: the standard deviation of the normal distribution
282
- a: the minimum cutoff value
283
- b: the maximum cutoff value
284
- Examples:
285
- >>> w = torch.empty(3, 5)
286
- >>> nn.init.trunc_normal_(w)
287
- """
288
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
289
-
290
-
291
- def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
292
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
293
- if mode == "fan_in":
294
- denom = fan_in
295
- elif mode == "fan_out":
296
- denom = fan_out
297
- elif mode == "fan_avg":
298
- denom = (fan_in + fan_out) / 2
299
-
300
- variance = scale / denom
301
-
302
- if distribution == "truncated_normal":
303
- # constant is stddev of standard normal truncated to (-2, 2)
304
- trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
305
- elif distribution == "normal":
306
- tensor.normal_(std=math.sqrt(variance))
307
- elif distribution == "uniform":
308
- bound = math.sqrt(3 * variance)
309
- tensor.uniform_(-bound, bound)
310
- else:
311
- raise ValueError(f"invalid distribution {distribution}")
312
-
313
-
314
- def lecun_normal_(tensor):
315
- variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
316
-
317
-
318
- def window_partition(x, window_size):
319
- """
320
- Args:
321
- x: (B, H, W, C)
322
- window_size (int): window size
323
- Returns:
324
- windows: (num_windows*B, window_size, window_size, C)
325
- """
326
- B, H, W, C = x.shape
327
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
328
- windows = (
329
- x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
330
- )
331
- return windows
332
-
333
-
334
- def window_reverse(windows, window_size, H, W):
335
- """
336
- Args:
337
- windows: (num_windows*B, window_size, window_size, C)
338
- window_size (int): Window size
339
- H (int): Height of image
340
- W (int): Width of image
341
- Returns:
342
- x: (B, H, W, C)
343
- """
344
- B = int(windows.shape[0] / (H * W / window_size / window_size))
345
- x = windows.view(
346
- B, H // window_size, W // window_size, window_size, window_size, -1
347
- )
348
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
349
- return x
350
-
351
-
352
- class WindowAttention(nn.Module):
353
- r"""Window based multi-head self attention (W-MSA) module with relative position bias.
354
- It supports both of shifted and non-shifted window.
355
- Args:
356
- dim (int): Number of input channels.
357
- window_size (tuple[int]): The height and width of the window.
358
- num_heads (int): Number of attention heads.
359
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
360
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
361
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
362
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
363
- """
364
-
365
- def __init__(
366
- self,
367
- dim,
368
- window_size,
369
- num_heads,
370
- qkv_bias=True,
371
- qk_scale=None,
372
- attn_drop=0.0,
373
- proj_drop=0.0,
374
- ):
375
-
376
- super().__init__()
377
- self.dim = dim
378
- self.window_size = window_size # Wh, Ww
379
- self.num_heads = num_heads
380
- head_dim = dim // num_heads
381
- self.scale = qk_scale or head_dim**-0.5
382
-
383
- # define a parameter table of relative position bias
384
- self.relative_position_bias_table = nn.Parameter(
385
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
386
- ) # 2*Wh-1 * 2*Ww-1, nH
387
-
388
- # get pair-wise relative position index for each token inside the window
389
- coords_h = torch.arange(self.window_size[0])
390
- coords_w = torch.arange(self.window_size[1])
391
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
392
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
393
- relative_coords = (
394
- coords_flatten[:, :, None] - coords_flatten[:, None, :]
395
- ) # 2, Wh*Ww, Wh*Ww
396
- relative_coords = relative_coords.permute(
397
- 1, 2, 0
398
- ).contiguous() # Wh*Ww, Wh*Ww, 2
399
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
400
- relative_coords[:, :, 1] += self.window_size[1] - 1
401
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
402
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
403
- self.register_buffer("relative_position_index", relative_position_index)
404
-
405
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
406
- self.attn_drop = nn.Dropout(attn_drop)
407
- self.proj = nn.Linear(dim, dim)
408
- self.proj_drop = nn.Dropout(proj_drop)
409
-
410
- trunc_normal_(self.relative_position_bias_table, std=0.02)
411
- self.softmax = nn.Softmax(dim=-1)
412
-
413
- def forward(self, x, mask=None):
414
- """
415
- Args:
416
- x: input features with shape of (num_windows*B, N, C)
417
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
418
- """
419
- B_, N, C = x.shape
420
- qkv = (
421
- self.qkv(x)
422
- .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
423
- .permute(2, 0, 3, 1, 4)
424
- )
425
- q, k, v = (
426
- qkv[0],
427
- qkv[1],
428
- qkv[2],
429
- ) # make torchscript happy (cannot use tensor as tuple)
430
-
431
- q = q * self.scale
432
- attn = q @ k.transpose(-2, -1)
433
-
434
- relative_position_bias = self.relative_position_bias_table[
435
- self.relative_position_index.view(-1)
436
- ].view(
437
- self.window_size[0] * self.window_size[1],
438
- self.window_size[0] * self.window_size[1],
439
- -1,
440
- ) # Wh*Ww,Wh*Ww,nH
441
- relative_position_bias = relative_position_bias.permute(
442
- 2, 0, 1
443
- ).contiguous() # nH, Wh*Ww, Wh*Ww
444
- attn = attn + relative_position_bias.unsqueeze(0)
445
-
446
- if mask is not None:
447
- nW = mask.shape[0]
448
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
449
- 1
450
- ).unsqueeze(0)
451
- attn = attn.view(-1, self.num_heads, N, N)
452
- attn = self.softmax(attn)
453
- else:
454
- attn = self.softmax(attn)
455
-
456
- attn = self.attn_drop(attn)
457
-
458
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
459
- x = self.proj(x)
460
- x = self.proj_drop(x)
461
- return x, attn
462
-
463
- def extra_repr(self):
464
- return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
465
-
466
-
467
- # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
468
- class SwinTransformerBlock(nn.Module):
469
- r"""Swin Transformer Block.
470
- Args:
471
- dim (int): Number of input channels.
472
- input_resolution (tuple[int]): Input resulotion.
473
- num_heads (int): Number of attention heads.
474
- window_size (int): Window size.
475
- shift_size (int): Shift size for SW-MSA.
476
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
477
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
478
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
479
- drop (float, optional): Dropout rate. Default: 0.0
480
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
481
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
482
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
483
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
484
- """
485
-
486
- def __init__(
487
- self,
488
- dim,
489
- input_resolution,
490
- num_heads,
491
- window_size=7,
492
- shift_size=0,
493
- mlp_ratio=4.0,
494
- qkv_bias=True,
495
- qk_scale=None,
496
- drop=0.0,
497
- attn_drop=0.0,
498
- drop_path=0.0,
499
- act_layer=nn.GELU,
500
- norm_layer=nn.LayerNorm,
501
- norm_before_mlp="ln",
502
- ):
503
- super().__init__()
504
- self.dim = dim
505
- self.input_resolution = input_resolution
506
- self.num_heads = num_heads
507
- self.window_size = window_size
508
- self.shift_size = shift_size
509
- self.mlp_ratio = mlp_ratio
510
- self.norm_before_mlp = norm_before_mlp
511
- if min(self.input_resolution) <= self.window_size:
512
- # if window size is larger than input resolution, we don't partition windows
513
- self.shift_size = 0
514
- self.window_size = min(self.input_resolution)
515
- assert (
516
- 0 <= self.shift_size < self.window_size
517
- ), "shift_size must in 0-window_size"
518
-
519
- self.norm1 = norm_layer(dim)
520
- self.attn = WindowAttention(
521
- dim,
522
- window_size=to_2tuple(self.window_size),
523
- num_heads=num_heads,
524
- qkv_bias=qkv_bias,
525
- qk_scale=qk_scale,
526
- attn_drop=attn_drop,
527
- proj_drop=drop,
528
- )
529
-
530
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
531
- if self.norm_before_mlp == "ln":
532
- self.norm2 = nn.LayerNorm(dim)
533
- elif self.norm_before_mlp == "bn":
534
- self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
535
- 1, 2
536
- )
537
- else:
538
- raise NotImplementedError
539
- mlp_hidden_dim = int(dim * mlp_ratio)
540
- self.mlp = Mlp(
541
- in_features=dim,
542
- hidden_features=mlp_hidden_dim,
543
- act_layer=act_layer,
544
- drop=drop,
545
- )
546
-
547
- if self.shift_size > 0:
548
- # calculate attention mask for SW-MSA
549
- H, W = self.input_resolution
550
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
551
- h_slices = (
552
- slice(0, -self.window_size),
553
- slice(-self.window_size, -self.shift_size),
554
- slice(-self.shift_size, None),
555
- )
556
- w_slices = (
557
- slice(0, -self.window_size),
558
- slice(-self.window_size, -self.shift_size),
559
- slice(-self.shift_size, None),
560
- )
561
- cnt = 0
562
- for h in h_slices:
563
- for w in w_slices:
564
- img_mask[:, h, w, :] = cnt
565
- cnt += 1
566
-
567
- mask_windows = window_partition(
568
- img_mask, self.window_size
569
- ) # nW, window_size, window_size, 1
570
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
571
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
572
- attn_mask = attn_mask.masked_fill(
573
- attn_mask != 0, float(-100.0)
574
- ).masked_fill(attn_mask == 0, float(0.0))
575
- else:
576
- attn_mask = None
577
-
578
- self.register_buffer("attn_mask", attn_mask)
579
-
580
- def forward(self, x):
581
- # pdb.set_trace()
582
- H, W = self.input_resolution
583
- # print("H: ", H)
584
- # print("W: ", W)
585
- # pdb.set_trace()
586
- B, L, C = x.shape
587
- # assert L == H * W, "input feature has wrong size"
588
-
589
- shortcut = x
590
- x = self.norm1(x)
591
- x = x.view(B, H, W, C)
592
-
593
- # cyclic shift
594
- if self.shift_size > 0:
595
- shifted_x = torch.roll(
596
- x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
597
- )
598
- else:
599
- shifted_x = x
600
-
601
- # partition windows
602
- x_windows = window_partition(
603
- shifted_x, self.window_size
604
- ) # nW*B, window_size, window_size, C
605
- x_windows = x_windows.view(
606
- -1, self.window_size * self.window_size, C
607
- ) # nW*B, window_size*window_size, C
608
-
609
- # W-MSA/SW-MSA
610
- attn_windows, attn = self.attn(
611
- x_windows, mask=self.attn_mask
612
- ) # nW*B, window_size*window_size, C
613
-
614
- # merge windows
615
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
616
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
617
-
618
- # reverse cyclic shift
619
- if self.shift_size > 0:
620
- x = torch.roll(
621
- shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
622
- )
623
- else:
624
- x = shifted_x
625
- x = x.view(B, H * W, C)
626
-
627
- # FFN
628
- x = shortcut + self.drop_path(x)
629
- x = x + self.drop_path(self.mlp(self.norm2(x)))
630
-
631
- return x, attn
632
-
633
- def extra_repr(self):
634
- return (
635
- f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
636
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
637
- )
638
-
639
-
640
- class PatchMerging(nn.Module):
641
- r"""Patch Merging Layer.
642
- Args:
643
- input_resolution (tuple[int]): Resolution of input feature.
644
- dim (int): Number of input channels.
645
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
646
- """
647
-
648
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
649
- super().__init__()
650
- self.input_resolution = input_resolution
651
- self.dim = dim
652
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
653
- self.norm = norm_layer(4 * dim)
654
-
655
- def forward(self, x):
656
- """
657
- x: B, H*W, C
658
- """
659
- H, W = self.input_resolution
660
- B, L, C = x.shape
661
- assert L == H * W, "input feature has wrong size"
662
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
663
-
664
- x = x.view(B, H, W, C)
665
-
666
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
667
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
668
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
669
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
670
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
671
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
672
-
673
- x = self.norm(x)
674
- x = self.reduction(x)
675
-
676
- return x
677
-
678
- def extra_repr(self):
679
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
680
-
681
-
682
- class BasicLayer(nn.Module):
683
- """A basic Swin Transformer layer for one stage.
684
- Args:
685
- dim (int): Number of input channels.
686
- input_resolution (tuple[int]): Input resolution.
687
- depth (int): Number of blocks.
688
- num_heads (int): Number of attention heads.
689
- window_size (int): Local window size.
690
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
691
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
692
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
693
- drop (float, optional): Dropout rate. Default: 0.0
694
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
695
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
696
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
697
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
698
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
699
- """
700
-
701
- def __init__(
702
- self,
703
- dim,
704
- input_resolution,
705
- depth,
706
- num_heads,
707
- window_size,
708
- mlp_ratio=4.0,
709
- qkv_bias=True,
710
- qk_scale=None,
711
- drop=0.0,
712
- attn_drop=0.0,
713
- drop_path=0.0,
714
- norm_layer=nn.LayerNorm,
715
- downsample=None,
716
- use_checkpoint=False,
717
- norm_before_mlp="ln",
718
- ):
719
-
720
- super().__init__()
721
- self.dim = dim
722
- self.input_resolution = input_resolution
723
- self.depth = depth