PySR / pysr /export_latex.py
MilesCranmer's picture
test: fix mypy errors with sympy
e84bed4 unverified
raw
history blame
5.13 kB
"""Functions to help export PySR equations to LaTeX."""
from typing import List, Optional, Tuple
import pandas as pd
import sympy # type: ignore
from sympy.printing.latex import LatexPrinter # type: ignore
class PreciseLatexPrinter(LatexPrinter):
"""Modified SymPy printer with custom float precision."""
def __init__(self, settings=None, prec=3):
super().__init__(settings)
self.prec = prec
def _print_Float(self, expr):
# Reduce precision of float:
reduced_float = sympy.Float(expr, self.prec)
return super()._print_Float(reduced_float)
def sympy2latex(expr, prec=3, full_prec=True, **settings) -> str:
"""Convert sympy expression to LaTeX with custom precision."""
settings["full_prec"] = full_prec
printer = PreciseLatexPrinter(settings=settings, prec=prec)
return str(printer.doprint(expr))
def generate_table_environment(
columns: List[str] = ["equation", "complexity", "loss"]
) -> Tuple[str, str]:
margins = "c" * len(columns)
column_map = {
"complexity": "Complexity",
"loss": "Loss",
"equation": "Equation",
"score": "Score",
}
columns = [column_map[col] for col in columns]
top_pieces = [
r"\begin{table}[h]",
r"\begin{center}",
r"\begin{tabular}{@{}" + margins + r"@{}}",
r"\toprule",
" & ".join(columns) + r" \\",
r"\midrule",
]
bottom_pieces = [
r"\bottomrule",
r"\end{tabular}",
r"\end{center}",
r"\end{table}",
]
top_latex_table = "\n".join(top_pieces)
bottom_latex_table = "\n".join(bottom_pieces)
return top_latex_table, bottom_latex_table
def sympy2latextable(
equations: pd.DataFrame,
indices: Optional[List[int]] = None,
precision: int = 3,
columns: List[str] = ["equation", "complexity", "loss", "score"],
max_equation_length: int = 50,
output_variable_name: str = "y",
) -> str:
"""Generate a booktabs-style LaTeX table for a single set of equations."""
assert isinstance(equations, pd.DataFrame)
latex_top, latex_bottom = generate_table_environment(columns)
latex_table_content = []
if indices is None:
indices = list(equations.index)
for i in indices:
latex_equation = sympy2latex(
equations.iloc[i]["sympy_format"],
prec=precision,
)
complexity = str(equations.iloc[i]["complexity"])
loss = sympy2latex(
sympy.Float(equations.iloc[i]["loss"]),
prec=precision,
)
score = sympy2latex(
sympy.Float(equations.iloc[i]["score"]),
prec=precision,
)
row_pieces = []
for col in columns:
if col == "equation":
if len(latex_equation) < max_equation_length:
row_pieces.append(
"$" + output_variable_name + " = " + latex_equation + "$"
)
else:
broken_latex_equation = " ".join(
[
r"\begin{minipage}{0.8\linewidth}",
r"\vspace{-1em}",
r"\begin{dmath*}",
output_variable_name + " = " + latex_equation,
r"\end{dmath*}",
r"\end{minipage}",
]
)
row_pieces.append(broken_latex_equation)
elif col == "complexity":
row_pieces.append("$" + complexity + "$")
elif col == "loss":
row_pieces.append("$" + loss + "$")
elif col == "score":
row_pieces.append("$" + score + "$")
else:
raise ValueError(f"Unknown column: {col}")
latex_table_content.append(
" & ".join(row_pieces) + r" \\",
)
return "\n".join([latex_top, *latex_table_content, latex_bottom])
def sympy2multilatextable(
equations: List[pd.DataFrame],
indices: Optional[List[List[int]]] = None,
precision: int = 3,
columns: List[str] = ["equation", "complexity", "loss", "score"],
output_variable_names: Optional[List[str]] = None,
) -> str:
"""Generate multiple latex tables for a list of equation sets."""
# TODO: Let user specify custom output variable
latex_tables = [
sympy2latextable(
equations[i],
(None if not indices else indices[i]),
precision=precision,
columns=columns,
output_variable_name=(
"y_{" + str(i) + "}"
if output_variable_names is None
else output_variable_names[i]
),
)
for i in range(len(equations))
]
return "\n\n".join(latex_tables)
def with_preamble(table_string: str) -> str:
preamble_string = [
r"\usepackage{breqn}",
r"\usepackage{booktabs}",
"",
"...",
"",
table_string,
]
return "\n".join(preamble_string)