Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Any | |
| import pytorch_lightning as pl | |
| from torch.utils.data import random_split | |
| from transformers import AutoFeatureExtractor | |
| from transformers import ( | |
| AutoModelForAudioClassification, | |
| TrainingArguments, | |
| Trainer, | |
| AutoProcessor, | |
| ) | |
| from preprocessing.dataset import ( | |
| HuggingFaceDatasetWrapper, | |
| get_datasets, | |
| ) | |
| from preprocessing.pipelines import WaveformTrainingPipeline | |
| from .utils import get_id_label_mapping, compute_hf_metrics | |
| MODEL_CHECKPOINT = "yuval6967/wav2vec2-base-finetuned-gtzan" | |
| PROCESSOR_CHECKPOINT = "facebook/wav2vec2-base" | |
| class Wav2VecFeatureExtractor: | |
| def __init__(self) -> None: | |
| self.waveform_pipeline = WaveformTrainingPipeline() | |
| self.feature_extractor = AutoProcessor.from_pretrained(PROCESSOR_CHECKPOINT) | |
| def __call__(self, waveform) -> Any: | |
| waveform = self.waveform_pipeline(waveform) | |
| return self.feature_extractor(waveform.squeeze(0), sampling_rate=16000) | |
| def __getattr__(self, attr): | |
| return getattr(self.feature_extractor, attr) | |
| def train_huggingface(config: dict): | |
| TARGET_CLASSES = config["dance_ids"] | |
| DEVICE = config["device"] | |
| SEED = config["seed"] | |
| OUTPUT_DIR = "models/weights/wav2vec2" | |
| batch_size = config["data_module"]["batch_size"] | |
| epochs = config["trainer"]["min_epochs"] | |
| test_proportion = config["data_module"].get("test_proportion", 0.2) | |
| pl.seed_everything(SEED, workers=True) | |
| feature_extractor = Wav2VecFeatureExtractor() | |
| dataset = get_datasets(config["datasets"], feature_extractor) | |
| dataset = HuggingFaceDatasetWrapper(dataset) | |
| id2label, label2id = get_id_label_mapping(TARGET_CLASSES) | |
| test_proportion = config["data_module"]["test_proportion"] | |
| train_proporition = 1 - test_proportion | |
| train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion]) | |
| model = AutoModelForAudioClassification.from_pretrained( | |
| MODEL_CHECKPOINT, | |
| num_labels=len(TARGET_CLASSES), | |
| label2id=label2id, | |
| id2label=id2label, | |
| ignore_mismatched_sizes=True, | |
| ).to(DEVICE) | |
| training_args = TrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| learning_rate=3e-5, | |
| per_device_train_batch_size=batch_size, | |
| gradient_accumulation_steps=5, | |
| gradient_checkpointing=True, | |
| per_device_eval_batch_size=batch_size, | |
| num_train_epochs=epochs, | |
| warmup_ratio=0.1, | |
| logging_steps=10, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| push_to_hub=False, | |
| use_mps_device=DEVICE == "mps", | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_ds, | |
| eval_dataset=test_ds, | |
| compute_metrics=compute_hf_metrics, | |
| ) | |
| trainer.train() | |
| return model | |