Leon299 commited on
Commit
ec0bc9b
·
verified ·
1 Parent(s): 8337fa0

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. MuCodec/libs/rvq/__pycache__/descript_quantize3.cpython-312.pyc +0 -0
  2. MuCodec/models/__pycache__/attention.cpython-310.pyc +0 -0
  3. MuCodec/models/__pycache__/attention.cpython-312.pyc +0 -0
  4. MuCodec/models/__pycache__/transformer_2d_flow.cpython-310.pyc +0 -0
  5. MuCodec/models/__pycache__/transformer_2d_flow.cpython-312.pyc +0 -0
  6. MuCodec/muq_dev/__pycache__/test.cpython-310.pyc +0 -0
  7. MuCodec/muq_dev/__pycache__/test.cpython-312.pyc +0 -0
  8. MuCodec/muq_dev/muq_fairseq/data/__init__.py +1 -0
  9. MuCodec/muq_dev/muq_fairseq/data/__pycache__/__init__.cpython-310.pyc +0 -0
  10. MuCodec/muq_dev/muq_fairseq/data/__pycache__/ark_dataset.cpython-310.pyc +0 -0
  11. MuCodec/muq_dev/muq_fairseq/data/__pycache__/mert_dataset.cpython-310.pyc +0 -0
  12. MuCodec/muq_dev/muq_fairseq/data/ark_dataset.py +71 -0
  13. MuCodec/muq_dev/muq_fairseq/data/mert_dataset.py +295 -0
  14. MuCodec/muq_dev/muq_fairseq/data/utils/data_utils.py +535 -0
  15. MuCodec/muq_dev/muq_fairseq/models/muq/__init__.py +1 -0
  16. MuCodec/muq_dev/muq_fairseq/models/muq/__pycache__/__init__.cpython-310.pyc +0 -0
  17. MuCodec/muq_dev/muq_fairseq/models/muq/__pycache__/muq_model.cpython-310.pyc +0 -0
  18. MuCodec/muq_dev/muq_fairseq/models/muq/model/__init__.py +2 -0
  19. MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/__init__.cpython-310.pyc +0 -0
  20. MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/muq.cpython-310.pyc +0 -0
  21. MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/rvq.cpython-310.pyc +0 -0
  22. MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/rvq_muq.cpython-310.pyc +0 -0
  23. MuCodec/muq_dev/muq_fairseq/models/muq/model/muq.py +520 -0
  24. MuCodec/muq_dev/muq_fairseq/models/muq/model/pred_ark_target_with_model.py +151 -0
  25. MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq.py +459 -0
  26. MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq_muq.py +394 -0
  27. MuCodec/muq_dev/muq_fairseq/models/muq/model/w2v2_config.json +113 -0
  28. MuCodec/muq_dev/muq_fairseq/models/muq/modules/__init__.py +2 -0
  29. MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  30. MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/conv.cpython-310.pyc +0 -0
  31. MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/features.cpython-310.pyc +0 -0
  32. MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/random_quantizer.cpython-310.pyc +0 -0
  33. MuCodec/muq_dev/muq_fairseq/models/muq/modules/conv.py +77 -0
  34. MuCodec/muq_dev/muq_fairseq/models/muq/modules/features.py +67 -0
  35. MuCodec/muq_dev/muq_fairseq/models/muq/modules/flash_conformer.py +2114 -0
  36. MuCodec/muq_dev/muq_fairseq/models/muq/modules/random_quantizer.py +68 -0
  37. MuCodec/muq_dev/muq_fairseq/models/muq/muq_model.py +139 -0
  38. MuCodec/muq_dev/muq_fairseq/tasks/__pycache__/muq_pretraining.cpython-310.pyc +0 -0
  39. MuCodec/muq_dev/muq_fairseq/tasks/muq_pretraining.py +354 -0
  40. MuCodec/tools/__pycache__/get_melvaehifigan48k.cpython-310.pyc +0 -0
  41. MuCodec/tools/__pycache__/torch_tools.cpython-310.pyc +0 -0
  42. MuCodec/tools/__pycache__/torch_tools.cpython-312.pyc +0 -0
  43. checkpoints/Qwen3-0.6B/.gitattributes +36 -0
  44. checkpoints/Qwen3-0.6B/LICENSE +202 -0
  45. checkpoints/Qwen3-0.6B/README.md +301 -0
  46. checkpoints/Qwen3-0.6B/config.json +33 -0
  47. checkpoints/Qwen3-0.6B/generation_config.json +13 -0
  48. checkpoints/Qwen3-0.6B/merges.txt +0 -0
  49. checkpoints/Qwen3-0.6B/tokenizer_config.json +239 -0
  50. checkpoints/Qwen3-0.6B/vocab.json +0 -0
MuCodec/libs/rvq/__pycache__/descript_quantize3.cpython-312.pyc ADDED
Binary file (16.1 kB). View file
 
MuCodec/models/__pycache__/attention.cpython-310.pyc ADDED
Binary file (16.3 kB). View file
 
MuCodec/models/__pycache__/attention.cpython-312.pyc ADDED
Binary file (25.6 kB). View file
 
MuCodec/models/__pycache__/transformer_2d_flow.cpython-310.pyc ADDED
Binary file (17.9 kB). View file
 
MuCodec/models/__pycache__/transformer_2d_flow.cpython-312.pyc ADDED
Binary file (26.9 kB). View file
 
MuCodec/muq_dev/__pycache__/test.cpython-310.pyc ADDED
Binary file (866 Bytes). View file
 
MuCodec/muq_dev/__pycache__/test.cpython-312.pyc ADDED
Binary file (1.19 kB). View file
 
MuCodec/muq_dev/muq_fairseq/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .mert_dataset import MERTDataset
MuCodec/muq_dev/muq_fairseq/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (219 Bytes). View file
 
MuCodec/muq_dev/muq_fairseq/data/__pycache__/ark_dataset.cpython-310.pyc ADDED
Binary file (2.35 kB). View file
 
MuCodec/muq_dev/muq_fairseq/data/__pycache__/mert_dataset.cpython-310.pyc ADDED
Binary file (9.85 kB). View file
 
MuCodec/muq_dev/muq_fairseq/data/ark_dataset.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from fairseq.data.audio.raw_audio_dataset import RawAudioDataset
5
+ from typing import Tuple
6
+ try:
7
+ import kaldiio
8
+ except:
9
+ kaldiio = None
10
+ import warnings
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ArkDataset(RawAudioDataset):
16
+ def __init__(
17
+ self,
18
+ wav_scp,
19
+ dur_scp,
20
+ sr = 24000,
21
+ max_dur = 20,
22
+ num_buckets=0,
23
+ normalize=False,
24
+ ):
25
+ super().__init__(
26
+ sample_rate=sr,
27
+ max_sample_size=max_dur*sr,
28
+ min_sample_size=1200,
29
+ shuffle=True,
30
+ pad=True,
31
+ normalize=normalize,
32
+ compute_mask=False,
33
+ )
34
+ self.sr = sr
35
+ self.max_dur = max_dur
36
+ self.normalize = normalize
37
+
38
+ logger.info("Loading Kaldi scp files from {}".format(wav_scp))
39
+
40
+ self.wav_data = kaldiio.load_scp(wav_scp)
41
+ self.keys = list(self.wav_data.keys())
42
+ dur_data = {}
43
+ keys_set = set(self.keys)
44
+
45
+ with open(dur_scp, 'r') as f:
46
+ for line in f:
47
+ line = line.strip().split()
48
+ if line[0] in keys_set:
49
+ dur_data[line[0]] = float(line[-1])
50
+ self.sizes = [int(dur_data[k]*self.sr/100) for k in self.keys]
51
+
52
+ logger.info("Loading Kaldi scp files done")
53
+
54
+ self.dataset_len = len(self.keys)
55
+ self.set_bucket_info(num_buckets)
56
+
57
+ def __len__(self):
58
+ return self.dataset_len
59
+
60
+ def __getitem__(self, idx):
61
+ pass
62
+
63
+ def size(self, idx):
64
+ pass
65
+
66
+ def postprocess(self, wav):
67
+ pass
68
+
69
+ def collater(self, samples):
70
+ pass
71
+
MuCodec/muq_dev/muq_fairseq/data/mert_dataset.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import logging
8
+ import os
9
+ import sys
10
+ from typing import Any, List, Optional, Union
11
+
12
+ import numpy as np
13
+ from typing import Tuple
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from fairseq.data import data_utils
17
+ from fairseq.data.fairseq_dataset import FairseqDataset
18
+ from fairseq.data.audio.audio_utils import (
19
+ parse_path,
20
+ read_from_stored_zip,
21
+ )
22
+
23
+ import math
24
+ import io
25
+ import torchaudio
26
+ # this is in the user_dir
27
+ from nnAudio import features as nnAudioFeatures
28
+
29
+ # from tqdm import tqdm
30
+ import tqdm
31
+ import json
32
+ import random
33
+ import traceback
34
+ from einops import rearrange
35
+ # from scripts.prepare_codecs_from_manifest import *
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ class model_cqt_pred(torch.nn.Module):
40
+ def __init__(self, n_bins=84, sr=16000, freq=50):
41
+ super().__init__()
42
+ self.epsilon=1e-10
43
+ # Getting Mel Spectrogram on the fly
44
+ self.spec_layer = nnAudioFeatures.cqt.CQT(sr=sr, hop_length=sr//freq, fmin=32.7,
45
+ fmax=None, n_bins=n_bins, bins_per_octave=n_bins//7,
46
+ filter_scale=1, norm=1, window='hann', center=True,
47
+ pad_mode='constant', trainable=False,
48
+ output_format='Magnitude', verbose=True)
49
+
50
+ # self.fc = nn.Linear(input_dim, n_bins)
51
+
52
+ # self.criterion = nn.MSELoss()
53
+ self.forward_dict = {
54
+ # 'masked_transformer_output': self.plain_forward
55
+ 'compute_cqt': self.compute_cqt
56
+ }
57
+ def compute_cqt(self, x):
58
+ '''
59
+ convert waveform to CQT -> [batch, bins, len] -> transpose
60
+ '''
61
+ # align with the padding of HuBERT model,
62
+ # the truncation is calculated by bruteforce search since the nnAudio padding strategy and fairseq models are different
63
+ # x = x[..., :-560]
64
+ return torch.transpose(self.spec_layer(x), -1, -2)
65
+
66
+ def forward(self, x, forward_type='masked_transformer_output'):
67
+ '''
68
+ take input from transformer hidden states: [batch, len_seq, channel]
69
+ output: [batch, len_seq, n_bins]
70
+ '''
71
+
72
+ return self.forward_dict[forward_type](x)
73
+
74
+ def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate, clip_secs=5):
75
+ # read json file
76
+ print(json_path)
77
+ datas = []
78
+ inds = []
79
+ sizes = []
80
+ with open(json_path) as fp:
81
+ for ind,line in enumerate(fp):
82
+ data = json.loads(line)
83
+ if 'duration' in data and min_keep is not None and tgt_sample_rate*data['duration'] < min_keep:
84
+ continue
85
+ datas.append(data)
86
+ inds.append(ind)
87
+ # sz = int(data['duration'] * data['sample_rate'])
88
+ if clip_secs > 0:
89
+ sz = int(tgt_sample_rate * clip_secs)
90
+ else:
91
+ sz = int(tgt_sample_rate * data['duration'])
92
+ sizes.append(sz)
93
+ tot = ind + 1
94
+ return datas,inds,tot,sizes
95
+ def load_audio(manifest_path, max_keep, min_keep):
96
+ pass
97
+
98
+
99
+ def load_label(label_path, inds, tot):
100
+ pass
101
+
102
+ def load_numpy_label(label_path, inds, tot):
103
+ labels = np.load(label_path, mmap_mode='r')
104
+ assert (labels.shape[0] == tot), f"number of labels does not match ({labels.shape[0]} != {tot})"
105
+ return labels
106
+
107
+ def verify_label_lengths(
108
+ audio_sizes,
109
+ audio_rate,
110
+ label_path,
111
+ label_rate,
112
+ inds,
113
+ tot,
114
+ tol=0.1, # tolerance in seconds
115
+ ):
116
+ pass
117
+
118
+ class Read_and_PadCrop_Normalized_T(torch.nn.Module):
119
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
120
+
121
+ super().__init__()
122
+
123
+ self.n_samples = n_samples
124
+ self.sample_rate = sample_rate
125
+ self.randomize = randomize
126
+
127
+
128
+ def __call__(self, filename: str, duration: float, cur_sample_rate: int, fixed_offset_duration=None) -> Tuple[torch.Tensor, float, float, int, int]:
129
+ pass
130
+
131
+
132
+ class MERTDataset(FairseqDataset):
133
+ def __init__(
134
+ self,
135
+ manifest_path: str,
136
+ sample_rate: float,
137
+ label_paths: List[str],
138
+ label_rates: Union[List[float], float], # -1 for sequence labels
139
+ pad_list: List[str],
140
+ eos_list: List[str],
141
+ label_scp_path: Optional[str] = None,
142
+ label_scp_clip_duration: float = -1,
143
+ label_processors: Optional[List[Any]] = None,
144
+ max_keep_sample_size: Optional[int] = None,
145
+ min_keep_sample_size: Optional[int] = None,
146
+ max_sample_size: Optional[int] = None,
147
+ shuffle: bool = True,
148
+ pad_audio: bool = False,
149
+ normalize: bool = False,
150
+ store_labels: bool = True,
151
+ npmemmap: bool = False,
152
+ random_crop: bool = False,
153
+ single_target: bool = False,
154
+ augmentation_effects: List[str] = [],
155
+ augmentation_probs: List[float] = [],
156
+ inbatch_noise_augment_len_range: List[int] = [8000, 24000],
157
+ inbatch_noise_augment_number_range: List[int] = [1, 3],
158
+ inbatch_noise_augment_volume: float = 1.0,
159
+ cqt_prediction_bin: int = -1,
160
+ dataset_len:int = 128*3000,
161
+ clip_secs = 5,
162
+ ):
163
+ self.sample_rate = sample_rate
164
+ self.shuffle = shuffle
165
+ self.random_crop = random_crop
166
+ self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path,max_keep_sample_size,min_keep_sample_size, self.sample_rate, clip_secs)
167
+ self.inds = inds
168
+
169
+ self.num_labels = len(label_paths)
170
+ self.pad_list = pad_list
171
+ self.eos_list = eos_list
172
+ self.label_processors = label_processors
173
+ self.single_target = single_target
174
+ self.label_rates = (
175
+ [label_rates for _ in range(len(label_paths))]
176
+ if isinstance(label_rates, float)
177
+ else label_rates
178
+ )
179
+ self.store_labels = store_labels
180
+ self.npmemmap = npmemmap
181
+ self.label_scp_path = label_scp_path
182
+ self.label_scp_clip_duration = label_scp_clip_duration
183
+
184
+
185
+ if self.label_scp_path is not None:
186
+ from kaldiio import load_scp
187
+ self.label_scp = load_scp(self.label_scp_path)
188
+
189
+ # self.dataset_len = dataset_len
190
+ self.dataset_len = len(self.datas)
191
+ logger.info('preparing labels')
192
+ logger.info('========dataset len: {}=========='.format(self.dataset_len))
193
+ if store_labels:
194
+ if self.npmemmap:
195
+ self.label_list = [load_numpy_label(p+'.npy', inds, tot) for p in label_paths]
196
+ else:
197
+ self.label_list = [load_label(p, inds, tot) for p in label_paths]
198
+ else:
199
+ self.label_paths = label_paths
200
+ # self.label_offsets_list = [
201
+ # load_label_offset(p, inds, tot) for p in label_paths
202
+ # ]
203
+ assert label_processors is None or len(label_processors) == self.num_labels
204
+
205
+
206
+ self.max_sample_size = (
207
+ max_sample_size if max_sample_size is not None else sys.maxsize
208
+ )
209
+ self.pad_audio = pad_audio
210
+ self.normalize = normalize
211
+ logger.info(
212
+ f"pad_audio={pad_audio}, random_crop={random_crop}, "
213
+ f"normalize={normalize}, max_sample_size={self.max_sample_size}"
214
+ )
215
+
216
+ self.augmentation_effects = augmentation_effects
217
+ self.augmentation_probs = augmentation_probs
218
+
219
+
220
+ self.inbatch_noise_augment_len_range = inbatch_noise_augment_len_range
221
+ self.inbatch_noise_augment_number_range = inbatch_noise_augment_number_range
222
+ self.inbatch_noise_augment_volume = inbatch_noise_augment_volume
223
+
224
+
225
+ self.cqt_prediction_bin = cqt_prediction_bin
226
+ if self.cqt_prediction_bin > 0:
227
+ self.encoder_cqt_model = model_cqt_pred(n_bins=self.cqt_prediction_bin)
228
+ logger.info('preparing cqt loss objective in dataloader with cpu')
229
+
230
+ self.epoch = -1
231
+
232
+ self.reader = Read_and_PadCrop_Normalized_T(n_samples=clip_secs*sample_rate if clip_secs>0 else None, sample_rate = self.sample_rate)
233
+
234
+
235
+
236
+ @property
237
+ def can_reuse_epoch_itr_across_epochs(self):
238
+ pass
239
+ def set_epoch(self, epoch):
240
+ pass
241
+
242
+ def inbatch_noise_augment(self,
243
+ target_audio: torch.Tensor, target_audio_idx: int ,
244
+ batch_audios: torch.Tensor, # [bsz, audio_lengths]
245
+ noise_len_min: int, noise_len_max: int,
246
+ n_noise_min: int, n_noise_max: int,
247
+ noise_vol: float = 1.0):
248
+ pass
249
+
250
+ def get_audio_by_slice(self,index):
251
+ pass
252
+ def get_audio(self, index):
253
+ pass
254
+
255
+ def get_label(self, index, label_idx):
256
+ pass
257
+
258
+ def get_labels(self, index):
259
+ pass
260
+
261
+ def __getitem__(self, i):
262
+ pass
263
+
264
+ def __len__(self):
265
+ return self.dataset_len
266
+
267
+ def crop_to_max_size(self, wav, target_size):
268
+ pass
269
+
270
+ def collater(self, samples):
271
+ pass
272
+
273
+ def collater_audio(self, audios, audio_size):
274
+ pass
275
+
276
+ def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
277
+ pass
278
+
279
+ def collater_seq_label(self, targets, pad):
280
+ pass
281
+
282
+ def collater_label(self, targets_by_label, audio_size, audio_starts):
283
+ pass
284
+
285
+ def num_tokens(self, index):
286
+ pass
287
+
288
+ def size(self, index):
289
+ pass
290
+
291
+ def ordered_indices(self):
292
+ pass
293
+
294
+ def postprocess(self, wav, cur_sample_rate):
295
+ pass
MuCodec/muq_dev/muq_fairseq/data/utils/data_utils.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import math
8
+ import numpy as np
9
+ import torch
10
+
11
+ from typing import Optional, Tuple
12
+
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+
19
+ def compute_mask_indices(
20
+ shape: Tuple[int, int],
21
+ padding_mask: Optional[torch.Tensor],
22
+ mask_prob: float,
23
+ mask_length: int,
24
+ mask_type: str = "static",
25
+ mask_other: float = 0.0,
26
+ min_masks: int = 0,
27
+ no_overlap: bool = False,
28
+ min_space: int = 0,
29
+ require_same_masks: bool = True,
30
+ mask_dropout: float = 0.0,
31
+ add_masks: bool = False,
32
+ seed: Optional[int] = None,
33
+ epoch: Optional[int] = None,
34
+ indices: Optional[torch.Tensor] = None,
35
+ idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
36
+ num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
37
+ ) -> np.ndarray:
38
+ """
39
+ Computes random mask spans for a given shape
40
+
41
+ Args:
42
+ shape: the the shape for which to compute masks.
43
+ should be of size 2 where first element is batch size and 2nd is timesteps
44
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
45
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
46
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
47
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
48
+ mask_type: how to compute mask lengths
49
+ static = fixed size
50
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
51
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
52
+ poisson = sample from possion distribution with lambda = mask length
53
+ min_masks: minimum number of masked spans
54
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
55
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
56
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
57
+ mask_dropout: randomly dropout this percentage of masks in each example
58
+ """
59
+
60
+ bsz, all_sz = shape
61
+ mask = np.full((bsz, all_sz), False)
62
+
63
+ if num_mask_ver == 1:
64
+ all_num_mask = int(
65
+ # add a random number for probabilistic rounding
66
+ mask_prob * all_sz / float(mask_length)
67
+ + np.random.rand()
68
+ )
69
+ all_num_mask = max(min_masks, all_num_mask)
70
+
71
+ mask_idcs = []
72
+ for i in range(bsz):
73
+ if seed is not None and epoch is not None and indices is not None:
74
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
75
+ else:
76
+ seed_i = None
77
+
78
+ rng = np.random.default_rng(seed_i)
79
+
80
+ if padding_mask is not None:
81
+ sz = all_sz - padding_mask[i].long().sum().item()
82
+ assert sz >= 0, sz
83
+ else:
84
+ sz = all_sz
85
+
86
+ if num_mask_ver == 1:
87
+ if padding_mask is not None:
88
+ num_mask = int(
89
+ # add a random number for probabilistic rounding
90
+ mask_prob * sz / float(mask_length)
91
+ + np.random.rand()
92
+ )
93
+ num_mask = max(min_masks, num_mask)
94
+ else:
95
+ num_mask = all_num_mask
96
+ elif num_mask_ver == 2:
97
+ num_mask = int(
98
+ # add a random number for probabilistic rounding
99
+ mask_prob * sz / float(mask_length)
100
+ + rng.random()
101
+ )
102
+ num_mask = max(min_masks, num_mask)
103
+ else:
104
+ raise ValueError()
105
+
106
+ if mask_type == "static":
107
+ lengths = np.full(num_mask, mask_length)
108
+ elif mask_type == "uniform":
109
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
110
+ elif mask_type == "normal":
111
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
112
+ lengths = [max(1, int(round(x))) for x in lengths]
113
+ elif mask_type == "poisson":
114
+ lengths = rng.poisson(mask_length, size=num_mask)
115
+ lengths = [int(round(x)) for x in lengths]
116
+ else:
117
+ raise Exception("unknown mask selection " + mask_type)
118
+
119
+ if sum(lengths) == 0:
120
+ if mask_type == "static":
121
+ raise ValueError(f"this should never happens")
122
+ else:
123
+ lengths = [min(mask_length, sz - 1)]
124
+
125
+ if no_overlap:
126
+ mask_idc = []
127
+
128
+ def arrange(s, e, length, keep_length):
129
+ span_start = rng.randint(s, e - length)
130
+ mask_idc.extend(span_start + i for i in range(length))
131
+
132
+ new_parts = []
133
+ if span_start - s - min_space >= keep_length:
134
+ new_parts.append((s, span_start - min_space + 1))
135
+ if e - span_start - length - min_space > keep_length:
136
+ new_parts.append((span_start + length + min_space, e))
137
+ return new_parts
138
+
139
+ parts = [(0, sz)]
140
+ min_length = min(lengths)
141
+ for length in sorted(lengths, reverse=True):
142
+ lens = np.fromiter(
143
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
144
+ np.int,
145
+ )
146
+ l_sum = np.sum(lens)
147
+ if l_sum == 0:
148
+ break
149
+ probs = lens / np.sum(lens)
150
+ c = rng.choice(len(parts), p=probs)
151
+ s, e = parts.pop(c)
152
+ parts.extend(arrange(s, e, length, min_length))
153
+ mask_idc = np.asarray(mask_idc)
154
+ else:
155
+ if idc_select_ver == 1:
156
+ min_len = min(lengths)
157
+ if sz - min_len <= num_mask:
158
+ min_len = sz - num_mask - 1
159
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
160
+ elif idc_select_ver == 2:
161
+ mask_idc = rng.choice(sz, num_mask, replace=False)
162
+ else:
163
+ raise ValueError()
164
+
165
+ mask_idc = np.asarray(
166
+ [
167
+ mask_idc[j] + offset
168
+ for j in range(len(mask_idc))
169
+ for offset in range(lengths[j])
170
+ ]
171
+ )
172
+
173
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
174
+ if len(mask_idc) >= sz:
175
+ raise ValueError(
176
+ (
177
+ f"the entire sequence is masked. "
178
+ f"sz={sz}; mask_idc[mask_idc]; "
179
+ f"index={indices[i] if indices is not None else None}"
180
+ )
181
+ )
182
+ mask_idcs.append(mask_idc)
183
+
184
+ target_len = None
185
+ if require_same_masks:
186
+ if add_masks:
187
+ target_len = max([len(m) for m in mask_idcs])
188
+ else:
189
+ target_len = min([len(m) for m in mask_idcs])
190
+
191
+ for i, mask_idc in enumerate(mask_idcs):
192
+ if target_len is not None and len(mask_idc) > target_len:
193
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
194
+
195
+ mask[i, mask_idc] = True
196
+
197
+ if target_len is not None and len(mask_idc) < target_len:
198
+ unmasked = np.flatnonzero(~mask[i])
199
+ to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
200
+ mask[i, to_mask] = True
201
+
202
+ if mask_dropout > 0:
203
+ masked = np.flatnonzero(mask[i])
204
+ num_holes = np.rint(len(masked) * mask_dropout).astype(int)
205
+ to_drop = rng.choice(masked, num_holes, replace=False)
206
+ mask[i, to_drop] = False
207
+
208
+ return mask
209
+
210
+
211
+ def compute_block_mask_2d(
212
+ shape: Tuple[int, int],
213
+ mask_prob: float,
214
+ mask_length: int,
215
+ mask_prob_adjust: float = 0,
216
+ inverse_mask: bool = False,
217
+ require_same_masks: bool = True,
218
+ expand_adjcent: bool = False,
219
+ mask_dropout: float = 0,
220
+ non_overlapping: bool = False,
221
+ img_shape: tuple = None, # For the situation when d[0] != d[1], especially in audio spce ways
222
+ flexible_mask: bool = False,
223
+ ) -> torch.Tensor:
224
+
225
+ assert mask_length > 1
226
+
227
+ B, L = shape
228
+
229
+ d = (int(L**0.5),int(L**0.5))
230
+
231
+ if img_shape:
232
+ d = (img_shape[0],img_shape[1])
233
+
234
+ if flexible_mask:
235
+ index = np.random.randint(0,3)
236
+ block_size_options = np.array([(6, 4), (5, 5), (8, 3)])
237
+ block_size = block_size_options[index]
238
+
239
+ if inverse_mask:
240
+ mask_prob = 1 - mask_prob
241
+
242
+ if flexible_mask:
243
+ mask = torch.zeros((B, d[0], d[1]))
244
+ mask_inds = torch.randint(
245
+ 0,
246
+ L,
247
+ size=(
248
+ B,
249
+ int(
250
+ L
251
+ * ((mask_prob + mask_prob_adjust) / (block_size[0]*block_size[1]))
252
+ * (1 + mask_dropout)
253
+ ),
254
+ ),
255
+ )
256
+ mask.view(B, -1).scatter_(1, mask_inds, 1)
257
+ centers = mask.nonzero(as_tuple=True)
258
+
259
+ inds = ([], [], [])
260
+
261
+ offset = mask_length // 2
262
+ for i in range(block_size[0]):
263
+ for j in range(block_size[1]):
264
+ k1 = i - offset
265
+ k2 = j - offset
266
+ inds[0].append(centers[0])
267
+ inds[1].append(centers[1] + k1)
268
+ inds[2].append(centers[2] + k2)
269
+
270
+ i0 = torch.cat(inds[0])
271
+ i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1)
272
+ i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1)
273
+
274
+ mask[(i0, i1, i2)] = 1
275
+
276
+ elif non_overlapping:
277
+ sz = math.ceil(d[0] / mask_length)
278
+ inp_len = sz * sz
279
+
280
+ inp = torch.zeros((B, 1, sz, sz))
281
+ w = torch.ones((1, 1, mask_length, mask_length))
282
+
283
+ mask_inds = torch.multinomial(
284
+ 1 - inp.view(B, -1),
285
+ int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
286
+ replacement=False,
287
+ )
288
+ inp.view(B, -1).scatter_(1, mask_inds, 1)
289
+
290
+ mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze(
291
+ 1
292
+ )
293
+ if mask.size(-1) > d[0]:
294
+ mask = mask[..., :d, :d]
295
+ else:
296
+ mask = torch.zeros((B, d[0], d[1]))
297
+ mask_inds = torch.randint(
298
+ 0,
299
+ L,
300
+ size=(
301
+ B,
302
+ int(
303
+ L
304
+ * ((mask_prob + mask_prob_adjust) / mask_length**2)
305
+ * (1 + mask_dropout)
306
+ ),
307
+ ),
308
+ )
309
+ mask.view(B, -1).scatter_(1, mask_inds, 1)
310
+ centers = mask.nonzero(as_tuple=True)
311
+
312
+ inds = ([], [], [])
313
+
314
+ offset = mask_length // 2
315
+ for i in range(mask_length):
316
+ for j in range(mask_length):
317
+ k1 = i - offset
318
+ k2 = j - offset
319
+ inds[0].append(centers[0])
320
+ inds[1].append(centers[1] + k1)
321
+ inds[2].append(centers[2] + k2)
322
+
323
+ i0 = torch.cat(inds[0])
324
+ i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1)
325
+ i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1)
326
+
327
+ mask[(i0, i1, i2)] = 1
328
+
329
+ def get_nbs(b, m, w):
330
+ all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same")
331
+ all_nbs = all_nbs.clamp_max_(1).view(b, -1)
332
+ return all_nbs
333
+
334
+ if require_same_masks and expand_adjcent:
335
+ w = torch.zeros((1, 1, 3, 3))
336
+ w[..., 0, 1] = 1
337
+ w[..., 2, 1] = 1
338
+ w[..., 1, 0] = 1
339
+ w[..., 1, 2] = 1
340
+
341
+ all_nbs = get_nbs(B, mask, w)
342
+
343
+ mask = mask.reshape(B, -1)
344
+
345
+ if require_same_masks:
346
+ n_masks = mask.sum(dim=-1)
347
+ final_target_len = int(L * (mask_prob))
348
+ target_len = int(final_target_len * (1 + mask_dropout))
349
+
350
+ for i in range(len(mask)):
351
+ n = n_masks[i]
352
+ m = mask[i]
353
+ r = 0
354
+ while expand_adjcent and n < target_len:
355
+ if r == 0:
356
+ nbs = all_nbs[i]
357
+ else:
358
+ nbs = get_nbs(1, m.view(1, d[0], d[1]), w).flatten()
359
+
360
+ cands = (1 - m + nbs) > 1
361
+ cand_sz = int(cands.sum().item())
362
+
363
+ assert cand_sz > 0, f"{nbs} {cand_sz}"
364
+
365
+ to_mask = torch.multinomial(
366
+ cands.float(), min(cand_sz, int(target_len - n)), replacement=False
367
+ )
368
+ m[to_mask] = 1
369
+ assert to_mask.numel() > 0
370
+ n += to_mask.numel()
371
+ r += 1
372
+
373
+ if n > final_target_len:
374
+ to_unmask = torch.multinomial(
375
+ m, int(n - final_target_len), replacement=False
376
+ )
377
+ m[to_unmask] = 0
378
+ elif n < final_target_len:
379
+ to_mask = torch.multinomial(
380
+ (1 - m), int(final_target_len - n), replacement=False
381
+ )
382
+ m[to_mask] = 1
383
+
384
+ if inverse_mask:
385
+ mask = 1 - mask
386
+
387
+ return mask
388
+
389
+
390
+ def compute_block_mask_1d(
391
+ shape: Tuple[int, int],
392
+ mask_prob: float,
393
+ mask_length: int,
394
+ mask_prob_adjust: float = 0,
395
+ inverse_mask: bool = False,
396
+ require_same_masks: bool = True,
397
+ expand_adjcent: bool = False,
398
+ mask_dropout: float = 0,
399
+ non_overlapping: bool = False,
400
+ ) -> torch.Tensor:
401
+
402
+ B, L = shape
403
+
404
+ if inverse_mask:
405
+ mask_prob = 1 - mask_prob
406
+
407
+ if non_overlapping:
408
+ sz = math.ceil(L / mask_length)
409
+
410
+ inp = torch.zeros((B, 1, sz))
411
+ w = torch.ones((1, 1, mask_length))
412
+
413
+ mask_inds = torch.multinomial(
414
+ 1 - inp.view(B, -1),
415
+ int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
416
+ replacement=False,
417
+ )
418
+ inp.view(B, -1).scatter_(1, mask_inds, 1)
419
+
420
+ mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze(
421
+ 1
422
+ )
423
+ if mask.size(-1) > L:
424
+ mask = mask[..., :L]
425
+
426
+ else:
427
+ mask = torch.zeros((B, L))
428
+ mask_inds = torch.randint(
429
+ 0,
430
+ L,
431
+ size=(
432
+ B,
433
+ int(
434
+ L
435
+ * ((mask_prob + mask_prob_adjust) / mask_length)
436
+ * (1 + mask_dropout)
437
+ ),
438
+ ),
439
+ )
440
+
441
+ mask.view(B, -1).scatter_(1, mask_inds, 1)
442
+ centers = mask.nonzero(as_tuple=True)
443
+
444
+ inds = ([], [])
445
+
446
+ offset = mask_length // 2
447
+ for i in range(mask_length):
448
+ k1 = i - offset
449
+ inds[0].append(centers[0])
450
+ inds[1].append(centers[1] + k1)
451
+
452
+ i0 = torch.cat(inds[0])
453
+ i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1)
454
+
455
+ mask[(i0, i1)] = 1
456
+
457
+ def get_nbs(b, m, w):
458
+ all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same")
459
+ all_nbs = all_nbs.clamp_max_(1).view(b, -1)
460
+ return all_nbs
461
+
462
+ if require_same_masks and expand_adjcent:
463
+ w = torch.ones((1, 1, 3))
464
+ w[..., 1] = 0
465
+ all_nbs = get_nbs(B, mask, w)
466
+
467
+ mask = mask.view(B, -1)
468
+
469
+ if require_same_masks:
470
+ n_masks = mask.sum(dim=-1)
471
+ final_target_len = int(L * (mask_prob))
472
+ target_len = int(final_target_len * (1 + mask_dropout))
473
+
474
+ for i in range(len(mask)):
475
+ n = n_masks[i]
476
+ m = mask[i]
477
+ r = 0
478
+ while expand_adjcent and n < target_len:
479
+ if r == 0:
480
+ nbs = all_nbs[i]
481
+ else:
482
+ nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0)
483
+
484
+ cands = (1 - m + nbs) > 1
485
+ cand_sz = int(cands.sum().item())
486
+
487
+ assert cand_sz > 0, f"{nbs} {cand_sz}"
488
+
489
+ to_mask = torch.multinomial(
490
+ cands.float(), min(cand_sz, int(target_len - n)), replacement=False
491
+ )
492
+ m[to_mask] = 1
493
+ assert to_mask.numel() > 0
494
+ n += to_mask.numel()
495
+ r += 1
496
+
497
+ if n > final_target_len:
498
+ to_unmask = torch.multinomial(
499
+ m, int(n - final_target_len), replacement=False
500
+ )
501
+ m[to_unmask] = 0
502
+ elif n < final_target_len:
503
+ to_mask = torch.multinomial(
504
+ (1 - m), int(final_target_len - n), replacement=False
505
+ )
506
+ m[to_mask] = 1
507
+
508
+ if inverse_mask:
509
+ mask = 1 - mask
510
+
511
+ return mask
512
+
513
+
514
+ def get_buckets(sizes, num_buckets):
515
+ buckets = np.unique(
516
+ np.percentile(
517
+ sizes,
518
+ np.linspace(0, 100, num_buckets + 1),
519
+ interpolation="lower",
520
+ )[1:]
521
+ )
522
+ return buckets
523
+
524
+
525
+ def get_bucketed_sizes(orig_sizes, buckets):
526
+ sizes = np.copy(orig_sizes)
527
+ assert np.min(sizes) >= 0
528
+ start_val = -1
529
+ for end_val in buckets:
530
+ mask = (sizes > start_val) & (sizes <= end_val)
531
+ sizes[mask] = end_val
532
+ start_val = end_val
533
+ return sizes
534
+
535
+
MuCodec/muq_dev/muq_fairseq/models/muq/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .muq_model import *
MuCodec/muq_dev/muq_fairseq/models/muq/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (203 Bytes). View file
 
