File size: 1,266 Bytes
519d358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Script to evaluate pretrained models.

from argparse import ArgumentParser
import logging
import sys

import torch

from demucs import train, pretrained, evaluate


def main():
    torch.set_num_threads(1)
    logging.basicConfig(stream=sys.stderr, level=logging.INFO)
    parser = ArgumentParser("tools.test_pretrained",
                            description="Evaluate pre-trained models or bags of models "
                                        "on MusDB.")
    pretrained.add_model_flags(parser)
    parser.add_argument('overrides', nargs='*',
                        help='Extra overrides, e.g. test.shifts=2.')
    args = parser.parse_args()

    xp = train.main.get_xp(args.overrides)
    with xp.enter():
        solver = train.get_solver(xp.cfg)

    model = pretrained.get_model_from_args(args)
    solver.model = model.to(solver.device)
    solver.model.eval()

    with torch.no_grad():
        results = evaluate.evaluate(solver, xp.cfg.test.sdr)
    print(results)


if __name__ == '__main__':
    main()