Spaces:
Running
Running
allow custom calculator with warning
Browse files
mlip_arena/tasks/utils.py
CHANGED
@@ -21,7 +21,7 @@ from pprint import pformat
|
|
21 |
|
22 |
|
23 |
def get_calculator(
|
24 |
-
calculator_name: str | MLIPEnum,
|
25 |
calculator_kwargs: dict | None,
|
26 |
dispersion: bool = False,
|
27 |
dispersion_kwargs: dict | None = None,
|
@@ -30,7 +30,7 @@ def get_calculator(
|
|
30 |
"""Get a calculator with optional dispersion correction."""
|
31 |
device = device or str(get_freer_device())
|
32 |
|
33 |
-
logger.info(
|
34 |
|
35 |
calculator_kwargs = calculator_kwargs or {}
|
36 |
|
@@ -39,10 +39,13 @@ def get_calculator(
|
|
39 |
calc = calculator_name.value(**calculator_kwargs)
|
40 |
elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
|
41 |
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
|
|
|
|
|
|
|
42 |
else:
|
43 |
raise ValueError(f"Invalid calculator: {calculator_name}")
|
44 |
|
45 |
-
logger.info(
|
46 |
if calculator_kwargs:
|
47 |
logger.info(pformat(calculator_kwargs))
|
48 |
|
@@ -58,7 +61,7 @@ def get_calculator(
|
|
58 |
)
|
59 |
calc = SumCalculator([calc, disp_calc])
|
60 |
|
61 |
-
logger.info(
|
62 |
if dispersion_kwargs:
|
63 |
logger.info(pformat(dispersion_kwargs))
|
64 |
|
|
|
21 |
|
22 |
|
23 |
def get_calculator(
|
24 |
+
calculator_name: str | MLIPEnum | Calculator,
|
25 |
calculator_kwargs: dict | None,
|
26 |
dispersion: bool = False,
|
27 |
dispersion_kwargs: dict | None = None,
|
|
|
30 |
"""Get a calculator with optional dispersion correction."""
|
31 |
device = device or str(get_freer_device())
|
32 |
|
33 |
+
logger.info("Using device: %s", device)
|
34 |
|
35 |
calculator_kwargs = calculator_kwargs or {}
|
36 |
|
|
|
39 |
calc = calculator_name.value(**calculator_kwargs)
|
40 |
elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
|
41 |
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
|
42 |
+
elif isinstance(calculator_name, Calculator):
|
43 |
+
logger.warning("Using custom calculator: {calculator_name}")
|
44 |
+
calc = calculator_name
|
45 |
else:
|
46 |
raise ValueError(f"Invalid calculator: {calculator_name}")
|
47 |
|
48 |
+
logger.info("Using calculator: %s", calc)
|
49 |
if calculator_kwargs:
|
50 |
logger.info(pformat(calculator_kwargs))
|
51 |
|
|
|
61 |
)
|
62 |
calc = SumCalculator([calc, disp_calc])
|
63 |
|
64 |
+
logger.info("Using dispersion: %s", disp_calc)
|
65 |
if dispersion_kwargs:
|
66 |
logger.info(pformat(dispersion_kwargs))
|
67 |
|