waidhoferj commited on
Commit
5da8010
1 Parent(s): 4078103

added hubert and updated preprocessing for wav2vec2

Browse files
Files changed (2) hide show
  1. models/hubert.py +86 -0
  2. models/wav2vec2.py +9 -5
models/hubert.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any
3
+ import pytorch_lightning as pl
4
+ from torch.utils.data import random_split
5
+ from transformers import AutoFeatureExtractor
6
+ from transformers import (
7
+ AutoModelForAudioClassification,
8
+ TrainingArguments,
9
+ Trainer,
10
+ )
11
+
12
+ from preprocessing.dataset import (
13
+ HuggingFaceDatasetWrapper,
14
+ get_datasets,
15
+ )
16
+ from preprocessing.pipelines import WaveformTrainingPipeline
17
+
18
+ from .utils import get_id_label_mapping, compute_hf_metrics
19
+
20
+ MODEL_CHECKPOINT = "ntu-spml/distilhubert"
21
+
22
+
23
+ class HubertFeatureExtractor:
24
+ def __init__(self) -> None:
25
+ self.waveform_pipeline = WaveformTrainingPipeline()
26
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_CHECKPOINT)
27
+
28
+ def __call__(self, waveform) -> Any:
29
+ waveform = self.waveform_pipeline(waveform)
30
+ return self.feature_extractor(waveform.squeeze(0), sampling_rate=16000)
31
+
32
+ def __getattr__(self, attr):
33
+ return getattr(self.feature_extractor, attr)
34
+
35
+
36
+ def train_huggingface(config: dict):
37
+ TARGET_CLASSES = config["dance_ids"]
38
+ DEVICE = config["device"]
39
+ SEED = config["seed"]
40
+ OUTPUT_DIR = "models/weights/wav2vec2"
41
+ batch_size = config["data_module"]["batch_size"]
42
+ epochs = config["trainer"]["min_epochs"]
43
+ test_proportion = config["data_module"].get("test_proportion", 0.2)
44
+ pl.seed_everything(SEED, workers=True)
45
+ feature_extractor = HubertFeatureExtractor()
46
+ dataset = get_datasets(config["datasets"], feature_extractor)
47
+ dataset = HuggingFaceDatasetWrapper(dataset)
48
+ id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
49
+ test_proportion = config["data_module"]["test_proportion"]
50
+ train_proporition = 1 - test_proportion
51
+ train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion])
52
+
53
+ model = AutoModelForAudioClassification.from_pretrained(
54
+ MODEL_CHECKPOINT,
55
+ num_labels=len(TARGET_CLASSES),
56
+ label2id=label2id,
57
+ id2label=id2label,
58
+ # ignore_mismatched_sizes=True,
59
+ ).to(DEVICE)
60
+ training_args = TrainingArguments(
61
+ output_dir=OUTPUT_DIR,
62
+ evaluation_strategy="epoch",
63
+ save_strategy="epoch",
64
+ learning_rate=5e-5,
65
+ per_device_train_batch_size=batch_size,
66
+ gradient_accumulation_steps=1,
67
+ gradient_checkpointing=True,
68
+ per_device_eval_batch_size=batch_size,
69
+ num_train_epochs=epochs,
70
+ warmup_ratio=0.1,
71
+ logging_steps=10,
72
+ load_best_model_at_end=True,
73
+ metric_for_best_model="accuracy",
74
+ push_to_hub=False,
75
+ use_mps_device=DEVICE == "mps",
76
+ fp16=True,
77
+ )
78
+ trainer = Trainer(
79
+ model=model,
80
+ args=training_args,
81
+ train_dataset=train_ds,
82
+ eval_dataset=test_ds,
83
+ compute_metrics=compute_hf_metrics,
84
+ )
85
+ trainer.train()
86
+ return model
models/wav2vec2.py CHANGED
@@ -3,7 +3,12 @@ from typing import Any
3
  import pytorch_lightning as pl
4
  from torch.utils.data import random_split
5
  from transformers import AutoFeatureExtractor
6
- from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer, AutoProcessor
 
 
 
 
 
7
 
8
  from preprocessing.dataset import (
9
  HuggingFaceDatasetWrapper,
@@ -14,18 +19,17 @@ from preprocessing.pipelines import WaveformTrainingPipeline
14
  from .utils import get_id_label_mapping, compute_hf_metrics
15
 
16
  MODEL_CHECKPOINT = "yuval6967/wav2vec2-base-finetuned-gtzan"
 
17
 
18
 
19
  class Wav2VecFeatureExtractor:
20
  def __init__(self) -> None:
21
  self.waveform_pipeline = WaveformTrainingPipeline()
22
- self.feature_extractor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")
23
 
24
  def __call__(self, waveform) -> Any:
25
  waveform = self.waveform_pipeline(waveform)
26
- return self.feature_extractor(
27
- waveform.squeeze(0), sampling_rate=16000
28
- )
29
 
30
  def __getattr__(self, attr):
31
  return getattr(self.feature_extractor, attr)
 
3
  import pytorch_lightning as pl
4
  from torch.utils.data import random_split
5
  from transformers import AutoFeatureExtractor
6
+ from transformers import (
7
+ AutoModelForAudioClassification,
8
+ TrainingArguments,
9
+ Trainer,
10
+ AutoProcessor,
11
+ )
12
 
13
  from preprocessing.dataset import (
14
  HuggingFaceDatasetWrapper,
 
19
  from .utils import get_id_label_mapping, compute_hf_metrics
20
 
21
  MODEL_CHECKPOINT = "yuval6967/wav2vec2-base-finetuned-gtzan"
22
+ PROCESSOR_CHECKPOINT = "facebook/wav2vec2-base"
23
 
24
 
25
  class Wav2VecFeatureExtractor:
26
  def __init__(self) -> None:
27
  self.waveform_pipeline = WaveformTrainingPipeline()
28
+ self.feature_extractor = AutoProcessor.from_pretrained(PROCESSOR_CHECKPOINT)
29
 
30
  def __call__(self, waveform) -> Any:
31
  waveform = self.waveform_pipeline(waveform)
32
+ return self.feature_extractor(waveform.squeeze(0), sampling_rate=16000)
 
 
33
 
34
  def __getattr__(self, attr):
35
  return getattr(self.feature_extractor, attr)