supun9 commited on
Commit
b7f4dbe
1 Parent(s): c4e415d

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +1 -12
  2. audio_train.py +236 -0
  3. collator.py +38 -0
  4. crema.py +73 -0
  5. requirements.txt +101 -0
README.md CHANGED
@@ -1,12 +1 @@
1
- ---
2
- title: Audio Sentiment Analysis
3
- emoji: 🏃
4
- colorFrom: indigo
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 3.23.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ Placeholder for zip data
 
 
 
 
 
 
 
 
 
 
 
audio_train.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import librosa
4
+
5
+ import wandb
6
+ import numpy as np
7
+
8
+ from datasets import DatasetDict, load_dataset, load_metric
9
+ from transformers import (
10
+ HubertForSequenceClassification,
11
+ PretrainedConfig,
12
+ Trainer,
13
+ TrainingArguments,
14
+ Wav2Vec2FeatureExtractor,
15
+ )
16
+ from utils import collator
17
+
18
+ logging.basicConfig(
19
+ format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO
20
+ )
21
+
22
+ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
23
+ NUM_LABELS = 6
24
+
25
+
26
+ USER = "XXXX" # TODO: replace with your username
27
+ WANDB_PROJECT = "XXXXX" # TODO: replace with your project name
28
+ wandb.init(entity=USER, project=WANDB_PROJECT)
29
+
30
+
31
+ # PROCESS THE DATASET TO THE FORMAT EXPECTED BY THE MODEL FOR TRAINING
32
+ PreTrainedFeatureExtractor = "SequenceFeatureExtractor" # noqa: F821
33
+
34
+ INPUT_FIELD = "input_values"
35
+ LABEL_FIELD = "labels"
36
+
37
+
38
+ def prepare_dataset(batch, feature_extractor: PreTrainedFeatureExtractor):
39
+ audio_arr = batch["array"]
40
+ input = feature_extractor(
41
+ audio_arr, sampling_rate=16000, padding=True, return_tensors="pt"
42
+ )
43
+
44
+ batch[INPUT_FIELD] = input.input_values[0]
45
+ batch[LABEL_FIELD] = batch[
46
+ "label"
47
+ ] # colname MUST be labels as Trainer will look for it by default
48
+
49
+ return batch
50
+
51
+
52
+ model_id = "facebook/hubert-base-ls960"
53
+ MODELS_DIR = os.path.join(PROJECT_ROOT, "models")
54
+
55
+ extractor_path = (
56
+ model_id
57
+ if len(os.listdir(MODELS_DIR)) == 0
58
+ else os.path.join(MODELS_DIR, "feature_extractor")
59
+ )
60
+ model_path = (
61
+ model_id
62
+ if len(os.listdir(MODELS_DIR)) == 0
63
+ else os.path.join(MODELS_DIR, "pretrained_model")
64
+ )
65
+
66
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(extractor_path)
67
+
68
+ config = PretrainedConfig.from_pretrained(model_path, num_labels=NUM_LABELS)
69
+ hubert_model = HubertForSequenceClassification.from_pretrained(
70
+ model_path,
71
+ config=config, # because we need to update num_labels as per our dataset
72
+ ignore_mismatched_sizes=True, # to avoid classifier size mismatch from from_pretrained.
73
+ )
74
+
75
+
76
+ # FREEZE LAYERS
77
+
78
+ # freeze all layers to begin with
79
+ for param in hubert_model.parameters():
80
+ param.requires_grad = False
81
+
82
+ layers_freeze_num = 2
83
+ n_layers = (
84
+ 4 + layers_freeze_num * 16
85
+ ) # 4 refers to projector and classifier's weights and biases.
86
+ for name, param in list(hubert_model.named_parameters())[-n_layers:]:
87
+ param.requires_grad = True
88
+
89
+ # # freeze model weights for all layers except projector and classifier
90
+ # for name, param in hubert_model.named_parameters():
91
+ # if any(ext in name for ext in ["projector", "classifier"]):
92
+ # param.requires_grad = True
93
+
94
+
95
+ trainer_config = {
96
+ "OUTPUT_DIR": "results",
97
+ "TRAIN_EPOCHS": 5,
98
+ "TRAIN_BATCH_SIZE": 32,
99
+ "EVAL_BATCH_SIZE": 32,
100
+ "GRADIENT_ACCUMULATION_STEPS": 4,
101
+ "WARMUP_STEPS": 500,
102
+ "DECAY": 0.01,
103
+ "LOGGING_STEPS": 10,
104
+ "MODEL_DIR": "models/audio-model",
105
+ "LR": 1e-3,
106
+ }
107
+
108
+
109
+ dataset_config = {
110
+ "LOADING_SCRIPT_FILES": os.path.join(PROJECT_ROOT, "src/data/crema.py"),
111
+ "CONFIG_NAME": "clean",
112
+ "DATA_DIR": os.path.join(PROJECT_ROOT, "data/archive.zip"),
113
+ "CACHE_DIR": os.path.join(PROJECT_ROOT, "cache_crema"),
114
+ }
115
+
116
+
117
+ ds = load_dataset(
118
+ dataset_config["LOADING_SCRIPT_FILES"],
119
+ dataset_config["CONFIG_NAME"],
120
+ cache_dir=dataset_config["CACHE_DIR"],
121
+ data_dir=dataset_config["DATA_DIR"],
122
+ )
123
+
124
+
125
+ # CONVERING RAW AUDIO TO ARRAYS
126
+ ds = ds.map(
127
+ lambda x: {"array": librosa.load(x["file"], sr=16000, mono=False)[0]},
128
+ num_proc=2,
129
+ )
130
+
131
+
132
+ # LABEL TO ID
133
+ ds = ds.class_encode_column("label")
134
+
135
+
136
+ # ds["train"] = ds["train"].select(range(2500))
137
+ wandb.log({"dataset_size": len(ds["train"])})
138
+
139
+
140
+ # APPLY THE DATA PREP USING FEATURE EXTRACTOR TO ALL EXAMPLES
141
+ ds = ds.map(
142
+ prepare_dataset,
143
+ fn_kwargs={"feature_extractor": feature_extractor},
144
+ # num_proc=4,
145
+ )
146
+ logging.info("Finished extracting features from audio arrays.")
147
+
148
+
149
+ # INTRODUCE TRAIN TEST VAL SPLITS
150
+
151
+ # 90% train, 10% test + validation
152
+ train_testvalid = ds["train"].train_test_split(shuffle=True, test_size=0.1)
153
+ # Split the 10% test + valid in half test, half valid
154
+ test_valid = train_testvalid["test"].train_test_split(test_size=0.5)
155
+ # gather everyone if you want to have a single DatasetDict
156
+ ds = DatasetDict(
157
+ {
158
+ "train": train_testvalid["train"],
159
+ "test": test_valid["test"],
160
+ "val": test_valid["train"],
161
+ }
162
+ )
163
+
164
+
165
+ # DEFINE DATA COLLATOR - TO PAD TRAINING BATCHES DYNAMICALLY
166
+ data_collator = collator.DataCollatorCTCWithPadding(
167
+ processor=feature_extractor, padding=True
168
+ )
169
+
170
+
171
+ # Fine-Tuning with Trainer
172
+ training_args = TrainingArguments(
173
+ output_dir=os.path.join(
174
+ PROJECT_ROOT, trainer_config["OUTPUT_DIR"]
175
+ ), # output directory
176
+ gradient_accumulation_steps=trainer_config[
177
+ "GRADIENT_ACCUMULATION_STEPS"
178
+ ], # accumulate the gradients before running optimization step
179
+ num_train_epochs=trainer_config["TRAIN_EPOCHS"], # total number of training epochs
180
+ per_device_train_batch_size=trainer_config[
181
+ "TRAIN_BATCH_SIZE"
182
+ ], # batch size per device during training
183
+ per_device_eval_batch_size=trainer_config[
184
+ "EVAL_BATCH_SIZE"
185
+ ], # batch size for evaluation
186
+ warmup_steps=trainer_config[
187
+ "WARMUP_STEPS"
188
+ ], # number of warmup steps for learning rate scheduler
189
+ weight_decay=trainer_config["DECAY"], # strength of weight decay
190
+ logging_steps=trainer_config["LOGGING_STEPS"],
191
+ evaluation_strategy="epoch", # report metric at end of each epoch
192
+ report_to="wandb", # enable logging to W&B
193
+ learning_rate=trainer_config["LR"], # default = 5e-5
194
+ )
195
+
196
+
197
+ def compute_metrics(eval_pred):
198
+ # DEFINE EVALUATION METRIC
199
+ compute_accuracy_metric = load_metric("accuracy")
200
+ logits, labels = eval_pred
201
+ predictions = np.argmax(logits, axis=-1)
202
+ return compute_accuracy_metric.compute(predictions=predictions, references=labels)
203
+
204
+
205
+ # START TRAINING
206
+ trainer = Trainer(
207
+ model=hubert_model, # the instantiated 🤗 Transformers model to be trained
208
+ args=training_args, # training arguments, defined above
209
+ data_collator=data_collator,
210
+ train_dataset=ds["train"], # training dataset
211
+ eval_dataset=ds["val"], # evaluation dataset
212
+ compute_metrics=compute_metrics,
213
+ )
214
+
215
+
216
+ trainer.train()
217
+
218
+ # TO RESUME TRAINING FROM CHECKPOINT
219
+ # trainer.train("results/checkpoint-2000")
220
+
221
+ # VALIDATION SET RESULTS
222
+ logging.info("Eval Set Result: {}".format(trainer.evaluate()))
223
+
224
+ # TEST RESULTS
225
+ test_results = trainer.predict(ds["test"])
226
+ logging.info("Test Set Result: {}".format(test_results.metrics))
227
+ wandb.log({"test_accuracy": test_results.metrics["test_accuracy"]})
228
+
229
+ trainer.save_model(os.path.join(PROJECT_ROOT, trainer_config["MODEL_DIR"]))
230
+
231
+ # logging trained models to wandb
232
+ wandb.save(
233
+ os.path.join(PROJECT_ROOT, trainer_config["MODEL_DIR"], "*"),
234
+ base_path=os.path.dirname(trainer_config["MODEL_DIR"]),
235
+ policy="end",
236
+ )
collator.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Union
3
+
4
+ import torch
5
+ from transformers import Wav2Vec2Processor
6
+
7
+ INPUT_FIELD = "input_values"
8
+ LABEL_FIELD = "labels"
9
+
10
+
11
+ @dataclass
12
+ class DataCollatorCTCWithPadding:
13
+ processor: Wav2Vec2Processor
14
+ padding: Union[bool, str] = True
15
+ max_length: Optional[int] = None
16
+ max_length_labels: Optional[int] = None
17
+ pad_to_multiple_of: Optional[int] = None
18
+ pad_to_multiple_of_labels: Optional[int] = None
19
+
20
+ def __call__(
21
+ self, examples: List[Dict[str, Union[List[int], torch.Tensor]]]
22
+ ) -> Dict[str, torch.Tensor]:
23
+
24
+ input_features = [
25
+ {INPUT_FIELD: example[INPUT_FIELD]} for example in examples
26
+ ] # example is basically row0, row1, etc...
27
+ labels = [example[LABEL_FIELD] for example in examples]
28
+
29
+ batch = self.processor.pad(
30
+ input_features,
31
+ padding=self.padding,
32
+ max_length=self.max_length,
33
+ pad_to_multiple_of=self.pad_to_multiple_of,
34
+ return_tensors="pt",
35
+ )
36
+ batch[LABEL_FIELD] = torch.tensor(labels)
37
+
38
+ return batch
crema.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Lint as: python3
2
+ """CREMA-D dataset."""
3
+
4
+ import os
5
+ from typing import Union
6
+
7
+ import datasets
8
+ import pandas as pd
9
+
10
+ _DESCRIPTION = """\
11
+ CREMA-D is a data set of 7,442 original clips from 91 actors.
12
+ These clips were from 48 male and 43 female actors between the ages of 20 and 74
13
+ coming from a variety of races and ethnicities (African America, Asian,
14
+ Caucasian, Hispanic, and Unspecified). Actors spoke from a selection of 12
15
+ sentences. The sentences were presented using one of six different emotions
16
+ (Anger, Disgust, Fear, Happy, Neutral, and Sad) and four different emotion
17
+ levels (Low, Medium, High, and Unspecified).
18
+ """
19
+
20
+ _HOMEPAGE = "https://github.com/CheyneyComputerScience/CREMA-D"
21
+
22
+ DATA_DIR = {"train": "AudioWAV"}
23
+
24
+
25
+ class Crema(datasets.GeneratorBasedBuilder):
26
+ """Crema-D dataset."""
27
+
28
+ DEFAULT_WRITER_BATCH_SIZE = 256
29
+ BUILDER_CONFIGS = [datasets.BuilderConfig(name="clean", description="Train Set.")]
30
+
31
+ def _info(self):
32
+ return datasets.DatasetInfo(
33
+ description=_DESCRIPTION,
34
+ features=datasets.Features(
35
+ {"file": datasets.Value("string"), "label": datasets.Value("string")}
36
+ ),
37
+ supervised_keys=("file", "label"),
38
+ homepage=_HOMEPAGE,
39
+ )
40
+
41
+ def _split_generators(
42
+ self, dl_manager: datasets.utils.download_manager.DownloadManager
43
+ ):
44
+ data_dir = dl_manager.extract(self.config.data_dir)
45
+ if self.config.name == "clean":
46
+ train_splits = [
47
+ datasets.SplitGenerator(
48
+ name="train", gen_kwargs={"files": data_dir, "name": "train"}
49
+ )
50
+ ]
51
+
52
+ return train_splits
53
+
54
+ def _generate_examples(self, files: Union[str, os.PathLike], name: str):
55
+ """Generate examples from a Crema unzipped directory."""
56
+ key = 0
57
+ examples = list()
58
+
59
+ audio_dir = os.path.join(files, DATA_DIR[name])
60
+
61
+ if not os.path.exists(audio_dir):
62
+ raise FileNotFoundError
63
+ else:
64
+ for file in os.listdir(audio_dir):
65
+ res = dict()
66
+ res["file"] = "{}".format(os.path.join(audio_dir, file))
67
+ res["label"] = file.split("_")[-2]
68
+ examples.append(res)
69
+
70
+ for example in examples:
71
+ yield key, {**example}
72
+ key += 1
73
+ examples = []
requirements.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.1
2
+ aiosignal==1.2.0
3
+ appdirs==1.4.4
4
+ appnope==0.1.3
5
+ asttokens==2.0.8
6
+ async-timeout==4.0.2
7
+ attrs==22.1.0
8
+ audioread==3.0.0
9
+ backcall==0.2.0
10
+ black==22.6.0
11
+ certifi==2022.6.15
12
+ cffi==1.15.1
13
+ charset-normalizer==2.1.1
14
+ click==8.1.3
15
+ datasets==2.4.0
16
+ debugpy==1.6.3
17
+ decorator==5.1.1
18
+ dill==0.3.5.1
19
+ docker-pycreds==0.4.0
20
+ entrypoints==0.4
21
+ executing==0.10.0
22
+ filelock==3.8.0
23
+ flake8==5.0.4
24
+ frozenlist==1.3.1
25
+ fsspec==2022.7.1
26
+ gitdb==4.0.9
27
+ GitPython==3.1.27
28
+ huggingface-hub==0.8.1
29
+ idna==3.3
30
+ ipykernel==6.15.1
31
+ ipython==8.4.0
32
+ ipywidgets==8.0.1
33
+ jedi==0.18.1
34
+ joblib==1.1.0
35
+ jupyter-client==7.3.4
36
+ jupyter-core==4.11.1
37
+ jupyterlab-widgets==3.0.2
38
+ librosa==0.8.1
39
+ llvmlite==0.39.0
40
+ matplotlib-inline==0.1.6
41
+ mccabe==0.7.0
42
+ multidict==6.0.2
43
+ multiprocess==0.70.13
44
+ mypy-extensions==0.4.3
45
+ nest-asyncio==1.5.5
46
+ numba==0.56.0
47
+ numpy==1.22.0
48
+ packaging==21.3
49
+ pandas==1.4.3
50
+ parso==0.8.3
51
+ pathspec==0.9.0
52
+ pathtools==0.1.2
53
+ pexpect==4.8.0
54
+ pickleshare==0.7.5
55
+ platformdirs==2.5.2
56
+ pooch==1.6.0
57
+ promise==2.3
58
+ prompt-toolkit==3.0.30
59
+ protobuf==3.20.1
60
+ psutil==5.9.1
61
+ ptyprocess==0.7.0
62
+ pure-eval==0.2.2
63
+ pyarrow==9.0.0
64
+ pycodestyle==2.9.1
65
+ pycparser==2.21
66
+ pyflakes==2.5.0
67
+ Pygments==2.13.0
68
+ pyparsing==3.0.9
69
+ python-dateutil==2.8.2
70
+ pytz==2022.2.1
71
+ PyYAML==6.0
72
+ pyzmq==23.2.1
73
+ regex==2022.8.17
74
+ requests==2.28.1
75
+ resampy==0.4.0
76
+ responses==0.18.0
77
+ scikit-learn==1.1.2
78
+ scipy==1.9.0
79
+ sentry-sdk==1.9.5
80
+ setproctitle==1.3.2
81
+ shortuuid==1.0.9
82
+ six==1.16.0
83
+ sklearn==0.0
84
+ smmap==5.0.0
85
+ SoundFile==0.10.3.post1
86
+ stack-data==0.4.0
87
+ threadpoolctl==3.1.0
88
+ tokenizers==0.12.1
89
+ tomli==2.0.1
90
+ torch==1.12.1
91
+ tornado==6.2
92
+ tqdm==4.64.0
93
+ traitlets==5.3.0
94
+ transformers==4.21.2
95
+ typing_extensions==4.3.0
96
+ urllib3==1.26.12
97
+ wandb==0.13.2
98
+ wcwidth==0.2.5
99
+ widgetsnbextension==4.0.2
100
+ xxhash==3.0.0
101
+ yarl==1.8.1