JuanJoseMV's picture
load classifier weights
3f5f788
raw
history blame contribute delete
742 Bytes
import argparse
import torch
from .make_model import make_model
hparams_dict = {
'HF_MODEL_PATH': 'facebook/wav2vec2-large-xlsr-53',
'DATASET': 'recanvo',
'MAX_DURATION': 4,
'SAMPLING_RATE': 16_000,
'OUTPUT_HIDDEN_STATES': True,
'CLASSIFIER_NAME': 'multilevel',
'CLASSIFIER_PROJ_SIZE': 256,
'NUM_LABELS': 3,
'LABEL_WEIGHTS': [1.0],
'LOSS': 'cross-entropy',
'GPU_ID': 0,
'RETURN_RAW_ARRAY': False,
}
hparams = argparse.Namespace(**hparams_dict)
def get_behaviour_model(classifier_weights_path, device):
state_dict = torch.load(classifier_weights_path, map_location=device)
model = make_model(hparams)
model.classifier.load_state_dict(state_dict)
model.eval()
return model