File size: 1,585 Bytes
139b8d0
976f8d8
139b8d0
3876e75
618a3f8
 
 
 
92eb30b
618a3f8
 
3555cfd
3876e75
01ec39f
3876e75
01ec39f
f145620
3876e75
 
2a98f83
d8d6e2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3876e75
139b8d0
 
 
3555cfd
 
92eb30b
3555cfd
 
2a98f83
 
3555cfd
2a98f83
98fa83e
92eb30b
98fa83e
2a98f83
8685680
 
 
 
 
 
 
 
 
92eb30b
 
2a98f83
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import warnings

import click

from ..test import (
    get_runtests_cli,
    runtests,
    runtests_jax,
    runtests_startup,
    runtests_torch,
)


@click.group("pysr")
@click.pass_context
def pysr(context):
    ctx = context


@pysr.command("install", help="DEPRECATED (dependencies are now installed at import).")
@click.option(
    "-p",
    "julia_project",
    "--project",
    default=None,
    type=str,
)
@click.option("-q", "--quiet", is_flag=True, default=False, help="Disable logging.")
@click.option(
    "--precompile",
    "precompile",
    flag_value=True,
    default=None,
)
@click.option(
    "--no-precompile",
    "precompile",
    flag_value=False,
    default=None,
)
def _install(julia_project, quiet, precompile):
    warnings.warn(
        "This command is deprecated. Julia dependencies are now installed at first import."
    )


TEST_OPTIONS = {"main", "jax", "torch", "cli", "startup"}


@pysr.command("test")
@click.argument("tests", nargs=1)
def _tests(tests):
    """Run parts of the PySR test suite.

    Choose from main, jax, torch, cli, and startup. You can give multiple tests, separated by commas.
    """
    for test in tests.split(","):
        if test == "main":
            runtests()
        elif test == "jax":
            runtests_jax()
        elif test == "torch":
            runtests_torch()
        elif test == "cli":
            runtests_cli = get_runtests_cli()
            runtests_cli()
        elif test == "startup":
            runtests_startup()
        else:
            warnings.warn(f"Invalid test {test}. Skipping.")