MuCodec/muq_dev/muq_fairseq/models/muq/__pycache__/muq_model.cpython-310.pyc ADDED
Binary file (4.96 kB). View file
 
MuCodec/muq_dev/muq_fairseq/models/muq/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (183 Bytes). View file
 
MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/muq.cpython-310.pyc ADDED
Binary file (15.8 kB). View file
 
MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/rvq.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/rvq_muq.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
MuCodec/muq_dev/muq_fairseq/models/muq/model/muq.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from einops import rearrange
6
+ import os
7
+ from fairseq.data.data_utils import compute_mask_indices
8
+ from fairseq.models.wav2vec.wav2vec2 import ConvFeatureExtractionModel
9
+ from fairseq.modules import LayerNorm
10
+
11
+ try:
12
+ from ..modules.random_quantizer import RandomProjectionQuantizer
13
+ from ..modules.features import MelSTFT
14
+ from ..modules.conv import Conv2dSubsampling
15
+ except:
16
+ import sys, os
17
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
18
+ from modules.random_quantizer import RandomProjectionQuantizer
19
+ from modules.features import MelSTFT
20
+ from modules.conv import Conv2dSubsampling
21
+
22
+
23
+ class MuQ(nn.Module):
24
+ """
25
+ MuQ
26
+
27
+ Input: 128-band mel spectrogram
28
+ Frontend: 2-layer Residual convolution
29
+ Backend: 12-layer Conformer
30
+ Quantizer: a codebook for mel spectrogram
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ num_codebooks=1,
36
+ codebook_dim=16,
37
+ codebook_size=4096,
38
+ features=["melspec_2048"],
39
+ hop_length=240,
40
+ n_mels=128,
41
+ conv_dim=512,
42
+ encoder_dim=1024,
43
+ encoder_depth=12,
44
+ mask_hop=0.4,
45
+ mask_prob=0.6,
46
+ is_flash=False,
47
+ stat_path=None, #"./data/fma_stats.json",
48
+ model_path=None, #"./data/pretrained_fma.pt",
49
+ w2v2_config_path=None, #"facebook/wav2vec2-conformer-rope-large-960h-ft",
50
+ use_rvq_target=False,
51
+ use_vq_target=False,
52
+ rvq_ckpt_path=None,
53
+ recon_loss_ratio=None,
54
+ label_rate=25,
55
+ use_hubert_masking_strategy=False,
56
+ use_hubert_featurizer=False,
57
+ hubert_conv_feature_layers="[(512,10,5)] + [(512,3,2)] * 3 + [(512,3,3)] + [(512,2,2)] * 2",
58
+ use_hubert_nce_loss=False,
59
+ hubert_final_dim=256,
60
+ rvq_n_codebooks=8,
61
+ rvq_multi_layer_num=1,
62
+ use_encodec_target=False,
63
+ ):
64
+ super(MuQ, self).__init__()
65
+
66
+ # global variables
67
+ self.hop_length = hop_length
68
+ self.mask_hop = mask_hop
69
+ self.mask_prob = mask_prob
70
+ self.num_codebooks = num_codebooks
71
+ self.codebook_size = codebook_size
72
+ self.features = features
73
+ self.recon_loss_ratio = recon_loss_ratio
74
+ self.n_fold = int(100//label_rate)
75
+ self.label_rate = label_rate
76
+ self.use_hubert_masking_strategy = use_hubert_masking_strategy
77
+ self.use_hubert_featurizer = use_hubert_featurizer
78
+ self.use_hubert_nce_loss = use_hubert_nce_loss
79
+
80
+ # load feature mean / std stats
81
+ import os
82
+ if stat_path is not None and os.path.exists(stat_path):
83
+ with open(stat_path, "r") as f:
84
+ self.stat = json.load(f)
85
+ else:
86
+ # print("No stats file found at `{}`, use default from msd.".format(stat_path))
87
+ self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
88
+
89
+ # feature extractor
90
+ self.preprocessor_melspec_2048 = MelSTFT(
91
+ n_fft=2048, hop_length=hop_length, is_db=True
92
+ )
93
+
94
+ # random quantizer
95
+ self.use_rvq_target = use_rvq_target
96
+ self.use_vq_target = use_vq_target
97
+ self.use_encodec_target = use_encodec_target
98
+
99
+ seed = 142
100
+ if self.use_rvq_like_target:
101
+ if use_rvq_target:
102
+ try:
103
+ from .rvq_muq import ResidualVectorQuantize
104
+ except:
105
+ import sys, os
106
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
107
+ from rvq_muq import ResidualVectorQuantize
108
+
109
+ inp_dim = 128*self.n_fold
110
+ self.rvq = ResidualVectorQuantize(
111
+ input_dim = inp_dim,
112
+ n_codebooks = rvq_n_codebooks,
113
+ codebook_size = 1024,
114
+ codebook_dim = 16,
115
+ quantizer_dropout = 0.0,
116
+ use_multi_layer_num = rvq_multi_layer_num,
117
+ )
118
+ elif use_vq_target:
119
+ try:
120
+ from .rvq_muq import VectorQuantize
121
+ except:
122
+ import sys, os
123
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
124
+ from rvq_muq import VectorQuantize
125
+
126
+ self.rvq = VectorQuantize(
127
+ input_dim = 128*self.n_fold,
128
+ codebook_size = 1024,
129
+ codebook_dim = 8,
130
+ stale_tolerance = 1000,
131
+ mfcc_clustering = False
132
+ )
133
+ elif use_encodec_target:
134
+ from encodec import EncodecModel
135
+ self.rvq = EncodecModel.encodec_model_24khz()
136
+ self.rvq.set_target_bandwidth(6.0)
137
+ for param in self.rvq.parameters():
138
+ param.requires_grad = False
139
+
140
+ import os
141
+ if rvq_ckpt_path is not None and os.path.exists(rvq_ckpt_path):
142
+ state_dict = torch.load(rvq_ckpt_path, map_location="cpu")
143
+ self.rvq.load_state_dict(state_dict)
144
+ else:
145
+ print(f'Checkpoint for rvq `{rvq_ckpt_path}` not found. Using random initialization.')
146
+ else:
147
+ for feature in self.features:
148
+ for i in range(num_codebooks):
149
+ setattr(
150
+ self,
151
+ f"quantizer_{feature}", # _{i}
152
+ RandomProjectionQuantizer(
153
+ n_mels * self.n_fold, codebook_dim, codebook_size, seed=seed + i
154
+ ),
155
+ )
156
+
157
+ if use_hubert_masking_strategy:
158
+ self.mask_emb = nn.Parameter(
159
+ torch.FloatTensor(encoder_dim).uniform_()
160
+ )
161
+
162
+ if use_hubert_featurizer:
163
+ feature_enc_layers = eval(hubert_conv_feature_layers) # noqa
164
+ hubert_feat_embed = feature_enc_layers[-1][0]
165
+ self.hubert_feature_extractor = ConvFeatureExtractionModel(
166
+ conv_layers=feature_enc_layers,
167
+ dropout=0.0,
168
+ mode='default', #cfg.extractor_mode,
169
+ conv_bias=False, #cfg.conv_bias,
170
+ )
171
+ self.post_extract_proj = (
172
+ nn.Linear(hubert_feat_embed, encoder_dim)
173
+ if hubert_feat_embed != encoder_dim
174
+ else None
175
+ )
176
+ self.layer_norm = LayerNorm(hubert_feat_embed)
177
+ else:
178
+ # two residual convolution layers + one projection layer
179
+ strides_factory = {
180
+ 4: [2, 2],
181
+ 2: [2, 1]
182
+ }
183
+ self.conv = Conv2dSubsampling(
184
+ 1, conv_dim, encoder_dim, strides=strides_factory.get(self.n_fold), n_bands=n_mels
185
+ )
186
+
187
+ # Conformer
188
+ if is_flash:
189
+ from modules.flash_conformer import (
190
+ Wav2Vec2ConformerEncoder,
191
+ Wav2Vec2ConformerConfig,
192
+ )
193
+ else:
194
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
195
+ Wav2Vec2ConformerEncoder,
196
+ Wav2Vec2ConformerConfig,
197
+ )
198
+ import os
199
+ if w2v2_config_path is None or not os.path.exists(w2v2_config_path):
200
+ w2v2_config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "w2v2_config.json")
201
+ print("load w2v2 config from:", w2v2_config_path)
202
+ config = Wav2Vec2ConformerConfig.from_pretrained(
203
+ w2v2_config_path
204
+ )
205
+ config.num_hidden_layers = encoder_depth
206
+ config.hidden_size = encoder_dim
207
+
208
+ self.conformer = Wav2Vec2ConformerEncoder(config)
209
+
210
+ if self.use_hubert_nce_loss:
211
+ self.label_embs_concat = nn.Parameter(
212
+ torch.FloatTensor(codebook_size, hubert_final_dim)
213
+ ) # embeddings of codes
214
+ nn.init.uniform_(self.label_embs_concat)
215
+ self.linear = nn.Linear(encoder_dim, hubert_final_dim) # final_proj
216
+ else:
217
+ # projection
218
+ self.linear = nn.Linear(encoder_dim, codebook_size) # N_SubSpec=8
219
+
220
+ # reconstruct melspec
221
+ if self.recon_loss_ratio is not None and self.recon_loss_ratio > 0:
222
+ self.recon_proj = nn.Linear(encoder_dim, n_mels * self.n_fold)
223
+ self.recon_loss = nn.MSELoss()
224
+
225
+ # loss function
226
+ self.loss = nn.CrossEntropyLoss()
227
+
228
+ # cls token (used for sequence classification)
229
+ random.seed(seed)
230
+ self.cls_token = nn.Parameter(torch.randn(encoder_dim))
231
+
232
+ # load model
233
+ if model_path:
234
+ S = torch.load(model_path)["state_dict"]
235
+ SS = {k[6:]: v for k, v in S.items()}
236
+ SS['quantizer_melspec_2048.random_projection'] = SS['quantizer_melspec_2048_0.random_projection']
237
+ SS['quantizer_melspec_2048.codebook'] = SS['quantizer_melspec_2048_0.codebook']
238
+ del SS['quantizer_melspec_2048_0.random_projection']
239
+ del SS['quantizer_melspec_2048_0.codebook']
240
+ unmatch = self.load_state_dict(SS, strict=False)
241
+ if len(unmatch.missing_keys) > 0:
242
+ print(f'Missing keys: {unmatch.missing_keys}')
243
+
244
+ @property
245
+ def use_rvq_like_target(self):
246
+ return self.use_rvq_target or self.use_vq_target or self.use_encodec_target
247
+
248
+
249
+ def apply_hubert_mask(self, x, padding_mask=None, target_list=None):
250
+ B, T, C = x.shape
251
+ if self.mask_prob > 0:
252
+ mask_length = int(self.mask_hop / (1/self.label_rate))
253
+ mask_indices = compute_mask_indices(
254
+ (B, T),
255
+ padding_mask,
256
+ self.mask_prob,
257
+ mask_length, # self.mask_length,
258
+ "static", #self.mask_selection,
259
+ 0, #self.mask_other,
260
+ min_masks=2,
261
+ no_overlap=False, #self.no_mask_overlap,
262
+ min_space=1, #self.mask_min_space,
263
+ )
264
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
265
+ x[mask_indices] = self.mask_emb
266
+ mask_indices = torch.nonzero(mask_indices)
267
+ else:
268
+ mask_indices = None
269
+
270
+ return x, mask_indices
271
+
272
+ def masking(self, x, attention_mask=None):
273
+ """random masking of 400ms with given probability"""
274
+ if self.use_hubert_masking_strategy:
275
+ return x, None
276
+ mx = x.clone()
277
+ b, t = mx.shape
278
+ len_masking_raw = int(24000 * self.mask_hop) # 9600 = 24000 * 0.4
279
+ len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop) # 10 = 25Hz * 0.4
280
+
281
+ # get random mask indices
282
+ start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
283
+ time_domain_masked_indices = torch.nonzero(
284
+ start_indices.repeat_interleave(len_masking_raw, dim=1)
285
+ )
286
+ token_domain_masked_indices = torch.nonzero(
287
+ start_indices.repeat_interleave(len_masking_token, dim=1)
288
+ )
289
+
290
+ # mask with random values
291
+ masking_noise = (
292
+ torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
293
+ ) # 0 mean 0.1 std
294
+ mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
295
+
296
+ return mx, token_domain_masked_indices
297
+
298
+
299
+ @torch.no_grad()
300
+ def preprocessing(self, x, features):
301
+ """extract classic audio features"""
302
+ # check precision
303
+ if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
304
+ precision = 16
305
+ else:
306
+ precision = 32
307
+
308
+ out = {}
309
+ for key in features:
310
+ layer = getattr(self, "preprocessor_%s" % key)
311
+ layer.to(x.device)
312
+ dtype = x.dtype
313
+ out[key] = layer.float()(x.float())[..., :-1]
314
+ if precision == 16:
315
+ out[key] = out[key].half()
316
+ if out[key].dtype != dtype:
317
+ out[key].to(dtype=dtype)
318
+ return out
319
+
320
+ def encoder(self, x, *, attention_mask=None, is_features_only=False):
321
+ """2-layer conv + w2v-conformer"""
322
+ if not self.use_hubert_featurizer:
323
+ x = self.conv(x) # [3, 128, 3000] -> [3, 750, 1024]
324
+ if self.training and self.use_hubert_masking_strategy and not is_features_only:
325
+ x, mask_indices = self.apply_hubert_mask(x)
326
+ else:
327
+ mask_indices = None
328
+ if attention_mask is None:
329
+ out = self.conformer(x, output_hidden_states=True)
330
+ else:
331
+ attention_mask = attention_mask.bool()
332
+ skip_n = int(attention_mask.size(-1) / x.size(1))
333
+ attention_mask = attention_mask[:, ::skip_n]
334
+ attention_mask = attention_mask[:, :x.size(1)]
335
+ out = self.conformer(x, attention_mask=attention_mask, output_hidden_states=True)
336
+ hidden_emb = out["hidden_states"]
337
+ last_emb = out["last_hidden_state"]
338
+ logits = self.linear(last_emb)
339
+ interval = self.codebook_size
340
+ logits = {
341
+ key: logits[:, :, i * interval : (i + 1) * interval]
342
+ for i, key in enumerate(self.features)
343
+ }
344
+ return logits, hidden_emb, mask_indices
345
+
346
+ @torch.no_grad()
347
+ def normalize(self, x):
348
+ """normalize the input audio to have zero mean unit variance"""
349
+ for key in x.keys():
350
+ x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] # {'melspec_2048_cnt': 14282760192, 'melspec_2048_mean': 6.768444971712967}
351
+ return x
352
+
353
+ @torch.no_grad()
354
+ def rearrange(self, x):
355
+ """rearrange the batch to flatten every 4 steps"""
356
+ for key in x.keys():
357
+ if key == "chromagram":
358
+ x[key] = rearrange(x[key], "b f t -> b t f")
359
+ else:
360
+ x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=self.n_fold)
361
+ return x
362
+
363
+ def get_rvq_codes(self, inp, raw_wav):
364
+ if self.use_rvq_target:
365
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(inp)
366
+ return codes
367
+ if self.use_vq_target:
368
+ quantized_prompt_embeds, commitment_loss, codebook_loss, codes, _ = self.rvq(inp)
369
+ return codes.unsqueeze(1)
370
+ if self.use_encodec_target:
371
+ encoded_frames = self.rvq.encode(raw_wav.unsqueeze(1)) #list, B,[ 8,T ]
372
+ codes = torch.cat([encoded[0].detach() for encoded in encoded_frames], dim=-1)
373
+ if self.label_rate == 25:
374
+ codes = codes[:, :, ::3]
375
+ return codes
376
+
377
+ @torch.no_grad()
378
+ def tokenize(self, x, raw_wav):
379
+ out = {}
380
+ for key in x.keys():
381
+ if self.use_rvq_like_target:
382
+ self.rvq.eval()
383
+ inp = x[key].permute((0, 2, 1))
384
+ codes = self.get_rvq_codes(inp, raw_wav)
385
+ out[key] = torch.cat([codes[:, idx, ...] for idx in range(int(self.codebook_size//1024))], dim=-1) # (when use freq mask)->[Batch, N_SubSpec, SeqLen=8*750]
386
+ else:
387
+ layer = getattr(self, "quantizer_%s" % key)
388
+ out[key] = layer(x[key])
389
+ return out
390
+
391
+ def to_spec_wise_quad(self, x):
392
+ Batch, QuadSpec, Time = x.shape
393
+ SubSpec, N_SubSpec = 16, 8
394
+ assert 4 * SubSpec * N_SubSpec == QuadSpec == 4*128
395
+ x = rearrange(x, "b (q n s) t -> b (q s) (n t)", q=4, n=N_SubSpec, s=SubSpec)
396
+ return x # [Batch, SubSpec=16, N_SubSpec*Time=8*100Hz]
397
+
398
+ def get_targets(self, x, label=None):
399
+ if self.use_encodec_target:
400
+ raw_x = x.clone()
401
+ else:
402
+ raw_x = None
403
+ x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
404
+ x = self.normalize(x)
405
+ x = self.rearrange(x) # -> {'melspec_2048': Tensor{Size([3, 750, 512]) cuda:0 f32}}
406
+ melspec = x['melspec_2048']
407
+ if label is None:
408
+ target_tokens = self.tokenize(x, raw_x) # -> {'melspec_2048': Tensor{Size([3, 750]) cuda:0 i64}}
409
+ else:
410
+ # print("use_target from label")
411
+ target_tokens = {'melspec_2048': rearrange(label, "b n s -> b (n s)").long()}
412
+ return target_tokens, melspec
413
+
414
+ def get_predictions(self, x, *, mask=None, attention_mask=None, return_new_mask=False, is_features_only=False):
415
+ # preprocessing
416
+ if not self.use_hubert_featurizer:
417
+ x = self.preprocessing(x, features=["melspec_2048"])
418
+ x = self.normalize(x) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
419
+ else:
420
+ features = self.hubert_feature_extractor(x)
421
+ features = self.layer_norm(features.transpose(1, 2))
422
+ if self.post_extract_proj is not None:
423
+ features = self.post_extract_proj(features)
424
+ x = {"melspec_2048": features}
425
+
426
+ # encoding
427
+ logits, hidden_emb, new_mask = self.encoder(x["melspec_2048"], attention_mask=attention_mask, is_features_only=is_features_only)
428
+
429
+ if return_new_mask:
430
+ return logits, hidden_emb, mask if new_mask is None else new_mask
431
+ else:
432
+ return logits, hidden_emb
433
+
434
+ def get_latent(self, x, layer_ix=12):
435
+ _, hidden_states = self.get_predictions(x)
436
+ emb = hidden_states[layer_ix]
437
+ return emb
438
+
439
+ def compute_nce(self, x, pos, negs):
440
+ neg_is_pos = (pos == negs).all(-1)
441
+ pos = pos.unsqueeze(0)
442
+ targets = torch.cat([pos, negs], dim=0)
443
+
444
+ logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
445
+ logits /= 0.1
446
+ if neg_is_pos.any():
447
+ logits[1:][neg_is_pos] = float("-inf")
448
+ logits = logits.transpose(0, 1) # (num_x, num_cls+1)
449
+ return logits
450
+
451
+ def compute_hubert_nce_loss(self, proj_xs, targets):
452
+
453
+ label_embs_list = self.label_embs_concat.split(self.codebook_size, 0) # (self.num_classes, 0)
454
+
455
+ def compute_pred(proj_x, target, label_embs):
456
+ # compute logits for the i-th label set
457
+ y = torch.index_select(label_embs, 0, target.long())
458
+ negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
459
+ return self.compute_nce(proj_x, y, negs)
460
+
461
+ logit_list = [
462
+ compute_pred(proj_x, t, label_embs_list[i])
463
+ for i, (proj_x, t) in enumerate(zip(proj_xs, targets))
464
+ ]
465
+
466
+ return sum(logit_list)
467
+
468
+
469
+ def get_loss(self, logits, target_tokens, masked_indices):
470
+ losses = {}
471
+ accuracies = {}
472
+ for key in logits.keys():
473
+ if not self.use_rvq_like_target:
474
+ masked_logits = logits[key][tuple(masked_indices.t())]
475
+ masked_tokens = target_tokens[key][tuple(masked_indices.t())]
476
+ else:
477
+ Batch, SeqLen, N_Codebook_x_CodebookSize = logits[key].shape # CodebookSize=4096
478
+ Batch, N_Codebook_x_SeqLen = target_tokens[key].shape # N_Codebook*SeqLen=4*750
479
+ N_Codebook = int(N_Codebook_x_SeqLen // SeqLen)
480
+ # print("not use_virtual, n codebook = ", N_Codebook)
481
+ target_tokens[key] = rearrange(target_tokens[key], "b (n s) -> b s n", n=N_Codebook) # Batch, SeqLen=750, N_Codebook=4
482
+ masked_logits = logits[key][tuple(masked_indices.t())]
483
+ masked_tokens = target_tokens[key][tuple(masked_indices.t())]
484
+ masked_logits = rearrange(masked_logits, "b (n c) -> (b n) c", n=N_Codebook)
485
+ masked_tokens = rearrange(masked_tokens, "b n -> (b n)", n=N_Codebook)
486
+
487
+ if self.use_hubert_nce_loss:
488
+ losses[key] = self.compute_hubert_nce_loss(masked_logits, masked_tokens)
489
+ else:
490
+ losses[key] = self.loss(masked_logits, masked_tokens)
491
+ accuracies[key] = (
492
+ torch.sum(masked_logits.argmax(-1) == masked_tokens)
493
+ / masked_tokens.numel()
494
+ )
495
+ return losses, accuracies
496
+
497
+ def get_recon_loss(self, last_hidden_emb, melspec, masked_indices):
498
+ pred_melspec = self.recon_proj(last_hidden_emb[tuple(masked_indices.t())])
499
+ target_melspec = melspec[tuple(masked_indices.t())]
500
+ recon_loss = self.recon_loss(pred_melspec, target_melspec)
501
+ return recon_loss
502
+
503
+ def forward(self, x, attention_mask=None, label=None):
504
+ dtype = x.dtype
505
+ # get target feature tokens
506
+ target_tokens, melspec = self.get_targets(x, label=label)
507
+
508
+ # masking
509
+ x, masked_indices = self.masking(x, attention_mask=attention_mask)
510
+
511
+ # forward
512
+ logits, hidden_emb, masked_indices = self.get_predictions(x, mask=masked_indices, attention_mask=attention_mask, return_new_mask=True)
513
+
514
+ # get loss
515
+ losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
516
+
517
+ if self.recon_loss_ratio:
518
+ losses["recon_loss"] = self.get_recon_loss(hidden_emb[-1], melspec, masked_indices) * self.recon_loss_ratio
519
+
520
+ return logits, hidden_emb, losses, accuracies
MuCodec/muq_dev/muq_fairseq/models/muq/model/pred_ark_target_with_model.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch.nn as nn
3
+ import torch
4
+ import sys, os
5
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
6
+ from rvq_musicfm import PreprocessorWithModel, ResidualVectorQuantize
7
+
8
+ class RVQ(nn.Module):
9
+ def __init__(self,
10
+ model_config,
11
+ rvq_ckpt_path,
12
+ preprocess,
13
+ ):
14
+ super().__init__()
15
+ self.rvq = ResidualVectorQuantize(**model_config)
16
+ if rvq_ckpt_path is not None:
17
+ self.rvq.load_state_dict(torch.load(rvq_ckpt_path, map_location='cpu'))
18
+ self.preprocess = preprocess
19
+
20
+ def get_targets(self, x):
21
+ self.rvq.eval()
22
+ x = self.preprocess(x)
23
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(x)
24
+ return codes.permute(1,0,2)
25
+
26
+ @torch.no_grad()
27
+ def encode_wavs(self, wavs):
28
+ wavs = wavs[..., :int((wavs.shape[-1]//320)*320)]
29
+ return self.get_targets(wavs)
30
+
31
+ def This_Music_ModelTarget_Config():
32
+ config = dict(
33
+ model = dict(
34
+ input_dim = 1024,
35
+ n_codebooks = 8,
36
+ codebook_size = 1024,
37
+ codebook_dim = 16,
38
+ quantizer_dropout = 0.0,
39
+ ),
40
+ train = dict(
41
+ batch_size = 32,
42
+ num_workers = 6,
43
+ valid_interval = 10,
44
+ save_interval = 100,
45
+ max_updates = 500000,
46
+ lr = 1e-4,
47
+ # device = 'cuda:1',
48
+ loss = 'commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()',
49
+ preprocess = PreprocessorWithModel(
50
+ model_dir= 'path/to/muq_fairseq',
51
+ checkpoint_dir='path/to/muq_m4a_75K.pt',
52
+ use_layer_idx=9,
53
+ )
54
+ ),
55
+ pred = dict(
56
+ rvq_ckpt_path='path/to/runs/Aug07_18-09-24_ts-828fa13e58384d0bba4144fda78ecc92-launcher/ckpt/RVQ_8100.pth',
57
+ sr=24000,
58
+ data_jsonl_path='path/to/data/music4all/train.json',
59
+ save_target_dir= 'path/to/data/music4all_ark/reiter_musicssl_m4a',
60
+ ),
61
+ )
62
+ return config
63
+
64
+
65
+ CLEN = 30
66
+ N_GPU_PER = 8
67
+ N_NODE = 4
68
+
69
+ def parse_lr(wave_length, sr):
70
+ n_step = int( wave_length // (sr*CLEN) )
71
+ if n_step == 0:
72
+ n_step = 1
73
+ print('wave_length: ', wave_length, 'sr: ', sr, 'n_step: ', n_step)
74
+ starts = torch.arange(n_step) * CLEN * sr
75
+ left_rights = torch.stack((starts, starts+CLEN*sr)).T
76
+ return left_rights[:10, ...]
77
+
78
+ @torch.no_grad()
79
+ def main(index, rank):
80
+ device = f'cuda:{rank}'
81
+ config = This_Music_ModelTarget_Config()
82
+ preprocess = config['train']['preprocess']
83
+ model = RVQ(
84
+ model_config = config['model'],
85
+ rvq_ckpt_path = config['pred']['rvq_ckpt_path'],
86
+ preprocess = preprocess
87
+ ).to(device)
88
+ model.eval()
89
+ sr = config['pred']['sr']
90
+
91
+ fname_nobase = os.path.basename(config['pred']['data_jsonl_path']).split('.')[0]
92
+ scp_dir = os.path.join(config['pred']['save_target_dir'], 'scp')
93
+ ark_dir = os.path.join(config['pred']['save_target_dir'], 'ark')
94
+ os.makedirs(scp_dir, exist_ok=True)
95
+ os.makedirs(ark_dir, exist_ok=True)
96
+
97
+ scp_path = os.path.join(scp_dir, f'{fname_nobase}.{index}_{rank}.scp')
98
+ ark_path = os.path.join(ark_dir, f'{fname_nobase}.{index}_{rank}.ark')
99
+
100
+ from kaldiio import WriteHelper
101
+
102
+ with open(config['pred']['data_jsonl_path']) as f:
103
+ lines = f.readlines()
104
+
105
+ print("Total:", len(lines))
106
+
107
+ from tqdm import tqdm
108
+ import json
109
+ import librosa
110
+ import time
111
+ from einops import rearrange
112
+ import numpy as np
113
+
114
+ # lines = lines[(index*N_GPU_PER+rank)::(N_GPU_PER*N_NODE)]
115
+
116
+ with WriteHelper(f'ark,scp:{ark_path},{scp_path}') as writer:
117
+ for idx, line in tqdm(enumerate(lines)):
118
+ try:
119
+ if idx % (N_GPU_PER*N_NODE) != (index*N_GPU_PER+rank):
120
+ continue
121
+ item = json.loads(line)
122
+ path = item['path']
123
+ wave, _ = librosa.load(path, sr=sr)
124
+ wave = torch.from_numpy(wave)
125
+ wave_length = wave.shape[-1]
126
+ if wave_length < sr*CLEN:
127
+ continue
128
+ left_rights = parse_lr(wave_length, sr)
129
+ lr = left_rights.tolist()
130
+ wavs = torch.stack(
131
+ [wave[l:r] for l,r in lr]
132
+ ).to(device)
133
+ targets = model.encode_wavs(wavs) # [Codebook=8, N_Steps, Feature]
134
+
135
+ final_target = rearrange(targets, "c n f -> n (c f)").cpu().numpy().astype(np.int32)
136
+ for j in range(final_target.shape[0]):
137
+ writer(f'{idx}:{j}', final_target[j])
138
+ except Exception as e:
139
+ print(e)
140
+
141
+
142
+ if __name__ == '__main__':
143
+ import sys
144
+ index = int(sys.argv[1])
145
+ import multiprocessing
146
+ pool = multiprocessing.Pool(processes=N_GPU_PER)
147
+ for rank in range(8):
148
+ pool.apply_async(main, (index, rank))
149
+ pool.close()
150
+ pool.join()
151
+ print("Done.")
MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch.nn.utils import weight_norm
10
+
11
+ def WNConv1d(*args, **kwargs):
12
+ return weight_norm(nn.Conv1d(*args, **kwargs))
13
+
14
+
15
+ class VectorQuantize(nn.Module):
16
+ """
17
+ Implementation of VQ similar to Karpathy's repo:
18
+ https://github.com/karpathy/deep-vector-quantization
19
+ Additionally uses following tricks from Improved VQGAN
20
+ (https://arxiv.org/pdf/2110.04627.pdf):
21
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
22
+ for improved codebook usage
23
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
24
+ improves training stability
25
+ """
26
+
27
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 1000, mfcc_clustering=False, n_layer=1):
28
+ super().__init__()
29
+ self.codebook_size = codebook_size
30
+ self.codebook_dim = codebook_dim
31
+ self.mfcc_clustering = mfcc_clustering
32
+
33
+ ProjClass = nn.Identity if mfcc_clustering else WNConv1d
34
+ if n_layer==1:
35
+ self.in_proj = ProjClass(input_dim, codebook_dim, kernel_size=1)
36
+ self.out_proj = ProjClass(codebook_dim, input_dim, kernel_size=1)
37
+ elif n_layer >= 2:
38
+ ndim_hidden = 128
39
+ self.in_proj = nn.Sequential(
40
+ ProjClass(input_dim, ndim_hidden, kernel_size=1),
41
+ *[nn.Sequential(nn.ReLU(), ProjClass(ndim_hidden, ndim_hidden, kernel_size=1),) for _ in range(n_layer-2)],
42
+ nn.ReLU(),
43
+ ProjClass(ndim_hidden, codebook_dim, kernel_size=1)
44
+ )
45
+ self.out_proj = nn.Sequential(
46
+ ProjClass(codebook_dim, ndim_hidden, kernel_size=1),
47
+ nn.ReLU(),
48
+ *[nn.Sequential(ProjClass(ndim_hidden, ndim_hidden, kernel_size=1), nn.ReLU()) for _ in range(n_layer-2)],
49
+ ProjClass(ndim_hidden, input_dim, kernel_size=1),
50
+ )
51
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
52
+ self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
53
+ self.stale_tolerance = stale_tolerance
54
+
55
+ def forward(self, z):
56
+ """Quantized the input tensor using a fixed codebook and returns
57
+ the corresponding codebook vectors
58
+
59
+ Parameters
60
+ ----------
61
+ z : Tensor[B x D x T]
62
+
63
+ Returns
64
+ -------
65
+ Tensor[B x D x T]
66
+ Quantized continuous representation of input
67
+ Tensor[1]
68
+ Commitment loss to train encoder to predict vectors closer to codebook
69
+ entries
70
+ Tensor[1]
71
+ Codebook loss to update the codebook
72
+ Tensor[B x T]
73
+ Codebook indices (quantized discrete representation of input)
74
+ Tensor[B x D x T]
75
+ Projected latents (continuous representation of input before quantization)
76
+ """
77
+
78
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
79
+
80
+ z_e = self.in_proj(z) # z_e : (B x D x T)
81
+ z_q, indices = self.decode_latents(z_e)
82
+
83
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
84
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
85
+
86
+ z_q = (
87
+ z_e + (z_q - z_e).detach()
88
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
89
+
90
+ z_q = self.out_proj(z_q)
91
+
92
+ return z_q, commitment_loss, codebook_loss, indices, z_e
93
+
94
+ def embed_code(self, embed_id):
95
+ return F.embedding(embed_id, self.codebook.weight)
96
+
97
+ def decode_code(self, embed_id):
98
+ return self.embed_code(embed_id).transpose(1, 2)
99
+
100
+ def decode_latents(self, latents):
101
+ encodings = rearrange(latents, "b d t -> (b t) d")
102
+ codebook = self.codebook.weight # codebook: (N x D)
103
+
104
+ # L2 normalize encodings and codebook (ViT-VQGAN)
105
+ encodings = F.normalize(encodings)
106
+ codebook = F.normalize(codebook)
107
+
108
+ # Compute euclidean distance with codebook
109
+ dist = (
110
+ encodings.pow(2).sum(1, keepdim=True)
111
+ - 2 * encodings @ codebook.t()
112
+ + codebook.pow(2).sum(1, keepdim=True).t()
113
+ )
114
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
115
+ z_q = self.decode_code(indices)
116
+
117
+ if(self.training):
118
+ onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
119
+ stale_codes = (onehots.sum(0).sum(0) == 0).float()
120
+ self.stale_counter = self.stale_counter * stale_codes + stale_codes
121
+
122
+ # random replace codes that haven't been used for a while
123
+ replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
124
+ if replace_code.sum(-1) > 0:
125
+ print("Replace {} codes".format(replace_code.sum(-1)))
126
+ random_input_idx = torch.randperm(encodings.shape[0])
127
+ random_input = encodings[random_input_idx].view(encodings.shape)
128
+ if random_input.shape[0] < self.codebook_size:
129
+ random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
130
+ random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
131
+
132
+ self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
133
+ self.stale_counter = self.stale_counter * (1 - replace_code)
134
+
135
+ return z_q, indices
136
+
137
+
138
+ class ResidualVectorQuantize(nn.Module):
139
+ """
140
+ Introduced in SoundStream: An end2end neural audio codec
141
+ https://arxiv.org/abs/2107.03312
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ input_dim: int = 512,
147
+ n_codebooks: int = 9,
148
+ codebook_size: int = 1024,
149
+ codebook_dim: Union[int, list] = 8,
150
+ quantizer_dropout: float = 0.0,
151
+ stale_tolerance: int = 100,
152
+ use_multi_layer_num:int = 1,
153
+ ):
154
+ super().__init__()
155
+ if isinstance(codebook_dim, int):
156
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
157
+
158
+ self.n_codebooks = n_codebooks
159
+ self.codebook_dim = codebook_dim
160
+ self.codebook_size = codebook_size
161
+
162
+ self.quantizers = nn.ModuleList(
163
+ [
164
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance, n_layer=use_multi_layer_num)
165
+ for i in range(n_codebooks)
166
+ ]
167
+ )
168
+ self.quantizer_dropout = quantizer_dropout
169
+
170
+ def forward(self, z, n_quantizers: int = None):
171
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
172
+ the corresponding codebook vectors
173
+ Parameters
174
+ ----------
175
+ z : Tensor[B x D x T]
176
+ n_quantizers : int, optional
177
+ No. of quantizers to use
178
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
179
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
180
+ when in training mode, and a random number of quantizers is used.
181
+ Returns
182
+ -------
183
+ dict
184
+ A dictionary with the following keys:
185
+
186
+ "z" : Tensor[B x D x T]
187
+ Quantized continuous representation of input
188
+ "codes" : Tensor[B x N x T]
189
+ Codebook indices for each codebook
190
+ (quantized discrete representation of input)
191
+ "latents" : Tensor[B x N*D x T]
192
+ Projected latents (continuous representation of input before quantization)
193
+ "vq/commitment_loss" : Tensor[1]
194
+ Commitment loss to train encoder to predict vectors closer to codebook
195
+ entries
196
+ "vq/codebook_loss" : Tensor[1]
197
+ Codebook loss to update the codebook
198
+ """
199
+ z_q = 0
200
+ residual = z
201
+ commitment_loss = 0
202
+ codebook_loss = 0
203
+
204
+ codebook_indices = []
205
+ latents = []
206
+
207
+ if n_quantizers is None:
208
+ n_quantizers = self.n_codebooks
209
+ if self.training:
210
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
211
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
212
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
213
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
214
+ n_quantizers = n_quantizers.to(z.device)
215
+ else:
216
+ n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
217
+ n_quantizers = n_quantizers.to(z.device)
218
+
219
+ for i, quantizer in enumerate(self.quantizers):
220
+ # if self.training is False and i >= n_quantizers:
221
+ # break
222
+
223
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
224
+ residual
225
+ )
226
+
227
+ # Create mask to apply quantizer dropout
228
+ mask = (
229
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
230
+ )
231
+ z_q = z_q + z_q_i * mask[:, None, None]
232
+ residual = residual - z_q_i
233
+
234
+ # Sum losses
235
+ commitment_loss += (commitment_loss_i * mask).mean()
236
+ codebook_loss += (codebook_loss_i * mask).mean()
237
+
238
+ codebook_indices.append(indices_i)
239
+ latents.append(z_e_i)
240
+
241
+ codes = torch.stack(codebook_indices, dim=1)
242
+ latents = torch.cat(latents, dim=1)
243
+
244
+ encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
245
+
246
+ return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
247
+
248
+ def from_codes(self, codes: torch.Tensor):
249
+ """Given the quantized codes, reconstruct the continuous representation
250
+ Parameters
251
+ ----------
252
+ codes : Tensor[B x N x T]
253
+ Quantized discrete representation of input
254
+ Returns
255
+ -------
256
+ Tensor[B x D x T]
257
+ Quantized continuous representation of input
258
+ """
259
+ z_q = 0.0
260
+ z_p = []
261
+ n_codebooks = codes.shape[1]
262
+ for i in range(n_codebooks):
263
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
264
+ z_p.append(z_p_i)
265
+
266
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
267
+ z_q = z_q + z_q_i
268
+ return z_q, torch.cat(z_p, dim=1), codes
269
+
270
+ def from_latents(self, latents: torch.Tensor):
271
+ """Given the unquantized latents, reconstruct the
272
+ continuous representation after quantization.
273
+
274
+ Parameters
275
+ ----------
276
+ latents : Tensor[B x N x T]
277
+ Continuous representation of input after projection
278
+
279
+ Returns
280
+ -------
281
+ Tensor[B x D x T]
282
+ Quantized representation of full-projected space
283
+ Tensor[B x D x T]
284
+ Quantized representation of latent space
285
+ """
286
+ z_q = 0
287
+ z_p = []
288
+ codes = []
289
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
290
+
291
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
292
+ 0
293
+ ]
294
+ for i in range(n_codebooks):
295
+ j, k = dims[i], dims[i + 1]
296
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
297
+ z_p.append(z_p_i)
298
+ codes.append(codes_i)
299
+
300
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
301
+ z_q = z_q + z_q_i
302
+
303
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
304
+
305
+ from torch.utils.data import Dataset, DataLoader
306
+ import json, traceback
307
+ import torchaudio
308
+ import math
309
+
310
+ from typing import List, Tuple, Dict, Any
311
+
312
+ CLIPSECS = 5
313
+ def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate):
314
+ # read json file
315
+ print(json_path)
316
+ datas = []
317
+ inds = []
318
+ sizes = []
319
+ with open(json_path) as fp:
320
+ for ind,line in enumerate(fp):
321
+ data = json.loads(line)
322
+ datas.append(data)
323
+ inds.append(ind)
324
+ # sz = int(data['duration'] * data['sample_rate'])
325
+ sz = int(tgt_sample_rate * CLIPSECS)
326
+ sizes.append(sz)
327
+ tot = ind + 1
328
+ return datas,inds,tot,sizes
329
+
330
+ class Read_and_PadCrop_Normalized_T(torch.nn.Module):
331
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
332
+
333
+ super().__init__()
334
+
335
+ self.n_samples = n_samples
336
+ self.sample_rate = sample_rate
337
+ self.randomize = randomize
338
+
339
+
340
+ def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
341
+ if(duration<(float(self.n_samples)/self.sample_rate+1)):
342
+ # print(duration,(float(self.n_samples)/self.sample_rate+1))
343
+ chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
344
+ t_start = 0.
345
+ t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
346
+ offset = 0
347
+ # print('c1:',chunk.shape)
348
+ else:
349
+ offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
350
+ t_start = offset / float(cur_sample_rate) / duration
351
+ t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
352
+ chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
353
+ # print('offset:',offset)
354
+ # print('c0:',chunk.shape)
355
+ # Pad with silence if necessary.
356
+ if(chunk.shape[0]>1):
357
+ chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
358
+ else:
359
+ chunk = chunk[[0],:].float()
360
+ if(cur_sample_rate!=self.sample_rate):
361
+ # print('a:',cur_sample_rate,chunk.shape)
362
+ chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
363
+ # print('b:',self.sample_rate,chunk.shape)
364
+ if chunk.shape[-1] < self.n_samples:
365
+ chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
366
+ else:
367
+ chunk = chunk[:,0:self.n_samples]
368
+ seconds_start = math.floor(offset / cur_sample_rate)
369
+ seconds_total = math.floor(duration)
370
+
371
+ return (
372
+ chunk,
373
+ t_start,
374
+ t_end,
375
+ seconds_start,
376
+ seconds_total
377
+ )
378
+
379
+ class RVQDataset(Dataset):
380
+ def __init__(
381
+ self,
382
+ manifest_path: str,
383
+ sample_rate: float,
384
+ normalize: bool = False,
385
+ ):
386
+ self.sample_rate = sample_rate
387
+ self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path, None, None, self.sample_rate)
388
+ self.dataset_len = len(self.datas)
389
+
390
+ self.reader = Read_and_PadCrop_Normalized_T(n_samples=CLIPSECS*sample_rate,sample_rate = self.sample_rate)
391
+ self.normalize = normalize
392
+
393
+
394
+ def __getitem__(self, i):
395
+ # WORLD_SIZE = int(torch.distributed.get_world_size())
396
+ # WORLD_RANK = int(torch.distributed.get_rank())
397
+ # np.random.seed(1337 + self.epoch * WORLD_SIZE + WORLD_RANK + i)
398
+ # index = random.randint(0,len(self.sizes) - 1)
399
+ index = i
400
+ item = None
401
+ while item is None:
402
+ try:
403
+ wav = self.get_audio_by_slice(index)
404
+ # labels = self.get_labels(index)
405
+ # labels = None
406
+ # item = {"id": index, "source": wav, "label_list": labels}
407
+ item = {"id": index, "source": wav}
408
+ except Exception as e:
409
+ # print(e)
410
+ traceback.print_exc()
411
+ print(f'skip damaged data {index}')
412
+ index = np.random.randint(0,len(self.sizes)-1)
413
+ return item
414
+
415
+ def __len__(self):
416
+ return self.dataset_len
417
+
418
+ def get_audio_by_slice(self,index):
419
+
420
+ wav_path = self.datas[index]['path']
421
+ # print(wav_path)
422
+ audio_info = torchaudio.info(wav_path)
423
+ origin_sample_rate = audio_info.sample_rate
424
+ origin_duration = audio_info.num_frames / origin_sample_rate
425
+
426
+ wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
427
+ wav = wav.float()
428
+
429
+ # _path, slice_ptr = parse_path(wav_path)
430
+ # original way
431
+ # if len(slice_ptr) == 0:
432
+ # wav, cur_sample_rate = sf.read(_path)
433
+ # else:
434
+ # assert _path.endswith(".zip")
435
+ # data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
436
+ # f = io.BytesIO(data)
437
+ # wav, cur_sample_rate = sf.read(f)
438
+ # wav = torch.from_numpy(wav).float()
439
+ # print(wav.shape)
440
+ wav = wav.permute(1,0)
441
+ wav = self.postprocess(wav, self.sample_rate)
442
+ # print(wav.shape)
443
+
444
+ # wav = wav.squeeze(0)
445
+ return wav
446
+
447
+ def postprocess(self, wav, cur_sample_rate):
448
+ if wav.dim() == 2:
449
+ wav = wav.mean(-1)
450
+ assert wav.dim() == 1, wav.dim()
451
+
452
+ if cur_sample_rate != self.sample_rate:
453
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
454
+
455
+ if self.normalize:
456
+ with torch.no_grad():
457
+ wav = F.layer_norm(wav, wav.shape)
458
+ return wav
459
+
MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq_muq.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from .rvq import *
3
+ except:
4
+ import sys, os
5
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
6
+ from rvq import *
7
+
8
+ try:
9
+ from ..modules.random_quantizer import RandomProjectionQuantizer
10
+ from ..modules.features import MelSTFT
11
+ from ..modules.conv import Conv2dSubsampling
12
+ except:
13
+ import sys, os
14
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
15
+ from modules.random_quantizer import RandomProjectionQuantizer
16
+ from modules.features import MelSTFT
17
+ from modules.conv import Conv2dSubsampling
18
+
19
+ import fairseq
20
+
21
+ CLIPSECS = 5 # 5 for rvq, 30 for model
22
+
23
+ class RVQDataset(Dataset):
24
+ def __init__(
25
+ self,
26
+ manifest_path: str,
27
+ sample_rate: float,
28
+ normalize: bool = False,
29
+ ):
30
+ self.sample_rate = sample_rate
31
+ self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path, None, None, self.sample_rate)
32
+ self.dataset_len = len(self.datas)
33
+
34
+ self.reader = Read_and_PadCrop_Normalized_T(n_samples=CLIPSECS*sample_rate,sample_rate = self.sample_rate)
35
+ self.normalize = normalize
36
+
37
+
38
+ def __getitem__(self, i):
39
+ # WORLD_SIZE = int(torch.distributed.get_world_size())
40
+ # WORLD_RANK = int(torch.distributed.get_rank())
41
+ # np.random.seed(1337 + self.epoch * WORLD_SIZE + WORLD_RANK + i)
42
+ # index = random.randint(0,len(self.sizes) - 1)
43
+ index = i
44
+ item = None
45
+ while item is None:
46
+ try:
47
+ wav = self.get_audio_by_slice(index)
48
+ item = {"id": index, "source": wav}
49
+ except Exception as e:
50
+ # print(e)
51
+ traceback.print_exc()
52
+ print(f'skip damaged data {index}')
53
+ index = np.random.randint(0,len(self.sizes)-1)
54
+ return item
55
+
56
+ def __len__(self):
57
+ return self.dataset_len
58
+
59
+ def get_audio_by_slice(self,index):
60
+
61
+ wav_path = self.datas[index]['path']
62
+ audio_info = torchaudio.info(wav_path)
63
+ origin_sample_rate = audio_info.sample_rate
64
+ origin_duration = audio_info.num_frames / origin_sample_rate
65
+
66
+ wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
67
+ wav = wav.float()
68
+
69
+ # _path, slice_ptr = parse_path(wav_path)
70
+ # original way
71
+ # if len(slice_ptr) == 0:
72
+ # wav, cur_sample_rate = sf.read(_path)
73
+ # else:
74
+ # assert _path.endswith(".zip")
75
+ # data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
76
+ # f = io.BytesIO(data)
77
+ # wav, cur_sample_rate = sf.read(f)
78
+ # wav = torch.from_numpy(wav).float()
79
+ # print(wav.shape)
80
+ wav = wav.permute(1,0)
81
+ wav = self.postprocess(wav, self.sample_rate)
82
+ # print(wav.shape)
83
+
84
+ # wav = wav.squeeze(0)
85
+ return wav
86
+
87
+ def postprocess(self, wav, cur_sample_rate):
88
+ if wav.dim() == 2:
89
+ wav = wav.mean(-1)
90
+ assert wav.dim() == 1, wav.dim()
91
+
92
+ if cur_sample_rate != self.sample_rate:
93
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
94
+
95
+ if self.normalize:
96
+ with torch.no_grad():
97
+ wav = F.layer_norm(wav, wav.shape)
98
+ return wav
99
+
100
+ class Preprocessor(nn.Module):
101
+ def __init__(self,
102
+ codebook_dim=16,
103
+ codebook_size=4096,
104
+ hop_length=240,
105
+ n_mels=128,
106
+ stat_path=None,
107
+ is_spec_wise=False,
108
+ s=4,
109
+ ) -> None:
110
+ super().__init__()
111
+
112
+ self.features=["melspec_2048"]
113
+ self.s = s
114
+
115
+ # load feature mean / std stats
116
+ import os
117
+ if stat_path is not None and os.path.exists(stat_path):
118
+ with open(stat_path, "r") as f:
119
+ self.stat = json.load(f)
120
+ else:
121
+ # print("No stats file found at `{}`, use default from msd.".format(stat_path))
122
+ self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
123
+
124
+ # feature extractor
125
+ self.preprocessor_melspec_2048 = MelSTFT(
126
+ n_fft=2048, hop_length=hop_length, is_db=True
127
+ )
128
+
129
+ self.is_spec_wise = is_spec_wise
130
+
131
+
132
+ @torch.no_grad()
133
+ def normalize(self, x):
134
+ """normalize the input audio to have zero mean unit variance"""
135
+ for key in x.keys():
136
+ x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] # {'melspec_2048_cnt': 14282760192, 'melspec_2048_mean': 6.768444971712967}
137
+ return x
138
+
139
+ @torch.no_grad()
140
+ def rearrange(self, x):
141
+ """rearrange the batch to flatten every 4 steps"""
142
+ for key in x.keys():
143
+ if key == "chromagram":
144
+ x[key] = rearrange(x[key], "b f t -> b t f")
145
+ else:
146
+ x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=self.s)
147
+ return x
148
+
149
+ @torch.no_grad()
150
+ def preprocessing(self, x, features):
151
+ """extract classic audio features"""
152
+ # check precision
153
+ if x.dtype == torch.float16:
154
+ precision = 16
155
+ else:
156
+ precision = 32
157
+
158
+ out = {}
159
+ for key in features:
160
+ layer = getattr(self, "preprocessor_%s" % key)
161
+ out[key] = layer.float()(x.float())[..., :-1]
162
+ if precision == 16:
163
+ out[key] = out[key].half()
164
+ return out
165
+
166
+ @torch.no_grad()
167
+ def tokenize(self, x):
168
+ out = {}
169
+ for key in x.keys():
170
+ layer = getattr(self, "quantizer_%s" % key)
171
+ out[key] = layer(x[key])
172
+ return out
173
+
174
+ def to_spec_wise(self, x):
175
+ Batch, Spec, Time = x.shape
176
+ SubSpec, N_SubSpec = 16, 8
177
+ assert SubSpec * N_SubSpec == Spec == 128
178
+ x = rearrange(x, "b (n s) t -> b s (n t)", n=N_SubSpec, s=SubSpec)
179
+ return x # [Batch, SubSpec=16, N_SubSpec*Time=8*100Hz]
180
+
181
+ @torch.no_grad()
182
+ def __call__(self, x):
183
+ x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
184
+ x = self.normalize(x)
185
+ if self.is_spec_wise:
186
+ x = {k:self.to_spec_wise(v) for k,v in x.items()}
187
+ x = self.rearrange(x) # -> {'melspec_2048': Tensor{Size([3, 750, 512]) cuda:0 f32}}
188
+ return x['melspec_2048'].permute((0, 2, 1))
189
+
190
+
191
+ class CQTPreprocessor(nn.Module):
192
+ def __init__(self,
193
+ sr=24000,
194
+ hop=960,
195
+ nb=84,
196
+ to_db = True,
197
+ ) -> None:
198
+ super().__init__()
199
+
200
+ from nnAudio.features.cqt import CQT
201
+ import torchaudio
202
+ self.cqt_fn = CQT(
203
+ sr=sr,
204
+ hop_length=hop,
205
+ n_bins=nb,
206
+ fmin=32.7 if nb == 84 else 27.5, # 84 or 88
207
+ bins_per_octave=12,
208
+ filter_scale=1,
209
+ norm=1,
210
+ window='hann',
211
+ center=True,
212
+ pad_mode='constant',
213
+ trainable=False,
214
+ output_format='Magnitude',
215
+ verbose=True,
216
+ )
217
+ if to_db:
218
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
219
+ else:
220
+ self.amplitude_to_db = lambda x:x
221
+
222
+ @torch.no_grad()
223
+ def __call__(self, x):
224
+ return self.amplitude_to_db(self.cqt_fn(x))
225
+
226
+
227
+ from dataclasses import dataclass
228
+
229
+ @dataclass
230
+ class UserDirModule:
231
+ user_dir: str
232
+
233
+ def load_model(model_dir, checkpoint_dir):
234
+ '''Load Fairseq SSL model'''
235
+
236
+ if model_dir is not None:
237
+ model_path = UserDirModule(model_dir)
238
+ fairseq.utils.import_user_module(model_path)
239
+
240
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_dir], strict=False)
241
+ model = model[0]
242
+
243
+ return model
244
+
245
+
246
+
247
+ class PreprocessorWithModel(nn.Module):
248
+ def __init__(self, model_dir, checkpoint_dir, use_layer_idx=9) -> None:
249
+ super().__init__()
250
+ self.model = load_model(model_dir=model_dir, checkpoint_dir=checkpoint_dir)
251
+ self.model.eval()
252
+ self.use_layer_idx = use_layer_idx
253
+
254
+ def forward(self, x):
255
+ with torch.no_grad():
256
+ self.model.eval()
257
+ res = self.model(x, features_only = True)
258
+ layer_results = res['layer_results']
259
+ return layer_results[self.use_layer_idx].permute(0,2,1)
260
+
261
+
262
+
263
+ def Music_Mel_Target_Config():
264
+ config = dict(
265
+ train_dataset = dict(
266
+ manifest_path = 'path/to/data/music4all/train.json',
267
+ sample_rate = 24000,
268
+ normalize = False,
269
+ ),
270
+ valid_dataset = dict(
271
+ manifest_path = 'path/to/data/music4all/valid.json',
272
+ sample_rate = 24000,
273
+ normalize = False,
274
+ ),
275
+ model = dict(
276
+ input_dim = 128*4,
277
+ n_codebooks = 8,
278
+ codebook_size = 1024,
279
+ codebook_dim = 16,
280
+ quantizer_dropout = 0.0,
281
+ ),
282
+ train = dict(
283
+ batch_size = 32,
284
+ num_workers = 6,
285
+ valid_interval = 10,
286
+ save_interval = 100,
287
+ max_updates = 500000,
288
+ lr = 1e-4,
289
+ device = 'cuda:0',
290
+ loss = 'commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()',
291
+ preprocess = Preprocessor()
292
+ )
293
+ )
294
+ return config
295
+
296
+
297
+ def main(config):
298
+ train_dataset = RVQDataset(**config['train_dataset'])
299
+ if config['valid_dataset']['manifest_path'] is None:
300
+ # split train and valid dataset
301
+ from torch.utils.data import random_split
302
+ train_dataset, valid_dataset = random_split(
303
+ train_dataset, lengths=[len(train_dataset) - 500, 500]
304
+ )
305
+ else:
306
+ valid_dataset = RVQDataset(**config['valid_dataset'])
307
+ train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers'])
308
+ valid_dataloader = DataLoader(valid_dataset, shuffle=False, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers'])
309
+ model = ResidualVectorQuantize(**config['model'])
310
+
311
+ device = config['train']['device']
312
+ preprocess = config['train']['preprocess'].to(device)
313
+ model = model.to(device)
314
+
315
+ optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'])
316
+ cur_updates = 0
317
+ is_running = True
318
+ result = {}
319
+ from tqdm import tqdm
320
+ from tensorboardX import SummaryWriter
321
+ writer = SummaryWriter()
322
+ from collections import defaultdict
323
+ import os
324
+ from logging import getLogger
325
+ logger = getLogger()
326
+
327
+ while is_running:
328
+ results = defaultdict(lambda:0)
329
+ for item in tqdm(train_dataloader, desc='train'):
330
+ wavs = item['source']
331
+ optimizer.zero_grad()
332
+ wavs = wavs.to(device)
333
+ x = preprocess(wavs)
334
+ model.train()
335
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x)
336
+ loss = eval(config['train']['loss'])
337
+ loss.backward()
338
+ optimizer.step()
339
+
340
+ results['loss/train'] += loss.item()
341
+ results['commitment_loss/train'] += commitment_loss.item()
342
+ results['codebook_loss/train'] += codebook_loss.item()
343
+ results['rvq_usage/train'] += rvq_usage.float().mean().item()
344
+
345
+ if cur_updates % config['train']['valid_interval'] == 0:
346
+ model.eval()
347
+ with torch.no_grad():
348
+ for item in tqdm(valid_dataloader, desc='valid'):
349
+ wavs = item['source']
350
+ wavs = wavs.to(device)
351
+ x = preprocess(wavs)
352
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x)
353
+ valid_loss = eval(config['train']['loss'])
354
+
355
+ results['loss/valid'] += valid_loss.item()
356
+ results['commitment_loss/valid'] += commitment_loss.item()
357
+ results['codebook_loss/valid'] += codebook_loss.item()
358
+ results['rvq_usage/valid'] += rvq_usage.float().mean().item()
359
+
360
+ results['cur_updates'] = cur_updates
361
+ results['loss/train'] /= config['train']['valid_interval']
362
+ results['commitment_loss/train'] /= config['train']['valid_interval']
363
+ results['codebook_loss/train'] /= config['train']['valid_interval']
364
+ results['rvq_usage/train'] /= config['train']['valid_interval']
365
+
366
+ results['loss/valid'] /= len(valid_dataloader)
367
+ results['commitment_loss/valid'] /= len(valid_dataloader)
368
+ results['codebook_loss/valid'] /= len(valid_dataloader)
369
+ results['rvq_usage/valid'] /= len(valid_dataloader)
370
+
371
+ print('')
372
+ logger.info(str(results))
373
+ for k,v in results.items():
374
+ writer.add_scalar(k, v, cur_updates)
375
+
376
+ results.clear()
377
+
378
+ if cur_updates % config['train']['save_interval'] == 0:
379
+ os.makedirs(f'{writer.logdir}/ckpt/', exist_ok=True)
380
+ logger.info(f'saving checkpoint to {writer.logdir}/ckpt/RVQ_{cur_updates}.pth')
381
+ torch.save(model.state_dict(), f'{writer.logdir}/ckpt/RVQ_{cur_updates}.pth')
382
+
383
+
384
+ if cur_updates < config['train']['max_updates']:
385
+ cur_updates += 1
386
+ else:
387
+ is_running = False
388
+ break
389
+
390
+
391
+
392
+ if __name__ == '__main__':
393
+ config = Music_Mel_Target_Config()
394
+ main(config)
MuCodec/muq_dev/muq_fairseq/models/muq/model/w2v2_config.json ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.1,
3
+ "adapter_kernel_size": 3,
4
+ "adapter_stride": 2,
5
+ "add_adapter": false,
6
+ "apply_spec_augment": true,
7
+ "architectures": [
8
+ "Wav2Vec2ConformerForCTC"
9
+ ],
10
+ "attention_dropout": 0.1,
11
+ "bos_token_id": 1,
12
+ "classifier_proj_size": 256,
13
+ "codevector_dim": 768,
14
+ "conformer_conv_dropout": 0.1,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": true,
17
+ "conv_depthwise_kernel_size": 31,
18
+ "conv_dim": [
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512,
25
+ 512
26
+ ],
27
+ "conv_kernel": [
28
+ 10,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 3,
33
+ 2,
34
+ 2
35
+ ],
36
+ "conv_stride": [
37
+ 5,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2,
43
+ 2
44
+ ],
45
+ "ctc_loss_reduction": "sum",
46
+ "ctc_zero_infinity": false,
47
+ "diversity_loss_weight": 0.1,
48
+ "do_stable_layer_norm": true,
49
+ "eos_token_id": 2,
50
+ "feat_extract_activation": "gelu",
51
+ "feat_extract_dropout": 0.0,
52
+ "feat_extract_norm": "layer",
53
+ "feat_proj_dropout": 0.1,
54
+ "feat_quantizer_dropout": 0.0,
55
+ "final_dropout": 0.1,
56
+ "gradient_checkpointing": false,
57
+ "hidden_act": "swish",
58
+ "hidden_dropout": 0.1,
59
+ "hidden_dropout_prob": 0.1,
60
+ "hidden_size": 1024,
61
+ "initializer_range": 0.02,
62
+ "intermediate_size": 4096,
63
+ "layer_norm_eps": 1e-05,
64
+ "layerdrop": 0.0,
65
+ "mask_feature_length": 10,
66
+ "mask_feature_min_masks": 0,
67
+ "mask_feature_prob": 0.0,
68
+ "mask_time_length": 10,
69
+ "mask_time_min_masks": 2,
70
+ "mask_time_prob": 0.05,
71
+ "max_source_positions": 5000,
72
+ "model_type": "wav2vec2-conformer",
73
+ "num_adapter_layers": 3,
74
+ "num_attention_heads": 16,
75
+ "num_codevector_groups": 2,
76
+ "num_codevectors_per_group": 320,
77
+ "num_conv_pos_embedding_groups": 16,
78
+ "num_conv_pos_embeddings": 128,
79
+ "num_feat_extract_layers": 7,
80
+ "num_hidden_layers": 24,
81
+ "num_negatives": 100,
82
+ "output_hidden_size": 1024,
83
+ "pad_token_id": 0,
84
+ "position_embeddings_type": "rotary",
85
+ "proj_codevector_dim": 768,
86
+ "rotary_embedding_base": 10000,
87
+ "tdnn_dilation": [
88
+ 1,
89
+ 2,
90
+ 3,
91
+ 1,
92
+ 1
93
+ ],
94
+ "tdnn_dim": [
95
+ 512,
96
+ 512,
97
+ 512,
98
+ 512,
99
+ 1500
100
+ ],
101
+ "tdnn_kernel": [
102
+ 5,
103
+ 3,
104
+ 3,
105
+ 1,
106
+ 1
107
+ ],
108
+ "torch_dtype": "float32",
109
+ "transformers_version": "4.19.0.dev0",
110
+ "use_weighted_layer_sum": false,
111
+ "vocab_size": 32,
112
+ "xvector_output_dim": 512
113
+ }
MuCodec/muq_dev/muq_fairseq/models/muq/modules/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (185 Bytes). View file
 
MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/conv.cpython-310.pyc ADDED
Binary file (2.72 kB). View file
 
MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/features.cpython-310.pyc ADDED
Binary file (2.14 kB). View file
 
MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/random_quantizer.cpython-310.pyc ADDED
Binary file (1.98 kB). View file
 
MuCodec/muq_dev/muq_fairseq/models/muq/modules/conv.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from einops import rearrange
3
+
4
+
5
+ class Res2dModule(nn.Module):
6
+ def __init__(self, idim, odim, stride=(2, 2)):
7
+ super(Res2dModule, self).__init__()
8
+ self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
9
+ self.bn1 = nn.BatchNorm2d(odim)
10
+ self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
11
+ self.bn2 = nn.BatchNorm2d(odim)
12
+ self.relu = nn.ReLU()
13
+
14
+ # residual
15
+ self.diff = False
16
+ if (idim != odim) or (stride[0] > 1):
17
+ self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
18
+ self.bn3 = nn.BatchNorm2d(odim)
19
+ self.diff = True
20
+
21
+ def forward(self, x):
22
+ out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
23
+ if self.diff:
24
+ x = self.bn3(self.conv3(x))
25
+ out = x + out
26
+ out = self.relu(out)
27
+ return out
28
+
29
+
30
+ class Conv2dSubsampling(nn.Module):
31
+ """Convolutional 2D subsampling (to 1/4 length).
32
+
33
+ Args:
34
+ idim (int): Input dimension.
35
+ hdim (int): Hidden dimension.
36
+ odim (int): Output dimension.
37
+ strides (list): Sizes of strides.
38
+ n_bands (int): Number of frequency bands.
39
+ """
40
+
41
+ def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
42
+ """Construct an Conv2dSubsampling object."""
43
+ super(Conv2dSubsampling, self).__init__()
44
+
45
+ self.conv = nn.Sequential(
46
+ Res2dModule(idim, hdim, (2, strides[0])),
47
+ Res2dModule(hdim, hdim, (2, strides[1])),
48
+ )
49
+ self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
50
+
51
+ def forward(self, x):
52
+ """Subsample x.
53
+
54
+ Args:
55
+ x (torch.Tensor): Input tensor (#batch, idim, time).
56
+
57
+ Returns:
58
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
59
+ where time' = time // 4.
60
+ """
61
+
62
+ if x.dim() == 3:
63
+ x = x.unsqueeze(1) # (b, c, f, t)
64
+ x = self.conv(x)
65
+ x = rearrange(x, "b c f t -> b t (c f)")
66
+ x = self.linear(x)
67
+ return x
68
+
69
+ if __name__ == '__main__':
70
+ import torch
71
+ conv_dim, encoder_dim = 512, 1024
72
+ conv = Conv2dSubsampling(
73
+ 1, conv_dim, encoder_dim, strides=[2, 1], n_bands=128
74
+ )
75
+ inp = torch.randn((1, 128, 3000))
76
+ out = conv(inp)
77
+ print(out.shape)
MuCodec/muq_dev/muq_fairseq/models/muq/modules/features.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ from torch import nn
3
+ import torch
4
+
5
+
6
+ class MelSTFT(nn.Module):
7
+ def __init__(
8
+ self,
9
+ sample_rate=24000,
10
+ n_fft=2048,
11
+ hop_length=240,
12
+ n_mels=128,
13
+ is_db=False,
14
+ ):
15
+ super(MelSTFT, self).__init__()
16
+
17
+ # spectrogram
18
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
19
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
20
+ )
21
+
22
+ # amplitude to decibel
23
+ self.is_db = is_db
24
+ if is_db:
25
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
26
+
27
+ def forward(self, waveform):
28
+ if self.is_db:
29
+ return self.amplitude_to_db(self.mel_stft(waveform))
30
+ else:
31
+ return self.mel_stft(waveform)
32
+
33
+
34
+ class CQTPreprocessor(nn.Module):
35
+ def __init__(self,
36
+ sr=24000,
37
+ hop=960,
38
+ nb=84,
39
+ to_db = True,
40
+ ) -> None:
41
+ super().__init__()
42
+
43
+ from nnAudio.features.cqt import CQT
44
+ import torchaudio
45
+ self.cqt_fn = CQT(
46
+ sr=sr,
47
+ hop_length=hop,
48
+ n_bins=nb,
49
+ fmin=32.7 if nb == 84 else 27.5, # 84 or 88
50
+ bins_per_octave=12,
51
+ filter_scale=1,
52
+ norm=1,
53
+ window='hann',
54
+ center=True,
55
+ pad_mode='constant',
56
+ trainable=False,
57
+ output_format='Magnitude',
58
+ verbose=True,
59
+ )
60
+ if to_db:
61
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
62
+ else:
63
+ self.amplitude_to_db = lambda x:x
64
+
65
+ @torch.no_grad()
66
+ def __call__(self, x):
67
+ return self.amplitude_to_db(self.cqt_fn(x))
MuCodec/muq_dev/muq_fairseq/models/muq/modules/flash_conformer.py ADDED
@@ -0,0 +1,2114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Wav2Vec2-Conformer model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+ from torch.nn import functional as F
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutput,
32
+ CausalLMOutput,
33
+ SequenceClassifierOutput,
34
+ TokenClassifierOutput,
35
+ Wav2Vec2BaseModelOutput,
36
+ XVectorOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.utils import (
40
+ ModelOutput,
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+
53
+ _HIDDEN_STATES_START_POSITION = 2
54
+
55
+ # General docstring
56
+ _CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
57
+
58
+ # Base docstring
59
+ _CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
60
+ _EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
61
+
62
+ # CTC docstring
63
+ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
64
+ _CTC_EXPECTED_LOSS = 64.21
65
+
66
+
67
+ WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
68
+ "facebook/wav2vec2-conformer-rel-pos-large",
69
+ # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
70
+ ]
71
+
72
+
73
+ @dataclass
74
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
75
+ class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
76
+ """
77
+ Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
78
+
79
+ Args:
80
+ loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
81
+ Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
82
+ paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
83
+ projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
84
+ Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
85
+ projected quantized states.
86
+ projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
87
+ Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
88
+ target vectors for contrastive loss.
89
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
90
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
91
+ shape `(batch_size, sequence_length, hidden_size)`.
92
+
93
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
94
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
95
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
96
+ sequence_length)`.
97
+
98
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
99
+ heads.
100
+ contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
101
+ The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
102
+ diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
103
+ The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
104
+ """
105
+
106
+ loss: Optional[torch.FloatTensor] = None
107
+ projected_states: torch.FloatTensor = None
108
+ projected_quantized_states: torch.FloatTensor = None
109
+ codevector_perplexity: torch.FloatTensor = None
110
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
111
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
112
+ contrastive_loss: Optional[torch.FloatTensor] = None
113
+ diversity_loss: Optional[torch.FloatTensor] = None
114
+
115
+
116
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
117
+ def _compute_mask_indices(
118
+ shape: Tuple[int, int],
119
+ mask_prob: float,
120
+ mask_length: int,
121
+ attention_mask: Optional[torch.LongTensor] = None,
122
+ min_masks: int = 0,
123
+ ) -> np.ndarray:
124
+ """
125
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
126
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
127
+ CPU as part of the preprocessing during training.
128
+
129
+ Args:
130
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
131
+ the first element is the batch size and the second element is the length of the axis to span.
132
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
133
+ independently generated mask spans of length `mask_length` is computed by
134
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
135
+ actual percentage will be smaller.
136
+ mask_length: size of the mask
137
+ min_masks: minimum number of masked spans
138
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
139
+ each batch dimension.
140
+ """
141
+ batch_size, sequence_length = shape
142
+
143
+ if mask_length < 1:
144
+ raise ValueError("`mask_length` has to be bigger than 0.")
145
+
146
+ if mask_length > sequence_length:
147
+ raise ValueError(
148
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
149
+ f" and `sequence_length`: {sequence_length}`"
150
+ )
151
+
152
+ # epsilon is used for probabilistic rounding
153
+ epsilon = np.random.rand(1).item()
154
+
155
+ def compute_num_masked_span(input_length):
156
+ """Given input length, compute how many spans should be masked"""
157
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
158
+ num_masked_span = max(num_masked_span, min_masks)
159
+
160
+ # make sure num masked span <= sequence_length
161
+ if num_masked_span * mask_length > sequence_length:
162
+ num_masked_span = sequence_length // mask_length
163
+
164
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
165
+ if input_length - (mask_length - 1) < num_masked_span:
166
+ num_masked_span = max(input_length - (mask_length - 1), 0)
167
+
168
+ return num_masked_span
169
+
170
+ # compute number of masked spans in batch
171
+ input_lengths = (
172
+ attention_mask.sum(-1).detach().tolist()
173
+ if attention_mask is not None
174
+ else [sequence_length for _ in range(batch_size)]
175
+ )
176
+
177
+ # SpecAugment mask to fill
178
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
179
+ spec_aug_mask_idxs = []
180
+
181
+ max_num_masked_span = compute_num_masked_span(sequence_length)
182
+
183
+ if max_num_masked_span == 0:
184
+ return spec_aug_mask
185
+
186
+ for input_length in input_lengths:
187
+ # compute num of masked spans for this input
188
+ num_masked_span = compute_num_masked_span(input_length)
189
+
190
+ # get random indices to mask
191
+ spec_aug_mask_idx = np.random.choice(
192
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
193
+ )
194
+
195
+ # pick first sampled index that will serve as a dummy index to pad vector
196
+ # to ensure same dimension for all batches due to probabilistic rounding
197
+ # Picking first sample just pads those vectors twice.
198
+ if len(spec_aug_mask_idx) == 0:
199
+ # this case can only happen if `input_length` is strictly smaller then
200
+ # `sequence_length` in which case the last token has to be a padding
201
+ # token which we can use as a dummy mask id
202
+ dummy_mask_idx = sequence_length - 1
203
+ else:
204
+ dummy_mask_idx = spec_aug_mask_idx[0]
205
+
206
+ spec_aug_mask_idx = np.concatenate(
207
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
208
+ )
209
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
210
+
211
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
212
+
213
+ # expand masked indices to masked spans
214
+ spec_aug_mask_idxs = np.broadcast_to(
215
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
216
+ )
217
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
218
+
219
+ # add offset to the starting indexes so that indexes now create a span
220
+ offsets = np.arange(mask_length)[None, None, :]
221
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
222
+ batch_size, max_num_masked_span * mask_length
223
+ )
224
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
225
+
226
+ # ensure that we cannot have indices larger than sequence_length
227
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
228
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
229
+
230
+ # scatter indices to mask
231
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
232
+
233
+ return spec_aug_mask
234
+
235
+
236
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
237
+ def _sample_negative_indices(
238
+ features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
239
+ ):
240
+ """
241
+ Sample `num_negatives` vectors from feature vectors.
242
+ """
243
+ batch_size, sequence_length = features_shape
244
+
245
+ # generate indices of the positive vectors themselves, repeat them `num_negatives` times
246
+ sequence_length_range = np.arange(sequence_length)
247
+
248
+ # get `num_negatives` random vector indices from the same utterance
249
+ sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
250
+
251
+ mask_time_indices = (
252
+ mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
253
+ )
254
+
255
+ for batch_idx in range(batch_size):
256
+ high = mask_time_indices[batch_idx].sum() - 1
257
+ mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
258
+
259
+ feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
260
+ sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
261
+ # avoid sampling the same positive vector, but keep the distribution uniform
262
+ sampled_indices[sampled_indices >= feature_indices] += 1
263
+
264
+ # remap to actual indices
265
+ sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
266
+
267
+ # correct for batch size
268
+ sampled_negative_indices[batch_idx] += batch_idx * sequence_length
269
+
270
+ return sampled_negative_indices
271
+
272
+
273
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
274
+ class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
275
+ def __init__(self, config, layer_id=0):
276
+ super().__init__()
277
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
278
+ self.out_conv_dim = config.conv_dim[layer_id]
279
+
280
+ self.conv = nn.Conv1d(
281
+ self.in_conv_dim,
282
+ self.out_conv_dim,
283
+ kernel_size=config.conv_kernel[layer_id],
284
+ stride=config.conv_stride[layer_id],
285
+ bias=config.conv_bias,
286
+ )
287
+ self.activation = ACT2FN[config.feat_extract_activation]
288
+
289
+ def forward(self, hidden_states):
290
+ hidden_states = self.conv(hidden_states)
291
+ hidden_states = self.activation(hidden_states)
292
+ return hidden_states
293
+
294
+
295
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
296
+ class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
297
+ def __init__(self, config, layer_id=0):
298
+ super().__init__()
299
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
300
+ self.out_conv_dim = config.conv_dim[layer_id]
301
+
302
+ self.conv = nn.Conv1d(
303
+ self.in_conv_dim,
304
+ self.out_conv_dim,
305
+ kernel_size=config.conv_kernel[layer_id],
306
+ stride=config.conv_stride[layer_id],
307
+ bias=config.conv_bias,
308
+ )
309
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
310
+ self.activation = ACT2FN[config.feat_extract_activation]
311
+
312
+ def forward(self, hidden_states):
313
+ hidden_states = self.conv(hidden_states)
314
+
315
+ hidden_states = hidden_states.transpose(-2, -1)
316
+ hidden_states = self.layer_norm(hidden_states)
317
+ hidden_states = hidden_states.transpose(-2, -1)
318
+
319
+ hidden_states = self.activation(hidden_states)
320
+ return hidden_states
321
+
322
+
323
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
324
+ class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
325
+ def __init__(self, config, layer_id=0):
326
+ super().__init__()
327
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
328
+ self.out_conv_dim = config.conv_dim[layer_id]
329
+
330
+ self.conv = nn.Conv1d(
331
+ self.in_conv_dim,
332
+ self.out_conv_dim,
333
+ kernel_size=config.conv_kernel[layer_id],
334
+ stride=config.conv_stride[layer_id],
335
+ bias=config.conv_bias,
336
+ )
337
+ self.activation = ACT2FN[config.feat_extract_activation]
338
+
339
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
340
+
341
+ def forward(self, hidden_states):
342
+ hidden_states = self.conv(hidden_states)
343
+ hidden_states = self.layer_norm(hidden_states)
344
+ hidden_states = self.activation(hidden_states)
345
+ return hidden_states
346
+
347
+
348
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
349
+ class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.conv = nn.Conv1d(
353
+ config.hidden_size,
354
+ config.hidden_size,
355
+ kernel_size=config.num_conv_pos_embeddings,
356
+ padding=config.num_conv_pos_embeddings // 2,
357
+ groups=config.num_conv_pos_embedding_groups,
358
+ )
359
+
360
+ if is_deepspeed_zero3_enabled():
361
+ import deepspeed
362
+
363
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
364
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
365
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
366
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
367
+ else:
368
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
369
+
370
+ self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
371
+ self.activation = ACT2FN[config.feat_extract_activation]
372
+
373
+ def forward(self, hidden_states):
374
+ hidden_states = hidden_states.transpose(1, 2)
375
+
376
+ hidden_states = self.conv(hidden_states)
377
+ hidden_states = self.padding(hidden_states)
378
+ hidden_states = self.activation(hidden_states)
379
+
380
+ hidden_states = hidden_states.transpose(1, 2)
381
+ return hidden_states
382
+
383
+
384
+ class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
385
+ """Rotary positional embedding
386
+ Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
387
+ """
388
+
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ dim = config.hidden_size // config.num_attention_heads
392
+ base = config.rotary_embedding_base
393
+
394
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
395
+ self.register_buffer("inv_freq", inv_freq)
396
+ self.cached_sequence_length = None
397
+ self.cached_rotary_positional_embedding = None
398
+
399
+ def forward(self, hidden_states):
400
+ sequence_length = hidden_states.shape[1]
401
+
402
+ if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
403
+ return self.cached_rotary_positional_embedding
404
+
405
+ self.cached_sequence_length = sequence_length
406
+ time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
407
+ freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
408
+ embeddings = torch.cat((freqs, freqs), dim=-1)
409
+
410
+ cos_embeddings = embeddings.cos()[:, None, None, :]
411
+ sin_embeddings = embeddings.sin()[:, None, None, :]
412
+ self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
413
+ return self.cached_rotary_positional_embedding
414
+
415
+
416
+ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
417
+ """Relative positional encoding module."""
418
+
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.max_len = config.max_source_positions
422
+ self.d_model = config.hidden_size
423
+ self.pe = None
424
+ self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
425
+
426
+ def extend_pe(self, x):
427
+ # Reset the positional encodings
428
+ if self.pe is not None:
429
+ # self.pe contains both positive and negative parts
430
+ # the length of self.pe is 2 * input_len - 1
431
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
432
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
433
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
434
+ return
435
+ # Suppose `i` is the position of query vector and `j` is the
436
+ # position of key vector. We use positive relative positions when keys
437
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
438
+ pe_positive = torch.zeros(x.size(1), self.d_model)
439
+ pe_negative = torch.zeros(x.size(1), self.d_model)
440
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
441
+ div_term = torch.exp(
442
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
443
+ )
444
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
445
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
446
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
447
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
448
+
449
+ # Reverse the order of positive indices and concat both positive and
450
+ # negative indices. This is used to support the shifting trick
451
+ # as in https://arxiv.org/abs/1901.02860
452
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
453
+ pe_negative = pe_negative[1:].unsqueeze(0)
454
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
455
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
456
+
457
+ def forward(self, hidden_states: torch.Tensor):
458
+ self.extend_pe(hidden_states)
459
+ start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
460
+ end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
461
+ relative_position_embeddings = self.pe[:, start_idx:end_idx]
462
+
463
+ return relative_position_embeddings
464
+
465
+
466
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
467
+ class Wav2Vec2ConformerSamePadLayer(nn.Module):
468
+ def __init__(self, num_conv_pos_embeddings):
469
+ super().__init__()
470
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
471
+
472
+ def forward(self, hidden_states):
473
+ if self.num_pad_remove > 0:
474
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
475
+ return hidden_states
476
+
477
+
478
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
479
+ class Wav2Vec2ConformerFeatureEncoder(nn.Module):
480
+ """Construct the features from raw audio waveform"""
481
+
482
+ def __init__(self, config):
483
+ super().__init__()
484
+
485
+ if config.feat_extract_norm == "group":
486
+ conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
487
+ Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
488
+ for i in range(config.num_feat_extract_layers - 1)
489
+ ]
490
+ elif config.feat_extract_norm == "layer":
491
+ conv_layers = [
492
+ Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
493
+ ]
494
+ else:
495
+ raise ValueError(
496
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
497
+ )
498
+ self.conv_layers = nn.ModuleList(conv_layers)
499
+ self.gradient_checkpointing = False
500
+ self._requires_grad = True
501
+
502
+ def _freeze_parameters(self):
503
+ for param in self.parameters():
504
+ param.requires_grad = False
505
+ self._requires_grad = False
506
+
507
+ def forward(self, input_values):
508
+ hidden_states = input_values[:, None]
509
+
510
+ # make sure hidden_states require grad for gradient_checkpointing
511
+ if self._requires_grad and self.training:
512
+ hidden_states.requires_grad = True
513
+
514
+ for conv_layer in self.conv_layers:
515
+ if self._requires_grad and self.gradient_checkpointing and self.training:
516
+
517
+ def create_custom_forward(module):
518
+ def custom_forward(*inputs):
519
+ return module(*inputs)
520
+
521
+ return custom_forward
522
+
523
+ hidden_states = torch.utils.checkpoint.checkpoint(
524
+ create_custom_forward(conv_layer),
525
+ hidden_states,
526
+ )
527
+ else:
528
+ hidden_states = conv_layer(hidden_states)
529
+
530
+ return hidden_states
531
+
532
+
533
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
534
+ class Wav2Vec2ConformerFeatureProjection(nn.Module):
535
+ def __init__(self, config):
536
+ super().__init__()
537
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
538
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
539
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
540
+
541
+ def forward(self, hidden_states):
542
+ # non-projected hidden states are needed for quantization
543
+ norm_hidden_states = self.layer_norm(hidden_states)
544
+ hidden_states = self.projection(norm_hidden_states)
545
+ hidden_states = self.dropout(hidden_states)
546
+ return hidden_states, norm_hidden_states
547
+
548
+
549
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
550
+ class Wav2Vec2ConformerFeedForward(nn.Module):
551
+ def __init__(self, config):
552
+ super().__init__()
553
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
554
+
555
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
556
+ if isinstance(config.hidden_act, str):
557
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
558
+ else:
559
+ self.intermediate_act_fn = config.hidden_act
560
+
561
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
562
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
563
+
564
+ def forward(self, hidden_states):
565
+ hidden_states = self.intermediate_dense(hidden_states)
566
+ hidden_states = self.intermediate_act_fn(hidden_states)
567
+ hidden_states = self.intermediate_dropout(hidden_states)
568
+
569
+ hidden_states = self.output_dense(hidden_states)
570
+ hidden_states = self.output_dropout(hidden_states)
571
+ return hidden_states
572
+
573
+
574
+ class Wav2Vec2ConformerConvolutionModule(nn.Module):
575
+ """Convolution block used in the conformer block"""
576
+
577
+ def __init__(self, config):
578
+ super().__init__()
579
+ if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
580
+ raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
581
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
582
+ self.pointwise_conv1 = torch.nn.Conv1d(
583
+ config.hidden_size,
584
+ 2 * config.hidden_size,
585
+ kernel_size=1,
586
+ stride=1,
587
+ padding=0,
588
+ bias=False,
589
+ )
590
+ self.glu = torch.nn.GLU(dim=1)
591
+ self.depthwise_conv = torch.nn.Conv1d(
592
+ config.hidden_size,
593
+ config.hidden_size,
594
+ config.conv_depthwise_kernel_size,
595
+ stride=1,
596
+ padding=(config.conv_depthwise_kernel_size - 1) // 2,
597
+ groups=config.hidden_size,
598
+ bias=False,
599
+ )
600
+ self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
601
+ self.activation = ACT2FN[config.hidden_act]
602
+ self.pointwise_conv2 = torch.nn.Conv1d(
603
+ config.hidden_size,
604
+ config.hidden_size,
605
+ kernel_size=1,
606
+ stride=1,
607
+ padding=0,
608
+ bias=False,
609
+ )
610
+ self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
611
+
612
+ def forward(self, hidden_states):
613
+ hidden_states = self.layer_norm(hidden_states)
614
+ # exchange the temporal dimension and the feature dimension
615
+ hidden_states = hidden_states.transpose(1, 2)
616
+
617
+ # GLU mechanism
618
+ # => (batch, 2*channel, dim)
619
+ hidden_states = self.pointwise_conv1(hidden_states)
620
+ # => (batch, channel, dim)
621
+ hidden_states = self.glu(hidden_states)
622
+
623
+ # 1D Depthwise Conv
624
+ hidden_states = self.depthwise_conv(hidden_states)
625
+ hidden_states = self.batch_norm(hidden_states)
626
+ hidden_states = self.activation(hidden_states)
627
+
628
+ hidden_states = self.pointwise_conv2(hidden_states)
629
+ hidden_states = self.dropout(hidden_states)
630
+ hidden_states = hidden_states.transpose(1, 2)
631
+ return hidden_states
632
+
633
+
634
+ class Wav2Vec2ConformerSelfAttention(nn.Module):
635
+ """Construct an Wav2Vec2ConformerSelfAttention object.
636
+ Can be enhanced with rotary or relative position embeddings.
637
+ """
638
+
639
+ def __init__(self, config):
640
+ super().__init__()
641
+
642
+ self.head_size = config.hidden_size // config.num_attention_heads
643
+ self.num_heads = config.num_attention_heads
644
+ self.position_embeddings_type = config.position_embeddings_type
645
+
646
+ self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
647
+ self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
648
+ self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
649
+ self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
650
+
651
+ self.dropout = nn.Dropout(p=config.attention_dropout)
652
+ self.dropout_p = config.attention_dropout
653
+
654
+ self.is_causal = config.is_causal
655
+
656
+ if self.position_embeddings_type == "relative":
657
+ # linear transformation for positional encoding
658
+ self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
659
+ # these two learnable bias are used in matrix c and matrix d
660
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
661
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
662
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
663
+
664
+ def forward(
665
+ self,
666
+ hidden_states: torch.Tensor,
667
+ attention_mask: Optional[torch.Tensor] = None,
668
+ relative_position_embeddings: Optional[torch.Tensor] = None,
669
+ output_attentions: bool = False,
670
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
671
+ # self-attention mechanism
672
+ batch_size, sequence_length, hidden_size = hidden_states.size()
673
+
674
+ # make sure query/key states can be != value states
675
+ query_key_states = hidden_states
676
+ value_states = hidden_states
677
+
678
+ if self.position_embeddings_type == "rotary":
679
+ if relative_position_embeddings is None:
680
+ raise ValueError(
681
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
682
+ )
683
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
684
+
685
+ # project query_key_states and value_states
686
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
687
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
688
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
689
+
690
+ # => (batch, head, time1, d_k)
691
+ query = query.transpose(1, 2)
692
+ key = key.transpose(1, 2)
693
+ value = value.transpose(1, 2)
694
+
695
+ with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
696
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
697
+ probs = None
698
+
699
+ # # apply attention_mask if necessary
700
+ # if attention_mask is not None:
701
+ # scores = scores + attention_mask
702
+
703
+ # # => (batch, head, time1, time2)
704
+ # probs = torch.softmax(scores, dim=-1)
705
+ # probs = self.dropout(probs)
706
+
707
+ # # => (batch, head, time1, d_k)
708
+ # hidden_states = torch.matmul(probs, value)
709
+
710
+ # => (batch, time1, hidden_size)
711
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
712
+ hidden_states = self.linear_out(hidden_states)
713
+
714
+ return hidden_states, probs
715
+
716
+ def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
717
+ batch_size, sequence_length, hidden_size = hidden_states.size()
718
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
719
+
720
+ cos = relative_position_embeddings[0, :sequence_length, ...]
721
+ sin = relative_position_embeddings[1, :sequence_length, ...]
722
+
723
+ # rotate hidden_states with rotary embeddings
724
+ hidden_states = hidden_states.transpose(0, 1)
725
+ rotated_states_begin = hidden_states[..., : self.head_size // 2]
726
+ rotated_states_end = hidden_states[..., self.head_size // 2 :]
727
+ rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
728
+ hidden_states = (hidden_states * cos) + (rotated_states * sin)
729
+ hidden_states = hidden_states.transpose(0, 1)
730
+
731
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
732
+
733
+ return hidden_states
734
+
735
+ def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
736
+ # 1. project positional embeddings
737
+ # => (batch, head, 2*time1-1, d_k)
738
+ proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
739
+ proj_relative_position_embeddings = proj_relative_position_embeddings.view(
740
+ relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
741
+ )
742
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
743
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
744
+
745
+ # 2. Add bias to query
746
+ # => (batch, head, time1, d_k)
747
+ query = query.transpose(1, 2)
748
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
749
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
750
+
751
+ # 3. attention score: first compute matrix a and matrix c
752
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
753
+ # => (batch, head, time1, time2)
754
+ scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
755
+
756
+ # 4. then compute matrix b and matrix d
757
+ # => (batch, head, time1, 2*time1-1)
758
+ scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
759
+
760
+ # 5. shift matrix b and matrix d
761
+ zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
762
+ scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
763
+ scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
764
+ scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
765
+ scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
766
+ scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
767
+
768
+ # 6. sum matrices
769
+ # => (batch, head, time1, time2)
770
+ scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
771
+
772
+ return scores
773
+
774
+
775
+ class Wav2Vec2ConformerEncoderLayer(nn.Module):
776
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
777
+
778
+ def __init__(self, config):
779
+ super().__init__()
780
+ embed_dim = config.hidden_size
781
+ dropout = config.attention_dropout
782
+
783
+ # Feed-forward 1
784
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
785
+ self.ffn1 = Wav2Vec2ConformerFeedForward(config)
786
+
787
+ # Self-Attention
788
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
789
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
790
+ self.self_attn = Wav2Vec2ConformerSelfAttention(config)
791
+
792
+ # Conformer Convolution
793
+ self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
794
+
795
+ # Feed-forward 2
796
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
797
+ self.ffn2 = Wav2Vec2ConformerFeedForward(config)
798
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
799
+
800
+ def forward(
801
+ self,
802
+ hidden_states,
803
+ attention_mask: Optional[torch.Tensor] = None,
804
+ relative_position_embeddings: Optional[torch.Tensor] = None,
805
+ output_attentions: bool = False,
806
+ ):
807
+ hidden_states = hidden_states
808
+
809
+ # 1. Feed-Forward 1 layer
810
+ residual = hidden_states
811
+ hidden_states = self.ffn1_layer_norm(hidden_states)
812
+ hidden_states = self.ffn1(hidden_states)
813
+ hidden_states = hidden_states * 0.5 + residual
814
+ residual = hidden_states
815
+
816
+ # 2. Self-Attention layer
817
+ hidden_states = self.self_attn_layer_norm(hidden_states)
818
+ hidden_states, attn_weigts = self.self_attn(
819
+ hidden_states=hidden_states,
820
+ attention_mask=attention_mask,
821
+ relative_position_embeddings=relative_position_embeddings,
822
+ output_attentions=output_attentions,
823
+ )
824
+ hidden_states = self.self_attn_dropout(hidden_states)
825
+ hidden_states = hidden_states + residual
826
+
827
+ # 3. Convolutional Layer
828
+ residual = hidden_states
829
+ hidden_states = self.conv_module(hidden_states)
830
+ hidden_states = residual + hidden_states
831
+
832
+ # 4. Feed-Forward 2 Layer
833
+ residual = hidden_states
834
+ hidden_states = self.ffn2_layer_norm(hidden_states)
835
+ hidden_states = self.ffn2(hidden_states)
836
+ hidden_states = hidden_states * 0.5 + residual
837
+ hidden_states = self.final_layer_norm(hidden_states)
838
+
839
+ return hidden_states, attn_weigts
840
+
841
+
842
+ class Wav2Vec2ConformerEncoder(nn.Module):
843
+ def __init__(self, config, is_causal=False):
844
+ super().__init__()
845
+ config.is_causal = is_causal
846
+ self.config = config
847
+
848
+ if config.position_embeddings_type == "relative":
849
+ self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
850
+ elif config.position_embeddings_type == "rotary":
851
+ self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
852
+ else:
853
+ self.embed_positions = None
854
+
855
+ self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
856
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
857
+ self.dropout = nn.Dropout(config.hidden_dropout)
858
+ self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
859
+ self.gradient_checkpointing = False
860
+
861
+ def forward(
862
+ self,
863
+ hidden_states,
864
+ attention_mask=None,
865
+ output_attentions=False,
866
+ output_hidden_states=False,
867
+ return_dict=True,
868
+ ):
869
+ all_hidden_states = () if output_hidden_states else None
870
+ all_self_attentions = () if output_attentions else None
871
+
872
+ if attention_mask is not None:
873
+ # make sure padded tokens output 0
874
+ hidden_states[~attention_mask] = 0.0
875
+
876
+ # extend attention_mask
877
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
878
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
879
+ attention_mask = attention_mask.expand(
880
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
881
+ )
882
+
883
+ hidden_states = self.dropout(hidden_states)
884
+
885
+ if self.embed_positions is not None:
886
+ relative_position_embeddings = self.embed_positions(hidden_states)
887
+ else:
888
+ relative_position_embeddings = None
889
+
890
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
891
+
892
+ for i, layer in enumerate(self.layers):
893
+ if output_hidden_states:
894
+ all_hidden_states = all_hidden_states + (hidden_states,)
895
+
896
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
897
+ dropout_probability = np.random.uniform(0, 1)
898
+
899
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
900
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
901
+ # under deepspeed zero3 all gpus must run in sync
902
+ if self.gradient_checkpointing and self.training:
903
+ # create gradient checkpointing function
904
+ def create_custom_forward(module):
905
+ def custom_forward(*inputs):
906
+ return module(*inputs, output_attentions)
907
+
908
+ return custom_forward
909
+
910
+ layer_outputs = torch.utils.checkpoint.checkpoint(
911
+ create_custom_forward(layer),
912
+ hidden_states,
913
+ attention_mask,
914
+ relative_position_embeddings,
915
+ )
916
+ else:
917
+ layer_outputs = layer(
918
+ hidden_states,
919
+ attention_mask=attention_mask,
920
+ relative_position_embeddings=relative_position_embeddings,
921
+ output_attentions=output_attentions,
922
+ )
923
+ hidden_states = layer_outputs[0]
924
+
925
+ if skip_the_layer:
926
+ layer_outputs = (None, None)
927
+
928
+ if output_attentions:
929
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
930
+
931
+ hidden_states = self.layer_norm(hidden_states)
932
+ if output_hidden_states:
933
+ all_hidden_states = all_hidden_states + (hidden_states,)
934
+
935
+ if not return_dict:
936
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
937
+ return BaseModelOutput(
938
+ last_hidden_state=hidden_states,
939
+ hidden_states=all_hidden_states,
940
+ attentions=all_self_attentions,
941
+ )
942
+
943
+
944
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
945
+ class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
946
+ """
947
+ Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
948
+ GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
949
+ """
950
+
951
+ def __init__(self, config):
952
+ super().__init__()
953
+ self.num_groups = config.num_codevector_groups
954
+ self.num_vars = config.num_codevectors_per_group
955
+
956
+ if config.codevector_dim % self.num_groups != 0:
957
+ raise ValueError(
958
+ f"`config.codevector_dim {config.codevector_dim} must be divisible "
959
+ f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
960
+ )
961
+
962
+ # storage for codebook variables (codewords)
963
+ self.codevectors = nn.Parameter(
964
+ torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
965
+ )
966
+ self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
967
+
968
+ # can be decayed for training
969
+ self.temperature = 2
970
+
971
+ @staticmethod
972
+ def _compute_perplexity(probs, mask=None):
973
+ if mask is not None:
974
+ mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
975
+ probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
976
+ marginal_probs = probs.sum(dim=0) / mask.sum()
977
+ else:
978
+ marginal_probs = probs.mean(dim=0)
979
+
980
+ perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
981
+ return perplexity
982
+
983
+ def forward(self, hidden_states, mask_time_indices=None):
984
+ batch_size, sequence_length, hidden_size = hidden_states.shape
985
+
986
+ # project to codevector dim
987
+ hidden_states = self.weight_proj(hidden_states)
988
+ hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
989
+
990
+ if self.training:
991
+ # sample code vector probs via gumbel in differentiateable way
992
+ codevector_probs = nn.functional.gumbel_softmax(
993
+ hidden_states.float(), tau=self.temperature, hard=True
994
+ ).type_as(hidden_states)
995
+
996
+ # compute perplexity
997
+ codevector_soft_dist = torch.softmax(
998
+ hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
999
+ )
1000
+ perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
1001
+ else:
1002
+ # take argmax in non-differentiable way
1003
+ # comptute hard codevector distribution (one hot)
1004
+ codevector_idx = hidden_states.argmax(dim=-1)
1005
+ codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
1006
+ -1, codevector_idx.view(-1, 1), 1.0
1007
+ )
1008
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
1009
+
1010
+ perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
1011
+
1012
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
1013
+ # use probs to retrieve codevectors
1014
+ codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
1015
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
1016
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
1017
+
1018
+ return codevectors, perplexity
1019
+
1020
+
1021
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
1022
+ class Wav2Vec2ConformerAdapter(nn.Module):
1023
+ def __init__(self, config):
1024
+ super().__init__()
1025
+
1026
+ # feature dim might need to be down-projected
1027
+ if config.output_hidden_size != config.hidden_size:
1028
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
1029
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
1030
+ else:
1031
+ self.proj = self.proj_layer_norm = None
1032
+
1033
+ self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
1034
+ self.layerdrop = config.layerdrop
1035
+
1036
+ def forward(self, hidden_states):
1037
+ # down project hidden_states if necessary
1038
+ if self.proj is not None and self.proj_layer_norm is not None:
1039
+ hidden_states = self.proj(hidden_states)
1040
+ hidden_states = self.proj_layer_norm(hidden_states)
1041
+
1042
+ hidden_states = hidden_states.transpose(1, 2)
1043
+
1044
+ for layer in self.layers:
1045
+ layerdrop_prob = np.random.random()
1046
+ if not self.training or (layerdrop_prob > self.layerdrop):
1047
+ hidden_states = layer(hidden_states)
1048
+
1049
+ hidden_states = hidden_states.transpose(1, 2)
1050
+ return hidden_states
1051
+
1052
+
1053
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
1054
+ class Wav2Vec2ConformerAdapterLayer(nn.Module):
1055
+ def __init__(self, config):
1056
+ super().__init__()
1057
+ self.conv = nn.Conv1d(
1058
+ config.output_hidden_size,
1059
+ 2 * config.output_hidden_size,
1060
+ config.adapter_kernel_size,
1061
+ stride=config.adapter_stride,
1062
+ padding=1,
1063
+ )
1064
+
1065
+ def forward(self, hidden_states):
1066
+ hidden_states = self.conv(hidden_states)
1067
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
1068
+
1069
+ return hidden_states
1070
+
1071
+
1072
+ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
1073
+ """
1074
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1075
+ models.
1076
+ """
1077
+
1078
+ config_class = Wav2Vec2ConformerConfig
1079
+ base_model_prefix = "wav2vec2_conformer"
1080
+ main_input_name = "input_values"
1081
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1082
+ supports_gradient_checkpointing = True
1083
+
1084
+ def _init_weights(self, module):
1085
+ """Initialize the weights"""
1086
+ # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
1087
+ if isinstance(module, Wav2Vec2ConformerForPreTraining):
1088
+ module.project_hid.reset_parameters()
1089
+ module.project_q.reset_parameters()
1090
+ module.project_hid._is_hf_initialized = True
1091
+ module.project_q._is_hf_initialized = True
1092
+ # gumbel softmax requires special init
1093
+ elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
1094
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
1095
+ module.weight_proj.bias.data.zero_()
1096
+ nn.init.uniform_(module.codevectors)
1097
+ elif isinstance(module, Wav2Vec2ConformerSelfAttention):
1098
+ if hasattr(module, "pos_bias_u"):
1099
+ nn.init.xavier_uniform_(module.pos_bias_u)
1100
+ if hasattr(module, "pos_bias_v"):
1101
+ nn.init.xavier_uniform_(module.pos_bias_v)
1102
+ elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
1103
+ nn.init.normal_(
1104
+ module.conv.weight,
1105
+ mean=0,
1106
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
1107
+ )
1108
+ nn.init.constant_(module.conv.bias, 0)
1109
+ elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
1110
+ k = math.sqrt(1 / module.projection.in_features)
1111
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
1112
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
1113
+ elif isinstance(module, nn.Linear):
1114
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1115
+
1116
+ if module.bias is not None:
1117
+ module.bias.data.zero_()
1118
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
1119
+ module.bias.data.zero_()
1120
+ module.weight.data.fill_(1.0)
1121
+ elif isinstance(module, nn.Conv1d):
1122
+ nn.init.kaiming_normal_(module.weight)
1123
+
1124
+ if module.bias is not None:
1125
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1126
+ nn.init.uniform_(module.bias, a=-k, b=k)
1127
+
1128
+ def _get_feat_extract_output_lengths(
1129
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
1130
+ ):
1131
+ """
1132
+ Computes the output length of the convolutional layers
1133
+ """
1134
+
1135
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
1136
+
1137
+ def _conv_out_length(input_length, kernel_size, stride):
1138
+ # 1D convolutional layer output length formula taken
1139
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
1140
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
1141
+
1142
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
1143
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
1144
+
1145
+ if add_adapter:
1146
+ for _ in range(self.config.num_adapter_layers):
1147
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
1148
+
1149
+ return input_lengths
1150
+
1151
+ def _get_feature_vector_attention_mask(
1152
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
1153
+ ):
1154
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
1155
+ # on inference mode.
1156
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
1157
+
1158
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
1159
+ output_lengths = output_lengths.to(torch.long)
1160
+
1161
+ batch_size = attention_mask.shape[0]
1162
+
1163
+ attention_mask = torch.zeros(
1164
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
1165
+ )
1166
+ # these two operations makes sure that all values before the output lengths idxs are attended to
1167
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
1168
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
1169
+ return attention_mask
1170
+
1171
+ def _set_gradient_checkpointing(self, module, value=False):
1172
+ if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
1173
+ module.gradient_checkpointing = value
1174
+
1175
+
1176
+ WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
1177
+ Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
1178
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
1179
+ Auli.
1180
+
1181
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1182
+ library implements for all its model (such as downloading or saving etc.).
1183
+
1184
+ This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
1185
+ regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
1186
+
1187
+ Parameters:
1188
+ config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
1189
+ Initializing with a config file does not load the weights associated with the model, only the
1190
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1191
+ """
1192
+
1193
+
1194
+ WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
1195
+ Args:
1196
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
1197
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
1198
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
1199
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
1200
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
1201
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1202
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1203
+ 1]`:
1204
+
1205
+ - 1 for tokens that are **not masked**,
1206
+ - 0 for tokens that are **masked**.
1207
+
1208
+ [What are attention masks?](../glossary#attention-mask)
1209
+
1210
+ <Tip warning={true}>
1211
+
1212
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
1213
+ True`. For all models whose processor has `config.return_attention_mask == False`, such as
1214
+ [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
1215
+ `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
1216
+ such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
1217
+ that these models also yield slightly different results depending on whether `input_values` is padded or
1218
+ not.
1219
+
1220
+ </Tip>
1221
+
1222
+ output_attentions (`bool`, *optional*):
1223
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1224
+ tensors for more detail.
1225
+ output_hidden_states (`bool`, *optional*):
1226
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1227
+ more detail.
1228
+ return_dict (`bool`, *optional*):
1229
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1230
+ """
1231
+
1232
+
1233
+ @add_start_docstrings(
1234
+ "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
1235
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1236
+ )
1237
+ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
1238
+ def __init__(self, config: Wav2Vec2ConformerConfig):
1239
+ super().__init__(config)
1240
+ self.config = config
1241
+ self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
1242
+ self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
1243
+
1244
+ # model only needs masking vector if mask prob is > 0.0
1245
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
1246
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
1247
+
1248
+ self.encoder = Wav2Vec2ConformerEncoder(config)
1249
+
1250
+ self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
1251
+
1252
+ # Initialize weights and apply final processing
1253
+ self.post_init()
1254
+
1255
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
1256
+ def freeze_feature_encoder(self):
1257
+ """
1258
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1259
+ not be updated during training.
1260
+ """
1261
+ self.feature_extractor._freeze_parameters()
1262
+
1263
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
1264
+ def _mask_hidden_states(
1265
+ self,
1266
+ hidden_states: torch.FloatTensor,
1267
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1268
+ attention_mask: Optional[torch.LongTensor] = None,
1269
+ ):
1270
+ """
1271
+ Masks extracted features along time axis and/or along feature axis according to
1272
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
1273
+ """
1274
+
1275
+ # `config.apply_spec_augment` can set masking to False
1276
+ if not getattr(self.config, "apply_spec_augment", True):
1277
+ return hidden_states
1278
+
1279
+ # generate indices & apply SpecAugment along time axis
1280
+ batch_size, sequence_length, hidden_size = hidden_states.size()
1281
+
1282
+ if mask_time_indices is not None:
1283
+ # apply SpecAugment along time axis with given mask_time_indices
1284
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1285
+ elif self.config.mask_time_prob > 0 and self.training:
1286
+ mask_time_indices = _compute_mask_indices(
1287
+ (batch_size, sequence_length),
1288
+ mask_prob=self.config.mask_time_prob,
1289
+ mask_length=self.config.mask_time_length,
1290
+ attention_mask=attention_mask,
1291
+ min_masks=self.config.mask_time_min_masks,
1292
+ )
1293
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
1294
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1295
+
1296
+ if self.config.mask_feature_prob > 0 and self.training:
1297
+ # generate indices & apply SpecAugment along feature axis
1298
+ mask_feature_indices = _compute_mask_indices(
1299
+ (batch_size, hidden_size),
1300
+ mask_prob=self.config.mask_feature_prob,
1301
+ mask_length=self.config.mask_feature_length,
1302
+ min_masks=self.config.mask_feature_min_masks,
1303
+ )
1304
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
1305
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
1306
+ hidden_states[mask_feature_indices] = 0
1307
+
1308
+ return hidden_states
1309
+
1310
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1311
+ @add_code_sample_docstrings(
1312
+ checkpoint=_CHECKPOINT_FOR_DOC,
1313
+ output_type=Wav2Vec2BaseModelOutput,
1314
+ config_class=_CONFIG_FOR_DOC,
1315
+ modality="audio",
1316
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1317
+ )
1318
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
1319
+ def forward(
1320
+ self,
1321
+ input_values: Optional[torch.Tensor],
1322
+ attention_mask: Optional[torch.Tensor] = None,
1323
+ mask_time_indices: Optional[torch.FloatTensor] = None,
1324
+ output_attentions: Optional[bool] = None,
1325
+ output_hidden_states: Optional[bool] = None,
1326
+ return_dict: Optional[bool] = None,
1327
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
1328
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1329
+ output_hidden_states = (
1330
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1331
+ )
1332
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1333
+
1334
+ extract_features = self.feature_extractor(input_values)
1335
+ extract_features = extract_features.transpose(1, 2)
1336
+
1337
+ if attention_mask is not None:
1338
+ # compute reduced attention_mask corresponding to feature vectors
1339
+ attention_mask = self._get_feature_vector_attention_mask(
1340
+ extract_features.shape[1], attention_mask, add_adapter=False
1341
+ )
1342
+
1343
+ hidden_states, extract_features = self.feature_projection(extract_features)
1344
+ hidden_states = self._mask_hidden_states(
1345
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
1346
+ )
1347
+
1348
+ encoder_outputs = self.encoder(
1349
+ hidden_states,
1350
+ attention_mask=attention_mask,
1351
+ output_attentions=output_attentions,
1352
+ output_hidden_states=output_hidden_states,
1353
+ return_dict=return_dict,
1354
+ )
1355
+
1356
+ hidden_states = encoder_outputs[0]
1357
+
1358
+ if self.adapter is not None:
1359
+ hidden_states = self.adapter(hidden_states)
1360
+
1361
+ if not return_dict:
1362
+ return (hidden_states, extract_features) + encoder_outputs[1:]
1363
+
1364
+ return Wav2Vec2BaseModelOutput(
1365
+ last_hidden_state=hidden_states,
1366
+ extract_features=extract_features,
1367
+ hidden_states=encoder_outputs.hidden_states,
1368
+ attentions=encoder_outputs.attentions,
1369
+ )
1370
+
1371
+
1372
+ @add_start_docstrings(
1373
+ """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
1374
+ )
1375
+ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
1376
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1377
+ def __init__(self, config: Wav2Vec2ConformerConfig):
1378
+ super().__init__(config)
1379
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1380
+ self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
1381
+
1382
+ self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
1383
+
1384
+ self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
1385
+ self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
1386
+
1387
+ # Initialize weights and apply final processing
1388
+ self.post_init()
1389
+
1390
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
1391
+ def set_gumbel_temperature(self, temperature: int):
1392
+ """
1393
+ Set the Gumbel softmax temperature to a given value. Only necessary for training
1394
+ """
1395
+ self.quantizer.temperature = temperature
1396
+
1397
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1398
+ def freeze_feature_encoder(self):
1399
+ """
1400
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1401
+ not be updated during training.
1402
+ """
1403
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1404
+
1405
+ @staticmethod
1406
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
1407
+ def compute_contrastive_logits(
1408
+ target_features: torch.FloatTensor,
1409
+ negative_features: torch.FloatTensor,
1410
+ predicted_features: torch.FloatTensor,
1411
+ temperature: int = 0.1,
1412
+ ):
1413
+ """
1414
+ Compute logits for contrastive loss based using cosine similarity as the distance measure between
1415
+ `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
1416
+ """
1417
+ target_features = torch.cat([target_features, negative_features], dim=0)
1418
+
1419
+ logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
1420
+ target_features
1421
+ )
1422
+
1423
+ # apply temperature
1424
+ logits = logits / temperature
1425
+ return logits
1426
+
1427
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1428
+ @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1429
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
1430
+ def forward(
1431
+ self,
1432
+ input_values: Optional[torch.Tensor],
1433
+ attention_mask: Optional[torch.Tensor] = None,
1434
+ mask_time_indices: Optional[torch.BoolTensor] = None,
1435
+ sampled_negative_indices: Optional[torch.BoolTensor] = None,
1436
+ output_attentions: Optional[bool] = None,
1437
+ output_hidden_states: Optional[bool] = None,
1438
+ return_dict: Optional[bool] = None,
1439
+ ) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
1440
+ r"""
1441
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
1442
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
1443
+ masked extracted features in *config.proj_codevector_dim* space.
1444
+ sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
1445
+ Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
1446
+ Required input for pre-training.
1447
+
1448
+ Returns:
1449
+
1450
+ Example:
1451
+
1452
+ ```python
1453
+ >>> import torch
1454
+ >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
1455
+ >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
1456
+ ... _compute_mask_indices,
1457
+ ... _sample_negative_indices,
1458
+ ... )
1459
+ >>> from datasets import load_dataset
1460
+
1461
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
1462
+ >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
1463
+
1464
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1465
+ >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
1466
+
1467
+ >>> # compute masked indices
1468
+ >>> batch_size, raw_sequence_length = input_values.shape
1469
+ >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
1470
+ >>> mask_time_indices = _compute_mask_indices(
1471
+ ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
1472
+ ... )
1473
+ >>> sampled_negative_indices = _sample_negative_indices(
1474
+ ... features_shape=(batch_size, sequence_length),
1475
+ ... num_negatives=model.config.num_negatives,
1476
+ ... mask_time_indices=mask_time_indices,
1477
+ ... )
1478
+ >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
1479
+ >>> sampled_negative_indices = torch.tensor(
1480
+ ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
1481
+ ... )
1482
+
1483
+ >>> with torch.no_grad():
1484
+ ... outputs = model(input_values, mask_time_indices=mask_time_indices)
1485
+
1486
+ >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
1487
+ >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
1488
+
1489
+ >>> # show that cosine similarity is much higher than random
1490
+ >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
1491
+ tensor(True)
1492
+
1493
+ >>> # for contrastive loss training model should be put into train mode
1494
+ >>> model = model.train()
1495
+ >>> loss = model(
1496
+ ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
1497
+ ... ).loss
1498
+ ```"""
1499
+
1500
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1501
+
1502
+ if mask_time_indices is not None:
1503
+ mask_time_indices = mask_time_indices.to(torch.bool)
1504
+
1505
+ outputs = self.wav2vec2_conformer(
1506
+ input_values,
1507
+ attention_mask=attention_mask,
1508
+ output_attentions=output_attentions,
1509
+ output_hidden_states=output_hidden_states,
1510
+ mask_time_indices=mask_time_indices,
1511
+ return_dict=return_dict,
1512
+ )
1513
+
1514
+ # 1. project all transformed features (including masked) to final vq dim
1515
+ transformer_features = self.project_hid(outputs[0])
1516
+
1517
+ # 2. quantize all (unmasked) extracted features and project to final vq dim
1518
+ extract_features = self.dropout_features(outputs[1])
1519
+
1520
+ if attention_mask is not None:
1521
+ # compute reduced attention_mask correponding to feature vectors
1522
+ attention_mask = self._get_feature_vector_attention_mask(
1523
+ extract_features.shape[1], attention_mask, add_adapter=False
1524
+ )
1525
+
1526
+ quantized_features, codevector_perplexity = self.quantizer(
1527
+ extract_features, mask_time_indices=mask_time_indices
1528
+ )
1529
+ quantized_features = self.project_q(quantized_features)
1530
+
1531
+ loss = contrastive_loss = diversity_loss = None
1532
+ if sampled_negative_indices is not None:
1533
+ batch_size, sequence_length, hidden_size = quantized_features.shape
1534
+
1535
+ # for training, we sample negatives
1536
+ # 3. sample K negatives (distractors) quantized states for contrastive loss
1537
+ # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
1538
+ # sample negative quantized vectors BTC => (BxT)C
1539
+ negative_quantized_features = quantized_features.view(-1, hidden_size)[
1540
+ sampled_negative_indices.long().view(-1)
1541
+ ]
1542
+ negative_quantized_features = negative_quantized_features.view(
1543
+ batch_size, sequence_length, -1, hidden_size
1544
+ ).permute(2, 0, 1, 3)
1545
+
1546
+ # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
1547
+ # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
1548
+ logits = self.compute_contrastive_logits(
1549
+ quantized_features[None, :],
1550
+ negative_quantized_features,
1551
+ transformer_features,
1552
+ self.config.contrastive_logits_temperature,
1553
+ )
1554
+
1555
+ # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
1556
+ # its cosine similarity will be masked
1557
+ neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
1558
+
1559
+ if neg_is_pos.any():
1560
+ logits[1:][neg_is_pos] = float("-inf")
1561
+
1562
+ # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
1563
+ # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
1564
+ logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
1565
+ target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
1566
+
1567
+ contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
1568
+ # 7. compute diversity loss: \mathbf{L}_d
1569
+ num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
1570
+ diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
1571
+
1572
+ # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
1573
+ loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
1574
+
1575
+ if not return_dict:
1576
+ if loss is not None:
1577
+ return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
1578
+ return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
1579
+
1580
+ return Wav2Vec2ConformerForPreTrainingOutput(
1581
+ loss=loss,
1582
+ projected_states=transformer_features,
1583
+ projected_quantized_states=quantized_features,
1584
+ codevector_perplexity=codevector_perplexity,
1585
+ hidden_states=outputs.hidden_states,
1586
+ attentions=outputs.attentions,
1587
+ contrastive_loss=contrastive_loss,
1588
+ diversity_loss=diversity_loss,
1589
+ )
1590
+
1591
+
1592
+ @add_start_docstrings(
1593
+ """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
1594
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1595
+ )
1596
+ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
1597
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1598
+ def __init__(self, config):
1599
+ super().__init__(config)
1600
+
1601
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1602
+ self.dropout = nn.Dropout(config.final_dropout)
1603
+
1604
+ if config.vocab_size is None:
1605
+ raise ValueError(
1606
+ f"You are trying to instantiate {self.__class__} with a configuration that "
1607
+ "does not define the vocabulary size of the language model head. Please "
1608
+ "instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
1609
+ "or define `vocab_size` of your model's configuration."
1610
+ )
1611
+ output_hidden_size = (
1612
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
1613
+ )
1614
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
1615
+
1616
+ # Initialize weights and apply final processing
1617
+ self.post_init()
1618
+
1619
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1620
+ def freeze_feature_encoder(self):
1621
+ """
1622
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1623
+ not be updated during training.
1624
+ """
1625
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1626
+
1627
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1628
+ @add_code_sample_docstrings(
1629
+ checkpoint=_CHECKPOINT_FOR_DOC,
1630
+ output_type=CausalLMOutput,
1631
+ config_class=_CONFIG_FOR_DOC,
1632
+ expected_output=_CTC_EXPECTED_OUTPUT,
1633
+ expected_loss=_CTC_EXPECTED_LOSS,
1634
+ )
1635
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1636
+ def forward(
1637
+ self,
1638
+ input_values: Optional[torch.Tensor],
1639
+ attention_mask: Optional[torch.Tensor] = None,
1640
+ output_attentions: Optional[bool] = None,
1641
+ output_hidden_states: Optional[bool] = None,
1642
+ return_dict: Optional[bool] = None,
1643
+ labels: Optional[torch.Tensor] = None,
1644
+ ) -> Union[Tuple, CausalLMOutput]:
1645
+ r"""
1646
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
1647
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
1648
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
1649
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
1650
+ config.vocab_size - 1]`.
1651
+ """
1652
+
1653
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1654
+
1655
+ outputs = self.wav2vec2_conformer(
1656
+ input_values,
1657
+ attention_mask=attention_mask,
1658
+ output_attentions=output_attentions,
1659
+ output_hidden_states=output_hidden_states,
1660
+ return_dict=return_dict,
1661
+ )
1662
+
1663
+ hidden_states = outputs[0]
1664
+ hidden_states = self.dropout(hidden_states)
1665
+
1666
+ logits = self.lm_head(hidden_states)
1667
+
1668
+ loss = None
1669
+ if labels is not None:
1670
+ if labels.max() >= self.config.vocab_size:
1671
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
1672
+
1673
+ # retrieve loss input_lengths from attention_mask
1674
+ attention_mask = (
1675
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
1676
+ )
1677
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
1678
+
1679
+ # assuming that padded tokens are filled with -100
1680
+ # when not being attended to
1681
+ labels_mask = labels >= 0
1682
+ target_lengths = labels_mask.sum(-1)
1683
+ flattened_targets = labels.masked_select(labels_mask)
1684
+
1685
+ # ctc_loss doesn't support fp16
1686
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
1687
+
1688
+ with torch.backends.cudnn.flags(enabled=False):
1689
+ loss = nn.functional.ctc_loss(
1690
+ log_probs,
1691
+ flattened_targets,
1692
+ input_lengths,
1693
+ target_lengths,
1694
+ blank=self.config.pad_token_id,
1695
+ reduction=self.config.ctc_loss_reduction,
1696
+ zero_infinity=self.config.ctc_zero_infinity,
1697
+ )
1698
+
1699
+ if not return_dict:
1700
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1701
+ return ((loss,) + output) if loss is not None else output
1702
+
1703
+ return CausalLMOutput(
1704
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1705
+ )
1706
+
1707
+
1708
+ @add_start_docstrings(
1709
+ """
1710
+ Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
1711
+ tasks like SUPERB Keyword Spotting.
1712
+ """,
1713
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1714
+ )
1715
+ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
1716
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1717
+ def __init__(self, config):
1718
+ super().__init__(config)
1719
+
1720
+ if hasattr(config, "add_adapter") and config.add_adapter:
1721
+ raise ValueError(
1722
+ "Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
1723
+ )
1724
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1725
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1726
+ if config.use_weighted_layer_sum:
1727
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1728
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
1729
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
1730
+
1731
+ # Initialize weights and apply final processing
1732
+ self.post_init()
1733
+
1734
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1735
+ def freeze_feature_encoder(self):
1736
+ """
1737
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1738
+ not be updated during training.
1739
+ """
1740
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1741
+
1742
+ def freeze_base_model(self):
1743
+ """
1744
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1745
+ be updated during training. Only the classification head will be updated.
1746
+ """
1747
+ for param in self.wav2vec2_conformer.parameters():
1748
+ param.requires_grad = False
1749
+
1750
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1751
+ @add_code_sample_docstrings(
1752
+ checkpoint=_CHECKPOINT_FOR_DOC,
1753
+ output_type=SequenceClassifierOutput,
1754
+ config_class=_CONFIG_FOR_DOC,
1755
+ modality="audio",
1756
+ )
1757
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
1758
+ def forward(
1759
+ self,
1760
+ input_values: Optional[torch.Tensor],
1761
+ attention_mask: Optional[torch.Tensor] = None,
1762
+ output_attentions: Optional[bool] = None,
1763
+ output_hidden_states: Optional[bool] = None,
1764
+ return_dict: Optional[bool] = None,
1765
+ labels: Optional[torch.Tensor] = None,
1766
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1767
+ r"""
1768
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1769
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1770
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1771
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1772
+ """
1773
+
1774
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1775
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1776
+
1777
+ outputs = self.wav2vec2_conformer(
1778
+ input_values,
1779
+ attention_mask=attention_mask,
1780
+ output_attentions=output_attentions,
1781
+ output_hidden_states=output_hidden_states,
1782
+ return_dict=return_dict,
1783
+ )
1784
+
1785
+ if self.config.use_weighted_layer_sum:
1786
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1787
+ hidden_states = torch.stack(hidden_states, dim=1)
1788
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1789
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1790
+ else:
1791
+ hidden_states = outputs[0]
1792
+
1793
+ hidden_states = self.projector(hidden_states)
1794
+ if attention_mask is None:
1795
+ pooled_output = hidden_states.mean(dim=1)
1796
+ else:
1797
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
1798
+ hidden_states[~padding_mask] = 0.0
1799
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
1800
+
1801
+ logits = self.classifier(pooled_output)
1802
+
1803
+ loss = None
1804
+ if labels is not None:
1805
+ loss_fct = CrossEntropyLoss()
1806
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1807
+
1808
+ if not return_dict:
1809
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1810
+ return ((loss,) + output) if loss is not None else output
1811
+
1812
+ return SequenceClassifierOutput(
1813
+ loss=loss,
1814
+ logits=logits,
1815
+ hidden_states=outputs.hidden_states,
1816
+ attentions=outputs.attentions,
1817
+ )
1818
+
1819
+
1820
+ @add_start_docstrings(
1821
+ """
1822
+ Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
1823
+ """,
1824
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1825
+ )
1826
+ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
1827
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
1828
+ def __init__(self, config):
1829
+ super().__init__(config)
1830
+
1831
+ if hasattr(config, "add_adapter") and config.add_adapter:
1832
+ raise ValueError(
1833
+ "Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
1834
+ )
1835
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1836
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1837
+ if config.use_weighted_layer_sum:
1838
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1839
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1840
+ self.num_labels = config.num_labels
1841
+
1842
+ self.init_weights()
1843
+
1844
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1845
+ def freeze_feature_encoder(self):
1846
+ """
1847
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1848
+ not be updated during training.
1849
+ """
1850
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1851
+
1852
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
1853
+ def freeze_base_model(self):
1854
+ """
1855
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
1856
+ be updated during training. Only the classification head will be updated.
1857
+ """
1858
+ for param in self.wav2vec2_conformer.parameters():
1859
+ param.requires_grad = False
1860
+
1861
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1862
+ @add_code_sample_docstrings(
1863
+ checkpoint=_CHECKPOINT_FOR_DOC,
1864
+ output_type=TokenClassifierOutput,
1865
+ config_class=_CONFIG_FOR_DOC,
1866
+ modality="audio",
1867
+ )
1868
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
1869
+ def forward(
1870
+ self,
1871
+ input_values: Optional[torch.Tensor],
1872
+ attention_mask: Optional[torch.Tensor] = None,
1873
+ labels: Optional[torch.Tensor] = None,
1874
+ output_attentions: Optional[bool] = None,
1875
+ output_hidden_states: Optional[bool] = None,
1876
+ return_dict: Optional[bool] = None,
1877
+ ) -> Union[Tuple, TokenClassifierOutput]:
1878
+ r"""
1879
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1880
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1881
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1882
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1883
+ """
1884
+
1885
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1886
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1887
+
1888
+ outputs = self.wav2vec2_conformer(
1889
+ input_values,
1890
+ attention_mask=attention_mask,
1891
+ output_attentions=output_attentions,
1892
+ output_hidden_states=output_hidden_states,
1893
+ return_dict=return_dict,
1894
+ )
1895
+
1896
+ if self.config.use_weighted_layer_sum:
1897
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1898
+ hidden_states = torch.stack(hidden_states, dim=1)
1899
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1900
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1901
+ else:
1902
+ hidden_states = outputs[0]
1903
+
1904
+ logits = self.classifier(hidden_states)
1905
+
1906
+ loss = None
1907
+ if labels is not None:
1908
+ loss_fct = CrossEntropyLoss()
1909
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
1910
+
1911
+ if not return_dict:
1912
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1913
+ return output
1914
+
1915
+ return TokenClassifierOutput(
1916
+ loss=loss,
1917
+ logits=logits,
1918
+ hidden_states=outputs.hidden_states,
1919
+ attentions=outputs.attentions,
1920
+ )
1921
+
1922
+
1923
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
1924
+ class AMSoftmaxLoss(nn.Module):
1925
+ def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
1926
+ super(AMSoftmaxLoss, self).__init__()
1927
+ self.scale = scale
1928
+ self.margin = margin
1929
+ self.num_labels = num_labels
1930
+ self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
1931
+ self.loss = nn.CrossEntropyLoss()
1932
+
1933
+ def forward(self, hidden_states, labels):
1934
+ labels = labels.flatten()
1935
+ weight = nn.functional.normalize(self.weight, dim=0)
1936
+ hidden_states = nn.functional.normalize(hidden_states, dim=1)
1937
+ cos_theta = torch.mm(hidden_states, weight)
1938
+ psi = cos_theta - self.margin
1939
+
1940
+ onehot = nn.functional.one_hot(labels, self.num_labels)
1941
+ logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
1942
+ loss = self.loss(logits, labels)
1943
+
1944
+ return loss
1945
+
1946
+
1947
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
1948
+ class TDNNLayer(nn.Module):
1949
+ def __init__(self, config, layer_id=0):
1950
+ super().__init__()
1951
+ self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
1952
+ self.out_conv_dim = config.tdnn_dim[layer_id]
1953
+ self.kernel_size = config.tdnn_kernel[layer_id]
1954
+ self.dilation = config.tdnn_dilation[layer_id]
1955
+
1956
+ self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
1957
+ self.activation = nn.ReLU()
1958
+
1959
+ def forward(self, hidden_states):
1960
+ hidden_states = hidden_states.unsqueeze(1)
1961
+ hidden_states = nn.functional.unfold(
1962
+ hidden_states,
1963
+ (self.kernel_size, self.in_conv_dim),
1964
+ stride=(1, self.in_conv_dim),
1965
+ dilation=(self.dilation, 1),
1966
+ )
1967
+ hidden_states = hidden_states.transpose(1, 2)
1968
+ hidden_states = self.kernel(hidden_states)
1969
+
1970
+ hidden_states = self.activation(hidden_states)
1971
+ return hidden_states
1972
+
1973
+
1974
+ @add_start_docstrings(
1975
+ """
1976
+ Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
1977
+ """,
1978
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
1979
+ )
1980
+ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
1981
+ def __init__(self, config):
1982
+ super().__init__(config)
1983
+
1984
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1985
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1986
+ if config.use_weighted_layer_sum:
1987
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1988
+ self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
1989
+
1990
+ tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
1991
+ self.tdnn = nn.ModuleList(tdnn_layers)
1992
+
1993
+ self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
1994
+ self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
1995
+
1996
+ self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
1997
+
1998
+ self.init_weights()
1999
+
2000
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
2001
+ def freeze_feature_encoder(self):
2002
+ """
2003
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
2004
+ not be updated during training.
2005
+ """
2006
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
2007
+
2008
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
2009
+ def freeze_base_model(self):
2010
+ """
2011
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
2012
+ be updated during training. Only the classification head will be updated.
2013
+ """
2014
+ for param in self.wav2vec2_conformer.parameters():
2015
+ param.requires_grad = False
2016
+
2017
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
2018
+ def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
2019
+ """
2020
+ Computes the output length of the TDNN layers
2021
+ """
2022
+
2023
+ def _conv_out_length(input_length, kernel_size, stride):
2024
+ # 1D convolutional layer output length formula taken
2025
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
2026
+ return (input_length - kernel_size) // stride + 1
2027
+
2028
+ for kernel_size in self.config.tdnn_kernel:
2029
+ input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
2030
+
2031
+ return input_lengths
2032
+
2033
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
2034
+ @add_code_sample_docstrings(
2035
+ checkpoint=_CHECKPOINT_FOR_DOC,
2036
+ output_type=XVectorOutput,
2037
+ config_class=_CONFIG_FOR_DOC,
2038
+ modality="audio",
2039
+ )
2040
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
2041
+ def forward(
2042
+ self,
2043
+ input_values: Optional[torch.Tensor],
2044
+ attention_mask: Optional[torch.Tensor] = None,
2045
+ output_attentions: Optional[bool] = None,
2046
+ output_hidden_states: Optional[bool] = None,
2047
+ return_dict: Optional[bool] = None,
2048
+ labels: Optional[torch.Tensor] = None,
2049
+ ) -> Union[Tuple, XVectorOutput]:
2050
+ r"""
2051
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
2052
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
2053
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
2054
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
2055
+ """
2056
+
2057
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2058
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
2059
+
2060
+ outputs = self.wav2vec2_conformer(
2061
+ input_values,
2062
+ attention_mask=attention_mask,
2063
+ output_attentions=output_attentions,
2064
+ output_hidden_states=output_hidden_states,
2065
+ return_dict=return_dict,
2066
+ )
2067
+
2068
+ if self.config.use_weighted_layer_sum:
2069
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
2070
+ hidden_states = torch.stack(hidden_states, dim=1)
2071
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
2072
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
2073
+ else:
2074
+ hidden_states = outputs[0]
2075
+
2076
+ hidden_states = self.projector(hidden_states)
2077
+
2078
+ for tdnn_layer in self.tdnn:
2079
+ hidden_states = tdnn_layer(hidden_states)
2080
+
2081
+ # Statistic Pooling
2082
+ if attention_mask is None:
2083
+ mean_features = hidden_states.mean(dim=1)
2084
+ std_features = hidden_states.std(dim=1)
2085
+ else:
2086
+ feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
2087
+ tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
2088
+ mean_features = []
2089
+ std_features = []
2090
+ for i, length in enumerate(tdnn_output_lengths):
2091
+ mean_features.append(hidden_states[i, :length].mean(dim=0))
2092
+ std_features.append(hidden_states[i, :length].std(dim=0))
2093
+ mean_features = torch.stack(mean_features)
2094
+ std_features = torch.stack(std_features)
2095
+ statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
2096
+
2097
+ output_embeddings = self.feature_extractor(statistic_pooling)
2098
+ logits = self.classifier(output_embeddings)
2099
+
2100
+ loss = None
2101
+ if labels is not None:
2102
+ loss = self.objective(logits, labels)
2103
+
2104
+ if not return_dict:
2105
+ output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
2106
+ return ((loss,) + output) if loss is not None else output
2107
+
2108
+ return XVectorOutput(
2109
+ loss=loss,
2110
+ logits=logits,
2111
+ embeddings=output_embeddings,
2112
+ hidden_states=outputs.hidden_states,
2113
+ attentions=outputs.attentions,
2114
+ )
MuCodec/muq_dev/muq_fairseq/models/muq/modules/random_quantizer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ from einops import rearrange
4
+
5
+
6
+ class RandomProjectionQuantizer(nn.Module):
7
+ """
8
+ Random projection and codebook lookup module
9
+
10
+ Some code is borrowed from:
11
+ https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
12
+ But I did normalization using pre-computed global mean & variance instead of using layer norm.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ input_dim,
18
+ codebook_dim,
19
+ codebook_size,
20
+ seed=142,
21
+ ):
22
+ super().__init__()
23
+
24
+ # random seed
25
+ torch.manual_seed(seed)
26
+
27
+ # randomly initialized projection
28
+ random_projection = torch.empty(input_dim, codebook_dim)
29
+ nn.init.xavier_normal_(random_projection)
30
+ self.register_buffer("random_projection", random_projection)
31
+
32
+ # randomly initialized codebook
33
+ codebook = torch.empty(codebook_size, codebook_dim)
34
+ nn.init.normal_(codebook)
35
+ self.register_buffer("codebook", codebook)
36
+
37
+ def codebook_lookup(self, x):
38
+ # reshape
39
+ b = x.shape[0]
40
+ x = rearrange(x, "b n e -> (b n) e")
41
+
42
+ # L2 normalization
43
+ normalized_x = nn.functional.normalize(x, dim=1, p=2)
44
+ normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
45
+
46
+ # compute distances
47
+ distances = torch.cdist(normalized_codebook, normalized_x)
48
+
49
+ # get nearest
50
+ nearest_indices = torch.argmin(distances, dim=0)
51
+
52
+ # reshape
53
+ xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
54
+
55
+ return xq
56
+
57
+ @torch.no_grad()
58
+ def forward(self, x):
59
+ # always eval
60
+ self.eval()
61
+
62
+ # random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
63
+ x = einsum("b n d, d e -> b n e", x, self.random_projection)
64
+
65
+ # codebook lookup
66
+ xq = self.codebook_lookup(x)
67
+
68
+ return xq
MuCodec/muq_dev/muq_fairseq/models/muq/muq_model.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from .model.muq import MuQ
3
+ except:
4
+ import sys, os
5
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
6
+ from model.muq import MuQ
7
+ try:
8
+ from fairseq.fairseq.dataclass import FairseqDataclass
9
+ from fairseq.fairseq.models import BaseFairseqModel, register_model
10
+ from fairseq.fairseq.tasks.fairseq_task import FairseqTask
11
+ except:
12
+ from fairseq.dataclass import FairseqDataclass
13
+ from fairseq.models import BaseFairseqModel, register_model
14
+ from fairseq.tasks.fairseq_task import FairseqTask
15
+
16
+ from dataclasses import dataclass, field
17
+ from typing import List, Tuple, Optional
18
+ import torch
19
+
20
+ from logging import getLogger
21
+
22
+ logger = getLogger(__name__)
23
+
24
+ @dataclass
25
+ class MuQConfig(FairseqDataclass):
26
+ label_rate:int = field(default=25)
27
+ num_codebooks:int = field(default=1)
28
+ codebook_dim:int = field(default=16)
29
+ codebook_size:int = field(default=4096)
30
+ features:List[str] = field(default_factory=lambda:["melspec_2048"])
31
+ hop_length:int = field(default=240)
32
+ n_mels:int = field(default=128)
33
+ conv_dim:int = field(default=512)
34
+ encoder_dim:int = field(default=1024)
35
+ encoder_depth:int = field(default=12)
36
+ mask_hop:float = field(default=0.4)
37
+ mask_prob:float = field(default=0.6)
38
+ is_flash:bool = field(default=False)
39
+ stat_path:Optional[str] = field(default=None)
40
+ model_path:Optional[str] = field(default=None)
41
+ w2v2_config_path:Optional[str] = field(default=None)
42
+ use_rvq_target:bool = field(default=False)
43
+ use_vq_target:bool = field(default=False)
44
+ rvq_ckpt_path: Optional[str] = field(default=None)
45
+ recon_loss_ratio: Optional[float] = field(default=None)
46
+ resume_checkpoint: Optional[str] = None
47
+ use_hubert_masking_strategy:bool = field(default=False)
48
+ use_hubert_featurizer:bool = field(default=False)
49
+ hubert_conv_feature_layers:str = field(default_factory=lambda:"[(512,10,5)] + [(512,3,2)] * 3 + [(512,3,3)] + [(512,2,2)] * 2")
50
+ rvq_n_codebooks:int = field(default=8)
51
+ rvq_multi_layer_num:int = field(default=1)
52
+ use_encodec_target:bool = field(default=False)
53
+
54
+ SAMPLE_RATE = 24_000
55
+
56
+ @register_model("muq", dataclass=MuQConfig)
57
+ class MuQModel(BaseFairseqModel):
58
+ def __init__(self, cfg: MuQConfig, task_cfg: FairseqTask):
59
+ super().__init__()
60
+ self.cfg = cfg
61
+ self.model = MuQ(
62
+ num_codebooks=cfg.num_codebooks,
63
+ codebook_dim=cfg.codebook_dim,
64
+ codebook_size=cfg.codebook_size,
65
+ features=cfg.features,
66
+ n_mels=cfg.n_mels,
67
+ conv_dim=cfg.conv_dim,
68
+ encoder_dim=cfg.encoder_dim,
69
+ encoder_depth=cfg.encoder_depth,
70
+ mask_hop=cfg.mask_hop,
71
+ mask_prob=cfg.mask_prob,
72
+ is_flash=cfg.is_flash,
73
+ stat_path=cfg.stat_path,
74
+ model_path=cfg.model_path,
75
+ w2v2_config_path=cfg.w2v2_config_path,
76
+ use_rvq_target=cfg.use_rvq_target,
77
+ use_vq_target=cfg.use_vq_target,
78
+ rvq_ckpt_path=cfg.rvq_ckpt_path,
79
+ recon_loss_ratio=cfg.recon_loss_ratio,
80
+ label_rate=cfg.label_rate,
81
+ use_hubert_masking_strategy=cfg.use_hubert_masking_strategy,
82
+ use_hubert_featurizer=cfg.use_hubert_featurizer,
83
+ hubert_conv_feature_layers=cfg.hubert_conv_feature_layers,
84
+ rvq_n_codebooks=cfg.rvq_n_codebooks,
85
+ rvq_multi_layer_num=cfg.rvq_multi_layer_num,
86
+ use_encodec_target=cfg.use_encodec_target,
87
+ )
88
+
89
+ def forward(
90
+ self,
91
+ source: torch.Tensor, # B,L
92
+ features_only: bool = False,
93
+ label = None, # pre-extracted labeks, dim is [Batch, N_Codebook, SeqLen]
94
+ **kwargs,
95
+ ):
96
+ source = source[..., :int((source.shape[-1]//(SAMPLE_RATE//self.cfg.label_rate))*(SAMPLE_RATE//self.cfg.label_rate)) ]
97
+ if features_only:
98
+ if 'attention_mask' in kwargs:
99
+ attention_mask = kwargs['attention_mask']
100
+ elif 'padding_mask' in kwargs:
101
+ attention_mask = ~kwargs['padding_mask'].bool()
102
+ else:
103
+ attention_mask = None
104
+ _, hidden_states = self.model.get_predictions(source, attention_mask=attention_mask, is_features_only=True)
105
+ result = {
106
+ "layer_results": hidden_states
107
+ }
108
+ return result
109
+ else:
110
+ result = {}
111
+ logits, hidden_emb, losses, accuracies = self.model(source, label=label)
112
+ result["losses"] = losses
113
+ result["accuracies"] = accuracies
114
+ result["logits"] = logits
115
+ result["hidden_emb"] = hidden_emb
116
+ for k, v in losses.items():
117
+ result[k] = v
118
+ return result
119
+
120
+ @classmethod
121
+ def build_model(cls, cfg: MuQConfig, task: FairseqTask):
122
+ """Build a new model instance."""
123
+
124
+ model = MuQModel(cfg, task.cfg)
125
+ import numpy as np
126
+ s = 0
127
+ for param in model.parameters():
128
+ s += np.product(param.size())
129
+ # print('# of parameters: '+str(s/1024.0/1024.0))
130
+
131
+ if cfg.get("resume_checkpoint", None):
132
+ print("Loading checkpoint from {}".format(cfg.resume_checkpoint))
133
+ model.load_state_dict(torch.load(cfg.resume_checkpoint)['model'], strict=False)
134
+
135
+ return model
136
+
137
+ def get_losses(self, result, batch):
138
+ return result['losses']
139
+
MuCodec/muq_dev/muq_fairseq/tasks/__pycache__/muq_pretraining.cpython-310.pyc ADDED
Binary file (9.93 kB). View file
 
MuCodec/muq_dev/muq_fairseq/tasks/muq_pretraining.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.
7
+
8
+ import logging
9
+ import os
10
+ import sys
11
+ from typing import Dict, List, Optional, Tuple
12
+
13
+ import numpy as np
14
+ import torch
15
+
16
+ from dataclasses import dataclass, field
17
+ from fairseq.data import Dictionary, HubertDataset
18
+ from fairseq.dataclass.configs import FairseqDataclass
19
+ from fairseq.tasks import register_task
20
+ from fairseq.tasks.fairseq_task import FairseqTask
21
+ from omegaconf import MISSING
22
+
23
+ from ..data.mert_dataset import MERTDataset
24
+ from ..data.ark_dataset import ArkDataset
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class LabelEncoder(object):
30
+ def __init__(self, dictionary: Dictionary) -> None:
31
+ self.dictionary = dictionary
32
+
33
+ def __call__(self, label: str) -> List[str]:
34
+ # encode_line return a torch.IntTensor, should be all 1 for vanila HuBERT
35
+ return self.dictionary.encode_line(
36
+ label,
37
+ append_eos=False,
38
+ add_if_not_exist=False,
39
+ )
40
+ class PaddedNumpyLabelEncoder(object):
41
+ def __init__(self):
42
+ # self.dictionary = dictionary
43
+ pass
44
+
45
+ def __call__(self, label):
46
+ t = torch.IntTensor(np.asarray(label))
47
+ t = t[t>=0] # remove padded -1 values at the end
48
+ return t
49
+
50
+ @dataclass
51
+ class MuQPretrainingConfig(FairseqDataclass):
52
+ data: str = field(default=MISSING, metadata={"help": "path to data directory"})
53
+ sharding_data: int = field(
54
+ default=-1,
55
+ metadata={
56
+ "help": "set this para >1 to use sharding dataset to prevent OOM"
57
+ "prepare data tsv and label files by adding postfix for sharding 64 like:"
58
+ "train_28_64.tsv and train_28_64.encodec_6"
59
+ },
60
+ )
61
+ load_random_data_shard: bool = field(
62
+ default=True,
63
+ metadata={
64
+ "help": "whether to laod shards randomly or in order when use sharding_data"
65
+ },
66
+ )
67
+ fine_tuning: bool = field(
68
+ default=False, metadata={"help": "set to true if fine-tuning Hubert"}
69
+ )
70
+ labels: List[str] = field(
71
+ default_factory=lambda: ["ltr"],
72
+ metadata={
73
+ "help": (
74
+ "extension of the label files to load, frame-level labels for"
75
+ " pre-training, and sequence-level label for fine-tuning"
76
+ )
77
+ },
78
+ )
79
+ label_dir: Optional[str] = field(
80
+ default=None,
81
+ metadata={
82
+ "help": "if set, looks for labels in this directory instead",
83
+ },
84
+ )
85
+ label_scp_path: Optional[str] = field(
86
+ default=None,
87
+ metadata={
88
+ 'help': 'if set, load label from scp file'
89
+ }
90
+ )
91
+ label_scp_clip_duration: float = field(
92
+ default=-1,
93
+ metadata={
94
+ 'help': 'clip duration for loading scp label. if set to -1, this will not make effect.'
95
+ }
96
+ )
97
+ label_rate: float = field(
98
+ default=-1.0,
99
+ metadata={"help": "label frame rate. -1.0 for sequence label"},
100
+ )
101
+ sample_rate: int = field(
102
+ default=16_000,
103
+ metadata={
104
+ "help": "target sample rate. audio files will be up/down "
105
+ "sampled to this rate"
106
+ },
107
+ )
108
+ normalize: bool = field(
109
+ default=False,
110
+ metadata={"help": "if set, normalizes input to have 0 mean and unit variance"},
111
+ )
112
+ enable_padding: bool = field(
113
+ default=False,
114
+ metadata={"help": "pad shorter samples instead of cropping"},
115
+ )
116
+ max_keep_size: Optional[int] = field(
117
+ default=None,
118
+ metadata={"help": "exclude sample longer than this"},
119
+ )
120
+ max_sample_size: Optional[int] = field(
121
+ default=None,
122
+ metadata={"help": "max sample size to crop to for batching"},
123
+ )
124
+ min_sample_size: Optional[int] = field(
125
+ default=None,
126
+ metadata={"help": "min sample size to crop to for batching"},
127
+ )
128
+ single_target: Optional[bool] = field(
129
+ default=False,
130
+ metadata={
131
+ "help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset"
132
+ },
133
+ )
134
+ random_crop: Optional[bool] = field(
135
+ default=True,
136
+ metadata={"help": "always crop from the beginning if false"},
137
+ )
138
+ pad_audio: Optional[bool] = field(
139
+ default=False,
140
+ metadata={"help": "pad audio to the longest one in the batch if true"},
141
+ )
142
+
143
+ store_labels: Optional[bool] = field(
144
+ default=False,
145
+ metadata={"help": "whether to load all of the label into memory"},
146
+ )
147
+
148
+ numpy_memmap_label: Optional[bool] = field(
149
+ default=False,
150
+ metadata={"help": "whether the label file is saved as a numpy file, each line is ended with padding -1"},
151
+ )
152
+
153
+ augmentation_effects: Optional[str] = field(
154
+ default="[]",
155
+ metadata={
156
+ "help": (
157
+ "a list of effects that might apply to the audios"
158
+ "example: \"['random_mute', 'random_Gaussian', 'reverse_polarity']\" "
159
+ "supported: random_mute,"
160
+ "todo: "
161
+ )
162
+ },
163
+ )
164
+ augmentation_probs: Optional[str] = field(
165
+ default="[]",
166
+ metadata={
167
+ "help": (
168
+ "the corresponding probabilities for the data augmentation effects"
169
+ "example: \"[0.1, 0.5, 0.8]\" "
170
+ "the sum is not necessarily need to be 1.0, and multiple effects can be applied to the same audio"
171
+ )
172
+ },
173
+ )
174
+
175
+ # inbatch_noise_augment_len_range: Optional[List[int]] = field(
176
+ # default_factory=lambda: [8000, 24000],
177
+ # default = [8000, 24000],
178
+ inbatch_noise_augment_len_range: Optional[str] = field(
179
+ default = "[8000, 24000]",
180
+ metadata={
181
+ "help": (
182
+ "the range of length of the mix-up noise augmentation, unit in smaples"
183
+ )
184
+ },
185
+ )
186
+ # inbatch_noise_augment_number_range: Optional[List[int]] = field(
187
+ # default_factory=lambda: [1, 3],
188
+ # default = [1, 3],
189
+ inbatch_noise_augment_number_range: Optional[str] = field(
190
+ default = "[1, 3]",
191
+ metadata={
192
+ "help": (
193
+ "the range of numbers of the mix-up noise augmentation"
194
+ )
195
+ },
196
+ )
197
+ inbatch_noise_augment_volume: float = field(
198
+ default = 1.0,
199
+ metadata={
200
+ "help": (
201
+ "the coefficient used to modify the volume of the noise audios wavs"
202
+ )
203
+ },
204
+ )
205
+ dynamic_crops: Optional[str] = field(
206
+ default="[]",
207
+ metadata={
208
+ "help": (
209
+ "used to set the maximum audio length setting, for training"
210
+ "example: \"[1, 2, 3, 4, 5, 10]\" "
211
+ )
212
+ },
213
+ )
214
+ dynamic_crops_epoches: Optional[str] = field(
215
+ default="[]",
216
+ metadata={
217
+ "help": (
218
+ "used to set training epoches of changing the maximum audio length"
219
+ "example: \"[1, 10, 20, 40, 80, 160,]\" "
220
+ "then len need to be equal to len(dynamic_crops)"
221
+ )
222
+ },
223
+ )
224
+
225
+ cqt_loss_bin_dataloader: Optional[int] = field(
226
+ default=-1,
227
+ metadata={
228
+ "help": (
229
+ "use this parameter to prepare cqt prediction objective in dataloader"
230
+ )
231
+ },
232
+ )
233
+
234
+ clip_secs: int = field(
235
+ default=5,
236
+ metadata={
237
+ "help": "clip secs for each audio"
238
+ }
239
+ )
240
+
241
+ dataset_shuffle: bool = field(
242
+ default=True,
243
+ metadata={
244
+ "help": (
245
+ "dataset shuffle when sample a batch"
246
+ )
247
+ },
248
+ )
249
+
250
+
251
+ @register_task("muq_pretraining", dataclass=MuQPretrainingConfig)
252
+ class MuQPretrainingTask(FairseqTask):
253
+
254
+ cfg: MuQPretrainingConfig
255
+
256
+ def __init__(
257
+ self,
258
+ cfg: MuQPretrainingConfig,
259
+ ) -> None:
260
+ super().__init__(cfg)
261
+
262
+ logger.info(f"current directory is {os.getcwd()}")
263
+ logger.info(f"MuQPretrainingTask Config {cfg}")
264
+
265
+ self.cfg = cfg
266
+ self.fine_tuning = cfg.fine_tuning
267
+
268
+ if cfg.fine_tuning:
269
+ self.state.add_factory("target_dictionary", self.load_dictionaries)
270
+ else:
271
+ self.state.add_factory("dictionaries", self.load_dictionaries)
272
+
273
+ self.blank_symbol = "<s>"
274
+
275
+ # use eval() to pass list parameters, skirt the fairseq/torch error: Can't pickle <enum 'Choices'>: attribute lookup Choices on fairseq.dataclass.constants failed
276
+ self.augmentation_effects = eval(self.cfg.augmentation_effects)
277
+ self.augmentation_probs = eval(self.cfg.augmentation_probs)
278
+ if len(self.augmentation_effects) > 0:
279
+ assert len(self.augmentation_effects) == len(self.augmentation_probs)
280
+ logger.info(f"Applying audio augmentation {self.augmentation_effects}, probabilities: {self.augmentation_probs}")
281
+
282
+ self.inbatch_noise_augment_number_range = eval(self.cfg.inbatch_noise_augment_number_range)
283
+ self.inbatch_noise_augment_len_range = eval(self.cfg.inbatch_noise_augment_len_range)
284
+
285
+ self.max_sample_size = self.cfg.max_sample_size
286
+
287
+ self.dynamic_crops = eval(self.cfg.dynamic_crops)
288
+ self.dynamic_crops_epoches = eval(self.cfg.dynamic_crops_epoches)
289
+ assert len(self.dynamic_crops) == len(self.dynamic_crops_epoches)
290
+ if len(self.dynamic_crops) > 0:
291
+ assert self.dynamic_crops_epoches[0] == 1
292
+
293
+ self.cqt_loss_bin_dataloader = self.cfg.cqt_loss_bin_dataloader
294
+
295
+ self.numpy_memmap_label = self.cfg.numpy_memmap_label
296
+ self.store_labels = self.cfg.store_labels
297
+ if self.numpy_memmap_label:
298
+ assert self.store_labels
299
+
300
+ @property
301
+ def source_dictionary(self) -> Optional[Dictionary]:
302
+ return None
303
+
304
+ @property
305
+ def target_dictionary(self) -> Optional[Dictionary]:
306
+ return self.state.target_dictionary
307
+
308
+ @property
309
+ def dictionaries(self) -> List[Dictionary]:
310
+ return self.state.dictionaries
311
+
312
+ @classmethod
313
+ def setup_task(
314
+ cls, cfg: MuQPretrainingConfig, **kwargs
315
+ ) -> "MuQPretrainingTask":
316
+ return cls(cfg)
317
+
318
+ def load_dictionaries(self):
319
+ label_dir = self.cfg.data if (self.cfg.label_dir is None or self.cfg.label_dir == '') else self.cfg.label_dir
320
+ print(label_dir)
321
+ dictionaries = [
322
+ Dictionary.load(f"{label_dir}/dict.{label}.txt")
323
+ for label in self.cfg.labels
324
+ ]
325
+ return dictionaries[0] if self.cfg.fine_tuning else dictionaries
326
+
327
+ def get_label_dir(self) -> str:
328
+ if self.cfg.label_dir is None or self.cfg.label_dir=='':
329
+ return self.cfg.data
330
+ return self.cfg.label_dir
331
+
332
+
333
+ def is_force_load_dataset(self, epoch, training_restore=False):
334
+ # find the threshold that holds epoch \in [threshold, next_threshold)
335
+ return (epoch in self.dynamic_crops_epoches) or training_restore or (self.cfg.sharding_data > 1)
336
+
337
+
338
+ def set_dynamic_crop_max_sample(self, epoch):
339
+ pass
340
+
341
+ def load_dataset(self, split: str, **kwargs) -> None:
342
+ pass
343
+
344
+ def load_dataset_ark(self, split, **kwargs):
345
+ pass
346
+
347
+ def load_dataset_mert(self, split: str, **kwargs) -> None:
348
+ pass
349
+
350
+ def max_positions(self) -> Tuple[int, int]:
351
+ return (sys.maxsize, sys.maxsize)
352
+
353
+ def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array:
354
+ return indices
MuCodec/tools/__pycache__/get_melvaehifigan48k.cpython-310.pyc ADDED
Binary file (35.6 kB). View file
 
MuCodec/tools/__pycache__/torch_tools.cpython-310.pyc ADDED
Binary file (2.74 kB). View file
 
MuCodec/tools/__pycache__/torch_tools.cpython-312.pyc ADDED
Binary file (4.48 kB). View file
 
checkpoints/Qwen3-0.6B/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
checkpoints/Qwen3-0.6B/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright 2024 Alibaba Cloud
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
checkpoints/Qwen3-0.6B/README.md ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: apache-2.0
4
+ license_link: https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/LICENSE
5
+ pipeline_tag: text-generation
6
+ base_model:
7
+ - Qwen/Qwen3-0.6B-Base
8
+ ---
9
+
10
+ # Qwen3-0.6B
11
+ <a href="https://chat.qwen.ai/" target="_blank" style="margin: 2px;">
12
+ <img alt="Chat" src="https://img.shields.io/badge/%F0%9F%92%9C%EF%B8%8F%20Qwen%20Chat%20-536af5" style="display: inline-block; vertical-align: middle;"/>
13
+ </a>
14
+
15
+ ## Qwen3 Highlights
16
+
17
+ Qwen3 is the latest generation of large language models in Qwen series, offering a comprehensive suite of dense and mixture-of-experts (MoE) models. Built upon extensive training, Qwen3 delivers groundbreaking advancements in reasoning, instruction-following, agent capabilities, and multilingual support, with the following key features:
18
+
19
+ - **Uniquely support of seamless switching between thinking mode** (for complex logical reasoning, math, and coding) and **non-thinking mode** (for efficient, general-purpose dialogue) **within single model**, ensuring optimal performance across various scenarios.
20
+ - **Significantly enhancement in its reasoning capabilities**, surpassing previous QwQ (in thinking mode) and Qwen2.5 instruct models (in non-thinking mode) on mathematics, code generation, and commonsense logical reasoning.
21
+ - **Superior human preference alignment**, excelling in creative writing, role-playing, multi-turn dialogues, and instruction following, to deliver a more natural, engaging, and immersive conversational experience.
22
+ - **Expertise in agent capabilities**, enabling precise integration with external tools in both thinking and unthinking modes and achieving leading performance among open-source models in complex agent-based tasks.
23
+ - **Support of 100+ languages and dialects** with strong capabilities for **multilingual instruction following** and **translation**.
24
+
25
+ ## Model Overview
26
+
27
+ **Qwen3-0.6B** has the following features:
28
+ - Type: Causal Language Models
29
+ - Training Stage: Pretraining & Post-training
30
+ - Number of Parameters: 0.6B
31
+ - Number of Paramaters (Non-Embedding): 0.44B
32
+ - Number of Layers: 28
33
+ - Number of Attention Heads (GQA): 16 for Q and 8 for KV
34
+ - Context Length: 32,768
35
+
36
+ For more details, including benchmark evaluation, hardware requirements, and inference performance, please refer to our [blog](https://qwenlm.github.io/blog/qwen3/), [GitHub](https://github.com/QwenLM/Qwen3), and [Documentation](https://qwen.readthedocs.io/en/latest/).
37
+
38
+ > [!TIP]
39
+ > If you encounter significant endless repetitions, please refer to the [Best Practices](#best-practices) section for optimal sampling parameters, and set the ``presence_penalty`` to 1.5.
40
+
41
+ ## Quickstart
42
+
43
+ The code of Qwen3 has been in the latest Hugging Face `transformers` and we advise you to use the latest version of `transformers`.
44
+
45
+ With `transformers<4.51.0`, you will encounter the following error:
46
+ ```
47
+ KeyError: 'qwen3'
48
+ ```
49
+
50
+ The following contains a code snippet illustrating how to use the model generate content based on given inputs.
51
+ ```python
52
+ from transformers import AutoModelForCausalLM, AutoTokenizer
53
+
54
+ model_name = "Qwen/Qwen3-0.6B"
55
+
56
+ # load the tokenizer and the model
57
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ model_name,
60
+ torch_dtype="auto",
61
+ device_map="auto"
62
+ )
63
+
64
+ # prepare the model input
65
+ prompt = "Give me a short introduction to large language model."
66
+ messages = [
67
+ {"role": "user", "content": prompt}
68
+ ]
69
+ text = tokenizer.apply_chat_template(
70
+ messages,
71
+ tokenize=False,
72
+ add_generation_prompt=True,
73
+ enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.
74
+ )
75
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
76
+
77
+ # conduct text completion
78
+ generated_ids = model.generate(
79
+ **model_inputs,
80
+ max_new_tokens=32768
81
+ )
82
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
83
+
84
+ # parsing thinking content
85
+ try:
86
+ # rindex finding 151668 (</think>)
87
+ index = len(output_ids) - output_ids[::-1].index(151668)
88
+ except ValueError:
89
+ index = 0
90
+
91
+ thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
92
+ content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
93
+
94
+ print("thinking content:", thinking_content)
95
+ print("content:", content)
96
+ ```
97
+
98
+ For deployment, you can use `sglang>=0.4.6.post1` or `vllm>=0.8.5` or to create an OpenAI-compatible API endpoint:
99
+ - SGLang:
100
+ ```shell
101
+ python -m sglang.launch_server --model-path Qwen/Qwen3-0.6B --reasoning-parser qwen3
102
+ ```
103
+ - vLLM:
104
+ ```shell
105
+ vllm serve Qwen/Qwen3-0.6B --enable-reasoning --reasoning-parser deepseek_r1
106
+ ```
107
+
108
+ For local use, applications such as Ollama, LMStudio, MLX-LM, llama.cpp, and KTransformers have also supported Qwen3.
109
+
110
+ ## Switching Between Thinking and Non-Thinking Mode
111
+
112
+ > [!TIP]
113
+ > The `enable_thinking` switch is also available in APIs created by SGLang and vLLM.
114
+ > Please refer to our documentation for [SGLang](https://qwen.readthedocs.io/en/latest/deployment/sglang.html#thinking-non-thinking-modes) and [vLLM](https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes) users.
115
+
116
+ ### `enable_thinking=True`
117
+
118
+ By default, Qwen3 has thinking capabilities enabled, similar to QwQ-32B. This means the model will use its reasoning abilities to enhance the quality of generated responses. For example, when explicitly setting `enable_thinking=True` or leaving it as the default value in `tokenizer.apply_chat_template`, the model will engage its thinking mode.
119
+
120
+ ```python
121
+ text = tokenizer.apply_chat_template(
122
+ messages,
123
+ tokenize=False,
124
+ add_generation_prompt=True,
125
+ enable_thinking=True # True is the default value for enable_thinking
126
+ )
127
+ ```
128
+
129
+ In this mode, the model will generate think content wrapped in a `<think>...</think>` block, followed by the final response.
130
+
131
+ > [!NOTE]
132
+ > For thinking mode, use `Temperature=0.6`, `TopP=0.95`, `TopK=20`, and `MinP=0` (the default setting in `generation_config.json`). **DO NOT use greedy decoding**, as it can lead to performance degradation and endless repetitions. For more detailed guidance, please refer to the [Best Practices](#best-practices) section.
133
+
134
+
135
+ ### `enable_thinking=False`
136
+
137
+ We provide a hard switch to strictly disable the model's thinking behavior, aligning its functionality with the previous Qwen2.5-Instruct models. This mode is particularly useful in scenarios where disabling thinking is essential for enhancing efficiency.
138
+
139
+ ```python
140
+ text = tokenizer.apply_chat_template(
141
+ messages,
142
+ tokenize=False,
143
+ add_generation_prompt=True,
144
+ enable_thinking=False # Setting enable_thinking=False disables thinking mode
145
+ )
146
+ ```
147
+
148
+ In this mode, the model will not generate any think content and will not include a `<think>...</think>` block.
149
+
150
+ > [!NOTE]
151
+ > For non-thinking mode, we suggest using `Temperature=0.7`, `TopP=0.8`, `TopK=20`, and `MinP=0`. For more detailed guidance, please refer to the [Best Practices](#best-practices) section.
152
+
153
+ ### Advanced Usage: Switching Between Thinking and Non-Thinking Modes via User Input
154
+
155
+ We provide a soft switch mechanism that allows users to dynamically control the model's behavior when `enable_thinking=True`. Specifically, you can add `/think` and `/no_think` to user prompts or system messages to switch the model's thinking mode from turn to turn. The model will follow the most recent instruction in multi-turn conversations.
156
+
157
+ Here is an example of a multi-turn conversation:
158
+
159
+ ```python
160
+ from transformers import AutoModelForCausalLM, AutoTokenizer
161
+
162
+ class QwenChatbot:
163
+ def __init__(self, model_name="Qwen/Qwen3-0.6B"):
164
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
165
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
166
+ self.history = []
167
+
168
+ def generate_response(self, user_input):
169
+ messages = self.history + [{"role": "user", "content": user_input}]
170
+
171
+ text = self.tokenizer.apply_chat_template(
172
+ messages,
173
+ tokenize=False,
174
+ add_generation_prompt=True
175
+ )
176
+
177
+ inputs = self.tokenizer(text, return_tensors="pt")
178
+ response_ids = self.model.generate(**inputs, max_new_tokens=32768)[0][len(inputs.input_ids[0]):].tolist()
179
+ response = self.tokenizer.decode(response_ids, skip_special_tokens=True)
180
+
181
+ # Update history
182
+ self.history.append({"role": "user", "content": user_input})
183
+ self.history.append({"role": "assistant", "content": response})
184
+
185
+ return response
186
+
187
+ # Example Usage
188
+ if __name__ == "__main__":
189
+ chatbot = QwenChatbot()
190
+
191
+ # First input (without /think or /no_think tags, thinking mode is enabled by default)
192
+ user_input_1 = "How many r's in strawberries?"
193
+ print(f"User: {user_input_1}")
194
+ response_1 = chatbot.generate_response(user_input_1)
195
+ print(f"Bot: {response_1}")
196
+ print("----------------------")
197
+
198
+ # Second input with /no_think
199
+ user_input_2 = "Then, how many r's in blueberries? /no_think"
200
+ print(f"User: {user_input_2}")
201
+ response_2 = chatbot.generate_response(user_input_2)
202
+ print(f"Bot: {response_2}")
203
+ print("----------------------")
204
+
205
+ # Third input with /think
206
+ user_input_3 = "Really? /think"
207
+ print(f"User: {user_input_3}")
208
+ response_3 = chatbot.generate_response(user_input_3)
209
+ print(f"Bot: {response_3}")
210
+ ```
211
+
212
+ > [!NOTE]
213
+ > For API compatibility, when `enable_thinking=True`, regardless of whether the user uses `/think` or `/no_think`, the model will always output a block wrapped in `<think>...</think>`. However, the content inside this block may be empty if thinking is disabled.
214
+ > When `enable_thinking=False`, the soft switches are not valid. Regardless of any `/think` or `/no_think` tags input by the user, the model will not generate think content and will not include a `<think>...</think>` block.
215
+
216
+ ## Agentic Use
217
+
218
+ Qwen3 excels in tool calling capabilities. We recommend using [Qwen-Agent](https://github.com/QwenLM/Qwen-Agent) to make the best use of agentic ability of Qwen3. Qwen-Agent encapsulates tool-calling templates and tool-calling parsers internally, greatly reducing coding complexity.
219
+
220
+ To define the available tools, you can use the MCP configuration file, use the integrated tool of Qwen-Agent, or integrate other tools by yourself.
221
+ ```python
222
+ from qwen_agent.agents import Assistant
223
+
224
+ # Define LLM
225
+ llm_cfg = {
226
+ 'model': 'Qwen3-0.6B',
227
+
228
+ # Use the endpoint provided by Alibaba Model Studio:
229
+ # 'model_type': 'qwen_dashscope',
230
+ # 'api_key': os.getenv('DASHSCOPE_API_KEY'),
231
+
232
+ # Use a custom endpoint compatible with OpenAI API:
233
+ 'model_server': 'http://localhost:8000/v1', # api_base
234
+ 'api_key': 'EMPTY',
235
+
236
+ # Other parameters:
237
+ # 'generate_cfg': {
238
+ # # Add: When the response content is `<think>this is the thought</think>this is the answer;
239
+ # # Do not add: When the response has been separated by reasoning_content and content.
240
+ # 'thought_in_content': True,
241
+ # },
242
+ }
243
+
244
+ # Define Tools
245
+ tools = [
246
+ {'mcpServers': { # You can specify the MCP configuration file
247
+ 'time': {
248
+ 'command': 'uvx',
249
+ 'args': ['mcp-server-time', '--local-timezone=Asia/Shanghai']
250
+ },
251
+ "fetch": {
252
+ "command": "uvx",
253
+ "args": ["mcp-server-fetch"]
254
+ }
255
+ }
256
+ },
257
+ 'code_interpreter', # Built-in tools
258
+ ]
259
+
260
+ # Define Agent
261
+ bot = Assistant(llm=llm_cfg, function_list=tools)
262
+
263
+ # Streaming generation
264
+ messages = [{'role': 'user', 'content': 'https://qwenlm.github.io/blog/ Introduce the latest developments of Qwen'}]
265
+ for responses in bot.run(messages=messages):
266
+ pass
267
+ print(responses)
268
+ ```
269
+
270
+ ## Best Practices
271
+
272
+ To achieve optimal performance, we recommend the following settings:
273
+
274
+ 1. **Sampling Parameters**:
275
+ - For thinking mode (`enable_thinking=True`), use `Temperature=0.6`, `TopP=0.95`, `TopK=20`, and `MinP=0`. **DO NOT use greedy decoding**, as it can lead to performance degradation and endless repetitions.
276
+ - For non-thinking mode (`enable_thinking=False`), we suggest using `Temperature=0.7`, `TopP=0.8`, `TopK=20`, and `MinP=0`.
277
+ - For supported frameworks, you can adjust the `presence_penalty` parameter between 0 and 2 to reduce endless repetitions. However, using a higher value may occasionally result in language mixing and a slight decrease in model performance.
278
+
279
+ 2. **Adequate Output Length**: We recommend using an output length of 32,768 tokens for most queries. For benchmarking on highly complex problems, such as those found in math and programming competitions, we suggest setting the max output length to 38,912 tokens. This provides the model with sufficient space to generate detailed and comprehensive responses, thereby enhancing its overall performance.
280
+
281
+ 3. **Standardize Output Format**: We recommend using prompts to standardize model outputs when benchmarking.
282
+ - **Math Problems**: Include "Please reason step by step, and put your final answer within \boxed{}." in the prompt.
283
+ - **Multiple-Choice Questions**: Add the following JSON structure to the prompt to standardize responses: "Please show your choice in the `answer` field with only the choice letter, e.g., `"answer": "C"`."
284
+
285
+ 4. **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final output part and does not need to include the thinking content. It is implemented in the provided chat template in Jinja2. However, for frameworks that do not directly use the Jinja2 chat template, it is up to the developers to ensure that the best practice is followed.
286
+
287
+ ### Citation
288
+
289
+ If you find our work helpful, feel free to give us a cite.
290
+
291
+ ```
292
+ @misc{qwen3technicalreport,
293
+ title={Qwen3 Technical Report},
294
+ author={Qwen Team},
295
+ year={2025},
296
+ eprint={2505.09388},
297
+ archivePrefix={arXiv},
298
+ primaryClass={cs.CL},
299
+ url={https://arxiv.org/abs/2505.09388},
300
+ }
301
+ ```
checkpoints/Qwen3-0.6B/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "head_dim": 128,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "max_position_embeddings": 40960,
15
+ "max_window_layers": 28,
16
+ "model_type": "qwen3",
17
+ "num_attention_heads": 16,
18
+ "num_hidden_layers": 28,
19
+ "num_key_value_heads": 8,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": null,
22
+ "rope_theta": 1000000,
23
+ "sliding_window": null,
24
+ "tie_word_embeddings": true,
25
+ "torch_dtype": "bfloat16",
26
+ "transformers_version": "4.51.0",
27
+ "use_cache": true,
28
+ "use_sliding_window": false,
29
+ "vocab_size": 151936,
30
+ "magel_chord_dropout_trigger_prob": 0.6,
31
+ "magel_structure_dropout_trigger_prob": 0.6,
32
+ "magel_num_audio_token": 16384
33
+ }
checkpoints/Qwen3-0.6B/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.51.0"
13
+ }
checkpoints/Qwen3-0.6B/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
checkpoints/Qwen3-0.6B/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
231
+ "clean_up_tokenization_spaces": false,
232
+ "eos_token": "<|im_end|>",
233
+ "errors": "replace",
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
checkpoints/Qwen3-0.6B/vocab.json ADDED
The diff for this file is too large to render. See raw diff