Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- MuCodec/libs/rvq/__pycache__/descript_quantize3.cpython-312.pyc +0 -0
- MuCodec/models/__pycache__/attention.cpython-310.pyc +0 -0
- MuCodec/models/__pycache__/attention.cpython-312.pyc +0 -0
- MuCodec/models/__pycache__/transformer_2d_flow.cpython-310.pyc +0 -0
- MuCodec/models/__pycache__/transformer_2d_flow.cpython-312.pyc +0 -0
- MuCodec/muq_dev/__pycache__/test.cpython-310.pyc +0 -0
- MuCodec/muq_dev/__pycache__/test.cpython-312.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/data/__init__.py +1 -0
- MuCodec/muq_dev/muq_fairseq/data/__pycache__/__init__.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/data/__pycache__/ark_dataset.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/data/__pycache__/mert_dataset.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/data/ark_dataset.py +71 -0
- MuCodec/muq_dev/muq_fairseq/data/mert_dataset.py +295 -0
- MuCodec/muq_dev/muq_fairseq/data/utils/data_utils.py +535 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/__init__.py +1 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/__pycache__/__init__.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/__pycache__/muq_model.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/model/__init__.py +2 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/__init__.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/muq.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/rvq.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/rvq_muq.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/model/muq.py +520 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/model/pred_ark_target_with_model.py +151 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq.py +459 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq_muq.py +394 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/model/w2v2_config.json +113 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/__init__.py +2 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/__init__.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/conv.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/features.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/random_quantizer.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/conv.py +77 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/features.py +67 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/flash_conformer.py +2114 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/random_quantizer.py +68 -0
- MuCodec/muq_dev/muq_fairseq/models/muq/muq_model.py +139 -0
- MuCodec/muq_dev/muq_fairseq/tasks/__pycache__/muq_pretraining.cpython-310.pyc +0 -0
- MuCodec/muq_dev/muq_fairseq/tasks/muq_pretraining.py +354 -0
- MuCodec/tools/__pycache__/get_melvaehifigan48k.cpython-310.pyc +0 -0
- MuCodec/tools/__pycache__/torch_tools.cpython-310.pyc +0 -0
- MuCodec/tools/__pycache__/torch_tools.cpython-312.pyc +0 -0
- checkpoints/Qwen3-0.6B/.gitattributes +36 -0
- checkpoints/Qwen3-0.6B/LICENSE +202 -0
- checkpoints/Qwen3-0.6B/README.md +301 -0
- checkpoints/Qwen3-0.6B/config.json +33 -0
- checkpoints/Qwen3-0.6B/generation_config.json +13 -0
- checkpoints/Qwen3-0.6B/merges.txt +0 -0
- checkpoints/Qwen3-0.6B/tokenizer_config.json +239 -0
- checkpoints/Qwen3-0.6B/vocab.json +0 -0
MuCodec/libs/rvq/__pycache__/descript_quantize3.cpython-312.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
MuCodec/models/__pycache__/attention.cpython-310.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
MuCodec/models/__pycache__/attention.cpython-312.pyc
ADDED
|
Binary file (25.6 kB). View file
|
|
|
MuCodec/models/__pycache__/transformer_2d_flow.cpython-310.pyc
ADDED
|
Binary file (17.9 kB). View file
|
|
|
MuCodec/models/__pycache__/transformer_2d_flow.cpython-312.pyc
ADDED
|
Binary file (26.9 kB). View file
|
|
|
MuCodec/muq_dev/__pycache__/test.cpython-310.pyc
ADDED
|
Binary file (866 Bytes). View file
|
|
|
MuCodec/muq_dev/__pycache__/test.cpython-312.pyc
ADDED
|
Binary file (1.19 kB). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .mert_dataset import MERTDataset
|
MuCodec/muq_dev/muq_fairseq/data/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (219 Bytes). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/data/__pycache__/ark_dataset.cpython-310.pyc
ADDED
|
Binary file (2.35 kB). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/data/__pycache__/mert_dataset.cpython-310.pyc
ADDED
|
Binary file (9.85 kB). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/data/ark_dataset.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from fairseq.data.audio.raw_audio_dataset import RawAudioDataset
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
try:
|
| 7 |
+
import kaldiio
|
| 8 |
+
except:
|
| 9 |
+
kaldiio = None
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ArkDataset(RawAudioDataset):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
wav_scp,
|
| 19 |
+
dur_scp,
|
| 20 |
+
sr = 24000,
|
| 21 |
+
max_dur = 20,
|
| 22 |
+
num_buckets=0,
|
| 23 |
+
normalize=False,
|
| 24 |
+
):
|
| 25 |
+
super().__init__(
|
| 26 |
+
sample_rate=sr,
|
| 27 |
+
max_sample_size=max_dur*sr,
|
| 28 |
+
min_sample_size=1200,
|
| 29 |
+
shuffle=True,
|
| 30 |
+
pad=True,
|
| 31 |
+
normalize=normalize,
|
| 32 |
+
compute_mask=False,
|
| 33 |
+
)
|
| 34 |
+
self.sr = sr
|
| 35 |
+
self.max_dur = max_dur
|
| 36 |
+
self.normalize = normalize
|
| 37 |
+
|
| 38 |
+
logger.info("Loading Kaldi scp files from {}".format(wav_scp))
|
| 39 |
+
|
| 40 |
+
self.wav_data = kaldiio.load_scp(wav_scp)
|
| 41 |
+
self.keys = list(self.wav_data.keys())
|
| 42 |
+
dur_data = {}
|
| 43 |
+
keys_set = set(self.keys)
|
| 44 |
+
|
| 45 |
+
with open(dur_scp, 'r') as f:
|
| 46 |
+
for line in f:
|
| 47 |
+
line = line.strip().split()
|
| 48 |
+
if line[0] in keys_set:
|
| 49 |
+
dur_data[line[0]] = float(line[-1])
|
| 50 |
+
self.sizes = [int(dur_data[k]*self.sr/100) for k in self.keys]
|
| 51 |
+
|
| 52 |
+
logger.info("Loading Kaldi scp files done")
|
| 53 |
+
|
| 54 |
+
self.dataset_len = len(self.keys)
|
| 55 |
+
self.set_bucket_info(num_buckets)
|
| 56 |
+
|
| 57 |
+
def __len__(self):
|
| 58 |
+
return self.dataset_len
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, idx):
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
def size(self, idx):
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
def postprocess(self, wav):
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
def collater(self, samples):
|
| 70 |
+
pass
|
| 71 |
+
|
MuCodec/muq_dev/muq_fairseq/data/mert_dataset.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
from typing import Any, List, Optional, Union
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
from typing import Tuple
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from fairseq.data import data_utils
|
| 17 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
| 18 |
+
from fairseq.data.audio.audio_utils import (
|
| 19 |
+
parse_path,
|
| 20 |
+
read_from_stored_zip,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import math
|
| 24 |
+
import io
|
| 25 |
+
import torchaudio
|
| 26 |
+
# this is in the user_dir
|
| 27 |
+
from nnAudio import features as nnAudioFeatures
|
| 28 |
+
|
| 29 |
+
# from tqdm import tqdm
|
| 30 |
+
import tqdm
|
| 31 |
+
import json
|
| 32 |
+
import random
|
| 33 |
+
import traceback
|
| 34 |
+
from einops import rearrange
|
| 35 |
+
# from scripts.prepare_codecs_from_manifest import *
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
class model_cqt_pred(torch.nn.Module):
|
| 40 |
+
def __init__(self, n_bins=84, sr=16000, freq=50):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.epsilon=1e-10
|
| 43 |
+
# Getting Mel Spectrogram on the fly
|
| 44 |
+
self.spec_layer = nnAudioFeatures.cqt.CQT(sr=sr, hop_length=sr//freq, fmin=32.7,
|
| 45 |
+
fmax=None, n_bins=n_bins, bins_per_octave=n_bins//7,
|
| 46 |
+
filter_scale=1, norm=1, window='hann', center=True,
|
| 47 |
+
pad_mode='constant', trainable=False,
|
| 48 |
+
output_format='Magnitude', verbose=True)
|
| 49 |
+
|
| 50 |
+
# self.fc = nn.Linear(input_dim, n_bins)
|
| 51 |
+
|
| 52 |
+
# self.criterion = nn.MSELoss()
|
| 53 |
+
self.forward_dict = {
|
| 54 |
+
# 'masked_transformer_output': self.plain_forward
|
| 55 |
+
'compute_cqt': self.compute_cqt
|
| 56 |
+
}
|
| 57 |
+
def compute_cqt(self, x):
|
| 58 |
+
'''
|
| 59 |
+
convert waveform to CQT -> [batch, bins, len] -> transpose
|
| 60 |
+
'''
|
| 61 |
+
# align with the padding of HuBERT model,
|
| 62 |
+
# the truncation is calculated by bruteforce search since the nnAudio padding strategy and fairseq models are different
|
| 63 |
+
# x = x[..., :-560]
|
| 64 |
+
return torch.transpose(self.spec_layer(x), -1, -2)
|
| 65 |
+
|
| 66 |
+
def forward(self, x, forward_type='masked_transformer_output'):
|
| 67 |
+
'''
|
| 68 |
+
take input from transformer hidden states: [batch, len_seq, channel]
|
| 69 |
+
output: [batch, len_seq, n_bins]
|
| 70 |
+
'''
|
| 71 |
+
|
| 72 |
+
return self.forward_dict[forward_type](x)
|
| 73 |
+
|
| 74 |
+
def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate, clip_secs=5):
|
| 75 |
+
# read json file
|
| 76 |
+
print(json_path)
|
| 77 |
+
datas = []
|
| 78 |
+
inds = []
|
| 79 |
+
sizes = []
|
| 80 |
+
with open(json_path) as fp:
|
| 81 |
+
for ind,line in enumerate(fp):
|
| 82 |
+
data = json.loads(line)
|
| 83 |
+
if 'duration' in data and min_keep is not None and tgt_sample_rate*data['duration'] < min_keep:
|
| 84 |
+
continue
|
| 85 |
+
datas.append(data)
|
| 86 |
+
inds.append(ind)
|
| 87 |
+
# sz = int(data['duration'] * data['sample_rate'])
|
| 88 |
+
if clip_secs > 0:
|
| 89 |
+
sz = int(tgt_sample_rate * clip_secs)
|
| 90 |
+
else:
|
| 91 |
+
sz = int(tgt_sample_rate * data['duration'])
|
| 92 |
+
sizes.append(sz)
|
| 93 |
+
tot = ind + 1
|
| 94 |
+
return datas,inds,tot,sizes
|
| 95 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def load_label(label_path, inds, tot):
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
def load_numpy_label(label_path, inds, tot):
|
| 103 |
+
labels = np.load(label_path, mmap_mode='r')
|
| 104 |
+
assert (labels.shape[0] == tot), f"number of labels does not match ({labels.shape[0]} != {tot})"
|
| 105 |
+
return labels
|
| 106 |
+
|
| 107 |
+
def verify_label_lengths(
|
| 108 |
+
audio_sizes,
|
| 109 |
+
audio_rate,
|
| 110 |
+
label_path,
|
| 111 |
+
label_rate,
|
| 112 |
+
inds,
|
| 113 |
+
tot,
|
| 114 |
+
tol=0.1, # tolerance in seconds
|
| 115 |
+
):
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
class Read_and_PadCrop_Normalized_T(torch.nn.Module):
|
| 119 |
+
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
| 120 |
+
|
| 121 |
+
super().__init__()
|
| 122 |
+
|
| 123 |
+
self.n_samples = n_samples
|
| 124 |
+
self.sample_rate = sample_rate
|
| 125 |
+
self.randomize = randomize
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def __call__(self, filename: str, duration: float, cur_sample_rate: int, fixed_offset_duration=None) -> Tuple[torch.Tensor, float, float, int, int]:
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class MERTDataset(FairseqDataset):
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
manifest_path: str,
|
| 136 |
+
sample_rate: float,
|
| 137 |
+
label_paths: List[str],
|
| 138 |
+
label_rates: Union[List[float], float], # -1 for sequence labels
|
| 139 |
+
pad_list: List[str],
|
| 140 |
+
eos_list: List[str],
|
| 141 |
+
label_scp_path: Optional[str] = None,
|
| 142 |
+
label_scp_clip_duration: float = -1,
|
| 143 |
+
label_processors: Optional[List[Any]] = None,
|
| 144 |
+
max_keep_sample_size: Optional[int] = None,
|
| 145 |
+
min_keep_sample_size: Optional[int] = None,
|
| 146 |
+
max_sample_size: Optional[int] = None,
|
| 147 |
+
shuffle: bool = True,
|
| 148 |
+
pad_audio: bool = False,
|
| 149 |
+
normalize: bool = False,
|
| 150 |
+
store_labels: bool = True,
|
| 151 |
+
npmemmap: bool = False,
|
| 152 |
+
random_crop: bool = False,
|
| 153 |
+
single_target: bool = False,
|
| 154 |
+
augmentation_effects: List[str] = [],
|
| 155 |
+
augmentation_probs: List[float] = [],
|
| 156 |
+
inbatch_noise_augment_len_range: List[int] = [8000, 24000],
|
| 157 |
+
inbatch_noise_augment_number_range: List[int] = [1, 3],
|
| 158 |
+
inbatch_noise_augment_volume: float = 1.0,
|
| 159 |
+
cqt_prediction_bin: int = -1,
|
| 160 |
+
dataset_len:int = 128*3000,
|
| 161 |
+
clip_secs = 5,
|
| 162 |
+
):
|
| 163 |
+
self.sample_rate = sample_rate
|
| 164 |
+
self.shuffle = shuffle
|
| 165 |
+
self.random_crop = random_crop
|
| 166 |
+
self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path,max_keep_sample_size,min_keep_sample_size, self.sample_rate, clip_secs)
|
| 167 |
+
self.inds = inds
|
| 168 |
+
|
| 169 |
+
self.num_labels = len(label_paths)
|
| 170 |
+
self.pad_list = pad_list
|
| 171 |
+
self.eos_list = eos_list
|
| 172 |
+
self.label_processors = label_processors
|
| 173 |
+
self.single_target = single_target
|
| 174 |
+
self.label_rates = (
|
| 175 |
+
[label_rates for _ in range(len(label_paths))]
|
| 176 |
+
if isinstance(label_rates, float)
|
| 177 |
+
else label_rates
|
| 178 |
+
)
|
| 179 |
+
self.store_labels = store_labels
|
| 180 |
+
self.npmemmap = npmemmap
|
| 181 |
+
self.label_scp_path = label_scp_path
|
| 182 |
+
self.label_scp_clip_duration = label_scp_clip_duration
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if self.label_scp_path is not None:
|
| 186 |
+
from kaldiio import load_scp
|
| 187 |
+
self.label_scp = load_scp(self.label_scp_path)
|
| 188 |
+
|
| 189 |
+
# self.dataset_len = dataset_len
|
| 190 |
+
self.dataset_len = len(self.datas)
|
| 191 |
+
logger.info('preparing labels')
|
| 192 |
+
logger.info('========dataset len: {}=========='.format(self.dataset_len))
|
| 193 |
+
if store_labels:
|
| 194 |
+
if self.npmemmap:
|
| 195 |
+
self.label_list = [load_numpy_label(p+'.npy', inds, tot) for p in label_paths]
|
| 196 |
+
else:
|
| 197 |
+
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
| 198 |
+
else:
|
| 199 |
+
self.label_paths = label_paths
|
| 200 |
+
# self.label_offsets_list = [
|
| 201 |
+
# load_label_offset(p, inds, tot) for p in label_paths
|
| 202 |
+
# ]
|
| 203 |
+
assert label_processors is None or len(label_processors) == self.num_labels
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
self.max_sample_size = (
|
| 207 |
+
max_sample_size if max_sample_size is not None else sys.maxsize
|
| 208 |
+
)
|
| 209 |
+
self.pad_audio = pad_audio
|
| 210 |
+
self.normalize = normalize
|
| 211 |
+
logger.info(
|
| 212 |
+
f"pad_audio={pad_audio}, random_crop={random_crop}, "
|
| 213 |
+
f"normalize={normalize}, max_sample_size={self.max_sample_size}"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
self.augmentation_effects = augmentation_effects
|
| 217 |
+
self.augmentation_probs = augmentation_probs
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
self.inbatch_noise_augment_len_range = inbatch_noise_augment_len_range
|
| 221 |
+
self.inbatch_noise_augment_number_range = inbatch_noise_augment_number_range
|
| 222 |
+
self.inbatch_noise_augment_volume = inbatch_noise_augment_volume
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
self.cqt_prediction_bin = cqt_prediction_bin
|
| 226 |
+
if self.cqt_prediction_bin > 0:
|
| 227 |
+
self.encoder_cqt_model = model_cqt_pred(n_bins=self.cqt_prediction_bin)
|
| 228 |
+
logger.info('preparing cqt loss objective in dataloader with cpu')
|
| 229 |
+
|
| 230 |
+
self.epoch = -1
|
| 231 |
+
|
| 232 |
+
self.reader = Read_and_PadCrop_Normalized_T(n_samples=clip_secs*sample_rate if clip_secs>0 else None, sample_rate = self.sample_rate)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@property
|
| 237 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
| 238 |
+
pass
|
| 239 |
+
def set_epoch(self, epoch):
|
| 240 |
+
pass
|
| 241 |
+
|
| 242 |
+
def inbatch_noise_augment(self,
|
| 243 |
+
target_audio: torch.Tensor, target_audio_idx: int ,
|
| 244 |
+
batch_audios: torch.Tensor, # [bsz, audio_lengths]
|
| 245 |
+
noise_len_min: int, noise_len_max: int,
|
| 246 |
+
n_noise_min: int, n_noise_max: int,
|
| 247 |
+
noise_vol: float = 1.0):
|
| 248 |
+
pass
|
| 249 |
+
|
| 250 |
+
def get_audio_by_slice(self,index):
|
| 251 |
+
pass
|
| 252 |
+
def get_audio(self, index):
|
| 253 |
+
pass
|
| 254 |
+
|
| 255 |
+
def get_label(self, index, label_idx):
|
| 256 |
+
pass
|
| 257 |
+
|
| 258 |
+
def get_labels(self, index):
|
| 259 |
+
pass
|
| 260 |
+
|
| 261 |
+
def __getitem__(self, i):
|
| 262 |
+
pass
|
| 263 |
+
|
| 264 |
+
def __len__(self):
|
| 265 |
+
return self.dataset_len
|
| 266 |
+
|
| 267 |
+
def crop_to_max_size(self, wav, target_size):
|
| 268 |
+
pass
|
| 269 |
+
|
| 270 |
+
def collater(self, samples):
|
| 271 |
+
pass
|
| 272 |
+
|
| 273 |
+
def collater_audio(self, audios, audio_size):
|
| 274 |
+
pass
|
| 275 |
+
|
| 276 |
+
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
| 277 |
+
pass
|
| 278 |
+
|
| 279 |
+
def collater_seq_label(self, targets, pad):
|
| 280 |
+
pass
|
| 281 |
+
|
| 282 |
+
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
def num_tokens(self, index):
|
| 286 |
+
pass
|
| 287 |
+
|
| 288 |
+
def size(self, index):
|
| 289 |
+
pass
|
| 290 |
+
|
| 291 |
+
def ordered_indices(self):
|
| 292 |
+
pass
|
| 293 |
+
|
| 294 |
+
def postprocess(self, wav, cur_sample_rate):
|
| 295 |
+
pass
|
MuCodec/muq_dev/muq_fairseq/data/utils/data_utils.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from typing import Optional, Tuple
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def compute_mask_indices(
|
| 20 |
+
shape: Tuple[int, int],
|
| 21 |
+
padding_mask: Optional[torch.Tensor],
|
| 22 |
+
mask_prob: float,
|
| 23 |
+
mask_length: int,
|
| 24 |
+
mask_type: str = "static",
|
| 25 |
+
mask_other: float = 0.0,
|
| 26 |
+
min_masks: int = 0,
|
| 27 |
+
no_overlap: bool = False,
|
| 28 |
+
min_space: int = 0,
|
| 29 |
+
require_same_masks: bool = True,
|
| 30 |
+
mask_dropout: float = 0.0,
|
| 31 |
+
add_masks: bool = False,
|
| 32 |
+
seed: Optional[int] = None,
|
| 33 |
+
epoch: Optional[int] = None,
|
| 34 |
+
indices: Optional[torch.Tensor] = None,
|
| 35 |
+
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
|
| 36 |
+
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
|
| 37 |
+
) -> np.ndarray:
|
| 38 |
+
"""
|
| 39 |
+
Computes random mask spans for a given shape
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
shape: the the shape for which to compute masks.
|
| 43 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
| 44 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
| 45 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
| 46 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
| 47 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
| 48 |
+
mask_type: how to compute mask lengths
|
| 49 |
+
static = fixed size
|
| 50 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
| 51 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
| 52 |
+
poisson = sample from possion distribution with lambda = mask length
|
| 53 |
+
min_masks: minimum number of masked spans
|
| 54 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
| 55 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
| 56 |
+
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
| 57 |
+
mask_dropout: randomly dropout this percentage of masks in each example
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
bsz, all_sz = shape
|
| 61 |
+
mask = np.full((bsz, all_sz), False)
|
| 62 |
+
|
| 63 |
+
if num_mask_ver == 1:
|
| 64 |
+
all_num_mask = int(
|
| 65 |
+
# add a random number for probabilistic rounding
|
| 66 |
+
mask_prob * all_sz / float(mask_length)
|
| 67 |
+
+ np.random.rand()
|
| 68 |
+
)
|
| 69 |
+
all_num_mask = max(min_masks, all_num_mask)
|
| 70 |
+
|
| 71 |
+
mask_idcs = []
|
| 72 |
+
for i in range(bsz):
|
| 73 |
+
if seed is not None and epoch is not None and indices is not None:
|
| 74 |
+
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
|
| 75 |
+
else:
|
| 76 |
+
seed_i = None
|
| 77 |
+
|
| 78 |
+
rng = np.random.default_rng(seed_i)
|
| 79 |
+
|
| 80 |
+
if padding_mask is not None:
|
| 81 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
| 82 |
+
assert sz >= 0, sz
|
| 83 |
+
else:
|
| 84 |
+
sz = all_sz
|
| 85 |
+
|
| 86 |
+
if num_mask_ver == 1:
|
| 87 |
+
if padding_mask is not None:
|
| 88 |
+
num_mask = int(
|
| 89 |
+
# add a random number for probabilistic rounding
|
| 90 |
+
mask_prob * sz / float(mask_length)
|
| 91 |
+
+ np.random.rand()
|
| 92 |
+
)
|
| 93 |
+
num_mask = max(min_masks, num_mask)
|
| 94 |
+
else:
|
| 95 |
+
num_mask = all_num_mask
|
| 96 |
+
elif num_mask_ver == 2:
|
| 97 |
+
num_mask = int(
|
| 98 |
+
# add a random number for probabilistic rounding
|
| 99 |
+
mask_prob * sz / float(mask_length)
|
| 100 |
+
+ rng.random()
|
| 101 |
+
)
|
| 102 |
+
num_mask = max(min_masks, num_mask)
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError()
|
| 105 |
+
|
| 106 |
+
if mask_type == "static":
|
| 107 |
+
lengths = np.full(num_mask, mask_length)
|
| 108 |
+
elif mask_type == "uniform":
|
| 109 |
+
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
| 110 |
+
elif mask_type == "normal":
|
| 111 |
+
lengths = rng.normal(mask_length, mask_other, size=num_mask)
|
| 112 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
| 113 |
+
elif mask_type == "poisson":
|
| 114 |
+
lengths = rng.poisson(mask_length, size=num_mask)
|
| 115 |
+
lengths = [int(round(x)) for x in lengths]
|
| 116 |
+
else:
|
| 117 |
+
raise Exception("unknown mask selection " + mask_type)
|
| 118 |
+
|
| 119 |
+
if sum(lengths) == 0:
|
| 120 |
+
if mask_type == "static":
|
| 121 |
+
raise ValueError(f"this should never happens")
|
| 122 |
+
else:
|
| 123 |
+
lengths = [min(mask_length, sz - 1)]
|
| 124 |
+
|
| 125 |
+
if no_overlap:
|
| 126 |
+
mask_idc = []
|
| 127 |
+
|
| 128 |
+
def arrange(s, e, length, keep_length):
|
| 129 |
+
span_start = rng.randint(s, e - length)
|
| 130 |
+
mask_idc.extend(span_start + i for i in range(length))
|
| 131 |
+
|
| 132 |
+
new_parts = []
|
| 133 |
+
if span_start - s - min_space >= keep_length:
|
| 134 |
+
new_parts.append((s, span_start - min_space + 1))
|
| 135 |
+
if e - span_start - length - min_space > keep_length:
|
| 136 |
+
new_parts.append((span_start + length + min_space, e))
|
| 137 |
+
return new_parts
|
| 138 |
+
|
| 139 |
+
parts = [(0, sz)]
|
| 140 |
+
min_length = min(lengths)
|
| 141 |
+
for length in sorted(lengths, reverse=True):
|
| 142 |
+
lens = np.fromiter(
|
| 143 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
| 144 |
+
np.int,
|
| 145 |
+
)
|
| 146 |
+
l_sum = np.sum(lens)
|
| 147 |
+
if l_sum == 0:
|
| 148 |
+
break
|
| 149 |
+
probs = lens / np.sum(lens)
|
| 150 |
+
c = rng.choice(len(parts), p=probs)
|
| 151 |
+
s, e = parts.pop(c)
|
| 152 |
+
parts.extend(arrange(s, e, length, min_length))
|
| 153 |
+
mask_idc = np.asarray(mask_idc)
|
| 154 |
+
else:
|
| 155 |
+
if idc_select_ver == 1:
|
| 156 |
+
min_len = min(lengths)
|
| 157 |
+
if sz - min_len <= num_mask:
|
| 158 |
+
min_len = sz - num_mask - 1
|
| 159 |
+
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
| 160 |
+
elif idc_select_ver == 2:
|
| 161 |
+
mask_idc = rng.choice(sz, num_mask, replace=False)
|
| 162 |
+
else:
|
| 163 |
+
raise ValueError()
|
| 164 |
+
|
| 165 |
+
mask_idc = np.asarray(
|
| 166 |
+
[
|
| 167 |
+
mask_idc[j] + offset
|
| 168 |
+
for j in range(len(mask_idc))
|
| 169 |
+
for offset in range(lengths[j])
|
| 170 |
+
]
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
| 174 |
+
if len(mask_idc) >= sz:
|
| 175 |
+
raise ValueError(
|
| 176 |
+
(
|
| 177 |
+
f"the entire sequence is masked. "
|
| 178 |
+
f"sz={sz}; mask_idc[mask_idc]; "
|
| 179 |
+
f"index={indices[i] if indices is not None else None}"
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
mask_idcs.append(mask_idc)
|
| 183 |
+
|
| 184 |
+
target_len = None
|
| 185 |
+
if require_same_masks:
|
| 186 |
+
if add_masks:
|
| 187 |
+
target_len = max([len(m) for m in mask_idcs])
|
| 188 |
+
else:
|
| 189 |
+
target_len = min([len(m) for m in mask_idcs])
|
| 190 |
+
|
| 191 |
+
for i, mask_idc in enumerate(mask_idcs):
|
| 192 |
+
if target_len is not None and len(mask_idc) > target_len:
|
| 193 |
+
mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
| 194 |
+
|
| 195 |
+
mask[i, mask_idc] = True
|
| 196 |
+
|
| 197 |
+
if target_len is not None and len(mask_idc) < target_len:
|
| 198 |
+
unmasked = np.flatnonzero(~mask[i])
|
| 199 |
+
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
|
| 200 |
+
mask[i, to_mask] = True
|
| 201 |
+
|
| 202 |
+
if mask_dropout > 0:
|
| 203 |
+
masked = np.flatnonzero(mask[i])
|
| 204 |
+
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
|
| 205 |
+
to_drop = rng.choice(masked, num_holes, replace=False)
|
| 206 |
+
mask[i, to_drop] = False
|
| 207 |
+
|
| 208 |
+
return mask
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def compute_block_mask_2d(
|
| 212 |
+
shape: Tuple[int, int],
|
| 213 |
+
mask_prob: float,
|
| 214 |
+
mask_length: int,
|
| 215 |
+
mask_prob_adjust: float = 0,
|
| 216 |
+
inverse_mask: bool = False,
|
| 217 |
+
require_same_masks: bool = True,
|
| 218 |
+
expand_adjcent: bool = False,
|
| 219 |
+
mask_dropout: float = 0,
|
| 220 |
+
non_overlapping: bool = False,
|
| 221 |
+
img_shape: tuple = None, # For the situation when d[0] != d[1], especially in audio spce ways
|
| 222 |
+
flexible_mask: bool = False,
|
| 223 |
+
) -> torch.Tensor:
|
| 224 |
+
|
| 225 |
+
assert mask_length > 1
|
| 226 |
+
|
| 227 |
+
B, L = shape
|
| 228 |
+
|
| 229 |
+
d = (int(L**0.5),int(L**0.5))
|
| 230 |
+
|
| 231 |
+
if img_shape:
|
| 232 |
+
d = (img_shape[0],img_shape[1])
|
| 233 |
+
|
| 234 |
+
if flexible_mask:
|
| 235 |
+
index = np.random.randint(0,3)
|
| 236 |
+
block_size_options = np.array([(6, 4), (5, 5), (8, 3)])
|
| 237 |
+
block_size = block_size_options[index]
|
| 238 |
+
|
| 239 |
+
if inverse_mask:
|
| 240 |
+
mask_prob = 1 - mask_prob
|
| 241 |
+
|
| 242 |
+
if flexible_mask:
|
| 243 |
+
mask = torch.zeros((B, d[0], d[1]))
|
| 244 |
+
mask_inds = torch.randint(
|
| 245 |
+
0,
|
| 246 |
+
L,
|
| 247 |
+
size=(
|
| 248 |
+
B,
|
| 249 |
+
int(
|
| 250 |
+
L
|
| 251 |
+
* ((mask_prob + mask_prob_adjust) / (block_size[0]*block_size[1]))
|
| 252 |
+
* (1 + mask_dropout)
|
| 253 |
+
),
|
| 254 |
+
),
|
| 255 |
+
)
|
| 256 |
+
mask.view(B, -1).scatter_(1, mask_inds, 1)
|
| 257 |
+
centers = mask.nonzero(as_tuple=True)
|
| 258 |
+
|
| 259 |
+
inds = ([], [], [])
|
| 260 |
+
|
| 261 |
+
offset = mask_length // 2
|
| 262 |
+
for i in range(block_size[0]):
|
| 263 |
+
for j in range(block_size[1]):
|
| 264 |
+
k1 = i - offset
|
| 265 |
+
k2 = j - offset
|
| 266 |
+
inds[0].append(centers[0])
|
| 267 |
+
inds[1].append(centers[1] + k1)
|
| 268 |
+
inds[2].append(centers[2] + k2)
|
| 269 |
+
|
| 270 |
+
i0 = torch.cat(inds[0])
|
| 271 |
+
i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1)
|
| 272 |
+
i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1)
|
| 273 |
+
|
| 274 |
+
mask[(i0, i1, i2)] = 1
|
| 275 |
+
|
| 276 |
+
elif non_overlapping:
|
| 277 |
+
sz = math.ceil(d[0] / mask_length)
|
| 278 |
+
inp_len = sz * sz
|
| 279 |
+
|
| 280 |
+
inp = torch.zeros((B, 1, sz, sz))
|
| 281 |
+
w = torch.ones((1, 1, mask_length, mask_length))
|
| 282 |
+
|
| 283 |
+
mask_inds = torch.multinomial(
|
| 284 |
+
1 - inp.view(B, -1),
|
| 285 |
+
int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
|
| 286 |
+
replacement=False,
|
| 287 |
+
)
|
| 288 |
+
inp.view(B, -1).scatter_(1, mask_inds, 1)
|
| 289 |
+
|
| 290 |
+
mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze(
|
| 291 |
+
1
|
| 292 |
+
)
|
| 293 |
+
if mask.size(-1) > d[0]:
|
| 294 |
+
mask = mask[..., :d, :d]
|
| 295 |
+
else:
|
| 296 |
+
mask = torch.zeros((B, d[0], d[1]))
|
| 297 |
+
mask_inds = torch.randint(
|
| 298 |
+
0,
|
| 299 |
+
L,
|
| 300 |
+
size=(
|
| 301 |
+
B,
|
| 302 |
+
int(
|
| 303 |
+
L
|
| 304 |
+
* ((mask_prob + mask_prob_adjust) / mask_length**2)
|
| 305 |
+
* (1 + mask_dropout)
|
| 306 |
+
),
|
| 307 |
+
),
|
| 308 |
+
)
|
| 309 |
+
mask.view(B, -1).scatter_(1, mask_inds, 1)
|
| 310 |
+
centers = mask.nonzero(as_tuple=True)
|
| 311 |
+
|
| 312 |
+
inds = ([], [], [])
|
| 313 |
+
|
| 314 |
+
offset = mask_length // 2
|
| 315 |
+
for i in range(mask_length):
|
| 316 |
+
for j in range(mask_length):
|
| 317 |
+
k1 = i - offset
|
| 318 |
+
k2 = j - offset
|
| 319 |
+
inds[0].append(centers[0])
|
| 320 |
+
inds[1].append(centers[1] + k1)
|
| 321 |
+
inds[2].append(centers[2] + k2)
|
| 322 |
+
|
| 323 |
+
i0 = torch.cat(inds[0])
|
| 324 |
+
i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1)
|
| 325 |
+
i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1)
|
| 326 |
+
|
| 327 |
+
mask[(i0, i1, i2)] = 1
|
| 328 |
+
|
| 329 |
+
def get_nbs(b, m, w):
|
| 330 |
+
all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same")
|
| 331 |
+
all_nbs = all_nbs.clamp_max_(1).view(b, -1)
|
| 332 |
+
return all_nbs
|
| 333 |
+
|
| 334 |
+
if require_same_masks and expand_adjcent:
|
| 335 |
+
w = torch.zeros((1, 1, 3, 3))
|
| 336 |
+
w[..., 0, 1] = 1
|
| 337 |
+
w[..., 2, 1] = 1
|
| 338 |
+
w[..., 1, 0] = 1
|
| 339 |
+
w[..., 1, 2] = 1
|
| 340 |
+
|
| 341 |
+
all_nbs = get_nbs(B, mask, w)
|
| 342 |
+
|
| 343 |
+
mask = mask.reshape(B, -1)
|
| 344 |
+
|
| 345 |
+
if require_same_masks:
|
| 346 |
+
n_masks = mask.sum(dim=-1)
|
| 347 |
+
final_target_len = int(L * (mask_prob))
|
| 348 |
+
target_len = int(final_target_len * (1 + mask_dropout))
|
| 349 |
+
|
| 350 |
+
for i in range(len(mask)):
|
| 351 |
+
n = n_masks[i]
|
| 352 |
+
m = mask[i]
|
| 353 |
+
r = 0
|
| 354 |
+
while expand_adjcent and n < target_len:
|
| 355 |
+
if r == 0:
|
| 356 |
+
nbs = all_nbs[i]
|
| 357 |
+
else:
|
| 358 |
+
nbs = get_nbs(1, m.view(1, d[0], d[1]), w).flatten()
|
| 359 |
+
|
| 360 |
+
cands = (1 - m + nbs) > 1
|
| 361 |
+
cand_sz = int(cands.sum().item())
|
| 362 |
+
|
| 363 |
+
assert cand_sz > 0, f"{nbs} {cand_sz}"
|
| 364 |
+
|
| 365 |
+
to_mask = torch.multinomial(
|
| 366 |
+
cands.float(), min(cand_sz, int(target_len - n)), replacement=False
|
| 367 |
+
)
|
| 368 |
+
m[to_mask] = 1
|
| 369 |
+
assert to_mask.numel() > 0
|
| 370 |
+
n += to_mask.numel()
|
| 371 |
+
r += 1
|
| 372 |
+
|
| 373 |
+
if n > final_target_len:
|
| 374 |
+
to_unmask = torch.multinomial(
|
| 375 |
+
m, int(n - final_target_len), replacement=False
|
| 376 |
+
)
|
| 377 |
+
m[to_unmask] = 0
|
| 378 |
+
elif n < final_target_len:
|
| 379 |
+
to_mask = torch.multinomial(
|
| 380 |
+
(1 - m), int(final_target_len - n), replacement=False
|
| 381 |
+
)
|
| 382 |
+
m[to_mask] = 1
|
| 383 |
+
|
| 384 |
+
if inverse_mask:
|
| 385 |
+
mask = 1 - mask
|
| 386 |
+
|
| 387 |
+
return mask
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def compute_block_mask_1d(
|
| 391 |
+
shape: Tuple[int, int],
|
| 392 |
+
mask_prob: float,
|
| 393 |
+
mask_length: int,
|
| 394 |
+
mask_prob_adjust: float = 0,
|
| 395 |
+
inverse_mask: bool = False,
|
| 396 |
+
require_same_masks: bool = True,
|
| 397 |
+
expand_adjcent: bool = False,
|
| 398 |
+
mask_dropout: float = 0,
|
| 399 |
+
non_overlapping: bool = False,
|
| 400 |
+
) -> torch.Tensor:
|
| 401 |
+
|
| 402 |
+
B, L = shape
|
| 403 |
+
|
| 404 |
+
if inverse_mask:
|
| 405 |
+
mask_prob = 1 - mask_prob
|
| 406 |
+
|
| 407 |
+
if non_overlapping:
|
| 408 |
+
sz = math.ceil(L / mask_length)
|
| 409 |
+
|
| 410 |
+
inp = torch.zeros((B, 1, sz))
|
| 411 |
+
w = torch.ones((1, 1, mask_length))
|
| 412 |
+
|
| 413 |
+
mask_inds = torch.multinomial(
|
| 414 |
+
1 - inp.view(B, -1),
|
| 415 |
+
int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
|
| 416 |
+
replacement=False,
|
| 417 |
+
)
|
| 418 |
+
inp.view(B, -1).scatter_(1, mask_inds, 1)
|
| 419 |
+
|
| 420 |
+
mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze(
|
| 421 |
+
1
|
| 422 |
+
)
|
| 423 |
+
if mask.size(-1) > L:
|
| 424 |
+
mask = mask[..., :L]
|
| 425 |
+
|
| 426 |
+
else:
|
| 427 |
+
mask = torch.zeros((B, L))
|
| 428 |
+
mask_inds = torch.randint(
|
| 429 |
+
0,
|
| 430 |
+
L,
|
| 431 |
+
size=(
|
| 432 |
+
B,
|
| 433 |
+
int(
|
| 434 |
+
L
|
| 435 |
+
* ((mask_prob + mask_prob_adjust) / mask_length)
|
| 436 |
+
* (1 + mask_dropout)
|
| 437 |
+
),
|
| 438 |
+
),
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
mask.view(B, -1).scatter_(1, mask_inds, 1)
|
| 442 |
+
centers = mask.nonzero(as_tuple=True)
|
| 443 |
+
|
| 444 |
+
inds = ([], [])
|
| 445 |
+
|
| 446 |
+
offset = mask_length // 2
|
| 447 |
+
for i in range(mask_length):
|
| 448 |
+
k1 = i - offset
|
| 449 |
+
inds[0].append(centers[0])
|
| 450 |
+
inds[1].append(centers[1] + k1)
|
| 451 |
+
|
| 452 |
+
i0 = torch.cat(inds[0])
|
| 453 |
+
i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1)
|
| 454 |
+
|
| 455 |
+
mask[(i0, i1)] = 1
|
| 456 |
+
|
| 457 |
+
def get_nbs(b, m, w):
|
| 458 |
+
all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same")
|
| 459 |
+
all_nbs = all_nbs.clamp_max_(1).view(b, -1)
|
| 460 |
+
return all_nbs
|
| 461 |
+
|
| 462 |
+
if require_same_masks and expand_adjcent:
|
| 463 |
+
w = torch.ones((1, 1, 3))
|
| 464 |
+
w[..., 1] = 0
|
| 465 |
+
all_nbs = get_nbs(B, mask, w)
|
| 466 |
+
|
| 467 |
+
mask = mask.view(B, -1)
|
| 468 |
+
|
| 469 |
+
if require_same_masks:
|
| 470 |
+
n_masks = mask.sum(dim=-1)
|
| 471 |
+
final_target_len = int(L * (mask_prob))
|
| 472 |
+
target_len = int(final_target_len * (1 + mask_dropout))
|
| 473 |
+
|
| 474 |
+
for i in range(len(mask)):
|
| 475 |
+
n = n_masks[i]
|
| 476 |
+
m = mask[i]
|
| 477 |
+
r = 0
|
| 478 |
+
while expand_adjcent and n < target_len:
|
| 479 |
+
if r == 0:
|
| 480 |
+
nbs = all_nbs[i]
|
| 481 |
+
else:
|
| 482 |
+
nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0)
|
| 483 |
+
|
| 484 |
+
cands = (1 - m + nbs) > 1
|
| 485 |
+
cand_sz = int(cands.sum().item())
|
| 486 |
+
|
| 487 |
+
assert cand_sz > 0, f"{nbs} {cand_sz}"
|
| 488 |
+
|
| 489 |
+
to_mask = torch.multinomial(
|
| 490 |
+
cands.float(), min(cand_sz, int(target_len - n)), replacement=False
|
| 491 |
+
)
|
| 492 |
+
m[to_mask] = 1
|
| 493 |
+
assert to_mask.numel() > 0
|
| 494 |
+
n += to_mask.numel()
|
| 495 |
+
r += 1
|
| 496 |
+
|
| 497 |
+
if n > final_target_len:
|
| 498 |
+
to_unmask = torch.multinomial(
|
| 499 |
+
m, int(n - final_target_len), replacement=False
|
| 500 |
+
)
|
| 501 |
+
m[to_unmask] = 0
|
| 502 |
+
elif n < final_target_len:
|
| 503 |
+
to_mask = torch.multinomial(
|
| 504 |
+
(1 - m), int(final_target_len - n), replacement=False
|
| 505 |
+
)
|
| 506 |
+
m[to_mask] = 1
|
| 507 |
+
|
| 508 |
+
if inverse_mask:
|
| 509 |
+
mask = 1 - mask
|
| 510 |
+
|
| 511 |
+
return mask
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def get_buckets(sizes, num_buckets):
|
| 515 |
+
buckets = np.unique(
|
| 516 |
+
np.percentile(
|
| 517 |
+
sizes,
|
| 518 |
+
np.linspace(0, 100, num_buckets + 1),
|
| 519 |
+
interpolation="lower",
|
| 520 |
+
)[1:]
|
| 521 |
+
)
|
| 522 |
+
return buckets
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def get_bucketed_sizes(orig_sizes, buckets):
|
| 526 |
+
sizes = np.copy(orig_sizes)
|
| 527 |
+
assert np.min(sizes) >= 0
|
| 528 |
+
start_val = -1
|
| 529 |
+
for end_val in buckets:
|
| 530 |
+
mask = (sizes > start_val) & (sizes <= end_val)
|
| 531 |
+
sizes[mask] = end_val
|
| 532 |
+
start_val = end_val
|
| 533 |
+
return sizes
|
| 534 |
+
|
| 535 |
+
|
MuCodec/muq_dev/muq_fairseq/models/muq/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .muq_model import *
|
MuCodec/muq_dev/muq_fairseq/models/muq/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (203 Bytes). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/__pycache__/muq_model.cpython-310.pyc
ADDED
|
Binary file (4.96 kB). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (183 Bytes). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/muq.cpython-310.pyc
ADDED
|
Binary file (15.8 kB). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/rvq.cpython-310.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/__pycache__/rvq_muq.cpython-310.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/muq.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
import os
|
| 7 |
+
from fairseq.data.data_utils import compute_mask_indices
|
| 8 |
+
from fairseq.models.wav2vec.wav2vec2 import ConvFeatureExtractionModel
|
| 9 |
+
from fairseq.modules import LayerNorm
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from ..modules.random_quantizer import RandomProjectionQuantizer
|
| 13 |
+
from ..modules.features import MelSTFT
|
| 14 |
+
from ..modules.conv import Conv2dSubsampling
|
| 15 |
+
except:
|
| 16 |
+
import sys, os
|
| 17 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
| 18 |
+
from modules.random_quantizer import RandomProjectionQuantizer
|
| 19 |
+
from modules.features import MelSTFT
|
| 20 |
+
from modules.conv import Conv2dSubsampling
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MuQ(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
MuQ
|
| 26 |
+
|
| 27 |
+
Input: 128-band mel spectrogram
|
| 28 |
+
Frontend: 2-layer Residual convolution
|
| 29 |
+
Backend: 12-layer Conformer
|
| 30 |
+
Quantizer: a codebook for mel spectrogram
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
num_codebooks=1,
|
| 36 |
+
codebook_dim=16,
|
| 37 |
+
codebook_size=4096,
|
| 38 |
+
features=["melspec_2048"],
|
| 39 |
+
hop_length=240,
|
| 40 |
+
n_mels=128,
|
| 41 |
+
conv_dim=512,
|
| 42 |
+
encoder_dim=1024,
|
| 43 |
+
encoder_depth=12,
|
| 44 |
+
mask_hop=0.4,
|
| 45 |
+
mask_prob=0.6,
|
| 46 |
+
is_flash=False,
|
| 47 |
+
stat_path=None, #"./data/fma_stats.json",
|
| 48 |
+
model_path=None, #"./data/pretrained_fma.pt",
|
| 49 |
+
w2v2_config_path=None, #"facebook/wav2vec2-conformer-rope-large-960h-ft",
|
| 50 |
+
use_rvq_target=False,
|
| 51 |
+
use_vq_target=False,
|
| 52 |
+
rvq_ckpt_path=None,
|
| 53 |
+
recon_loss_ratio=None,
|
| 54 |
+
label_rate=25,
|
| 55 |
+
use_hubert_masking_strategy=False,
|
| 56 |
+
use_hubert_featurizer=False,
|
| 57 |
+
hubert_conv_feature_layers="[(512,10,5)] + [(512,3,2)] * 3 + [(512,3,3)] + [(512,2,2)] * 2",
|
| 58 |
+
use_hubert_nce_loss=False,
|
| 59 |
+
hubert_final_dim=256,
|
| 60 |
+
rvq_n_codebooks=8,
|
| 61 |
+
rvq_multi_layer_num=1,
|
| 62 |
+
use_encodec_target=False,
|
| 63 |
+
):
|
| 64 |
+
super(MuQ, self).__init__()
|
| 65 |
+
|
| 66 |
+
# global variables
|
| 67 |
+
self.hop_length = hop_length
|
| 68 |
+
self.mask_hop = mask_hop
|
| 69 |
+
self.mask_prob = mask_prob
|
| 70 |
+
self.num_codebooks = num_codebooks
|
| 71 |
+
self.codebook_size = codebook_size
|
| 72 |
+
self.features = features
|
| 73 |
+
self.recon_loss_ratio = recon_loss_ratio
|
| 74 |
+
self.n_fold = int(100//label_rate)
|
| 75 |
+
self.label_rate = label_rate
|
| 76 |
+
self.use_hubert_masking_strategy = use_hubert_masking_strategy
|
| 77 |
+
self.use_hubert_featurizer = use_hubert_featurizer
|
| 78 |
+
self.use_hubert_nce_loss = use_hubert_nce_loss
|
| 79 |
+
|
| 80 |
+
# load feature mean / std stats
|
| 81 |
+
import os
|
| 82 |
+
if stat_path is not None and os.path.exists(stat_path):
|
| 83 |
+
with open(stat_path, "r") as f:
|
| 84 |
+
self.stat = json.load(f)
|
| 85 |
+
else:
|
| 86 |
+
# print("No stats file found at `{}`, use default from msd.".format(stat_path))
|
| 87 |
+
self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
|
| 88 |
+
|
| 89 |
+
# feature extractor
|
| 90 |
+
self.preprocessor_melspec_2048 = MelSTFT(
|
| 91 |
+
n_fft=2048, hop_length=hop_length, is_db=True
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# random quantizer
|
| 95 |
+
self.use_rvq_target = use_rvq_target
|
| 96 |
+
self.use_vq_target = use_vq_target
|
| 97 |
+
self.use_encodec_target = use_encodec_target
|
| 98 |
+
|
| 99 |
+
seed = 142
|
| 100 |
+
if self.use_rvq_like_target:
|
| 101 |
+
if use_rvq_target:
|
| 102 |
+
try:
|
| 103 |
+
from .rvq_muq import ResidualVectorQuantize
|
| 104 |
+
except:
|
| 105 |
+
import sys, os
|
| 106 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 107 |
+
from rvq_muq import ResidualVectorQuantize
|
| 108 |
+
|
| 109 |
+
inp_dim = 128*self.n_fold
|
| 110 |
+
self.rvq = ResidualVectorQuantize(
|
| 111 |
+
input_dim = inp_dim,
|
| 112 |
+
n_codebooks = rvq_n_codebooks,
|
| 113 |
+
codebook_size = 1024,
|
| 114 |
+
codebook_dim = 16,
|
| 115 |
+
quantizer_dropout = 0.0,
|
| 116 |
+
use_multi_layer_num = rvq_multi_layer_num,
|
| 117 |
+
)
|
| 118 |
+
elif use_vq_target:
|
| 119 |
+
try:
|
| 120 |
+
from .rvq_muq import VectorQuantize
|
| 121 |
+
except:
|
| 122 |
+
import sys, os
|
| 123 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 124 |
+
from rvq_muq import VectorQuantize
|
| 125 |
+
|
| 126 |
+
self.rvq = VectorQuantize(
|
| 127 |
+
input_dim = 128*self.n_fold,
|
| 128 |
+
codebook_size = 1024,
|
| 129 |
+
codebook_dim = 8,
|
| 130 |
+
stale_tolerance = 1000,
|
| 131 |
+
mfcc_clustering = False
|
| 132 |
+
)
|
| 133 |
+
elif use_encodec_target:
|
| 134 |
+
from encodec import EncodecModel
|
| 135 |
+
self.rvq = EncodecModel.encodec_model_24khz()
|
| 136 |
+
self.rvq.set_target_bandwidth(6.0)
|
| 137 |
+
for param in self.rvq.parameters():
|
| 138 |
+
param.requires_grad = False
|
| 139 |
+
|
| 140 |
+
import os
|
| 141 |
+
if rvq_ckpt_path is not None and os.path.exists(rvq_ckpt_path):
|
| 142 |
+
state_dict = torch.load(rvq_ckpt_path, map_location="cpu")
|
| 143 |
+
self.rvq.load_state_dict(state_dict)
|
| 144 |
+
else:
|
| 145 |
+
print(f'Checkpoint for rvq `{rvq_ckpt_path}` not found. Using random initialization.')
|
| 146 |
+
else:
|
| 147 |
+
for feature in self.features:
|
| 148 |
+
for i in range(num_codebooks):
|
| 149 |
+
setattr(
|
| 150 |
+
self,
|
| 151 |
+
f"quantizer_{feature}", # _{i}
|
| 152 |
+
RandomProjectionQuantizer(
|
| 153 |
+
n_mels * self.n_fold, codebook_dim, codebook_size, seed=seed + i
|
| 154 |
+
),
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
if use_hubert_masking_strategy:
|
| 158 |
+
self.mask_emb = nn.Parameter(
|
| 159 |
+
torch.FloatTensor(encoder_dim).uniform_()
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
if use_hubert_featurizer:
|
| 163 |
+
feature_enc_layers = eval(hubert_conv_feature_layers) # noqa
|
| 164 |
+
hubert_feat_embed = feature_enc_layers[-1][0]
|
| 165 |
+
self.hubert_feature_extractor = ConvFeatureExtractionModel(
|
| 166 |
+
conv_layers=feature_enc_layers,
|
| 167 |
+
dropout=0.0,
|
| 168 |
+
mode='default', #cfg.extractor_mode,
|
| 169 |
+
conv_bias=False, #cfg.conv_bias,
|
| 170 |
+
)
|
| 171 |
+
self.post_extract_proj = (
|
| 172 |
+
nn.Linear(hubert_feat_embed, encoder_dim)
|
| 173 |
+
if hubert_feat_embed != encoder_dim
|
| 174 |
+
else None
|
| 175 |
+
)
|
| 176 |
+
self.layer_norm = LayerNorm(hubert_feat_embed)
|
| 177 |
+
else:
|
| 178 |
+
# two residual convolution layers + one projection layer
|
| 179 |
+
strides_factory = {
|
| 180 |
+
4: [2, 2],
|
| 181 |
+
2: [2, 1]
|
| 182 |
+
}
|
| 183 |
+
self.conv = Conv2dSubsampling(
|
| 184 |
+
1, conv_dim, encoder_dim, strides=strides_factory.get(self.n_fold), n_bands=n_mels
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Conformer
|
| 188 |
+
if is_flash:
|
| 189 |
+
from modules.flash_conformer import (
|
| 190 |
+
Wav2Vec2ConformerEncoder,
|
| 191 |
+
Wav2Vec2ConformerConfig,
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
|
| 195 |
+
Wav2Vec2ConformerEncoder,
|
| 196 |
+
Wav2Vec2ConformerConfig,
|
| 197 |
+
)
|
| 198 |
+
import os
|
| 199 |
+
if w2v2_config_path is None or not os.path.exists(w2v2_config_path):
|
| 200 |
+
w2v2_config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "w2v2_config.json")
|
| 201 |
+
print("load w2v2 config from:", w2v2_config_path)
|
| 202 |
+
config = Wav2Vec2ConformerConfig.from_pretrained(
|
| 203 |
+
w2v2_config_path
|
| 204 |
+
)
|
| 205 |
+
config.num_hidden_layers = encoder_depth
|
| 206 |
+
config.hidden_size = encoder_dim
|
| 207 |
+
|
| 208 |
+
self.conformer = Wav2Vec2ConformerEncoder(config)
|
| 209 |
+
|
| 210 |
+
if self.use_hubert_nce_loss:
|
| 211 |
+
self.label_embs_concat = nn.Parameter(
|
| 212 |
+
torch.FloatTensor(codebook_size, hubert_final_dim)
|
| 213 |
+
) # embeddings of codes
|
| 214 |
+
nn.init.uniform_(self.label_embs_concat)
|
| 215 |
+
self.linear = nn.Linear(encoder_dim, hubert_final_dim) # final_proj
|
| 216 |
+
else:
|
| 217 |
+
# projection
|
| 218 |
+
self.linear = nn.Linear(encoder_dim, codebook_size) # N_SubSpec=8
|
| 219 |
+
|
| 220 |
+
# reconstruct melspec
|
| 221 |
+
if self.recon_loss_ratio is not None and self.recon_loss_ratio > 0:
|
| 222 |
+
self.recon_proj = nn.Linear(encoder_dim, n_mels * self.n_fold)
|
| 223 |
+
self.recon_loss = nn.MSELoss()
|
| 224 |
+
|
| 225 |
+
# loss function
|
| 226 |
+
self.loss = nn.CrossEntropyLoss()
|
| 227 |
+
|
| 228 |
+
# cls token (used for sequence classification)
|
| 229 |
+
random.seed(seed)
|
| 230 |
+
self.cls_token = nn.Parameter(torch.randn(encoder_dim))
|
| 231 |
+
|
| 232 |
+
# load model
|
| 233 |
+
if model_path:
|
| 234 |
+
S = torch.load(model_path)["state_dict"]
|
| 235 |
+
SS = {k[6:]: v for k, v in S.items()}
|
| 236 |
+
SS['quantizer_melspec_2048.random_projection'] = SS['quantizer_melspec_2048_0.random_projection']
|
| 237 |
+
SS['quantizer_melspec_2048.codebook'] = SS['quantizer_melspec_2048_0.codebook']
|
| 238 |
+
del SS['quantizer_melspec_2048_0.random_projection']
|
| 239 |
+
del SS['quantizer_melspec_2048_0.codebook']
|
| 240 |
+
unmatch = self.load_state_dict(SS, strict=False)
|
| 241 |
+
if len(unmatch.missing_keys) > 0:
|
| 242 |
+
print(f'Missing keys: {unmatch.missing_keys}')
|
| 243 |
+
|
| 244 |
+
@property
|
| 245 |
+
def use_rvq_like_target(self):
|
| 246 |
+
return self.use_rvq_target or self.use_vq_target or self.use_encodec_target
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def apply_hubert_mask(self, x, padding_mask=None, target_list=None):
|
| 250 |
+
B, T, C = x.shape
|
| 251 |
+
if self.mask_prob > 0:
|
| 252 |
+
mask_length = int(self.mask_hop / (1/self.label_rate))
|
| 253 |
+
mask_indices = compute_mask_indices(
|
| 254 |
+
(B, T),
|
| 255 |
+
padding_mask,
|
| 256 |
+
self.mask_prob,
|
| 257 |
+
mask_length, # self.mask_length,
|
| 258 |
+
"static", #self.mask_selection,
|
| 259 |
+
0, #self.mask_other,
|
| 260 |
+
min_masks=2,
|
| 261 |
+
no_overlap=False, #self.no_mask_overlap,
|
| 262 |
+
min_space=1, #self.mask_min_space,
|
| 263 |
+
)
|
| 264 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
| 265 |
+
x[mask_indices] = self.mask_emb
|
| 266 |
+
mask_indices = torch.nonzero(mask_indices)
|
| 267 |
+
else:
|
| 268 |
+
mask_indices = None
|
| 269 |
+
|
| 270 |
+
return x, mask_indices
|
| 271 |
+
|
| 272 |
+
def masking(self, x, attention_mask=None):
|
| 273 |
+
"""random masking of 400ms with given probability"""
|
| 274 |
+
if self.use_hubert_masking_strategy:
|
| 275 |
+
return x, None
|
| 276 |
+
mx = x.clone()
|
| 277 |
+
b, t = mx.shape
|
| 278 |
+
len_masking_raw = int(24000 * self.mask_hop) # 9600 = 24000 * 0.4
|
| 279 |
+
len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop) # 10 = 25Hz * 0.4
|
| 280 |
+
|
| 281 |
+
# get random mask indices
|
| 282 |
+
start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
|
| 283 |
+
time_domain_masked_indices = torch.nonzero(
|
| 284 |
+
start_indices.repeat_interleave(len_masking_raw, dim=1)
|
| 285 |
+
)
|
| 286 |
+
token_domain_masked_indices = torch.nonzero(
|
| 287 |
+
start_indices.repeat_interleave(len_masking_token, dim=1)
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# mask with random values
|
| 291 |
+
masking_noise = (
|
| 292 |
+
torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
|
| 293 |
+
) # 0 mean 0.1 std
|
| 294 |
+
mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
|
| 295 |
+
|
| 296 |
+
return mx, token_domain_masked_indices
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
@torch.no_grad()
|
| 300 |
+
def preprocessing(self, x, features):
|
| 301 |
+
"""extract classic audio features"""
|
| 302 |
+
# check precision
|
| 303 |
+
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
| 304 |
+
precision = 16
|
| 305 |
+
else:
|
| 306 |
+
precision = 32
|
| 307 |
+
|
| 308 |
+
out = {}
|
| 309 |
+
for key in features:
|
| 310 |
+
layer = getattr(self, "preprocessor_%s" % key)
|
| 311 |
+
layer.to(x.device)
|
| 312 |
+
dtype = x.dtype
|
| 313 |
+
out[key] = layer.float()(x.float())[..., :-1]
|
| 314 |
+
if precision == 16:
|
| 315 |
+
out[key] = out[key].half()
|
| 316 |
+
if out[key].dtype != dtype:
|
| 317 |
+
out[key].to(dtype=dtype)
|
| 318 |
+
return out
|
| 319 |
+
|
| 320 |
+
def encoder(self, x, *, attention_mask=None, is_features_only=False):
|
| 321 |
+
"""2-layer conv + w2v-conformer"""
|
| 322 |
+
if not self.use_hubert_featurizer:
|
| 323 |
+
x = self.conv(x) # [3, 128, 3000] -> [3, 750, 1024]
|
| 324 |
+
if self.training and self.use_hubert_masking_strategy and not is_features_only:
|
| 325 |
+
x, mask_indices = self.apply_hubert_mask(x)
|
| 326 |
+
else:
|
| 327 |
+
mask_indices = None
|
| 328 |
+
if attention_mask is None:
|
| 329 |
+
out = self.conformer(x, output_hidden_states=True)
|
| 330 |
+
else:
|
| 331 |
+
attention_mask = attention_mask.bool()
|
| 332 |
+
skip_n = int(attention_mask.size(-1) / x.size(1))
|
| 333 |
+
attention_mask = attention_mask[:, ::skip_n]
|
| 334 |
+
attention_mask = attention_mask[:, :x.size(1)]
|
| 335 |
+
out = self.conformer(x, attention_mask=attention_mask, output_hidden_states=True)
|
| 336 |
+
hidden_emb = out["hidden_states"]
|
| 337 |
+
last_emb = out["last_hidden_state"]
|
| 338 |
+
logits = self.linear(last_emb)
|
| 339 |
+
interval = self.codebook_size
|
| 340 |
+
logits = {
|
| 341 |
+
key: logits[:, :, i * interval : (i + 1) * interval]
|
| 342 |
+
for i, key in enumerate(self.features)
|
| 343 |
+
}
|
| 344 |
+
return logits, hidden_emb, mask_indices
|
| 345 |
+
|
| 346 |
+
@torch.no_grad()
|
| 347 |
+
def normalize(self, x):
|
| 348 |
+
"""normalize the input audio to have zero mean unit variance"""
|
| 349 |
+
for key in x.keys():
|
| 350 |
+
x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] # {'melspec_2048_cnt': 14282760192, 'melspec_2048_mean': 6.768444971712967}
|
| 351 |
+
return x
|
| 352 |
+
|
| 353 |
+
@torch.no_grad()
|
| 354 |
+
def rearrange(self, x):
|
| 355 |
+
"""rearrange the batch to flatten every 4 steps"""
|
| 356 |
+
for key in x.keys():
|
| 357 |
+
if key == "chromagram":
|
| 358 |
+
x[key] = rearrange(x[key], "b f t -> b t f")
|
| 359 |
+
else:
|
| 360 |
+
x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=self.n_fold)
|
| 361 |
+
return x
|
| 362 |
+
|
| 363 |
+
def get_rvq_codes(self, inp, raw_wav):
|
| 364 |
+
if self.use_rvq_target:
|
| 365 |
+
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(inp)
|
| 366 |
+
return codes
|
| 367 |
+
if self.use_vq_target:
|
| 368 |
+
quantized_prompt_embeds, commitment_loss, codebook_loss, codes, _ = self.rvq(inp)
|
| 369 |
+
return codes.unsqueeze(1)
|
| 370 |
+
if self.use_encodec_target:
|
| 371 |
+
encoded_frames = self.rvq.encode(raw_wav.unsqueeze(1)) #list, B,[ 8,T ]
|
| 372 |
+
codes = torch.cat([encoded[0].detach() for encoded in encoded_frames], dim=-1)
|
| 373 |
+
if self.label_rate == 25:
|
| 374 |
+
codes = codes[:, :, ::3]
|
| 375 |
+
return codes
|
| 376 |
+
|
| 377 |
+
@torch.no_grad()
|
| 378 |
+
def tokenize(self, x, raw_wav):
|
| 379 |
+
out = {}
|
| 380 |
+
for key in x.keys():
|
| 381 |
+
if self.use_rvq_like_target:
|
| 382 |
+
self.rvq.eval()
|
| 383 |
+
inp = x[key].permute((0, 2, 1))
|
| 384 |
+
codes = self.get_rvq_codes(inp, raw_wav)
|
| 385 |
+
out[key] = torch.cat([codes[:, idx, ...] for idx in range(int(self.codebook_size//1024))], dim=-1) # (when use freq mask)->[Batch, N_SubSpec, SeqLen=8*750]
|
| 386 |
+
else:
|
| 387 |
+
layer = getattr(self, "quantizer_%s" % key)
|
| 388 |
+
out[key] = layer(x[key])
|
| 389 |
+
return out
|
| 390 |
+
|
| 391 |
+
def to_spec_wise_quad(self, x):
|
| 392 |
+
Batch, QuadSpec, Time = x.shape
|
| 393 |
+
SubSpec, N_SubSpec = 16, 8
|
| 394 |
+
assert 4 * SubSpec * N_SubSpec == QuadSpec == 4*128
|
| 395 |
+
x = rearrange(x, "b (q n s) t -> b (q s) (n t)", q=4, n=N_SubSpec, s=SubSpec)
|
| 396 |
+
return x # [Batch, SubSpec=16, N_SubSpec*Time=8*100Hz]
|
| 397 |
+
|
| 398 |
+
def get_targets(self, x, label=None):
|
| 399 |
+
if self.use_encodec_target:
|
| 400 |
+
raw_x = x.clone()
|
| 401 |
+
else:
|
| 402 |
+
raw_x = None
|
| 403 |
+
x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
|
| 404 |
+
x = self.normalize(x)
|
| 405 |
+
x = self.rearrange(x) # -> {'melspec_2048': Tensor{Size([3, 750, 512]) cuda:0 f32}}
|
| 406 |
+
melspec = x['melspec_2048']
|
| 407 |
+
if label is None:
|
| 408 |
+
target_tokens = self.tokenize(x, raw_x) # -> {'melspec_2048': Tensor{Size([3, 750]) cuda:0 i64}}
|
| 409 |
+
else:
|
| 410 |
+
# print("use_target from label")
|
| 411 |
+
target_tokens = {'melspec_2048': rearrange(label, "b n s -> b (n s)").long()}
|
| 412 |
+
return target_tokens, melspec
|
| 413 |
+
|
| 414 |
+
def get_predictions(self, x, *, mask=None, attention_mask=None, return_new_mask=False, is_features_only=False):
|
| 415 |
+
# preprocessing
|
| 416 |
+
if not self.use_hubert_featurizer:
|
| 417 |
+
x = self.preprocessing(x, features=["melspec_2048"])
|
| 418 |
+
x = self.normalize(x) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
|
| 419 |
+
else:
|
| 420 |
+
features = self.hubert_feature_extractor(x)
|
| 421 |
+
features = self.layer_norm(features.transpose(1, 2))
|
| 422 |
+
if self.post_extract_proj is not None:
|
| 423 |
+
features = self.post_extract_proj(features)
|
| 424 |
+
x = {"melspec_2048": features}
|
| 425 |
+
|
| 426 |
+
# encoding
|
| 427 |
+
logits, hidden_emb, new_mask = self.encoder(x["melspec_2048"], attention_mask=attention_mask, is_features_only=is_features_only)
|
| 428 |
+
|
| 429 |
+
if return_new_mask:
|
| 430 |
+
return logits, hidden_emb, mask if new_mask is None else new_mask
|
| 431 |
+
else:
|
| 432 |
+
return logits, hidden_emb
|
| 433 |
+
|
| 434 |
+
def get_latent(self, x, layer_ix=12):
|
| 435 |
+
_, hidden_states = self.get_predictions(x)
|
| 436 |
+
emb = hidden_states[layer_ix]
|
| 437 |
+
return emb
|
| 438 |
+
|
| 439 |
+
def compute_nce(self, x, pos, negs):
|
| 440 |
+
neg_is_pos = (pos == negs).all(-1)
|
| 441 |
+
pos = pos.unsqueeze(0)
|
| 442 |
+
targets = torch.cat([pos, negs], dim=0)
|
| 443 |
+
|
| 444 |
+
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
|
| 445 |
+
logits /= 0.1
|
| 446 |
+
if neg_is_pos.any():
|
| 447 |
+
logits[1:][neg_is_pos] = float("-inf")
|
| 448 |
+
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
|
| 449 |
+
return logits
|
| 450 |
+
|
| 451 |
+
def compute_hubert_nce_loss(self, proj_xs, targets):
|
| 452 |
+
|
| 453 |
+
label_embs_list = self.label_embs_concat.split(self.codebook_size, 0) # (self.num_classes, 0)
|
| 454 |
+
|
| 455 |
+
def compute_pred(proj_x, target, label_embs):
|
| 456 |
+
# compute logits for the i-th label set
|
| 457 |
+
y = torch.index_select(label_embs, 0, target.long())
|
| 458 |
+
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
|
| 459 |
+
return self.compute_nce(proj_x, y, negs)
|
| 460 |
+
|
| 461 |
+
logit_list = [
|
| 462 |
+
compute_pred(proj_x, t, label_embs_list[i])
|
| 463 |
+
for i, (proj_x, t) in enumerate(zip(proj_xs, targets))
|
| 464 |
+
]
|
| 465 |
+
|
| 466 |
+
return sum(logit_list)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def get_loss(self, logits, target_tokens, masked_indices):
|
| 470 |
+
losses = {}
|
| 471 |
+
accuracies = {}
|
| 472 |
+
for key in logits.keys():
|
| 473 |
+
if not self.use_rvq_like_target:
|
| 474 |
+
masked_logits = logits[key][tuple(masked_indices.t())]
|
| 475 |
+
masked_tokens = target_tokens[key][tuple(masked_indices.t())]
|
| 476 |
+
else:
|
| 477 |
+
Batch, SeqLen, N_Codebook_x_CodebookSize = logits[key].shape # CodebookSize=4096
|
| 478 |
+
Batch, N_Codebook_x_SeqLen = target_tokens[key].shape # N_Codebook*SeqLen=4*750
|
| 479 |
+
N_Codebook = int(N_Codebook_x_SeqLen // SeqLen)
|
| 480 |
+
# print("not use_virtual, n codebook = ", N_Codebook)
|
| 481 |
+
target_tokens[key] = rearrange(target_tokens[key], "b (n s) -> b s n", n=N_Codebook) # Batch, SeqLen=750, N_Codebook=4
|
| 482 |
+
masked_logits = logits[key][tuple(masked_indices.t())]
|
| 483 |
+
masked_tokens = target_tokens[key][tuple(masked_indices.t())]
|
| 484 |
+
masked_logits = rearrange(masked_logits, "b (n c) -> (b n) c", n=N_Codebook)
|
| 485 |
+
masked_tokens = rearrange(masked_tokens, "b n -> (b n)", n=N_Codebook)
|
| 486 |
+
|
| 487 |
+
if self.use_hubert_nce_loss:
|
| 488 |
+
losses[key] = self.compute_hubert_nce_loss(masked_logits, masked_tokens)
|
| 489 |
+
else:
|
| 490 |
+
losses[key] = self.loss(masked_logits, masked_tokens)
|
| 491 |
+
accuracies[key] = (
|
| 492 |
+
torch.sum(masked_logits.argmax(-1) == masked_tokens)
|
| 493 |
+
/ masked_tokens.numel()
|
| 494 |
+
)
|
| 495 |
+
return losses, accuracies
|
| 496 |
+
|
| 497 |
+
def get_recon_loss(self, last_hidden_emb, melspec, masked_indices):
|
| 498 |
+
pred_melspec = self.recon_proj(last_hidden_emb[tuple(masked_indices.t())])
|
| 499 |
+
target_melspec = melspec[tuple(masked_indices.t())]
|
| 500 |
+
recon_loss = self.recon_loss(pred_melspec, target_melspec)
|
| 501 |
+
return recon_loss
|
| 502 |
+
|
| 503 |
+
def forward(self, x, attention_mask=None, label=None):
|
| 504 |
+
dtype = x.dtype
|
| 505 |
+
# get target feature tokens
|
| 506 |
+
target_tokens, melspec = self.get_targets(x, label=label)
|
| 507 |
+
|
| 508 |
+
# masking
|
| 509 |
+
x, masked_indices = self.masking(x, attention_mask=attention_mask)
|
| 510 |
+
|
| 511 |
+
# forward
|
| 512 |
+
logits, hidden_emb, masked_indices = self.get_predictions(x, mask=masked_indices, attention_mask=attention_mask, return_new_mask=True)
|
| 513 |
+
|
| 514 |
+
# get loss
|
| 515 |
+
losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
|
| 516 |
+
|
| 517 |
+
if self.recon_loss_ratio:
|
| 518 |
+
losses["recon_loss"] = self.get_recon_loss(hidden_emb[-1], melspec, masked_indices) * self.recon_loss_ratio
|
| 519 |
+
|
| 520 |
+
return logits, hidden_emb, losses, accuracies
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/pred_ark_target_with_model.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
import sys, os
|
| 5 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
+
from rvq_musicfm import PreprocessorWithModel, ResidualVectorQuantize
|
| 7 |
+
|
| 8 |
+
class RVQ(nn.Module):
|
| 9 |
+
def __init__(self,
|
| 10 |
+
model_config,
|
| 11 |
+
rvq_ckpt_path,
|
| 12 |
+
preprocess,
|
| 13 |
+
):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.rvq = ResidualVectorQuantize(**model_config)
|
| 16 |
+
if rvq_ckpt_path is not None:
|
| 17 |
+
self.rvq.load_state_dict(torch.load(rvq_ckpt_path, map_location='cpu'))
|
| 18 |
+
self.preprocess = preprocess
|
| 19 |
+
|
| 20 |
+
def get_targets(self, x):
|
| 21 |
+
self.rvq.eval()
|
| 22 |
+
x = self.preprocess(x)
|
| 23 |
+
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(x)
|
| 24 |
+
return codes.permute(1,0,2)
|
| 25 |
+
|
| 26 |
+
@torch.no_grad()
|
| 27 |
+
def encode_wavs(self, wavs):
|
| 28 |
+
wavs = wavs[..., :int((wavs.shape[-1]//320)*320)]
|
| 29 |
+
return self.get_targets(wavs)
|
| 30 |
+
|
| 31 |
+
def This_Music_ModelTarget_Config():
|
| 32 |
+
config = dict(
|
| 33 |
+
model = dict(
|
| 34 |
+
input_dim = 1024,
|
| 35 |
+
n_codebooks = 8,
|
| 36 |
+
codebook_size = 1024,
|
| 37 |
+
codebook_dim = 16,
|
| 38 |
+
quantizer_dropout = 0.0,
|
| 39 |
+
),
|
| 40 |
+
train = dict(
|
| 41 |
+
batch_size = 32,
|
| 42 |
+
num_workers = 6,
|
| 43 |
+
valid_interval = 10,
|
| 44 |
+
save_interval = 100,
|
| 45 |
+
max_updates = 500000,
|
| 46 |
+
lr = 1e-4,
|
| 47 |
+
# device = 'cuda:1',
|
| 48 |
+
loss = 'commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()',
|
| 49 |
+
preprocess = PreprocessorWithModel(
|
| 50 |
+
model_dir= 'path/to/muq_fairseq',
|
| 51 |
+
checkpoint_dir='path/to/muq_m4a_75K.pt',
|
| 52 |
+
use_layer_idx=9,
|
| 53 |
+
)
|
| 54 |
+
),
|
| 55 |
+
pred = dict(
|
| 56 |
+
rvq_ckpt_path='path/to/runs/Aug07_18-09-24_ts-828fa13e58384d0bba4144fda78ecc92-launcher/ckpt/RVQ_8100.pth',
|
| 57 |
+
sr=24000,
|
| 58 |
+
data_jsonl_path='path/to/data/music4all/train.json',
|
| 59 |
+
save_target_dir= 'path/to/data/music4all_ark/reiter_musicssl_m4a',
|
| 60 |
+
),
|
| 61 |
+
)
|
| 62 |
+
return config
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
CLEN = 30
|
| 66 |
+
N_GPU_PER = 8
|
| 67 |
+
N_NODE = 4
|
| 68 |
+
|
| 69 |
+
def parse_lr(wave_length, sr):
|
| 70 |
+
n_step = int( wave_length // (sr*CLEN) )
|
| 71 |
+
if n_step == 0:
|
| 72 |
+
n_step = 1
|
| 73 |
+
print('wave_length: ', wave_length, 'sr: ', sr, 'n_step: ', n_step)
|
| 74 |
+
starts = torch.arange(n_step) * CLEN * sr
|
| 75 |
+
left_rights = torch.stack((starts, starts+CLEN*sr)).T
|
| 76 |
+
return left_rights[:10, ...]
|
| 77 |
+
|
| 78 |
+
@torch.no_grad()
|
| 79 |
+
def main(index, rank):
|
| 80 |
+
device = f'cuda:{rank}'
|
| 81 |
+
config = This_Music_ModelTarget_Config()
|
| 82 |
+
preprocess = config['train']['preprocess']
|
| 83 |
+
model = RVQ(
|
| 84 |
+
model_config = config['model'],
|
| 85 |
+
rvq_ckpt_path = config['pred']['rvq_ckpt_path'],
|
| 86 |
+
preprocess = preprocess
|
| 87 |
+
).to(device)
|
| 88 |
+
model.eval()
|
| 89 |
+
sr = config['pred']['sr']
|
| 90 |
+
|
| 91 |
+
fname_nobase = os.path.basename(config['pred']['data_jsonl_path']).split('.')[0]
|
| 92 |
+
scp_dir = os.path.join(config['pred']['save_target_dir'], 'scp')
|
| 93 |
+
ark_dir = os.path.join(config['pred']['save_target_dir'], 'ark')
|
| 94 |
+
os.makedirs(scp_dir, exist_ok=True)
|
| 95 |
+
os.makedirs(ark_dir, exist_ok=True)
|
| 96 |
+
|
| 97 |
+
scp_path = os.path.join(scp_dir, f'{fname_nobase}.{index}_{rank}.scp')
|
| 98 |
+
ark_path = os.path.join(ark_dir, f'{fname_nobase}.{index}_{rank}.ark')
|
| 99 |
+
|
| 100 |
+
from kaldiio import WriteHelper
|
| 101 |
+
|
| 102 |
+
with open(config['pred']['data_jsonl_path']) as f:
|
| 103 |
+
lines = f.readlines()
|
| 104 |
+
|
| 105 |
+
print("Total:", len(lines))
|
| 106 |
+
|
| 107 |
+
from tqdm import tqdm
|
| 108 |
+
import json
|
| 109 |
+
import librosa
|
| 110 |
+
import time
|
| 111 |
+
from einops import rearrange
|
| 112 |
+
import numpy as np
|
| 113 |
+
|
| 114 |
+
# lines = lines[(index*N_GPU_PER+rank)::(N_GPU_PER*N_NODE)]
|
| 115 |
+
|
| 116 |
+
with WriteHelper(f'ark,scp:{ark_path},{scp_path}') as writer:
|
| 117 |
+
for idx, line in tqdm(enumerate(lines)):
|
| 118 |
+
try:
|
| 119 |
+
if idx % (N_GPU_PER*N_NODE) != (index*N_GPU_PER+rank):
|
| 120 |
+
continue
|
| 121 |
+
item = json.loads(line)
|
| 122 |
+
path = item['path']
|
| 123 |
+
wave, _ = librosa.load(path, sr=sr)
|
| 124 |
+
wave = torch.from_numpy(wave)
|
| 125 |
+
wave_length = wave.shape[-1]
|
| 126 |
+
if wave_length < sr*CLEN:
|
| 127 |
+
continue
|
| 128 |
+
left_rights = parse_lr(wave_length, sr)
|
| 129 |
+
lr = left_rights.tolist()
|
| 130 |
+
wavs = torch.stack(
|
| 131 |
+
[wave[l:r] for l,r in lr]
|
| 132 |
+
).to(device)
|
| 133 |
+
targets = model.encode_wavs(wavs) # [Codebook=8, N_Steps, Feature]
|
| 134 |
+
|
| 135 |
+
final_target = rearrange(targets, "c n f -> n (c f)").cpu().numpy().astype(np.int32)
|
| 136 |
+
for j in range(final_target.shape[0]):
|
| 137 |
+
writer(f'{idx}:{j}', final_target[j])
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(e)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == '__main__':
|
| 143 |
+
import sys
|
| 144 |
+
index = int(sys.argv[1])
|
| 145 |
+
import multiprocessing
|
| 146 |
+
pool = multiprocessing.Pool(processes=N_GPU_PER)
|
| 147 |
+
for rank in range(8):
|
| 148 |
+
pool.apply_async(main, (index, rank))
|
| 149 |
+
pool.close()
|
| 150 |
+
pool.join()
|
| 151 |
+
print("Done.")
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch.nn.utils import weight_norm
|
| 10 |
+
|
| 11 |
+
def WNConv1d(*args, **kwargs):
|
| 12 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class VectorQuantize(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
Implementation of VQ similar to Karpathy's repo:
|
| 18 |
+
https://github.com/karpathy/deep-vector-quantization
|
| 19 |
+
Additionally uses following tricks from Improved VQGAN
|
| 20 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
| 21 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
| 22 |
+
for improved codebook usage
|
| 23 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
| 24 |
+
improves training stability
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 1000, mfcc_clustering=False, n_layer=1):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.codebook_size = codebook_size
|
| 30 |
+
self.codebook_dim = codebook_dim
|
| 31 |
+
self.mfcc_clustering = mfcc_clustering
|
| 32 |
+
|
| 33 |
+
ProjClass = nn.Identity if mfcc_clustering else WNConv1d
|
| 34 |
+
if n_layer==1:
|
| 35 |
+
self.in_proj = ProjClass(input_dim, codebook_dim, kernel_size=1)
|
| 36 |
+
self.out_proj = ProjClass(codebook_dim, input_dim, kernel_size=1)
|
| 37 |
+
elif n_layer >= 2:
|
| 38 |
+
ndim_hidden = 128
|
| 39 |
+
self.in_proj = nn.Sequential(
|
| 40 |
+
ProjClass(input_dim, ndim_hidden, kernel_size=1),
|
| 41 |
+
*[nn.Sequential(nn.ReLU(), ProjClass(ndim_hidden, ndim_hidden, kernel_size=1),) for _ in range(n_layer-2)],
|
| 42 |
+
nn.ReLU(),
|
| 43 |
+
ProjClass(ndim_hidden, codebook_dim, kernel_size=1)
|
| 44 |
+
)
|
| 45 |
+
self.out_proj = nn.Sequential(
|
| 46 |
+
ProjClass(codebook_dim, ndim_hidden, kernel_size=1),
|
| 47 |
+
nn.ReLU(),
|
| 48 |
+
*[nn.Sequential(ProjClass(ndim_hidden, ndim_hidden, kernel_size=1), nn.ReLU()) for _ in range(n_layer-2)],
|
| 49 |
+
ProjClass(ndim_hidden, input_dim, kernel_size=1),
|
| 50 |
+
)
|
| 51 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
| 52 |
+
self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
|
| 53 |
+
self.stale_tolerance = stale_tolerance
|
| 54 |
+
|
| 55 |
+
def forward(self, z):
|
| 56 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
| 57 |
+
the corresponding codebook vectors
|
| 58 |
+
|
| 59 |
+
Parameters
|
| 60 |
+
----------
|
| 61 |
+
z : Tensor[B x D x T]
|
| 62 |
+
|
| 63 |
+
Returns
|
| 64 |
+
-------
|
| 65 |
+
Tensor[B x D x T]
|
| 66 |
+
Quantized continuous representation of input
|
| 67 |
+
Tensor[1]
|
| 68 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 69 |
+
entries
|
| 70 |
+
Tensor[1]
|
| 71 |
+
Codebook loss to update the codebook
|
| 72 |
+
Tensor[B x T]
|
| 73 |
+
Codebook indices (quantized discrete representation of input)
|
| 74 |
+
Tensor[B x D x T]
|
| 75 |
+
Projected latents (continuous representation of input before quantization)
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
| 79 |
+
|
| 80 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
| 81 |
+
z_q, indices = self.decode_latents(z_e)
|
| 82 |
+
|
| 83 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 84 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 85 |
+
|
| 86 |
+
z_q = (
|
| 87 |
+
z_e + (z_q - z_e).detach()
|
| 88 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
| 89 |
+
|
| 90 |
+
z_q = self.out_proj(z_q)
|
| 91 |
+
|
| 92 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
| 93 |
+
|
| 94 |
+
def embed_code(self, embed_id):
|
| 95 |
+
return F.embedding(embed_id, self.codebook.weight)
|
| 96 |
+
|
| 97 |
+
def decode_code(self, embed_id):
|
| 98 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
| 99 |
+
|
| 100 |
+
def decode_latents(self, latents):
|
| 101 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 102 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
| 103 |
+
|
| 104 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
| 105 |
+
encodings = F.normalize(encodings)
|
| 106 |
+
codebook = F.normalize(codebook)
|
| 107 |
+
|
| 108 |
+
# Compute euclidean distance with codebook
|
| 109 |
+
dist = (
|
| 110 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 111 |
+
- 2 * encodings @ codebook.t()
|
| 112 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 113 |
+
)
|
| 114 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 115 |
+
z_q = self.decode_code(indices)
|
| 116 |
+
|
| 117 |
+
if(self.training):
|
| 118 |
+
onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
|
| 119 |
+
stale_codes = (onehots.sum(0).sum(0) == 0).float()
|
| 120 |
+
self.stale_counter = self.stale_counter * stale_codes + stale_codes
|
| 121 |
+
|
| 122 |
+
# random replace codes that haven't been used for a while
|
| 123 |
+
replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
|
| 124 |
+
if replace_code.sum(-1) > 0:
|
| 125 |
+
print("Replace {} codes".format(replace_code.sum(-1)))
|
| 126 |
+
random_input_idx = torch.randperm(encodings.shape[0])
|
| 127 |
+
random_input = encodings[random_input_idx].view(encodings.shape)
|
| 128 |
+
if random_input.shape[0] < self.codebook_size:
|
| 129 |
+
random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
|
| 130 |
+
random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
|
| 131 |
+
|
| 132 |
+
self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
|
| 133 |
+
self.stale_counter = self.stale_counter * (1 - replace_code)
|
| 134 |
+
|
| 135 |
+
return z_q, indices
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class ResidualVectorQuantize(nn.Module):
|
| 139 |
+
"""
|
| 140 |
+
Introduced in SoundStream: An end2end neural audio codec
|
| 141 |
+
https://arxiv.org/abs/2107.03312
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
input_dim: int = 512,
|
| 147 |
+
n_codebooks: int = 9,
|
| 148 |
+
codebook_size: int = 1024,
|
| 149 |
+
codebook_dim: Union[int, list] = 8,
|
| 150 |
+
quantizer_dropout: float = 0.0,
|
| 151 |
+
stale_tolerance: int = 100,
|
| 152 |
+
use_multi_layer_num:int = 1,
|
| 153 |
+
):
|
| 154 |
+
super().__init__()
|
| 155 |
+
if isinstance(codebook_dim, int):
|
| 156 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
| 157 |
+
|
| 158 |
+
self.n_codebooks = n_codebooks
|
| 159 |
+
self.codebook_dim = codebook_dim
|
| 160 |
+
self.codebook_size = codebook_size
|
| 161 |
+
|
| 162 |
+
self.quantizers = nn.ModuleList(
|
| 163 |
+
[
|
| 164 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance, n_layer=use_multi_layer_num)
|
| 165 |
+
for i in range(n_codebooks)
|
| 166 |
+
]
|
| 167 |
+
)
|
| 168 |
+
self.quantizer_dropout = quantizer_dropout
|
| 169 |
+
|
| 170 |
+
def forward(self, z, n_quantizers: int = None):
|
| 171 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
| 172 |
+
the corresponding codebook vectors
|
| 173 |
+
Parameters
|
| 174 |
+
----------
|
| 175 |
+
z : Tensor[B x D x T]
|
| 176 |
+
n_quantizers : int, optional
|
| 177 |
+
No. of quantizers to use
|
| 178 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
| 179 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
| 180 |
+
when in training mode, and a random number of quantizers is used.
|
| 181 |
+
Returns
|
| 182 |
+
-------
|
| 183 |
+
dict
|
| 184 |
+
A dictionary with the following keys:
|
| 185 |
+
|
| 186 |
+
"z" : Tensor[B x D x T]
|
| 187 |
+
Quantized continuous representation of input
|
| 188 |
+
"codes" : Tensor[B x N x T]
|
| 189 |
+
Codebook indices for each codebook
|
| 190 |
+
(quantized discrete representation of input)
|
| 191 |
+
"latents" : Tensor[B x N*D x T]
|
| 192 |
+
Projected latents (continuous representation of input before quantization)
|
| 193 |
+
"vq/commitment_loss" : Tensor[1]
|
| 194 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 195 |
+
entries
|
| 196 |
+
"vq/codebook_loss" : Tensor[1]
|
| 197 |
+
Codebook loss to update the codebook
|
| 198 |
+
"""
|
| 199 |
+
z_q = 0
|
| 200 |
+
residual = z
|
| 201 |
+
commitment_loss = 0
|
| 202 |
+
codebook_loss = 0
|
| 203 |
+
|
| 204 |
+
codebook_indices = []
|
| 205 |
+
latents = []
|
| 206 |
+
|
| 207 |
+
if n_quantizers is None:
|
| 208 |
+
n_quantizers = self.n_codebooks
|
| 209 |
+
if self.training:
|
| 210 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
| 211 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
| 212 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
| 213 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 214 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 215 |
+
else:
|
| 216 |
+
n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
|
| 217 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 218 |
+
|
| 219 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 220 |
+
# if self.training is False and i >= n_quantizers:
|
| 221 |
+
# break
|
| 222 |
+
|
| 223 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
| 224 |
+
residual
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Create mask to apply quantizer dropout
|
| 228 |
+
mask = (
|
| 229 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
| 230 |
+
)
|
| 231 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
| 232 |
+
residual = residual - z_q_i
|
| 233 |
+
|
| 234 |
+
# Sum losses
|
| 235 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
| 236 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
| 237 |
+
|
| 238 |
+
codebook_indices.append(indices_i)
|
| 239 |
+
latents.append(z_e_i)
|
| 240 |
+
|
| 241 |
+
codes = torch.stack(codebook_indices, dim=1)
|
| 242 |
+
latents = torch.cat(latents, dim=1)
|
| 243 |
+
|
| 244 |
+
encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
|
| 245 |
+
|
| 246 |
+
return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
|
| 247 |
+
|
| 248 |
+
def from_codes(self, codes: torch.Tensor):
|
| 249 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
| 250 |
+
Parameters
|
| 251 |
+
----------
|
| 252 |
+
codes : Tensor[B x N x T]
|
| 253 |
+
Quantized discrete representation of input
|
| 254 |
+
Returns
|
| 255 |
+
-------
|
| 256 |
+
Tensor[B x D x T]
|
| 257 |
+
Quantized continuous representation of input
|
| 258 |
+
"""
|
| 259 |
+
z_q = 0.0
|
| 260 |
+
z_p = []
|
| 261 |
+
n_codebooks = codes.shape[1]
|
| 262 |
+
for i in range(n_codebooks):
|
| 263 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
| 264 |
+
z_p.append(z_p_i)
|
| 265 |
+
|
| 266 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 267 |
+
z_q = z_q + z_q_i
|
| 268 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
| 269 |
+
|
| 270 |
+
def from_latents(self, latents: torch.Tensor):
|
| 271 |
+
"""Given the unquantized latents, reconstruct the
|
| 272 |
+
continuous representation after quantization.
|
| 273 |
+
|
| 274 |
+
Parameters
|
| 275 |
+
----------
|
| 276 |
+
latents : Tensor[B x N x T]
|
| 277 |
+
Continuous representation of input after projection
|
| 278 |
+
|
| 279 |
+
Returns
|
| 280 |
+
-------
|
| 281 |
+
Tensor[B x D x T]
|
| 282 |
+
Quantized representation of full-projected space
|
| 283 |
+
Tensor[B x D x T]
|
| 284 |
+
Quantized representation of latent space
|
| 285 |
+
"""
|
| 286 |
+
z_q = 0
|
| 287 |
+
z_p = []
|
| 288 |
+
codes = []
|
| 289 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
| 290 |
+
|
| 291 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
| 292 |
+
0
|
| 293 |
+
]
|
| 294 |
+
for i in range(n_codebooks):
|
| 295 |
+
j, k = dims[i], dims[i + 1]
|
| 296 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
| 297 |
+
z_p.append(z_p_i)
|
| 298 |
+
codes.append(codes_i)
|
| 299 |
+
|
| 300 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 301 |
+
z_q = z_q + z_q_i
|
| 302 |
+
|
| 303 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
| 304 |
+
|
| 305 |
+
from torch.utils.data import Dataset, DataLoader
|
| 306 |
+
import json, traceback
|
| 307 |
+
import torchaudio
|
| 308 |
+
import math
|
| 309 |
+
|
| 310 |
+
from typing import List, Tuple, Dict, Any
|
| 311 |
+
|
| 312 |
+
CLIPSECS = 5
|
| 313 |
+
def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate):
|
| 314 |
+
# read json file
|
| 315 |
+
print(json_path)
|
| 316 |
+
datas = []
|
| 317 |
+
inds = []
|
| 318 |
+
sizes = []
|
| 319 |
+
with open(json_path) as fp:
|
| 320 |
+
for ind,line in enumerate(fp):
|
| 321 |
+
data = json.loads(line)
|
| 322 |
+
datas.append(data)
|
| 323 |
+
inds.append(ind)
|
| 324 |
+
# sz = int(data['duration'] * data['sample_rate'])
|
| 325 |
+
sz = int(tgt_sample_rate * CLIPSECS)
|
| 326 |
+
sizes.append(sz)
|
| 327 |
+
tot = ind + 1
|
| 328 |
+
return datas,inds,tot,sizes
|
| 329 |
+
|
| 330 |
+
class Read_and_PadCrop_Normalized_T(torch.nn.Module):
|
| 331 |
+
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
| 332 |
+
|
| 333 |
+
super().__init__()
|
| 334 |
+
|
| 335 |
+
self.n_samples = n_samples
|
| 336 |
+
self.sample_rate = sample_rate
|
| 337 |
+
self.randomize = randomize
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
|
| 341 |
+
if(duration<(float(self.n_samples)/self.sample_rate+1)):
|
| 342 |
+
# print(duration,(float(self.n_samples)/self.sample_rate+1))
|
| 343 |
+
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
| 344 |
+
t_start = 0.
|
| 345 |
+
t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
|
| 346 |
+
offset = 0
|
| 347 |
+
# print('c1:',chunk.shape)
|
| 348 |
+
else:
|
| 349 |
+
offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
| 350 |
+
t_start = offset / float(cur_sample_rate) / duration
|
| 351 |
+
t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
|
| 352 |
+
chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
| 353 |
+
# print('offset:',offset)
|
| 354 |
+
# print('c0:',chunk.shape)
|
| 355 |
+
# Pad with silence if necessary.
|
| 356 |
+
if(chunk.shape[0]>1):
|
| 357 |
+
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
| 358 |
+
else:
|
| 359 |
+
chunk = chunk[[0],:].float()
|
| 360 |
+
if(cur_sample_rate!=self.sample_rate):
|
| 361 |
+
# print('a:',cur_sample_rate,chunk.shape)
|
| 362 |
+
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
|
| 363 |
+
# print('b:',self.sample_rate,chunk.shape)
|
| 364 |
+
if chunk.shape[-1] < self.n_samples:
|
| 365 |
+
chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
|
| 366 |
+
else:
|
| 367 |
+
chunk = chunk[:,0:self.n_samples]
|
| 368 |
+
seconds_start = math.floor(offset / cur_sample_rate)
|
| 369 |
+
seconds_total = math.floor(duration)
|
| 370 |
+
|
| 371 |
+
return (
|
| 372 |
+
chunk,
|
| 373 |
+
t_start,
|
| 374 |
+
t_end,
|
| 375 |
+
seconds_start,
|
| 376 |
+
seconds_total
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
class RVQDataset(Dataset):
|
| 380 |
+
def __init__(
|
| 381 |
+
self,
|
| 382 |
+
manifest_path: str,
|
| 383 |
+
sample_rate: float,
|
| 384 |
+
normalize: bool = False,
|
| 385 |
+
):
|
| 386 |
+
self.sample_rate = sample_rate
|
| 387 |
+
self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path, None, None, self.sample_rate)
|
| 388 |
+
self.dataset_len = len(self.datas)
|
| 389 |
+
|
| 390 |
+
self.reader = Read_and_PadCrop_Normalized_T(n_samples=CLIPSECS*sample_rate,sample_rate = self.sample_rate)
|
| 391 |
+
self.normalize = normalize
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def __getitem__(self, i):
|
| 395 |
+
# WORLD_SIZE = int(torch.distributed.get_world_size())
|
| 396 |
+
# WORLD_RANK = int(torch.distributed.get_rank())
|
| 397 |
+
# np.random.seed(1337 + self.epoch * WORLD_SIZE + WORLD_RANK + i)
|
| 398 |
+
# index = random.randint(0,len(self.sizes) - 1)
|
| 399 |
+
index = i
|
| 400 |
+
item = None
|
| 401 |
+
while item is None:
|
| 402 |
+
try:
|
| 403 |
+
wav = self.get_audio_by_slice(index)
|
| 404 |
+
# labels = self.get_labels(index)
|
| 405 |
+
# labels = None
|
| 406 |
+
# item = {"id": index, "source": wav, "label_list": labels}
|
| 407 |
+
item = {"id": index, "source": wav}
|
| 408 |
+
except Exception as e:
|
| 409 |
+
# print(e)
|
| 410 |
+
traceback.print_exc()
|
| 411 |
+
print(f'skip damaged data {index}')
|
| 412 |
+
index = np.random.randint(0,len(self.sizes)-1)
|
| 413 |
+
return item
|
| 414 |
+
|
| 415 |
+
def __len__(self):
|
| 416 |
+
return self.dataset_len
|
| 417 |
+
|
| 418 |
+
def get_audio_by_slice(self,index):
|
| 419 |
+
|
| 420 |
+
wav_path = self.datas[index]['path']
|
| 421 |
+
# print(wav_path)
|
| 422 |
+
audio_info = torchaudio.info(wav_path)
|
| 423 |
+
origin_sample_rate = audio_info.sample_rate
|
| 424 |
+
origin_duration = audio_info.num_frames / origin_sample_rate
|
| 425 |
+
|
| 426 |
+
wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
|
| 427 |
+
wav = wav.float()
|
| 428 |
+
|
| 429 |
+
# _path, slice_ptr = parse_path(wav_path)
|
| 430 |
+
# original way
|
| 431 |
+
# if len(slice_ptr) == 0:
|
| 432 |
+
# wav, cur_sample_rate = sf.read(_path)
|
| 433 |
+
# else:
|
| 434 |
+
# assert _path.endswith(".zip")
|
| 435 |
+
# data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
| 436 |
+
# f = io.BytesIO(data)
|
| 437 |
+
# wav, cur_sample_rate = sf.read(f)
|
| 438 |
+
# wav = torch.from_numpy(wav).float()
|
| 439 |
+
# print(wav.shape)
|
| 440 |
+
wav = wav.permute(1,0)
|
| 441 |
+
wav = self.postprocess(wav, self.sample_rate)
|
| 442 |
+
# print(wav.shape)
|
| 443 |
+
|
| 444 |
+
# wav = wav.squeeze(0)
|
| 445 |
+
return wav
|
| 446 |
+
|
| 447 |
+
def postprocess(self, wav, cur_sample_rate):
|
| 448 |
+
if wav.dim() == 2:
|
| 449 |
+
wav = wav.mean(-1)
|
| 450 |
+
assert wav.dim() == 1, wav.dim()
|
| 451 |
+
|
| 452 |
+
if cur_sample_rate != self.sample_rate:
|
| 453 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
| 454 |
+
|
| 455 |
+
if self.normalize:
|
| 456 |
+
with torch.no_grad():
|
| 457 |
+
wav = F.layer_norm(wav, wav.shape)
|
| 458 |
+
return wav
|
| 459 |
+
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq_muq.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from .rvq import *
|
| 3 |
+
except:
|
| 4 |
+
import sys, os
|
| 5 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
+
from rvq import *
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from ..modules.random_quantizer import RandomProjectionQuantizer
|
| 10 |
+
from ..modules.features import MelSTFT
|
| 11 |
+
from ..modules.conv import Conv2dSubsampling
|
| 12 |
+
except:
|
| 13 |
+
import sys, os
|
| 14 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
| 15 |
+
from modules.random_quantizer import RandomProjectionQuantizer
|
| 16 |
+
from modules.features import MelSTFT
|
| 17 |
+
from modules.conv import Conv2dSubsampling
|
| 18 |
+
|
| 19 |
+
import fairseq
|
| 20 |
+
|
| 21 |
+
CLIPSECS = 5 # 5 for rvq, 30 for model
|
| 22 |
+
|
| 23 |
+
class RVQDataset(Dataset):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
manifest_path: str,
|
| 27 |
+
sample_rate: float,
|
| 28 |
+
normalize: bool = False,
|
| 29 |
+
):
|
| 30 |
+
self.sample_rate = sample_rate
|
| 31 |
+
self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path, None, None, self.sample_rate)
|
| 32 |
+
self.dataset_len = len(self.datas)
|
| 33 |
+
|
| 34 |
+
self.reader = Read_and_PadCrop_Normalized_T(n_samples=CLIPSECS*sample_rate,sample_rate = self.sample_rate)
|
| 35 |
+
self.normalize = normalize
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, i):
|
| 39 |
+
# WORLD_SIZE = int(torch.distributed.get_world_size())
|
| 40 |
+
# WORLD_RANK = int(torch.distributed.get_rank())
|
| 41 |
+
# np.random.seed(1337 + self.epoch * WORLD_SIZE + WORLD_RANK + i)
|
| 42 |
+
# index = random.randint(0,len(self.sizes) - 1)
|
| 43 |
+
index = i
|
| 44 |
+
item = None
|
| 45 |
+
while item is None:
|
| 46 |
+
try:
|
| 47 |
+
wav = self.get_audio_by_slice(index)
|
| 48 |
+
item = {"id": index, "source": wav}
|
| 49 |
+
except Exception as e:
|
| 50 |
+
# print(e)
|
| 51 |
+
traceback.print_exc()
|
| 52 |
+
print(f'skip damaged data {index}')
|
| 53 |
+
index = np.random.randint(0,len(self.sizes)-1)
|
| 54 |
+
return item
|
| 55 |
+
|
| 56 |
+
def __len__(self):
|
| 57 |
+
return self.dataset_len
|
| 58 |
+
|
| 59 |
+
def get_audio_by_slice(self,index):
|
| 60 |
+
|
| 61 |
+
wav_path = self.datas[index]['path']
|
| 62 |
+
audio_info = torchaudio.info(wav_path)
|
| 63 |
+
origin_sample_rate = audio_info.sample_rate
|
| 64 |
+
origin_duration = audio_info.num_frames / origin_sample_rate
|
| 65 |
+
|
| 66 |
+
wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
|
| 67 |
+
wav = wav.float()
|
| 68 |
+
|
| 69 |
+
# _path, slice_ptr = parse_path(wav_path)
|
| 70 |
+
# original way
|
| 71 |
+
# if len(slice_ptr) == 0:
|
| 72 |
+
# wav, cur_sample_rate = sf.read(_path)
|
| 73 |
+
# else:
|
| 74 |
+
# assert _path.endswith(".zip")
|
| 75 |
+
# data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
| 76 |
+
# f = io.BytesIO(data)
|
| 77 |
+
# wav, cur_sample_rate = sf.read(f)
|
| 78 |
+
# wav = torch.from_numpy(wav).float()
|
| 79 |
+
# print(wav.shape)
|
| 80 |
+
wav = wav.permute(1,0)
|
| 81 |
+
wav = self.postprocess(wav, self.sample_rate)
|
| 82 |
+
# print(wav.shape)
|
| 83 |
+
|
| 84 |
+
# wav = wav.squeeze(0)
|
| 85 |
+
return wav
|
| 86 |
+
|
| 87 |
+
def postprocess(self, wav, cur_sample_rate):
|
| 88 |
+
if wav.dim() == 2:
|
| 89 |
+
wav = wav.mean(-1)
|
| 90 |
+
assert wav.dim() == 1, wav.dim()
|
| 91 |
+
|
| 92 |
+
if cur_sample_rate != self.sample_rate:
|
| 93 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
| 94 |
+
|
| 95 |
+
if self.normalize:
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
wav = F.layer_norm(wav, wav.shape)
|
| 98 |
+
return wav
|
| 99 |
+
|
| 100 |
+
class Preprocessor(nn.Module):
|
| 101 |
+
def __init__(self,
|
| 102 |
+
codebook_dim=16,
|
| 103 |
+
codebook_size=4096,
|
| 104 |
+
hop_length=240,
|
| 105 |
+
n_mels=128,
|
| 106 |
+
stat_path=None,
|
| 107 |
+
is_spec_wise=False,
|
| 108 |
+
s=4,
|
| 109 |
+
) -> None:
|
| 110 |
+
super().__init__()
|
| 111 |
+
|
| 112 |
+
self.features=["melspec_2048"]
|
| 113 |
+
self.s = s
|
| 114 |
+
|
| 115 |
+
# load feature mean / std stats
|
| 116 |
+
import os
|
| 117 |
+
if stat_path is not None and os.path.exists(stat_path):
|
| 118 |
+
with open(stat_path, "r") as f:
|
| 119 |
+
self.stat = json.load(f)
|
| 120 |
+
else:
|
| 121 |
+
# print("No stats file found at `{}`, use default from msd.".format(stat_path))
|
| 122 |
+
self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
|
| 123 |
+
|
| 124 |
+
# feature extractor
|
| 125 |
+
self.preprocessor_melspec_2048 = MelSTFT(
|
| 126 |
+
n_fft=2048, hop_length=hop_length, is_db=True
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.is_spec_wise = is_spec_wise
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@torch.no_grad()
|
| 133 |
+
def normalize(self, x):
|
| 134 |
+
"""normalize the input audio to have zero mean unit variance"""
|
| 135 |
+
for key in x.keys():
|
| 136 |
+
x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] # {'melspec_2048_cnt': 14282760192, 'melspec_2048_mean': 6.768444971712967}
|
| 137 |
+
return x
|
| 138 |
+
|
| 139 |
+
@torch.no_grad()
|
| 140 |
+
def rearrange(self, x):
|
| 141 |
+
"""rearrange the batch to flatten every 4 steps"""
|
| 142 |
+
for key in x.keys():
|
| 143 |
+
if key == "chromagram":
|
| 144 |
+
x[key] = rearrange(x[key], "b f t -> b t f")
|
| 145 |
+
else:
|
| 146 |
+
x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=self.s)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
@torch.no_grad()
|
| 150 |
+
def preprocessing(self, x, features):
|
| 151 |
+
"""extract classic audio features"""
|
| 152 |
+
# check precision
|
| 153 |
+
if x.dtype == torch.float16:
|
| 154 |
+
precision = 16
|
| 155 |
+
else:
|
| 156 |
+
precision = 32
|
| 157 |
+
|
| 158 |
+
out = {}
|
| 159 |
+
for key in features:
|
| 160 |
+
layer = getattr(self, "preprocessor_%s" % key)
|
| 161 |
+
out[key] = layer.float()(x.float())[..., :-1]
|
| 162 |
+
if precision == 16:
|
| 163 |
+
out[key] = out[key].half()
|
| 164 |
+
return out
|
| 165 |
+
|
| 166 |
+
@torch.no_grad()
|
| 167 |
+
def tokenize(self, x):
|
| 168 |
+
out = {}
|
| 169 |
+
for key in x.keys():
|
| 170 |
+
layer = getattr(self, "quantizer_%s" % key)
|
| 171 |
+
out[key] = layer(x[key])
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
+
def to_spec_wise(self, x):
|
| 175 |
+
Batch, Spec, Time = x.shape
|
| 176 |
+
SubSpec, N_SubSpec = 16, 8
|
| 177 |
+
assert SubSpec * N_SubSpec == Spec == 128
|
| 178 |
+
x = rearrange(x, "b (n s) t -> b s (n t)", n=N_SubSpec, s=SubSpec)
|
| 179 |
+
return x # [Batch, SubSpec=16, N_SubSpec*Time=8*100Hz]
|
| 180 |
+
|
| 181 |
+
@torch.no_grad()
|
| 182 |
+
def __call__(self, x):
|
| 183 |
+
x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
|
| 184 |
+
x = self.normalize(x)
|
| 185 |
+
if self.is_spec_wise:
|
| 186 |
+
x = {k:self.to_spec_wise(v) for k,v in x.items()}
|
| 187 |
+
x = self.rearrange(x) # -> {'melspec_2048': Tensor{Size([3, 750, 512]) cuda:0 f32}}
|
| 188 |
+
return x['melspec_2048'].permute((0, 2, 1))
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class CQTPreprocessor(nn.Module):
|
| 192 |
+
def __init__(self,
|
| 193 |
+
sr=24000,
|
| 194 |
+
hop=960,
|
| 195 |
+
nb=84,
|
| 196 |
+
to_db = True,
|
| 197 |
+
) -> None:
|
| 198 |
+
super().__init__()
|
| 199 |
+
|
| 200 |
+
from nnAudio.features.cqt import CQT
|
| 201 |
+
import torchaudio
|
| 202 |
+
self.cqt_fn = CQT(
|
| 203 |
+
sr=sr,
|
| 204 |
+
hop_length=hop,
|
| 205 |
+
n_bins=nb,
|
| 206 |
+
fmin=32.7 if nb == 84 else 27.5, # 84 or 88
|
| 207 |
+
bins_per_octave=12,
|
| 208 |
+
filter_scale=1,
|
| 209 |
+
norm=1,
|
| 210 |
+
window='hann',
|
| 211 |
+
center=True,
|
| 212 |
+
pad_mode='constant',
|
| 213 |
+
trainable=False,
|
| 214 |
+
output_format='Magnitude',
|
| 215 |
+
verbose=True,
|
| 216 |
+
)
|
| 217 |
+
if to_db:
|
| 218 |
+
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
|
| 219 |
+
else:
|
| 220 |
+
self.amplitude_to_db = lambda x:x
|
| 221 |
+
|
| 222 |
+
@torch.no_grad()
|
| 223 |
+
def __call__(self, x):
|
| 224 |
+
return self.amplitude_to_db(self.cqt_fn(x))
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
from dataclasses import dataclass
|
| 228 |
+
|
| 229 |
+
@dataclass
|
| 230 |
+
class UserDirModule:
|
| 231 |
+
user_dir: str
|
| 232 |
+
|
| 233 |
+
def load_model(model_dir, checkpoint_dir):
|
| 234 |
+
'''Load Fairseq SSL model'''
|
| 235 |
+
|
| 236 |
+
if model_dir is not None:
|
| 237 |
+
model_path = UserDirModule(model_dir)
|
| 238 |
+
fairseq.utils.import_user_module(model_path)
|
| 239 |
+
|
| 240 |
+
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_dir], strict=False)
|
| 241 |
+
model = model[0]
|
| 242 |
+
|
| 243 |
+
return model
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class PreprocessorWithModel(nn.Module):
|
| 248 |
+
def __init__(self, model_dir, checkpoint_dir, use_layer_idx=9) -> None:
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.model = load_model(model_dir=model_dir, checkpoint_dir=checkpoint_dir)
|
| 251 |
+
self.model.eval()
|
| 252 |
+
self.use_layer_idx = use_layer_idx
|
| 253 |
+
|
| 254 |
+
def forward(self, x):
|
| 255 |
+
with torch.no_grad():
|
| 256 |
+
self.model.eval()
|
| 257 |
+
res = self.model(x, features_only = True)
|
| 258 |
+
layer_results = res['layer_results']
|
| 259 |
+
return layer_results[self.use_layer_idx].permute(0,2,1)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def Music_Mel_Target_Config():
|
| 264 |
+
config = dict(
|
| 265 |
+
train_dataset = dict(
|
| 266 |
+
manifest_path = 'path/to/data/music4all/train.json',
|
| 267 |
+
sample_rate = 24000,
|
| 268 |
+
normalize = False,
|
| 269 |
+
),
|
| 270 |
+
valid_dataset = dict(
|
| 271 |
+
manifest_path = 'path/to/data/music4all/valid.json',
|
| 272 |
+
sample_rate = 24000,
|
| 273 |
+
normalize = False,
|
| 274 |
+
),
|
| 275 |
+
model = dict(
|
| 276 |
+
input_dim = 128*4,
|
| 277 |
+
n_codebooks = 8,
|
| 278 |
+
codebook_size = 1024,
|
| 279 |
+
codebook_dim = 16,
|
| 280 |
+
quantizer_dropout = 0.0,
|
| 281 |
+
),
|
| 282 |
+
train = dict(
|
| 283 |
+
batch_size = 32,
|
| 284 |
+
num_workers = 6,
|
| 285 |
+
valid_interval = 10,
|
| 286 |
+
save_interval = 100,
|
| 287 |
+
max_updates = 500000,
|
| 288 |
+
lr = 1e-4,
|
| 289 |
+
device = 'cuda:0',
|
| 290 |
+
loss = 'commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()',
|
| 291 |
+
preprocess = Preprocessor()
|
| 292 |
+
)
|
| 293 |
+
)
|
| 294 |
+
return config
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def main(config):
|
| 298 |
+
train_dataset = RVQDataset(**config['train_dataset'])
|
| 299 |
+
if config['valid_dataset']['manifest_path'] is None:
|
| 300 |
+
# split train and valid dataset
|
| 301 |
+
from torch.utils.data import random_split
|
| 302 |
+
train_dataset, valid_dataset = random_split(
|
| 303 |
+
train_dataset, lengths=[len(train_dataset) - 500, 500]
|
| 304 |
+
)
|
| 305 |
+
else:
|
| 306 |
+
valid_dataset = RVQDataset(**config['valid_dataset'])
|
| 307 |
+
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers'])
|
| 308 |
+
valid_dataloader = DataLoader(valid_dataset, shuffle=False, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers'])
|
| 309 |
+
model = ResidualVectorQuantize(**config['model'])
|
| 310 |
+
|
| 311 |
+
device = config['train']['device']
|
| 312 |
+
preprocess = config['train']['preprocess'].to(device)
|
| 313 |
+
model = model.to(device)
|
| 314 |
+
|
| 315 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'])
|
| 316 |
+
cur_updates = 0
|
| 317 |
+
is_running = True
|
| 318 |
+
result = {}
|
| 319 |
+
from tqdm import tqdm
|
| 320 |
+
from tensorboardX import SummaryWriter
|
| 321 |
+
writer = SummaryWriter()
|
| 322 |
+
from collections import defaultdict
|
| 323 |
+
import os
|
| 324 |
+
from logging import getLogger
|
| 325 |
+
logger = getLogger()
|
| 326 |
+
|
| 327 |
+
while is_running:
|
| 328 |
+
results = defaultdict(lambda:0)
|
| 329 |
+
for item in tqdm(train_dataloader, desc='train'):
|
| 330 |
+
wavs = item['source']
|
| 331 |
+
optimizer.zero_grad()
|
| 332 |
+
wavs = wavs.to(device)
|
| 333 |
+
x = preprocess(wavs)
|
| 334 |
+
model.train()
|
| 335 |
+
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x)
|
| 336 |
+
loss = eval(config['train']['loss'])
|
| 337 |
+
loss.backward()
|
| 338 |
+
optimizer.step()
|
| 339 |
+
|
| 340 |
+
results['loss/train'] += loss.item()
|
| 341 |
+
results['commitment_loss/train'] += commitment_loss.item()
|
| 342 |
+
results['codebook_loss/train'] += codebook_loss.item()
|
| 343 |
+
results['rvq_usage/train'] += rvq_usage.float().mean().item()
|
| 344 |
+
|
| 345 |
+
if cur_updates % config['train']['valid_interval'] == 0:
|
| 346 |
+
model.eval()
|
| 347 |
+
with torch.no_grad():
|
| 348 |
+
for item in tqdm(valid_dataloader, desc='valid'):
|
| 349 |
+
wavs = item['source']
|
| 350 |
+
wavs = wavs.to(device)
|
| 351 |
+
x = preprocess(wavs)
|
| 352 |
+
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x)
|
| 353 |
+
valid_loss = eval(config['train']['loss'])
|
| 354 |
+
|
| 355 |
+
results['loss/valid'] += valid_loss.item()
|
| 356 |
+
results['commitment_loss/valid'] += commitment_loss.item()
|
| 357 |
+
results['codebook_loss/valid'] += codebook_loss.item()
|
| 358 |
+
results['rvq_usage/valid'] += rvq_usage.float().mean().item()
|
| 359 |
+
|
| 360 |
+
results['cur_updates'] = cur_updates
|
| 361 |
+
results['loss/train'] /= config['train']['valid_interval']
|
| 362 |
+
results['commitment_loss/train'] /= config['train']['valid_interval']
|
| 363 |
+
results['codebook_loss/train'] /= config['train']['valid_interval']
|
| 364 |
+
results['rvq_usage/train'] /= config['train']['valid_interval']
|
| 365 |
+
|
| 366 |
+
results['loss/valid'] /= len(valid_dataloader)
|
| 367 |
+
results['commitment_loss/valid'] /= len(valid_dataloader)
|
| 368 |
+
results['codebook_loss/valid'] /= len(valid_dataloader)
|
| 369 |
+
results['rvq_usage/valid'] /= len(valid_dataloader)
|
| 370 |
+
|
| 371 |
+
print('')
|
| 372 |
+
logger.info(str(results))
|
| 373 |
+
for k,v in results.items():
|
| 374 |
+
writer.add_scalar(k, v, cur_updates)
|
| 375 |
+
|
| 376 |
+
results.clear()
|
| 377 |
+
|
| 378 |
+
if cur_updates % config['train']['save_interval'] == 0:
|
| 379 |
+
os.makedirs(f'{writer.logdir}/ckpt/', exist_ok=True)
|
| 380 |
+
logger.info(f'saving checkpoint to {writer.logdir}/ckpt/RVQ_{cur_updates}.pth')
|
| 381 |
+
torch.save(model.state_dict(), f'{writer.logdir}/ckpt/RVQ_{cur_updates}.pth')
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
if cur_updates < config['train']['max_updates']:
|
| 385 |
+
cur_updates += 1
|
| 386 |
+
else:
|
| 387 |
+
is_running = False
|
| 388 |
+
break
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
if __name__ == '__main__':
|
| 393 |
+
config = Music_Mel_Target_Config()
|
| 394 |
+
main(config)
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/w2v2_config.json
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation_dropout": 0.1,
|
| 3 |
+
"adapter_kernel_size": 3,
|
| 4 |
+
"adapter_stride": 2,
|
| 5 |
+
"add_adapter": false,
|
| 6 |
+
"apply_spec_augment": true,
|
| 7 |
+
"architectures": [
|
| 8 |
+
"Wav2Vec2ConformerForCTC"
|
| 9 |
+
],
|
| 10 |
+
"attention_dropout": 0.1,
|
| 11 |
+
"bos_token_id": 1,
|
| 12 |
+
"classifier_proj_size": 256,
|
| 13 |
+
"codevector_dim": 768,
|
| 14 |
+
"conformer_conv_dropout": 0.1,
|
| 15 |
+
"contrastive_logits_temperature": 0.1,
|
| 16 |
+
"conv_bias": true,
|
| 17 |
+
"conv_depthwise_kernel_size": 31,
|
| 18 |
+
"conv_dim": [
|
| 19 |
+
512,
|
| 20 |
+
512,
|
| 21 |
+
512,
|
| 22 |
+
512,
|
| 23 |
+
512,
|
| 24 |
+
512,
|
| 25 |
+
512
|
| 26 |
+
],
|
| 27 |
+
"conv_kernel": [
|
| 28 |
+
10,
|
| 29 |
+
3,
|
| 30 |
+
3,
|
| 31 |
+
3,
|
| 32 |
+
3,
|
| 33 |
+
2,
|
| 34 |
+
2
|
| 35 |
+
],
|
| 36 |
+
"conv_stride": [
|
| 37 |
+
5,
|
| 38 |
+
2,
|
| 39 |
+
2,
|
| 40 |
+
2,
|
| 41 |
+
2,
|
| 42 |
+
2,
|
| 43 |
+
2
|
| 44 |
+
],
|
| 45 |
+
"ctc_loss_reduction": "sum",
|
| 46 |
+
"ctc_zero_infinity": false,
|
| 47 |
+
"diversity_loss_weight": 0.1,
|
| 48 |
+
"do_stable_layer_norm": true,
|
| 49 |
+
"eos_token_id": 2,
|
| 50 |
+
"feat_extract_activation": "gelu",
|
| 51 |
+
"feat_extract_dropout": 0.0,
|
| 52 |
+
"feat_extract_norm": "layer",
|
| 53 |
+
"feat_proj_dropout": 0.1,
|
| 54 |
+
"feat_quantizer_dropout": 0.0,
|
| 55 |
+
"final_dropout": 0.1,
|
| 56 |
+
"gradient_checkpointing": false,
|
| 57 |
+
"hidden_act": "swish",
|
| 58 |
+
"hidden_dropout": 0.1,
|
| 59 |
+
"hidden_dropout_prob": 0.1,
|
| 60 |
+
"hidden_size": 1024,
|
| 61 |
+
"initializer_range": 0.02,
|
| 62 |
+
"intermediate_size": 4096,
|
| 63 |
+
"layer_norm_eps": 1e-05,
|
| 64 |
+
"layerdrop": 0.0,
|
| 65 |
+
"mask_feature_length": 10,
|
| 66 |
+
"mask_feature_min_masks": 0,
|
| 67 |
+
"mask_feature_prob": 0.0,
|
| 68 |
+
"mask_time_length": 10,
|
| 69 |
+
"mask_time_min_masks": 2,
|
| 70 |
+
"mask_time_prob": 0.05,
|
| 71 |
+
"max_source_positions": 5000,
|
| 72 |
+
"model_type": "wav2vec2-conformer",
|
| 73 |
+
"num_adapter_layers": 3,
|
| 74 |
+
"num_attention_heads": 16,
|
| 75 |
+
"num_codevector_groups": 2,
|
| 76 |
+
"num_codevectors_per_group": 320,
|
| 77 |
+
"num_conv_pos_embedding_groups": 16,
|
| 78 |
+
"num_conv_pos_embeddings": 128,
|
| 79 |
+
"num_feat_extract_layers": 7,
|
| 80 |
+
"num_hidden_layers": 24,
|
| 81 |
+
"num_negatives": 100,
|
| 82 |
+
"output_hidden_size": 1024,
|
| 83 |
+
"pad_token_id": 0,
|
| 84 |
+
"position_embeddings_type": "rotary",
|
| 85 |
+
"proj_codevector_dim": 768,
|
| 86 |
+
"rotary_embedding_base": 10000,
|
| 87 |
+
"tdnn_dilation": [
|
| 88 |
+
1,
|
| 89 |
+
2,
|
| 90 |
+
3,
|
| 91 |
+
1,
|
| 92 |
+
1
|
| 93 |
+
],
|
| 94 |
+
"tdnn_dim": [
|
| 95 |
+
512,
|
| 96 |
+
512,
|
| 97 |
+
512,
|
| 98 |
+
512,
|
| 99 |
+
1500
|
| 100 |
+
],
|
| 101 |
+
"tdnn_kernel": [
|
| 102 |
+
5,
|
| 103 |
+
3,
|
| 104 |
+
3,
|
| 105 |
+
1,
|
| 106 |
+
1
|
| 107 |
+
],
|
| 108 |
+
"torch_dtype": "float32",
|
| 109 |
+
"transformers_version": "4.19.0.dev0",
|
| 110 |
+
"use_weighted_layer_sum": false,
|
| 111 |
+
"vocab_size": 32,
|
| 112 |
+
"xvector_output_dim": 512
|
| 113 |
+
}
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (185 Bytes). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/conv.cpython-310.pyc
ADDED
|
Binary file (2.72 kB). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/features.cpython-310.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/__pycache__/random_quantizer.cpython-310.pyc
ADDED
|
Binary file (1.98 kB). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/conv.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Res2dModule(nn.Module):
|
| 6 |
+
def __init__(self, idim, odim, stride=(2, 2)):
|
| 7 |
+
super(Res2dModule, self).__init__()
|
| 8 |
+
self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
|
| 9 |
+
self.bn1 = nn.BatchNorm2d(odim)
|
| 10 |
+
self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
|
| 11 |
+
self.bn2 = nn.BatchNorm2d(odim)
|
| 12 |
+
self.relu = nn.ReLU()
|
| 13 |
+
|
| 14 |
+
# residual
|
| 15 |
+
self.diff = False
|
| 16 |
+
if (idim != odim) or (stride[0] > 1):
|
| 17 |
+
self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
|
| 18 |
+
self.bn3 = nn.BatchNorm2d(odim)
|
| 19 |
+
self.diff = True
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
|
| 23 |
+
if self.diff:
|
| 24 |
+
x = self.bn3(self.conv3(x))
|
| 25 |
+
out = x + out
|
| 26 |
+
out = self.relu(out)
|
| 27 |
+
return out
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Conv2dSubsampling(nn.Module):
|
| 31 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
idim (int): Input dimension.
|
| 35 |
+
hdim (int): Hidden dimension.
|
| 36 |
+
odim (int): Output dimension.
|
| 37 |
+
strides (list): Sizes of strides.
|
| 38 |
+
n_bands (int): Number of frequency bands.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
|
| 42 |
+
"""Construct an Conv2dSubsampling object."""
|
| 43 |
+
super(Conv2dSubsampling, self).__init__()
|
| 44 |
+
|
| 45 |
+
self.conv = nn.Sequential(
|
| 46 |
+
Res2dModule(idim, hdim, (2, strides[0])),
|
| 47 |
+
Res2dModule(hdim, hdim, (2, strides[1])),
|
| 48 |
+
)
|
| 49 |
+
self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
"""Subsample x.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
x (torch.Tensor): Input tensor (#batch, idim, time).
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 59 |
+
where time' = time // 4.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
if x.dim() == 3:
|
| 63 |
+
x = x.unsqueeze(1) # (b, c, f, t)
|
| 64 |
+
x = self.conv(x)
|
| 65 |
+
x = rearrange(x, "b c f t -> b t (c f)")
|
| 66 |
+
x = self.linear(x)
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
if __name__ == '__main__':
|
| 70 |
+
import torch
|
| 71 |
+
conv_dim, encoder_dim = 512, 1024
|
| 72 |
+
conv = Conv2dSubsampling(
|
| 73 |
+
1, conv_dim, encoder_dim, strides=[2, 1], n_bands=128
|
| 74 |
+
)
|
| 75 |
+
inp = torch.randn((1, 128, 3000))
|
| 76 |
+
out = conv(inp)
|
| 77 |
+
print(out.shape)
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/features.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchaudio
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MelSTFT(nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
sample_rate=24000,
|
| 10 |
+
n_fft=2048,
|
| 11 |
+
hop_length=240,
|
| 12 |
+
n_mels=128,
|
| 13 |
+
is_db=False,
|
| 14 |
+
):
|
| 15 |
+
super(MelSTFT, self).__init__()
|
| 16 |
+
|
| 17 |
+
# spectrogram
|
| 18 |
+
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
| 19 |
+
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# amplitude to decibel
|
| 23 |
+
self.is_db = is_db
|
| 24 |
+
if is_db:
|
| 25 |
+
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
|
| 26 |
+
|
| 27 |
+
def forward(self, waveform):
|
| 28 |
+
if self.is_db:
|
| 29 |
+
return self.amplitude_to_db(self.mel_stft(waveform))
|
| 30 |
+
else:
|
| 31 |
+
return self.mel_stft(waveform)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CQTPreprocessor(nn.Module):
|
| 35 |
+
def __init__(self,
|
| 36 |
+
sr=24000,
|
| 37 |
+
hop=960,
|
| 38 |
+
nb=84,
|
| 39 |
+
to_db = True,
|
| 40 |
+
) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
from nnAudio.features.cqt import CQT
|
| 44 |
+
import torchaudio
|
| 45 |
+
self.cqt_fn = CQT(
|
| 46 |
+
sr=sr,
|
| 47 |
+
hop_length=hop,
|
| 48 |
+
n_bins=nb,
|
| 49 |
+
fmin=32.7 if nb == 84 else 27.5, # 84 or 88
|
| 50 |
+
bins_per_octave=12,
|
| 51 |
+
filter_scale=1,
|
| 52 |
+
norm=1,
|
| 53 |
+
window='hann',
|
| 54 |
+
center=True,
|
| 55 |
+
pad_mode='constant',
|
| 56 |
+
trainable=False,
|
| 57 |
+
output_format='Magnitude',
|
| 58 |
+
verbose=True,
|
| 59 |
+
)
|
| 60 |
+
if to_db:
|
| 61 |
+
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
|
| 62 |
+
else:
|
| 63 |
+
self.amplitude_to_db = lambda x:x
|
| 64 |
+
|
| 65 |
+
@torch.no_grad()
|
| 66 |
+
def __call__(self, x):
|
| 67 |
+
return self.amplitude_to_db(self.cqt_fn(x))
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/flash_conformer.py
ADDED
|
@@ -0,0 +1,2114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" PyTorch Wav2Vec2-Conformer model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
+
from torch.nn import CrossEntropyLoss
|
| 26 |
+
from torch.nn import functional as F
|
| 27 |
+
|
| 28 |
+
from transformers.activations import ACT2FN
|
| 29 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
| 30 |
+
from transformers.modeling_outputs import (
|
| 31 |
+
BaseModelOutput,
|
| 32 |
+
CausalLMOutput,
|
| 33 |
+
SequenceClassifierOutput,
|
| 34 |
+
TokenClassifierOutput,
|
| 35 |
+
Wav2Vec2BaseModelOutput,
|
| 36 |
+
XVectorOutput,
|
| 37 |
+
)
|
| 38 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 39 |
+
from transformers.utils import (
|
| 40 |
+
ModelOutput,
|
| 41 |
+
add_code_sample_docstrings,
|
| 42 |
+
add_start_docstrings,
|
| 43 |
+
add_start_docstrings_to_model_forward,
|
| 44 |
+
logging,
|
| 45 |
+
replace_return_docstrings,
|
| 46 |
+
)
|
| 47 |
+
from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
_HIDDEN_STATES_START_POSITION = 2
|
| 54 |
+
|
| 55 |
+
# General docstring
|
| 56 |
+
_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
|
| 57 |
+
|
| 58 |
+
# Base docstring
|
| 59 |
+
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
|
| 60 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
|
| 61 |
+
|
| 62 |
+
# CTC docstring
|
| 63 |
+
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
| 64 |
+
_CTC_EXPECTED_LOSS = 64.21
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 68 |
+
"facebook/wav2vec2-conformer-rel-pos-large",
|
| 69 |
+
# See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
|
| 75 |
+
class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
|
| 76 |
+
"""
|
| 77 |
+
Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
| 81 |
+
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
|
| 82 |
+
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
|
| 83 |
+
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
|
| 84 |
+
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
|
| 85 |
+
projected quantized states.
|
| 86 |
+
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
|
| 87 |
+
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
|
| 88 |
+
target vectors for contrastive loss.
|
| 89 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 90 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 91 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
| 92 |
+
|
| 93 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 94 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 95 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 96 |
+
sequence_length)`.
|
| 97 |
+
|
| 98 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 99 |
+
heads.
|
| 100 |
+
contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
| 101 |
+
The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
|
| 102 |
+
diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
| 103 |
+
The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
loss: Optional[torch.FloatTensor] = None
|
| 107 |
+
projected_states: torch.FloatTensor = None
|
| 108 |
+
projected_quantized_states: torch.FloatTensor = None
|
| 109 |
+
codevector_perplexity: torch.FloatTensor = None
|
| 110 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 111 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 112 |
+
contrastive_loss: Optional[torch.FloatTensor] = None
|
| 113 |
+
diversity_loss: Optional[torch.FloatTensor] = None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
| 117 |
+
def _compute_mask_indices(
|
| 118 |
+
shape: Tuple[int, int],
|
| 119 |
+
mask_prob: float,
|
| 120 |
+
mask_length: int,
|
| 121 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 122 |
+
min_masks: int = 0,
|
| 123 |
+
) -> np.ndarray:
|
| 124 |
+
"""
|
| 125 |
+
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
| 126 |
+
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
| 127 |
+
CPU as part of the preprocessing during training.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
| 131 |
+
the first element is the batch size and the second element is the length of the axis to span.
|
| 132 |
+
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
| 133 |
+
independently generated mask spans of length `mask_length` is computed by
|
| 134 |
+
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
| 135 |
+
actual percentage will be smaller.
|
| 136 |
+
mask_length: size of the mask
|
| 137 |
+
min_masks: minimum number of masked spans
|
| 138 |
+
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
| 139 |
+
each batch dimension.
|
| 140 |
+
"""
|
| 141 |
+
batch_size, sequence_length = shape
|
| 142 |
+
|
| 143 |
+
if mask_length < 1:
|
| 144 |
+
raise ValueError("`mask_length` has to be bigger than 0.")
|
| 145 |
+
|
| 146 |
+
if mask_length > sequence_length:
|
| 147 |
+
raise ValueError(
|
| 148 |
+
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
| 149 |
+
f" and `sequence_length`: {sequence_length}`"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# epsilon is used for probabilistic rounding
|
| 153 |
+
epsilon = np.random.rand(1).item()
|
| 154 |
+
|
| 155 |
+
def compute_num_masked_span(input_length):
|
| 156 |
+
"""Given input length, compute how many spans should be masked"""
|
| 157 |
+
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
| 158 |
+
num_masked_span = max(num_masked_span, min_masks)
|
| 159 |
+
|
| 160 |
+
# make sure num masked span <= sequence_length
|
| 161 |
+
if num_masked_span * mask_length > sequence_length:
|
| 162 |
+
num_masked_span = sequence_length // mask_length
|
| 163 |
+
|
| 164 |
+
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
| 165 |
+
if input_length - (mask_length - 1) < num_masked_span:
|
| 166 |
+
num_masked_span = max(input_length - (mask_length - 1), 0)
|
| 167 |
+
|
| 168 |
+
return num_masked_span
|
| 169 |
+
|
| 170 |
+
# compute number of masked spans in batch
|
| 171 |
+
input_lengths = (
|
| 172 |
+
attention_mask.sum(-1).detach().tolist()
|
| 173 |
+
if attention_mask is not None
|
| 174 |
+
else [sequence_length for _ in range(batch_size)]
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# SpecAugment mask to fill
|
| 178 |
+
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
| 179 |
+
spec_aug_mask_idxs = []
|
| 180 |
+
|
| 181 |
+
max_num_masked_span = compute_num_masked_span(sequence_length)
|
| 182 |
+
|
| 183 |
+
if max_num_masked_span == 0:
|
| 184 |
+
return spec_aug_mask
|
| 185 |
+
|
| 186 |
+
for input_length in input_lengths:
|
| 187 |
+
# compute num of masked spans for this input
|
| 188 |
+
num_masked_span = compute_num_masked_span(input_length)
|
| 189 |
+
|
| 190 |
+
# get random indices to mask
|
| 191 |
+
spec_aug_mask_idx = np.random.choice(
|
| 192 |
+
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# pick first sampled index that will serve as a dummy index to pad vector
|
| 196 |
+
# to ensure same dimension for all batches due to probabilistic rounding
|
| 197 |
+
# Picking first sample just pads those vectors twice.
|
| 198 |
+
if len(spec_aug_mask_idx) == 0:
|
| 199 |
+
# this case can only happen if `input_length` is strictly smaller then
|
| 200 |
+
# `sequence_length` in which case the last token has to be a padding
|
| 201 |
+
# token which we can use as a dummy mask id
|
| 202 |
+
dummy_mask_idx = sequence_length - 1
|
| 203 |
+
else:
|
| 204 |
+
dummy_mask_idx = spec_aug_mask_idx[0]
|
| 205 |
+
|
| 206 |
+
spec_aug_mask_idx = np.concatenate(
|
| 207 |
+
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
| 208 |
+
)
|
| 209 |
+
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
| 210 |
+
|
| 211 |
+
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
| 212 |
+
|
| 213 |
+
# expand masked indices to masked spans
|
| 214 |
+
spec_aug_mask_idxs = np.broadcast_to(
|
| 215 |
+
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
| 216 |
+
)
|
| 217 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
| 218 |
+
|
| 219 |
+
# add offset to the starting indexes so that indexes now create a span
|
| 220 |
+
offsets = np.arange(mask_length)[None, None, :]
|
| 221 |
+
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
| 222 |
+
batch_size, max_num_masked_span * mask_length
|
| 223 |
+
)
|
| 224 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
| 225 |
+
|
| 226 |
+
# ensure that we cannot have indices larger than sequence_length
|
| 227 |
+
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
| 228 |
+
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
|
| 229 |
+
|
| 230 |
+
# scatter indices to mask
|
| 231 |
+
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
| 232 |
+
|
| 233 |
+
return spec_aug_mask
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
|
| 237 |
+
def _sample_negative_indices(
|
| 238 |
+
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
|
| 239 |
+
):
|
| 240 |
+
"""
|
| 241 |
+
Sample `num_negatives` vectors from feature vectors.
|
| 242 |
+
"""
|
| 243 |
+
batch_size, sequence_length = features_shape
|
| 244 |
+
|
| 245 |
+
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
| 246 |
+
sequence_length_range = np.arange(sequence_length)
|
| 247 |
+
|
| 248 |
+
# get `num_negatives` random vector indices from the same utterance
|
| 249 |
+
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
|
| 250 |
+
|
| 251 |
+
mask_time_indices = (
|
| 252 |
+
mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
for batch_idx in range(batch_size):
|
| 256 |
+
high = mask_time_indices[batch_idx].sum() - 1
|
| 257 |
+
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
|
| 258 |
+
|
| 259 |
+
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
|
| 260 |
+
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
|
| 261 |
+
# avoid sampling the same positive vector, but keep the distribution uniform
|
| 262 |
+
sampled_indices[sampled_indices >= feature_indices] += 1
|
| 263 |
+
|
| 264 |
+
# remap to actual indices
|
| 265 |
+
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
|
| 266 |
+
|
| 267 |
+
# correct for batch size
|
| 268 |
+
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
|
| 269 |
+
|
| 270 |
+
return sampled_negative_indices
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 274 |
+
class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
|
| 275 |
+
def __init__(self, config, layer_id=0):
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 278 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 279 |
+
|
| 280 |
+
self.conv = nn.Conv1d(
|
| 281 |
+
self.in_conv_dim,
|
| 282 |
+
self.out_conv_dim,
|
| 283 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 284 |
+
stride=config.conv_stride[layer_id],
|
| 285 |
+
bias=config.conv_bias,
|
| 286 |
+
)
|
| 287 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 288 |
+
|
| 289 |
+
def forward(self, hidden_states):
|
| 290 |
+
hidden_states = self.conv(hidden_states)
|
| 291 |
+
hidden_states = self.activation(hidden_states)
|
| 292 |
+
return hidden_states
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 296 |
+
class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
|
| 297 |
+
def __init__(self, config, layer_id=0):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 300 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 301 |
+
|
| 302 |
+
self.conv = nn.Conv1d(
|
| 303 |
+
self.in_conv_dim,
|
| 304 |
+
self.out_conv_dim,
|
| 305 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 306 |
+
stride=config.conv_stride[layer_id],
|
| 307 |
+
bias=config.conv_bias,
|
| 308 |
+
)
|
| 309 |
+
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
| 310 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 311 |
+
|
| 312 |
+
def forward(self, hidden_states):
|
| 313 |
+
hidden_states = self.conv(hidden_states)
|
| 314 |
+
|
| 315 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 316 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 317 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 318 |
+
|
| 319 |
+
hidden_states = self.activation(hidden_states)
|
| 320 |
+
return hidden_states
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 324 |
+
class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
|
| 325 |
+
def __init__(self, config, layer_id=0):
|
| 326 |
+
super().__init__()
|
| 327 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 328 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 329 |
+
|
| 330 |
+
self.conv = nn.Conv1d(
|
| 331 |
+
self.in_conv_dim,
|
| 332 |
+
self.out_conv_dim,
|
| 333 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 334 |
+
stride=config.conv_stride[layer_id],
|
| 335 |
+
bias=config.conv_bias,
|
| 336 |
+
)
|
| 337 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 338 |
+
|
| 339 |
+
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
|
| 340 |
+
|
| 341 |
+
def forward(self, hidden_states):
|
| 342 |
+
hidden_states = self.conv(hidden_states)
|
| 343 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 344 |
+
hidden_states = self.activation(hidden_states)
|
| 345 |
+
return hidden_states
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
|
| 349 |
+
class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
|
| 350 |
+
def __init__(self, config):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.conv = nn.Conv1d(
|
| 353 |
+
config.hidden_size,
|
| 354 |
+
config.hidden_size,
|
| 355 |
+
kernel_size=config.num_conv_pos_embeddings,
|
| 356 |
+
padding=config.num_conv_pos_embeddings // 2,
|
| 357 |
+
groups=config.num_conv_pos_embedding_groups,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
if is_deepspeed_zero3_enabled():
|
| 361 |
+
import deepspeed
|
| 362 |
+
|
| 363 |
+
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
| 364 |
+
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
| 365 |
+
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
| 366 |
+
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
| 367 |
+
else:
|
| 368 |
+
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
| 369 |
+
|
| 370 |
+
self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
|
| 371 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 372 |
+
|
| 373 |
+
def forward(self, hidden_states):
|
| 374 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 375 |
+
|
| 376 |
+
hidden_states = self.conv(hidden_states)
|
| 377 |
+
hidden_states = self.padding(hidden_states)
|
| 378 |
+
hidden_states = self.activation(hidden_states)
|
| 379 |
+
|
| 380 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 381 |
+
return hidden_states
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
|
| 385 |
+
"""Rotary positional embedding
|
| 386 |
+
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
def __init__(self, config):
|
| 390 |
+
super().__init__()
|
| 391 |
+
dim = config.hidden_size // config.num_attention_heads
|
| 392 |
+
base = config.rotary_embedding_base
|
| 393 |
+
|
| 394 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 395 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 396 |
+
self.cached_sequence_length = None
|
| 397 |
+
self.cached_rotary_positional_embedding = None
|
| 398 |
+
|
| 399 |
+
def forward(self, hidden_states):
|
| 400 |
+
sequence_length = hidden_states.shape[1]
|
| 401 |
+
|
| 402 |
+
if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
|
| 403 |
+
return self.cached_rotary_positional_embedding
|
| 404 |
+
|
| 405 |
+
self.cached_sequence_length = sequence_length
|
| 406 |
+
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
|
| 407 |
+
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
|
| 408 |
+
embeddings = torch.cat((freqs, freqs), dim=-1)
|
| 409 |
+
|
| 410 |
+
cos_embeddings = embeddings.cos()[:, None, None, :]
|
| 411 |
+
sin_embeddings = embeddings.sin()[:, None, None, :]
|
| 412 |
+
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
|
| 413 |
+
return self.cached_rotary_positional_embedding
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
|
| 417 |
+
"""Relative positional encoding module."""
|
| 418 |
+
|
| 419 |
+
def __init__(self, config):
|
| 420 |
+
super().__init__()
|
| 421 |
+
self.max_len = config.max_source_positions
|
| 422 |
+
self.d_model = config.hidden_size
|
| 423 |
+
self.pe = None
|
| 424 |
+
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
|
| 425 |
+
|
| 426 |
+
def extend_pe(self, x):
|
| 427 |
+
# Reset the positional encodings
|
| 428 |
+
if self.pe is not None:
|
| 429 |
+
# self.pe contains both positive and negative parts
|
| 430 |
+
# the length of self.pe is 2 * input_len - 1
|
| 431 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
| 432 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 433 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 434 |
+
return
|
| 435 |
+
# Suppose `i` is the position of query vector and `j` is the
|
| 436 |
+
# position of key vector. We use positive relative positions when keys
|
| 437 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
| 438 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
| 439 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
| 440 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 441 |
+
div_term = torch.exp(
|
| 442 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
|
| 443 |
+
)
|
| 444 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
| 445 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
| 446 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
| 447 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
| 448 |
+
|
| 449 |
+
# Reverse the order of positive indices and concat both positive and
|
| 450 |
+
# negative indices. This is used to support the shifting trick
|
| 451 |
+
# as in https://arxiv.org/abs/1901.02860
|
| 452 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
| 453 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
| 454 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
| 455 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 456 |
+
|
| 457 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 458 |
+
self.extend_pe(hidden_states)
|
| 459 |
+
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
|
| 460 |
+
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
|
| 461 |
+
relative_position_embeddings = self.pe[:, start_idx:end_idx]
|
| 462 |
+
|
| 463 |
+
return relative_position_embeddings
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 467 |
+
class Wav2Vec2ConformerSamePadLayer(nn.Module):
|
| 468 |
+
def __init__(self, num_conv_pos_embeddings):
|
| 469 |
+
super().__init__()
|
| 470 |
+
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
|
| 471 |
+
|
| 472 |
+
def forward(self, hidden_states):
|
| 473 |
+
if self.num_pad_remove > 0:
|
| 474 |
+
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
|
| 475 |
+
return hidden_states
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
|
| 479 |
+
class Wav2Vec2ConformerFeatureEncoder(nn.Module):
|
| 480 |
+
"""Construct the features from raw audio waveform"""
|
| 481 |
+
|
| 482 |
+
def __init__(self, config):
|
| 483 |
+
super().__init__()
|
| 484 |
+
|
| 485 |
+
if config.feat_extract_norm == "group":
|
| 486 |
+
conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
|
| 487 |
+
Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
|
| 488 |
+
for i in range(config.num_feat_extract_layers - 1)
|
| 489 |
+
]
|
| 490 |
+
elif config.feat_extract_norm == "layer":
|
| 491 |
+
conv_layers = [
|
| 492 |
+
Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
|
| 493 |
+
]
|
| 494 |
+
else:
|
| 495 |
+
raise ValueError(
|
| 496 |
+
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
| 497 |
+
)
|
| 498 |
+
self.conv_layers = nn.ModuleList(conv_layers)
|
| 499 |
+
self.gradient_checkpointing = False
|
| 500 |
+
self._requires_grad = True
|
| 501 |
+
|
| 502 |
+
def _freeze_parameters(self):
|
| 503 |
+
for param in self.parameters():
|
| 504 |
+
param.requires_grad = False
|
| 505 |
+
self._requires_grad = False
|
| 506 |
+
|
| 507 |
+
def forward(self, input_values):
|
| 508 |
+
hidden_states = input_values[:, None]
|
| 509 |
+
|
| 510 |
+
# make sure hidden_states require grad for gradient_checkpointing
|
| 511 |
+
if self._requires_grad and self.training:
|
| 512 |
+
hidden_states.requires_grad = True
|
| 513 |
+
|
| 514 |
+
for conv_layer in self.conv_layers:
|
| 515 |
+
if self._requires_grad and self.gradient_checkpointing and self.training:
|
| 516 |
+
|
| 517 |
+
def create_custom_forward(module):
|
| 518 |
+
def custom_forward(*inputs):
|
| 519 |
+
return module(*inputs)
|
| 520 |
+
|
| 521 |
+
return custom_forward
|
| 522 |
+
|
| 523 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 524 |
+
create_custom_forward(conv_layer),
|
| 525 |
+
hidden_states,
|
| 526 |
+
)
|
| 527 |
+
else:
|
| 528 |
+
hidden_states = conv_layer(hidden_states)
|
| 529 |
+
|
| 530 |
+
return hidden_states
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
|
| 534 |
+
class Wav2Vec2ConformerFeatureProjection(nn.Module):
|
| 535 |
+
def __init__(self, config):
|
| 536 |
+
super().__init__()
|
| 537 |
+
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
| 538 |
+
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
| 539 |
+
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
| 540 |
+
|
| 541 |
+
def forward(self, hidden_states):
|
| 542 |
+
# non-projected hidden states are needed for quantization
|
| 543 |
+
norm_hidden_states = self.layer_norm(hidden_states)
|
| 544 |
+
hidden_states = self.projection(norm_hidden_states)
|
| 545 |
+
hidden_states = self.dropout(hidden_states)
|
| 546 |
+
return hidden_states, norm_hidden_states
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
|
| 550 |
+
class Wav2Vec2ConformerFeedForward(nn.Module):
|
| 551 |
+
def __init__(self, config):
|
| 552 |
+
super().__init__()
|
| 553 |
+
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
| 554 |
+
|
| 555 |
+
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 556 |
+
if isinstance(config.hidden_act, str):
|
| 557 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 558 |
+
else:
|
| 559 |
+
self.intermediate_act_fn = config.hidden_act
|
| 560 |
+
|
| 561 |
+
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 562 |
+
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
| 563 |
+
|
| 564 |
+
def forward(self, hidden_states):
|
| 565 |
+
hidden_states = self.intermediate_dense(hidden_states)
|
| 566 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 567 |
+
hidden_states = self.intermediate_dropout(hidden_states)
|
| 568 |
+
|
| 569 |
+
hidden_states = self.output_dense(hidden_states)
|
| 570 |
+
hidden_states = self.output_dropout(hidden_states)
|
| 571 |
+
return hidden_states
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
class Wav2Vec2ConformerConvolutionModule(nn.Module):
|
| 575 |
+
"""Convolution block used in the conformer block"""
|
| 576 |
+
|
| 577 |
+
def __init__(self, config):
|
| 578 |
+
super().__init__()
|
| 579 |
+
if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
|
| 580 |
+
raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
|
| 581 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
| 582 |
+
self.pointwise_conv1 = torch.nn.Conv1d(
|
| 583 |
+
config.hidden_size,
|
| 584 |
+
2 * config.hidden_size,
|
| 585 |
+
kernel_size=1,
|
| 586 |
+
stride=1,
|
| 587 |
+
padding=0,
|
| 588 |
+
bias=False,
|
| 589 |
+
)
|
| 590 |
+
self.glu = torch.nn.GLU(dim=1)
|
| 591 |
+
self.depthwise_conv = torch.nn.Conv1d(
|
| 592 |
+
config.hidden_size,
|
| 593 |
+
config.hidden_size,
|
| 594 |
+
config.conv_depthwise_kernel_size,
|
| 595 |
+
stride=1,
|
| 596 |
+
padding=(config.conv_depthwise_kernel_size - 1) // 2,
|
| 597 |
+
groups=config.hidden_size,
|
| 598 |
+
bias=False,
|
| 599 |
+
)
|
| 600 |
+
self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
|
| 601 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 602 |
+
self.pointwise_conv2 = torch.nn.Conv1d(
|
| 603 |
+
config.hidden_size,
|
| 604 |
+
config.hidden_size,
|
| 605 |
+
kernel_size=1,
|
| 606 |
+
stride=1,
|
| 607 |
+
padding=0,
|
| 608 |
+
bias=False,
|
| 609 |
+
)
|
| 610 |
+
self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
|
| 611 |
+
|
| 612 |
+
def forward(self, hidden_states):
|
| 613 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 614 |
+
# exchange the temporal dimension and the feature dimension
|
| 615 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 616 |
+
|
| 617 |
+
# GLU mechanism
|
| 618 |
+
# => (batch, 2*channel, dim)
|
| 619 |
+
hidden_states = self.pointwise_conv1(hidden_states)
|
| 620 |
+
# => (batch, channel, dim)
|
| 621 |
+
hidden_states = self.glu(hidden_states)
|
| 622 |
+
|
| 623 |
+
# 1D Depthwise Conv
|
| 624 |
+
hidden_states = self.depthwise_conv(hidden_states)
|
| 625 |
+
hidden_states = self.batch_norm(hidden_states)
|
| 626 |
+
hidden_states = self.activation(hidden_states)
|
| 627 |
+
|
| 628 |
+
hidden_states = self.pointwise_conv2(hidden_states)
|
| 629 |
+
hidden_states = self.dropout(hidden_states)
|
| 630 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 631 |
+
return hidden_states
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
class Wav2Vec2ConformerSelfAttention(nn.Module):
|
| 635 |
+
"""Construct an Wav2Vec2ConformerSelfAttention object.
|
| 636 |
+
Can be enhanced with rotary or relative position embeddings.
|
| 637 |
+
"""
|
| 638 |
+
|
| 639 |
+
def __init__(self, config):
|
| 640 |
+
super().__init__()
|
| 641 |
+
|
| 642 |
+
self.head_size = config.hidden_size // config.num_attention_heads
|
| 643 |
+
self.num_heads = config.num_attention_heads
|
| 644 |
+
self.position_embeddings_type = config.position_embeddings_type
|
| 645 |
+
|
| 646 |
+
self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
|
| 647 |
+
self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
|
| 648 |
+
self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
|
| 649 |
+
self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
|
| 650 |
+
|
| 651 |
+
self.dropout = nn.Dropout(p=config.attention_dropout)
|
| 652 |
+
self.dropout_p = config.attention_dropout
|
| 653 |
+
|
| 654 |
+
self.is_causal = config.is_causal
|
| 655 |
+
|
| 656 |
+
if self.position_embeddings_type == "relative":
|
| 657 |
+
# linear transformation for positional encoding
|
| 658 |
+
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
| 659 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 660 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 661 |
+
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
|
| 662 |
+
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
|
| 663 |
+
|
| 664 |
+
def forward(
|
| 665 |
+
self,
|
| 666 |
+
hidden_states: torch.Tensor,
|
| 667 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 668 |
+
relative_position_embeddings: Optional[torch.Tensor] = None,
|
| 669 |
+
output_attentions: bool = False,
|
| 670 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 671 |
+
# self-attention mechanism
|
| 672 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 673 |
+
|
| 674 |
+
# make sure query/key states can be != value states
|
| 675 |
+
query_key_states = hidden_states
|
| 676 |
+
value_states = hidden_states
|
| 677 |
+
|
| 678 |
+
if self.position_embeddings_type == "rotary":
|
| 679 |
+
if relative_position_embeddings is None:
|
| 680 |
+
raise ValueError(
|
| 681 |
+
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
|
| 682 |
+
)
|
| 683 |
+
query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
|
| 684 |
+
|
| 685 |
+
# project query_key_states and value_states
|
| 686 |
+
query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
| 687 |
+
key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
| 688 |
+
value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
|
| 689 |
+
|
| 690 |
+
# => (batch, head, time1, d_k)
|
| 691 |
+
query = query.transpose(1, 2)
|
| 692 |
+
key = key.transpose(1, 2)
|
| 693 |
+
value = value.transpose(1, 2)
|
| 694 |
+
|
| 695 |
+
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
|
| 696 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
|
| 697 |
+
probs = None
|
| 698 |
+
|
| 699 |
+
# # apply attention_mask if necessary
|
| 700 |
+
# if attention_mask is not None:
|
| 701 |
+
# scores = scores + attention_mask
|
| 702 |
+
|
| 703 |
+
# # => (batch, head, time1, time2)
|
| 704 |
+
# probs = torch.softmax(scores, dim=-1)
|
| 705 |
+
# probs = self.dropout(probs)
|
| 706 |
+
|
| 707 |
+
# # => (batch, head, time1, d_k)
|
| 708 |
+
# hidden_states = torch.matmul(probs, value)
|
| 709 |
+
|
| 710 |
+
# => (batch, time1, hidden_size)
|
| 711 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
|
| 712 |
+
hidden_states = self.linear_out(hidden_states)
|
| 713 |
+
|
| 714 |
+
return hidden_states, probs
|
| 715 |
+
|
| 716 |
+
def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
|
| 717 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 718 |
+
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
|
| 719 |
+
|
| 720 |
+
cos = relative_position_embeddings[0, :sequence_length, ...]
|
| 721 |
+
sin = relative_position_embeddings[1, :sequence_length, ...]
|
| 722 |
+
|
| 723 |
+
# rotate hidden_states with rotary embeddings
|
| 724 |
+
hidden_states = hidden_states.transpose(0, 1)
|
| 725 |
+
rotated_states_begin = hidden_states[..., : self.head_size // 2]
|
| 726 |
+
rotated_states_end = hidden_states[..., self.head_size // 2 :]
|
| 727 |
+
rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
|
| 728 |
+
hidden_states = (hidden_states * cos) + (rotated_states * sin)
|
| 729 |
+
hidden_states = hidden_states.transpose(0, 1)
|
| 730 |
+
|
| 731 |
+
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
|
| 732 |
+
|
| 733 |
+
return hidden_states
|
| 734 |
+
|
| 735 |
+
def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
|
| 736 |
+
# 1. project positional embeddings
|
| 737 |
+
# => (batch, head, 2*time1-1, d_k)
|
| 738 |
+
proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
|
| 739 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.view(
|
| 740 |
+
relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
|
| 741 |
+
)
|
| 742 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
|
| 743 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
|
| 744 |
+
|
| 745 |
+
# 2. Add bias to query
|
| 746 |
+
# => (batch, head, time1, d_k)
|
| 747 |
+
query = query.transpose(1, 2)
|
| 748 |
+
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
|
| 749 |
+
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
|
| 750 |
+
|
| 751 |
+
# 3. attention score: first compute matrix a and matrix c
|
| 752 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 753 |
+
# => (batch, head, time1, time2)
|
| 754 |
+
scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
|
| 755 |
+
|
| 756 |
+
# 4. then compute matrix b and matrix d
|
| 757 |
+
# => (batch, head, time1, 2*time1-1)
|
| 758 |
+
scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
|
| 759 |
+
|
| 760 |
+
# 5. shift matrix b and matrix d
|
| 761 |
+
zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
|
| 762 |
+
scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
|
| 763 |
+
scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
|
| 764 |
+
scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
|
| 765 |
+
scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
|
| 766 |
+
scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
|
| 767 |
+
|
| 768 |
+
# 6. sum matrices
|
| 769 |
+
# => (batch, head, time1, time2)
|
| 770 |
+
scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
|
| 771 |
+
|
| 772 |
+
return scores
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
class Wav2Vec2ConformerEncoderLayer(nn.Module):
|
| 776 |
+
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
|
| 777 |
+
|
| 778 |
+
def __init__(self, config):
|
| 779 |
+
super().__init__()
|
| 780 |
+
embed_dim = config.hidden_size
|
| 781 |
+
dropout = config.attention_dropout
|
| 782 |
+
|
| 783 |
+
# Feed-forward 1
|
| 784 |
+
self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
|
| 785 |
+
self.ffn1 = Wav2Vec2ConformerFeedForward(config)
|
| 786 |
+
|
| 787 |
+
# Self-Attention
|
| 788 |
+
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
|
| 789 |
+
self.self_attn_dropout = torch.nn.Dropout(dropout)
|
| 790 |
+
self.self_attn = Wav2Vec2ConformerSelfAttention(config)
|
| 791 |
+
|
| 792 |
+
# Conformer Convolution
|
| 793 |
+
self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
|
| 794 |
+
|
| 795 |
+
# Feed-forward 2
|
| 796 |
+
self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
|
| 797 |
+
self.ffn2 = Wav2Vec2ConformerFeedForward(config)
|
| 798 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
| 799 |
+
|
| 800 |
+
def forward(
|
| 801 |
+
self,
|
| 802 |
+
hidden_states,
|
| 803 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 804 |
+
relative_position_embeddings: Optional[torch.Tensor] = None,
|
| 805 |
+
output_attentions: bool = False,
|
| 806 |
+
):
|
| 807 |
+
hidden_states = hidden_states
|
| 808 |
+
|
| 809 |
+
# 1. Feed-Forward 1 layer
|
| 810 |
+
residual = hidden_states
|
| 811 |
+
hidden_states = self.ffn1_layer_norm(hidden_states)
|
| 812 |
+
hidden_states = self.ffn1(hidden_states)
|
| 813 |
+
hidden_states = hidden_states * 0.5 + residual
|
| 814 |
+
residual = hidden_states
|
| 815 |
+
|
| 816 |
+
# 2. Self-Attention layer
|
| 817 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 818 |
+
hidden_states, attn_weigts = self.self_attn(
|
| 819 |
+
hidden_states=hidden_states,
|
| 820 |
+
attention_mask=attention_mask,
|
| 821 |
+
relative_position_embeddings=relative_position_embeddings,
|
| 822 |
+
output_attentions=output_attentions,
|
| 823 |
+
)
|
| 824 |
+
hidden_states = self.self_attn_dropout(hidden_states)
|
| 825 |
+
hidden_states = hidden_states + residual
|
| 826 |
+
|
| 827 |
+
# 3. Convolutional Layer
|
| 828 |
+
residual = hidden_states
|
| 829 |
+
hidden_states = self.conv_module(hidden_states)
|
| 830 |
+
hidden_states = residual + hidden_states
|
| 831 |
+
|
| 832 |
+
# 4. Feed-Forward 2 Layer
|
| 833 |
+
residual = hidden_states
|
| 834 |
+
hidden_states = self.ffn2_layer_norm(hidden_states)
|
| 835 |
+
hidden_states = self.ffn2(hidden_states)
|
| 836 |
+
hidden_states = hidden_states * 0.5 + residual
|
| 837 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 838 |
+
|
| 839 |
+
return hidden_states, attn_weigts
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
class Wav2Vec2ConformerEncoder(nn.Module):
|
| 843 |
+
def __init__(self, config, is_causal=False):
|
| 844 |
+
super().__init__()
|
| 845 |
+
config.is_causal = is_causal
|
| 846 |
+
self.config = config
|
| 847 |
+
|
| 848 |
+
if config.position_embeddings_type == "relative":
|
| 849 |
+
self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
|
| 850 |
+
elif config.position_embeddings_type == "rotary":
|
| 851 |
+
self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
|
| 852 |
+
else:
|
| 853 |
+
self.embed_positions = None
|
| 854 |
+
|
| 855 |
+
self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
|
| 856 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 857 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 858 |
+
self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 859 |
+
self.gradient_checkpointing = False
|
| 860 |
+
|
| 861 |
+
def forward(
|
| 862 |
+
self,
|
| 863 |
+
hidden_states,
|
| 864 |
+
attention_mask=None,
|
| 865 |
+
output_attentions=False,
|
| 866 |
+
output_hidden_states=False,
|
| 867 |
+
return_dict=True,
|
| 868 |
+
):
|
| 869 |
+
all_hidden_states = () if output_hidden_states else None
|
| 870 |
+
all_self_attentions = () if output_attentions else None
|
| 871 |
+
|
| 872 |
+
if attention_mask is not None:
|
| 873 |
+
# make sure padded tokens output 0
|
| 874 |
+
hidden_states[~attention_mask] = 0.0
|
| 875 |
+
|
| 876 |
+
# extend attention_mask
|
| 877 |
+
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
|
| 878 |
+
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
|
| 879 |
+
attention_mask = attention_mask.expand(
|
| 880 |
+
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
hidden_states = self.dropout(hidden_states)
|
| 884 |
+
|
| 885 |
+
if self.embed_positions is not None:
|
| 886 |
+
relative_position_embeddings = self.embed_positions(hidden_states)
|
| 887 |
+
else:
|
| 888 |
+
relative_position_embeddings = None
|
| 889 |
+
|
| 890 |
+
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
| 891 |
+
|
| 892 |
+
for i, layer in enumerate(self.layers):
|
| 893 |
+
if output_hidden_states:
|
| 894 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 895 |
+
|
| 896 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 897 |
+
dropout_probability = np.random.uniform(0, 1)
|
| 898 |
+
|
| 899 |
+
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
| 900 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
| 901 |
+
# under deepspeed zero3 all gpus must run in sync
|
| 902 |
+
if self.gradient_checkpointing and self.training:
|
| 903 |
+
# create gradient checkpointing function
|
| 904 |
+
def create_custom_forward(module):
|
| 905 |
+
def custom_forward(*inputs):
|
| 906 |
+
return module(*inputs, output_attentions)
|
| 907 |
+
|
| 908 |
+
return custom_forward
|
| 909 |
+
|
| 910 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 911 |
+
create_custom_forward(layer),
|
| 912 |
+
hidden_states,
|
| 913 |
+
attention_mask,
|
| 914 |
+
relative_position_embeddings,
|
| 915 |
+
)
|
| 916 |
+
else:
|
| 917 |
+
layer_outputs = layer(
|
| 918 |
+
hidden_states,
|
| 919 |
+
attention_mask=attention_mask,
|
| 920 |
+
relative_position_embeddings=relative_position_embeddings,
|
| 921 |
+
output_attentions=output_attentions,
|
| 922 |
+
)
|
| 923 |
+
hidden_states = layer_outputs[0]
|
| 924 |
+
|
| 925 |
+
if skip_the_layer:
|
| 926 |
+
layer_outputs = (None, None)
|
| 927 |
+
|
| 928 |
+
if output_attentions:
|
| 929 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 930 |
+
|
| 931 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 932 |
+
if output_hidden_states:
|
| 933 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 934 |
+
|
| 935 |
+
if not return_dict:
|
| 936 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 937 |
+
return BaseModelOutput(
|
| 938 |
+
last_hidden_state=hidden_states,
|
| 939 |
+
hidden_states=all_hidden_states,
|
| 940 |
+
attentions=all_self_attentions,
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
|
| 945 |
+
class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
|
| 946 |
+
"""
|
| 947 |
+
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
|
| 948 |
+
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
|
| 949 |
+
"""
|
| 950 |
+
|
| 951 |
+
def __init__(self, config):
|
| 952 |
+
super().__init__()
|
| 953 |
+
self.num_groups = config.num_codevector_groups
|
| 954 |
+
self.num_vars = config.num_codevectors_per_group
|
| 955 |
+
|
| 956 |
+
if config.codevector_dim % self.num_groups != 0:
|
| 957 |
+
raise ValueError(
|
| 958 |
+
f"`config.codevector_dim {config.codevector_dim} must be divisible "
|
| 959 |
+
f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
# storage for codebook variables (codewords)
|
| 963 |
+
self.codevectors = nn.Parameter(
|
| 964 |
+
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
|
| 965 |
+
)
|
| 966 |
+
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
|
| 967 |
+
|
| 968 |
+
# can be decayed for training
|
| 969 |
+
self.temperature = 2
|
| 970 |
+
|
| 971 |
+
@staticmethod
|
| 972 |
+
def _compute_perplexity(probs, mask=None):
|
| 973 |
+
if mask is not None:
|
| 974 |
+
mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
|
| 975 |
+
probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
|
| 976 |
+
marginal_probs = probs.sum(dim=0) / mask.sum()
|
| 977 |
+
else:
|
| 978 |
+
marginal_probs = probs.mean(dim=0)
|
| 979 |
+
|
| 980 |
+
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
| 981 |
+
return perplexity
|
| 982 |
+
|
| 983 |
+
def forward(self, hidden_states, mask_time_indices=None):
|
| 984 |
+
batch_size, sequence_length, hidden_size = hidden_states.shape
|
| 985 |
+
|
| 986 |
+
# project to codevector dim
|
| 987 |
+
hidden_states = self.weight_proj(hidden_states)
|
| 988 |
+
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
| 989 |
+
|
| 990 |
+
if self.training:
|
| 991 |
+
# sample code vector probs via gumbel in differentiateable way
|
| 992 |
+
codevector_probs = nn.functional.gumbel_softmax(
|
| 993 |
+
hidden_states.float(), tau=self.temperature, hard=True
|
| 994 |
+
).type_as(hidden_states)
|
| 995 |
+
|
| 996 |
+
# compute perplexity
|
| 997 |
+
codevector_soft_dist = torch.softmax(
|
| 998 |
+
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
|
| 999 |
+
)
|
| 1000 |
+
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
|
| 1001 |
+
else:
|
| 1002 |
+
# take argmax in non-differentiable way
|
| 1003 |
+
# comptute hard codevector distribution (one hot)
|
| 1004 |
+
codevector_idx = hidden_states.argmax(dim=-1)
|
| 1005 |
+
codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
|
| 1006 |
+
-1, codevector_idx.view(-1, 1), 1.0
|
| 1007 |
+
)
|
| 1008 |
+
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
| 1009 |
+
|
| 1010 |
+
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
|
| 1011 |
+
|
| 1012 |
+
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
|
| 1013 |
+
# use probs to retrieve codevectors
|
| 1014 |
+
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
|
| 1015 |
+
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
| 1016 |
+
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
|
| 1017 |
+
|
| 1018 |
+
return codevectors, perplexity
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
|
| 1022 |
+
class Wav2Vec2ConformerAdapter(nn.Module):
|
| 1023 |
+
def __init__(self, config):
|
| 1024 |
+
super().__init__()
|
| 1025 |
+
|
| 1026 |
+
# feature dim might need to be down-projected
|
| 1027 |
+
if config.output_hidden_size != config.hidden_size:
|
| 1028 |
+
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
|
| 1029 |
+
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
|
| 1030 |
+
else:
|
| 1031 |
+
self.proj = self.proj_layer_norm = None
|
| 1032 |
+
|
| 1033 |
+
self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
|
| 1034 |
+
self.layerdrop = config.layerdrop
|
| 1035 |
+
|
| 1036 |
+
def forward(self, hidden_states):
|
| 1037 |
+
# down project hidden_states if necessary
|
| 1038 |
+
if self.proj is not None and self.proj_layer_norm is not None:
|
| 1039 |
+
hidden_states = self.proj(hidden_states)
|
| 1040 |
+
hidden_states = self.proj_layer_norm(hidden_states)
|
| 1041 |
+
|
| 1042 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1043 |
+
|
| 1044 |
+
for layer in self.layers:
|
| 1045 |
+
layerdrop_prob = np.random.random()
|
| 1046 |
+
if not self.training or (layerdrop_prob > self.layerdrop):
|
| 1047 |
+
hidden_states = layer(hidden_states)
|
| 1048 |
+
|
| 1049 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1050 |
+
return hidden_states
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 1054 |
+
class Wav2Vec2ConformerAdapterLayer(nn.Module):
|
| 1055 |
+
def __init__(self, config):
|
| 1056 |
+
super().__init__()
|
| 1057 |
+
self.conv = nn.Conv1d(
|
| 1058 |
+
config.output_hidden_size,
|
| 1059 |
+
2 * config.output_hidden_size,
|
| 1060 |
+
config.adapter_kernel_size,
|
| 1061 |
+
stride=config.adapter_stride,
|
| 1062 |
+
padding=1,
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
def forward(self, hidden_states):
|
| 1066 |
+
hidden_states = self.conv(hidden_states)
|
| 1067 |
+
hidden_states = nn.functional.glu(hidden_states, dim=1)
|
| 1068 |
+
|
| 1069 |
+
return hidden_states
|
| 1070 |
+
|
| 1071 |
+
|
| 1072 |
+
class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
|
| 1073 |
+
"""
|
| 1074 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 1075 |
+
models.
|
| 1076 |
+
"""
|
| 1077 |
+
|
| 1078 |
+
config_class = Wav2Vec2ConformerConfig
|
| 1079 |
+
base_model_prefix = "wav2vec2_conformer"
|
| 1080 |
+
main_input_name = "input_values"
|
| 1081 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 1082 |
+
supports_gradient_checkpointing = True
|
| 1083 |
+
|
| 1084 |
+
def _init_weights(self, module):
|
| 1085 |
+
"""Initialize the weights"""
|
| 1086 |
+
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
|
| 1087 |
+
if isinstance(module, Wav2Vec2ConformerForPreTraining):
|
| 1088 |
+
module.project_hid.reset_parameters()
|
| 1089 |
+
module.project_q.reset_parameters()
|
| 1090 |
+
module.project_hid._is_hf_initialized = True
|
| 1091 |
+
module.project_q._is_hf_initialized = True
|
| 1092 |
+
# gumbel softmax requires special init
|
| 1093 |
+
elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
|
| 1094 |
+
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
| 1095 |
+
module.weight_proj.bias.data.zero_()
|
| 1096 |
+
nn.init.uniform_(module.codevectors)
|
| 1097 |
+
elif isinstance(module, Wav2Vec2ConformerSelfAttention):
|
| 1098 |
+
if hasattr(module, "pos_bias_u"):
|
| 1099 |
+
nn.init.xavier_uniform_(module.pos_bias_u)
|
| 1100 |
+
if hasattr(module, "pos_bias_v"):
|
| 1101 |
+
nn.init.xavier_uniform_(module.pos_bias_v)
|
| 1102 |
+
elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
|
| 1103 |
+
nn.init.normal_(
|
| 1104 |
+
module.conv.weight,
|
| 1105 |
+
mean=0,
|
| 1106 |
+
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
|
| 1107 |
+
)
|
| 1108 |
+
nn.init.constant_(module.conv.bias, 0)
|
| 1109 |
+
elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
|
| 1110 |
+
k = math.sqrt(1 / module.projection.in_features)
|
| 1111 |
+
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
| 1112 |
+
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
| 1113 |
+
elif isinstance(module, nn.Linear):
|
| 1114 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1115 |
+
|
| 1116 |
+
if module.bias is not None:
|
| 1117 |
+
module.bias.data.zero_()
|
| 1118 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
| 1119 |
+
module.bias.data.zero_()
|
| 1120 |
+
module.weight.data.fill_(1.0)
|
| 1121 |
+
elif isinstance(module, nn.Conv1d):
|
| 1122 |
+
nn.init.kaiming_normal_(module.weight)
|
| 1123 |
+
|
| 1124 |
+
if module.bias is not None:
|
| 1125 |
+
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
| 1126 |
+
nn.init.uniform_(module.bias, a=-k, b=k)
|
| 1127 |
+
|
| 1128 |
+
def _get_feat_extract_output_lengths(
|
| 1129 |
+
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
|
| 1130 |
+
):
|
| 1131 |
+
"""
|
| 1132 |
+
Computes the output length of the convolutional layers
|
| 1133 |
+
"""
|
| 1134 |
+
|
| 1135 |
+
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
| 1136 |
+
|
| 1137 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
| 1138 |
+
# 1D convolutional layer output length formula taken
|
| 1139 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 1140 |
+
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
|
| 1141 |
+
|
| 1142 |
+
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
| 1143 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
| 1144 |
+
|
| 1145 |
+
if add_adapter:
|
| 1146 |
+
for _ in range(self.config.num_adapter_layers):
|
| 1147 |
+
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
| 1148 |
+
|
| 1149 |
+
return input_lengths
|
| 1150 |
+
|
| 1151 |
+
def _get_feature_vector_attention_mask(
|
| 1152 |
+
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
|
| 1153 |
+
):
|
| 1154 |
+
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
| 1155 |
+
# on inference mode.
|
| 1156 |
+
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
|
| 1157 |
+
|
| 1158 |
+
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
| 1159 |
+
output_lengths = output_lengths.to(torch.long)
|
| 1160 |
+
|
| 1161 |
+
batch_size = attention_mask.shape[0]
|
| 1162 |
+
|
| 1163 |
+
attention_mask = torch.zeros(
|
| 1164 |
+
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
| 1165 |
+
)
|
| 1166 |
+
# these two operations makes sure that all values before the output lengths idxs are attended to
|
| 1167 |
+
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
| 1168 |
+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
| 1169 |
+
return attention_mask
|
| 1170 |
+
|
| 1171 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 1172 |
+
if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
|
| 1173 |
+
module.gradient_checkpointing = value
|
| 1174 |
+
|
| 1175 |
+
|
| 1176 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
|
| 1177 |
+
Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
|
| 1178 |
+
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
|
| 1179 |
+
Auli.
|
| 1180 |
+
|
| 1181 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1182 |
+
library implements for all its model (such as downloading or saving etc.).
|
| 1183 |
+
|
| 1184 |
+
This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
|
| 1185 |
+
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
|
| 1186 |
+
|
| 1187 |
+
Parameters:
|
| 1188 |
+
config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
|
| 1189 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 1190 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1191 |
+
"""
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
|
| 1195 |
+
Args:
|
| 1196 |
+
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 1197 |
+
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
|
| 1198 |
+
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
|
| 1199 |
+
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
|
| 1200 |
+
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
|
| 1201 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1202 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
| 1203 |
+
1]`:
|
| 1204 |
+
|
| 1205 |
+
- 1 for tokens that are **not masked**,
|
| 1206 |
+
- 0 for tokens that are **masked**.
|
| 1207 |
+
|
| 1208 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1209 |
+
|
| 1210 |
+
<Tip warning={true}>
|
| 1211 |
+
|
| 1212 |
+
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
|
| 1213 |
+
True`. For all models whose processor has `config.return_attention_mask == False`, such as
|
| 1214 |
+
[wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
|
| 1215 |
+
`attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
|
| 1216 |
+
such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
|
| 1217 |
+
that these models also yield slightly different results depending on whether `input_values` is padded or
|
| 1218 |
+
not.
|
| 1219 |
+
|
| 1220 |
+
</Tip>
|
| 1221 |
+
|
| 1222 |
+
output_attentions (`bool`, *optional*):
|
| 1223 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1224 |
+
tensors for more detail.
|
| 1225 |
+
output_hidden_states (`bool`, *optional*):
|
| 1226 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1227 |
+
more detail.
|
| 1228 |
+
return_dict (`bool`, *optional*):
|
| 1229 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1230 |
+
"""
|
| 1231 |
+
|
| 1232 |
+
|
| 1233 |
+
@add_start_docstrings(
|
| 1234 |
+
"The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1235 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1236 |
+
)
|
| 1237 |
+
class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
|
| 1238 |
+
def __init__(self, config: Wav2Vec2ConformerConfig):
|
| 1239 |
+
super().__init__(config)
|
| 1240 |
+
self.config = config
|
| 1241 |
+
self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
|
| 1242 |
+
self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
|
| 1243 |
+
|
| 1244 |
+
# model only needs masking vector if mask prob is > 0.0
|
| 1245 |
+
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
| 1246 |
+
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
| 1247 |
+
|
| 1248 |
+
self.encoder = Wav2Vec2ConformerEncoder(config)
|
| 1249 |
+
|
| 1250 |
+
self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
|
| 1251 |
+
|
| 1252 |
+
# Initialize weights and apply final processing
|
| 1253 |
+
self.post_init()
|
| 1254 |
+
|
| 1255 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
|
| 1256 |
+
def freeze_feature_encoder(self):
|
| 1257 |
+
"""
|
| 1258 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1259 |
+
not be updated during training.
|
| 1260 |
+
"""
|
| 1261 |
+
self.feature_extractor._freeze_parameters()
|
| 1262 |
+
|
| 1263 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
|
| 1264 |
+
def _mask_hidden_states(
|
| 1265 |
+
self,
|
| 1266 |
+
hidden_states: torch.FloatTensor,
|
| 1267 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1268 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1269 |
+
):
|
| 1270 |
+
"""
|
| 1271 |
+
Masks extracted features along time axis and/or along feature axis according to
|
| 1272 |
+
[SpecAugment](https://arxiv.org/abs/1904.08779).
|
| 1273 |
+
"""
|
| 1274 |
+
|
| 1275 |
+
# `config.apply_spec_augment` can set masking to False
|
| 1276 |
+
if not getattr(self.config, "apply_spec_augment", True):
|
| 1277 |
+
return hidden_states
|
| 1278 |
+
|
| 1279 |
+
# generate indices & apply SpecAugment along time axis
|
| 1280 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 1281 |
+
|
| 1282 |
+
if mask_time_indices is not None:
|
| 1283 |
+
# apply SpecAugment along time axis with given mask_time_indices
|
| 1284 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1285 |
+
elif self.config.mask_time_prob > 0 and self.training:
|
| 1286 |
+
mask_time_indices = _compute_mask_indices(
|
| 1287 |
+
(batch_size, sequence_length),
|
| 1288 |
+
mask_prob=self.config.mask_time_prob,
|
| 1289 |
+
mask_length=self.config.mask_time_length,
|
| 1290 |
+
attention_mask=attention_mask,
|
| 1291 |
+
min_masks=self.config.mask_time_min_masks,
|
| 1292 |
+
)
|
| 1293 |
+
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1294 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1295 |
+
|
| 1296 |
+
if self.config.mask_feature_prob > 0 and self.training:
|
| 1297 |
+
# generate indices & apply SpecAugment along feature axis
|
| 1298 |
+
mask_feature_indices = _compute_mask_indices(
|
| 1299 |
+
(batch_size, hidden_size),
|
| 1300 |
+
mask_prob=self.config.mask_feature_prob,
|
| 1301 |
+
mask_length=self.config.mask_feature_length,
|
| 1302 |
+
min_masks=self.config.mask_feature_min_masks,
|
| 1303 |
+
)
|
| 1304 |
+
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1305 |
+
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
|
| 1306 |
+
hidden_states[mask_feature_indices] = 0
|
| 1307 |
+
|
| 1308 |
+
return hidden_states
|
| 1309 |
+
|
| 1310 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1311 |
+
@add_code_sample_docstrings(
|
| 1312 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1313 |
+
output_type=Wav2Vec2BaseModelOutput,
|
| 1314 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1315 |
+
modality="audio",
|
| 1316 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 1317 |
+
)
|
| 1318 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
|
| 1319 |
+
def forward(
|
| 1320 |
+
self,
|
| 1321 |
+
input_values: Optional[torch.Tensor],
|
| 1322 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1323 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1324 |
+
output_attentions: Optional[bool] = None,
|
| 1325 |
+
output_hidden_states: Optional[bool] = None,
|
| 1326 |
+
return_dict: Optional[bool] = None,
|
| 1327 |
+
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
|
| 1328 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1329 |
+
output_hidden_states = (
|
| 1330 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1331 |
+
)
|
| 1332 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1333 |
+
|
| 1334 |
+
extract_features = self.feature_extractor(input_values)
|
| 1335 |
+
extract_features = extract_features.transpose(1, 2)
|
| 1336 |
+
|
| 1337 |
+
if attention_mask is not None:
|
| 1338 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 1339 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 1340 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 1341 |
+
)
|
| 1342 |
+
|
| 1343 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
| 1344 |
+
hidden_states = self._mask_hidden_states(
|
| 1345 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
| 1346 |
+
)
|
| 1347 |
+
|
| 1348 |
+
encoder_outputs = self.encoder(
|
| 1349 |
+
hidden_states,
|
| 1350 |
+
attention_mask=attention_mask,
|
| 1351 |
+
output_attentions=output_attentions,
|
| 1352 |
+
output_hidden_states=output_hidden_states,
|
| 1353 |
+
return_dict=return_dict,
|
| 1354 |
+
)
|
| 1355 |
+
|
| 1356 |
+
hidden_states = encoder_outputs[0]
|
| 1357 |
+
|
| 1358 |
+
if self.adapter is not None:
|
| 1359 |
+
hidden_states = self.adapter(hidden_states)
|
| 1360 |
+
|
| 1361 |
+
if not return_dict:
|
| 1362 |
+
return (hidden_states, extract_features) + encoder_outputs[1:]
|
| 1363 |
+
|
| 1364 |
+
return Wav2Vec2BaseModelOutput(
|
| 1365 |
+
last_hidden_state=hidden_states,
|
| 1366 |
+
extract_features=extract_features,
|
| 1367 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1368 |
+
attentions=encoder_outputs.attentions,
|
| 1369 |
+
)
|
| 1370 |
+
|
| 1371 |
+
|
| 1372 |
+
@add_start_docstrings(
|
| 1373 |
+
"""Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
|
| 1374 |
+
)
|
| 1375 |
+
class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
| 1376 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1377 |
+
def __init__(self, config: Wav2Vec2ConformerConfig):
|
| 1378 |
+
super().__init__(config)
|
| 1379 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1380 |
+
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
|
| 1381 |
+
|
| 1382 |
+
self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
|
| 1383 |
+
|
| 1384 |
+
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
| 1385 |
+
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
| 1386 |
+
|
| 1387 |
+
# Initialize weights and apply final processing
|
| 1388 |
+
self.post_init()
|
| 1389 |
+
|
| 1390 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
|
| 1391 |
+
def set_gumbel_temperature(self, temperature: int):
|
| 1392 |
+
"""
|
| 1393 |
+
Set the Gumbel softmax temperature to a given value. Only necessary for training
|
| 1394 |
+
"""
|
| 1395 |
+
self.quantizer.temperature = temperature
|
| 1396 |
+
|
| 1397 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1398 |
+
def freeze_feature_encoder(self):
|
| 1399 |
+
"""
|
| 1400 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1401 |
+
not be updated during training.
|
| 1402 |
+
"""
|
| 1403 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1404 |
+
|
| 1405 |
+
@staticmethod
|
| 1406 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
|
| 1407 |
+
def compute_contrastive_logits(
|
| 1408 |
+
target_features: torch.FloatTensor,
|
| 1409 |
+
negative_features: torch.FloatTensor,
|
| 1410 |
+
predicted_features: torch.FloatTensor,
|
| 1411 |
+
temperature: int = 0.1,
|
| 1412 |
+
):
|
| 1413 |
+
"""
|
| 1414 |
+
Compute logits for contrastive loss based using cosine similarity as the distance measure between
|
| 1415 |
+
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
|
| 1416 |
+
"""
|
| 1417 |
+
target_features = torch.cat([target_features, negative_features], dim=0)
|
| 1418 |
+
|
| 1419 |
+
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
|
| 1420 |
+
target_features
|
| 1421 |
+
)
|
| 1422 |
+
|
| 1423 |
+
# apply temperature
|
| 1424 |
+
logits = logits / temperature
|
| 1425 |
+
return logits
|
| 1426 |
+
|
| 1427 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1428 |
+
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
| 1429 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
|
| 1430 |
+
def forward(
|
| 1431 |
+
self,
|
| 1432 |
+
input_values: Optional[torch.Tensor],
|
| 1433 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1434 |
+
mask_time_indices: Optional[torch.BoolTensor] = None,
|
| 1435 |
+
sampled_negative_indices: Optional[torch.BoolTensor] = None,
|
| 1436 |
+
output_attentions: Optional[bool] = None,
|
| 1437 |
+
output_hidden_states: Optional[bool] = None,
|
| 1438 |
+
return_dict: Optional[bool] = None,
|
| 1439 |
+
) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
|
| 1440 |
+
r"""
|
| 1441 |
+
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1442 |
+
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
| 1443 |
+
masked extracted features in *config.proj_codevector_dim* space.
|
| 1444 |
+
sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
|
| 1445 |
+
Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
|
| 1446 |
+
Required input for pre-training.
|
| 1447 |
+
|
| 1448 |
+
Returns:
|
| 1449 |
+
|
| 1450 |
+
Example:
|
| 1451 |
+
|
| 1452 |
+
```python
|
| 1453 |
+
>>> import torch
|
| 1454 |
+
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
|
| 1455 |
+
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
|
| 1456 |
+
... _compute_mask_indices,
|
| 1457 |
+
... _sample_negative_indices,
|
| 1458 |
+
... )
|
| 1459 |
+
>>> from datasets import load_dataset
|
| 1460 |
+
|
| 1461 |
+
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
| 1462 |
+
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
| 1463 |
+
|
| 1464 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 1465 |
+
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
|
| 1466 |
+
|
| 1467 |
+
>>> # compute masked indices
|
| 1468 |
+
>>> batch_size, raw_sequence_length = input_values.shape
|
| 1469 |
+
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
|
| 1470 |
+
>>> mask_time_indices = _compute_mask_indices(
|
| 1471 |
+
... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
|
| 1472 |
+
... )
|
| 1473 |
+
>>> sampled_negative_indices = _sample_negative_indices(
|
| 1474 |
+
... features_shape=(batch_size, sequence_length),
|
| 1475 |
+
... num_negatives=model.config.num_negatives,
|
| 1476 |
+
... mask_time_indices=mask_time_indices,
|
| 1477 |
+
... )
|
| 1478 |
+
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
|
| 1479 |
+
>>> sampled_negative_indices = torch.tensor(
|
| 1480 |
+
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
|
| 1481 |
+
... )
|
| 1482 |
+
|
| 1483 |
+
>>> with torch.no_grad():
|
| 1484 |
+
... outputs = model(input_values, mask_time_indices=mask_time_indices)
|
| 1485 |
+
|
| 1486 |
+
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
| 1487 |
+
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
| 1488 |
+
|
| 1489 |
+
>>> # show that cosine similarity is much higher than random
|
| 1490 |
+
>>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
|
| 1491 |
+
tensor(True)
|
| 1492 |
+
|
| 1493 |
+
>>> # for contrastive loss training model should be put into train mode
|
| 1494 |
+
>>> model = model.train()
|
| 1495 |
+
>>> loss = model(
|
| 1496 |
+
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
|
| 1497 |
+
... ).loss
|
| 1498 |
+
```"""
|
| 1499 |
+
|
| 1500 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1501 |
+
|
| 1502 |
+
if mask_time_indices is not None:
|
| 1503 |
+
mask_time_indices = mask_time_indices.to(torch.bool)
|
| 1504 |
+
|
| 1505 |
+
outputs = self.wav2vec2_conformer(
|
| 1506 |
+
input_values,
|
| 1507 |
+
attention_mask=attention_mask,
|
| 1508 |
+
output_attentions=output_attentions,
|
| 1509 |
+
output_hidden_states=output_hidden_states,
|
| 1510 |
+
mask_time_indices=mask_time_indices,
|
| 1511 |
+
return_dict=return_dict,
|
| 1512 |
+
)
|
| 1513 |
+
|
| 1514 |
+
# 1. project all transformed features (including masked) to final vq dim
|
| 1515 |
+
transformer_features = self.project_hid(outputs[0])
|
| 1516 |
+
|
| 1517 |
+
# 2. quantize all (unmasked) extracted features and project to final vq dim
|
| 1518 |
+
extract_features = self.dropout_features(outputs[1])
|
| 1519 |
+
|
| 1520 |
+
if attention_mask is not None:
|
| 1521 |
+
# compute reduced attention_mask correponding to feature vectors
|
| 1522 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 1523 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 1524 |
+
)
|
| 1525 |
+
|
| 1526 |
+
quantized_features, codevector_perplexity = self.quantizer(
|
| 1527 |
+
extract_features, mask_time_indices=mask_time_indices
|
| 1528 |
+
)
|
| 1529 |
+
quantized_features = self.project_q(quantized_features)
|
| 1530 |
+
|
| 1531 |
+
loss = contrastive_loss = diversity_loss = None
|
| 1532 |
+
if sampled_negative_indices is not None:
|
| 1533 |
+
batch_size, sequence_length, hidden_size = quantized_features.shape
|
| 1534 |
+
|
| 1535 |
+
# for training, we sample negatives
|
| 1536 |
+
# 3. sample K negatives (distractors) quantized states for contrastive loss
|
| 1537 |
+
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
|
| 1538 |
+
# sample negative quantized vectors BTC => (BxT)C
|
| 1539 |
+
negative_quantized_features = quantized_features.view(-1, hidden_size)[
|
| 1540 |
+
sampled_negative_indices.long().view(-1)
|
| 1541 |
+
]
|
| 1542 |
+
negative_quantized_features = negative_quantized_features.view(
|
| 1543 |
+
batch_size, sequence_length, -1, hidden_size
|
| 1544 |
+
).permute(2, 0, 1, 3)
|
| 1545 |
+
|
| 1546 |
+
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
|
| 1547 |
+
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
|
| 1548 |
+
logits = self.compute_contrastive_logits(
|
| 1549 |
+
quantized_features[None, :],
|
| 1550 |
+
negative_quantized_features,
|
| 1551 |
+
transformer_features,
|
| 1552 |
+
self.config.contrastive_logits_temperature,
|
| 1553 |
+
)
|
| 1554 |
+
|
| 1555 |
+
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
|
| 1556 |
+
# its cosine similarity will be masked
|
| 1557 |
+
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
|
| 1558 |
+
|
| 1559 |
+
if neg_is_pos.any():
|
| 1560 |
+
logits[1:][neg_is_pos] = float("-inf")
|
| 1561 |
+
|
| 1562 |
+
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
|
| 1563 |
+
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
|
| 1564 |
+
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
| 1565 |
+
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
|
| 1566 |
+
|
| 1567 |
+
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
|
| 1568 |
+
# 7. compute diversity loss: \mathbf{L}_d
|
| 1569 |
+
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
|
| 1570 |
+
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
|
| 1571 |
+
|
| 1572 |
+
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
|
| 1573 |
+
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
|
| 1574 |
+
|
| 1575 |
+
if not return_dict:
|
| 1576 |
+
if loss is not None:
|
| 1577 |
+
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
| 1578 |
+
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
| 1579 |
+
|
| 1580 |
+
return Wav2Vec2ConformerForPreTrainingOutput(
|
| 1581 |
+
loss=loss,
|
| 1582 |
+
projected_states=transformer_features,
|
| 1583 |
+
projected_quantized_states=quantized_features,
|
| 1584 |
+
codevector_perplexity=codevector_perplexity,
|
| 1585 |
+
hidden_states=outputs.hidden_states,
|
| 1586 |
+
attentions=outputs.attentions,
|
| 1587 |
+
contrastive_loss=contrastive_loss,
|
| 1588 |
+
diversity_loss=diversity_loss,
|
| 1589 |
+
)
|
| 1590 |
+
|
| 1591 |
+
|
| 1592 |
+
@add_start_docstrings(
|
| 1593 |
+
"""Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
|
| 1594 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1595 |
+
)
|
| 1596 |
+
class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
|
| 1597 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1598 |
+
def __init__(self, config):
|
| 1599 |
+
super().__init__(config)
|
| 1600 |
+
|
| 1601 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1602 |
+
self.dropout = nn.Dropout(config.final_dropout)
|
| 1603 |
+
|
| 1604 |
+
if config.vocab_size is None:
|
| 1605 |
+
raise ValueError(
|
| 1606 |
+
f"You are trying to instantiate {self.__class__} with a configuration that "
|
| 1607 |
+
"does not define the vocabulary size of the language model head. Please "
|
| 1608 |
+
"instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
| 1609 |
+
"or define `vocab_size` of your model's configuration."
|
| 1610 |
+
)
|
| 1611 |
+
output_hidden_size = (
|
| 1612 |
+
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
| 1613 |
+
)
|
| 1614 |
+
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
| 1615 |
+
|
| 1616 |
+
# Initialize weights and apply final processing
|
| 1617 |
+
self.post_init()
|
| 1618 |
+
|
| 1619 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1620 |
+
def freeze_feature_encoder(self):
|
| 1621 |
+
"""
|
| 1622 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1623 |
+
not be updated during training.
|
| 1624 |
+
"""
|
| 1625 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1626 |
+
|
| 1627 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1628 |
+
@add_code_sample_docstrings(
|
| 1629 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1630 |
+
output_type=CausalLMOutput,
|
| 1631 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1632 |
+
expected_output=_CTC_EXPECTED_OUTPUT,
|
| 1633 |
+
expected_loss=_CTC_EXPECTED_LOSS,
|
| 1634 |
+
)
|
| 1635 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1636 |
+
def forward(
|
| 1637 |
+
self,
|
| 1638 |
+
input_values: Optional[torch.Tensor],
|
| 1639 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1640 |
+
output_attentions: Optional[bool] = None,
|
| 1641 |
+
output_hidden_states: Optional[bool] = None,
|
| 1642 |
+
return_dict: Optional[bool] = None,
|
| 1643 |
+
labels: Optional[torch.Tensor] = None,
|
| 1644 |
+
) -> Union[Tuple, CausalLMOutput]:
|
| 1645 |
+
r"""
|
| 1646 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
| 1647 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
| 1648 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
| 1649 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
| 1650 |
+
config.vocab_size - 1]`.
|
| 1651 |
+
"""
|
| 1652 |
+
|
| 1653 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1654 |
+
|
| 1655 |
+
outputs = self.wav2vec2_conformer(
|
| 1656 |
+
input_values,
|
| 1657 |
+
attention_mask=attention_mask,
|
| 1658 |
+
output_attentions=output_attentions,
|
| 1659 |
+
output_hidden_states=output_hidden_states,
|
| 1660 |
+
return_dict=return_dict,
|
| 1661 |
+
)
|
| 1662 |
+
|
| 1663 |
+
hidden_states = outputs[0]
|
| 1664 |
+
hidden_states = self.dropout(hidden_states)
|
| 1665 |
+
|
| 1666 |
+
logits = self.lm_head(hidden_states)
|
| 1667 |
+
|
| 1668 |
+
loss = None
|
| 1669 |
+
if labels is not None:
|
| 1670 |
+
if labels.max() >= self.config.vocab_size:
|
| 1671 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
| 1672 |
+
|
| 1673 |
+
# retrieve loss input_lengths from attention_mask
|
| 1674 |
+
attention_mask = (
|
| 1675 |
+
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
| 1676 |
+
)
|
| 1677 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
| 1678 |
+
|
| 1679 |
+
# assuming that padded tokens are filled with -100
|
| 1680 |
+
# when not being attended to
|
| 1681 |
+
labels_mask = labels >= 0
|
| 1682 |
+
target_lengths = labels_mask.sum(-1)
|
| 1683 |
+
flattened_targets = labels.masked_select(labels_mask)
|
| 1684 |
+
|
| 1685 |
+
# ctc_loss doesn't support fp16
|
| 1686 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
| 1687 |
+
|
| 1688 |
+
with torch.backends.cudnn.flags(enabled=False):
|
| 1689 |
+
loss = nn.functional.ctc_loss(
|
| 1690 |
+
log_probs,
|
| 1691 |
+
flattened_targets,
|
| 1692 |
+
input_lengths,
|
| 1693 |
+
target_lengths,
|
| 1694 |
+
blank=self.config.pad_token_id,
|
| 1695 |
+
reduction=self.config.ctc_loss_reduction,
|
| 1696 |
+
zero_infinity=self.config.ctc_zero_infinity,
|
| 1697 |
+
)
|
| 1698 |
+
|
| 1699 |
+
if not return_dict:
|
| 1700 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1701 |
+
return ((loss,) + output) if loss is not None else output
|
| 1702 |
+
|
| 1703 |
+
return CausalLMOutput(
|
| 1704 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 1705 |
+
)
|
| 1706 |
+
|
| 1707 |
+
|
| 1708 |
+
@add_start_docstrings(
|
| 1709 |
+
"""
|
| 1710 |
+
Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
|
| 1711 |
+
tasks like SUPERB Keyword Spotting.
|
| 1712 |
+
""",
|
| 1713 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1714 |
+
)
|
| 1715 |
+
class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
|
| 1716 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1717 |
+
def __init__(self, config):
|
| 1718 |
+
super().__init__(config)
|
| 1719 |
+
|
| 1720 |
+
if hasattr(config, "add_adapter") and config.add_adapter:
|
| 1721 |
+
raise ValueError(
|
| 1722 |
+
"Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
|
| 1723 |
+
)
|
| 1724 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1725 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1726 |
+
if config.use_weighted_layer_sum:
|
| 1727 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1728 |
+
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 1729 |
+
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
| 1730 |
+
|
| 1731 |
+
# Initialize weights and apply final processing
|
| 1732 |
+
self.post_init()
|
| 1733 |
+
|
| 1734 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1735 |
+
def freeze_feature_encoder(self):
|
| 1736 |
+
"""
|
| 1737 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1738 |
+
not be updated during training.
|
| 1739 |
+
"""
|
| 1740 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1741 |
+
|
| 1742 |
+
def freeze_base_model(self):
|
| 1743 |
+
"""
|
| 1744 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1745 |
+
be updated during training. Only the classification head will be updated.
|
| 1746 |
+
"""
|
| 1747 |
+
for param in self.wav2vec2_conformer.parameters():
|
| 1748 |
+
param.requires_grad = False
|
| 1749 |
+
|
| 1750 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1751 |
+
@add_code_sample_docstrings(
|
| 1752 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1753 |
+
output_type=SequenceClassifierOutput,
|
| 1754 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1755 |
+
modality="audio",
|
| 1756 |
+
)
|
| 1757 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
| 1758 |
+
def forward(
|
| 1759 |
+
self,
|
| 1760 |
+
input_values: Optional[torch.Tensor],
|
| 1761 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1762 |
+
output_attentions: Optional[bool] = None,
|
| 1763 |
+
output_hidden_states: Optional[bool] = None,
|
| 1764 |
+
return_dict: Optional[bool] = None,
|
| 1765 |
+
labels: Optional[torch.Tensor] = None,
|
| 1766 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 1767 |
+
r"""
|
| 1768 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1769 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1770 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1771 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1772 |
+
"""
|
| 1773 |
+
|
| 1774 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1775 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1776 |
+
|
| 1777 |
+
outputs = self.wav2vec2_conformer(
|
| 1778 |
+
input_values,
|
| 1779 |
+
attention_mask=attention_mask,
|
| 1780 |
+
output_attentions=output_attentions,
|
| 1781 |
+
output_hidden_states=output_hidden_states,
|
| 1782 |
+
return_dict=return_dict,
|
| 1783 |
+
)
|
| 1784 |
+
|
| 1785 |
+
if self.config.use_weighted_layer_sum:
|
| 1786 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1787 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1788 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1789 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1790 |
+
else:
|
| 1791 |
+
hidden_states = outputs[0]
|
| 1792 |
+
|
| 1793 |
+
hidden_states = self.projector(hidden_states)
|
| 1794 |
+
if attention_mask is None:
|
| 1795 |
+
pooled_output = hidden_states.mean(dim=1)
|
| 1796 |
+
else:
|
| 1797 |
+
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
| 1798 |
+
hidden_states[~padding_mask] = 0.0
|
| 1799 |
+
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
| 1800 |
+
|
| 1801 |
+
logits = self.classifier(pooled_output)
|
| 1802 |
+
|
| 1803 |
+
loss = None
|
| 1804 |
+
if labels is not None:
|
| 1805 |
+
loss_fct = CrossEntropyLoss()
|
| 1806 |
+
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| 1807 |
+
|
| 1808 |
+
if not return_dict:
|
| 1809 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1810 |
+
return ((loss,) + output) if loss is not None else output
|
| 1811 |
+
|
| 1812 |
+
return SequenceClassifierOutput(
|
| 1813 |
+
loss=loss,
|
| 1814 |
+
logits=logits,
|
| 1815 |
+
hidden_states=outputs.hidden_states,
|
| 1816 |
+
attentions=outputs.attentions,
|
| 1817 |
+
)
|
| 1818 |
+
|
| 1819 |
+
|
| 1820 |
+
@add_start_docstrings(
|
| 1821 |
+
"""
|
| 1822 |
+
Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
|
| 1823 |
+
""",
|
| 1824 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1825 |
+
)
|
| 1826 |
+
class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
|
| 1827 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
| 1828 |
+
def __init__(self, config):
|
| 1829 |
+
super().__init__(config)
|
| 1830 |
+
|
| 1831 |
+
if hasattr(config, "add_adapter") and config.add_adapter:
|
| 1832 |
+
raise ValueError(
|
| 1833 |
+
"Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
|
| 1834 |
+
)
|
| 1835 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1836 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1837 |
+
if config.use_weighted_layer_sum:
|
| 1838 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1839 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1840 |
+
self.num_labels = config.num_labels
|
| 1841 |
+
|
| 1842 |
+
self.init_weights()
|
| 1843 |
+
|
| 1844 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1845 |
+
def freeze_feature_encoder(self):
|
| 1846 |
+
"""
|
| 1847 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1848 |
+
not be updated during training.
|
| 1849 |
+
"""
|
| 1850 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1851 |
+
|
| 1852 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
|
| 1853 |
+
def freeze_base_model(self):
|
| 1854 |
+
"""
|
| 1855 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1856 |
+
be updated during training. Only the classification head will be updated.
|
| 1857 |
+
"""
|
| 1858 |
+
for param in self.wav2vec2_conformer.parameters():
|
| 1859 |
+
param.requires_grad = False
|
| 1860 |
+
|
| 1861 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1862 |
+
@add_code_sample_docstrings(
|
| 1863 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1864 |
+
output_type=TokenClassifierOutput,
|
| 1865 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1866 |
+
modality="audio",
|
| 1867 |
+
)
|
| 1868 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
|
| 1869 |
+
def forward(
|
| 1870 |
+
self,
|
| 1871 |
+
input_values: Optional[torch.Tensor],
|
| 1872 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1873 |
+
labels: Optional[torch.Tensor] = None,
|
| 1874 |
+
output_attentions: Optional[bool] = None,
|
| 1875 |
+
output_hidden_states: Optional[bool] = None,
|
| 1876 |
+
return_dict: Optional[bool] = None,
|
| 1877 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 1878 |
+
r"""
|
| 1879 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1880 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1881 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1882 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1883 |
+
"""
|
| 1884 |
+
|
| 1885 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1886 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1887 |
+
|
| 1888 |
+
outputs = self.wav2vec2_conformer(
|
| 1889 |
+
input_values,
|
| 1890 |
+
attention_mask=attention_mask,
|
| 1891 |
+
output_attentions=output_attentions,
|
| 1892 |
+
output_hidden_states=output_hidden_states,
|
| 1893 |
+
return_dict=return_dict,
|
| 1894 |
+
)
|
| 1895 |
+
|
| 1896 |
+
if self.config.use_weighted_layer_sum:
|
| 1897 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1898 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1899 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1900 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1901 |
+
else:
|
| 1902 |
+
hidden_states = outputs[0]
|
| 1903 |
+
|
| 1904 |
+
logits = self.classifier(hidden_states)
|
| 1905 |
+
|
| 1906 |
+
loss = None
|
| 1907 |
+
if labels is not None:
|
| 1908 |
+
loss_fct = CrossEntropyLoss()
|
| 1909 |
+
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
|
| 1910 |
+
|
| 1911 |
+
if not return_dict:
|
| 1912 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1913 |
+
return output
|
| 1914 |
+
|
| 1915 |
+
return TokenClassifierOutput(
|
| 1916 |
+
loss=loss,
|
| 1917 |
+
logits=logits,
|
| 1918 |
+
hidden_states=outputs.hidden_states,
|
| 1919 |
+
attentions=outputs.attentions,
|
| 1920 |
+
)
|
| 1921 |
+
|
| 1922 |
+
|
| 1923 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
|
| 1924 |
+
class AMSoftmaxLoss(nn.Module):
|
| 1925 |
+
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
|
| 1926 |
+
super(AMSoftmaxLoss, self).__init__()
|
| 1927 |
+
self.scale = scale
|
| 1928 |
+
self.margin = margin
|
| 1929 |
+
self.num_labels = num_labels
|
| 1930 |
+
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
|
| 1931 |
+
self.loss = nn.CrossEntropyLoss()
|
| 1932 |
+
|
| 1933 |
+
def forward(self, hidden_states, labels):
|
| 1934 |
+
labels = labels.flatten()
|
| 1935 |
+
weight = nn.functional.normalize(self.weight, dim=0)
|
| 1936 |
+
hidden_states = nn.functional.normalize(hidden_states, dim=1)
|
| 1937 |
+
cos_theta = torch.mm(hidden_states, weight)
|
| 1938 |
+
psi = cos_theta - self.margin
|
| 1939 |
+
|
| 1940 |
+
onehot = nn.functional.one_hot(labels, self.num_labels)
|
| 1941 |
+
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
|
| 1942 |
+
loss = self.loss(logits, labels)
|
| 1943 |
+
|
| 1944 |
+
return loss
|
| 1945 |
+
|
| 1946 |
+
|
| 1947 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
|
| 1948 |
+
class TDNNLayer(nn.Module):
|
| 1949 |
+
def __init__(self, config, layer_id=0):
|
| 1950 |
+
super().__init__()
|
| 1951 |
+
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
|
| 1952 |
+
self.out_conv_dim = config.tdnn_dim[layer_id]
|
| 1953 |
+
self.kernel_size = config.tdnn_kernel[layer_id]
|
| 1954 |
+
self.dilation = config.tdnn_dilation[layer_id]
|
| 1955 |
+
|
| 1956 |
+
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
|
| 1957 |
+
self.activation = nn.ReLU()
|
| 1958 |
+
|
| 1959 |
+
def forward(self, hidden_states):
|
| 1960 |
+
hidden_states = hidden_states.unsqueeze(1)
|
| 1961 |
+
hidden_states = nn.functional.unfold(
|
| 1962 |
+
hidden_states,
|
| 1963 |
+
(self.kernel_size, self.in_conv_dim),
|
| 1964 |
+
stride=(1, self.in_conv_dim),
|
| 1965 |
+
dilation=(self.dilation, 1),
|
| 1966 |
+
)
|
| 1967 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1968 |
+
hidden_states = self.kernel(hidden_states)
|
| 1969 |
+
|
| 1970 |
+
hidden_states = self.activation(hidden_states)
|
| 1971 |
+
return hidden_states
|
| 1972 |
+
|
| 1973 |
+
|
| 1974 |
+
@add_start_docstrings(
|
| 1975 |
+
"""
|
| 1976 |
+
Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
| 1977 |
+
""",
|
| 1978 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1979 |
+
)
|
| 1980 |
+
class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
|
| 1981 |
+
def __init__(self, config):
|
| 1982 |
+
super().__init__(config)
|
| 1983 |
+
|
| 1984 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1985 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1986 |
+
if config.use_weighted_layer_sum:
|
| 1987 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1988 |
+
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
|
| 1989 |
+
|
| 1990 |
+
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
|
| 1991 |
+
self.tdnn = nn.ModuleList(tdnn_layers)
|
| 1992 |
+
|
| 1993 |
+
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
|
| 1994 |
+
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
|
| 1995 |
+
|
| 1996 |
+
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
|
| 1997 |
+
|
| 1998 |
+
self.init_weights()
|
| 1999 |
+
|
| 2000 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 2001 |
+
def freeze_feature_encoder(self):
|
| 2002 |
+
"""
|
| 2003 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 2004 |
+
not be updated during training.
|
| 2005 |
+
"""
|
| 2006 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 2007 |
+
|
| 2008 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
|
| 2009 |
+
def freeze_base_model(self):
|
| 2010 |
+
"""
|
| 2011 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 2012 |
+
be updated during training. Only the classification head will be updated.
|
| 2013 |
+
"""
|
| 2014 |
+
for param in self.wav2vec2_conformer.parameters():
|
| 2015 |
+
param.requires_grad = False
|
| 2016 |
+
|
| 2017 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
|
| 2018 |
+
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
| 2019 |
+
"""
|
| 2020 |
+
Computes the output length of the TDNN layers
|
| 2021 |
+
"""
|
| 2022 |
+
|
| 2023 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
| 2024 |
+
# 1D convolutional layer output length formula taken
|
| 2025 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 2026 |
+
return (input_length - kernel_size) // stride + 1
|
| 2027 |
+
|
| 2028 |
+
for kernel_size in self.config.tdnn_kernel:
|
| 2029 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
|
| 2030 |
+
|
| 2031 |
+
return input_lengths
|
| 2032 |
+
|
| 2033 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 2034 |
+
@add_code_sample_docstrings(
|
| 2035 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 2036 |
+
output_type=XVectorOutput,
|
| 2037 |
+
config_class=_CONFIG_FOR_DOC,
|
| 2038 |
+
modality="audio",
|
| 2039 |
+
)
|
| 2040 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
| 2041 |
+
def forward(
|
| 2042 |
+
self,
|
| 2043 |
+
input_values: Optional[torch.Tensor],
|
| 2044 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 2045 |
+
output_attentions: Optional[bool] = None,
|
| 2046 |
+
output_hidden_states: Optional[bool] = None,
|
| 2047 |
+
return_dict: Optional[bool] = None,
|
| 2048 |
+
labels: Optional[torch.Tensor] = None,
|
| 2049 |
+
) -> Union[Tuple, XVectorOutput]:
|
| 2050 |
+
r"""
|
| 2051 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 2052 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 2053 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 2054 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 2055 |
+
"""
|
| 2056 |
+
|
| 2057 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2058 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 2059 |
+
|
| 2060 |
+
outputs = self.wav2vec2_conformer(
|
| 2061 |
+
input_values,
|
| 2062 |
+
attention_mask=attention_mask,
|
| 2063 |
+
output_attentions=output_attentions,
|
| 2064 |
+
output_hidden_states=output_hidden_states,
|
| 2065 |
+
return_dict=return_dict,
|
| 2066 |
+
)
|
| 2067 |
+
|
| 2068 |
+
if self.config.use_weighted_layer_sum:
|
| 2069 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 2070 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 2071 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 2072 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 2073 |
+
else:
|
| 2074 |
+
hidden_states = outputs[0]
|
| 2075 |
+
|
| 2076 |
+
hidden_states = self.projector(hidden_states)
|
| 2077 |
+
|
| 2078 |
+
for tdnn_layer in self.tdnn:
|
| 2079 |
+
hidden_states = tdnn_layer(hidden_states)
|
| 2080 |
+
|
| 2081 |
+
# Statistic Pooling
|
| 2082 |
+
if attention_mask is None:
|
| 2083 |
+
mean_features = hidden_states.mean(dim=1)
|
| 2084 |
+
std_features = hidden_states.std(dim=1)
|
| 2085 |
+
else:
|
| 2086 |
+
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
|
| 2087 |
+
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
|
| 2088 |
+
mean_features = []
|
| 2089 |
+
std_features = []
|
| 2090 |
+
for i, length in enumerate(tdnn_output_lengths):
|
| 2091 |
+
mean_features.append(hidden_states[i, :length].mean(dim=0))
|
| 2092 |
+
std_features.append(hidden_states[i, :length].std(dim=0))
|
| 2093 |
+
mean_features = torch.stack(mean_features)
|
| 2094 |
+
std_features = torch.stack(std_features)
|
| 2095 |
+
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
|
| 2096 |
+
|
| 2097 |
+
output_embeddings = self.feature_extractor(statistic_pooling)
|
| 2098 |
+
logits = self.classifier(output_embeddings)
|
| 2099 |
+
|
| 2100 |
+
loss = None
|
| 2101 |
+
if labels is not None:
|
| 2102 |
+
loss = self.objective(logits, labels)
|
| 2103 |
+
|
| 2104 |
+
if not return_dict:
|
| 2105 |
+
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 2106 |
+
return ((loss,) + output) if loss is not None else output
|
| 2107 |
+
|
| 2108 |
+
return XVectorOutput(
|
| 2109 |
+
loss=loss,
|
| 2110 |
+
logits=logits,
|
| 2111 |
+
embeddings=output_embeddings,
|
| 2112 |
+
hidden_states=outputs.hidden_states,
|
| 2113 |
+
attentions=outputs.attentions,
|
| 2114 |
+
)
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/random_quantizer.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, einsum
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class RandomProjectionQuantizer(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Random projection and codebook lookup module
|
| 9 |
+
|
| 10 |
+
Some code is borrowed from:
|
| 11 |
+
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
|
| 12 |
+
But I did normalization using pre-computed global mean & variance instead of using layer norm.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
input_dim,
|
| 18 |
+
codebook_dim,
|
| 19 |
+
codebook_size,
|
| 20 |
+
seed=142,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
# random seed
|
| 25 |
+
torch.manual_seed(seed)
|
| 26 |
+
|
| 27 |
+
# randomly initialized projection
|
| 28 |
+
random_projection = torch.empty(input_dim, codebook_dim)
|
| 29 |
+
nn.init.xavier_normal_(random_projection)
|
| 30 |
+
self.register_buffer("random_projection", random_projection)
|
| 31 |
+
|
| 32 |
+
# randomly initialized codebook
|
| 33 |
+
codebook = torch.empty(codebook_size, codebook_dim)
|
| 34 |
+
nn.init.normal_(codebook)
|
| 35 |
+
self.register_buffer("codebook", codebook)
|
| 36 |
+
|
| 37 |
+
def codebook_lookup(self, x):
|
| 38 |
+
# reshape
|
| 39 |
+
b = x.shape[0]
|
| 40 |
+
x = rearrange(x, "b n e -> (b n) e")
|
| 41 |
+
|
| 42 |
+
# L2 normalization
|
| 43 |
+
normalized_x = nn.functional.normalize(x, dim=1, p=2)
|
| 44 |
+
normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
|
| 45 |
+
|
| 46 |
+
# compute distances
|
| 47 |
+
distances = torch.cdist(normalized_codebook, normalized_x)
|
| 48 |
+
|
| 49 |
+
# get nearest
|
| 50 |
+
nearest_indices = torch.argmin(distances, dim=0)
|
| 51 |
+
|
| 52 |
+
# reshape
|
| 53 |
+
xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
|
| 54 |
+
|
| 55 |
+
return xq
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
# always eval
|
| 60 |
+
self.eval()
|
| 61 |
+
|
| 62 |
+
# random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
|
| 63 |
+
x = einsum("b n d, d e -> b n e", x, self.random_projection)
|
| 64 |
+
|
| 65 |
+
# codebook lookup
|
| 66 |
+
xq = self.codebook_lookup(x)
|
| 67 |
+
|
| 68 |
+
return xq
|
MuCodec/muq_dev/muq_fairseq/models/muq/muq_model.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from .model.muq import MuQ
|
| 3 |
+
except:
|
| 4 |
+
import sys, os
|
| 5 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
+
from model.muq import MuQ
|
| 7 |
+
try:
|
| 8 |
+
from fairseq.fairseq.dataclass import FairseqDataclass
|
| 9 |
+
from fairseq.fairseq.models import BaseFairseqModel, register_model
|
| 10 |
+
from fairseq.fairseq.tasks.fairseq_task import FairseqTask
|
| 11 |
+
except:
|
| 12 |
+
from fairseq.dataclass import FairseqDataclass
|
| 13 |
+
from fairseq.models import BaseFairseqModel, register_model
|
| 14 |
+
from fairseq.tasks.fairseq_task import FairseqTask
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import List, Tuple, Optional
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from logging import getLogger
|
| 21 |
+
|
| 22 |
+
logger = getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class MuQConfig(FairseqDataclass):
|
| 26 |
+
label_rate:int = field(default=25)
|
| 27 |
+
num_codebooks:int = field(default=1)
|
| 28 |
+
codebook_dim:int = field(default=16)
|
| 29 |
+
codebook_size:int = field(default=4096)
|
| 30 |
+
features:List[str] = field(default_factory=lambda:["melspec_2048"])
|
| 31 |
+
hop_length:int = field(default=240)
|
| 32 |
+
n_mels:int = field(default=128)
|
| 33 |
+
conv_dim:int = field(default=512)
|
| 34 |
+
encoder_dim:int = field(default=1024)
|
| 35 |
+
encoder_depth:int = field(default=12)
|
| 36 |
+
mask_hop:float = field(default=0.4)
|
| 37 |
+
mask_prob:float = field(default=0.6)
|
| 38 |
+
is_flash:bool = field(default=False)
|
| 39 |
+
stat_path:Optional[str] = field(default=None)
|
| 40 |
+
model_path:Optional[str] = field(default=None)
|
| 41 |
+
w2v2_config_path:Optional[str] = field(default=None)
|
| 42 |
+
use_rvq_target:bool = field(default=False)
|
| 43 |
+
use_vq_target:bool = field(default=False)
|
| 44 |
+
rvq_ckpt_path: Optional[str] = field(default=None)
|
| 45 |
+
recon_loss_ratio: Optional[float] = field(default=None)
|
| 46 |
+
resume_checkpoint: Optional[str] = None
|
| 47 |
+
use_hubert_masking_strategy:bool = field(default=False)
|
| 48 |
+
use_hubert_featurizer:bool = field(default=False)
|
| 49 |
+
hubert_conv_feature_layers:str = field(default_factory=lambda:"[(512,10,5)] + [(512,3,2)] * 3 + [(512,3,3)] + [(512,2,2)] * 2")
|
| 50 |
+
rvq_n_codebooks:int = field(default=8)
|
| 51 |
+
rvq_multi_layer_num:int = field(default=1)
|
| 52 |
+
use_encodec_target:bool = field(default=False)
|
| 53 |
+
|
| 54 |
+
SAMPLE_RATE = 24_000
|
| 55 |
+
|
| 56 |
+
@register_model("muq", dataclass=MuQConfig)
|
| 57 |
+
class MuQModel(BaseFairseqModel):
|
| 58 |
+
def __init__(self, cfg: MuQConfig, task_cfg: FairseqTask):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.cfg = cfg
|
| 61 |
+
self.model = MuQ(
|
| 62 |
+
num_codebooks=cfg.num_codebooks,
|
| 63 |
+
codebook_dim=cfg.codebook_dim,
|
| 64 |
+
codebook_size=cfg.codebook_size,
|
| 65 |
+
features=cfg.features,
|
| 66 |
+
n_mels=cfg.n_mels,
|
| 67 |
+
conv_dim=cfg.conv_dim,
|
| 68 |
+
encoder_dim=cfg.encoder_dim,
|
| 69 |
+
encoder_depth=cfg.encoder_depth,
|
| 70 |
+
mask_hop=cfg.mask_hop,
|
| 71 |
+
mask_prob=cfg.mask_prob,
|
| 72 |
+
is_flash=cfg.is_flash,
|
| 73 |
+
stat_path=cfg.stat_path,
|
| 74 |
+
model_path=cfg.model_path,
|
| 75 |
+
w2v2_config_path=cfg.w2v2_config_path,
|
| 76 |
+
use_rvq_target=cfg.use_rvq_target,
|
| 77 |
+
use_vq_target=cfg.use_vq_target,
|
| 78 |
+
rvq_ckpt_path=cfg.rvq_ckpt_path,
|
| 79 |
+
recon_loss_ratio=cfg.recon_loss_ratio,
|
| 80 |
+
label_rate=cfg.label_rate,
|
| 81 |
+
use_hubert_masking_strategy=cfg.use_hubert_masking_strategy,
|
| 82 |
+
use_hubert_featurizer=cfg.use_hubert_featurizer,
|
| 83 |
+
hubert_conv_feature_layers=cfg.hubert_conv_feature_layers,
|
| 84 |
+
rvq_n_codebooks=cfg.rvq_n_codebooks,
|
| 85 |
+
rvq_multi_layer_num=cfg.rvq_multi_layer_num,
|
| 86 |
+
use_encodec_target=cfg.use_encodec_target,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def forward(
|
| 90 |
+
self,
|
| 91 |
+
source: torch.Tensor, # B,L
|
| 92 |
+
features_only: bool = False,
|
| 93 |
+
label = None, # pre-extracted labeks, dim is [Batch, N_Codebook, SeqLen]
|
| 94 |
+
**kwargs,
|
| 95 |
+
):
|
| 96 |
+
source = source[..., :int((source.shape[-1]//(SAMPLE_RATE//self.cfg.label_rate))*(SAMPLE_RATE//self.cfg.label_rate)) ]
|
| 97 |
+
if features_only:
|
| 98 |
+
if 'attention_mask' in kwargs:
|
| 99 |
+
attention_mask = kwargs['attention_mask']
|
| 100 |
+
elif 'padding_mask' in kwargs:
|
| 101 |
+
attention_mask = ~kwargs['padding_mask'].bool()
|
| 102 |
+
else:
|
| 103 |
+
attention_mask = None
|
| 104 |
+
_, hidden_states = self.model.get_predictions(source, attention_mask=attention_mask, is_features_only=True)
|
| 105 |
+
result = {
|
| 106 |
+
"layer_results": hidden_states
|
| 107 |
+
}
|
| 108 |
+
return result
|
| 109 |
+
else:
|
| 110 |
+
result = {}
|
| 111 |
+
logits, hidden_emb, losses, accuracies = self.model(source, label=label)
|
| 112 |
+
result["losses"] = losses
|
| 113 |
+
result["accuracies"] = accuracies
|
| 114 |
+
result["logits"] = logits
|
| 115 |
+
result["hidden_emb"] = hidden_emb
|
| 116 |
+
for k, v in losses.items():
|
| 117 |
+
result[k] = v
|
| 118 |
+
return result
|
| 119 |
+
|
| 120 |
+
@classmethod
|
| 121 |
+
def build_model(cls, cfg: MuQConfig, task: FairseqTask):
|
| 122 |
+
"""Build a new model instance."""
|
| 123 |
+
|
| 124 |
+
model = MuQModel(cfg, task.cfg)
|
| 125 |
+
import numpy as np
|
| 126 |
+
s = 0
|
| 127 |
+
for param in model.parameters():
|
| 128 |
+
s += np.product(param.size())
|
| 129 |
+
# print('# of parameters: '+str(s/1024.0/1024.0))
|
| 130 |
+
|
| 131 |
+
if cfg.get("resume_checkpoint", None):
|
| 132 |
+
print("Loading checkpoint from {}".format(cfg.resume_checkpoint))
|
| 133 |
+
model.load_state_dict(torch.load(cfg.resume_checkpoint)['model'], strict=False)
|
| 134 |
+
|
| 135 |
+
return model
|
| 136 |
+
|
| 137 |
+
def get_losses(self, result, batch):
|
| 138 |
+
return result['losses']
|
| 139 |
+
|
MuCodec/muq_dev/muq_fairseq/tasks/__pycache__/muq_pretraining.cpython-310.pyc
ADDED
|
Binary file (9.93 kB). View file
|
|
|
MuCodec/muq_dev/muq_fairseq/tasks/muq_pretraining.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the LICENSE file in
|
| 5 |
+
# the root directory of this source tree. An additional grant of patent rights
|
| 6 |
+
# can be found in the PATENTS file in the same directory.
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
from typing import Dict, List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from fairseq.data import Dictionary, HubertDataset
|
| 18 |
+
from fairseq.dataclass.configs import FairseqDataclass
|
| 19 |
+
from fairseq.tasks import register_task
|
| 20 |
+
from fairseq.tasks.fairseq_task import FairseqTask
|
| 21 |
+
from omegaconf import MISSING
|
| 22 |
+
|
| 23 |
+
from ..data.mert_dataset import MERTDataset
|
| 24 |
+
from ..data.ark_dataset import ArkDataset
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class LabelEncoder(object):
|
| 30 |
+
def __init__(self, dictionary: Dictionary) -> None:
|
| 31 |
+
self.dictionary = dictionary
|
| 32 |
+
|
| 33 |
+
def __call__(self, label: str) -> List[str]:
|
| 34 |
+
# encode_line return a torch.IntTensor, should be all 1 for vanila HuBERT
|
| 35 |
+
return self.dictionary.encode_line(
|
| 36 |
+
label,
|
| 37 |
+
append_eos=False,
|
| 38 |
+
add_if_not_exist=False,
|
| 39 |
+
)
|
| 40 |
+
class PaddedNumpyLabelEncoder(object):
|
| 41 |
+
def __init__(self):
|
| 42 |
+
# self.dictionary = dictionary
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def __call__(self, label):
|
| 46 |
+
t = torch.IntTensor(np.asarray(label))
|
| 47 |
+
t = t[t>=0] # remove padded -1 values at the end
|
| 48 |
+
return t
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class MuQPretrainingConfig(FairseqDataclass):
|
| 52 |
+
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
|
| 53 |
+
sharding_data: int = field(
|
| 54 |
+
default=-1,
|
| 55 |
+
metadata={
|
| 56 |
+
"help": "set this para >1 to use sharding dataset to prevent OOM"
|
| 57 |
+
"prepare data tsv and label files by adding postfix for sharding 64 like:"
|
| 58 |
+
"train_28_64.tsv and train_28_64.encodec_6"
|
| 59 |
+
},
|
| 60 |
+
)
|
| 61 |
+
load_random_data_shard: bool = field(
|
| 62 |
+
default=True,
|
| 63 |
+
metadata={
|
| 64 |
+
"help": "whether to laod shards randomly or in order when use sharding_data"
|
| 65 |
+
},
|
| 66 |
+
)
|
| 67 |
+
fine_tuning: bool = field(
|
| 68 |
+
default=False, metadata={"help": "set to true if fine-tuning Hubert"}
|
| 69 |
+
)
|
| 70 |
+
labels: List[str] = field(
|
| 71 |
+
default_factory=lambda: ["ltr"],
|
| 72 |
+
metadata={
|
| 73 |
+
"help": (
|
| 74 |
+
"extension of the label files to load, frame-level labels for"
|
| 75 |
+
" pre-training, and sequence-level label for fine-tuning"
|
| 76 |
+
)
|
| 77 |
+
},
|
| 78 |
+
)
|
| 79 |
+
label_dir: Optional[str] = field(
|
| 80 |
+
default=None,
|
| 81 |
+
metadata={
|
| 82 |
+
"help": "if set, looks for labels in this directory instead",
|
| 83 |
+
},
|
| 84 |
+
)
|
| 85 |
+
label_scp_path: Optional[str] = field(
|
| 86 |
+
default=None,
|
| 87 |
+
metadata={
|
| 88 |
+
'help': 'if set, load label from scp file'
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
label_scp_clip_duration: float = field(
|
| 92 |
+
default=-1,
|
| 93 |
+
metadata={
|
| 94 |
+
'help': 'clip duration for loading scp label. if set to -1, this will not make effect.'
|
| 95 |
+
}
|
| 96 |
+
)
|
| 97 |
+
label_rate: float = field(
|
| 98 |
+
default=-1.0,
|
| 99 |
+
metadata={"help": "label frame rate. -1.0 for sequence label"},
|
| 100 |
+
)
|
| 101 |
+
sample_rate: int = field(
|
| 102 |
+
default=16_000,
|
| 103 |
+
metadata={
|
| 104 |
+
"help": "target sample rate. audio files will be up/down "
|
| 105 |
+
"sampled to this rate"
|
| 106 |
+
},
|
| 107 |
+
)
|
| 108 |
+
normalize: bool = field(
|
| 109 |
+
default=False,
|
| 110 |
+
metadata={"help": "if set, normalizes input to have 0 mean and unit variance"},
|
| 111 |
+
)
|
| 112 |
+
enable_padding: bool = field(
|
| 113 |
+
default=False,
|
| 114 |
+
metadata={"help": "pad shorter samples instead of cropping"},
|
| 115 |
+
)
|
| 116 |
+
max_keep_size: Optional[int] = field(
|
| 117 |
+
default=None,
|
| 118 |
+
metadata={"help": "exclude sample longer than this"},
|
| 119 |
+
)
|
| 120 |
+
max_sample_size: Optional[int] = field(
|
| 121 |
+
default=None,
|
| 122 |
+
metadata={"help": "max sample size to crop to for batching"},
|
| 123 |
+
)
|
| 124 |
+
min_sample_size: Optional[int] = field(
|
| 125 |
+
default=None,
|
| 126 |
+
metadata={"help": "min sample size to crop to for batching"},
|
| 127 |
+
)
|
| 128 |
+
single_target: Optional[bool] = field(
|
| 129 |
+
default=False,
|
| 130 |
+
metadata={
|
| 131 |
+
"help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset"
|
| 132 |
+
},
|
| 133 |
+
)
|
| 134 |
+
random_crop: Optional[bool] = field(
|
| 135 |
+
default=True,
|
| 136 |
+
metadata={"help": "always crop from the beginning if false"},
|
| 137 |
+
)
|
| 138 |
+
pad_audio: Optional[bool] = field(
|
| 139 |
+
default=False,
|
| 140 |
+
metadata={"help": "pad audio to the longest one in the batch if true"},
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
store_labels: Optional[bool] = field(
|
| 144 |
+
default=False,
|
| 145 |
+
metadata={"help": "whether to load all of the label into memory"},
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
numpy_memmap_label: Optional[bool] = field(
|
| 149 |
+
default=False,
|
| 150 |
+
metadata={"help": "whether the label file is saved as a numpy file, each line is ended with padding -1"},
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
augmentation_effects: Optional[str] = field(
|
| 154 |
+
default="[]",
|
| 155 |
+
metadata={
|
| 156 |
+
"help": (
|
| 157 |
+
"a list of effects that might apply to the audios"
|
| 158 |
+
"example: \"['random_mute', 'random_Gaussian', 'reverse_polarity']\" "
|
| 159 |
+
"supported: random_mute,"
|
| 160 |
+
"todo: "
|
| 161 |
+
)
|
| 162 |
+
},
|
| 163 |
+
)
|
| 164 |
+
augmentation_probs: Optional[str] = field(
|
| 165 |
+
default="[]",
|
| 166 |
+
metadata={
|
| 167 |
+
"help": (
|
| 168 |
+
"the corresponding probabilities for the data augmentation effects"
|
| 169 |
+
"example: \"[0.1, 0.5, 0.8]\" "
|
| 170 |
+
"the sum is not necessarily need to be 1.0, and multiple effects can be applied to the same audio"
|
| 171 |
+
)
|
| 172 |
+
},
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# inbatch_noise_augment_len_range: Optional[List[int]] = field(
|
| 176 |
+
# default_factory=lambda: [8000, 24000],
|
| 177 |
+
# default = [8000, 24000],
|
| 178 |
+
inbatch_noise_augment_len_range: Optional[str] = field(
|
| 179 |
+
default = "[8000, 24000]",
|
| 180 |
+
metadata={
|
| 181 |
+
"help": (
|
| 182 |
+
"the range of length of the mix-up noise augmentation, unit in smaples"
|
| 183 |
+
)
|
| 184 |
+
},
|
| 185 |
+
)
|
| 186 |
+
# inbatch_noise_augment_number_range: Optional[List[int]] = field(
|
| 187 |
+
# default_factory=lambda: [1, 3],
|
| 188 |
+
# default = [1, 3],
|
| 189 |
+
inbatch_noise_augment_number_range: Optional[str] = field(
|
| 190 |
+
default = "[1, 3]",
|
| 191 |
+
metadata={
|
| 192 |
+
"help": (
|
| 193 |
+
"the range of numbers of the mix-up noise augmentation"
|
| 194 |
+
)
|
| 195 |
+
},
|
| 196 |
+
)
|
| 197 |
+
inbatch_noise_augment_volume: float = field(
|
| 198 |
+
default = 1.0,
|
| 199 |
+
metadata={
|
| 200 |
+
"help": (
|
| 201 |
+
"the coefficient used to modify the volume of the noise audios wavs"
|
| 202 |
+
)
|
| 203 |
+
},
|
| 204 |
+
)
|
| 205 |
+
dynamic_crops: Optional[str] = field(
|
| 206 |
+
default="[]",
|
| 207 |
+
metadata={
|
| 208 |
+
"help": (
|
| 209 |
+
"used to set the maximum audio length setting, for training"
|
| 210 |
+
"example: \"[1, 2, 3, 4, 5, 10]\" "
|
| 211 |
+
)
|
| 212 |
+
},
|
| 213 |
+
)
|
| 214 |
+
dynamic_crops_epoches: Optional[str] = field(
|
| 215 |
+
default="[]",
|
| 216 |
+
metadata={
|
| 217 |
+
"help": (
|
| 218 |
+
"used to set training epoches of changing the maximum audio length"
|
| 219 |
+
"example: \"[1, 10, 20, 40, 80, 160,]\" "
|
| 220 |
+
"then len need to be equal to len(dynamic_crops)"
|
| 221 |
+
)
|
| 222 |
+
},
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
cqt_loss_bin_dataloader: Optional[int] = field(
|
| 226 |
+
default=-1,
|
| 227 |
+
metadata={
|
| 228 |
+
"help": (
|
| 229 |
+
"use this parameter to prepare cqt prediction objective in dataloader"
|
| 230 |
+
)
|
| 231 |
+
},
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
clip_secs: int = field(
|
| 235 |
+
default=5,
|
| 236 |
+
metadata={
|
| 237 |
+
"help": "clip secs for each audio"
|
| 238 |
+
}
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
dataset_shuffle: bool = field(
|
| 242 |
+
default=True,
|
| 243 |
+
metadata={
|
| 244 |
+
"help": (
|
| 245 |
+
"dataset shuffle when sample a batch"
|
| 246 |
+
)
|
| 247 |
+
},
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@register_task("muq_pretraining", dataclass=MuQPretrainingConfig)
|
| 252 |
+
class MuQPretrainingTask(FairseqTask):
|
| 253 |
+
|
| 254 |
+
cfg: MuQPretrainingConfig
|
| 255 |
+
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
cfg: MuQPretrainingConfig,
|
| 259 |
+
) -> None:
|
| 260 |
+
super().__init__(cfg)
|
| 261 |
+
|
| 262 |
+
logger.info(f"current directory is {os.getcwd()}")
|
| 263 |
+
logger.info(f"MuQPretrainingTask Config {cfg}")
|
| 264 |
+
|
| 265 |
+
self.cfg = cfg
|
| 266 |
+
self.fine_tuning = cfg.fine_tuning
|
| 267 |
+
|
| 268 |
+
if cfg.fine_tuning:
|
| 269 |
+
self.state.add_factory("target_dictionary", self.load_dictionaries)
|
| 270 |
+
else:
|
| 271 |
+
self.state.add_factory("dictionaries", self.load_dictionaries)
|
| 272 |
+
|
| 273 |
+
self.blank_symbol = "<s>"
|
| 274 |
+
|
| 275 |
+
# use eval() to pass list parameters, skirt the fairseq/torch error: Can't pickle <enum 'Choices'>: attribute lookup Choices on fairseq.dataclass.constants failed
|
| 276 |
+
self.augmentation_effects = eval(self.cfg.augmentation_effects)
|
| 277 |
+
self.augmentation_probs = eval(self.cfg.augmentation_probs)
|
| 278 |
+
if len(self.augmentation_effects) > 0:
|
| 279 |
+
assert len(self.augmentation_effects) == len(self.augmentation_probs)
|
| 280 |
+
logger.info(f"Applying audio augmentation {self.augmentation_effects}, probabilities: {self.augmentation_probs}")
|
| 281 |
+
|
| 282 |
+
self.inbatch_noise_augment_number_range = eval(self.cfg.inbatch_noise_augment_number_range)
|
| 283 |
+
self.inbatch_noise_augment_len_range = eval(self.cfg.inbatch_noise_augment_len_range)
|
| 284 |
+
|
| 285 |
+
self.max_sample_size = self.cfg.max_sample_size
|
| 286 |
+
|
| 287 |
+
self.dynamic_crops = eval(self.cfg.dynamic_crops)
|
| 288 |
+
self.dynamic_crops_epoches = eval(self.cfg.dynamic_crops_epoches)
|
| 289 |
+
assert len(self.dynamic_crops) == len(self.dynamic_crops_epoches)
|
| 290 |
+
if len(self.dynamic_crops) > 0:
|
| 291 |
+
assert self.dynamic_crops_epoches[0] == 1
|
| 292 |
+
|
| 293 |
+
self.cqt_loss_bin_dataloader = self.cfg.cqt_loss_bin_dataloader
|
| 294 |
+
|
| 295 |
+
self.numpy_memmap_label = self.cfg.numpy_memmap_label
|
| 296 |
+
self.store_labels = self.cfg.store_labels
|
| 297 |
+
if self.numpy_memmap_label:
|
| 298 |
+
assert self.store_labels
|
| 299 |
+
|
| 300 |
+
@property
|
| 301 |
+
def source_dictionary(self) -> Optional[Dictionary]:
|
| 302 |
+
return None
|
| 303 |
+
|
| 304 |
+
@property
|
| 305 |
+
def target_dictionary(self) -> Optional[Dictionary]:
|
| 306 |
+
return self.state.target_dictionary
|
| 307 |
+
|
| 308 |
+
@property
|
| 309 |
+
def dictionaries(self) -> List[Dictionary]:
|
| 310 |
+
return self.state.dictionaries
|
| 311 |
+
|
| 312 |
+
@classmethod
|
| 313 |
+
def setup_task(
|
| 314 |
+
cls, cfg: MuQPretrainingConfig, **kwargs
|
| 315 |
+
) -> "MuQPretrainingTask":
|
| 316 |
+
return cls(cfg)
|
| 317 |
+
|
| 318 |
+
def load_dictionaries(self):
|
| 319 |
+
label_dir = self.cfg.data if (self.cfg.label_dir is None or self.cfg.label_dir == '') else self.cfg.label_dir
|
| 320 |
+
print(label_dir)
|
| 321 |
+
dictionaries = [
|
| 322 |
+
Dictionary.load(f"{label_dir}/dict.{label}.txt")
|
| 323 |
+
for label in self.cfg.labels
|
| 324 |
+
]
|
| 325 |
+
return dictionaries[0] if self.cfg.fine_tuning else dictionaries
|
| 326 |
+
|
| 327 |
+
def get_label_dir(self) -> str:
|
| 328 |
+
if self.cfg.label_dir is None or self.cfg.label_dir=='':
|
| 329 |
+
return self.cfg.data
|
| 330 |
+
return self.cfg.label_dir
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def is_force_load_dataset(self, epoch, training_restore=False):
|
| 334 |
+
# find the threshold that holds epoch \in [threshold, next_threshold)
|
| 335 |
+
return (epoch in self.dynamic_crops_epoches) or training_restore or (self.cfg.sharding_data > 1)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def set_dynamic_crop_max_sample(self, epoch):
|
| 339 |
+
pass
|
| 340 |
+
|
| 341 |
+
def load_dataset(self, split: str, **kwargs) -> None:
|
| 342 |
+
pass
|
| 343 |
+
|
| 344 |
+
def load_dataset_ark(self, split, **kwargs):
|
| 345 |
+
pass
|
| 346 |
+
|
| 347 |
+
def load_dataset_mert(self, split: str, **kwargs) -> None:
|
| 348 |
+
pass
|
| 349 |
+
|
| 350 |
+
def max_positions(self) -> Tuple[int, int]:
|
| 351 |
+
return (sys.maxsize, sys.maxsize)
|
| 352 |
+
|
| 353 |
+
def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array:
|
| 354 |
+
return indices
|
MuCodec/tools/__pycache__/get_melvaehifigan48k.cpython-310.pyc
ADDED
|
Binary file (35.6 kB). View file
|
|
|
MuCodec/tools/__pycache__/torch_tools.cpython-310.pyc
ADDED
|
Binary file (2.74 kB). View file
|
|
|
MuCodec/tools/__pycache__/torch_tools.cpython-312.pyc
ADDED
|
Binary file (4.48 kB). View file
|
|
|
checkpoints/Qwen3-0.6B/.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
checkpoints/Qwen3-0.6B/LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright 2024 Alibaba Cloud
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
checkpoints/Qwen3-0.6B/README.md
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
license_link: https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/LICENSE
|
| 5 |
+
pipeline_tag: text-generation
|
| 6 |
+
base_model:
|
| 7 |
+
- Qwen/Qwen3-0.6B-Base
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Qwen3-0.6B
|
| 11 |
+
<a href="https://chat.qwen.ai/" target="_blank" style="margin: 2px;">
|
| 12 |
+
<img alt="Chat" src="https://img.shields.io/badge/%F0%9F%92%9C%EF%B8%8F%20Qwen%20Chat%20-536af5" style="display: inline-block; vertical-align: middle;"/>
|
| 13 |
+
</a>
|
| 14 |
+
|
| 15 |
+
## Qwen3 Highlights
|
| 16 |
+
|
| 17 |
+
Qwen3 is the latest generation of large language models in Qwen series, offering a comprehensive suite of dense and mixture-of-experts (MoE) models. Built upon extensive training, Qwen3 delivers groundbreaking advancements in reasoning, instruction-following, agent capabilities, and multilingual support, with the following key features:
|
| 18 |
+
|
| 19 |
+
- **Uniquely support of seamless switching between thinking mode** (for complex logical reasoning, math, and coding) and **non-thinking mode** (for efficient, general-purpose dialogue) **within single model**, ensuring optimal performance across various scenarios.
|
| 20 |
+
- **Significantly enhancement in its reasoning capabilities**, surpassing previous QwQ (in thinking mode) and Qwen2.5 instruct models (in non-thinking mode) on mathematics, code generation, and commonsense logical reasoning.
|
| 21 |
+
- **Superior human preference alignment**, excelling in creative writing, role-playing, multi-turn dialogues, and instruction following, to deliver a more natural, engaging, and immersive conversational experience.
|
| 22 |
+
- **Expertise in agent capabilities**, enabling precise integration with external tools in both thinking and unthinking modes and achieving leading performance among open-source models in complex agent-based tasks.
|
| 23 |
+
- **Support of 100+ languages and dialects** with strong capabilities for **multilingual instruction following** and **translation**.
|
| 24 |
+
|
| 25 |
+
## Model Overview
|
| 26 |
+
|
| 27 |
+
**Qwen3-0.6B** has the following features:
|
| 28 |
+
- Type: Causal Language Models
|
| 29 |
+
- Training Stage: Pretraining & Post-training
|
| 30 |
+
- Number of Parameters: 0.6B
|
| 31 |
+
- Number of Paramaters (Non-Embedding): 0.44B
|
| 32 |
+
- Number of Layers: 28
|
| 33 |
+
- Number of Attention Heads (GQA): 16 for Q and 8 for KV
|
| 34 |
+
- Context Length: 32,768
|
| 35 |
+
|
| 36 |
+
For more details, including benchmark evaluation, hardware requirements, and inference performance, please refer to our [blog](https://qwenlm.github.io/blog/qwen3/), [GitHub](https://github.com/QwenLM/Qwen3), and [Documentation](https://qwen.readthedocs.io/en/latest/).
|
| 37 |
+
|
| 38 |
+
> [!TIP]
|
| 39 |
+
> If you encounter significant endless repetitions, please refer to the [Best Practices](#best-practices) section for optimal sampling parameters, and set the ``presence_penalty`` to 1.5.
|
| 40 |
+
|
| 41 |
+
## Quickstart
|
| 42 |
+
|
| 43 |
+
The code of Qwen3 has been in the latest Hugging Face `transformers` and we advise you to use the latest version of `transformers`.
|
| 44 |
+
|
| 45 |
+
With `transformers<4.51.0`, you will encounter the following error:
|
| 46 |
+
```
|
| 47 |
+
KeyError: 'qwen3'
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
The following contains a code snippet illustrating how to use the model generate content based on given inputs.
|
| 51 |
+
```python
|
| 52 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 53 |
+
|
| 54 |
+
model_name = "Qwen/Qwen3-0.6B"
|
| 55 |
+
|
| 56 |
+
# load the tokenizer and the model
|
| 57 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 58 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 59 |
+
model_name,
|
| 60 |
+
torch_dtype="auto",
|
| 61 |
+
device_map="auto"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# prepare the model input
|
| 65 |
+
prompt = "Give me a short introduction to large language model."
|
| 66 |
+
messages = [
|
| 67 |
+
{"role": "user", "content": prompt}
|
| 68 |
+
]
|
| 69 |
+
text = tokenizer.apply_chat_template(
|
| 70 |
+
messages,
|
| 71 |
+
tokenize=False,
|
| 72 |
+
add_generation_prompt=True,
|
| 73 |
+
enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.
|
| 74 |
+
)
|
| 75 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 76 |
+
|
| 77 |
+
# conduct text completion
|
| 78 |
+
generated_ids = model.generate(
|
| 79 |
+
**model_inputs,
|
| 80 |
+
max_new_tokens=32768
|
| 81 |
+
)
|
| 82 |
+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
| 83 |
+
|
| 84 |
+
# parsing thinking content
|
| 85 |
+
try:
|
| 86 |
+
# rindex finding 151668 (</think>)
|
| 87 |
+
index = len(output_ids) - output_ids[::-1].index(151668)
|
| 88 |
+
except ValueError:
|
| 89 |
+
index = 0
|
| 90 |
+
|
| 91 |
+
thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
|
| 92 |
+
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
|
| 93 |
+
|
| 94 |
+
print("thinking content:", thinking_content)
|
| 95 |
+
print("content:", content)
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
For deployment, you can use `sglang>=0.4.6.post1` or `vllm>=0.8.5` or to create an OpenAI-compatible API endpoint:
|
| 99 |
+
- SGLang:
|
| 100 |
+
```shell
|
| 101 |
+
python -m sglang.launch_server --model-path Qwen/Qwen3-0.6B --reasoning-parser qwen3
|
| 102 |
+
```
|
| 103 |
+
- vLLM:
|
| 104 |
+
```shell
|
| 105 |
+
vllm serve Qwen/Qwen3-0.6B --enable-reasoning --reasoning-parser deepseek_r1
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
For local use, applications such as Ollama, LMStudio, MLX-LM, llama.cpp, and KTransformers have also supported Qwen3.
|
| 109 |
+
|
| 110 |
+
## Switching Between Thinking and Non-Thinking Mode
|
| 111 |
+
|
| 112 |
+
> [!TIP]
|
| 113 |
+
> The `enable_thinking` switch is also available in APIs created by SGLang and vLLM.
|
| 114 |
+
> Please refer to our documentation for [SGLang](https://qwen.readthedocs.io/en/latest/deployment/sglang.html#thinking-non-thinking-modes) and [vLLM](https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes) users.
|
| 115 |
+
|
| 116 |
+
### `enable_thinking=True`
|
| 117 |
+
|
| 118 |
+
By default, Qwen3 has thinking capabilities enabled, similar to QwQ-32B. This means the model will use its reasoning abilities to enhance the quality of generated responses. For example, when explicitly setting `enable_thinking=True` or leaving it as the default value in `tokenizer.apply_chat_template`, the model will engage its thinking mode.
|
| 119 |
+
|
| 120 |
+
```python
|
| 121 |
+
text = tokenizer.apply_chat_template(
|
| 122 |
+
messages,
|
| 123 |
+
tokenize=False,
|
| 124 |
+
add_generation_prompt=True,
|
| 125 |
+
enable_thinking=True # True is the default value for enable_thinking
|
| 126 |
+
)
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
In this mode, the model will generate think content wrapped in a `<think>...</think>` block, followed by the final response.
|
| 130 |
+
|
| 131 |
+
> [!NOTE]
|
| 132 |
+
> For thinking mode, use `Temperature=0.6`, `TopP=0.95`, `TopK=20`, and `MinP=0` (the default setting in `generation_config.json`). **DO NOT use greedy decoding**, as it can lead to performance degradation and endless repetitions. For more detailed guidance, please refer to the [Best Practices](#best-practices) section.
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
### `enable_thinking=False`
|
| 136 |
+
|
| 137 |
+
We provide a hard switch to strictly disable the model's thinking behavior, aligning its functionality with the previous Qwen2.5-Instruct models. This mode is particularly useful in scenarios where disabling thinking is essential for enhancing efficiency.
|
| 138 |
+
|
| 139 |
+
```python
|
| 140 |
+
text = tokenizer.apply_chat_template(
|
| 141 |
+
messages,
|
| 142 |
+
tokenize=False,
|
| 143 |
+
add_generation_prompt=True,
|
| 144 |
+
enable_thinking=False # Setting enable_thinking=False disables thinking mode
|
| 145 |
+
)
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
In this mode, the model will not generate any think content and will not include a `<think>...</think>` block.
|
| 149 |
+
|
| 150 |
+
> [!NOTE]
|
| 151 |
+
> For non-thinking mode, we suggest using `Temperature=0.7`, `TopP=0.8`, `TopK=20`, and `MinP=0`. For more detailed guidance, please refer to the [Best Practices](#best-practices) section.
|
| 152 |
+
|
| 153 |
+
### Advanced Usage: Switching Between Thinking and Non-Thinking Modes via User Input
|
| 154 |
+
|
| 155 |
+
We provide a soft switch mechanism that allows users to dynamically control the model's behavior when `enable_thinking=True`. Specifically, you can add `/think` and `/no_think` to user prompts or system messages to switch the model's thinking mode from turn to turn. The model will follow the most recent instruction in multi-turn conversations.
|
| 156 |
+
|
| 157 |
+
Here is an example of a multi-turn conversation:
|
| 158 |
+
|
| 159 |
+
```python
|
| 160 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 161 |
+
|
| 162 |
+
class QwenChatbot:
|
| 163 |
+
def __init__(self, model_name="Qwen/Qwen3-0.6B"):
|
| 164 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 165 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 166 |
+
self.history = []
|
| 167 |
+
|
| 168 |
+
def generate_response(self, user_input):
|
| 169 |
+
messages = self.history + [{"role": "user", "content": user_input}]
|
| 170 |
+
|
| 171 |
+
text = self.tokenizer.apply_chat_template(
|
| 172 |
+
messages,
|
| 173 |
+
tokenize=False,
|
| 174 |
+
add_generation_prompt=True
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
inputs = self.tokenizer(text, return_tensors="pt")
|
| 178 |
+
response_ids = self.model.generate(**inputs, max_new_tokens=32768)[0][len(inputs.input_ids[0]):].tolist()
|
| 179 |
+
response = self.tokenizer.decode(response_ids, skip_special_tokens=True)
|
| 180 |
+
|
| 181 |
+
# Update history
|
| 182 |
+
self.history.append({"role": "user", "content": user_input})
|
| 183 |
+
self.history.append({"role": "assistant", "content": response})
|
| 184 |
+
|
| 185 |
+
return response
|
| 186 |
+
|
| 187 |
+
# Example Usage
|
| 188 |
+
if __name__ == "__main__":
|
| 189 |
+
chatbot = QwenChatbot()
|
| 190 |
+
|
| 191 |
+
# First input (without /think or /no_think tags, thinking mode is enabled by default)
|
| 192 |
+
user_input_1 = "How many r's in strawberries?"
|
| 193 |
+
print(f"User: {user_input_1}")
|
| 194 |
+
response_1 = chatbot.generate_response(user_input_1)
|
| 195 |
+
print(f"Bot: {response_1}")
|
| 196 |
+
print("----------------------")
|
| 197 |
+
|
| 198 |
+
# Second input with /no_think
|
| 199 |
+
user_input_2 = "Then, how many r's in blueberries? /no_think"
|
| 200 |
+
print(f"User: {user_input_2}")
|
| 201 |
+
response_2 = chatbot.generate_response(user_input_2)
|
| 202 |
+
print(f"Bot: {response_2}")
|
| 203 |
+
print("----------------------")
|
| 204 |
+
|
| 205 |
+
# Third input with /think
|
| 206 |
+
user_input_3 = "Really? /think"
|
| 207 |
+
print(f"User: {user_input_3}")
|
| 208 |
+
response_3 = chatbot.generate_response(user_input_3)
|
| 209 |
+
print(f"Bot: {response_3}")
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
> [!NOTE]
|
| 213 |
+
> For API compatibility, when `enable_thinking=True`, regardless of whether the user uses `/think` or `/no_think`, the model will always output a block wrapped in `<think>...</think>`. However, the content inside this block may be empty if thinking is disabled.
|
| 214 |
+
> When `enable_thinking=False`, the soft switches are not valid. Regardless of any `/think` or `/no_think` tags input by the user, the model will not generate think content and will not include a `<think>...</think>` block.
|
| 215 |
+
|
| 216 |
+
## Agentic Use
|
| 217 |
+
|
| 218 |
+
Qwen3 excels in tool calling capabilities. We recommend using [Qwen-Agent](https://github.com/QwenLM/Qwen-Agent) to make the best use of agentic ability of Qwen3. Qwen-Agent encapsulates tool-calling templates and tool-calling parsers internally, greatly reducing coding complexity.
|
| 219 |
+
|
| 220 |
+
To define the available tools, you can use the MCP configuration file, use the integrated tool of Qwen-Agent, or integrate other tools by yourself.
|
| 221 |
+
```python
|
| 222 |
+
from qwen_agent.agents import Assistant
|
| 223 |
+
|
| 224 |
+
# Define LLM
|
| 225 |
+
llm_cfg = {
|
| 226 |
+
'model': 'Qwen3-0.6B',
|
| 227 |
+
|
| 228 |
+
# Use the endpoint provided by Alibaba Model Studio:
|
| 229 |
+
# 'model_type': 'qwen_dashscope',
|
| 230 |
+
# 'api_key': os.getenv('DASHSCOPE_API_KEY'),
|
| 231 |
+
|
| 232 |
+
# Use a custom endpoint compatible with OpenAI API:
|
| 233 |
+
'model_server': 'http://localhost:8000/v1', # api_base
|
| 234 |
+
'api_key': 'EMPTY',
|
| 235 |
+
|
| 236 |
+
# Other parameters:
|
| 237 |
+
# 'generate_cfg': {
|
| 238 |
+
# # Add: When the response content is `<think>this is the thought</think>this is the answer;
|
| 239 |
+
# # Do not add: When the response has been separated by reasoning_content and content.
|
| 240 |
+
# 'thought_in_content': True,
|
| 241 |
+
# },
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
# Define Tools
|
| 245 |
+
tools = [
|
| 246 |
+
{'mcpServers': { # You can specify the MCP configuration file
|
| 247 |
+
'time': {
|
| 248 |
+
'command': 'uvx',
|
| 249 |
+
'args': ['mcp-server-time', '--local-timezone=Asia/Shanghai']
|
| 250 |
+
},
|
| 251 |
+
"fetch": {
|
| 252 |
+
"command": "uvx",
|
| 253 |
+
"args": ["mcp-server-fetch"]
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
},
|
| 257 |
+
'code_interpreter', # Built-in tools
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
# Define Agent
|
| 261 |
+
bot = Assistant(llm=llm_cfg, function_list=tools)
|
| 262 |
+
|
| 263 |
+
# Streaming generation
|
| 264 |
+
messages = [{'role': 'user', 'content': 'https://qwenlm.github.io/blog/ Introduce the latest developments of Qwen'}]
|
| 265 |
+
for responses in bot.run(messages=messages):
|
| 266 |
+
pass
|
| 267 |
+
print(responses)
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
## Best Practices
|
| 271 |
+
|
| 272 |
+
To achieve optimal performance, we recommend the following settings:
|
| 273 |
+
|
| 274 |
+
1. **Sampling Parameters**:
|
| 275 |
+
- For thinking mode (`enable_thinking=True`), use `Temperature=0.6`, `TopP=0.95`, `TopK=20`, and `MinP=0`. **DO NOT use greedy decoding**, as it can lead to performance degradation and endless repetitions.
|
| 276 |
+
- For non-thinking mode (`enable_thinking=False`), we suggest using `Temperature=0.7`, `TopP=0.8`, `TopK=20`, and `MinP=0`.
|
| 277 |
+
- For supported frameworks, you can adjust the `presence_penalty` parameter between 0 and 2 to reduce endless repetitions. However, using a higher value may occasionally result in language mixing and a slight decrease in model performance.
|
| 278 |
+
|
| 279 |
+
2. **Adequate Output Length**: We recommend using an output length of 32,768 tokens for most queries. For benchmarking on highly complex problems, such as those found in math and programming competitions, we suggest setting the max output length to 38,912 tokens. This provides the model with sufficient space to generate detailed and comprehensive responses, thereby enhancing its overall performance.
|
| 280 |
+
|
| 281 |
+
3. **Standardize Output Format**: We recommend using prompts to standardize model outputs when benchmarking.
|
| 282 |
+
- **Math Problems**: Include "Please reason step by step, and put your final answer within \boxed{}." in the prompt.
|
| 283 |
+
- **Multiple-Choice Questions**: Add the following JSON structure to the prompt to standardize responses: "Please show your choice in the `answer` field with only the choice letter, e.g., `"answer": "C"`."
|
| 284 |
+
|
| 285 |
+
4. **No Thinking Content in History**: In multi-turn conversations, the historical model output should only include the final output part and does not need to include the thinking content. It is implemented in the provided chat template in Jinja2. However, for frameworks that do not directly use the Jinja2 chat template, it is up to the developers to ensure that the best practice is followed.
|
| 286 |
+
|
| 287 |
+
### Citation
|
| 288 |
+
|
| 289 |
+
If you find our work helpful, feel free to give us a cite.
|
| 290 |
+
|
| 291 |
+
```
|
| 292 |
+
@misc{qwen3technicalreport,
|
| 293 |
+
title={Qwen3 Technical Report},
|
| 294 |
+
author={Qwen Team},
|
| 295 |
+
year={2025},
|
| 296 |
+
eprint={2505.09388},
|
| 297 |
+
archivePrefix={arXiv},
|
| 298 |
+
primaryClass={cs.CL},
|
| 299 |
+
url={https://arxiv.org/abs/2505.09388},
|
| 300 |
+
}
|
| 301 |
+
```
|
checkpoints/Qwen3-0.6B/config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen3ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 151643,
|
| 8 |
+
"eos_token_id": 151645,
|
| 9 |
+
"head_dim": 128,
|
| 10 |
+
"hidden_act": "silu",
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 3072,
|
| 14 |
+
"max_position_embeddings": 40960,
|
| 15 |
+
"max_window_layers": 28,
|
| 16 |
+
"model_type": "qwen3",
|
| 17 |
+
"num_attention_heads": 16,
|
| 18 |
+
"num_hidden_layers": 28,
|
| 19 |
+
"num_key_value_heads": 8,
|
| 20 |
+
"rms_norm_eps": 1e-06,
|
| 21 |
+
"rope_scaling": null,
|
| 22 |
+
"rope_theta": 1000000,
|
| 23 |
+
"sliding_window": null,
|
| 24 |
+
"tie_word_embeddings": true,
|
| 25 |
+
"torch_dtype": "bfloat16",
|
| 26 |
+
"transformers_version": "4.51.0",
|
| 27 |
+
"use_cache": true,
|
| 28 |
+
"use_sliding_window": false,
|
| 29 |
+
"vocab_size": 151936,
|
| 30 |
+
"magel_chord_dropout_trigger_prob": 0.6,
|
| 31 |
+
"magel_structure_dropout_trigger_prob": 0.6,
|
| 32 |
+
"magel_num_audio_token": 16384
|
| 33 |
+
}
|
checkpoints/Qwen3-0.6B/generation_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151643,
|
| 3 |
+
"do_sample": true,
|
| 4 |
+
"eos_token_id": [
|
| 5 |
+
151645,
|
| 6 |
+
151643
|
| 7 |
+
],
|
| 8 |
+
"pad_token_id": 151643,
|
| 9 |
+
"temperature": 0.6,
|
| 10 |
+
"top_k": 20,
|
| 11 |
+
"top_p": 0.95,
|
| 12 |
+
"transformers_version": "4.51.0"
|
| 13 |
+
}
|
checkpoints/Qwen3-0.6B/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
checkpoints/Qwen3-0.6B/tokenizer_config.json
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"151643": {
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"151644": {
|
| 14 |
+
"content": "<|im_start|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"151645": {
|
| 22 |
+
"content": "<|im_end|>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"151646": {
|
| 30 |
+
"content": "<|object_ref_start|>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"151647": {
|
| 38 |
+
"content": "<|object_ref_end|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"151648": {
|
| 46 |
+
"content": "<|box_start|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"151649": {
|
| 54 |
+
"content": "<|box_end|>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"151650": {
|
| 62 |
+
"content": "<|quad_start|>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"151651": {
|
| 70 |
+
"content": "<|quad_end|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"151652": {
|
| 78 |
+
"content": "<|vision_start|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"151653": {
|
| 86 |
+
"content": "<|vision_end|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
},
|
| 93 |
+
"151654": {
|
| 94 |
+
"content": "<|vision_pad|>",
|
| 95 |
+
"lstrip": false,
|
| 96 |
+
"normalized": false,
|
| 97 |
+
"rstrip": false,
|
| 98 |
+
"single_word": false,
|
| 99 |
+
"special": true
|
| 100 |
+
},
|
| 101 |
+
"151655": {
|
| 102 |
+
"content": "<|image_pad|>",
|
| 103 |
+
"lstrip": false,
|
| 104 |
+
"normalized": false,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false,
|
| 107 |
+
"special": true
|
| 108 |
+
},
|
| 109 |
+
"151656": {
|
| 110 |
+
"content": "<|video_pad|>",
|
| 111 |
+
"lstrip": false,
|
| 112 |
+
"normalized": false,
|
| 113 |
+
"rstrip": false,
|
| 114 |
+
"single_word": false,
|
| 115 |
+
"special": true
|
| 116 |
+
},
|
| 117 |
+
"151657": {
|
| 118 |
+
"content": "<tool_call>",
|
| 119 |
+
"lstrip": false,
|
| 120 |
+
"normalized": false,
|
| 121 |
+
"rstrip": false,
|
| 122 |
+
"single_word": false,
|
| 123 |
+
"special": false
|
| 124 |
+
},
|
| 125 |
+
"151658": {
|
| 126 |
+
"content": "</tool_call>",
|
| 127 |
+
"lstrip": false,
|
| 128 |
+
"normalized": false,
|
| 129 |
+
"rstrip": false,
|
| 130 |
+
"single_word": false,
|
| 131 |
+
"special": false
|
| 132 |
+
},
|
| 133 |
+
"151659": {
|
| 134 |
+
"content": "<|fim_prefix|>",
|
| 135 |
+
"lstrip": false,
|
| 136 |
+
"normalized": false,
|
| 137 |
+
"rstrip": false,
|
| 138 |
+
"single_word": false,
|
| 139 |
+
"special": false
|
| 140 |
+
},
|
| 141 |
+
"151660": {
|
| 142 |
+
"content": "<|fim_middle|>",
|
| 143 |
+
"lstrip": false,
|
| 144 |
+
"normalized": false,
|
| 145 |
+
"rstrip": false,
|
| 146 |
+
"single_word": false,
|
| 147 |
+
"special": false
|
| 148 |
+
},
|
| 149 |
+
"151661": {
|
| 150 |
+
"content": "<|fim_suffix|>",
|
| 151 |
+
"lstrip": false,
|
| 152 |
+
"normalized": false,
|
| 153 |
+
"rstrip": false,
|
| 154 |
+
"single_word": false,
|
| 155 |
+
"special": false
|
| 156 |
+
},
|
| 157 |
+
"151662": {
|
| 158 |
+
"content": "<|fim_pad|>",
|
| 159 |
+
"lstrip": false,
|
| 160 |
+
"normalized": false,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false,
|
| 163 |
+
"special": false
|
| 164 |
+
},
|
| 165 |
+
"151663": {
|
| 166 |
+
"content": "<|repo_name|>",
|
| 167 |
+
"lstrip": false,
|
| 168 |
+
"normalized": false,
|
| 169 |
+
"rstrip": false,
|
| 170 |
+
"single_word": false,
|
| 171 |
+
"special": false
|
| 172 |
+
},
|
| 173 |
+
"151664": {
|
| 174 |
+
"content": "<|file_sep|>",
|
| 175 |
+
"lstrip": false,
|
| 176 |
+
"normalized": false,
|
| 177 |
+
"rstrip": false,
|
| 178 |
+
"single_word": false,
|
| 179 |
+
"special": false
|
| 180 |
+
},
|
| 181 |
+
"151665": {
|
| 182 |
+
"content": "<tool_response>",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": false,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": false
|
| 188 |
+
},
|
| 189 |
+
"151666": {
|
| 190 |
+
"content": "</tool_response>",
|
| 191 |
+
"lstrip": false,
|
| 192 |
+
"normalized": false,
|
| 193 |
+
"rstrip": false,
|
| 194 |
+
"single_word": false,
|
| 195 |
+
"special": false
|
| 196 |
+
},
|
| 197 |
+
"151667": {
|
| 198 |
+
"content": "<think>",
|
| 199 |
+
"lstrip": false,
|
| 200 |
+
"normalized": false,
|
| 201 |
+
"rstrip": false,
|
| 202 |
+
"single_word": false,
|
| 203 |
+
"special": false
|
| 204 |
+
},
|
| 205 |
+
"151668": {
|
| 206 |
+
"content": "</think>",
|
| 207 |
+
"lstrip": false,
|
| 208 |
+
"normalized": false,
|
| 209 |
+
"rstrip": false,
|
| 210 |
+
"single_word": false,
|
| 211 |
+
"special": false
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"additional_special_tokens": [
|
| 215 |
+
"<|im_start|>",
|
| 216 |
+
"<|im_end|>",
|
| 217 |
+
"<|object_ref_start|>",
|
| 218 |
+
"<|object_ref_end|>",
|
| 219 |
+
"<|box_start|>",
|
| 220 |
+
"<|box_end|>",
|
| 221 |
+
"<|quad_start|>",
|
| 222 |
+
"<|quad_end|>",
|
| 223 |
+
"<|vision_start|>",
|
| 224 |
+
"<|vision_end|>",
|
| 225 |
+
"<|vision_pad|>",
|
| 226 |
+
"<|image_pad|>",
|
| 227 |
+
"<|video_pad|>"
|
| 228 |
+
],
|
| 229 |
+
"bos_token": null,
|
| 230 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
|
| 231 |
+
"clean_up_tokenization_spaces": false,
|
| 232 |
+
"eos_token": "<|im_end|>",
|
| 233 |
+
"errors": "replace",
|
| 234 |
+
"model_max_length": 131072,
|
| 235 |
+
"pad_token": "<|endoftext|>",
|
| 236 |
+
"split_special_tokens": false,
|
| 237 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 238 |
+
"unk_token": null
|
| 239 |
+
}
|
checkpoints/Qwen3-0.6B/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|