File size: 1,394 Bytes
f257e58
9a5df63
 
f257e58
 
9a5df63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f257e58
 
8f218cc
0cb1353
f257e58
 
6d5ddcb
8f218cc
f257e58
8f218cc
f257e58
 
 
 
 
 
 
 
 
6d5ddcb
f257e58
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""Functions to help export PySR equations to LaTeX."""
import sympy
from sympy.printing.latex import LatexPrinter


class PreciseLatexPrinter(LatexPrinter):
    """Modified SymPy printer with custom float precision."""
    def __init__(self, settings=None, prec=3):
        super().__init__(settings)
        self.prec = prec

    def _print_Float(self, expr):
        # Reduce precision of float:
        reduced_float = sympy.Float(expr, self.prec)
        return super()._print_Float(reduced_float)


def to_latex(expr, prec=3, **settings):
    """Convert sympy expression to LaTeX with custom precision."""
    if len(settings) == 0:
        settings = None
    printer = PreciseLatexPrinter(settings=settings, prec=prec)
    return printer.doprint(expr)


def generate_top_of_latex_table(columns=["Equation", "Complexity", "Loss"]):
    margins = "".join([("l" if col == "Equation" else "c") for col in columns])
    latex_table_pieces = [
        r"\begin{table}[h]",
        r"\begin{center}",
        r"\begin{tabular}{@{}" + margins + r"@{}}",
        r"\toprule",
        " & ".join(columns) + r" \\",
        r"\midrule",
    ]
    return "\n".join(latex_table_pieces)


def generate_bottom_of_latex_table():
    latex_table_pieces = [
        r"\bottomrule",
        r"\end{tabular}",
        r"\end{center}",
        r"\end{table}",
    ]
    return "\n".join(latex_table_pieces)