John Waidhofer commited on
Commit
4078103
1 Parent(s): 6193575

updated to gtzan

Browse files
Files changed (1) hide show
  1. models/wav2vec2.py +5 -6
models/wav2vec2.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any
3
  import pytorch_lightning as pl
4
  from torch.utils.data import random_split
5
  from transformers import AutoFeatureExtractor
6
- from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
7
 
8
  from preprocessing.dataset import (
9
  HuggingFaceDatasetWrapper,
@@ -13,20 +13,18 @@ from preprocessing.pipelines import WaveformTrainingPipeline
13
 
14
  from .utils import get_id_label_mapping, compute_hf_metrics
15
 
16
- MODEL_CHECKPOINT = "m3hrdadfi/wav2vec2-base-100k-voxpopuli-gtzan-music"
17
 
18
 
19
  class Wav2VecFeatureExtractor:
20
  def __init__(self) -> None:
21
  self.waveform_pipeline = WaveformTrainingPipeline()
22
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(
23
- MODEL_CHECKPOINT,
24
- )
25
 
26
  def __call__(self, waveform) -> Any:
27
  waveform = self.waveform_pipeline(waveform)
28
  return self.feature_extractor(
29
- waveform.squeeze(0), sampling_rate=self.feature_extractor.sampling_rate
30
  )
31
 
32
  def __getattr__(self, attr):
@@ -64,6 +62,7 @@ def train_huggingface(config: dict):
64
  learning_rate=3e-5,
65
  per_device_train_batch_size=batch_size,
66
  gradient_accumulation_steps=5,
 
67
  per_device_eval_batch_size=batch_size,
68
  num_train_epochs=epochs,
69
  warmup_ratio=0.1,
 
3
  import pytorch_lightning as pl
4
  from torch.utils.data import random_split
5
  from transformers import AutoFeatureExtractor
6
+ from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer, AutoProcessor
7
 
8
  from preprocessing.dataset import (
9
  HuggingFaceDatasetWrapper,
 
13
 
14
  from .utils import get_id_label_mapping, compute_hf_metrics
15
 
16
+ MODEL_CHECKPOINT = "yuval6967/wav2vec2-base-finetuned-gtzan"
17
 
18
 
19
  class Wav2VecFeatureExtractor:
20
  def __init__(self) -> None:
21
  self.waveform_pipeline = WaveformTrainingPipeline()
22
+ self.feature_extractor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")
 
 
23
 
24
  def __call__(self, waveform) -> Any:
25
  waveform = self.waveform_pipeline(waveform)
26
  return self.feature_extractor(
27
+ waveform.squeeze(0), sampling_rate=16000
28
  )
29
 
30
  def __getattr__(self, attr):
 
62
  learning_rate=3e-5,
63
  per_device_train_batch_size=batch_size,
64
  gradient_accumulation_steps=5,
65
+ gradient_checkpointing=True,
66
  per_device_eval_batch_size=batch_size,
67
  num_train_epochs=epochs,
68
  warmup_ratio=0.1,