import argparse import torch from .make_model import make_model hparams_dict = { 'HF_MODEL_PATH': 'facebook/wav2vec2-large-xlsr-53', 'DATASET': 'recanvo', 'MAX_DURATION': 4, 'SAMPLING_RATE': 16_000, 'OUTPUT_HIDDEN_STATES': True, 'CLASSIFIER_NAME': 'multilevel', 'CLASSIFIER_PROJ_SIZE': 256, 'NUM_LABELS': 3, 'LABEL_WEIGHTS': [1.0], 'LOSS': 'cross-entropy', 'GPU_ID': 0, 'RETURN_RAW_ARRAY': False, } hparams = argparse.Namespace(**hparams_dict) def get_behaviour_model(classifier_weights_path, device): state_dict = torch.load(classifier_weights_path, map_location=device) model = make_model(hparams) model.classifier.load_state_dict(state_dict) model.eval() return model