from transformers import ( | |
AutoConfig | |
) | |
from .custom_model import CustomModelForAudioClassification | |
def make_model(hparams): | |
""" Returns a model instance based on the provided hyperparameters. """ | |
hparams = vars(hparams) | |
config = AutoConfig.from_pretrained(hparams['HF_MODEL_PATH']) | |
config.max_duration = hparams['MAX_DURATION'] | |
config.sampling_rate = hparams['SAMPLING_RATE'] | |
config.output_hidden_states = hparams['OUTPUT_HIDDEN_STATES'] | |
config.classifier_name = hparams['CLASSIFIER_NAME'] | |
config.classifier_proj_size = hparams['CLASSIFIER_PROJ_SIZE'] | |
config.num_labels = hparams['NUM_LABELS'] | |
config.label_weights = hparams['LABEL_WEIGHTS'] | |
config.lossname = hparams['LOSS'] | |
model = CustomModelForAudioClassification(config) | |
return model |