Epsilon617 commited on
Commit
283e8f1
1 Parent(s): 2861940
Files changed (41) hide show
  1. MERT-v1-95M/README.md +121 -0
  2. MERT-v1-95M/__pycache__/configuration_MERT.cpython-310.pyc +0 -0
  3. MERT-v1-95M/__pycache__/modeling_MERT.cpython-310.pyc +0 -0
  4. MERT-v1-95M/config.json +85 -0
  5. MERT-v1-95M/configuration_MERT.py +138 -0
  6. MERT-v1-95M/modeling_MERT.py +409 -0
  7. MERT-v1-95M/preprocessor_config.json +9 -0
  8. MERT-v1-95M/pytorch_model.bin +3 -0
  9. Prediction_Head/MTGGenre_head.py +48 -0
  10. Prediction_Head/MTGGenre_id2class.json +1 -0
  11. Prediction_Head/__MACOSX/._best-layer-MERT-v1-95M +0 -0
  12. Prediction_Head/__pycache__/MTGGenre_head.cpython-310.pyc +0 -0
  13. Prediction_Head/best-layer-MERT-v1-95M.zip +3 -0
  14. Prediction_Head/best-layer-MERT-v1-95M/EMO.ckpt +3 -0
  15. Prediction_Head/best-layer-MERT-v1-95M/EMO.id2class.json +1 -0
  16. Prediction_Head/best-layer-MERT-v1-95M/GS.ckpt +3 -0
  17. Prediction_Head/best-layer-MERT-v1-95M/GS.id2class.json +1 -0
  18. Prediction_Head/best-layer-MERT-v1-95M/GTZAN.ckpt +3 -0
  19. Prediction_Head/best-layer-MERT-v1-95M/GTZAN.id2class.json +0 -0
  20. Prediction_Head/best-layer-MERT-v1-95M/MTGGenre.ckpt +3 -0
  21. Prediction_Head/best-layer-MERT-v1-95M/MTGGenre.id2class.json +1 -0
  22. Prediction_Head/best-layer-MERT-v1-95M/MTGInstrument.ckpt +3 -0
  23. Prediction_Head/best-layer-MERT-v1-95M/MTGInstrument.id2class.json +1 -0
  24. Prediction_Head/best-layer-MERT-v1-95M/MTGMood.ckpt +3 -0
  25. Prediction_Head/best-layer-MERT-v1-95M/MTGMood.id2class.json +1 -0
  26. Prediction_Head/best-layer-MERT-v1-95M/MTGTop50.ckpt +3 -0
  27. Prediction_Head/best-layer-MERT-v1-95M/MTGTop50.id2class.json +1 -0
  28. Prediction_Head/best-layer-MERT-v1-95M/MTT.id2class.json +0 -0
  29. Prediction_Head/best-layer-MERT-v1-95M/NSynthI.ckpt +3 -0
  30. Prediction_Head/best-layer-MERT-v1-95M/NSynthI.id2class.json +1 -0
  31. Prediction_Head/best-layer-MERT-v1-95M/NSynthP.ckpt +3 -0
  32. Prediction_Head/best-layer-MERT-v1-95M/NSynthP.id2class.json +1 -0
  33. Prediction_Head/best-layer-MERT-v1-95M/VocalSetS.ckpt +3 -0
  34. Prediction_Head/best-layer-MERT-v1-95M/VocalSetS.id2class.json +1 -0
  35. Prediction_Head/best-layer-MERT-v1-95M/VocalSetT.ckpt +3 -0
  36. Prediction_Head/best-layer-MERT-v1-95M/VocalSetT.id2class.json +1 -0
  37. Prediction_Head/best_MTGGenre.ckpt +3 -0
  38. README.md +5 -5
  39. __pycache__/app.cpython-310.pyc +0 -0
  40. app.py +207 -0
  41. requirements.txt +88 -0
