File size: 742 Bytes
8f96165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f5f788
 
8f96165
3f5f788
 
8f96165
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import argparse
import torch
from .make_model import make_model


hparams_dict = {
    'HF_MODEL_PATH': 'facebook/wav2vec2-large-xlsr-53',
    'DATASET': 'recanvo',
    'MAX_DURATION': 4,
    'SAMPLING_RATE': 16_000,
    'OUTPUT_HIDDEN_STATES': True,
    'CLASSIFIER_NAME': 'multilevel',
    'CLASSIFIER_PROJ_SIZE': 256,
    'NUM_LABELS': 3,
    'LABEL_WEIGHTS': [1.0],
    'LOSS': 'cross-entropy',
    'GPU_ID': 0,
    'RETURN_RAW_ARRAY': False,
}
hparams = argparse.Namespace(**hparams_dict)

def get_behaviour_model(classifier_weights_path, device):
    state_dict = torch.load(classifier_weights_path, map_location=device)
    model = make_model(hparams)
    model.classifier.load_state_dict(state_dict)
    model.eval()

    return model