Spaces:
Runtime error
Runtime error
import os | |
from typing import Any | |
import pytorch_lightning as pl | |
from torch.utils.data import random_split | |
from transformers import AutoFeatureExtractor | |
from transformers import ( | |
AutoModelForAudioClassification, | |
TrainingArguments, | |
Trainer, | |
AutoProcessor, | |
) | |
from preprocessing.dataset import ( | |
HuggingFaceDatasetWrapper, | |
get_datasets, | |
) | |
from preprocessing.pipelines import WaveformTrainingPipeline | |
from .utils import get_id_label_mapping, compute_hf_metrics | |
MODEL_CHECKPOINT = "yuval6967/wav2vec2-base-finetuned-gtzan" | |
PROCESSOR_CHECKPOINT = "facebook/wav2vec2-base" | |
class Wav2VecFeatureExtractor: | |
def __init__(self) -> None: | |
self.waveform_pipeline = WaveformTrainingPipeline() | |
self.feature_extractor = AutoProcessor.from_pretrained(PROCESSOR_CHECKPOINT) | |
def __call__(self, waveform) -> Any: | |
waveform = self.waveform_pipeline(waveform) | |
return self.feature_extractor(waveform.squeeze(0), sampling_rate=16000) | |
def __getattr__(self, attr): | |
return getattr(self.feature_extractor, attr) | |
def train_huggingface(config: dict): | |
TARGET_CLASSES = config["dance_ids"] | |
DEVICE = config["device"] | |
SEED = config["seed"] | |
OUTPUT_DIR = "models/weights/wav2vec2" | |
batch_size = config["data_module"]["batch_size"] | |
epochs = config["trainer"]["min_epochs"] | |
test_proportion = config["data_module"].get("test_proportion", 0.2) | |
pl.seed_everything(SEED, workers=True) | |
feature_extractor = Wav2VecFeatureExtractor() | |
dataset = get_datasets(config["datasets"], feature_extractor) | |
dataset = HuggingFaceDatasetWrapper(dataset) | |
id2label, label2id = get_id_label_mapping(TARGET_CLASSES) | |
test_proportion = config["data_module"]["test_proportion"] | |
train_proporition = 1 - test_proportion | |
train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion]) | |
model = AutoModelForAudioClassification.from_pretrained( | |
MODEL_CHECKPOINT, | |
num_labels=len(TARGET_CLASSES), | |
label2id=label2id, | |
id2label=id2label, | |
ignore_mismatched_sizes=True, | |
).to(DEVICE) | |
training_args = TrainingArguments( | |
output_dir=OUTPUT_DIR, | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
learning_rate=3e-5, | |
per_device_train_batch_size=batch_size, | |
gradient_accumulation_steps=5, | |
gradient_checkpointing=True, | |
per_device_eval_batch_size=batch_size, | |
num_train_epochs=epochs, | |
warmup_ratio=0.1, | |
logging_steps=10, | |
load_best_model_at_end=True, | |
metric_for_best_model="accuracy", | |
push_to_hub=False, | |
use_mps_device=DEVICE == "mps", | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_ds, | |
eval_dataset=test_ds, | |
compute_metrics=compute_hf_metrics, | |
) | |
trainer.train() | |
return model | |