File size: 801 Bytes
8f96165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from transformers import (
    AutoConfig
)
from .custom_model import CustomModelForAudioClassification

def make_model(hparams):
    """ Returns a model instance based on the provided hyperparameters. """
    hparams = vars(hparams)
    config = AutoConfig.from_pretrained(hparams['HF_MODEL_PATH'])
    config.max_duration = hparams['MAX_DURATION']
    config.sampling_rate = hparams['SAMPLING_RATE']
    config.output_hidden_states = hparams['OUTPUT_HIDDEN_STATES']
    config.classifier_name = hparams['CLASSIFIER_NAME']
    config.classifier_proj_size = hparams['CLASSIFIER_PROJ_SIZE']
    config.num_labels = hparams['NUM_LABELS']
    config.label_weights = hparams['LABEL_WEIGHTS']
    config.lossname = hparams['LOSS']
    model = CustomModelForAudioClassification(config)
    
    return model