MilesCranmer commited on
Commit
d423f0c
·
1 Parent(s): 215a692

Refactor table env generator

Browse files
Files changed (2) hide show
  1. pysr/export_latex.py +7 -7
  2. pysr/sr.py +4 -9
pysr/export_latex.py CHANGED
@@ -23,7 +23,7 @@ def to_latex(expr, prec=3, full_prec=True, **settings):
23
  return printer.doprint(expr)
24
 
25
 
26
- def generate_top_of_latex_table(columns=["equation", "complexity", "loss"]):
27
  margins = "".join([("l" if col == "equation" else "c") for col in columns])
28
  column_map = {
29
  "complexity": "Complexity",
@@ -32,7 +32,7 @@ def generate_top_of_latex_table(columns=["equation", "complexity", "loss"]):
32
  "score": "Score",
33
  }
34
  columns = [column_map[col] for col in columns]
35
- latex_table_pieces = [
36
  r"\begin{table}[h]",
37
  r"\begin{center}",
38
  r"\begin{tabular}{@{}" + margins + r"@{}}",
@@ -40,14 +40,14 @@ def generate_top_of_latex_table(columns=["equation", "complexity", "loss"]):
40
  " & ".join(columns) + r" \\",
41
  r"\midrule",
42
  ]
43
- return "\n".join(latex_table_pieces)
44
 
45
-
46
- def generate_bottom_of_latex_table():
47
- latex_table_pieces = [
48
  r"\bottomrule",
49
  r"\end{tabular}",
50
  r"\end{center}",
51
  r"\end{table}",
52
  ]
53
- return "\n".join(latex_table_pieces)
 
 
 
 
23
  return printer.doprint(expr)
24
 
25
 
26
+ def generate_table_environment(columns=["equation", "complexity", "loss"]):
27
  margins = "".join([("l" if col == "equation" else "c") for col in columns])
28
  column_map = {
29
  "complexity": "Complexity",
 
32
  "score": "Score",
33
  }
34
  columns = [column_map[col] for col in columns]
35
+ top_pieces = [
36
  r"\begin{table}[h]",
37
  r"\begin{center}",
38
  r"\begin{tabular}{@{}" + margins + r"@{}}",
 
40
  " & ".join(columns) + r" \\",
41
  r"\midrule",
42
  ]
 
43
 
44
+ bottom_pieces = [
 
 
45
  r"\bottomrule",
46
  r"\end{tabular}",
47
  r"\end{center}",
48
  r"\end{table}",
49
  ]
50
+ top_latex_table = "\n".join(top_pieces)
51
+ bottom_latex_table = "\n".join(bottom_pieces)
52
+
53
+ return top_latex_table, bottom_latex_table
pysr/sr.py CHANGED
@@ -27,11 +27,7 @@ from .julia_helpers import (
27
  import_error_string,
28
  )
29
  from .export_numpy import CallableEquation
30
- from .export_latex import (
31
- to_latex,
32
- generate_top_of_latex_table,
33
- generate_bottom_of_latex_table,
34
- )
35
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
36
 
37
 
@@ -2037,8 +2033,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2037
  else:
2038
  indices = list(range(len(self.equations_)))
2039
 
2040
- latex_table_top = generate_top_of_latex_table(columns)
2041
- latex_table_bottom = generate_bottom_of_latex_table()
2042
 
2043
  equations = self.equations_
2044
 
@@ -2092,9 +2087,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2092
  all_latex_table_str.append(
2093
  "\n".join(
2094
  [
2095
- latex_table_top,
2096
  *latex_table_content,
2097
- latex_table_bottom,
2098
  ]
2099
  )
2100
  )
 
27
  import_error_string,
28
  )
29
  from .export_numpy import CallableEquation
30
+ from .export_latex import to_latex, generate_table_environment
 
 
 
 
31
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
32
 
33
 
 
2033
  else:
2034
  indices = list(range(len(self.equations_)))
2035
 
2036
+ latex_top, latex_bottom = generate_table_environment(columns)
 
2037
 
2038
  equations = self.equations_
2039
 
 
2087
  all_latex_table_str.append(
2088
  "\n".join(
2089
  [
2090
+ latex_top,
2091
  *latex_table_content,
2092
+ latex_bottom,
2093
  ]
2094
  )
2095
  )