cyrusyc commited on
Commit
cb1fb61
·
1 Parent(s): a787930

allow custom calculator with warning

Browse files
Files changed (1) hide show
  1. mlip_arena/tasks/utils.py +7 -4
mlip_arena/tasks/utils.py CHANGED
@@ -21,7 +21,7 @@ from pprint import pformat
21
 
22
 
23
  def get_calculator(
24
- calculator_name: str | MLIPEnum,
25
  calculator_kwargs: dict | None,
26
  dispersion: bool = False,
27
  dispersion_kwargs: dict | None = None,
@@ -30,7 +30,7 @@ def get_calculator(
30
  """Get a calculator with optional dispersion correction."""
31
  device = device or str(get_freer_device())
32
 
33
- logger.info(f"Using device: {device}")
34
 
35
  calculator_kwargs = calculator_kwargs or {}
36
 
@@ -39,10 +39,13 @@ def get_calculator(
39
  calc = calculator_name.value(**calculator_kwargs)
40
  elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
41
  calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
 
 
 
42
  else:
43
  raise ValueError(f"Invalid calculator: {calculator_name}")
44
 
45
- logger.info(f"Using calculator: {calc}")
46
  if calculator_kwargs:
47
  logger.info(pformat(calculator_kwargs))
48
 
@@ -58,7 +61,7 @@ def get_calculator(
58
  )
59
  calc = SumCalculator([calc, disp_calc])
60
 
61
- logger.info(f"Using dispersion: {disp_calc}")
62
  if dispersion_kwargs:
63
  logger.info(pformat(dispersion_kwargs))
64
 
 
21
 
22
 
23
  def get_calculator(
24
+ calculator_name: str | MLIPEnum | Calculator,
25
  calculator_kwargs: dict | None,
26
  dispersion: bool = False,
27
  dispersion_kwargs: dict | None = None,
 
30
  """Get a calculator with optional dispersion correction."""
31
  device = device or str(get_freer_device())
32
 
33
+ logger.info("Using device: %s", device)
34
 
35
  calculator_kwargs = calculator_kwargs or {}
36
 
 
39
  calc = calculator_name.value(**calculator_kwargs)
40
  elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
41
  calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
42
+ elif isinstance(calculator_name, Calculator):
43
+ logger.warning("Using custom calculator: {calculator_name}")
44
+ calc = calculator_name
45
  else:
46
  raise ValueError(f"Invalid calculator: {calculator_name}")
47
 
48
+ logger.info("Using calculator: %s", calc)
49
  if calculator_kwargs:
50
  logger.info(pformat(calculator_kwargs))
51
 
 
61
  )
62
  calc = SumCalculator([calc, disp_calc])
63
 
64
+ logger.info("Using dispersion: %s", disp_calc)
65
  if dispersion_kwargs:
66
  logger.info(pformat(dispersion_kwargs))
67