File size: 801 Bytes
8f96165 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from transformers import (
AutoConfig
)
from .custom_model import CustomModelForAudioClassification
def make_model(hparams):
""" Returns a model instance based on the provided hyperparameters. """
hparams = vars(hparams)
config = AutoConfig.from_pretrained(hparams['HF_MODEL_PATH'])
config.max_duration = hparams['MAX_DURATION']
config.sampling_rate = hparams['SAMPLING_RATE']
config.output_hidden_states = hparams['OUTPUT_HIDDEN_STATES']
config.classifier_name = hparams['CLASSIFIER_NAME']
config.classifier_proj_size = hparams['CLASSIFIER_PROJ_SIZE']
config.num_labels = hparams['NUM_LABELS']
config.label_weights = hparams['LABEL_WEIGHTS']
config.lossname = hparams['LOSS']
model = CustomModelForAudioClassification(config)
return model |