MERT-v1-95M/README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ inference: false
4
+ tags:
5
+ - music
6
+ ---
7
+
8
+ # Introduction to our series work
9
+
10
+ The development log of our Music Audio Pre-training (m-a-p) model family:
11
+ - 17/03/2023: we release two advanced music understanding models, [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M) and [MERT-v1-330M](https://huggingface.co/m-a-p/MERT-v1-330M) , trained with new paradigm and dataset. They outperform the previous models and can better generalize to more tasks.
12
+ - 14/03/2023: we retrained the MERT-v0 model with open-source-only music dataset [MERT-v0-public](https://huggingface.co/m-a-p/MERT-v0-public)
13
+ - 29/12/2022: a music understanding model [MERT-v0](https://huggingface.co/m-a-p/MERT-v0) trained with **MLM** paradigm, which performs better at downstream tasks.
14
+ - 29/10/2022: a pre-trained MIR model [music2vec](https://huggingface.co/m-a-p/music2vec-v1) trained with **BYOL** paradigm.
15
+
16
+
17
+
18
+ Here is a table for quick model pick-up:
19
+
20
+ | Name | Pre-train Paradigm | Training Data (hour) | Pre-train Context (second) | Model Size | Transformer Layer-Dimension | Feature Rate | Sample Rate | Release Date |
21
+ | ------------------------------------------------------------ | ------------------ | -------------------- | ---------------------------- | ---------- | --------------------------- | ------------ | ----------- | ------------ |
22
+ | [MERT-v1-330M](https://huggingface.co/m-a-p/MERT-v1-330M) | MLM | 160K | 5 | 330M | 24-1024 | 75 Hz | 24K Hz | 17/03/2023 |
23
+ | [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M) | MLM | 20K | 5 | 95M | 12-768 | 75 Hz | 24K Hz | 17/03/2023 |
24
+ | [MERT-v0-public](https://huggingface.co/m-a-p/MERT-v0-public) | MLM | 900 | 5 | 95M | 12-768 | 50 Hz | 16K Hz | 14/03/2023 |
25
+ | [MERT-v0](https://huggingface.co/m-a-p/MERT-v0) | MLM | 1000 | 5 | 95 M | 12-768 | 50 Hz | 16K Hz | 29/12/2022 |
26
+ | [music2vec-v1](https://huggingface.co/m-a-p/music2vec-v1) | BYOL | 1000 | 30 | 95 M | 12-768 | 50 Hz | 16K Hz | 30/10/2022 |
27
+
28
+ ## Explanation
29
+
30
+ The m-a-p models share the similar model architecture and the most distinguished difference is the paradigm in used pre-training. Other than that, there are several nuance technical configuration needs to know before using:
31
+
32
+ - **Model Size**: the number of parameters that would be loaded to memory. Please select the appropriate size fitting your hardware.
33
+ - **Transformer Layer-Dimension**: The number of transformer layers and the corresponding feature dimensions can be outputted from our model. This is marked out because features extracted by **different layers could have various performance depending on tasks**.
34
+ - **Feature Rate**: Given a 1-second audio input, the number of features output by the model.
35
+ - **Sample Rate**: The frequency of audio that the model is trained with.
36
+
37
+
38
+
39
+ # Introduction to MERT-v1
40
+
41
+ Compared to MERT-v0, we introduce multiple new things in the MERT-v1 pre-training:
42
+
43
+ - Change the pseudo labels to 8 codebooks from [encodec](https://github.com/facebookresearch/encodec), which potentially has higher quality and empower our model to support music generation.
44
+ - MLM prediction with in-batch noise mixture.
45
+ - Train with higher audio frequency (24K Hz).
46
+ - Train with more audio data (up to 160 thousands of hours).
47
+ - More available model sizes 95M and 330M.
48
+
49
+
50
+
51
+ More details will be written in our coming-soon paper.
52
+
53
+
54
+
55
+ # Model Usage
56
+
57
+ ```python
58
+ # from transformers import Wav2Vec2Processor
59
+ from transformers import Wav2Vec2FeatureExtractor
60
+ from transformers import AutoModel
61
+ import torch
62
+ from torch import nn
63
+ import torchaudio.transforms as T
64
+ from datasets import load_dataset
65
+
66
+
67
+ # loading our model weights
68
+ model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
69
+ # loading the corresponding preprocessor config
70
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M",trust_remote_code=True)
71
+
72
+ # load demo audio and set processor
73
+ dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
74
+ dataset = dataset.sort("id")
75
+ sampling_rate = dataset.features["audio"].sampling_rate
76
+
77
+ resample_rate = processor.sampling_rate
78
+ # make sure the sample_rate aligned
79
+ if resample_rate != sampling_rate:
80
+ print(f'setting rate from {sampling_rate} to {resample_rate}')
81
+ resampler = T.Resample(sampling_rate, resample_rate)
82
+ else:
83
+ resampler = None
84
+
85
+ # audio file is decoded on the fly
86
+ if resampler is None:
87
+ input_audio = dataset[0]["audio"]["array"]
88
+ else:
89
+ input_audio = resampler(torch.from_numpy(dataset[0]["audio"]["array"]))
90
+
91
+ inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
92
+ with torch.no_grad():
93
+ outputs = model(**inputs, output_hidden_states=True)
94
+
95
+ # take a look at the output shape, there are 13 layers of representation
96
+ # each layer performs differently in different downstream tasks, you should choose empirically
97
+ all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
98
+ print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
99
+
100
+ # for utterance level classification tasks, you can simply reduce the representation in time
101
+ time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
102
+ print(time_reduced_hidden_states.shape) # [13, 768]
103
+
104
+ # you can even use a learnable weighted average representation
105
+ aggregator = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
106
+ weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
107
+ print(weighted_avg_hidden_states.shape) # [768]
108
+ ```
109
+
110
+
111
+
112
+ # Citation
113
+
114
+ ```shell
115
+ @article{li2022large,
116
+ title={Large-Scale Pretrained Model for Self-Supervised Music Audio Representation Learning},
117
+ author={Li, Yizhi and Yuan, Ruibin and Zhang, Ge and Ma, Yinghao and Lin, Chenghua and Chen, Xingran and Ragni, Anton and Yin, Hanzhi and Hu, Zhijie and He, Haoyu and others},
118
+ year={2022}
119
+ }
120
+
121
+ ```
MERT-v1-95M/__pycache__/configuration_MERT.cpython-310.pyc ADDED
Binary file (3.37 kB). View file
 
MERT-v1-95M/__pycache__/modeling_MERT.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
MERT-v1-95M/config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "m-a-p/MERT-v1-95M",
3
+ "activation_dropout": 0.0,
4
+ "apply_spec_augment": true,
5
+ "architectures": [
6
+ "MERTModel"
7
+ ],
8
+ "attention_dropout": 0.1,
9
+ "attention_relax": -1.0,
10
+ "auto_map": {
11
+ "AutoConfig": "configuration_MERT.MERTConfig",
12
+ "AutoModel": "modeling_MERT.MERTModel"
13
+ },
14
+ "bos_token_id": 1,
15
+ "classifier_proj_size": 256,
16
+ "conv_bias": false,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "sum",
45
+ "ctc_zero_infinity": false,
46
+ "deepnorm": false,
47
+ "do_stable_layer_norm": false,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_dropout": 0.0,
51
+ "feat_extract_norm": "group",
52
+ "feat_proj_dropout": 0.1,
53
+ "feat_proj_layer_norm": true,
54
+ "feature_extractor_cqt": false,
55
+ "feature_extractor_cqt_bins": 336,
56
+ "final_dropout": 0.1,
57
+ "gradient_checkpointing": false,
58
+ "hidden_act": "gelu",
59
+ "hidden_dropout": 0.1,
60
+ "hidden_dropout_prob": 0.1,
61
+ "hidden_size": 768,
62
+ "initializer_range": 0.02,
63
+ "intermediate_size": 3072,
64
+ "layer_norm_eps": 1e-05,
65
+ "layerdrop": 0.05,
66
+ "mask_feature_length": 10,
67
+ "mask_feature_min_masks": 0,
68
+ "mask_feature_prob": 0.0,
69
+ "mask_time_length": 10,
70
+ "mask_time_min_masks": 2,
71
+ "mask_time_prob": 0.05,
72
+ "model_type": "mert_model",
73
+ "num_attention_heads": 12,
74
+ "num_conv_pos_embedding_groups": 16,
75
+ "num_conv_pos_embeddings": 128,
76
+ "num_feat_extract_layers": 7,
77
+ "num_hidden_layers": 12,
78
+ "pad_token_id": 0,
79
+ "sample_rate": 24000,
80
+ "tokenizer_class": "Wav2Vec2CTCTokenizer",
81
+ "torch_dtype": "float32",
82
+ "transformers_version": "4.24.0",
83
+ "use_weighted_layer_sum": false,
84
+ "vocab_size": 32
85
+ }
MERT-v1-95M/configuration_MERT.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MERT model configuration
3
+ """
4
+
5
+ import functools
6
+ import operator
7
+
8
+ # from ...configuration_utils import PretrainedConfig
9
+ # from ...utils import logging
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.utils import logging
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ # TODO: use this MAP while uploading to Huggingface
16
+ # HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
17
+ # "facebook/hubert-base-ls960": "https://huggingface.co/facebook/hubert-base-ls960/resolve/main/config.json",
18
+ # # See all Hubert models at https://huggingface.co/models?filter=hubert
19
+ # }
20
+
21
+
22
+ class MERTConfig(PretrainedConfig):
23
+ r"""
24
+ """
25
+ model_type = "mert_model"
26
+
27
+ def __init__(
28
+ self,
29
+ vocab_size=32,
30
+ hidden_size=768,
31
+ num_hidden_layers=12,
32
+ num_attention_heads=12,
33
+ intermediate_size=3072,
34
+ hidden_act="gelu",
35
+ hidden_dropout=0.1,
36
+ activation_dropout=0.1,
37
+ attention_dropout=0.1,
38
+ feat_proj_layer_norm=True,
39
+ feat_proj_dropout=0.0,
40
+ final_dropout=0.1,
41
+ layerdrop=0.1,
42
+ initializer_range=0.02,
43
+ layer_norm_eps=1e-5,
44
+ feat_extract_norm="group",
45
+ feat_extract_activation="gelu",
46
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
47
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
48
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
49
+ conv_bias=False,
50
+ num_conv_pos_embeddings=128,
51
+ num_conv_pos_embedding_groups=16,
52
+ do_stable_layer_norm=False,
53
+ apply_spec_augment=True,
54
+ mask_time_prob=0.05,
55
+ mask_time_length=10,
56
+ mask_time_min_masks=2,
57
+ mask_feature_prob=0.0,
58
+ mask_feature_length=10,
59
+ mask_feature_min_masks=0,
60
+ ctc_loss_reduction="sum",
61
+ ctc_zero_infinity=False,
62
+ use_weighted_layer_sum=False,
63
+ classifier_proj_size=256,
64
+ pad_token_id=0,
65
+ bos_token_id=1,
66
+ eos_token_id=2,
67
+ feature_extractor_cqt=False,
68
+ feature_extractor_cqt_bins=336,
69
+ deepnorm=False,
70
+ attention_relax=-1.0,
71
+ **kwargs
72
+ ):
73
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
74
+ self.hidden_size = hidden_size
75
+ self.feat_extract_norm = feat_extract_norm
76
+ self.feat_extract_activation = feat_extract_activation
77
+ self.conv_dim = list(conv_dim)
78
+ self.conv_stride = list(conv_stride)
79
+ self.conv_kernel = list(conv_kernel)
80
+ self.conv_bias = conv_bias
81
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
82
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
83
+ self.num_feat_extract_layers = len(self.conv_dim)
84
+ self.num_hidden_layers = num_hidden_layers
85
+ self.intermediate_size = intermediate_size
86
+ self.hidden_act = hidden_act
87
+ self.num_attention_heads = num_attention_heads
88
+ self.hidden_dropout = hidden_dropout
89
+ self.attention_dropout = attention_dropout
90
+ self.activation_dropout = activation_dropout
91
+ self.feat_proj_layer_norm = feat_proj_layer_norm
92
+ self.feat_proj_dropout = feat_proj_dropout
93
+ self.final_dropout = final_dropout
94
+ self.layerdrop = layerdrop
95
+ self.layer_norm_eps = layer_norm_eps
96
+ self.initializer_range = initializer_range
97
+ self.vocab_size = vocab_size
98
+ self.do_stable_layer_norm = do_stable_layer_norm
99
+ self.use_weighted_layer_sum = use_weighted_layer_sum
100
+ self.classifier_proj_size = classifier_proj_size
101
+
102
+ if (
103
+ (len(self.conv_stride) != self.num_feat_extract_layers)
104
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
105
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
106
+ ):
107
+ raise ValueError(
108
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
109
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
110
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
111
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
112
+ )
113
+
114
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
115
+ self.apply_spec_augment = apply_spec_augment
116
+ self.mask_time_prob = mask_time_prob
117
+ self.mask_time_length = mask_time_length
118
+ self.mask_time_min_masks = mask_time_min_masks
119
+ self.mask_feature_prob = mask_feature_prob
120
+ self.mask_feature_length = mask_feature_length
121
+ self.mask_feature_min_masks = mask_feature_min_masks
122
+
123
+ # ctc loss
124
+ self.ctc_loss_reduction = ctc_loss_reduction
125
+ self.ctc_zero_infinity = ctc_zero_infinity
126
+
127
+ # cqt feature extractor
128
+ self.feature_extractor_cqt = feature_extractor_cqt
129
+ self.feature_extractor_cqt_bins = feature_extractor_cqt_bins
130
+
131
+ # deepnorm: up-scale weighted residual conection + down-scale initial value transformer encoder
132
+ self.deepnorm = deepnorm
133
+
134
+ self.attention_relax = attention_relax
135
+
136
+ @property
137
+ def inputs_to_logits_ratio(self):
138
+ return functools.reduce(operator.mul, self.conv_stride, 1)
MERT-v1-95M/modeling_MERT.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MERT model definition.
3
+ We largely adapt codes from:
4
+ 1. https://github.com/huggingface/transformers/blob/main/src/transformers/models/hubert/modeling_hubert.py
5
+ 2. https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
6
+ """
7
+
8
+ from typing import Optional, Tuple, Union
9
+ from transformers.modeling_outputs import BaseModelOutput
10
+ import torch
11
+ from torch import nn
12
+
13
+ from transformers.models.hubert.modeling_hubert import (
14
+ HubertFeatureEncoder,
15
+ HubertModel,
16
+ HubertEncoderStableLayerNorm,
17
+ HubertEncoder,
18
+ HubertEncoderLayer,
19
+ HubertPositionalConvEmbedding,
20
+ HubertAttention,
21
+ HubertFeedForward,
22
+ )
23
+
24
+ try:
25
+ from nnAudio import features as nnAudioFeatures
26
+ NNAUDIO_INSTALLED=True
27
+ except:
28
+ print("WARNING: feature_extractor_cqt requires the libray 'nnAudio'")
29
+ NNAUDIO_INSTALLED=False
30
+
31
+ from .configuration_MERT import MERTConfig
32
+
33
+ class MERTFeatureProjection(nn.Module):
34
+ def __init__(self, config):
35
+ super().__init__()
36
+ self.feat_proj_layer_norm = config.feat_proj_layer_norm
37
+ self.feature_extractor_cqt = config.feature_extractor_cqt
38
+
39
+ if self.feature_extractor_cqt:
40
+ # v3 concat features
41
+ self.feature_dimension = config.conv_dim[-1] + config.feature_extractor_cqt_bins
42
+ print(f"feature dimention: {self.feature_dimension}")
43
+ else:
44
+ self.feature_dimension = config.conv_dim[-1]
45
+ if self.feat_proj_layer_norm:
46
+ self.layer_norm = nn.LayerNorm(self.feature_dimension, eps=config.layer_norm_eps)
47
+ self.projection = nn.Linear(self.feature_dimension, config.hidden_size)
48
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
49
+
50
+ def forward(self, hidden_states):
51
+ # non-projected hidden states are needed for quantization
52
+ if self.feat_proj_layer_norm:
53
+ hidden_states = self.layer_norm(hidden_states)
54
+ hidden_states = self.projection(hidden_states)
55
+ hidden_states = self.dropout(hidden_states)
56
+ return hidden_states
57
+
58
+ class MERTModel(HubertModel):
59
+ # overwrite config class
60
+ config_class = MERTConfig
61
+ base_model_prefix = "mert_model"
62
+ def __init__(
63
+ self,
64
+ config: MERTConfig,
65
+ ) -> None:
66
+ """
67
+ initialize the with the grandparent method HubertPreTrainedModel.__init__()
68
+ and modify the HuBERTModel.__init__()
69
+ """
70
+ super(HubertModel, self).__init__(config)
71
+
72
+ self.config = config
73
+
74
+ self.feature_extractor = HubertFeatureEncoder(config)
75
+ self.feature_projection = MERTFeatureProjection(config) # replace Feature Projection for introcuing new feature
76
+
77
+ if self.config.feature_extractor_cqt:
78
+ assert NNAUDIO_INSTALLED, "ERROR: feature_extractor_cqt requires the libray 'nnAudio', try after `pip install nnAudio` "
79
+ print('initializing cqt extractor for MERT')
80
+ self.feature_extractor_cqt = nnAudioFeatures.cqt.CQT(sr=self.config.sample_rate, hop_length=self.config.sample_rate//50, fmin=32.7,
81
+ fmax=None, n_bins=self.config.feature_extractor_cqt_bins, bins_per_octave=self.config.feature_extractor_cqt_bins//7,
82
+ filter_scale=1, norm=1, window='hann', center=True,
83
+ pad_mode='constant', trainable=False,
84
+ output_format='Magnitude', verbose=True)
85
+
86
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
87
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
88
+
89
+
90
+ if config.do_stable_layer_norm:
91
+ assert not config.deepnorm, "must use post-layer_norm with deepnorm"
92
+ self.encoder = HubertEncoderStableLayerNorm(config)
93
+ else:
94
+ if config.deepnorm:
95
+ self.encoder = HubertEncoder_extend(config)
96
+ else:
97
+ self.encoder = HubertEncoder(config)
98
+
99
+ # Initialize weights and apply final processing
100
+ self.post_init()
101
+
102
+ def forward(self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) -> Union[Tuple, BaseModelOutput]:
103
+
104
+ # return super().forward(input_values, attention_mask, mask_time_indices, output_attentions, output_hidden_states, return_dict)
105
+
106
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
107
+ output_hidden_states = (
108
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
109
+ )
110
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
111
+
112
+ extract_features = self.feature_extractor(input_values)
113
+ extract_features = extract_features.transpose(1, 2)
114
+
115
+ # add additional cqt features for transformer input
116
+ if self.config.feature_extractor_cqt:
117
+ features_cqt = self.feature_extractor_cqt(input_values).transpose(1, 2)
118
+ features_cqt = features_cqt[:,:extract_features.shape[1],:] # align shape
119
+ # # v2
120
+ # features_cqt = self.post_cqt_feature_proj(features_cqt)
121
+ # extract_features = self.feature_projection.layer_norm(extract_features) + self.feature_projection.layer_norm(features_cqt) #v2
122
+ # v3
123
+ extract_features = torch.cat([extract_features,features_cqt], 2)
124
+
125
+ if attention_mask is not None:
126
+ # compute reduced attention_mask corresponding to feature vectors
127
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
128
+
129
+ hidden_states = self.feature_projection(extract_features)
130
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
131
+
132
+ encoder_outputs = self.encoder(
133
+ hidden_states,
134
+ attention_mask=attention_mask,
135
+ output_attentions=output_attentions,
136
+ output_hidden_states=output_hidden_states,
137
+ return_dict=return_dict,
138
+ )
139
+
140
+ hidden_states = encoder_outputs[0] # take last_hidden from encoder output
141
+
142
+ if not return_dict:
143
+ return (hidden_states,) + encoder_outputs[1:]
144
+
145
+ return BaseModelOutput(
146
+ last_hidden_state=hidden_states,
147
+ hidden_states=encoder_outputs.hidden_states,
148
+ attentions=encoder_outputs.attentions,
149
+ )
150
+
151
+
152
+ class HubertEncoder_extend(HubertEncoder):
153
+ def __init__(self, config):
154
+ # super().__init__()
155
+ # call nn module initialization
156
+ nn.Module.__init__(self)
157
+ # super(HubertEncoder_extend, self).__init__()
158
+
159
+ self.config = config
160
+ self.pos_conv_embed = HubertPositionalConvEmbedding(config)
161
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
162
+ self.dropout = nn.Dropout(config.hidden_dropout)
163
+
164
+
165
+ self.layers = nn.ModuleList([HubertEncoderLayerExtend(config) for _ in range(config.num_hidden_layers)])
166
+
167
+ self.gradient_checkpointing = False
168
+
169
+ if config.deepnorm:
170
+ import math
171
+ init_scale = math.pow(8.0 * config.num_hidden_layers, 0.25)
172
+ for name, p in self.named_parameters():
173
+ if (
174
+ "feed_forward.intermediate_dense" in name
175
+ or "feed_forward.output_dense" in name
176
+ or "out_proj" in name
177
+ or "v_proj" in name
178
+ ):
179
+ p.data.div_(init_scale)
180
+
181
+ class HubertEncoderLayerExtend(HubertEncoderLayer):
182
+ def __init__(self, config):
183
+ nn.Module.__init__(self)
184
+ # super(HubertEncoderLayerExtend, self).__init__()
185
+ if config.attention_relax > 0 :
186
+ self.attention = HubertAttention_extend(
187
+ embed_dim=config.hidden_size,
188
+ num_heads=config.num_attention_heads,
189
+ dropout=config.attention_dropout,
190
+ is_decoder=False,
191
+ attention_relax=config.attention_relax,
192
+ )
193
+ else:
194
+ self.attention = HubertAttention(
195
+ embed_dim=config.hidden_size,
196
+ num_heads=config.num_attention_heads,
197
+ dropout=config.attention_dropout,
198
+ is_decoder=False,
199
+ )
200
+ self.dropout = nn.Dropout(config.hidden_dropout)
201
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
202
+ self.feed_forward = HubertFeedForward(config)
203
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
204
+
205
+ if config.deepnorm:
206
+ import math
207
+ self.residual_alpha = math.pow(2.0 * config.num_hidden_layers, 0.25)
208
+ else:
209
+ self.residual_alpha = 1.0
210
+
211
+ def residual_connection(self, x, residual):
212
+ '''
213
+ residual: input before f()
214
+ x: output of f(residual)
215
+ '''
216
+ return residual * self.residual_alpha + x
217
+
218
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
219
+ attn_residual = hidden_states
220
+ hidden_states, attn_weights, _ = self.attention(
221
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
222
+ )
223
+ hidden_states = self.dropout(hidden_states)
224
+
225
+ # hidden_states = attn_residual + hidden_states
226
+ hidden_states = self.residual_connection(hidden_states, attn_residual)
227
+
228
+ hidden_states = self.layer_norm(hidden_states)
229
+
230
+ # hidden_states = hidden_states + self.feed_forward(hidden_states)
231
+ ffn_residual = hidden_states
232
+ hidden_states = self.feed_forward(hidden_states)
233
+ hidden_states = self.residual_connection(hidden_states, ffn_residual)
234
+
235
+ hidden_states = self.final_layer_norm(hidden_states)
236
+
237
+ outputs = (hidden_states,)
238
+
239
+ if output_attentions:
240
+ outputs += (attn_weights,)
241
+
242
+ return outputs
243
+
244
+
245
+ class HubertAttention_extend(nn.Module):
246
+ def __init__(
247
+ self,
248
+ embed_dim: int,
249
+ num_heads: int,
250
+ dropout: float = 0.0,
251
+ is_decoder: bool = False,
252
+ bias: bool = True,
253
+ attention_relax: float = -1.0,
254
+ ):
255
+ super().__init__()
256
+ # nn.Module.__init__(self)
257
+ self.embed_dim = embed_dim
258
+ self.num_heads = num_heads
259
+ self.dropout = dropout
260
+ self.head_dim = embed_dim // num_heads
261
+
262
+ if (self.head_dim * num_heads) != self.embed_dim:
263
+ raise ValueError(
264
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
265
+ f" and `num_heads`: {num_heads})."
266
+ )
267
+ self.scaling = self.head_dim**-0.5
268
+ self.is_decoder = is_decoder
269
+
270
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
271
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
272
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
273
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
274
+
275
+ if attention_relax > 0:
276
+ self.attention_relax = attention_relax
277
+
278
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
279
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
280
+
281
+ def forward(
282
+ self,
283
+ hidden_states: torch.Tensor,
284
+ key_value_states: Optional[torch.Tensor] = None,
285
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
286
+ attention_mask: Optional[torch.Tensor] = None,
287
+ layer_head_mask: Optional[torch.Tensor] = None,
288
+ output_attentions: bool = False,
289
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
290
+ """Input shape: Batch x Time x Channel"""
291
+
292
+ # if key_value_states are provided this layer is used as a cross-attention layer
293
+ # for the decoder
294
+ is_cross_attention = key_value_states is not None
295
+
296
+ bsz, tgt_len, _ = hidden_states.size()
297
+
298
+ # get query proj
299
+ query_states = self.q_proj(hidden_states) * self.scaling
300
+ # get key, value proj
301
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
302
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
303
+ # the provided `key_value_states` to support prefix tuning
304
+ if (
305
+ is_cross_attention
306
+ and past_key_value is not None
307
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
308
+ ):
309
+ # reuse k,v, cross_attentions
310
+ key_states = past_key_value[0]
311
+ value_states = past_key_value[1]
312
+ elif is_cross_attention:
313
+ # cross_attentions
314
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
315
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
316
+ elif past_key_value is not None:
317
+ # reuse k, v, self_attention
318
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
319
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
320
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
321
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
322
+ else:
323
+ # self_attention
324
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
325
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
326
+
327
+ if self.is_decoder:
328
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
329
+ # Further calls to cross_attention layer can then reuse all cross-attention
330
+ # key/value_states (first "if" case)
331
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
332
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
333
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
334
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
335
+ past_key_value = (key_states, value_states)
336
+
337
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
338
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
339
+ key_states = key_states.view(*proj_shape)
340
+ value_states = value_states.view(*proj_shape)
341
+
342
+ src_len = key_states.size(1)
343
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
344
+
345
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
346
+ raise ValueError(
347
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
348
+ f" {attn_weights.size()}"
349
+ )
350
+
351
+ if attention_mask is not None:
352
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
353
+ raise ValueError(
354
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
355
+ )
356
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
357
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
358
+
359
+ if self.attention_relax > 0:
360
+ # => (bsz, self.num_heads, tgt_len, src_len)
361
+ # attn_weights_relax = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)/self.attention_relax
362
+ # => (bsz*self.num_heads, tgt_len, src_len)
363
+ attn_weights_relax = attn_weights / self.attention_relax
364
+
365
+ # => (bsz* self.num_heads, tgt_len, 1)
366
+ attn_max_relax = torch.max(attn_weights_relax, dim=-1, keepdim=False).unsqueeze(2)
367
+ attn_weights = (attn_weights_relax - attn_max_relax) * self.attention_relax
368
+
369
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
370
+
371
+ if layer_head_mask is not None:
372
+ if layer_head_mask.size() != (self.num_heads,):
373
+ raise ValueError(
374
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
375
+ f" {layer_head_mask.size()}"
376
+ )
377
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
378
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
379
+
380
+ if output_attentions:
381
+ # this operation is a bit awkward, but it's required to
382
+ # make sure that attn_weights keeps its gradient.
383
+ # In order to do so, attn_weights have to be reshaped
384
+ # twice and have to be reused in the following
385
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
386
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
387
+ else:
388
+ attn_weights_reshaped = None
389
+
390
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
391
+
392
+ attn_output = torch.bmm(attn_probs, value_states)
393
+
394
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
395
+ raise ValueError(
396
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
397
+ f" {attn_output.size()}"
398
+ )
399
+
400
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
401
+ attn_output = attn_output.transpose(1, 2)
402
+
403
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
404
+ # partitioned aross GPUs when using tensor-parallelism.
405
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
406
+
407
+ attn_output = self.out_proj(attn_output)
408
+
409
+ return attn_output, attn_weights_reshaped, past_key_value
MERT-v1-95M/preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": false,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 24000
9
+ }
MERT-v1-95M/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2b8b747f72c06e0595aeae41ae5473f4364938c6b39b2c58be38c48e6bd3fcd
3
+ size 377552987
Prediction_Head/MTGGenre_head.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ class MLPProberBase(nn.Module):
6
+ def __init__(self, d=768, layer='all', num_outputs=87):
7
+ super().__init__()
8
+
9
+ self.hidden_layer_sizes = [512, ] # eval(self.cfg.hidden_layer_sizes)
10
+
11
+ self.num_layers = len(self.hidden_layer_sizes)
12
+
13
+ self.layer = layer
14
+
15
+ for i, ld in enumerate(self.hidden_layer_sizes):
16
+ setattr(self, f"hidden_{i}", nn.Linear(d, ld))
17
+ d = ld
18
+ self.output = nn.Linear(d, num_outputs)
19
+
20
+ self.n_tranformer_layer = 12
21
+
22
+ self.init_aggregator()
23
+
24
+
25
+ def init_aggregator(self):
26
+ """Initialize the aggregator for weighted sum over different layers of features
27
+ """
28
+ if self.layer == "all":
29
+ # use learned weights to aggregate features
30
+ self.aggregator = nn.Parameter(torch.randn((1, self.n_tranformer_layer, 1)))
31
+
32
+
33
+ def forward(self, x):
34
+ """
35
+ x: (B, L, T, H)
36
+ T=#chunks, can be 1 or several chunks
37
+ """
38
+
39
+ if self.layer == "all":
40
+ weights = F.softmax(self.aggregator, dim=1)
41
+ x = (x * weights).sum(dim=1)
42
+
43
+ for i in range(self.num_layers):
44
+ x = getattr(self, f"hidden_{i}")(x)
45
+ # x = self.dropout(x)
46
+ x = F.relu(x)
47
+ output = self.output(x)
48
+ return output
Prediction_Head/MTGGenre_id2class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "genre---rock", "1": "genre---pop", "2": "genre---classical", "3": "genre---popfolk", "4": "genre---disco", "5": "genre---funk", "6": "genre---rnb", "7": "genre---ambient", "8": "genre---chillout", "9": "genre---downtempo", "10": "genre---easylistening", "11": "genre---electronic", "12": "genre---lounge", "13": "genre---triphop", "14": "genre---breakbeat", "15": "genre---techno", "16": "genre---newage", "17": "genre---jazz", "18": "genre---metal", "19": "genre---industrial", "20": "genre---instrumentalrock", "21": "genre---minimal", "22": "genre---alternative", "23": "genre---experimental", "24": "genre---drumnbass", "25": "genre---soul", "26": "genre---fusion", "27": "genre---soundtrack", "28": "genre---electropop", "29": "genre---world", "30": "genre---ethno", "31": "genre---trance", "32": "genre---orchestral", "33": "genre---grunge", "34": "genre---chanson", "35": "genre---worldfusion", "36": "genre---hiphop", "37": "genre---groove", "38": "genre---instrumentalpop", "39": "genre---blues", "40": "genre---reggae", "41": "genre---dance", "42": "genre---club", "43": "genre---punkrock", "44": "genre---folk", "45": "genre---synthpop", "46": "genre---poprock", "47": "genre---choir", "48": "genre---symphonic", "49": "genre---indie", "50": "genre---progressive", "51": "genre---acidjazz", "52": "genre---contemporary", "53": "genre---newwave", "54": "genre---dub", "55": "genre---rocknroll", "56": "genre---hard", "57": "genre---hardrock", "58": "genre---house", "59": "genre---atmospheric", "60": "genre---psychedelic", "61": "genre---improvisation", "62": "genre---country", "63": "genre---electronica", "64": "genre---rap", "65": "genre---60s", "66": "genre---70s", "67": "genre---darkambient", "68": "genre---idm", "69": "genre---latin", "70": "genre---postrock", "71": "genre---bossanova", "72": "genre---singersongwriter", "73": "genre---darkwave", "74": "genre---swing", "75": "genre---medieval", "76": "genre---celtic", "77": "genre---eurodance", "78": "genre---classicrock", "79": "genre---dubstep", "80": "genre---bluesrock", "81": "genre---edm", "82": "genre---deephouse", "83": "genre---jazzfusion", "84": "genre---alternativerock", "85": "genre---80s", "86": "genre---90s"}
Prediction_Head/__MACOSX/._best-layer-MERT-v1-95M ADDED
Binary file (220 Bytes). View file
 
Prediction_Head/__pycache__/MTGGenre_head.cpython-310.pyc ADDED
Binary file (1.67 kB). View file
 
Prediction_Head/best-layer-MERT-v1-95M.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8155db897d77d6896ba5d87e2af5cc335a3fb1dd300356185848982deacad4d
3
+ size 17025915
Prediction_Head/best-layer-MERT-v1-95M/EMO.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cd122f800af666968e97e22ae18ea3473782ccf8f10da7783e20d35f90252cb
3
+ size 1584859
Prediction_Head/best-layer-MERT-v1-95M/EMO.id2class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "arousal", "1": "valence"}
Prediction_Head/best-layer-MERT-v1-95M/GS.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f87e9b3040a2abc59c4ea1e3ed98b27b29866584c7d3370893b6400ef536f519
3
+ size 1629851
Prediction_Head/best-layer-MERT-v1-95M/GS.id2class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "C major", "1": "Db major", "2": "D major", "3": "Eb major", "4": "E major", "5": "F major", "6": "Gb major", "7": "G major", "8": "Ab major", "9": "A major", "10": "Bb major", "11": "B major", "12": "C minor", "13": "Db minor", "14": "D minor", "15": "Eb minor", "16": "E minor", "17": "F minor", "18": "Gb minor", "19": "G minor", "20": "Ab minor", "21": "A minor", "22": "Bb minor", "23": "B minor"}
Prediction_Head/best-layer-MERT-v1-95M/GTZAN.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78e05de1074d1b3f68868304c66a02a32e4b1915ecabbb6fe9801f3cf862d4f
3
+ size 1601115
Prediction_Head/best-layer-MERT-v1-95M/GTZAN.id2class.json ADDED
File without changes
Prediction_Head/best-layer-MERT-v1-95M/MTGGenre.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e726ad4dc706edd51cf07e6bc781b2f9882af6f85aa510badb74e6af3b0bd46
3
+ size 1759259
Prediction_Head/best-layer-MERT-v1-95M/MTGGenre.id2class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "genre---rock", "1": "genre---pop", "2": "genre---classical", "3": "genre---popfolk", "4": "genre---disco", "5": "genre---funk", "6": "genre---rnb", "7": "genre---ambient", "8": "genre---chillout", "9": "genre---downtempo", "10": "genre---easylistening", "11": "genre---electronic", "12": "genre---lounge", "13": "genre---triphop", "14": "genre---breakbeat", "15": "genre---techno", "16": "genre---newage", "17": "genre---jazz", "18": "genre---metal", "19": "genre---industrial", "20": "genre---instrumentalrock", "21": "genre---minimal", "22": "genre---alternative", "23": "genre---experimental", "24": "genre---drumnbass", "25": "genre---soul", "26": "genre---fusion", "27": "genre---soundtrack", "28": "genre---electropop", "29": "genre---world", "30": "genre---ethno", "31": "genre---trance", "32": "genre---orchestral", "33": "genre---grunge", "34": "genre---chanson", "35": "genre---worldfusion", "36": "genre---hiphop", "37": "genre---groove", "38": "genre---instrumentalpop", "39": "genre---blues", "40": "genre---reggae", "41": "genre---dance", "42": "genre---club", "43": "genre---punkrock", "44": "genre---folk", "45": "genre---synthpop", "46": "genre---poprock", "47": "genre---choir", "48": "genre---symphonic", "49": "genre---indie", "50": "genre---progressive", "51": "genre---acidjazz", "52": "genre---contemporary", "53": "genre---newwave", "54": "genre---dub", "55": "genre---rocknroll", "56": "genre---hard", "57": "genre---hardrock", "58": "genre---house", "59": "genre---atmospheric", "60": "genre---psychedelic", "61": "genre---improvisation", "62": "genre---country", "63": "genre---electronica", "64": "genre---rap", "65": "genre---60s", "66": "genre---70s", "67": "genre---darkambient", "68": "genre---idm", "69": "genre---latin", "70": "genre---postrock", "71": "genre---bossanova", "72": "genre---singersongwriter", "73": "genre---darkwave", "74": "genre---swing", "75": "genre---medieval", "76": "genre---celtic", "77": "genre---eurodance", "78": "genre---classicrock", "79": "genre---dubstep", "80": "genre---bluesrock", "81": "genre---edm", "82": "genre---deephouse", "83": "genre---jazzfusion", "84": "genre---alternativerock", "85": "genre---80s", "86": "genre---90s"}
Prediction_Head/best-layer-MERT-v1-95M/MTGInstrument.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:228871f490cff126fac56fa81aa099ce4f07eb71a95da5ae830a5ec056a495ea
3
+ size 1663127
Prediction_Head/best-layer-MERT-v1-95M/MTGInstrument.id2class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "instrument---voice", "1": "instrument---synthesizer", "2": "instrument---piano", "3": "instrument---strings", "4": "instrument---beat", "5": "instrument---guitar", "6": "instrument---cello", "7": "instrument---keyboard", "8": "instrument---trombone", "9": "instrument---clarinet", "10": "instrument---doublebass", "11": "instrument---horn", "12": "instrument---trumpet", "13": "instrument---violin", "14": "instrument---accordion", "15": "instrument---bass", "16": "instrument---computer", "17": "instrument---drummachine", "18": "instrument---drums", "19": "instrument---electricguitar", "20": "instrument---sampler", "21": "instrument---acousticguitar", "22": "instrument---harmonica", "23": "instrument---flute", "24": "instrument---pipeorgan", "25": "instrument---harp", "26": "instrument---electricpiano", "27": "instrument---oboe", "28": "instrument---saxophone", "29": "instrument---percussion", "30": "instrument---acousticbassguitar", "31": "instrument---orchestra", "32": "instrument---bongo", "33": "instrument---brass", "34": "instrument---viola", "35": "instrument---rhodes", "36": "instrument---organ", "37": "instrument---classicalguitar", "38": "instrument---bell", "39": "instrument---pad"}
Prediction_Head/best-layer-MERT-v1-95M/MTGMood.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f838a4a2015ecce024222fee9f2bba56ee64f6259455fb2f15109da69db232c4
3
+ size 1695643
Prediction_Head/best-layer-MERT-v1-95M/MTGMood.id2class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "mood/theme---background", "1": "mood/theme---film", "2": "mood/theme---melancholic", "3": "mood/theme---melodic", "4": "mood/theme---children", "5": "mood/theme---relaxing", "6": "mood/theme---documentary", "7": "mood/theme---emotional", "8": "mood/theme---space", "9": "mood/theme---love", "10": "mood/theme---drama", "11": "mood/theme---adventure", "12": "mood/theme---energetic", "13": "mood/theme---heavy", "14": "mood/theme---dark", "15": "mood/theme---calm", "16": "mood/theme---action", "17": "mood/theme---dramatic", "18": "mood/theme---epic", "19": "mood/theme---powerful", "20": "mood/theme---upbeat", "21": "mood/theme---slow", "22": "mood/theme---inspiring", "23": "mood/theme---soft", "24": "mood/theme---meditative", "25": "mood/theme---fun", "26": "mood/theme---happy", "27": "mood/theme---positive", "28": "mood/theme---romantic", "29": "mood/theme---sad", "30": "mood/theme---hopeful", "31": "mood/theme---motivational", "32": "mood/theme---deep", "33": "mood/theme---uplifting", "34": "mood/theme---ballad", "35": "mood/theme---soundscape", "36": "mood/theme---dream", "37": "mood/theme---movie", "38": "mood/theme---fast", "39": "mood/theme---nature", "40": "mood/theme---cool", "41": "mood/theme---corporate", "42": "mood/theme---travel", "43": "mood/theme---funny", "44": "mood/theme---sport", "45": "mood/theme---commercial", "46": "mood/theme---advertising", "47": "mood/theme---holiday", "48": "mood/theme---christmas", "49": "mood/theme---sexy", "50": "mood/theme---game", "51": "mood/theme---groovy", "52": "mood/theme---retro", "53": "mood/theme---summer", "54": "mood/theme---party", "55": "mood/theme---trailer"}
Prediction_Head/best-layer-MERT-v1-95M/MTGTop50.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65f4e397e69c751c416608f6ac8ac7cf17f7c1283e106c808ea8a12549f9c022
3
+ size 1683355
Prediction_Head/best-layer-MERT-v1-95M/MTGTop50.id2class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "genre---rock", "1": "genre---pop", "2": "genre---classical", "3": "instrument---voice", "4": "genre---popfolk", "5": "genre---funk", "6": "genre---ambient", "7": "genre---chillout", "8": "genre---downtempo", "9": "genre---easylistening", "10": "genre---electronic", "11": "genre---lounge", "12": "instrument---synthesizer", "13": "genre---triphop", "14": "genre---techno", "15": "genre---newage", "16": "genre---jazz", "17": "genre---metal", "18": "instrument---piano", "19": "genre---alternative", "20": "genre---experimental", "21": "genre---soundtrack", "22": "mood/theme---film", "23": "genre---world", "24": "instrument---strings", "25": "genre---trance", "26": "genre---orchestral", "27": "instrument---guitar", "28": "genre---hiphop", "29": "genre---instrumentalpop", "30": "mood/theme---relaxing", "31": "genre---reggae", "32": "mood/theme---emotional", "33": "instrument---keyboard", "34": "instrument---violin", "35": "genre---dance", "36": "instrument---bass", "37": "instrument---computer", "38": "instrument---drummachine", "39": "instrument---drums", "40": "instrument---electricguitar", "41": "genre---folk", "42": "instrument---acousticguitar", "43": "genre---poprock", "44": "genre---indie", "45": "mood/theme---energetic", "46": "mood/theme---happy", "47": "instrument---electricpiano", "48": "genre---house", "49": "genre---atmospheric"}
Prediction_Head/best-layer-MERT-v1-95M/MTT.id2class.json ADDED
File without changes
Prediction_Head/best-layer-MERT-v1-95M/NSynthI.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc8f805a0d2a63f4a32f0eeecd166da41683f25e8f5ab46405a2bb8bd53483b0
3
+ size 1603163
Prediction_Head/best-layer-MERT-v1-95M/NSynthI.id2class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "bass", "1": "brass", "2": "flute", "3": "guitar", "4": "keyboard", "5": "mallet", "6": "organ", "7": "reed", "8": "string", "9": "synth_lead", "10": "vocal"}
Prediction_Head/best-layer-MERT-v1-95M/NSynthP.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c30afb678b285ada9f5ad348919617dbd708eedf3af6fd2be6ea953c9c6f5385
3
+ size 1843227
Prediction_Head/best-layer-MERT-v1-95M/NSynthP.id2class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": 1, "1": 2, "2": 3, "3": 4, "4": 5, "5": 6, "6": 7, "7": 8, "8": 9, "9": 10, "10": 11, "11": 12, "12": 13, "13": 14, "14": 15, "15": 16, "16": 17, "17": 18, "18": 19, "19": 20, "20": 21, "21": 22, "22": 23, "23": 24, "24": 25, "25": 26, "26": 27, "27": 28, "28": 29, "29": 30, "30": 31, "31": 32, "32": 33, "33": 34, "34": 35, "35": 36, "36": 37, "37": 38, "38": 39, "39": 40, "40": 41, "41": 42, "42": 43, "43": 44, "44": 45, "45": 46, "46": 47, "47": 48, "48": 49, "49": 50, "50": 51, "51": 52, "52": 53, "53": 54, "54": 55, "55": 56, "56": 57, "57": 58, "58": 59, "59": 60, "60": 61, "61": 62, "62": 63, "63": 64, "64": 65, "65": 66, "66": 67, "67": 68, "68": 69, "69": 70, "70": 71, "71": 72, "72": 73, "73": 74, "74": 75, "75": 76, "76": 77, "77": 78, "78": 79, "79": 80, "80": 81, "81": 82, "82": 83, "83": 84, "84": 85, "85": 86, "86": 87, "87": 88, "88": 89, "89": 90, "90": 91, "91": 92, "92": 93, "93": 94, "94": 95, "95": 96, "96": 97, "97": 98, "98": 99, "99": 100, "100": 101, "101": 102, "102": 103, "103": 104, "104": 105, "105": 106, "106": 107, "107": 108, "108": 109, "109": 110, "110": 111, "111": 112, "112": 113, "113": 114, "114": 115, "115": 116, "116": 117, "117": 118, "118": 119, "119": 120, "120": 121, "121": 122, "122": 123, "123": 124, "124": 125, "125": 126, "126": 127, "127": 128}
Prediction_Head/best-layer-MERT-v1-95M/VocalSetS.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aaf282b748e1e401b41dcd425f614edf46cd1b28300afb44cffa3facbc4a1d05
3
+ size 1621659
Prediction_Head/best-layer-MERT-v1-95M/VocalSetS.id2class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "f1", "1": "f2", "2": "f3", "3": "f4", "4": "f5", "5": "f6", "6": "f7", "7": "f8", "8": "f9", "9": "m1", "10": "m2", "11": "m3", "12": "m4", "13": "m5", "14": "m6", "15": "m7", "16": "m8", "17": "m9", "18": "m10", "19": "m11"}
Prediction_Head/best-layer-MERT-v1-95M/VocalSetT.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a97ed7665e288af4c9760b0a150bdd7a11d377cce4b826e99dc323b9a078a643
3
+ size 1601115
Prediction_Head/best-layer-MERT-v1-95M/VocalSetT.id2class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "belt", "1": "breathy", "2": "inhaled", "3": "lip_trill", "4": "spoken", "5": "straight", "6": "trill", "7": "trillo", "8": "vibrato", "9": "vocal_fry"}
Prediction_Head/best_MTGGenre.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83b7dcffde10a0dc7ba74341ea56dabec5c5de7cad6a0483708c80f1d893514a
3
+ size 1759067
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Music Descriptor
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.32.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
+ title: MusicTagging
3
+ emoji: 💻
4
+ colorFrom: gray
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.29.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
__pycache__/app.cpython-310.pyc ADDED
Binary file (4.91 kB). View file
 
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ #
3
+ from transformers import Wav2Vec2FeatureExtractor
4
+ from transformers import AutoModel
5
+ import torch
6
+ from torch import nn
7
+ import torchaudio
8
+ import torchaudio.transforms as T
9
+ import logging
10
+
11
+ import json
12
+ import os
13
+
14
+ import importlib
15
+ modeling_MERT = importlib.import_module("MERT-v1-95M.modeling_MERT")
16
+
17
+ from Prediction_Head.MTGGenre_head import MLPProberBase
18
+ # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
19
+
20
+
21
+ logger = logging.getLogger("MERT-v1-95M-app")
22
+ logger.setLevel(logging.INFO)
23
+ ch = logging.StreamHandler()
24
+ ch.setLevel(logging.INFO)
25
+ formatter = logging.Formatter(
26
+ "%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
27
+ ch.setFormatter(formatter)
28
+ logger.addHandler(ch)
29
+
30
+
31
+
32
+ inputs = [
33
+ gr.components.Audio(type="filepath", label="Add music audio file"),
34
+ gr.inputs.Audio(source="microphone", type="filepath"),
35
+ ]
36
+ live_inputs = [
37
+ gr.Audio(source="microphone",streaming=True, type="filepath"),
38
+ ]
39
+ # outputs = [gr.components.Textbox()]
40
+ # outputs = [gr.components.Textbox(), transcription_df]
41
+ title = "One Model for All Music Understanding Tasks"
42
+ description = "An example of using the [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M) model as backbone to conduct multiple music understanding tasks with the universal represenation."
43
+ article = "The tasks include EMO, GS, MTGInstrument, MTGGenre, MTGTop50, MTGMood, NSynthI, NSynthP, VocalSetS, VocalSetT. \n\n More models can be referred at the [map organization page](https://huggingface.co/m-a-p)."
44
+ audio_examples = [
45
+ # ["input/example-1.wav"],
46
+ # ["input/example-2.wav"],
47
+ ]
48
+
49
+ # Load the model and the corresponding preprocessor config
50
+ # model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
51
+ # processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
52
+ model = modeling_MERT.MERTModel.from_pretrained("./MERT-v1-95M")
53
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v1-95M")
54
+
55
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
56
+
57
+ MERT_BEST_LAYER_IDX = {
58
+ 'EMO': 5,
59
+ 'GS': 8,
60
+ 'GTZAN': 7,
61
+ 'MTGGenre': 7,
62
+ 'MTGInstrument': 'all',
63
+ 'MTGMood': 6,
64
+ 'MTGTop50': 6,
65
+ 'MTT': 'all',
66
+ 'NSynthI': 6,
67
+ 'NSynthP': 1,
68
+ 'VocalSetS': 2,
69
+ 'VocalSetT': 9,
70
+ }
71
+
72
+ MERT_BEST_LAYER_IDX = {
73
+ 'EMO': 5,
74
+ 'GS': 8,
75
+ 'GTZAN': 7,
76
+ 'MTGGenre': 7,
77
+ 'MTGInstrument': 'all',
78
+ 'MTGMood': 6,
79
+ 'MTGTop50': 6,
80
+ 'MTT': 'all',
81
+ 'NSynthI': 6,
82
+ 'NSynthP': 1,
83
+ 'VocalSetS': 2,
84
+ 'VocalSetT': 9,
85
+ }
86
+ CLASSIFIERS = {
87
+
88
+ }
89
+
90
+ ID2CLASS = {
91
+
92
+ }
93
+
94
+ TASKS = ['EMO','GS', 'MTGInstrument', 'MTGGenre', 'MTGTop50', 'MTGMood', 'NSynthI', 'NSynthP', 'VocalSetS', 'VocalSetT']
95
+ head_dir = '/home/chenghua/nanshen/Yizhi/MERT_Universal/Prediction_Head/best-layer-MERT-v1-95M'
96
+ for task in TASKS:
97
+ print('loading', task)
98
+ with open(os.path.join(head_dir,f'{task}.id2class.json'), 'r') as f:
99
+ ID2CLASS[task]=json.load(f)
100
+ num_class = len(ID2CLASS[task].keys())
101
+ CLASSIFIERS[task] = MLPProberBase(d=768, layer=MERT_BEST_LAYER_IDX[task], num_outputs=num_class)
102
+ CLASSIFIERS[task].load_state_dict(torch.load(f'/home/chenghua/nanshen/Yizhi/MERT_Universal/Prediction_Head/best-layer-MERT-v1-95M/{task}.ckpt')['state_dict'])
103
+ CLASSIFIERS[task].to(device)
104
+
105
+ model.to(device)
106
+
107
+ def model_infernce(inputs):
108
+ waveform, sample_rate = torchaudio.load(inputs)
109
+
110
+ resample_rate = processor.sampling_rate
111
+
112
+ # make sure the sample_rate aligned
113
+ if resample_rate != sample_rate:
114
+ print(f'setting rate from {sample_rate} to {resample_rate}')
115
+ resampler = T.Resample(sample_rate, resample_rate)
116
+ waveform = resampler(waveform)
117
+
118
+ waveform = waveform.view(-1,) # make it (n_sample, )
119
+ model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt")
120
+ model_inputs.to(device)
121
+ with torch.no_grad():
122
+ model_outputs = model(**model_inputs, output_hidden_states=True)
123
+
124
+ # take a look at the output shape, there are 13 layers of representation
125
+ # each layer performs differently in different downstream tasks, you should choose empirically
126
+ all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()[1:,:,:].unsqueeze(0)
127
+ print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
128
+ all_layer_hidden_states = all_layer_hidden_states.mean(dim=2)
129
+
130
+ task_output_texts = ""
131
+ for task in TASKS:
132
+ num_class = len(ID2CLASS[task].keys())
133
+ if MERT_BEST_LAYER_IDX[task] == 'all':
134
+ logits = CLASSIFIERS[task](all_layer_hidden_states) # [1, 87]
135
+ else:
136
+ logits = CLASSIFIERS[task](all_layer_hidden_states[:, MERT_BEST_LAYER_IDX[task]])
137
+ print(f'task {task} logits:', logits.shape, 'num class:', num_class)
138
+
139
+ sorted_idx = torch.argsort(logits, dim = -1, descending=True)[0] # batch =1
140
+ sorted_prob,_ = torch.sort(nn.functional.softmax(logits[0], dim=-1), dim=-1, descending=True)
141
+ # print(sorted_prob)
142
+ # print(sorted_prob.shape)
143
+
144
+ top_n_show = 3 if num_class >= 3 else num_class
145
+ task_output_texts = task_output_texts + f"TASK {task} output:\n" + "\n".join([str(ID2CLASS[task][str(sorted_idx[idx].item())])+f', probability: {sorted_prob[idx].item():.2%}' for idx in range(top_n_show)]) + '\n'
146
+ task_output_texts = task_output_texts + '----------------------\n'
147
+ # output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]])
148
+ # logger.warning(all_layer_hidden_states.shape)
149
+
150
+ # return f"device {device}, sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}"
151
+ # return f"device: {device}\n" + output_texts
152
+ return task_output_texts
153
+
154
+ def convert_audio(inputs, microphone):
155
+ if (microphone is not None):
156
+ inputs = microphone
157
+
158
+ text = model_infernce(inputs)
159
+
160
+ return text
161
+
162
+ def live_convert_audio(microphone):
163
+ if (microphone is not None):
164
+ inputs = microphone
165
+
166
+ text = model_infernce(inputs)
167
+
168
+ return text
169
+
170
+ audio_chunked = gr.Interface(
171
+ fn=convert_audio,
172
+ inputs=inputs,
173
+ outputs=[gr.components.Textbox()],
174
+ allow_flagging="never",
175
+ title=title,
176
+ description=description,
177
+ article=article,
178
+ examples=audio_examples,
179
+ )
180
+
181
+ live_audio_chunked = gr.Interface(
182
+ fn=live_convert_audio,
183
+ inputs=live_inputs,
184
+ outputs=[gr.components.Textbox()],
185
+ allow_flagging="never",
186
+ title=title,
187
+ description=description,
188
+ article=article,
189
+ # examples=audio_examples,
190
+ live=True,
191
+ )
192
+
193
+
194
+ demo = gr.Blocks()
195
+ with demo:
196
+ gr.TabbedInterface(
197
+ [
198
+ audio_chunked,
199
+ live_audio_chunked,
200
+ ],
201
+ [
202
+ "Audio File or Recording",
203
+ "Live Streaming Music"
204
+ ]
205
+ )
206
+ demo.queue(concurrency_count=1, max_size=5)
207
+ demo.launch(show_api=False)
requirements.txt ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==5.0.0
5
+ anyio==3.6.2
6
+ async-timeout==4.0.2
7
+ attrs==23.1.0
8
+ certifi==2023.5.7
9
+ charset-normalizer==3.1.0
10
+ click==8.1.3
11
+ cmake==3.26.3
12
+ contourpy==1.0.7
13
+ cycler==0.11.0
14
+ fastapi==0.95.2
15
+ ffmpy==0.3.0
16
+ filelock==3.12.0
17
+ fonttools==4.39.4
18
+ frozenlist==1.3.3
19
+ fsspec==2023.5.0
20
+ gradio==3.31.0
21
+ gradio_client==0.2.5
22
+ h11==0.14.0
23
+ httpcore==0.17.1
24
+ httpx==0.24.0
25
+ huggingface-hub==0.14.1
26
+ idna==3.4
27
+ Jinja2==3.1.2
28
+ jsonschema==4.17.3
29
+ kiwisolver==1.4.4
30
+ linkify-it-py==2.0.2
31
+ lit==16.0.5
32
+ markdown-it-py==2.2.0
33
+ MarkupSafe==2.1.2
34
+ matplotlib==3.7.1
35
+ mdit-py-plugins==0.3.3
36
+ mdurl==0.1.2
37
+ mpmath==1.3.0
38
+ multidict==6.0.4
39
+ networkx==3.1
40
+ nnAudio==0.3.2
41
+ numpy==1.24.3
42
+ nvidia-cublas-cu11==11.10.3.66
43
+ nvidia-cuda-cupti-cu11==11.7.101
44
+ nvidia-cuda-nvrtc-cu11==11.7.99
45
+ nvidia-cuda-runtime-cu11==11.7.99
46
+ nvidia-cudnn-cu11==8.5.0.96
47
+ nvidia-cufft-cu11==10.9.0.58
48
+ nvidia-curand-cu11==10.2.10.91
49
+ nvidia-cusolver-cu11==11.4.0.1
50
+ nvidia-cusparse-cu11==11.7.4.91
51
+ nvidia-nccl-cu11==2.14.3
52
+ nvidia-nvtx-cu11==11.7.91
53
+ orjson==3.8.12
54
+ packaging==23.1
55
+ pandas==2.0.1
56
+ Pillow==9.5.0
57
+ pydantic==1.10.7
58
+ pydub==0.25.1
59
+ Pygments==2.15.1
60
+ pyparsing==3.0.9
61
+ pyrsistent==0.19.3
62
+ python-dateutil==2.8.2
63
+ python-multipart==0.0.6
64
+ pytz==2023.3
65
+ PyYAML==6.0
66
+ regex==2023.5.5
67
+ requests==2.30.0
68
+ scipy==1.10.1
69
+ semantic-version==2.10.0
70
+ six==1.16.0
71
+ sniffio==1.3.0
72
+ starlette==0.27.0
73
+ sympy==1.12
74
+ tokenizers==0.13.3
75
+ toolz==0.12.0
76
+ torch==2.0.1
77
+ torchaudio==2.0.2
78
+ torchvision==0.15.2
79
+ tqdm==4.65.0
80
+ transformers==4.29.2
81
+ triton==2.0.0
82
+ typing_extensions==4.5.0
83
+ tzdata==2023.3
84
+ uc-micro-py==1.0.2
85
+ urllib3==2.0.2
86
+ uvicorn==0.22.0
87
+ websockets==11.0.3
88
+ yarl==1.9.2