|
import argparse |
|
import torch |
|
from .make_model import make_model |
|
|
|
|
|
hparams_dict = { |
|
'HF_MODEL_PATH': 'facebook/wav2vec2-large-xlsr-53', |
|
'DATASET': 'recanvo', |
|
'MAX_DURATION': 4, |
|
'SAMPLING_RATE': 16_000, |
|
'OUTPUT_HIDDEN_STATES': True, |
|
'CLASSIFIER_NAME': 'multilevel', |
|
'CLASSIFIER_PROJ_SIZE': 256, |
|
'NUM_LABELS': 3, |
|
'LABEL_WEIGHTS': [1.0], |
|
'LOSS': 'cross-entropy', |
|
'GPU_ID': 0, |
|
'RETURN_RAW_ARRAY': False, |
|
} |
|
hparams = argparse.Namespace(**hparams_dict) |
|
|
|
def get_behaviour_model(classifier_weights_path, device): |
|
state_dict = torch.load(classifier_weights_path, map_location=device) |
|
model = make_model(hparams) |
|
model.classifier.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
return model |